jarvis-ai-assistant 0.1.218__py3-none-any.whl → 0.1.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +37 -92
- jarvis/jarvis_agent/shell_input_handler.py +1 -1
- jarvis/jarvis_code_agent/code_agent.py +5 -3
- jarvis/jarvis_data/config_schema.json +30 -0
- jarvis/jarvis_git_squash/main.py +2 -1
- jarvis/jarvis_platform/human.py +2 -7
- jarvis/jarvis_platform/yuanbao.py +3 -1
- jarvis/jarvis_rag/__init__.py +11 -0
- jarvis/jarvis_rag/cache.py +87 -0
- jarvis/jarvis_rag/cli.py +297 -0
- jarvis/jarvis_rag/embedding_manager.py +109 -0
- jarvis/jarvis_rag/llm_interface.py +130 -0
- jarvis/jarvis_rag/query_rewriter.py +63 -0
- jarvis/jarvis_rag/rag_pipeline.py +177 -0
- jarvis/jarvis_rag/reranker.py +56 -0
- jarvis/jarvis_rag/retriever.py +201 -0
- jarvis/jarvis_tools/search_web.py +127 -11
- jarvis/jarvis_utils/config.py +71 -0
- jarvis/jarvis_utils/git_utils.py +27 -18
- jarvis/jarvis_utils/input.py +21 -10
- jarvis/jarvis_utils/utils.py +43 -20
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/METADATA +87 -5
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/RECORD +28 -19
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/entry_points.txt +1 -0
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/cli.py
ADDED
@@ -0,0 +1,297 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Optional, List, Literal, cast
|
5
|
+
import mimetypes
|
6
|
+
|
7
|
+
import typer
|
8
|
+
from langchain.docstore.document import Document
|
9
|
+
from langchain_community.document_loaders import (
|
10
|
+
TextLoader,
|
11
|
+
UnstructuredMarkdownLoader,
|
12
|
+
)
|
13
|
+
from langchain_core.document_loaders.base import BaseLoader
|
14
|
+
from rich.markdown import Markdown
|
15
|
+
|
16
|
+
from jarvis.jarvis_utils.utils import init_env
|
17
|
+
|
18
|
+
|
19
|
+
def is_likely_text_file(file_path: Path) -> bool:
|
20
|
+
"""
|
21
|
+
Checks if a file is likely to be a text file by reading its beginning.
|
22
|
+
Avoids loading large binary files into memory.
|
23
|
+
"""
|
24
|
+
try:
|
25
|
+
# Heuristic 1: Check MIME type if available
|
26
|
+
mime_type, _ = mimetypes.guess_type(file_path)
|
27
|
+
if mime_type and mime_type.startswith("text/"):
|
28
|
+
return True
|
29
|
+
if mime_type and any(x in mime_type for x in ["json", "xml", "javascript"]):
|
30
|
+
return True
|
31
|
+
|
32
|
+
# Heuristic 2: Check for null bytes in the first few KB
|
33
|
+
with open(file_path, "rb") as f:
|
34
|
+
chunk = f.read(4096) # Read first 4KB
|
35
|
+
if b"\x00" in chunk:
|
36
|
+
return False # Null bytes are a strong indicator of a binary file
|
37
|
+
return True
|
38
|
+
except Exception:
|
39
|
+
return False
|
40
|
+
|
41
|
+
|
42
|
+
# Ensure the project root is in the Python path to allow absolute imports
|
43
|
+
# This makes the script runnable as a module.
|
44
|
+
_project_root = os.path.abspath(
|
45
|
+
os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
46
|
+
)
|
47
|
+
if _project_root not in sys.path:
|
48
|
+
sys.path.insert(0, _project_root)
|
49
|
+
|
50
|
+
from jarvis.jarvis_platform.base import BasePlatform
|
51
|
+
from jarvis.jarvis_platform.registry import PlatformRegistry
|
52
|
+
from jarvis.jarvis_rag.llm_interface import LLMInterface
|
53
|
+
from jarvis.jarvis_rag.rag_pipeline import JarvisRAGPipeline
|
54
|
+
|
55
|
+
app = typer.Typer(
|
56
|
+
name="jarvis-rag",
|
57
|
+
help="A command-line tool to interact with the Jarvis RAG framework.",
|
58
|
+
add_completion=False,
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
class _CustomPlatformLLM(LLMInterface):
|
63
|
+
"""A simple wrapper to make a BasePlatform instance compatible with LLMInterface."""
|
64
|
+
|
65
|
+
def __init__(self, platform: BasePlatform):
|
66
|
+
self.platform = platform
|
67
|
+
print(
|
68
|
+
f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'"
|
69
|
+
)
|
70
|
+
|
71
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
72
|
+
return self.platform.chat_until_success(prompt)
|
73
|
+
|
74
|
+
|
75
|
+
def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInterface]:
|
76
|
+
"""Creates an LLM interface from a specific platform and model."""
|
77
|
+
if not platform_name or not model_name:
|
78
|
+
return None
|
79
|
+
try:
|
80
|
+
registry = PlatformRegistry.get_global_platform_registry()
|
81
|
+
platform_instance = registry.create_platform(platform_name)
|
82
|
+
if not platform_instance:
|
83
|
+
print(f"❌ 错误: 平台 '{platform_name}' 未找到。")
|
84
|
+
return None
|
85
|
+
platform_instance.set_model_name(model_name)
|
86
|
+
platform_instance.set_suppress_output(True)
|
87
|
+
return _CustomPlatformLLM(platform_instance)
|
88
|
+
except Exception as e:
|
89
|
+
print(f"❌ 创建自定义LLM时出错: {e}")
|
90
|
+
return None
|
91
|
+
|
92
|
+
|
93
|
+
@app.command(
|
94
|
+
"add",
|
95
|
+
help="Add documents from files, directories, or glob patterns (e.g., 'src/**/*.py').",
|
96
|
+
)
|
97
|
+
def add_documents(
|
98
|
+
paths: List[Path] = typer.Argument(
|
99
|
+
...,
|
100
|
+
help="File/directory paths or glob patterns. Shell expansion is supported.",
|
101
|
+
),
|
102
|
+
collection_name: str = typer.Option(
|
103
|
+
"jarvis_rag_collection",
|
104
|
+
"--collection",
|
105
|
+
"-c",
|
106
|
+
help="Name of the collection in the vector database.",
|
107
|
+
),
|
108
|
+
embedding_mode: Optional[str] = typer.Option(
|
109
|
+
None,
|
110
|
+
"--embedding-mode",
|
111
|
+
"-e",
|
112
|
+
help="Embedding mode ('performance' or 'accuracy'). Overrides global config.",
|
113
|
+
),
|
114
|
+
db_path: Optional[Path] = typer.Option(
|
115
|
+
None, "--db-path", help="Path to the vector database. Overrides global config."
|
116
|
+
),
|
117
|
+
):
|
118
|
+
"""Adds documents to the RAG knowledge base from various sources."""
|
119
|
+
files_to_process = set()
|
120
|
+
|
121
|
+
for path_str in paths:
|
122
|
+
# Typer with List[Path] might not expand globs, so we do it manually
|
123
|
+
from glob import glob
|
124
|
+
|
125
|
+
expanded_paths = glob(str(path_str), recursive=True)
|
126
|
+
|
127
|
+
for p_str in expanded_paths:
|
128
|
+
path = Path(p_str)
|
129
|
+
if not path.exists():
|
130
|
+
continue
|
131
|
+
|
132
|
+
if path.is_dir():
|
133
|
+
print(f"🔍 正在扫描目录: {path}")
|
134
|
+
for item in path.rglob("*"):
|
135
|
+
if item.is_file() and is_likely_text_file(item):
|
136
|
+
files_to_process.add(item)
|
137
|
+
elif path.is_file():
|
138
|
+
if is_likely_text_file(path):
|
139
|
+
files_to_process.add(path)
|
140
|
+
else:
|
141
|
+
print(f"⚠️ 跳过可能的二进制文件: {path}")
|
142
|
+
|
143
|
+
if not files_to_process:
|
144
|
+
print(f"⚠️ 在指定路径中未找到任何文本文件。")
|
145
|
+
return
|
146
|
+
|
147
|
+
print(f"✅ 发现 {len(files_to_process)} 个独立文件待处理。")
|
148
|
+
|
149
|
+
try:
|
150
|
+
pipeline = JarvisRAGPipeline(
|
151
|
+
embedding_mode=cast(
|
152
|
+
Optional[Literal["performance", "accuracy"]], embedding_mode
|
153
|
+
),
|
154
|
+
db_path=str(db_path) if db_path else None,
|
155
|
+
collection_name=collection_name,
|
156
|
+
)
|
157
|
+
|
158
|
+
docs: List[Document] = []
|
159
|
+
loader: BaseLoader
|
160
|
+
for file_path in sorted(list(files_to_process)):
|
161
|
+
try:
|
162
|
+
if file_path.suffix.lower() == ".md":
|
163
|
+
loader = UnstructuredMarkdownLoader(str(file_path))
|
164
|
+
else: # Default to TextLoader for .txt and all code files
|
165
|
+
loader = TextLoader(str(file_path), encoding="utf-8")
|
166
|
+
|
167
|
+
docs.extend(loader.load())
|
168
|
+
print(f"✅ 已加载: {file_path}")
|
169
|
+
except Exception as e:
|
170
|
+
print(f"⚠️ 加载失败 {file_path}: {e}")
|
171
|
+
|
172
|
+
if not docs:
|
173
|
+
print("❌ 未能成功加载任何文档。")
|
174
|
+
raise typer.Exit(code=1)
|
175
|
+
|
176
|
+
pipeline.add_documents(docs)
|
177
|
+
print(f"✅ 成功将 {len(docs)} 个文档的内容添加至集合 '{collection_name}'。")
|
178
|
+
|
179
|
+
except Exception as e:
|
180
|
+
print(f"❌ 发生严重错误: {e}")
|
181
|
+
raise typer.Exit(code=1)
|
182
|
+
|
183
|
+
|
184
|
+
@app.command("list-docs", help="List all unique documents in the knowledge base.")
|
185
|
+
def list_documents(
|
186
|
+
collection_name: str = typer.Option(
|
187
|
+
"jarvis_rag_collection",
|
188
|
+
"--collection",
|
189
|
+
"-c",
|
190
|
+
help="Name of the collection in the vector database.",
|
191
|
+
),
|
192
|
+
db_path: Optional[Path] = typer.Option(
|
193
|
+
None, "--db-path", help="Path to the vector database. Overrides global config."
|
194
|
+
),
|
195
|
+
):
|
196
|
+
"""Lists all unique documents in the specified collection."""
|
197
|
+
try:
|
198
|
+
pipeline = JarvisRAGPipeline(
|
199
|
+
db_path=str(db_path) if db_path else None,
|
200
|
+
collection_name=collection_name,
|
201
|
+
)
|
202
|
+
|
203
|
+
collection = pipeline.retriever.collection
|
204
|
+
results = collection.get() # Get all items in the collection
|
205
|
+
|
206
|
+
if not results or not results["metadatas"]:
|
207
|
+
print("ℹ️ 知识库中没有找到任何文档。")
|
208
|
+
return
|
209
|
+
|
210
|
+
# Extract unique source file paths from metadata
|
211
|
+
sources = set()
|
212
|
+
for metadata in results["metadatas"]:
|
213
|
+
if metadata:
|
214
|
+
source = metadata.get("source")
|
215
|
+
if isinstance(source, str):
|
216
|
+
sources.add(source)
|
217
|
+
|
218
|
+
if not sources:
|
219
|
+
print("ℹ️ 知识库中没有找到任何带有源信息的文档。")
|
220
|
+
return
|
221
|
+
|
222
|
+
print(f"📚 知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:")
|
223
|
+
for i, source in enumerate(sorted(list(sources)), 1):
|
224
|
+
print(f" {i}. {source}")
|
225
|
+
|
226
|
+
except Exception as e:
|
227
|
+
print(f"❌ 发生错误: {e}")
|
228
|
+
raise typer.Exit(code=1)
|
229
|
+
|
230
|
+
|
231
|
+
@app.command("query", help="Ask a question to the knowledge base.")
|
232
|
+
def query(
|
233
|
+
question: str = typer.Argument(..., help="The question to ask."),
|
234
|
+
collection_name: str = typer.Option(
|
235
|
+
"jarvis_rag_collection",
|
236
|
+
"--collection",
|
237
|
+
"-c",
|
238
|
+
help="Name of the collection in the vector database.",
|
239
|
+
),
|
240
|
+
embedding_mode: Optional[str] = typer.Option(
|
241
|
+
None,
|
242
|
+
"--embedding-mode",
|
243
|
+
"-e",
|
244
|
+
help="Embedding mode ('performance' or 'accuracy'). Overrides global config.",
|
245
|
+
),
|
246
|
+
db_path: Optional[Path] = typer.Option(
|
247
|
+
None, "--db-path", help="Path to the vector database. Overrides global config."
|
248
|
+
),
|
249
|
+
platform: Optional[str] = typer.Option(
|
250
|
+
None,
|
251
|
+
"--platform",
|
252
|
+
"-p",
|
253
|
+
help="Specify a platform name for the LLM. Overrides the default thinking model.",
|
254
|
+
),
|
255
|
+
model: Optional[str] = typer.Option(
|
256
|
+
None,
|
257
|
+
"--model",
|
258
|
+
"-m",
|
259
|
+
help="Specify a model name for the LLM. Requires --platform.",
|
260
|
+
),
|
261
|
+
):
|
262
|
+
"""Queries the RAG knowledge base and prints the answer."""
|
263
|
+
if model and not platform:
|
264
|
+
print("❌ 错误: --model 需要指定 --platform。")
|
265
|
+
raise typer.Exit(code=1)
|
266
|
+
|
267
|
+
try:
|
268
|
+
custom_llm = _create_custom_llm(platform, model) if platform and model else None
|
269
|
+
if (platform or model) and not custom_llm:
|
270
|
+
raise typer.Exit(code=1)
|
271
|
+
|
272
|
+
pipeline = JarvisRAGPipeline(
|
273
|
+
llm=custom_llm,
|
274
|
+
embedding_mode=cast(
|
275
|
+
Optional[Literal["performance", "accuracy"]], embedding_mode
|
276
|
+
),
|
277
|
+
db_path=str(db_path) if db_path else None,
|
278
|
+
collection_name=collection_name,
|
279
|
+
)
|
280
|
+
|
281
|
+
print(f"🤔 正在查询: '{question}'")
|
282
|
+
answer = pipeline.query(question)
|
283
|
+
|
284
|
+
print("💬 答案:")
|
285
|
+
# We can still use rich.markdown.Markdown as PrettyOutput uses rich underneath
|
286
|
+
from jarvis.jarvis_utils.globals import console
|
287
|
+
|
288
|
+
console.print(Markdown(answer))
|
289
|
+
|
290
|
+
except Exception as e:
|
291
|
+
print(f"❌ 发生错误: {e}")
|
292
|
+
raise typer.Exit(code=1)
|
293
|
+
|
294
|
+
|
295
|
+
def main():
|
296
|
+
init_env(welcome_str="Jarvis RAG")
|
297
|
+
app()
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from typing import List, Literal, cast
|
2
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
3
|
+
|
4
|
+
from jarvis.jarvis_utils.config import (
|
5
|
+
get_rag_embedding_models,
|
6
|
+
get_rag_embedding_cache_path,
|
7
|
+
)
|
8
|
+
from .cache import EmbeddingCache
|
9
|
+
|
10
|
+
|
11
|
+
class EmbeddingManager:
|
12
|
+
"""
|
13
|
+
Manages the loading and usage of local embedding models with caching.
|
14
|
+
|
15
|
+
This class handles the selection of embedding models based on a specified
|
16
|
+
mode ('performance' or 'accuracy'), loads the model from Hugging Face,
|
17
|
+
and uses a disk-based cache to avoid re-computing embeddings for the
|
18
|
+
same text.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
mode: Literal["performance", "accuracy"],
|
24
|
+
cache_dir: str,
|
25
|
+
):
|
26
|
+
"""
|
27
|
+
Initializes the EmbeddingManager.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
mode: The desired mode, either 'performance' or 'accuracy'.
|
31
|
+
cache_dir: The directory to store the embedding cache.
|
32
|
+
"""
|
33
|
+
self.mode = mode
|
34
|
+
self.embedding_models = get_rag_embedding_models()
|
35
|
+
if mode not in self.embedding_models:
|
36
|
+
raise ValueError(
|
37
|
+
f"Invalid mode '{mode}'. Must be one of {list(self.embedding_models.keys())}"
|
38
|
+
)
|
39
|
+
|
40
|
+
self.model_config = self.embedding_models[self.mode]
|
41
|
+
self.model_name = self.model_config["model_name"]
|
42
|
+
|
43
|
+
print(f"🚀 初始化嵌入管理器,模式: '{self.mode}', 模型: '{self.model_name}'...")
|
44
|
+
|
45
|
+
# The salt for the cache is the model name to prevent collisions
|
46
|
+
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=str(self.model_name))
|
47
|
+
self.model = self._load_model()
|
48
|
+
|
49
|
+
def _load_model(self) -> HuggingFaceEmbeddings:
|
50
|
+
"""Loads the Hugging Face embedding model based on the configuration."""
|
51
|
+
try:
|
52
|
+
return HuggingFaceEmbeddings(
|
53
|
+
model_name=self.model_name,
|
54
|
+
model_kwargs=self.model_config.get("model_kwargs"),
|
55
|
+
encode_kwargs=self.model_config.get("encode_kwargs"),
|
56
|
+
show_progress=self.model_config.get("show_progress", False),
|
57
|
+
)
|
58
|
+
except Exception as e:
|
59
|
+
print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
|
60
|
+
print("请确保您已安装 'sentence_transformers' 和 'torch'。")
|
61
|
+
raise
|
62
|
+
|
63
|
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
64
|
+
"""
|
65
|
+
Computes embeddings for a list of documents, using the cache.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
texts: A list of documents (strings) to embed.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
A list of embeddings, one for each document.
|
72
|
+
"""
|
73
|
+
if not texts:
|
74
|
+
return []
|
75
|
+
|
76
|
+
# Check cache for existing embeddings
|
77
|
+
cached_embeddings = self.cache.get_batch(texts)
|
78
|
+
|
79
|
+
texts_to_embed = []
|
80
|
+
indices_to_embed = []
|
81
|
+
for i, (text, cached) in enumerate(zip(texts, cached_embeddings)):
|
82
|
+
if cached is None:
|
83
|
+
texts_to_embed.append(text)
|
84
|
+
indices_to_embed.append(i)
|
85
|
+
|
86
|
+
# Compute embeddings for texts that were not in the cache
|
87
|
+
if texts_to_embed:
|
88
|
+
print(
|
89
|
+
f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
|
90
|
+
)
|
91
|
+
new_embeddings = self.model.embed_documents(texts_to_embed)
|
92
|
+
|
93
|
+
# Store new embeddings in the cache
|
94
|
+
self.cache.set_batch(texts_to_embed, new_embeddings)
|
95
|
+
|
96
|
+
# Place new embeddings back into the results list
|
97
|
+
for i, embedding in zip(indices_to_embed, new_embeddings):
|
98
|
+
cached_embeddings[i] = embedding
|
99
|
+
else:
|
100
|
+
print(f"✅ 缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。")
|
101
|
+
|
102
|
+
return cast(List[List[float]], cached_embeddings)
|
103
|
+
|
104
|
+
def embed_query(self, text: str) -> List[float]:
|
105
|
+
"""
|
106
|
+
Computes the embedding for a single query.
|
107
|
+
Queries are typically not cached, but we can add it if needed.
|
108
|
+
"""
|
109
|
+
return self.model.embed_query(text)
|
@@ -0,0 +1,130 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
import os
|
3
|
+
import os
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
|
6
|
+
from jarvis.jarvis_agent import Agent as JarvisAgent
|
7
|
+
from jarvis.jarvis_platform.base import BasePlatform
|
8
|
+
from jarvis.jarvis_platform.registry import PlatformRegistry
|
9
|
+
|
10
|
+
|
11
|
+
class LLMInterface(ABC):
|
12
|
+
"""
|
13
|
+
Abstract Base Class for Large Language Model interfaces.
|
14
|
+
|
15
|
+
This class defines the standard interface for interacting with a remote LLM.
|
16
|
+
Any LLM provider (OpenAI, Anthropic, etc.) should be implemented as a
|
17
|
+
subclass of this interface.
|
18
|
+
"""
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
22
|
+
"""
|
23
|
+
Generates a response from the LLM based on a given prompt.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
prompt: The input prompt to send to the LLM.
|
27
|
+
**kwargs: Additional keyword arguments for the LLM API call
|
28
|
+
(e.g., temperature, max_tokens).
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
The text response generated by the LLM.
|
32
|
+
"""
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class ToolAgent_LLM(LLMInterface):
|
37
|
+
"""
|
38
|
+
An implementation of the LLMInterface that uses a tool-wielding JarvisAgent
|
39
|
+
to generate the final response.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self):
|
43
|
+
"""
|
44
|
+
Initializes the Tool-Agent LLM wrapper.
|
45
|
+
"""
|
46
|
+
print("🤖 已初始化工具 Agent 作为最终应答者。")
|
47
|
+
self.allowed_tools = ["read_code", "execute_script"]
|
48
|
+
# A generic system prompt for the agent
|
49
|
+
self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
|
50
|
+
self.summary_prompt = """
|
51
|
+
<report>
|
52
|
+
请为本次问答任务生成一个总结报告,包含以下内容:
|
53
|
+
|
54
|
+
1. **原始问题**: 重述用户最开始提出的问题。
|
55
|
+
2. **关键信息来源**: 总结你是基于哪些关键信息或文件得出的结论。
|
56
|
+
3. **最终答案**: 给出最终的、精炼的回答。
|
57
|
+
</report>
|
58
|
+
"""
|
59
|
+
|
60
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
61
|
+
"""
|
62
|
+
Runs the JarvisAgent with a restricted toolset to generate an answer.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
prompt: The full prompt, including context, to be sent to the agent.
|
66
|
+
**kwargs: Ignored, kept for interface compatibility.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
The final answer generated by the agent.
|
70
|
+
"""
|
71
|
+
try:
|
72
|
+
# Initialize the agent with specific settings for RAG context
|
73
|
+
agent = JarvisAgent(
|
74
|
+
system_prompt=self.system_prompt,
|
75
|
+
use_tools=self.allowed_tools,
|
76
|
+
auto_complete=True,
|
77
|
+
use_methodology=False,
|
78
|
+
use_analysis=False,
|
79
|
+
need_summary=True,
|
80
|
+
summary_prompt=self.summary_prompt,
|
81
|
+
)
|
82
|
+
|
83
|
+
# The agent's run method expects the 'user_input' parameter
|
84
|
+
final_answer = agent.run(user_input=prompt)
|
85
|
+
return str(final_answer)
|
86
|
+
|
87
|
+
except Exception as e:
|
88
|
+
print(f"❌ Agent 在执行过程中发生错误: {e}")
|
89
|
+
return "错误: Agent 未能成功生成回答。"
|
90
|
+
|
91
|
+
|
92
|
+
class JarvisPlatform_LLM(LLMInterface):
|
93
|
+
"""
|
94
|
+
An implementation of the LLMInterface for the project's internal platform.
|
95
|
+
|
96
|
+
This class uses the PlatformRegistry to get the configured "normal" model.
|
97
|
+
"""
|
98
|
+
|
99
|
+
def __init__(self):
|
100
|
+
"""
|
101
|
+
Initializes the Jarvis Platform LLM client.
|
102
|
+
"""
|
103
|
+
try:
|
104
|
+
self.registry = PlatformRegistry.get_global_platform_registry()
|
105
|
+
self.platform: BasePlatform = self.registry.get_normal_platform()
|
106
|
+
self.platform.set_suppress_output(
|
107
|
+
False
|
108
|
+
) # Ensure no console output from the model
|
109
|
+
print(f"🚀 已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}")
|
110
|
+
except Exception as e:
|
111
|
+
print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
|
112
|
+
raise
|
113
|
+
|
114
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
115
|
+
"""
|
116
|
+
Sends a prompt to the local platform model and returns the response.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
prompt: The user's prompt.
|
120
|
+
**kwargs: Ignored, kept for interface compatibility.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
The response generated by the platform model.
|
124
|
+
"""
|
125
|
+
try:
|
126
|
+
# Use the robust chat_until_success method
|
127
|
+
return self.platform.chat_until_success(prompt)
|
128
|
+
except Exception as e:
|
129
|
+
print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
|
130
|
+
return "错误: 无法从本地LLM获取响应。"
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from typing import List
|
2
|
+
from .llm_interface import LLMInterface
|
3
|
+
|
4
|
+
|
5
|
+
class QueryRewriter:
|
6
|
+
"""
|
7
|
+
Uses an LLM to rewrite a user's query into multiple, diverse search
|
8
|
+
queries to enhance retrieval recall.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, llm: LLMInterface):
|
12
|
+
"""
|
13
|
+
Initializes the QueryRewriter.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
llm: An instance of a class implementing LLMInterface.
|
17
|
+
"""
|
18
|
+
self.llm = llm
|
19
|
+
self.rewrite_prompt_template = self._create_prompt_template()
|
20
|
+
|
21
|
+
def _create_prompt_template(self) -> str:
|
22
|
+
"""Creates the prompt template for the multi-query rewriting task."""
|
23
|
+
return """
|
24
|
+
你是一个精通检索的AI助手。你的任务是将以下这个单一的用户问题,从不同角度改写成 3 个不同的、但语义上相关的搜索查询。这有助于在知识库中进行更全面的搜索。
|
25
|
+
|
26
|
+
请遵循以下原则:
|
27
|
+
1. **多样性**:生成的查询应尝试使用不同的关键词和表述方式。
|
28
|
+
2. **保留核心意图**:所有查询都必须围绕原始问题的核心意图。
|
29
|
+
3. **简洁性**:每个查询都应该是独立的、可以直接用于搜索的短语或问题。
|
30
|
+
4. **格式要求**:请直接输出 3 个查询,每个查询占一行,用换行符分隔。不要添加任何编号、前缀或解释。
|
31
|
+
|
32
|
+
原始问题:
|
33
|
+
---
|
34
|
+
{query}
|
35
|
+
---
|
36
|
+
|
37
|
+
3个改写后的查询 (每行一个):
|
38
|
+
"""
|
39
|
+
|
40
|
+
def rewrite(self, query: str) -> List[str]:
|
41
|
+
"""
|
42
|
+
Rewrites the user query into multiple queries using the LLM.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
query: The original user query.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
A list of rewritten, search-optimized queries.
|
49
|
+
"""
|
50
|
+
prompt = self.rewrite_prompt_template.format(query=query)
|
51
|
+
print(f"✍️ 正在将原始查询重写为多个搜索查询...")
|
52
|
+
|
53
|
+
response_text = self.llm.generate(prompt)
|
54
|
+
rewritten_queries = [
|
55
|
+
line.strip() for line in response_text.strip().split("\n") if line.strip()
|
56
|
+
]
|
57
|
+
|
58
|
+
# Also include the original query for robustness
|
59
|
+
if query not in rewritten_queries:
|
60
|
+
rewritten_queries.insert(0, query)
|
61
|
+
|
62
|
+
print(f"✅ 生成了 {len(rewritten_queries)} 个查询变体。")
|
63
|
+
return rewritten_queries
|