jarvis-ai-assistant 0.1.220__py3-none-any.whl → 0.1.221__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +93 -382
- jarvis/jarvis_agent/edit_file_handler.py +32 -185
- jarvis/jarvis_agent/prompt_builder.py +57 -0
- jarvis/jarvis_agent/prompts.py +188 -0
- jarvis/jarvis_agent/protocols.py +30 -0
- jarvis/jarvis_agent/session_manager.py +84 -0
- jarvis/jarvis_agent/tool_executor.py +49 -0
- jarvis/jarvis_code_agent/code_agent.py +4 -4
- jarvis/jarvis_data/config_schema.json +8 -18
- jarvis/jarvis_rag/__init__.py +2 -2
- jarvis/jarvis_rag/cache.py +28 -30
- jarvis/jarvis_rag/cli.py +141 -52
- jarvis/jarvis_rag/embedding_manager.py +32 -46
- jarvis/jarvis_rag/llm_interface.py +32 -34
- jarvis/jarvis_rag/query_rewriter.py +11 -12
- jarvis/jarvis_rag/rag_pipeline.py +40 -43
- jarvis/jarvis_rag/reranker.py +18 -18
- jarvis/jarvis_rag/retriever.py +29 -29
- jarvis/jarvis_tools/edit_file.py +11 -36
- jarvis/jarvis_utils/config.py +10 -25
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/METADATA +15 -12
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/RECORD +27 -22
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/top_level.txt +0 -0
@@ -1,59 +1,45 @@
|
|
1
|
-
|
1
|
+
import torch
|
2
|
+
from typing import List, cast
|
2
3
|
from langchain_huggingface import HuggingFaceEmbeddings
|
3
4
|
|
4
|
-
from jarvis.jarvis_utils.config import (
|
5
|
-
get_rag_embedding_models,
|
6
|
-
get_rag_embedding_cache_path,
|
7
|
-
)
|
8
5
|
from .cache import EmbeddingCache
|
9
6
|
|
10
7
|
|
11
8
|
class EmbeddingManager:
|
12
9
|
"""
|
13
|
-
|
10
|
+
管理本地嵌入模型的加载和使用,并带有缓存功能。
|
14
11
|
|
15
|
-
|
16
|
-
|
17
|
-
and uses a disk-based cache to avoid re-computing embeddings for the
|
18
|
-
same text.
|
12
|
+
该类负责从Hugging Face加载指定的模型,并使用基于磁盘的缓存
|
13
|
+
来避免为相同文本重新计算嵌入。
|
19
14
|
"""
|
20
15
|
|
21
|
-
def __init__(
|
22
|
-
self,
|
23
|
-
mode: Literal["performance", "accuracy"],
|
24
|
-
cache_dir: str,
|
25
|
-
):
|
16
|
+
def __init__(self, model_name: str, cache_dir: str):
|
26
17
|
"""
|
27
|
-
|
18
|
+
初始化EmbeddingManager。
|
28
19
|
|
29
|
-
|
30
|
-
|
31
|
-
cache_dir:
|
20
|
+
参数:
|
21
|
+
model_name: 要加载的Hugging Face模型的名称。
|
22
|
+
cache_dir: 用于存储嵌入缓存的目录。
|
32
23
|
"""
|
33
|
-
self.
|
34
|
-
self.embedding_models = get_rag_embedding_models()
|
35
|
-
if mode not in self.embedding_models:
|
36
|
-
raise ValueError(
|
37
|
-
f"Invalid mode '{mode}'. Must be one of {list(self.embedding_models.keys())}"
|
38
|
-
)
|
39
|
-
|
40
|
-
self.model_config = self.embedding_models[self.mode]
|
41
|
-
self.model_name = self.model_config["model_name"]
|
24
|
+
self.model_name = model_name
|
42
25
|
|
43
|
-
print(f"🚀
|
26
|
+
print(f"🚀 初始化嵌入管理器, 模型: '{self.model_name}'...")
|
44
27
|
|
45
|
-
#
|
46
|
-
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=
|
28
|
+
# 缓存的salt是模型名称,以防止冲突
|
29
|
+
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
|
47
30
|
self.model = self._load_model()
|
48
31
|
|
49
32
|
def _load_model(self) -> HuggingFaceEmbeddings:
|
50
|
-
"""
|
33
|
+
"""根据配置加载Hugging Face嵌入模型。"""
|
34
|
+
model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
|
35
|
+
encode_kwargs = {"normalize_embeddings": True}
|
36
|
+
|
51
37
|
try:
|
52
38
|
return HuggingFaceEmbeddings(
|
53
39
|
model_name=self.model_name,
|
54
|
-
model_kwargs=
|
55
|
-
encode_kwargs=
|
56
|
-
show_progress=
|
40
|
+
model_kwargs=model_kwargs,
|
41
|
+
encode_kwargs=encode_kwargs,
|
42
|
+
show_progress=True,
|
57
43
|
)
|
58
44
|
except Exception as e:
|
59
45
|
print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
|
@@ -62,18 +48,18 @@ class EmbeddingManager:
|
|
62
48
|
|
63
49
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
64
50
|
"""
|
65
|
-
|
51
|
+
使用缓存为文档列表计算嵌入。
|
66
52
|
|
67
|
-
|
68
|
-
texts:
|
53
|
+
参数:
|
54
|
+
texts: 要嵌入的文档(字符串)列表。
|
69
55
|
|
70
|
-
|
71
|
-
|
56
|
+
返回:
|
57
|
+
一个嵌入列表,每个文档对应一个嵌入。
|
72
58
|
"""
|
73
59
|
if not texts:
|
74
60
|
return []
|
75
61
|
|
76
|
-
#
|
62
|
+
# 检查缓存中是否已存在嵌入
|
77
63
|
cached_embeddings = self.cache.get_batch(texts)
|
78
64
|
|
79
65
|
texts_to_embed = []
|
@@ -83,17 +69,17 @@ class EmbeddingManager:
|
|
83
69
|
texts_to_embed.append(text)
|
84
70
|
indices_to_embed.append(i)
|
85
71
|
|
86
|
-
#
|
72
|
+
# 为不在缓存中的文本计算嵌入
|
87
73
|
if texts_to_embed:
|
88
74
|
print(
|
89
75
|
f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
|
90
76
|
)
|
91
77
|
new_embeddings = self.model.embed_documents(texts_to_embed)
|
92
78
|
|
93
|
-
#
|
79
|
+
# 将新的嵌入存储在缓存中
|
94
80
|
self.cache.set_batch(texts_to_embed, new_embeddings)
|
95
81
|
|
96
|
-
#
|
82
|
+
# 将新的嵌入放回结果列表中
|
97
83
|
for i, embedding in zip(indices_to_embed, new_embeddings):
|
98
84
|
cached_embeddings[i] = embedding
|
99
85
|
else:
|
@@ -103,7 +89,7 @@ class EmbeddingManager:
|
|
103
89
|
|
104
90
|
def embed_query(self, text: str) -> List[float]:
|
105
91
|
"""
|
106
|
-
|
107
|
-
|
92
|
+
为单个查询计算嵌入。
|
93
|
+
查询通常不被缓存,但如果需要可以添加。
|
108
94
|
"""
|
109
95
|
return self.model.embed_query(text)
|
@@ -10,42 +10,40 @@ from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
10
10
|
|
11
11
|
class LLMInterface(ABC):
|
12
12
|
"""
|
13
|
-
|
13
|
+
大型语言模型接口的抽象基类。
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
subclass of this interface.
|
15
|
+
该类定义了与远程LLM交互的标准接口。
|
16
|
+
任何LLM提供商(如OpenAI、Anthropic等)都应作为该接口的子类来实现。
|
18
17
|
"""
|
19
18
|
|
20
19
|
@abstractmethod
|
21
20
|
def generate(self, prompt: str, **kwargs) -> str:
|
22
21
|
"""
|
23
|
-
|
22
|
+
根据给定的提示从LLM生成响应。
|
24
23
|
|
25
|
-
|
26
|
-
prompt:
|
27
|
-
**kwargs:
|
28
|
-
|
24
|
+
参数:
|
25
|
+
prompt: 发送给LLM的输入提示。
|
26
|
+
**kwargs: LLM API调用的其他关键字参数
|
27
|
+
(例如,temperature, max_tokens)。
|
29
28
|
|
30
|
-
|
31
|
-
|
29
|
+
返回:
|
30
|
+
由LLM生成的文本响应。
|
32
31
|
"""
|
33
32
|
pass
|
34
33
|
|
35
34
|
|
36
35
|
class ToolAgent_LLM(LLMInterface):
|
37
36
|
"""
|
38
|
-
|
39
|
-
to generate the final response.
|
37
|
+
LLMInterface的一个实现,它使用一个能操作工具的JarvisAgent来生成最终响应。
|
40
38
|
"""
|
41
39
|
|
42
40
|
def __init__(self):
|
43
41
|
"""
|
44
|
-
|
42
|
+
初始化工具-代理 LLM 包装器。
|
45
43
|
"""
|
46
44
|
print("🤖 已初始化工具 Agent 作为最终应答者。")
|
47
45
|
self.allowed_tools = ["read_code", "execute_script"]
|
48
|
-
#
|
46
|
+
# 为代理提供一个通用的系统提示
|
49
47
|
self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
|
50
48
|
self.summary_prompt = """
|
51
49
|
<report>
|
@@ -59,17 +57,17 @@ class ToolAgent_LLM(LLMInterface):
|
|
59
57
|
|
60
58
|
def generate(self, prompt: str, **kwargs) -> str:
|
61
59
|
"""
|
62
|
-
|
60
|
+
使用受限的工具集运行JarvisAgent以生成答案。
|
63
61
|
|
64
|
-
|
65
|
-
prompt:
|
66
|
-
**kwargs:
|
62
|
+
参数:
|
63
|
+
prompt: 要发送给代理的完整提示,包括上下文。
|
64
|
+
**kwargs: 已忽略,为保持接口兼容性而保留。
|
67
65
|
|
68
|
-
|
69
|
-
|
66
|
+
返回:
|
67
|
+
由代理生成的最终答案。
|
70
68
|
"""
|
71
69
|
try:
|
72
|
-
#
|
70
|
+
# 使用RAG上下文的特定设置初始化代理
|
73
71
|
agent = JarvisAgent(
|
74
72
|
system_prompt=self.system_prompt,
|
75
73
|
use_tools=self.allowed_tools,
|
@@ -80,7 +78,7 @@ class ToolAgent_LLM(LLMInterface):
|
|
80
78
|
summary_prompt=self.summary_prompt,
|
81
79
|
)
|
82
80
|
|
83
|
-
#
|
81
|
+
# 代理的run方法需要'user_input'参数
|
84
82
|
final_answer = agent.run(user_input=prompt)
|
85
83
|
return str(final_answer)
|
86
84
|
|
@@ -91,21 +89,21 @@ class ToolAgent_LLM(LLMInterface):
|
|
91
89
|
|
92
90
|
class JarvisPlatform_LLM(LLMInterface):
|
93
91
|
"""
|
94
|
-
|
92
|
+
项目内部平台的LLMInterface实现。
|
95
93
|
|
96
|
-
|
94
|
+
该类使用PlatformRegistry来获取配置的“普通”模型。
|
97
95
|
"""
|
98
96
|
|
99
97
|
def __init__(self):
|
100
98
|
"""
|
101
|
-
|
99
|
+
初始化Jarvis平台LLM客户端。
|
102
100
|
"""
|
103
101
|
try:
|
104
102
|
self.registry = PlatformRegistry.get_global_platform_registry()
|
105
103
|
self.platform: BasePlatform = self.registry.get_normal_platform()
|
106
104
|
self.platform.set_suppress_output(
|
107
105
|
False
|
108
|
-
) #
|
106
|
+
) # 确保模型没有控制台输出
|
109
107
|
print(f"🚀 已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}")
|
110
108
|
except Exception as e:
|
111
109
|
print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
|
@@ -113,17 +111,17 @@ class JarvisPlatform_LLM(LLMInterface):
|
|
113
111
|
|
114
112
|
def generate(self, prompt: str, **kwargs) -> str:
|
115
113
|
"""
|
116
|
-
|
114
|
+
向本地平台模型发送提示并返回响应。
|
117
115
|
|
118
|
-
|
119
|
-
prompt:
|
120
|
-
**kwargs:
|
116
|
+
参数:
|
117
|
+
prompt: 用户的提示。
|
118
|
+
**kwargs: 已忽略,为保持接口兼容性而保留。
|
121
119
|
|
122
|
-
|
123
|
-
|
120
|
+
返回:
|
121
|
+
由平台模型生成的响应。
|
124
122
|
"""
|
125
123
|
try:
|
126
|
-
#
|
124
|
+
# 使用健壮的chat_until_success方法
|
127
125
|
return self.platform.chat_until_success(prompt)
|
128
126
|
except Exception as e:
|
129
127
|
print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
|
@@ -4,22 +4,21 @@ from .llm_interface import LLMInterface
|
|
4
4
|
|
5
5
|
class QueryRewriter:
|
6
6
|
"""
|
7
|
-
|
8
|
-
queries to enhance retrieval recall.
|
7
|
+
使用LLM将用户的查询重写为多个不同的搜索查询,以提高检索召回率。
|
9
8
|
"""
|
10
9
|
|
11
10
|
def __init__(self, llm: LLMInterface):
|
12
11
|
"""
|
13
|
-
|
12
|
+
初始化QueryRewriter。
|
14
13
|
|
15
|
-
|
16
|
-
llm:
|
14
|
+
参数:
|
15
|
+
llm: 实现LLMInterface接口的类的实例。
|
17
16
|
"""
|
18
17
|
self.llm = llm
|
19
18
|
self.rewrite_prompt_template = self._create_prompt_template()
|
20
19
|
|
21
20
|
def _create_prompt_template(self) -> str:
|
22
|
-
"""
|
21
|
+
"""为多查询重写任务创建提示模板。"""
|
23
22
|
return """
|
24
23
|
你是一个精通检索的AI助手。你的任务是将以下这个单一的用户问题,从不同角度改写成 3 个不同的、但语义上相关的搜索查询。这有助于在知识库中进行更全面的搜索。
|
25
24
|
|
@@ -39,13 +38,13 @@ class QueryRewriter:
|
|
39
38
|
|
40
39
|
def rewrite(self, query: str) -> List[str]:
|
41
40
|
"""
|
42
|
-
|
41
|
+
使用LLM将用户查询重写为多个查询。
|
43
42
|
|
44
|
-
|
45
|
-
query:
|
43
|
+
参数:
|
44
|
+
query: 原始用户查询。
|
46
45
|
|
47
|
-
|
48
|
-
|
46
|
+
返回:
|
47
|
+
一个经过重写、搜索优化的查询列表。
|
49
48
|
"""
|
50
49
|
prompt = self.rewrite_prompt_template.format(query=query)
|
51
50
|
print(f"✍️ 正在将原始查询重写为多个搜索查询...")
|
@@ -55,7 +54,7 @@ class QueryRewriter:
|
|
55
54
|
line.strip() for line in response_text.strip().split("\n") if line.strip()
|
56
55
|
]
|
57
56
|
|
58
|
-
#
|
57
|
+
# 同时包含原始查询以保证鲁棒性
|
59
58
|
if query not in rewritten_queries:
|
60
59
|
rewritten_queries.insert(0, query)
|
61
60
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import List,
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
from langchain.docstore.document import Document
|
5
5
|
|
@@ -9,57 +9,55 @@ from .query_rewriter import QueryRewriter
|
|
9
9
|
from .reranker import Reranker
|
10
10
|
from .retriever import ChromaRetriever
|
11
11
|
from jarvis.jarvis_utils.config import (
|
12
|
-
|
12
|
+
get_rag_embedding_model,
|
13
|
+
get_rag_rerank_model,
|
13
14
|
get_rag_vector_db_path,
|
14
15
|
get_rag_embedding_cache_path,
|
15
|
-
get_rag_embedding_models,
|
16
16
|
)
|
17
17
|
|
18
18
|
|
19
19
|
class JarvisRAGPipeline:
|
20
20
|
"""
|
21
|
-
|
21
|
+
RAG管道的主要协调器。
|
22
22
|
|
23
|
-
|
24
|
-
|
23
|
+
该类集成了嵌入管理器、检索器和LLM,为添加文档和查询
|
24
|
+
提供了一个完整的管道。
|
25
25
|
"""
|
26
26
|
|
27
27
|
def __init__(
|
28
28
|
self,
|
29
29
|
llm: Optional[LLMInterface] = None,
|
30
|
-
|
30
|
+
embedding_model: Optional[str] = None,
|
31
31
|
db_path: Optional[str] = None,
|
32
32
|
collection_name: str = "jarvis_rag_collection",
|
33
33
|
):
|
34
34
|
"""
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
llm:
|
39
|
-
|
40
|
-
|
41
|
-
db_path:
|
42
|
-
collection_name:
|
35
|
+
初始化RAG管道。
|
36
|
+
|
37
|
+
参数:
|
38
|
+
llm: 实现LLMInterface接口的类的实例。
|
39
|
+
如果为None,则默认为ToolAgent_LLM。
|
40
|
+
embedding_model: 嵌入模型的名称。如果为None,则使用配置值。
|
41
|
+
db_path: 持久化向量数据库的路径。如果为None,则使用配置值。
|
42
|
+
collection_name: 向量数据库中集合的名称。
|
43
43
|
"""
|
44
|
-
#
|
45
|
-
|
46
|
-
embedding_models = get_rag_embedding_models()
|
47
|
-
model_name = embedding_models[_embedding_mode]["model_name"]
|
44
|
+
# 确定嵌入模型以隔离数据路径
|
45
|
+
model_name = embedding_model or get_rag_embedding_model()
|
48
46
|
sanitized_model_name = model_name.replace("/", "_").replace("\\", "_")
|
49
47
|
|
50
|
-
#
|
48
|
+
# 如果给定了特定的db_path,则使用它。否则,创建一个特定于模型的路径。
|
51
49
|
_final_db_path = (
|
52
50
|
str(db_path)
|
53
51
|
if db_path
|
54
52
|
else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
|
55
53
|
)
|
56
|
-
#
|
54
|
+
# 始终创建一个特定于模型的缓存路径。
|
57
55
|
_final_cache_path = os.path.join(
|
58
56
|
get_rag_embedding_cache_path(), sanitized_model_name
|
59
57
|
)
|
60
58
|
|
61
59
|
self.embedding_manager = EmbeddingManager(
|
62
|
-
|
60
|
+
model_name=model_name,
|
63
61
|
cache_dir=_final_cache_path,
|
64
62
|
)
|
65
63
|
self.retriever = ChromaRetriever(
|
@@ -67,27 +65,27 @@ class JarvisRAGPipeline:
|
|
67
65
|
db_path=_final_db_path,
|
68
66
|
collection_name=collection_name,
|
69
67
|
)
|
70
|
-
#
|
68
|
+
# 除非提供了特定的LLM,否则默认为ToolAgent_LLM
|
71
69
|
self.llm = llm if llm is not None else ToolAgent_LLM()
|
72
|
-
self.reranker = Reranker()
|
73
|
-
#
|
70
|
+
self.reranker = Reranker(model_name=get_rag_rerank_model())
|
71
|
+
# 使用标准LLM执行查询重写任务,而不是代理
|
74
72
|
self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
|
75
73
|
|
76
74
|
print("✅ JarvisRAGPipeline 初始化成功。")
|
77
75
|
|
78
76
|
def add_documents(self, documents: List[Document]):
|
79
77
|
"""
|
80
|
-
|
78
|
+
将文档添加到向量知识库。
|
81
79
|
|
82
|
-
|
83
|
-
documents:
|
80
|
+
参数:
|
81
|
+
documents: 要添加的LangChain文档对象列表。
|
84
82
|
"""
|
85
83
|
self.retriever.add_documents(documents)
|
86
84
|
|
87
85
|
def _create_prompt(
|
88
86
|
self, query: str, context_docs: List[Document], source_files: List[str]
|
89
87
|
) -> str:
|
90
|
-
"""
|
88
|
+
"""为LLM或代理创建最终的提示。"""
|
91
89
|
context = "\n\n".join([doc.page_content for doc in context_docs])
|
92
90
|
sources_text = "\n".join([f"- {source}" for source in source_files])
|
93
91
|
|
@@ -114,34 +112,33 @@ class JarvisRAGPipeline:
|
|
114
112
|
|
115
113
|
def query(self, query_text: str, n_results: int = 5) -> str:
|
116
114
|
"""
|
117
|
-
|
118
|
-
retrieval and reranking pipeline.
|
115
|
+
使用多查询检索和重排管道对知识库执行查询。
|
119
116
|
|
120
|
-
|
121
|
-
query_text:
|
122
|
-
n_results:
|
117
|
+
参数:
|
118
|
+
query_text: 用户的原始问题。
|
119
|
+
n_results: 要检索的最终相关块的数量。
|
123
120
|
|
124
|
-
|
125
|
-
|
121
|
+
返回:
|
122
|
+
由LLM生成的答案。
|
126
123
|
"""
|
127
|
-
# 1.
|
124
|
+
# 1. 将原始查询重写为多个查询
|
128
125
|
rewritten_queries = self.query_rewriter.rewrite(query_text)
|
129
126
|
|
130
|
-
# 2.
|
127
|
+
# 2. 为每个重写的查询检索初始候选文档
|
131
128
|
all_candidate_docs = []
|
132
129
|
for q in rewritten_queries:
|
133
130
|
print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
|
134
131
|
candidates = self.retriever.retrieve(q, n_results=n_results * 2)
|
135
132
|
all_candidate_docs.extend(candidates)
|
136
133
|
|
137
|
-
#
|
134
|
+
# 对候选文档进行去重
|
138
135
|
unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
|
139
136
|
unique_candidate_docs = list(unique_docs_dict.values())
|
140
137
|
|
141
138
|
if not unique_candidate_docs:
|
142
139
|
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
143
140
|
|
144
|
-
# 3.
|
141
|
+
# 3. 根据*原始*查询对统一的候选池进行重排
|
145
142
|
print(
|
146
143
|
f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
|
147
144
|
)
|
@@ -152,7 +149,7 @@ class JarvisRAGPipeline:
|
|
152
149
|
if not retrieved_docs:
|
153
150
|
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
154
151
|
|
155
|
-
#
|
152
|
+
# 打印最终检索到的文档的来源
|
156
153
|
sources = sorted(
|
157
154
|
list(
|
158
155
|
{
|
@@ -167,8 +164,8 @@ class JarvisRAGPipeline:
|
|
167
164
|
for source in sources:
|
168
165
|
print(f" - {source}")
|
169
166
|
|
170
|
-
# 4.
|
171
|
-
#
|
167
|
+
# 4. 创建最终提示并生成答案
|
168
|
+
# 我们使用原始的query_text作为给LLM的最终提示
|
172
169
|
prompt = self._create_prompt(query_text, retrieved_docs, sources)
|
173
170
|
|
174
171
|
print("🤖 正在从LLM生成答案...")
|
jarvis/jarvis_rag/reranker.py
CHANGED
@@ -8,16 +8,16 @@ from sentence_transformers.cross_encoder import ( # type: ignore
|
|
8
8
|
|
9
9
|
class Reranker:
|
10
10
|
"""
|
11
|
-
|
12
|
-
|
11
|
+
一个重排器类,使用Cross-Encoder模型根据文档与给定查询的相关性
|
12
|
+
对文档进行重新评分和排序。
|
13
13
|
"""
|
14
14
|
|
15
|
-
def __init__(self, model_name: str
|
15
|
+
def __init__(self, model_name: str):
|
16
16
|
"""
|
17
|
-
|
17
|
+
初始化重排器。
|
18
18
|
|
19
|
-
|
20
|
-
model_name (str):
|
19
|
+
参数:
|
20
|
+
model_name (str): 要使用的Cross-Encoder模型的名称。
|
21
21
|
"""
|
22
22
|
print(f"🔍 正在初始化重排模型: {model_name}...")
|
23
23
|
self.model = CrossEncoder(model_name)
|
@@ -27,30 +27,30 @@ class Reranker:
|
|
27
27
|
self, query: str, documents: List[Document], top_n: int = 5
|
28
28
|
) -> List[Document]:
|
29
29
|
"""
|
30
|
-
|
30
|
+
根据文档与查询的相关性对文档列表进行重排。
|
31
31
|
|
32
|
-
|
33
|
-
query (str):
|
34
|
-
documents (List[Document]):
|
35
|
-
top_n (int):
|
32
|
+
参数:
|
33
|
+
query (str): 用户的查询。
|
34
|
+
documents (List[Document]): 从初始搜索中检索到的文档列表。
|
35
|
+
top_n (int): 重排后要返回的顶部文档数。
|
36
36
|
|
37
|
-
|
38
|
-
List[Document]:
|
37
|
+
返回:
|
38
|
+
List[Document]: 一个已排序的最相关文档列表。
|
39
39
|
"""
|
40
40
|
if not documents:
|
41
41
|
return []
|
42
42
|
|
43
|
-
#
|
43
|
+
# 创建 [查询, 文档内容] 对用于评分
|
44
44
|
pairs = [[query, doc.page_content] for doc in documents]
|
45
45
|
|
46
|
-
#
|
46
|
+
# 从Cross-Encoder模型获取分数
|
47
47
|
scores = self.model.predict(pairs)
|
48
48
|
|
49
|
-
#
|
49
|
+
# 将文档与它们的分数结合并排序
|
50
50
|
doc_with_scores = list(zip(documents, scores))
|
51
|
-
doc_with_scores.sort(key=lambda x: x[1], reverse=True)
|
51
|
+
doc_with_scores.sort(key=lambda x: x[1], reverse=True) # type: ignore
|
52
52
|
|
53
|
-
#
|
53
|
+
# 返回前N个文档
|
54
54
|
reranked_docs = [doc for doc, score in doc_with_scores[:top_n]]
|
55
55
|
|
56
56
|
return reranked_docs
|