jarvis-ai-assistant 0.3.23__py3-none-any.whl → 0.3.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +96 -13
- jarvis/jarvis_agent/agent_manager.py +0 -3
- jarvis/jarvis_agent/jarvis.py +19 -34
- jarvis/jarvis_agent/main.py +2 -8
- jarvis/jarvis_code_agent/code_agent.py +5 -11
- jarvis/jarvis_code_analysis/code_review.py +12 -40
- jarvis/jarvis_data/config_schema.json +11 -18
- jarvis/jarvis_git_utils/git_commiter.py +11 -25
- jarvis/jarvis_mcp/sse_mcp_client.py +4 -3
- jarvis/jarvis_mcp/streamable_mcp_client.py +9 -8
- jarvis/jarvis_memory_organizer/memory_organizer.py +46 -53
- jarvis/jarvis_methodology/main.py +4 -2
- jarvis/jarvis_platform/base.py +90 -21
- jarvis/jarvis_platform/kimi.py +16 -22
- jarvis/jarvis_platform/registry.py +7 -14
- jarvis/jarvis_platform/tongyi.py +21 -32
- jarvis/jarvis_platform/yuanbao.py +15 -17
- jarvis/jarvis_platform_manager/main.py +14 -51
- jarvis/jarvis_rag/cli.py +21 -13
- jarvis/jarvis_rag/embedding_manager.py +138 -6
- jarvis/jarvis_rag/llm_interface.py +0 -2
- jarvis/jarvis_rag/rag_pipeline.py +41 -17
- jarvis/jarvis_rag/reranker.py +24 -2
- jarvis/jarvis_rag/retriever.py +21 -23
- jarvis/jarvis_smart_shell/main.py +1 -10
- jarvis/jarvis_tools/cli/main.py +22 -15
- jarvis/jarvis_tools/edit_file.py +6 -6
- jarvis/jarvis_tools/execute_script.py +1 -2
- jarvis/jarvis_tools/file_analyzer.py +12 -6
- jarvis/jarvis_tools/registry.py +13 -10
- jarvis/jarvis_tools/sub_agent.py +5 -8
- jarvis/jarvis_tools/sub_code_agent.py +5 -5
- jarvis/jarvis_utils/config.py +24 -10
- jarvis/jarvis_utils/input.py +8 -5
- jarvis/jarvis_utils/methodology.py +11 -6
- jarvis/jarvis_utils/utils.py +29 -12
- {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/METADATA +10 -3
- {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/RECORD +43 -43
- {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.3.23.dist-info → jarvis_ai_assistant-0.3.25.dist-info}/top_level.txt +0 -0
@@ -134,10 +134,11 @@ class YuanbaoPlatform(BasePlatform):
|
|
134
134
|
|
135
135
|
for file_path in file_list:
|
136
136
|
file_name = os.path.basename(file_path)
|
137
|
-
|
137
|
+
log_lines: list[str] = []
|
138
|
+
log_lines.append(f"上传文件 {file_name}")
|
138
139
|
try:
|
139
140
|
# 1. Prepare the file information
|
140
|
-
|
141
|
+
log_lines.append(f"准备文件信息: {file_name}")
|
141
142
|
file_size = os.path.getsize(file_path)
|
142
143
|
file_extension = os.path.splitext(file_path)[1].lower().lstrip(".")
|
143
144
|
|
@@ -192,23 +193,23 @@ class YuanbaoPlatform(BasePlatform):
|
|
192
193
|
file_type = "code"
|
193
194
|
|
194
195
|
# 2. Generate upload information
|
195
|
-
|
196
|
+
log_lines.append(f"获取上传信息: {file_name}")
|
196
197
|
upload_info = self._generate_upload_info(file_name)
|
197
198
|
if not upload_info:
|
198
|
-
|
199
|
-
|
200
|
-
)
|
199
|
+
log_lines.append(f"无法获取文件 {file_name} 的上传信息")
|
200
|
+
PrettyOutput.print("\n".join(log_lines), OutputType.ERROR)
|
201
201
|
return False
|
202
202
|
|
203
203
|
# 3. Upload the file to COS
|
204
|
-
|
204
|
+
log_lines.append(f"上传文件到云存储: {file_name}")
|
205
205
|
upload_success = self._upload_file_to_cos(file_path, upload_info)
|
206
206
|
if not upload_success:
|
207
|
-
|
207
|
+
log_lines.append(f"上传文件 {file_name} 失败")
|
208
|
+
PrettyOutput.print("\n".join(log_lines), OutputType.ERROR)
|
208
209
|
return False
|
209
210
|
|
210
211
|
# 4. Create file metadata for chat
|
211
|
-
|
212
|
+
log_lines.append(f"生成文件元数据: {file_name}")
|
212
213
|
file_metadata = {
|
213
214
|
"type": file_type,
|
214
215
|
"docType": file_extension if file_extension else file_type,
|
@@ -226,19 +227,16 @@ class YuanbaoPlatform(BasePlatform):
|
|
226
227
|
file_metadata["width"] = img.width
|
227
228
|
file_metadata["height"] = img.height
|
228
229
|
except Exception as e:
|
229
|
-
|
230
|
-
f"无法获取图片 {file_name} 的尺寸: {str(e)}",
|
231
|
-
OutputType.WARNING,
|
232
|
-
)
|
230
|
+
log_lines.append(f"无法获取图片 {file_name} 的尺寸: {str(e)}")
|
233
231
|
|
234
232
|
uploaded_files.append(file_metadata)
|
235
|
-
|
233
|
+
log_lines.append(f"文件 {file_name} 上传成功")
|
234
|
+
PrettyOutput.print("\n".join(log_lines), OutputType.INFO)
|
236
235
|
time.sleep(3) # 上传成功后等待3秒
|
237
236
|
|
238
237
|
except Exception as e:
|
239
|
-
|
240
|
-
|
241
|
-
)
|
238
|
+
log_lines.append(f"上传文件 {file_path} 时出错: {str(e)}")
|
239
|
+
PrettyOutput.print("\n".join(log_lines), OutputType.ERROR)
|
242
240
|
return False
|
243
241
|
|
244
242
|
self.multimedia = uploaded_files
|
@@ -11,8 +11,6 @@ import typer
|
|
11
11
|
from jarvis.jarvis_utils.config import (
|
12
12
|
get_normal_platform_name,
|
13
13
|
get_normal_model_name,
|
14
|
-
get_thinking_platform_name,
|
15
|
-
get_thinking_model_name,
|
16
14
|
)
|
17
15
|
|
18
16
|
from jarvis.jarvis_platform.registry import PlatformRegistry
|
@@ -66,7 +64,7 @@ def list_platforms(
|
|
66
64
|
|
67
65
|
|
68
66
|
def chat_with_model(
|
69
|
-
platform_name: str, model_name: str, system_prompt: str
|
67
|
+
platform_name: str, model_name: str, system_prompt: str
|
70
68
|
) -> None:
|
71
69
|
"""与指定平台和模型进行对话。
|
72
70
|
|
@@ -74,7 +72,7 @@ def chat_with_model(
|
|
74
72
|
platform_name: 平台名称
|
75
73
|
model_name: 模型名称
|
76
74
|
system_prompt: 系统提示语
|
77
|
-
|
75
|
+
|
78
76
|
"""
|
79
77
|
registry = PlatformRegistry.get_global_platform_registry()
|
80
78
|
conversation_history: List[Dict[str, str]] = [] # 存储对话记录
|
@@ -360,32 +358,19 @@ def chat_command(
|
|
360
358
|
None, "--platform", "-p", help="指定要使用的平台"
|
361
359
|
),
|
362
360
|
model: Optional[str] = typer.Option(None, "--model", "-m", help="指定要使用的模型"),
|
363
|
-
|
364
|
-
"normal",
|
365
|
-
"-t",
|
366
|
-
"--llm-type",
|
367
|
-
help="使用的LLM类型,可选值:'normal'(普通)或 'thinking'(思考模式)",
|
368
|
-
),
|
361
|
+
|
369
362
|
llm_group: Optional[str] = typer.Option(
|
370
363
|
None, "-g", "--llm-group", help="使用的模型组,覆盖配置文件中的设置"
|
371
364
|
),
|
372
365
|
) -> None:
|
373
366
|
"""与指定平台和模型聊天。"""
|
374
367
|
# 如果未提供平台或模型参数,则从config获取默认值
|
375
|
-
platform = platform or (
|
376
|
-
|
377
|
-
if llm_type == "thinking"
|
378
|
-
else get_normal_platform_name(llm_group)
|
379
|
-
)
|
380
|
-
model = model or (
|
381
|
-
get_thinking_model_name(llm_group)
|
382
|
-
if llm_type == "thinking"
|
383
|
-
else get_normal_model_name(llm_group)
|
384
|
-
)
|
368
|
+
platform = platform or get_normal_platform_name(llm_group)
|
369
|
+
model = model or get_normal_model_name(llm_group)
|
385
370
|
|
386
371
|
if not validate_platform_model(platform, model):
|
387
372
|
return
|
388
|
-
chat_with_model(platform, model, ""
|
373
|
+
chat_with_model(platform, model, "")
|
389
374
|
|
390
375
|
|
391
376
|
@app.command("service")
|
@@ -444,12 +429,7 @@ def role_command(
|
|
444
429
|
model: Optional[str] = typer.Option(
|
445
430
|
None, "--model", "-m", help="指定要使用的模型,覆盖角色配置"
|
446
431
|
),
|
447
|
-
|
448
|
-
None,
|
449
|
-
"-t",
|
450
|
-
"--llm-type",
|
451
|
-
help="使用的LLM类型,可选值:'normal'(普通)或 'thinking'(思考模式),覆盖角色配置",
|
452
|
-
),
|
432
|
+
|
453
433
|
llm_group: Optional[str] = typer.Option(
|
454
434
|
None, "-g", "--llm-group", help="使用的模型组,覆盖配置文件中的设置"
|
455
435
|
),
|
@@ -483,54 +463,37 @@ def role_command(
|
|
483
463
|
PrettyOutput.print("无效的选择", OutputType.ERROR)
|
484
464
|
return
|
485
465
|
|
486
|
-
|
487
|
-
role_llm_type = llm_type or selected_role.get("llm_type", "normal")
|
466
|
+
|
488
467
|
|
489
468
|
# 初始化平台和模型
|
490
469
|
# 如果提供了platform或model参数,优先使用命令行参数
|
491
|
-
# 否则,如果提供了llm_group
|
470
|
+
# 否则,如果提供了 llm_group,则从配置中获取
|
492
471
|
# 最后才使用角色配置中的platform和model
|
493
472
|
if platform:
|
494
473
|
platform_name = platform
|
495
474
|
elif llm_group:
|
496
|
-
platform_name = (
|
497
|
-
get_thinking_platform_name(llm_group)
|
498
|
-
if role_llm_type == "thinking"
|
499
|
-
else get_normal_platform_name(llm_group)
|
500
|
-
)
|
475
|
+
platform_name = get_normal_platform_name(llm_group)
|
501
476
|
else:
|
502
477
|
platform_name = selected_role.get("platform")
|
503
478
|
if not platform_name:
|
504
479
|
# 如果角色配置中没有platform,使用默认配置
|
505
|
-
platform_name = (
|
506
|
-
get_thinking_platform_name()
|
507
|
-
if role_llm_type == "thinking"
|
508
|
-
else get_normal_platform_name()
|
509
|
-
)
|
480
|
+
platform_name = get_normal_platform_name()
|
510
481
|
|
511
482
|
if model:
|
512
483
|
model_name = model
|
513
484
|
elif llm_group:
|
514
|
-
model_name = (
|
515
|
-
get_thinking_model_name(llm_group)
|
516
|
-
if role_llm_type == "thinking"
|
517
|
-
else get_normal_model_name(llm_group)
|
518
|
-
)
|
485
|
+
model_name = get_normal_model_name(llm_group)
|
519
486
|
else:
|
520
487
|
model_name = selected_role.get("model")
|
521
488
|
if not model_name:
|
522
489
|
# 如果角色配置中没有model,使用默认配置
|
523
|
-
model_name = (
|
524
|
-
get_thinking_model_name()
|
525
|
-
if role_llm_type == "thinking"
|
526
|
-
else get_normal_model_name()
|
527
|
-
)
|
490
|
+
model_name = get_normal_model_name()
|
528
491
|
|
529
492
|
system_prompt = selected_role.get("system_prompt", "")
|
530
493
|
|
531
494
|
# 开始对话
|
532
495
|
PrettyOutput.print(f"已选择角色: {selected_role['name']}", OutputType.SUCCESS)
|
533
|
-
chat_with_model(platform_name, model_name, system_prompt
|
496
|
+
chat_with_model(platform_name, model_name, system_prompt)
|
534
497
|
|
535
498
|
|
536
499
|
def main() -> None:
|
jarvis/jarvis_rag/cli.py
CHANGED
@@ -240,6 +240,7 @@ def add_documents(
|
|
240
240
|
|
241
241
|
sorted_files = sorted(list(files_to_process))
|
242
242
|
total_files = len(sorted_files)
|
243
|
+
loaded_msgs: List[str] = []
|
243
244
|
|
244
245
|
for i, file_path in enumerate(sorted_files):
|
245
246
|
try:
|
@@ -249,14 +250,15 @@ def add_documents(
|
|
249
250
|
loader = TextLoader(str(file_path), encoding="utf-8")
|
250
251
|
|
251
252
|
docs_batch.extend(loader.load())
|
252
|
-
|
253
|
-
f"已加载: {file_path} (文件 {i + 1}/{total_files})", OutputType.INFO
|
254
|
-
)
|
253
|
+
loaded_msgs.append(f"已加载: {file_path} (文件 {i + 1}/{total_files})")
|
255
254
|
except Exception as e:
|
256
255
|
PrettyOutput.print(f"加载失败 {file_path}: {e}", OutputType.WARNING)
|
257
256
|
|
258
257
|
# 当批处理已满或是最后一个文件时处理批处理
|
259
258
|
if docs_batch and (len(docs_batch) >= batch_size or (i + 1) == total_files):
|
259
|
+
if loaded_msgs:
|
260
|
+
PrettyOutput.print("\n".join(loaded_msgs), OutputType.INFO)
|
261
|
+
loaded_msgs = []
|
260
262
|
PrettyOutput.print(
|
261
263
|
f"正在处理批次,包含 {len(docs_batch)} 个文档...", OutputType.INFO
|
262
264
|
)
|
@@ -267,6 +269,10 @@ def add_documents(
|
|
267
269
|
)
|
268
270
|
docs_batch = [] # 清空批处理
|
269
271
|
|
272
|
+
# 最后统一打印可能残留的“已加载”信息
|
273
|
+
if loaded_msgs:
|
274
|
+
PrettyOutput.print("\n".join(loaded_msgs), OutputType.INFO)
|
275
|
+
loaded_msgs = []
|
270
276
|
if total_docs_added == 0:
|
271
277
|
PrettyOutput.print("未能成功加载任何文档。", OutputType.ERROR)
|
272
278
|
raise typer.Exit(code=1)
|
@@ -321,12 +327,11 @@ def list_documents(
|
|
321
327
|
)
|
322
328
|
return
|
323
329
|
|
324
|
-
|
325
|
-
|
326
|
-
OutputType.INFO,
|
327
|
-
)
|
330
|
+
# 避免在循环中逐条打印,先拼接后统一打印
|
331
|
+
lines = [f"知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:"]
|
328
332
|
for i, source in enumerate(sorted(list(sources)), 1):
|
329
|
-
|
333
|
+
lines.append(f" {i}. {source}")
|
334
|
+
PrettyOutput.print("\n".join(lines), OutputType.INFO)
|
330
335
|
|
331
336
|
except Exception as e:
|
332
337
|
PrettyOutput.print(f"发生错误: {e}", OutputType.ERROR)
|
@@ -352,6 +357,12 @@ def retrieve(
|
|
352
357
|
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
353
358
|
),
|
354
359
|
n_results: int = typer.Option(5, "--top-n", help="要检索的文档数量。"),
|
360
|
+
rewrite: bool = typer.Option(
|
361
|
+
True,
|
362
|
+
"--rewrite/--no-rewrite",
|
363
|
+
help="是否对查询进行LLM重写以提升召回,默认开启。",
|
364
|
+
show_default=True,
|
365
|
+
),
|
355
366
|
):
|
356
367
|
"""仅从RAG知识库检索文档并打印结果。"""
|
357
368
|
try:
|
@@ -366,6 +377,7 @@ def retrieve(
|
|
366
377
|
collection_name=collection_name,
|
367
378
|
use_bm25=use_bm25,
|
368
379
|
use_rerank=use_rerank,
|
380
|
+
use_query_rewrite=rewrite,
|
369
381
|
)
|
370
382
|
|
371
383
|
PrettyOutput.print(f"正在为问题检索文档: '{question}'", OutputType.INFO)
|
@@ -450,11 +462,7 @@ def query(
|
|
450
462
|
PrettyOutput.print(f"正在查询: '{question}'", OutputType.INFO)
|
451
463
|
answer = pipeline.query(question)
|
452
464
|
|
453
|
-
PrettyOutput.print(
|
454
|
-
# 我们仍然可以使用 rich.markdown.Markdown,因为 PrettyOutput 底层使用了 rich
|
455
|
-
from jarvis.jarvis_utils.globals import console
|
456
|
-
|
457
|
-
console.print(Markdown(answer))
|
465
|
+
PrettyOutput.print(answer, OutputType.SUCCESS)
|
458
466
|
|
459
467
|
except Exception as e:
|
460
468
|
PrettyOutput.print(f"发生错误: {e}", OutputType.ERROR)
|
@@ -1,6 +1,9 @@
|
|
1
1
|
import torch
|
2
|
+
import os
|
2
3
|
from typing import List, cast
|
3
4
|
from langchain_huggingface import HuggingFaceEmbeddings
|
5
|
+
from huggingface_hub import snapshot_download
|
6
|
+
|
4
7
|
|
5
8
|
from .cache import EmbeddingCache
|
6
9
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
@@ -38,12 +41,141 @@ class EmbeddingManager:
|
|
38
41
|
encode_kwargs = {"normalize_embeddings": True}
|
39
42
|
|
40
43
|
try:
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
44
|
+
# First try to load model from local cache without any network access
|
45
|
+
try:
|
46
|
+
from sentence_transformers import SentenceTransformer
|
47
|
+
local_dir = None
|
48
|
+
# Prefer explicit local dir via env or direct path
|
49
|
+
|
50
|
+
if os.path.isdir(self.model_name):
|
51
|
+
return HuggingFaceEmbeddings(
|
52
|
+
model_name=self.model_name,
|
53
|
+
model_kwargs=model_kwargs,
|
54
|
+
encode_kwargs=encode_kwargs,
|
55
|
+
show_progress=False,
|
56
|
+
)
|
57
|
+
|
58
|
+
# Try common local cache directories for sentence-transformers and HF hub
|
59
|
+
try:
|
60
|
+
home = os.path.expanduser("~")
|
61
|
+
st_home = os.path.join(home, ".cache", "sentence_transformers")
|
62
|
+
torch_st_home = os.path.join(home, ".cache", "torch", "sentence_transformers")
|
63
|
+
# Build common name variants found in local caches
|
64
|
+
org, name = (
|
65
|
+
self.model_name.split("/", 1)
|
66
|
+
if "/" in self.model_name
|
67
|
+
else ("", self.model_name)
|
68
|
+
)
|
69
|
+
san1 = self.model_name.replace("/", "_")
|
70
|
+
san2 = self.model_name.replace("/", "__")
|
71
|
+
san3 = self.model_name.replace("/", "--")
|
72
|
+
# include plain 'name' for caches that drop org prefix
|
73
|
+
name_variants = list(dict.fromkeys([self.model_name, san1, san2, san3, name]))
|
74
|
+
candidates = []
|
75
|
+
for base in [st_home, torch_st_home]:
|
76
|
+
for nv in name_variants:
|
77
|
+
p = os.path.join(base, nv)
|
78
|
+
if os.path.isdir(p):
|
79
|
+
candidates.append(p)
|
80
|
+
# Fuzzy scan cache directory for entries that include variants
|
81
|
+
try:
|
82
|
+
for entry in os.listdir(base):
|
83
|
+
ep = os.path.join(base, entry)
|
84
|
+
if not os.path.isdir(ep):
|
85
|
+
continue
|
86
|
+
if (
|
87
|
+
(org and entry.startswith(f"{org}__") and name in entry)
|
88
|
+
or (san1 in entry)
|
89
|
+
or (name in entry)
|
90
|
+
):
|
91
|
+
candidates.append(ep)
|
92
|
+
except Exception:
|
93
|
+
pass
|
94
|
+
|
95
|
+
# Hugging Face Hub cache snapshots
|
96
|
+
hf_cache = os.path.join(home, ".cache", "huggingface", "hub")
|
97
|
+
if "/" in self.model_name:
|
98
|
+
org, name = self.model_name.split("/", 1)
|
99
|
+
models_dir = os.path.join(hf_cache, f"models--{org}--{name}", "snapshots")
|
100
|
+
if os.path.isdir(models_dir):
|
101
|
+
try:
|
102
|
+
snaps = sorted(
|
103
|
+
[os.path.join(models_dir, d) for d in os.listdir(models_dir)],
|
104
|
+
key=lambda p: os.path.getmtime(p),
|
105
|
+
reverse=True,
|
106
|
+
)
|
107
|
+
except Exception:
|
108
|
+
snaps = [os.path.join(models_dir, d) for d in os.listdir(models_dir)]
|
109
|
+
for sp in snaps:
|
110
|
+
if os.path.isdir(sp):
|
111
|
+
candidates.append(sp)
|
112
|
+
break
|
113
|
+
|
114
|
+
for cand in candidates:
|
115
|
+
try:
|
116
|
+
return HuggingFaceEmbeddings(
|
117
|
+
model_name=cand,
|
118
|
+
model_kwargs=model_kwargs,
|
119
|
+
encode_kwargs=encode_kwargs,
|
120
|
+
show_progress=False,
|
121
|
+
)
|
122
|
+
except Exception:
|
123
|
+
continue
|
124
|
+
except Exception:
|
125
|
+
pass
|
126
|
+
|
127
|
+
try:
|
128
|
+
# Try resolve local cached directory; do not hit network
|
129
|
+
local_dir = snapshot_download(repo_id=self.model_name, local_files_only=True)
|
130
|
+
except Exception:
|
131
|
+
local_dir = None
|
132
|
+
|
133
|
+
if local_dir:
|
134
|
+
return HuggingFaceEmbeddings(
|
135
|
+
model_name=local_dir,
|
136
|
+
model_kwargs=model_kwargs,
|
137
|
+
encode_kwargs=encode_kwargs,
|
138
|
+
show_progress=False,
|
139
|
+
)
|
140
|
+
|
141
|
+
|
142
|
+
|
143
|
+
# Fall back to remote download if local cache not found and not offline
|
144
|
+
return HuggingFaceEmbeddings(
|
145
|
+
model_name=self.model_name,
|
146
|
+
model_kwargs=model_kwargs,
|
147
|
+
encode_kwargs=encode_kwargs,
|
148
|
+
show_progress=True,
|
149
|
+
)
|
150
|
+
except Exception as _e:
|
151
|
+
# 如果已检测到本地候选路径(直接目录 / 本地缓存快照),则视为本地加载失败,
|
152
|
+
# 为避免在用户期望“本地优先不联网”的情况下触发联网,直接抛错并给出修复建议。
|
153
|
+
had_local_candidate = False
|
154
|
+
try:
|
155
|
+
had_local_candidate = (
|
156
|
+
os.path.isdir(self.model_name)
|
157
|
+
# 如果上面 snapshot_download 命中了本地缓存,会将 local_dir 设为非 None
|
158
|
+
or (locals().get("local_dir") is not None)
|
159
|
+
)
|
160
|
+
except Exception:
|
161
|
+
pass
|
162
|
+
|
163
|
+
if had_local_candidate:
|
164
|
+
PrettyOutput.print(
|
165
|
+
"检测到本地模型路径但加载失败。为避免触发网络访问,已中止远程回退。\n"
|
166
|
+
"请确认本地目录包含完整的 Transformers/Tokenizer 文件(如 config.json、model.safetensors、tokenizer.json/merges.txt 等),\n"
|
167
|
+
"或在配置中将 embedding_model 设置为该本地目录,或将模型放置到默认的 Hugging Face 缓存目录(例如 ~/.cache/huggingface/hub)。",
|
168
|
+
OutputType.ERROR,
|
169
|
+
)
|
170
|
+
raise
|
171
|
+
|
172
|
+
# 未发现任何本地候选,则保持原有行为:回退至远程下载
|
173
|
+
return HuggingFaceEmbeddings(
|
174
|
+
model_name=self.model_name,
|
175
|
+
model_kwargs=model_kwargs,
|
176
|
+
encode_kwargs=encode_kwargs,
|
177
|
+
show_progress=True,
|
178
|
+
)
|
47
179
|
except Exception as e:
|
48
180
|
PrettyOutput.print(
|
49
181
|
f"加载嵌入模型 '{self.model_name}' 时出错: {e}", OutputType.ERROR
|
@@ -47,13 +47,11 @@ class ToolAgent_LLM(LLMInterface):
|
|
47
47
|
# 为代理提供一个通用的系统提示
|
48
48
|
self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
|
49
49
|
self.summary_prompt = """
|
50
|
-
<report>
|
51
50
|
请为本次问答任务生成一个总结报告,包含以下内容:
|
52
51
|
|
53
52
|
1. **原始问题**: 重述用户最开始提出的问题。
|
54
53
|
2. **关键信息来源**: 总结你是基于哪些关键信息或文件得出的结论。
|
55
54
|
3. **最终答案**: 给出最终的、精炼的回答。
|
56
|
-
</report>
|
57
55
|
"""
|
58
56
|
|
59
57
|
def generate(self, prompt: str, **kwargs) -> str:
|
@@ -34,6 +34,7 @@ class JarvisRAGPipeline:
|
|
34
34
|
collection_name: str = "jarvis_rag_collection",
|
35
35
|
use_bm25: bool = True,
|
36
36
|
use_rerank: bool = True,
|
37
|
+
use_query_rewrite: bool = True,
|
37
38
|
):
|
38
39
|
"""
|
39
40
|
初始化RAG管道。
|
@@ -69,6 +70,8 @@ class JarvisRAGPipeline:
|
|
69
70
|
self.collection_name = collection_name
|
70
71
|
self.use_bm25 = use_bm25
|
71
72
|
self.use_rerank = use_rerank
|
73
|
+
# 查询重写开关(默认开启,可由CLI控制)
|
74
|
+
self.use_query_rewrite = use_query_rewrite
|
72
75
|
|
73
76
|
# 延迟加载的组件
|
74
77
|
self._embedding_manager: Optional[EmbeddingManager] = None
|
@@ -161,14 +164,15 @@ class JarvisRAGPipeline:
|
|
161
164
|
if not changed and not deleted:
|
162
165
|
return
|
163
166
|
# 打印摘要
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
167
|
+
# 先拼接列表信息再统一打印,避免循环中逐条打印
|
168
|
+
lines = [
|
169
|
+
f"检测到索引可能不一致:变更 {len(changed)} 个,删除 {len(deleted)} 个。"
|
170
|
+
]
|
171
|
+
if changed:
|
172
|
+
lines.extend([f" 变更: {p}" for p in changed[:3]])
|
173
|
+
if deleted:
|
174
|
+
lines.extend([f" 删除: {p}" for p in deleted[:3]])
|
175
|
+
PrettyOutput.print("\n".join(lines), OutputType.WARNING)
|
172
176
|
# 询问用户
|
173
177
|
if get_yes_no(
|
174
178
|
"检测到索引变更,是否现在更新索引后再开始检索?", default=True
|
@@ -228,13 +232,23 @@ class JarvisRAGPipeline:
|
|
228
232
|
"""
|
229
233
|
# 0. 检测索引变更并可选更新(在重写query之前)
|
230
234
|
self._pre_search_update_index_if_needed()
|
231
|
-
# 1.
|
232
|
-
|
235
|
+
# 1. 将原始查询重写为多个查询(可配置)
|
236
|
+
if self.use_query_rewrite:
|
237
|
+
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
238
|
+
else:
|
239
|
+
PrettyOutput.print(
|
240
|
+
"已关闭查询重写,将直接使用原始查询进行检索。",
|
241
|
+
OutputType.INFO,
|
242
|
+
)
|
243
|
+
rewritten_queries = [query_text]
|
233
244
|
|
234
245
|
# 2. 为每个重写的查询检索初始候选文档
|
246
|
+
PrettyOutput.print(
|
247
|
+
"将为以下查询变体进行混合检索:\n" + "\n".join([f" - {q}" for q in rewritten_queries]),
|
248
|
+
OutputType.INFO,
|
249
|
+
)
|
235
250
|
all_candidate_docs = []
|
236
251
|
for q in rewritten_queries:
|
237
|
-
PrettyOutput.print(f"正在为查询变体 '{q}' 进行混合检索...", OutputType.INFO)
|
238
252
|
candidates = self._get_retriever().retrieve(
|
239
253
|
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
240
254
|
)
|
@@ -273,9 +287,9 @@ class JarvisRAGPipeline:
|
|
273
287
|
)
|
274
288
|
)
|
275
289
|
if sources:
|
276
|
-
|
277
|
-
for source in sources
|
278
|
-
|
290
|
+
# 合并来源列表后一次性打印,避免多次加框
|
291
|
+
lines = ["根据以下文档回答:"] + [f" - {source}" for source in sources]
|
292
|
+
PrettyOutput.print("\n".join(lines), OutputType.INFO)
|
279
293
|
|
280
294
|
# 4. 创建最终提示并生成答案
|
281
295
|
# 我们使用原始的query_text作为给LLM的最终提示
|
@@ -299,13 +313,23 @@ class JarvisRAGPipeline:
|
|
299
313
|
"""
|
300
314
|
# 0. 检测索引变更并可选更新(在重写query之前)
|
301
315
|
self._pre_search_update_index_if_needed()
|
302
|
-
# 1.
|
303
|
-
|
316
|
+
# 1. 重写查询(可配置)
|
317
|
+
if self.use_query_rewrite:
|
318
|
+
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
319
|
+
else:
|
320
|
+
PrettyOutput.print(
|
321
|
+
"已关闭查询重写,将直接使用原始查询进行检索。",
|
322
|
+
OutputType.INFO,
|
323
|
+
)
|
324
|
+
rewritten_queries = [query_text]
|
304
325
|
|
305
326
|
# 2. 检索候选文档
|
327
|
+
PrettyOutput.print(
|
328
|
+
"将为以下查询变体进行混合检索:\n" + "\n".join([f" - {q}" for q in rewritten_queries]),
|
329
|
+
OutputType.INFO,
|
330
|
+
)
|
306
331
|
all_candidate_docs = []
|
307
332
|
for q in rewritten_queries:
|
308
|
-
PrettyOutput.print(f"正在为查询变体 '{q}' 进行混合检索...", OutputType.INFO)
|
309
333
|
candidates = self._get_retriever().retrieve(
|
310
334
|
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
311
335
|
)
|
jarvis/jarvis_rag/reranker.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
from typing import List
|
2
|
+
import os
|
2
3
|
|
3
4
|
from langchain.docstore.document import Document
|
4
5
|
from sentence_transformers.cross_encoder import ( # type: ignore
|
5
6
|
CrossEncoder,
|
6
7
|
)
|
8
|
+
from huggingface_hub import snapshot_download
|
7
9
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
8
10
|
|
9
11
|
|
@@ -21,8 +23,28 @@ class Reranker:
|
|
21
23
|
model_name (str): 要使用的Cross-Encoder模型的名称。
|
22
24
|
"""
|
23
25
|
PrettyOutput.print(f"正在初始化重排模型: {model_name}...", OutputType.INFO)
|
24
|
-
|
25
|
-
|
26
|
+
try:
|
27
|
+
local_dir = None
|
28
|
+
|
29
|
+
if os.path.isdir(model_name):
|
30
|
+
self.model = CrossEncoder(model_name)
|
31
|
+
PrettyOutput.print("重排模型初始化成功。", OutputType.SUCCESS)
|
32
|
+
return
|
33
|
+
try:
|
34
|
+
# Prefer local cache; avoid any network access
|
35
|
+
local_dir = snapshot_download(repo_id=model_name, local_files_only=True)
|
36
|
+
except Exception:
|
37
|
+
local_dir = None
|
38
|
+
|
39
|
+
if local_dir:
|
40
|
+
self.model = CrossEncoder(local_dir)
|
41
|
+
else:
|
42
|
+
self.model = CrossEncoder(model_name)
|
43
|
+
|
44
|
+
PrettyOutput.print("重排模型初始化成功。", OutputType.SUCCESS)
|
45
|
+
except Exception as e:
|
46
|
+
PrettyOutput.print(f"初始化重排模型失败: {e}", OutputType.ERROR)
|
47
|
+
raise
|
26
48
|
|
27
49
|
def rerank(
|
28
50
|
self, query: str, documents: List[Document], top_n: int = 5
|