jarvis-ai-assistant 0.1.218__py3-none-any.whl → 0.1.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +37 -92
  3. jarvis/jarvis_agent/shell_input_handler.py +1 -1
  4. jarvis/jarvis_code_agent/code_agent.py +5 -3
  5. jarvis/jarvis_data/config_schema.json +30 -0
  6. jarvis/jarvis_git_squash/main.py +2 -1
  7. jarvis/jarvis_platform/human.py +2 -7
  8. jarvis/jarvis_platform/yuanbao.py +3 -1
  9. jarvis/jarvis_rag/__init__.py +11 -0
  10. jarvis/jarvis_rag/cache.py +87 -0
  11. jarvis/jarvis_rag/cli.py +297 -0
  12. jarvis/jarvis_rag/embedding_manager.py +109 -0
  13. jarvis/jarvis_rag/llm_interface.py +130 -0
  14. jarvis/jarvis_rag/query_rewriter.py +63 -0
  15. jarvis/jarvis_rag/rag_pipeline.py +177 -0
  16. jarvis/jarvis_rag/reranker.py +56 -0
  17. jarvis/jarvis_rag/retriever.py +201 -0
  18. jarvis/jarvis_tools/search_web.py +127 -11
  19. jarvis/jarvis_utils/config.py +71 -0
  20. jarvis/jarvis_utils/git_utils.py +27 -18
  21. jarvis/jarvis_utils/input.py +21 -10
  22. jarvis/jarvis_utils/utils.py +43 -20
  23. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/METADATA +87 -5
  24. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/RECORD +28 -19
  25. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/entry_points.txt +1 -0
  26. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/WHEEL +0 -0
  27. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/licenses/LICENSE +0 -0
  28. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,297 @@
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Optional, List, Literal, cast
5
+ import mimetypes
6
+
7
+ import typer
8
+ from langchain.docstore.document import Document
9
+ from langchain_community.document_loaders import (
10
+ TextLoader,
11
+ UnstructuredMarkdownLoader,
12
+ )
13
+ from langchain_core.document_loaders.base import BaseLoader
14
+ from rich.markdown import Markdown
15
+
16
+ from jarvis.jarvis_utils.utils import init_env
17
+
18
+
19
+ def is_likely_text_file(file_path: Path) -> bool:
20
+ """
21
+ Checks if a file is likely to be a text file by reading its beginning.
22
+ Avoids loading large binary files into memory.
23
+ """
24
+ try:
25
+ # Heuristic 1: Check MIME type if available
26
+ mime_type, _ = mimetypes.guess_type(file_path)
27
+ if mime_type and mime_type.startswith("text/"):
28
+ return True
29
+ if mime_type and any(x in mime_type for x in ["json", "xml", "javascript"]):
30
+ return True
31
+
32
+ # Heuristic 2: Check for null bytes in the first few KB
33
+ with open(file_path, "rb") as f:
34
+ chunk = f.read(4096) # Read first 4KB
35
+ if b"\x00" in chunk:
36
+ return False # Null bytes are a strong indicator of a binary file
37
+ return True
38
+ except Exception:
39
+ return False
40
+
41
+
42
+ # Ensure the project root is in the Python path to allow absolute imports
43
+ # This makes the script runnable as a module.
44
+ _project_root = os.path.abspath(
45
+ os.path.join(os.path.dirname(__file__), "..", "..", "..")
46
+ )
47
+ if _project_root not in sys.path:
48
+ sys.path.insert(0, _project_root)
49
+
50
+ from jarvis.jarvis_platform.base import BasePlatform
51
+ from jarvis.jarvis_platform.registry import PlatformRegistry
52
+ from jarvis.jarvis_rag.llm_interface import LLMInterface
53
+ from jarvis.jarvis_rag.rag_pipeline import JarvisRAGPipeline
54
+
55
+ app = typer.Typer(
56
+ name="jarvis-rag",
57
+ help="A command-line tool to interact with the Jarvis RAG framework.",
58
+ add_completion=False,
59
+ )
60
+
61
+
62
+ class _CustomPlatformLLM(LLMInterface):
63
+ """A simple wrapper to make a BasePlatform instance compatible with LLMInterface."""
64
+
65
+ def __init__(self, platform: BasePlatform):
66
+ self.platform = platform
67
+ print(
68
+ f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'"
69
+ )
70
+
71
+ def generate(self, prompt: str, **kwargs) -> str:
72
+ return self.platform.chat_until_success(prompt)
73
+
74
+
75
+ def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInterface]:
76
+ """Creates an LLM interface from a specific platform and model."""
77
+ if not platform_name or not model_name:
78
+ return None
79
+ try:
80
+ registry = PlatformRegistry.get_global_platform_registry()
81
+ platform_instance = registry.create_platform(platform_name)
82
+ if not platform_instance:
83
+ print(f"❌ 错误: 平台 '{platform_name}' 未找到。")
84
+ return None
85
+ platform_instance.set_model_name(model_name)
86
+ platform_instance.set_suppress_output(True)
87
+ return _CustomPlatformLLM(platform_instance)
88
+ except Exception as e:
89
+ print(f"❌ 创建自定义LLM时出错: {e}")
90
+ return None
91
+
92
+
93
+ @app.command(
94
+ "add",
95
+ help="Add documents from files, directories, or glob patterns (e.g., 'src/**/*.py').",
96
+ )
97
+ def add_documents(
98
+ paths: List[Path] = typer.Argument(
99
+ ...,
100
+ help="File/directory paths or glob patterns. Shell expansion is supported.",
101
+ ),
102
+ collection_name: str = typer.Option(
103
+ "jarvis_rag_collection",
104
+ "--collection",
105
+ "-c",
106
+ help="Name of the collection in the vector database.",
107
+ ),
108
+ embedding_mode: Optional[str] = typer.Option(
109
+ None,
110
+ "--embedding-mode",
111
+ "-e",
112
+ help="Embedding mode ('performance' or 'accuracy'). Overrides global config.",
113
+ ),
114
+ db_path: Optional[Path] = typer.Option(
115
+ None, "--db-path", help="Path to the vector database. Overrides global config."
116
+ ),
117
+ ):
118
+ """Adds documents to the RAG knowledge base from various sources."""
119
+ files_to_process = set()
120
+
121
+ for path_str in paths:
122
+ # Typer with List[Path] might not expand globs, so we do it manually
123
+ from glob import glob
124
+
125
+ expanded_paths = glob(str(path_str), recursive=True)
126
+
127
+ for p_str in expanded_paths:
128
+ path = Path(p_str)
129
+ if not path.exists():
130
+ continue
131
+
132
+ if path.is_dir():
133
+ print(f"🔍 正在扫描目录: {path}")
134
+ for item in path.rglob("*"):
135
+ if item.is_file() and is_likely_text_file(item):
136
+ files_to_process.add(item)
137
+ elif path.is_file():
138
+ if is_likely_text_file(path):
139
+ files_to_process.add(path)
140
+ else:
141
+ print(f"⚠️ 跳过可能的二进制文件: {path}")
142
+
143
+ if not files_to_process:
144
+ print(f"⚠️ 在指定路径中未找到任何文本文件。")
145
+ return
146
+
147
+ print(f"✅ 发现 {len(files_to_process)} 个独立文件待处理。")
148
+
149
+ try:
150
+ pipeline = JarvisRAGPipeline(
151
+ embedding_mode=cast(
152
+ Optional[Literal["performance", "accuracy"]], embedding_mode
153
+ ),
154
+ db_path=str(db_path) if db_path else None,
155
+ collection_name=collection_name,
156
+ )
157
+
158
+ docs: List[Document] = []
159
+ loader: BaseLoader
160
+ for file_path in sorted(list(files_to_process)):
161
+ try:
162
+ if file_path.suffix.lower() == ".md":
163
+ loader = UnstructuredMarkdownLoader(str(file_path))
164
+ else: # Default to TextLoader for .txt and all code files
165
+ loader = TextLoader(str(file_path), encoding="utf-8")
166
+
167
+ docs.extend(loader.load())
168
+ print(f"✅ 已加载: {file_path}")
169
+ except Exception as e:
170
+ print(f"⚠️ 加载失败 {file_path}: {e}")
171
+
172
+ if not docs:
173
+ print("❌ 未能成功加载任何文档。")
174
+ raise typer.Exit(code=1)
175
+
176
+ pipeline.add_documents(docs)
177
+ print(f"✅ 成功将 {len(docs)} 个文档的内容添加至集合 '{collection_name}'。")
178
+
179
+ except Exception as e:
180
+ print(f"❌ 发生严重错误: {e}")
181
+ raise typer.Exit(code=1)
182
+
183
+
184
+ @app.command("list-docs", help="List all unique documents in the knowledge base.")
185
+ def list_documents(
186
+ collection_name: str = typer.Option(
187
+ "jarvis_rag_collection",
188
+ "--collection",
189
+ "-c",
190
+ help="Name of the collection in the vector database.",
191
+ ),
192
+ db_path: Optional[Path] = typer.Option(
193
+ None, "--db-path", help="Path to the vector database. Overrides global config."
194
+ ),
195
+ ):
196
+ """Lists all unique documents in the specified collection."""
197
+ try:
198
+ pipeline = JarvisRAGPipeline(
199
+ db_path=str(db_path) if db_path else None,
200
+ collection_name=collection_name,
201
+ )
202
+
203
+ collection = pipeline.retriever.collection
204
+ results = collection.get() # Get all items in the collection
205
+
206
+ if not results or not results["metadatas"]:
207
+ print("ℹ️ 知识库中没有找到任何文档。")
208
+ return
209
+
210
+ # Extract unique source file paths from metadata
211
+ sources = set()
212
+ for metadata in results["metadatas"]:
213
+ if metadata:
214
+ source = metadata.get("source")
215
+ if isinstance(source, str):
216
+ sources.add(source)
217
+
218
+ if not sources:
219
+ print("ℹ️ 知识库中没有找到任何带有源信息的文档。")
220
+ return
221
+
222
+ print(f"📚 知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:")
223
+ for i, source in enumerate(sorted(list(sources)), 1):
224
+ print(f" {i}. {source}")
225
+
226
+ except Exception as e:
227
+ print(f"❌ 发生错误: {e}")
228
+ raise typer.Exit(code=1)
229
+
230
+
231
+ @app.command("query", help="Ask a question to the knowledge base.")
232
+ def query(
233
+ question: str = typer.Argument(..., help="The question to ask."),
234
+ collection_name: str = typer.Option(
235
+ "jarvis_rag_collection",
236
+ "--collection",
237
+ "-c",
238
+ help="Name of the collection in the vector database.",
239
+ ),
240
+ embedding_mode: Optional[str] = typer.Option(
241
+ None,
242
+ "--embedding-mode",
243
+ "-e",
244
+ help="Embedding mode ('performance' or 'accuracy'). Overrides global config.",
245
+ ),
246
+ db_path: Optional[Path] = typer.Option(
247
+ None, "--db-path", help="Path to the vector database. Overrides global config."
248
+ ),
249
+ platform: Optional[str] = typer.Option(
250
+ None,
251
+ "--platform",
252
+ "-p",
253
+ help="Specify a platform name for the LLM. Overrides the default thinking model.",
254
+ ),
255
+ model: Optional[str] = typer.Option(
256
+ None,
257
+ "--model",
258
+ "-m",
259
+ help="Specify a model name for the LLM. Requires --platform.",
260
+ ),
261
+ ):
262
+ """Queries the RAG knowledge base and prints the answer."""
263
+ if model and not platform:
264
+ print("❌ 错误: --model 需要指定 --platform。")
265
+ raise typer.Exit(code=1)
266
+
267
+ try:
268
+ custom_llm = _create_custom_llm(platform, model) if platform and model else None
269
+ if (platform or model) and not custom_llm:
270
+ raise typer.Exit(code=1)
271
+
272
+ pipeline = JarvisRAGPipeline(
273
+ llm=custom_llm,
274
+ embedding_mode=cast(
275
+ Optional[Literal["performance", "accuracy"]], embedding_mode
276
+ ),
277
+ db_path=str(db_path) if db_path else None,
278
+ collection_name=collection_name,
279
+ )
280
+
281
+ print(f"🤔 正在查询: '{question}'")
282
+ answer = pipeline.query(question)
283
+
284
+ print("💬 答案:")
285
+ # We can still use rich.markdown.Markdown as PrettyOutput uses rich underneath
286
+ from jarvis.jarvis_utils.globals import console
287
+
288
+ console.print(Markdown(answer))
289
+
290
+ except Exception as e:
291
+ print(f"❌ 发生错误: {e}")
292
+ raise typer.Exit(code=1)
293
+
294
+
295
+ def main():
296
+ init_env(welcome_str="Jarvis RAG")
297
+ app()
@@ -0,0 +1,109 @@
1
+ from typing import List, Literal, cast
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+
4
+ from jarvis.jarvis_utils.config import (
5
+ get_rag_embedding_models,
6
+ get_rag_embedding_cache_path,
7
+ )
8
+ from .cache import EmbeddingCache
9
+
10
+
11
+ class EmbeddingManager:
12
+ """
13
+ Manages the loading and usage of local embedding models with caching.
14
+
15
+ This class handles the selection of embedding models based on a specified
16
+ mode ('performance' or 'accuracy'), loads the model from Hugging Face,
17
+ and uses a disk-based cache to avoid re-computing embeddings for the
18
+ same text.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ mode: Literal["performance", "accuracy"],
24
+ cache_dir: str,
25
+ ):
26
+ """
27
+ Initializes the EmbeddingManager.
28
+
29
+ Args:
30
+ mode: The desired mode, either 'performance' or 'accuracy'.
31
+ cache_dir: The directory to store the embedding cache.
32
+ """
33
+ self.mode = mode
34
+ self.embedding_models = get_rag_embedding_models()
35
+ if mode not in self.embedding_models:
36
+ raise ValueError(
37
+ f"Invalid mode '{mode}'. Must be one of {list(self.embedding_models.keys())}"
38
+ )
39
+
40
+ self.model_config = self.embedding_models[self.mode]
41
+ self.model_name = self.model_config["model_name"]
42
+
43
+ print(f"🚀 初始化嵌入管理器,模式: '{self.mode}', 模型: '{self.model_name}'...")
44
+
45
+ # The salt for the cache is the model name to prevent collisions
46
+ self.cache = EmbeddingCache(cache_dir=cache_dir, salt=str(self.model_name))
47
+ self.model = self._load_model()
48
+
49
+ def _load_model(self) -> HuggingFaceEmbeddings:
50
+ """Loads the Hugging Face embedding model based on the configuration."""
51
+ try:
52
+ return HuggingFaceEmbeddings(
53
+ model_name=self.model_name,
54
+ model_kwargs=self.model_config.get("model_kwargs"),
55
+ encode_kwargs=self.model_config.get("encode_kwargs"),
56
+ show_progress=self.model_config.get("show_progress", False),
57
+ )
58
+ except Exception as e:
59
+ print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
60
+ print("请确保您已安装 'sentence_transformers' 和 'torch'。")
61
+ raise
62
+
63
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
64
+ """
65
+ Computes embeddings for a list of documents, using the cache.
66
+
67
+ Args:
68
+ texts: A list of documents (strings) to embed.
69
+
70
+ Returns:
71
+ A list of embeddings, one for each document.
72
+ """
73
+ if not texts:
74
+ return []
75
+
76
+ # Check cache for existing embeddings
77
+ cached_embeddings = self.cache.get_batch(texts)
78
+
79
+ texts_to_embed = []
80
+ indices_to_embed = []
81
+ for i, (text, cached) in enumerate(zip(texts, cached_embeddings)):
82
+ if cached is None:
83
+ texts_to_embed.append(text)
84
+ indices_to_embed.append(i)
85
+
86
+ # Compute embeddings for texts that were not in the cache
87
+ if texts_to_embed:
88
+ print(
89
+ f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
90
+ )
91
+ new_embeddings = self.model.embed_documents(texts_to_embed)
92
+
93
+ # Store new embeddings in the cache
94
+ self.cache.set_batch(texts_to_embed, new_embeddings)
95
+
96
+ # Place new embeddings back into the results list
97
+ for i, embedding in zip(indices_to_embed, new_embeddings):
98
+ cached_embeddings[i] = embedding
99
+ else:
100
+ print(f"✅ 缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。")
101
+
102
+ return cast(List[List[float]], cached_embeddings)
103
+
104
+ def embed_query(self, text: str) -> List[float]:
105
+ """
106
+ Computes the embedding for a single query.
107
+ Queries are typically not cached, but we can add it if needed.
108
+ """
109
+ return self.model.embed_query(text)
@@ -0,0 +1,130 @@
1
+ from abc import ABC, abstractmethod
2
+ import os
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+
6
+ from jarvis.jarvis_agent import Agent as JarvisAgent
7
+ from jarvis.jarvis_platform.base import BasePlatform
8
+ from jarvis.jarvis_platform.registry import PlatformRegistry
9
+
10
+
11
+ class LLMInterface(ABC):
12
+ """
13
+ Abstract Base Class for Large Language Model interfaces.
14
+
15
+ This class defines the standard interface for interacting with a remote LLM.
16
+ Any LLM provider (OpenAI, Anthropic, etc.) should be implemented as a
17
+ subclass of this interface.
18
+ """
19
+
20
+ @abstractmethod
21
+ def generate(self, prompt: str, **kwargs) -> str:
22
+ """
23
+ Generates a response from the LLM based on a given prompt.
24
+
25
+ Args:
26
+ prompt: The input prompt to send to the LLM.
27
+ **kwargs: Additional keyword arguments for the LLM API call
28
+ (e.g., temperature, max_tokens).
29
+
30
+ Returns:
31
+ The text response generated by the LLM.
32
+ """
33
+ pass
34
+
35
+
36
+ class ToolAgent_LLM(LLMInterface):
37
+ """
38
+ An implementation of the LLMInterface that uses a tool-wielding JarvisAgent
39
+ to generate the final response.
40
+ """
41
+
42
+ def __init__(self):
43
+ """
44
+ Initializes the Tool-Agent LLM wrapper.
45
+ """
46
+ print("🤖 已初始化工具 Agent 作为最终应答者。")
47
+ self.allowed_tools = ["read_code", "execute_script"]
48
+ # A generic system prompt for the agent
49
+ self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
50
+ self.summary_prompt = """
51
+ <report>
52
+ 请为本次问答任务生成一个总结报告,包含以下内容:
53
+
54
+ 1. **原始问题**: 重述用户最开始提出的问题。
55
+ 2. **关键信息来源**: 总结你是基于哪些关键信息或文件得出的结论。
56
+ 3. **最终答案**: 给出最终的、精炼的回答。
57
+ </report>
58
+ """
59
+
60
+ def generate(self, prompt: str, **kwargs) -> str:
61
+ """
62
+ Runs the JarvisAgent with a restricted toolset to generate an answer.
63
+
64
+ Args:
65
+ prompt: The full prompt, including context, to be sent to the agent.
66
+ **kwargs: Ignored, kept for interface compatibility.
67
+
68
+ Returns:
69
+ The final answer generated by the agent.
70
+ """
71
+ try:
72
+ # Initialize the agent with specific settings for RAG context
73
+ agent = JarvisAgent(
74
+ system_prompt=self.system_prompt,
75
+ use_tools=self.allowed_tools,
76
+ auto_complete=True,
77
+ use_methodology=False,
78
+ use_analysis=False,
79
+ need_summary=True,
80
+ summary_prompt=self.summary_prompt,
81
+ )
82
+
83
+ # The agent's run method expects the 'user_input' parameter
84
+ final_answer = agent.run(user_input=prompt)
85
+ return str(final_answer)
86
+
87
+ except Exception as e:
88
+ print(f"❌ Agent 在执行过程中发生错误: {e}")
89
+ return "错误: Agent 未能成功生成回答。"
90
+
91
+
92
+ class JarvisPlatform_LLM(LLMInterface):
93
+ """
94
+ An implementation of the LLMInterface for the project's internal platform.
95
+
96
+ This class uses the PlatformRegistry to get the configured "normal" model.
97
+ """
98
+
99
+ def __init__(self):
100
+ """
101
+ Initializes the Jarvis Platform LLM client.
102
+ """
103
+ try:
104
+ self.registry = PlatformRegistry.get_global_platform_registry()
105
+ self.platform: BasePlatform = self.registry.get_normal_platform()
106
+ self.platform.set_suppress_output(
107
+ False
108
+ ) # Ensure no console output from the model
109
+ print(f"🚀 已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}")
110
+ except Exception as e:
111
+ print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
112
+ raise
113
+
114
+ def generate(self, prompt: str, **kwargs) -> str:
115
+ """
116
+ Sends a prompt to the local platform model and returns the response.
117
+
118
+ Args:
119
+ prompt: The user's prompt.
120
+ **kwargs: Ignored, kept for interface compatibility.
121
+
122
+ Returns:
123
+ The response generated by the platform model.
124
+ """
125
+ try:
126
+ # Use the robust chat_until_success method
127
+ return self.platform.chat_until_success(prompt)
128
+ except Exception as e:
129
+ print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
130
+ return "错误: 无法从本地LLM获取响应。"
@@ -0,0 +1,63 @@
1
+ from typing import List
2
+ from .llm_interface import LLMInterface
3
+
4
+
5
+ class QueryRewriter:
6
+ """
7
+ Uses an LLM to rewrite a user's query into multiple, diverse search
8
+ queries to enhance retrieval recall.
9
+ """
10
+
11
+ def __init__(self, llm: LLMInterface):
12
+ """
13
+ Initializes the QueryRewriter.
14
+
15
+ Args:
16
+ llm: An instance of a class implementing LLMInterface.
17
+ """
18
+ self.llm = llm
19
+ self.rewrite_prompt_template = self._create_prompt_template()
20
+
21
+ def _create_prompt_template(self) -> str:
22
+ """Creates the prompt template for the multi-query rewriting task."""
23
+ return """
24
+ 你是一个精通检索的AI助手。你的任务是将以下这个单一的用户问题,从不同角度改写成 3 个不同的、但语义上相关的搜索查询。这有助于在知识库中进行更全面的搜索。
25
+
26
+ 请遵循以下原则:
27
+ 1. **多样性**:生成的查询应尝试使用不同的关键词和表述方式。
28
+ 2. **保留核心意图**:所有查询都必须围绕原始问题的核心意图。
29
+ 3. **简洁性**:每个查询都应该是独立的、可以直接用于搜索的短语或问题。
30
+ 4. **格式要求**:请直接输出 3 个查询,每个查询占一行,用换行符分隔。不要添加任何编号、前缀或解释。
31
+
32
+ 原始问题:
33
+ ---
34
+ {query}
35
+ ---
36
+
37
+ 3个改写后的查询 (每行一个):
38
+ """
39
+
40
+ def rewrite(self, query: str) -> List[str]:
41
+ """
42
+ Rewrites the user query into multiple queries using the LLM.
43
+
44
+ Args:
45
+ query: The original user query.
46
+
47
+ Returns:
48
+ A list of rewritten, search-optimized queries.
49
+ """
50
+ prompt = self.rewrite_prompt_template.format(query=query)
51
+ print(f"✍️ 正在将原始查询重写为多个搜索查询...")
52
+
53
+ response_text = self.llm.generate(prompt)
54
+ rewritten_queries = [
55
+ line.strip() for line in response_text.strip().split("\n") if line.strip()
56
+ ]
57
+
58
+ # Also include the original query for robustness
59
+ if query not in rewritten_queries:
60
+ rewritten_queries.insert(0, query)
61
+
62
+ print(f"✅ 生成了 {len(rewritten_queries)} 个查询变体。")
63
+ return rewritten_queries