jarvis-ai-assistant 0.1.220__py3-none-any.whl → 0.1.221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,59 +1,45 @@
1
- from typing import List, Literal, cast
1
+ import torch
2
+ from typing import List, cast
2
3
  from langchain_huggingface import HuggingFaceEmbeddings
3
4
 
4
- from jarvis.jarvis_utils.config import (
5
- get_rag_embedding_models,
6
- get_rag_embedding_cache_path,
7
- )
8
5
  from .cache import EmbeddingCache
9
6
 
10
7
 
11
8
  class EmbeddingManager:
12
9
  """
13
- Manages the loading and usage of local embedding models with caching.
10
+ 管理本地嵌入模型的加载和使用,并带有缓存功能。
14
11
 
15
- This class handles the selection of embedding models based on a specified
16
- mode ('performance' or 'accuracy'), loads the model from Hugging Face,
17
- and uses a disk-based cache to avoid re-computing embeddings for the
18
- same text.
12
+ 该类负责从Hugging Face加载指定的模型,并使用基于磁盘的缓存
13
+ 来避免为相同文本重新计算嵌入。
19
14
  """
20
15
 
21
- def __init__(
22
- self,
23
- mode: Literal["performance", "accuracy"],
24
- cache_dir: str,
25
- ):
16
+ def __init__(self, model_name: str, cache_dir: str):
26
17
  """
27
- Initializes the EmbeddingManager.
18
+ 初始化EmbeddingManager
28
19
 
29
- Args:
30
- mode: The desired mode, either 'performance' or 'accuracy'.
31
- cache_dir: The directory to store the embedding cache.
20
+ 参数:
21
+ model_name: 要加载的Hugging Face模型的名称。
22
+ cache_dir: 用于存储嵌入缓存的目录。
32
23
  """
33
- self.mode = mode
34
- self.embedding_models = get_rag_embedding_models()
35
- if mode not in self.embedding_models:
36
- raise ValueError(
37
- f"Invalid mode '{mode}'. Must be one of {list(self.embedding_models.keys())}"
38
- )
39
-
40
- self.model_config = self.embedding_models[self.mode]
41
- self.model_name = self.model_config["model_name"]
24
+ self.model_name = model_name
42
25
 
43
- print(f"🚀 初始化嵌入管理器,模式: '{self.mode}', 模型: '{self.model_name}'...")
26
+ print(f"🚀 初始化嵌入管理器, 模型: '{self.model_name}'...")
44
27
 
45
- # The salt for the cache is the model name to prevent collisions
46
- self.cache = EmbeddingCache(cache_dir=cache_dir, salt=str(self.model_name))
28
+ # 缓存的salt是模型名称,以防止冲突
29
+ self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
47
30
  self.model = self._load_model()
48
31
 
49
32
  def _load_model(self) -> HuggingFaceEmbeddings:
50
- """Loads the Hugging Face embedding model based on the configuration."""
33
+ """根据配置加载Hugging Face嵌入模型。"""
34
+ model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
35
+ encode_kwargs = {"normalize_embeddings": True}
36
+
51
37
  try:
52
38
  return HuggingFaceEmbeddings(
53
39
  model_name=self.model_name,
54
- model_kwargs=self.model_config.get("model_kwargs"),
55
- encode_kwargs=self.model_config.get("encode_kwargs"),
56
- show_progress=self.model_config.get("show_progress", False),
40
+ model_kwargs=model_kwargs,
41
+ encode_kwargs=encode_kwargs,
42
+ show_progress=True,
57
43
  )
58
44
  except Exception as e:
59
45
  print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
@@ -62,18 +48,18 @@ class EmbeddingManager:
62
48
 
63
49
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
64
50
  """
65
- Computes embeddings for a list of documents, using the cache.
51
+ 使用缓存为文档列表计算嵌入。
66
52
 
67
- Args:
68
- texts: A list of documents (strings) to embed.
53
+ 参数:
54
+ texts: 要嵌入的文档(字符串)列表。
69
55
 
70
- Returns:
71
- A list of embeddings, one for each document.
56
+ 返回:
57
+ 一个嵌入列表,每个文档对应一个嵌入。
72
58
  """
73
59
  if not texts:
74
60
  return []
75
61
 
76
- # Check cache for existing embeddings
62
+ # 检查缓存中是否已存在嵌入
77
63
  cached_embeddings = self.cache.get_batch(texts)
78
64
 
79
65
  texts_to_embed = []
@@ -83,17 +69,17 @@ class EmbeddingManager:
83
69
  texts_to_embed.append(text)
84
70
  indices_to_embed.append(i)
85
71
 
86
- # Compute embeddings for texts that were not in the cache
72
+ # 为不在缓存中的文本计算嵌入
87
73
  if texts_to_embed:
88
74
  print(
89
75
  f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
90
76
  )
91
77
  new_embeddings = self.model.embed_documents(texts_to_embed)
92
78
 
93
- # Store new embeddings in the cache
79
+ # 将新的嵌入存储在缓存中
94
80
  self.cache.set_batch(texts_to_embed, new_embeddings)
95
81
 
96
- # Place new embeddings back into the results list
82
+ # 将新的嵌入放回结果列表中
97
83
  for i, embedding in zip(indices_to_embed, new_embeddings):
98
84
  cached_embeddings[i] = embedding
99
85
  else:
@@ -103,7 +89,7 @@ class EmbeddingManager:
103
89
 
104
90
  def embed_query(self, text: str) -> List[float]:
105
91
  """
106
- Computes the embedding for a single query.
107
- Queries are typically not cached, but we can add it if needed.
92
+ 为单个查询计算嵌入。
93
+ 查询通常不被缓存,但如果需要可以添加。
108
94
  """
109
95
  return self.model.embed_query(text)
@@ -10,42 +10,40 @@ from jarvis.jarvis_platform.registry import PlatformRegistry
10
10
 
11
11
  class LLMInterface(ABC):
12
12
  """
13
- Abstract Base Class for Large Language Model interfaces.
13
+ 大型语言模型接口的抽象基类。
14
14
 
15
- This class defines the standard interface for interacting with a remote LLM.
16
- Any LLM provider (OpenAI, Anthropic, etc.) should be implemented as a
17
- subclass of this interface.
15
+ 该类定义了与远程LLM交互的标准接口。
16
+ 任何LLM提供商(如OpenAIAnthropic等)都应作为该接口的子类来实现。
18
17
  """
19
18
 
20
19
  @abstractmethod
21
20
  def generate(self, prompt: str, **kwargs) -> str:
22
21
  """
23
- Generates a response from the LLM based on a given prompt.
22
+ 根据给定的提示从LLM生成响应。
24
23
 
25
- Args:
26
- prompt: The input prompt to send to the LLM.
27
- **kwargs: Additional keyword arguments for the LLM API call
28
- (e.g., temperature, max_tokens).
24
+ 参数:
25
+ prompt: 发送给LLM的输入提示。
26
+ **kwargs: LLM API调用的其他关键字参数
27
+ (例如,temperature, max_tokens)。
29
28
 
30
- Returns:
31
- The text response generated by the LLM.
29
+ 返回:
30
+ LLM生成的文本响应。
32
31
  """
33
32
  pass
34
33
 
35
34
 
36
35
  class ToolAgent_LLM(LLMInterface):
37
36
  """
38
- An implementation of the LLMInterface that uses a tool-wielding JarvisAgent
39
- to generate the final response.
37
+ LLMInterface的一个实现,它使用一个能操作工具的JarvisAgent来生成最终响应。
40
38
  """
41
39
 
42
40
  def __init__(self):
43
41
  """
44
- Initializes the Tool-Agent LLM wrapper.
42
+ 初始化工具-代理 LLM 包装器。
45
43
  """
46
44
  print("🤖 已初始化工具 Agent 作为最终应答者。")
47
45
  self.allowed_tools = ["read_code", "execute_script"]
48
- # A generic system prompt for the agent
46
+ # 为代理提供一个通用的系统提示
49
47
  self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
50
48
  self.summary_prompt = """
51
49
  <report>
@@ -59,17 +57,17 @@ class ToolAgent_LLM(LLMInterface):
59
57
 
60
58
  def generate(self, prompt: str, **kwargs) -> str:
61
59
  """
62
- Runs the JarvisAgent with a restricted toolset to generate an answer.
60
+ 使用受限的工具集运行JarvisAgent以生成答案。
63
61
 
64
- Args:
65
- prompt: The full prompt, including context, to be sent to the agent.
66
- **kwargs: Ignored, kept for interface compatibility.
62
+ 参数:
63
+ prompt: 要发送给代理的完整提示,包括上下文。
64
+ **kwargs: 已忽略,为保持接口兼容性而保留。
67
65
 
68
- Returns:
69
- The final answer generated by the agent.
66
+ 返回:
67
+ 由代理生成的最终答案。
70
68
  """
71
69
  try:
72
- # Initialize the agent with specific settings for RAG context
70
+ # 使用RAG上下文的特定设置初始化代理
73
71
  agent = JarvisAgent(
74
72
  system_prompt=self.system_prompt,
75
73
  use_tools=self.allowed_tools,
@@ -80,7 +78,7 @@ class ToolAgent_LLM(LLMInterface):
80
78
  summary_prompt=self.summary_prompt,
81
79
  )
82
80
 
83
- # The agent's run method expects the 'user_input' parameter
81
+ # 代理的run方法需要'user_input'参数
84
82
  final_answer = agent.run(user_input=prompt)
85
83
  return str(final_answer)
86
84
 
@@ -91,21 +89,21 @@ class ToolAgent_LLM(LLMInterface):
91
89
 
92
90
  class JarvisPlatform_LLM(LLMInterface):
93
91
  """
94
- An implementation of the LLMInterface for the project's internal platform.
92
+ 项目内部平台的LLMInterface实现。
95
93
 
96
- This class uses the PlatformRegistry to get the configured "normal" model.
94
+ 该类使用PlatformRegistry来获取配置的“普通”模型。
97
95
  """
98
96
 
99
97
  def __init__(self):
100
98
  """
101
- Initializes the Jarvis Platform LLM client.
99
+ 初始化Jarvis平台LLM客户端。
102
100
  """
103
101
  try:
104
102
  self.registry = PlatformRegistry.get_global_platform_registry()
105
103
  self.platform: BasePlatform = self.registry.get_normal_platform()
106
104
  self.platform.set_suppress_output(
107
105
  False
108
- ) # Ensure no console output from the model
106
+ ) # 确保模型没有控制台输出
109
107
  print(f"🚀 已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}")
110
108
  except Exception as e:
111
109
  print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
@@ -113,17 +111,17 @@ class JarvisPlatform_LLM(LLMInterface):
113
111
 
114
112
  def generate(self, prompt: str, **kwargs) -> str:
115
113
  """
116
- Sends a prompt to the local platform model and returns the response.
114
+ 向本地平台模型发送提示并返回响应。
117
115
 
118
- Args:
119
- prompt: The user's prompt.
120
- **kwargs: Ignored, kept for interface compatibility.
116
+ 参数:
117
+ prompt: 用户的提示。
118
+ **kwargs: 已忽略,为保持接口兼容性而保留。
121
119
 
122
- Returns:
123
- The response generated by the platform model.
120
+ 返回:
121
+ 由平台模型生成的响应。
124
122
  """
125
123
  try:
126
- # Use the robust chat_until_success method
124
+ # 使用健壮的chat_until_success方法
127
125
  return self.platform.chat_until_success(prompt)
128
126
  except Exception as e:
129
127
  print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
@@ -4,22 +4,21 @@ from .llm_interface import LLMInterface
4
4
 
5
5
  class QueryRewriter:
6
6
  """
7
- Uses an LLM to rewrite a user's query into multiple, diverse search
8
- queries to enhance retrieval recall.
7
+ 使用LLM将用户的查询重写为多个不同的搜索查询,以提高检索召回率。
9
8
  """
10
9
 
11
10
  def __init__(self, llm: LLMInterface):
12
11
  """
13
- Initializes the QueryRewriter.
12
+ 初始化QueryRewriter
14
13
 
15
- Args:
16
- llm: An instance of a class implementing LLMInterface.
14
+ 参数:
15
+ llm: 实现LLMInterface接口的类的实例。
17
16
  """
18
17
  self.llm = llm
19
18
  self.rewrite_prompt_template = self._create_prompt_template()
20
19
 
21
20
  def _create_prompt_template(self) -> str:
22
- """Creates the prompt template for the multi-query rewriting task."""
21
+ """为多查询重写任务创建提示模板。"""
23
22
  return """
24
23
  你是一个精通检索的AI助手。你的任务是将以下这个单一的用户问题,从不同角度改写成 3 个不同的、但语义上相关的搜索查询。这有助于在知识库中进行更全面的搜索。
25
24
 
@@ -39,13 +38,13 @@ class QueryRewriter:
39
38
 
40
39
  def rewrite(self, query: str) -> List[str]:
41
40
  """
42
- Rewrites the user query into multiple queries using the LLM.
41
+ 使用LLM将用户查询重写为多个查询。
43
42
 
44
- Args:
45
- query: The original user query.
43
+ 参数:
44
+ query: 原始用户查询。
46
45
 
47
- Returns:
48
- A list of rewritten, search-optimized queries.
46
+ 返回:
47
+ 一个经过重写、搜索优化的查询列表。
49
48
  """
50
49
  prompt = self.rewrite_prompt_template.format(query=query)
51
50
  print(f"✍️ 正在将原始查询重写为多个搜索查询...")
@@ -55,7 +54,7 @@ class QueryRewriter:
55
54
  line.strip() for line in response_text.strip().split("\n") if line.strip()
56
55
  ]
57
56
 
58
- # Also include the original query for robustness
57
+ # 同时包含原始查询以保证鲁棒性
59
58
  if query not in rewritten_queries:
60
59
  rewritten_queries.insert(0, query)
61
60
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import List, Literal, Optional, cast
2
+ from typing import List, Optional
3
3
 
4
4
  from langchain.docstore.document import Document
5
5
 
@@ -9,57 +9,55 @@ from .query_rewriter import QueryRewriter
9
9
  from .reranker import Reranker
10
10
  from .retriever import ChromaRetriever
11
11
  from jarvis.jarvis_utils.config import (
12
- get_rag_embedding_mode,
12
+ get_rag_embedding_model,
13
+ get_rag_rerank_model,
13
14
  get_rag_vector_db_path,
14
15
  get_rag_embedding_cache_path,
15
- get_rag_embedding_models,
16
16
  )
17
17
 
18
18
 
19
19
  class JarvisRAGPipeline:
20
20
  """
21
- The main orchestrator for the RAG pipeline.
21
+ RAG管道的主要协调器。
22
22
 
23
- This class integrates the embedding manager, retriever, and LLM to provide
24
- a complete pipeline for adding documents and querying them.
23
+ 该类集成了嵌入管理器、检索器和LLM,为添加文档和查询
24
+ 提供了一个完整的管道。
25
25
  """
26
26
 
27
27
  def __init__(
28
28
  self,
29
29
  llm: Optional[LLMInterface] = None,
30
- embedding_mode: Optional[Literal["performance", "accuracy"]] = None,
30
+ embedding_model: Optional[str] = None,
31
31
  db_path: Optional[str] = None,
32
32
  collection_name: str = "jarvis_rag_collection",
33
33
  ):
34
34
  """
35
- Initializes the RAG pipeline.
36
-
37
- Args:
38
- llm: An instance of a class implementing LLMInterface.
39
- If None, defaults to the ToolAgent_LLM.
40
- embedding_mode: The mode for the local embedding model. If None, uses config value.
41
- db_path: Path to the persistent vector database. If None, uses config value.
42
- collection_name: Name of the collection in the vector database.
35
+ 初始化RAG管道。
36
+
37
+ 参数:
38
+ llm: 实现LLMInterface接口的类的实例。
39
+ 如果为None,则默认为ToolAgent_LLM
40
+ embedding_model: 嵌入模型的名称。如果为None,则使用配置值。
41
+ db_path: 持久化向量数据库的路径。如果为None,则使用配置值。
42
+ collection_name: 向量数据库中集合的名称。
43
43
  """
44
- # Determine the embedding model to isolate data paths
45
- _embedding_mode = embedding_mode or get_rag_embedding_mode()
46
- embedding_models = get_rag_embedding_models()
47
- model_name = embedding_models[_embedding_mode]["model_name"]
44
+ # 确定嵌入模型以隔离数据路径
45
+ model_name = embedding_model or get_rag_embedding_model()
48
46
  sanitized_model_name = model_name.replace("/", "_").replace("\\", "_")
49
47
 
50
- # If a specific db_path is given, use it. Otherwise, create a model-specific path.
48
+ # 如果给定了特定的db_path,则使用它。否则,创建一个特定于模型的路径。
51
49
  _final_db_path = (
52
50
  str(db_path)
53
51
  if db_path
54
52
  else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
55
53
  )
56
- # Always create a model-specific cache path.
54
+ # 始终创建一个特定于模型的缓存路径。
57
55
  _final_cache_path = os.path.join(
58
56
  get_rag_embedding_cache_path(), sanitized_model_name
59
57
  )
60
58
 
61
59
  self.embedding_manager = EmbeddingManager(
62
- mode=cast(Literal["performance", "accuracy"], _embedding_mode),
60
+ model_name=model_name,
63
61
  cache_dir=_final_cache_path,
64
62
  )
65
63
  self.retriever = ChromaRetriever(
@@ -67,27 +65,27 @@ class JarvisRAGPipeline:
67
65
  db_path=_final_db_path,
68
66
  collection_name=collection_name,
69
67
  )
70
- # Default to the ToolAgent_LLM unless a specific LLM is provided
68
+ # 除非提供了特定的LLM,否则默认为ToolAgent_LLM
71
69
  self.llm = llm if llm is not None else ToolAgent_LLM()
72
- self.reranker = Reranker()
73
- # Use a standard LLM for the query rewriting task, not the agent
70
+ self.reranker = Reranker(model_name=get_rag_rerank_model())
71
+ # 使用标准LLM执行查询重写任务,而不是代理
74
72
  self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
75
73
 
76
74
  print("✅ JarvisRAGPipeline 初始化成功。")
77
75
 
78
76
  def add_documents(self, documents: List[Document]):
79
77
  """
80
- Adds documents to the vector knowledge base.
78
+ 将文档添加到向量知识库。
81
79
 
82
- Args:
83
- documents: A list of LangChain Document objects to add.
80
+ 参数:
81
+ documents: 要添加的LangChain文档对象列表。
84
82
  """
85
83
  self.retriever.add_documents(documents)
86
84
 
87
85
  def _create_prompt(
88
86
  self, query: str, context_docs: List[Document], source_files: List[str]
89
87
  ) -> str:
90
- """Creates the final prompt for the LLM or Agent."""
88
+ """LLM或代理创建最终的提示。"""
91
89
  context = "\n\n".join([doc.page_content for doc in context_docs])
92
90
  sources_text = "\n".join([f"- {source}" for source in source_files])
93
91
 
@@ -114,34 +112,33 @@ class JarvisRAGPipeline:
114
112
 
115
113
  def query(self, query_text: str, n_results: int = 5) -> str:
116
114
  """
117
- Performs a query against the knowledge base using a multi-query
118
- retrieval and reranking pipeline.
115
+ 使用多查询检索和重排管道对知识库执行查询。
119
116
 
120
- Args:
121
- query_text: The user's original question.
122
- n_results: The number of final relevant chunks to retrieve.
117
+ 参数:
118
+ query_text: 用户的原始问题。
119
+ n_results: 要检索的最终相关块的数量。
123
120
 
124
- Returns:
125
- The answer generated by the LLM.
121
+ 返回:
122
+ LLM生成的答案。
126
123
  """
127
- # 1. Rewrite the original query into multiple queries
124
+ # 1. 将原始查询重写为多个查询
128
125
  rewritten_queries = self.query_rewriter.rewrite(query_text)
129
126
 
130
- # 2. Retrieve initial candidates for each rewritten query
127
+ # 2. 为每个重写的查询检索初始候选文档
131
128
  all_candidate_docs = []
132
129
  for q in rewritten_queries:
133
130
  print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
134
131
  candidates = self.retriever.retrieve(q, n_results=n_results * 2)
135
132
  all_candidate_docs.extend(candidates)
136
133
 
137
- # De-duplicate the candidate documents
134
+ # 对候选文档进行去重
138
135
  unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
139
136
  unique_candidate_docs = list(unique_docs_dict.values())
140
137
 
141
138
  if not unique_candidate_docs:
142
139
  return "我在提供的文档中找不到任何相关信息来回答您的问题。"
143
140
 
144
- # 3. Rerank the unified candidate pool against the *original* query
141
+ # 3. 根据*原始*查询对统一的候选池进行重排
145
142
  print(
146
143
  f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
147
144
  )
@@ -152,7 +149,7 @@ class JarvisRAGPipeline:
152
149
  if not retrieved_docs:
153
150
  return "我在提供的文档中找不到任何相关信息来回答您的问题。"
154
151
 
155
- # Print the sources of the final retrieved documents
152
+ # 打印最终检索到的文档的来源
156
153
  sources = sorted(
157
154
  list(
158
155
  {
@@ -167,8 +164,8 @@ class JarvisRAGPipeline:
167
164
  for source in sources:
168
165
  print(f" - {source}")
169
166
 
170
- # 4. Create the final prompt and generate the answer
171
- # We use the original query_text for the final prompt to the LLM
167
+ # 4. 创建最终提示并生成答案
168
+ # 我们使用原始的query_text作为给LLM的最终提示
172
169
  prompt = self._create_prompt(query_text, retrieved_docs, sources)
173
170
 
174
171
  print("🤖 正在从LLM生成答案...")
@@ -8,16 +8,16 @@ from sentence_transformers.cross_encoder import ( # type: ignore
8
8
 
9
9
  class Reranker:
10
10
  """
11
- A reranker class that uses a Cross-Encoder model to re-score and sort
12
- documents based on their relevance to a given query.
11
+ 一个重排器类,使用Cross-Encoder模型根据文档与给定查询的相关性
12
+ 对文档进行重新评分和排序。
13
13
  """
14
14
 
15
- def __init__(self, model_name: str = "BAAI/bge-reranker-base"):
15
+ def __init__(self, model_name: str):
16
16
  """
17
- Initializes the Reranker.
17
+ 初始化重排器。
18
18
 
19
- Args:
20
- model_name (str): The name of the Cross-Encoder model to use.
19
+ 参数:
20
+ model_name (str): 要使用的Cross-Encoder模型的名称。
21
21
  """
22
22
  print(f"🔍 正在初始化重排模型: {model_name}...")
23
23
  self.model = CrossEncoder(model_name)
@@ -27,30 +27,30 @@ class Reranker:
27
27
  self, query: str, documents: List[Document], top_n: int = 5
28
28
  ) -> List[Document]:
29
29
  """
30
- Reranks a list of documents based on their relevance to the query.
30
+ 根据文档与查询的相关性对文档列表进行重排。
31
31
 
32
- Args:
33
- query (str): The user's query.
34
- documents (List[Document]): The list of documents retrieved from the initial search.
35
- top_n (int): The number of top documents to return after reranking.
32
+ 参数:
33
+ query (str): 用户的查询。
34
+ documents (List[Document]): 从初始搜索中检索到的文档列表。
35
+ top_n (int): 重排后要返回的顶部文档数。
36
36
 
37
- Returns:
38
- List[Document]: A sorted list of the most relevant documents.
37
+ 返回:
38
+ List[Document]: 一个已排序的最相关文档列表。
39
39
  """
40
40
  if not documents:
41
41
  return []
42
42
 
43
- # Create pairs of [query, document_content] for scoring
43
+ # 创建 [查询, 文档内容] 对用于评分
44
44
  pairs = [[query, doc.page_content] for doc in documents]
45
45
 
46
- # Get scores from the Cross-Encoder model
46
+ # Cross-Encoder模型获取分数
47
47
  scores = self.model.predict(pairs)
48
48
 
49
- # Combine documents with their scores and sort
49
+ # 将文档与它们的分数结合并排序
50
50
  doc_with_scores = list(zip(documents, scores))
51
- doc_with_scores.sort(key=lambda x: x[1], reverse=True)
51
+ doc_with_scores.sort(key=lambda x: x[1], reverse=True) # type: ignore
52
52
 
53
- # Return the top N documents
53
+ # 返回前N个文档
54
54
  reranked_docs = [doc for doc, score in doc_with_scores[:top_n]]
55
55
 
56
56
  return reranked_docs