jarvis-ai-assistant 0.1.219__py3-none-any.whl → 0.1.221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +98 -440
  3. jarvis/jarvis_agent/edit_file_handler.py +32 -185
  4. jarvis/jarvis_agent/prompt_builder.py +57 -0
  5. jarvis/jarvis_agent/prompts.py +188 -0
  6. jarvis/jarvis_agent/protocols.py +30 -0
  7. jarvis/jarvis_agent/session_manager.py +84 -0
  8. jarvis/jarvis_agent/tool_executor.py +49 -0
  9. jarvis/jarvis_code_agent/code_agent.py +4 -4
  10. jarvis/jarvis_data/config_schema.json +20 -0
  11. jarvis/jarvis_platform/yuanbao.py +3 -1
  12. jarvis/jarvis_rag/__init__.py +11 -0
  13. jarvis/jarvis_rag/cache.py +85 -0
  14. jarvis/jarvis_rag/cli.py +386 -0
  15. jarvis/jarvis_rag/embedding_manager.py +95 -0
  16. jarvis/jarvis_rag/llm_interface.py +128 -0
  17. jarvis/jarvis_rag/query_rewriter.py +62 -0
  18. jarvis/jarvis_rag/rag_pipeline.py +174 -0
  19. jarvis/jarvis_rag/reranker.py +56 -0
  20. jarvis/jarvis_rag/retriever.py +201 -0
  21. jarvis/jarvis_tools/edit_file.py +11 -36
  22. jarvis/jarvis_utils/config.py +56 -0
  23. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/METADATA +90 -8
  24. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/RECORD +28 -14
  25. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/entry_points.txt +1 -0
  26. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/WHEEL +0 -0
  27. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/licenses/LICENSE +0 -0
  28. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
1
+ from abc import ABC, abstractmethod
2
+ import os
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+
6
+ from jarvis.jarvis_agent import Agent as JarvisAgent
7
+ from jarvis.jarvis_platform.base import BasePlatform
8
+ from jarvis.jarvis_platform.registry import PlatformRegistry
9
+
10
+
11
+ class LLMInterface(ABC):
12
+ """
13
+ 大型语言模型接口的抽象基类。
14
+
15
+ 该类定义了与远程LLM交互的标准接口。
16
+ 任何LLM提供商(如OpenAI、Anthropic等)都应作为该接口的子类来实现。
17
+ """
18
+
19
+ @abstractmethod
20
+ def generate(self, prompt: str, **kwargs) -> str:
21
+ """
22
+ 根据给定的提示从LLM生成响应。
23
+
24
+ 参数:
25
+ prompt: 发送给LLM的输入提示。
26
+ **kwargs: LLM API调用的其他关键字参数
27
+ (例如,temperature, max_tokens)。
28
+
29
+ 返回:
30
+ 由LLM生成的文本响应。
31
+ """
32
+ pass
33
+
34
+
35
+ class ToolAgent_LLM(LLMInterface):
36
+ """
37
+ LLMInterface的一个实现,它使用一个能操作工具的JarvisAgent来生成最终响应。
38
+ """
39
+
40
+ def __init__(self):
41
+ """
42
+ 初始化工具-代理 LLM 包装器。
43
+ """
44
+ print("🤖 已初始化工具 Agent 作为最终应答者。")
45
+ self.allowed_tools = ["read_code", "execute_script"]
46
+ # 为代理提供一个通用的系统提示
47
+ self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
48
+ self.summary_prompt = """
49
+ <report>
50
+ 请为本次问答任务生成一个总结报告,包含以下内容:
51
+
52
+ 1. **原始问题**: 重述用户最开始提出的问题。
53
+ 2. **关键信息来源**: 总结你是基于哪些关键信息或文件得出的结论。
54
+ 3. **最终答案**: 给出最终的、精炼的回答。
55
+ </report>
56
+ """
57
+
58
+ def generate(self, prompt: str, **kwargs) -> str:
59
+ """
60
+ 使用受限的工具集运行JarvisAgent以生成答案。
61
+
62
+ 参数:
63
+ prompt: 要发送给代理的完整提示,包括上下文。
64
+ **kwargs: 已忽略,为保持接口兼容性而保留。
65
+
66
+ 返回:
67
+ 由代理生成的最终答案。
68
+ """
69
+ try:
70
+ # 使用RAG上下文的特定设置初始化代理
71
+ agent = JarvisAgent(
72
+ system_prompt=self.system_prompt,
73
+ use_tools=self.allowed_tools,
74
+ auto_complete=True,
75
+ use_methodology=False,
76
+ use_analysis=False,
77
+ need_summary=True,
78
+ summary_prompt=self.summary_prompt,
79
+ )
80
+
81
+ # 代理的run方法需要'user_input'参数
82
+ final_answer = agent.run(user_input=prompt)
83
+ return str(final_answer)
84
+
85
+ except Exception as e:
86
+ print(f"❌ Agent 在执行过程中发生错误: {e}")
87
+ return "错误: Agent 未能成功生成回答。"
88
+
89
+
90
+ class JarvisPlatform_LLM(LLMInterface):
91
+ """
92
+ 项目内部平台的LLMInterface实现。
93
+
94
+ 该类使用PlatformRegistry来获取配置的“普通”模型。
95
+ """
96
+
97
+ def __init__(self):
98
+ """
99
+ 初始化Jarvis平台LLM客户端。
100
+ """
101
+ try:
102
+ self.registry = PlatformRegistry.get_global_platform_registry()
103
+ self.platform: BasePlatform = self.registry.get_normal_platform()
104
+ self.platform.set_suppress_output(
105
+ False
106
+ ) # 确保模型没有控制台输出
107
+ print(f"🚀 已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}")
108
+ except Exception as e:
109
+ print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
110
+ raise
111
+
112
+ def generate(self, prompt: str, **kwargs) -> str:
113
+ """
114
+ 向本地平台模型发送提示并返回响应。
115
+
116
+ 参数:
117
+ prompt: 用户的提示。
118
+ **kwargs: 已忽略,为保持接口兼容性而保留。
119
+
120
+ 返回:
121
+ 由平台模型生成的响应。
122
+ """
123
+ try:
124
+ # 使用健壮的chat_until_success方法
125
+ return self.platform.chat_until_success(prompt)
126
+ except Exception as e:
127
+ print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
128
+ return "错误: 无法从本地LLM获取响应。"
@@ -0,0 +1,62 @@
1
+ from typing import List
2
+ from .llm_interface import LLMInterface
3
+
4
+
5
+ class QueryRewriter:
6
+ """
7
+ 使用LLM将用户的查询重写为多个不同的搜索查询,以提高检索召回率。
8
+ """
9
+
10
+ def __init__(self, llm: LLMInterface):
11
+ """
12
+ 初始化QueryRewriter。
13
+
14
+ 参数:
15
+ llm: 实现LLMInterface接口的类的实例。
16
+ """
17
+ self.llm = llm
18
+ self.rewrite_prompt_template = self._create_prompt_template()
19
+
20
+ def _create_prompt_template(self) -> str:
21
+ """为多查询重写任务创建提示模板。"""
22
+ return """
23
+ 你是一个精通检索的AI助手。你的任务是将以下这个单一的用户问题,从不同角度改写成 3 个不同的、但语义上相关的搜索查询。这有助于在知识库中进行更全面的搜索。
24
+
25
+ 请遵循以下原则:
26
+ 1. **多样性**:生成的查询应尝试使用不同的关键词和表述方式。
27
+ 2. **保留核心意图**:所有查询都必须围绕原始问题的核心意图。
28
+ 3. **简洁性**:每个查询都应该是独立的、可以直接用于搜索的短语或问题。
29
+ 4. **格式要求**:请直接输出 3 个查询,每个查询占一行,用换行符分隔。不要添加任何编号、前缀或解释。
30
+
31
+ 原始问题:
32
+ ---
33
+ {query}
34
+ ---
35
+
36
+ 3个改写后的查询 (每行一个):
37
+ """
38
+
39
+ def rewrite(self, query: str) -> List[str]:
40
+ """
41
+ 使用LLM将用户查询重写为多个查询。
42
+
43
+ 参数:
44
+ query: 原始用户查询。
45
+
46
+ 返回:
47
+ 一个经过重写、搜索优化的查询列表。
48
+ """
49
+ prompt = self.rewrite_prompt_template.format(query=query)
50
+ print(f"✍️ 正在将原始查询重写为多个搜索查询...")
51
+
52
+ response_text = self.llm.generate(prompt)
53
+ rewritten_queries = [
54
+ line.strip() for line in response_text.strip().split("\n") if line.strip()
55
+ ]
56
+
57
+ # 同时包含原始查询以保证鲁棒性
58
+ if query not in rewritten_queries:
59
+ rewritten_queries.insert(0, query)
60
+
61
+ print(f"✅ 生成了 {len(rewritten_queries)} 个查询变体。")
62
+ return rewritten_queries
@@ -0,0 +1,174 @@
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ from langchain.docstore.document import Document
5
+
6
+ from .embedding_manager import EmbeddingManager
7
+ from .llm_interface import JarvisPlatform_LLM, LLMInterface, ToolAgent_LLM
8
+ from .query_rewriter import QueryRewriter
9
+ from .reranker import Reranker
10
+ from .retriever import ChromaRetriever
11
+ from jarvis.jarvis_utils.config import (
12
+ get_rag_embedding_model,
13
+ get_rag_rerank_model,
14
+ get_rag_vector_db_path,
15
+ get_rag_embedding_cache_path,
16
+ )
17
+
18
+
19
+ class JarvisRAGPipeline:
20
+ """
21
+ RAG管道的主要协调器。
22
+
23
+ 该类集成了嵌入管理器、检索器和LLM,为添加文档和查询
24
+ 提供了一个完整的管道。
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ llm: Optional[LLMInterface] = None,
30
+ embedding_model: Optional[str] = None,
31
+ db_path: Optional[str] = None,
32
+ collection_name: str = "jarvis_rag_collection",
33
+ ):
34
+ """
35
+ 初始化RAG管道。
36
+
37
+ 参数:
38
+ llm: 实现LLMInterface接口的类的实例。
39
+ 如果为None,则默认为ToolAgent_LLM。
40
+ embedding_model: 嵌入模型的名称。如果为None,则使用配置值。
41
+ db_path: 持久化向量数据库的路径。如果为None,则使用配置值。
42
+ collection_name: 向量数据库中集合的名称。
43
+ """
44
+ # 确定嵌入模型以隔离数据路径
45
+ model_name = embedding_model or get_rag_embedding_model()
46
+ sanitized_model_name = model_name.replace("/", "_").replace("\\", "_")
47
+
48
+ # 如果给定了特定的db_path,则使用它。否则,创建一个特定于模型的路径。
49
+ _final_db_path = (
50
+ str(db_path)
51
+ if db_path
52
+ else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
53
+ )
54
+ # 始终创建一个特定于模型的缓存路径。
55
+ _final_cache_path = os.path.join(
56
+ get_rag_embedding_cache_path(), sanitized_model_name
57
+ )
58
+
59
+ self.embedding_manager = EmbeddingManager(
60
+ model_name=model_name,
61
+ cache_dir=_final_cache_path,
62
+ )
63
+ self.retriever = ChromaRetriever(
64
+ embedding_manager=self.embedding_manager,
65
+ db_path=_final_db_path,
66
+ collection_name=collection_name,
67
+ )
68
+ # 除非提供了特定的LLM,否则默认为ToolAgent_LLM
69
+ self.llm = llm if llm is not None else ToolAgent_LLM()
70
+ self.reranker = Reranker(model_name=get_rag_rerank_model())
71
+ # 使用标准LLM执行查询重写任务,而不是代理
72
+ self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
73
+
74
+ print("✅ JarvisRAGPipeline 初始化成功。")
75
+
76
+ def add_documents(self, documents: List[Document]):
77
+ """
78
+ 将文档添加到向量知识库。
79
+
80
+ 参数:
81
+ documents: 要添加的LangChain文档对象列表。
82
+ """
83
+ self.retriever.add_documents(documents)
84
+
85
+ def _create_prompt(
86
+ self, query: str, context_docs: List[Document], source_files: List[str]
87
+ ) -> str:
88
+ """为LLM或代理创建最终的提示。"""
89
+ context = "\n\n".join([doc.page_content for doc in context_docs])
90
+ sources_text = "\n".join([f"- {source}" for source in source_files])
91
+
92
+ prompt_template = f"""
93
+ 你是一个专家助手。请根据用户的问题,结合下面提供的参考信息来回答。
94
+
95
+ **重要**: 提供的上下文和文件列表**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
96
+
97
+ 参考文件列表:
98
+ ---
99
+ {sources_text}
100
+ ---
101
+
102
+ 参考上下文:
103
+ ---
104
+ {context}
105
+ ---
106
+
107
+ 问题: {query}
108
+
109
+ 回答:
110
+ """
111
+ return prompt_template.strip()
112
+
113
+ def query(self, query_text: str, n_results: int = 5) -> str:
114
+ """
115
+ 使用多查询检索和重排管道对知识库执行查询。
116
+
117
+ 参数:
118
+ query_text: 用户的原始问题。
119
+ n_results: 要检索的最终相关块的数量。
120
+
121
+ 返回:
122
+ 由LLM生成的答案。
123
+ """
124
+ # 1. 将原始查询重写为多个查询
125
+ rewritten_queries = self.query_rewriter.rewrite(query_text)
126
+
127
+ # 2. 为每个重写的查询检索初始候选文档
128
+ all_candidate_docs = []
129
+ for q in rewritten_queries:
130
+ print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
131
+ candidates = self.retriever.retrieve(q, n_results=n_results * 2)
132
+ all_candidate_docs.extend(candidates)
133
+
134
+ # 对候选文档进行去重
135
+ unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
136
+ unique_candidate_docs = list(unique_docs_dict.values())
137
+
138
+ if not unique_candidate_docs:
139
+ return "我在提供的文档中找不到任何相关信息来回答您的问题。"
140
+
141
+ # 3. 根据*原始*查询对统一的候选池进行重排
142
+ print(
143
+ f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
144
+ )
145
+ retrieved_docs = self.reranker.rerank(
146
+ query_text, unique_candidate_docs, top_n=n_results
147
+ )
148
+
149
+ if not retrieved_docs:
150
+ return "我在提供的文档中找不到任何相关信息来回答您的问题。"
151
+
152
+ # 打印最终检索到的文档的来源
153
+ sources = sorted(
154
+ list(
155
+ {
156
+ doc.metadata["source"]
157
+ for doc in retrieved_docs
158
+ if "source" in doc.metadata
159
+ }
160
+ )
161
+ )
162
+ if sources:
163
+ print(f"📚 根据以下文档回答:")
164
+ for source in sources:
165
+ print(f" - {source}")
166
+
167
+ # 4. 创建最终提示并生成答案
168
+ # 我们使用原始的query_text作为给LLM的最终提示
169
+ prompt = self._create_prompt(query_text, retrieved_docs, sources)
170
+
171
+ print("🤖 正在从LLM生成答案...")
172
+ answer = self.llm.generate(prompt)
173
+
174
+ return answer
@@ -0,0 +1,56 @@
1
+ from typing import List
2
+
3
+ from langchain.docstore.document import Document
4
+ from sentence_transformers.cross_encoder import ( # type: ignore
5
+ CrossEncoder,
6
+ )
7
+
8
+
9
+ class Reranker:
10
+ """
11
+ 一个重排器类,使用Cross-Encoder模型根据文档与给定查询的相关性
12
+ 对文档进行重新评分和排序。
13
+ """
14
+
15
+ def __init__(self, model_name: str):
16
+ """
17
+ 初始化重排器。
18
+
19
+ 参数:
20
+ model_name (str): 要使用的Cross-Encoder模型的名称。
21
+ """
22
+ print(f"🔍 正在初始化重排模型: {model_name}...")
23
+ self.model = CrossEncoder(model_name)
24
+ print("✅ 重排模型初始化成功。")
25
+
26
+ def rerank(
27
+ self, query: str, documents: List[Document], top_n: int = 5
28
+ ) -> List[Document]:
29
+ """
30
+ 根据文档与查询的相关性对文档列表进行重排。
31
+
32
+ 参数:
33
+ query (str): 用户的查询。
34
+ documents (List[Document]): 从初始搜索中检索到的文档列表。
35
+ top_n (int): 重排后要返回的顶部文档数。
36
+
37
+ 返回:
38
+ List[Document]: 一个已排序的最相关文档列表。
39
+ """
40
+ if not documents:
41
+ return []
42
+
43
+ # 创建 [查询, 文档内容] 对用于评分
44
+ pairs = [[query, doc.page_content] for doc in documents]
45
+
46
+ # 从Cross-Encoder模型获取分数
47
+ scores = self.model.predict(pairs)
48
+
49
+ # 将文档与它们的分数结合并排序
50
+ doc_with_scores = list(zip(documents, scores))
51
+ doc_with_scores.sort(key=lambda x: x[1], reverse=True) # type: ignore
52
+
53
+ # 返回前N个文档
54
+ reranked_docs = [doc for doc, score in doc_with_scores[:top_n]]
55
+
56
+ return reranked_docs
@@ -0,0 +1,201 @@
1
+ import os
2
+ import pickle
3
+ from typing import Any, Dict, List, cast
4
+
5
+ import chromadb
6
+ from langchain.docstore.document import Document
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from rank_bm25 import BM25Okapi # type: ignore
9
+
10
+ from .embedding_manager import EmbeddingManager
11
+
12
+
13
+ class ChromaRetriever:
14
+ """
15
+ 一个检索器类,它结合了密集向量搜索(ChromaDB)和稀疏关键字搜索(BM25)
16
+ 以实现混合检索。
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ embedding_manager: EmbeddingManager,
22
+ db_path: str,
23
+ collection_name: str = "jarvis_rag_collection",
24
+ ):
25
+ """
26
+ 初始化ChromaRetriever。
27
+
28
+ 参数:
29
+ embedding_manager: EmbeddingManager的实例。
30
+ db_path: ChromaDB持久化存储的文件路径。
31
+ collection_name: ChromaDB中集合的名称。
32
+ """
33
+ self.embedding_manager = embedding_manager
34
+ self.db_path = db_path
35
+ self.collection_name = collection_name
36
+
37
+ # 初始化ChromaDB客户端
38
+ self.client = chromadb.PersistentClient(path=self.db_path)
39
+ self.collection = self.client.get_or_create_collection(
40
+ name=self.collection_name
41
+ )
42
+ print(
43
+ f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。"
44
+ )
45
+
46
+ # BM25索引设置
47
+ self.bm25_index_path = os.path.join(self.db_path, f"{collection_name}_bm25.pkl")
48
+ self._load_or_initialize_bm25()
49
+
50
+ def _load_or_initialize_bm25(self):
51
+ """从磁盘加载BM25索引或初始化一个新索引。"""
52
+ if os.path.exists(self.bm25_index_path):
53
+ print("🔍 正在加载现有的 BM25 索引...")
54
+ with open(self.bm25_index_path, "rb") as f:
55
+ data = pickle.load(f)
56
+ self.bm25_corpus = data["corpus"]
57
+ self.bm25_index = BM25Okapi(self.bm25_corpus)
58
+ print("✅ BM25 索引加载成功。")
59
+ else:
60
+ print("⚠️ 未找到 BM25 索引,将初始化一个新的。")
61
+ self.bm25_corpus = []
62
+ self.bm25_index = None
63
+
64
+ def _save_bm25_index(self):
65
+ """将BM25索引保存到磁盘。"""
66
+ if self.bm25_index:
67
+ print("💾 正在保存 BM25 索引...")
68
+ with open(self.bm25_index_path, "wb") as f:
69
+ pickle.dump({"corpus": self.bm25_corpus, "index": self.bm25_index}, f)
70
+ print("✅ BM25 索引保存成功。")
71
+
72
+ def add_documents(
73
+ self, documents: List[Document], chunk_size=1000, chunk_overlap=100
74
+ ):
75
+ """
76
+ 将文档拆分、嵌入,并添加到ChromaDB和BM25索引中。
77
+ """
78
+ text_splitter = RecursiveCharacterTextSplitter(
79
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
80
+ )
81
+ chunks = text_splitter.split_documents(documents)
82
+
83
+ print(f"📄 已将 {len(documents)} 个文档拆分为 {len(chunks)} 个块。")
84
+
85
+ if not chunks:
86
+ return
87
+
88
+ # 提取内容、元数据并生成ID
89
+ chunk_texts = [chunk.page_content for chunk in chunks]
90
+ metadatas = [chunk.metadata for chunk in chunks]
91
+ start_id = self.collection.count()
92
+ ids = [f"doc_{i}" for i in range(start_id, start_id + len(chunks))]
93
+
94
+ # 添加到ChromaDB
95
+ embeddings = self.embedding_manager.embed_documents(chunk_texts)
96
+ self.collection.add(
97
+ ids=ids,
98
+ embeddings=cast(Any, embeddings),
99
+ documents=chunk_texts,
100
+ metadatas=cast(Any, metadatas),
101
+ )
102
+ print(f"✅ 成功将 {len(chunks)} 个块添加到 ChromaDB 集合中。")
103
+
104
+ # 更新并保存BM25索引
105
+ tokenized_chunks = [doc.split() for doc in chunk_texts]
106
+ self.bm25_corpus.extend(tokenized_chunks)
107
+ self.bm25_index = BM25Okapi(self.bm25_corpus)
108
+ self._save_bm25_index()
109
+
110
+ def retrieve(self, query: str, n_results: int = 5) -> List[Document]:
111
+ """
112
+ 使用向量搜索和BM25执行混合检索,然后使用倒数排序融合(RRF)
113
+ 对结果进行融合。
114
+ """
115
+ # 1. 向量搜索 (ChromaDB)
116
+ query_embedding = self.embedding_manager.embed_query(query)
117
+ vector_results = self.collection.query(
118
+ query_embeddings=cast(Any, [query_embedding]),
119
+ n_results=n_results * 2, # 检索更多结果用于融合
120
+ )
121
+
122
+ # 2. 关键字搜索 (BM25)
123
+ bm25_docs = []
124
+ if self.bm25_index:
125
+ tokenized_query = query.split()
126
+ doc_scores = self.bm25_index.get_scores(tokenized_query)
127
+
128
+ # 从Chroma获取所有文档以匹配BM25分数
129
+ all_docs_in_collection = self.collection.get()
130
+ all_documents = all_docs_in_collection.get("documents")
131
+ all_metadatas = all_docs_in_collection.get("metadatas")
132
+
133
+ bm25_results_with_docs = []
134
+ if all_documents and all_metadatas:
135
+ # 创建从索引到文档的映射
136
+ bm25_results_with_docs = [
137
+ (
138
+ all_documents[i],
139
+ all_metadatas[i],
140
+ score,
141
+ )
142
+ for i, score in enumerate(doc_scores)
143
+ if score > 0
144
+ ]
145
+
146
+ # 按分数排序并取最高结果
147
+ bm25_results_with_docs.sort(key=lambda x: x[2], reverse=True)
148
+
149
+ for doc_text, metadata, _ in bm25_results_with_docs[: n_results * 2]:
150
+ bm25_docs.append(Document(page_content=doc_text, metadata=metadata))
151
+
152
+ # 3. 倒数排序融合 (RRF)
153
+ fused_scores: Dict[str, float] = {}
154
+ k = 60 # RRF排名常数
155
+
156
+ # 处理向量结果
157
+ if vector_results and vector_results["ids"] and vector_results["documents"]:
158
+ vec_ids = vector_results["ids"][0]
159
+ vec_texts = vector_results["documents"][0]
160
+
161
+ for rank, doc_id in enumerate(vec_ids):
162
+ fused_scores[doc_id] = fused_scores.get(doc_id, 0) + 1 / (k + rank)
163
+
164
+ # 为BM25融合创建从文档文本到其ID的映射
165
+ doc_text_to_id = {text: doc_id for text, doc_id in zip(vec_texts, vec_ids)}
166
+
167
+ for rank, doc in enumerate(bm25_docs):
168
+ bm25_doc_id = doc_text_to_id.get(doc.page_content)
169
+ if bm25_doc_id:
170
+ fused_scores[bm25_doc_id] = fused_scores.get(bm25_doc_id, 0) + 1 / (
171
+ k + rank
172
+ )
173
+
174
+ # 对融合结果进行排序
175
+ sorted_fused_results = sorted(
176
+ fused_scores.items(), key=lambda x: x[1], reverse=True
177
+ )
178
+
179
+ # 根据融合排名从ChromaDB获取最终文档
180
+ final_doc_ids = [item[0] for item in sorted_fused_results[:n_results]]
181
+
182
+ if not final_doc_ids:
183
+ return []
184
+
185
+ final_docs_data = self.collection.get(ids=final_doc_ids)
186
+
187
+ retrieved_docs = []
188
+ if final_docs_data:
189
+ final_documents = final_docs_data.get("documents")
190
+ final_metadatas = final_docs_data.get("metadatas")
191
+
192
+ if final_documents and final_metadatas:
193
+ for doc_text, metadata in zip(final_documents, final_metadatas):
194
+ if doc_text is not None and metadata is not None:
195
+ retrieved_docs.append(
196
+ Document(
197
+ page_content=cast(str, doc_text), metadata=metadata
198
+ )
199
+ )
200
+
201
+ return retrieved_docs