jarvis-ai-assistant 0.1.218__py3-none-any.whl → 0.1.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +37 -92
  3. jarvis/jarvis_agent/shell_input_handler.py +1 -1
  4. jarvis/jarvis_code_agent/code_agent.py +5 -3
  5. jarvis/jarvis_data/config_schema.json +30 -0
  6. jarvis/jarvis_git_squash/main.py +2 -1
  7. jarvis/jarvis_platform/human.py +2 -7
  8. jarvis/jarvis_platform/yuanbao.py +3 -1
  9. jarvis/jarvis_rag/__init__.py +11 -0
  10. jarvis/jarvis_rag/cache.py +87 -0
  11. jarvis/jarvis_rag/cli.py +297 -0
  12. jarvis/jarvis_rag/embedding_manager.py +109 -0
  13. jarvis/jarvis_rag/llm_interface.py +130 -0
  14. jarvis/jarvis_rag/query_rewriter.py +63 -0
  15. jarvis/jarvis_rag/rag_pipeline.py +177 -0
  16. jarvis/jarvis_rag/reranker.py +56 -0
  17. jarvis/jarvis_rag/retriever.py +201 -0
  18. jarvis/jarvis_tools/search_web.py +127 -11
  19. jarvis/jarvis_utils/config.py +71 -0
  20. jarvis/jarvis_utils/git_utils.py +27 -18
  21. jarvis/jarvis_utils/input.py +21 -10
  22. jarvis/jarvis_utils/utils.py +43 -20
  23. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/METADATA +87 -5
  24. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/RECORD +28 -19
  25. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/entry_points.txt +1 -0
  26. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/WHEEL +0 -0
  27. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/licenses/LICENSE +0 -0
  28. {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ import os
2
+ from typing import List, Literal, Optional, cast
3
+
4
+ from langchain.docstore.document import Document
5
+
6
+ from .embedding_manager import EmbeddingManager
7
+ from .llm_interface import JarvisPlatform_LLM, LLMInterface, ToolAgent_LLM
8
+ from .query_rewriter import QueryRewriter
9
+ from .reranker import Reranker
10
+ from .retriever import ChromaRetriever
11
+ from jarvis.jarvis_utils.config import (
12
+ get_rag_embedding_mode,
13
+ get_rag_vector_db_path,
14
+ get_rag_embedding_cache_path,
15
+ get_rag_embedding_models,
16
+ )
17
+
18
+
19
+ class JarvisRAGPipeline:
20
+ """
21
+ The main orchestrator for the RAG pipeline.
22
+
23
+ This class integrates the embedding manager, retriever, and LLM to provide
24
+ a complete pipeline for adding documents and querying them.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ llm: Optional[LLMInterface] = None,
30
+ embedding_mode: Optional[Literal["performance", "accuracy"]] = None,
31
+ db_path: Optional[str] = None,
32
+ collection_name: str = "jarvis_rag_collection",
33
+ ):
34
+ """
35
+ Initializes the RAG pipeline.
36
+
37
+ Args:
38
+ llm: An instance of a class implementing LLMInterface.
39
+ If None, defaults to the ToolAgent_LLM.
40
+ embedding_mode: The mode for the local embedding model. If None, uses config value.
41
+ db_path: Path to the persistent vector database. If None, uses config value.
42
+ collection_name: Name of the collection in the vector database.
43
+ """
44
+ # Determine the embedding model to isolate data paths
45
+ _embedding_mode = embedding_mode or get_rag_embedding_mode()
46
+ embedding_models = get_rag_embedding_models()
47
+ model_name = embedding_models[_embedding_mode]["model_name"]
48
+ sanitized_model_name = model_name.replace("/", "_").replace("\\", "_")
49
+
50
+ # If a specific db_path is given, use it. Otherwise, create a model-specific path.
51
+ _final_db_path = (
52
+ str(db_path)
53
+ if db_path
54
+ else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
55
+ )
56
+ # Always create a model-specific cache path.
57
+ _final_cache_path = os.path.join(
58
+ get_rag_embedding_cache_path(), sanitized_model_name
59
+ )
60
+
61
+ self.embedding_manager = EmbeddingManager(
62
+ mode=cast(Literal["performance", "accuracy"], _embedding_mode),
63
+ cache_dir=_final_cache_path,
64
+ )
65
+ self.retriever = ChromaRetriever(
66
+ embedding_manager=self.embedding_manager,
67
+ db_path=_final_db_path,
68
+ collection_name=collection_name,
69
+ )
70
+ # Default to the ToolAgent_LLM unless a specific LLM is provided
71
+ self.llm = llm if llm is not None else ToolAgent_LLM()
72
+ self.reranker = Reranker()
73
+ # Use a standard LLM for the query rewriting task, not the agent
74
+ self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
75
+
76
+ print("✅ JarvisRAGPipeline 初始化成功。")
77
+
78
+ def add_documents(self, documents: List[Document]):
79
+ """
80
+ Adds documents to the vector knowledge base.
81
+
82
+ Args:
83
+ documents: A list of LangChain Document objects to add.
84
+ """
85
+ self.retriever.add_documents(documents)
86
+
87
+ def _create_prompt(
88
+ self, query: str, context_docs: List[Document], source_files: List[str]
89
+ ) -> str:
90
+ """Creates the final prompt for the LLM or Agent."""
91
+ context = "\n\n".join([doc.page_content for doc in context_docs])
92
+ sources_text = "\n".join([f"- {source}" for source in source_files])
93
+
94
+ prompt_template = f"""
95
+ 你是一个专家助手。请根据用户的问题,结合下面提供的参考信息来回答。
96
+
97
+ **重要**: 提供的上下文和文件列表**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
98
+
99
+ 参考文件列表:
100
+ ---
101
+ {sources_text}
102
+ ---
103
+
104
+ 参考上下文:
105
+ ---
106
+ {context}
107
+ ---
108
+
109
+ 问题: {query}
110
+
111
+ 回答:
112
+ """
113
+ return prompt_template.strip()
114
+
115
+ def query(self, query_text: str, n_results: int = 5) -> str:
116
+ """
117
+ Performs a query against the knowledge base using a multi-query
118
+ retrieval and reranking pipeline.
119
+
120
+ Args:
121
+ query_text: The user's original question.
122
+ n_results: The number of final relevant chunks to retrieve.
123
+
124
+ Returns:
125
+ The answer generated by the LLM.
126
+ """
127
+ # 1. Rewrite the original query into multiple queries
128
+ rewritten_queries = self.query_rewriter.rewrite(query_text)
129
+
130
+ # 2. Retrieve initial candidates for each rewritten query
131
+ all_candidate_docs = []
132
+ for q in rewritten_queries:
133
+ print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
134
+ candidates = self.retriever.retrieve(q, n_results=n_results * 2)
135
+ all_candidate_docs.extend(candidates)
136
+
137
+ # De-duplicate the candidate documents
138
+ unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
139
+ unique_candidate_docs = list(unique_docs_dict.values())
140
+
141
+ if not unique_candidate_docs:
142
+ return "我在提供的文档中找不到任何相关信息来回答您的问题。"
143
+
144
+ # 3. Rerank the unified candidate pool against the *original* query
145
+ print(
146
+ f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
147
+ )
148
+ retrieved_docs = self.reranker.rerank(
149
+ query_text, unique_candidate_docs, top_n=n_results
150
+ )
151
+
152
+ if not retrieved_docs:
153
+ return "我在提供的文档中找不到任何相关信息来回答您的问题。"
154
+
155
+ # Print the sources of the final retrieved documents
156
+ sources = sorted(
157
+ list(
158
+ {
159
+ doc.metadata["source"]
160
+ for doc in retrieved_docs
161
+ if "source" in doc.metadata
162
+ }
163
+ )
164
+ )
165
+ if sources:
166
+ print(f"📚 根据以下文档回答:")
167
+ for source in sources:
168
+ print(f" - {source}")
169
+
170
+ # 4. Create the final prompt and generate the answer
171
+ # We use the original query_text for the final prompt to the LLM
172
+ prompt = self._create_prompt(query_text, retrieved_docs, sources)
173
+
174
+ print("🤖 正在从LLM生成答案...")
175
+ answer = self.llm.generate(prompt)
176
+
177
+ return answer
@@ -0,0 +1,56 @@
1
+ from typing import List
2
+
3
+ from langchain.docstore.document import Document
4
+ from sentence_transformers.cross_encoder import ( # type: ignore
5
+ CrossEncoder,
6
+ )
7
+
8
+
9
+ class Reranker:
10
+ """
11
+ A reranker class that uses a Cross-Encoder model to re-score and sort
12
+ documents based on their relevance to a given query.
13
+ """
14
+
15
+ def __init__(self, model_name: str = "BAAI/bge-reranker-base"):
16
+ """
17
+ Initializes the Reranker.
18
+
19
+ Args:
20
+ model_name (str): The name of the Cross-Encoder model to use.
21
+ """
22
+ print(f"🔍 正在初始化重排模型: {model_name}...")
23
+ self.model = CrossEncoder(model_name)
24
+ print("✅ 重排模型初始化成功。")
25
+
26
+ def rerank(
27
+ self, query: str, documents: List[Document], top_n: int = 5
28
+ ) -> List[Document]:
29
+ """
30
+ Reranks a list of documents based on their relevance to the query.
31
+
32
+ Args:
33
+ query (str): The user's query.
34
+ documents (List[Document]): The list of documents retrieved from the initial search.
35
+ top_n (int): The number of top documents to return after reranking.
36
+
37
+ Returns:
38
+ List[Document]: A sorted list of the most relevant documents.
39
+ """
40
+ if not documents:
41
+ return []
42
+
43
+ # Create pairs of [query, document_content] for scoring
44
+ pairs = [[query, doc.page_content] for doc in documents]
45
+
46
+ # Get scores from the Cross-Encoder model
47
+ scores = self.model.predict(pairs)
48
+
49
+ # Combine documents with their scores and sort
50
+ doc_with_scores = list(zip(documents, scores))
51
+ doc_with_scores.sort(key=lambda x: x[1], reverse=True)
52
+
53
+ # Return the top N documents
54
+ reranked_docs = [doc for doc, score in doc_with_scores[:top_n]]
55
+
56
+ return reranked_docs
@@ -0,0 +1,201 @@
1
+ import os
2
+ import pickle
3
+ from typing import Any, Dict, List, cast
4
+
5
+ import chromadb
6
+ from langchain.docstore.document import Document
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from rank_bm25 import BM25Okapi # type: ignore
9
+
10
+ from .embedding_manager import EmbeddingManager
11
+
12
+
13
+ class ChromaRetriever:
14
+ """
15
+ A retriever class that combines dense vector search (ChromaDB) and
16
+ sparse keyword search (BM25) for hybrid retrieval.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ embedding_manager: EmbeddingManager,
22
+ db_path: str,
23
+ collection_name: str = "jarvis_rag_collection",
24
+ ):
25
+ """
26
+ Initializes the ChromaRetriever.
27
+
28
+ Args:
29
+ embedding_manager: An instance of EmbeddingManager.
30
+ db_path: The file path for ChromaDB's persistent storage.
31
+ collection_name: The name of the collection within ChromaDB.
32
+ """
33
+ self.embedding_manager = embedding_manager
34
+ self.db_path = db_path
35
+ self.collection_name = collection_name
36
+
37
+ # Initialize ChromaDB client
38
+ self.client = chromadb.PersistentClient(path=self.db_path)
39
+ self.collection = self.client.get_or_create_collection(
40
+ name=self.collection_name
41
+ )
42
+ print(
43
+ f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。"
44
+ )
45
+
46
+ # BM25 Index setup
47
+ self.bm25_index_path = os.path.join(self.db_path, f"{collection_name}_bm25.pkl")
48
+ self._load_or_initialize_bm25()
49
+
50
+ def _load_or_initialize_bm25(self):
51
+ """Loads the BM25 index from disk or initializes a new one."""
52
+ if os.path.exists(self.bm25_index_path):
53
+ print("🔍 正在加载现有的 BM25 索引...")
54
+ with open(self.bm25_index_path, "rb") as f:
55
+ data = pickle.load(f)
56
+ self.bm25_corpus = data["corpus"]
57
+ self.bm25_index = BM25Okapi(self.bm25_corpus)
58
+ print("✅ BM25 索引加载成功。")
59
+ else:
60
+ print("⚠️ 未找到 BM25 索引,将初始化一个新的。")
61
+ self.bm25_corpus = []
62
+ self.bm25_index = None
63
+
64
+ def _save_bm25_index(self):
65
+ """Saves the BM25 index to disk."""
66
+ if self.bm25_index:
67
+ print("💾 正在保存 BM25 索引...")
68
+ with open(self.bm25_index_path, "wb") as f:
69
+ pickle.dump({"corpus": self.bm25_corpus, "index": self.bm25_index}, f)
70
+ print("✅ BM25 索引保存成功。")
71
+
72
+ def add_documents(
73
+ self, documents: List[Document], chunk_size=1000, chunk_overlap=100
74
+ ):
75
+ """
76
+ Splits, embeds, and adds documents to both ChromaDB and the BM25 index.
77
+ """
78
+ text_splitter = RecursiveCharacterTextSplitter(
79
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
80
+ )
81
+ chunks = text_splitter.split_documents(documents)
82
+
83
+ print(f"📄 已将 {len(documents)} 个文档拆分为 {len(chunks)} 个块。")
84
+
85
+ if not chunks:
86
+ return
87
+
88
+ # Extract content, metadata, and generate IDs
89
+ chunk_texts = [chunk.page_content for chunk in chunks]
90
+ metadatas = [chunk.metadata for chunk in chunks]
91
+ start_id = self.collection.count()
92
+ ids = [f"doc_{i}" for i in range(start_id, start_id + len(chunks))]
93
+
94
+ # Add to ChromaDB
95
+ embeddings = self.embedding_manager.embed_documents(chunk_texts)
96
+ self.collection.add(
97
+ ids=ids,
98
+ embeddings=cast(Any, embeddings),
99
+ documents=chunk_texts,
100
+ metadatas=cast(Any, metadatas),
101
+ )
102
+ print(f"✅ 成功将 {len(chunks)} 个块添加到 ChromaDB 集合中。")
103
+
104
+ # Update and save BM25 index
105
+ tokenized_chunks = [doc.split() for doc in chunk_texts]
106
+ self.bm25_corpus.extend(tokenized_chunks)
107
+ self.bm25_index = BM25Okapi(self.bm25_corpus)
108
+ self._save_bm25_index()
109
+
110
+ def retrieve(self, query: str, n_results: int = 5) -> List[Document]:
111
+ """
112
+ Performs hybrid retrieval using both vector search and BM25,
113
+ then fuses the results using Reciprocal Rank Fusion (RRF).
114
+ """
115
+ # 1. Vector Search (ChromaDB)
116
+ query_embedding = self.embedding_manager.embed_query(query)
117
+ vector_results = self.collection.query(
118
+ query_embeddings=cast(Any, [query_embedding]),
119
+ n_results=n_results * 2, # Retrieve more results for fusion
120
+ )
121
+
122
+ # 2. Keyword Search (BM25)
123
+ bm25_docs = []
124
+ if self.bm25_index:
125
+ tokenized_query = query.split()
126
+ doc_scores = self.bm25_index.get_scores(tokenized_query)
127
+
128
+ # Get all documents from Chroma to match with BM25 scores
129
+ all_docs_in_collection = self.collection.get()
130
+ all_documents = all_docs_in_collection.get("documents")
131
+ all_metadatas = all_docs_in_collection.get("metadatas")
132
+
133
+ bm25_results_with_docs = []
134
+ if all_documents and all_metadatas:
135
+ # Create a mapping from index to document
136
+ bm25_results_with_docs = [
137
+ (
138
+ all_documents[i],
139
+ all_metadatas[i],
140
+ score,
141
+ )
142
+ for i, score in enumerate(doc_scores)
143
+ if score > 0
144
+ ]
145
+
146
+ # Sort by score and take top results
147
+ bm25_results_with_docs.sort(key=lambda x: x[2], reverse=True)
148
+
149
+ for doc_text, metadata, _ in bm25_results_with_docs[: n_results * 2]:
150
+ bm25_docs.append(Document(page_content=doc_text, metadata=metadata))
151
+
152
+ # 3. Reciprocal Rank Fusion (RRF)
153
+ fused_scores: Dict[str, float] = {}
154
+ k = 60 # RRF ranking constant
155
+
156
+ # Process vector results
157
+ if vector_results and vector_results["ids"] and vector_results["documents"]:
158
+ vec_ids = vector_results["ids"][0]
159
+ vec_texts = vector_results["documents"][0]
160
+
161
+ for rank, doc_id in enumerate(vec_ids):
162
+ fused_scores[doc_id] = fused_scores.get(doc_id, 0) + 1 / (k + rank)
163
+
164
+ # Create a map from document text to its ID for BM25 fusion
165
+ doc_text_to_id = {text: doc_id for text, doc_id in zip(vec_texts, vec_ids)}
166
+
167
+ for rank, doc in enumerate(bm25_docs):
168
+ bm25_doc_id = doc_text_to_id.get(doc.page_content)
169
+ if bm25_doc_id:
170
+ fused_scores[bm25_doc_id] = fused_scores.get(bm25_doc_id, 0) + 1 / (
171
+ k + rank
172
+ )
173
+
174
+ # Sort fused results
175
+ sorted_fused_results = sorted(
176
+ fused_scores.items(), key=lambda x: x[1], reverse=True
177
+ )
178
+
179
+ # Get the final documents from ChromaDB based on fused ranking
180
+ final_doc_ids = [item[0] for item in sorted_fused_results[:n_results]]
181
+
182
+ if not final_doc_ids:
183
+ return []
184
+
185
+ final_docs_data = self.collection.get(ids=final_doc_ids)
186
+
187
+ retrieved_docs = []
188
+ if final_docs_data:
189
+ final_documents = final_docs_data.get("documents")
190
+ final_metadatas = final_docs_data.get("metadatas")
191
+
192
+ if final_documents and final_metadatas:
193
+ for doc_text, metadata in zip(final_documents, final_metadatas):
194
+ if doc_text is not None and metadata is not None:
195
+ retrieved_docs.append(
196
+ Document(
197
+ page_content=cast(str, doc_text), metadata=metadata
198
+ )
199
+ )
200
+
201
+ return retrieved_docs
@@ -1,10 +1,20 @@
1
1
  # -*- coding: utf-8 -*-
2
+ """A tool for searching the web."""
2
3
  from typing import Any, Dict
3
4
 
5
+ import httpx
6
+ from bs4 import BeautifulSoup
7
+ from ddgs import DDGS
8
+
9
+ from jarvis.jarvis_agent import Agent
4
10
  from jarvis.jarvis_platform.registry import PlatformRegistry
11
+ from jarvis.jarvis_utils.http import get as http_get
12
+ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
5
13
 
6
14
 
7
15
  class SearchWebTool:
16
+ """A class to handle web searches."""
17
+
8
18
  name = "search_web"
9
19
  description = "搜索互联网上的信息"
10
20
  parameters = {
@@ -12,18 +22,124 @@ class SearchWebTool:
12
22
  "properties": {"query": {"type": "string", "description": "具体的问题"}},
13
23
  }
14
24
 
15
- def execute(self, args: Dict[str, Any]) -> Dict[str, Any]: # type: ignore
25
+ def _search_with_ddgs(self, query: str, agent: Agent) -> Dict[str, Any]:
26
+ # pylint: disable=too-many-locals, broad-except
27
+ """Performs a web search, scrapes content, and summarizes the results."""
28
+ try:
29
+ PrettyOutput.print("▶️ 使用 DuckDuckGo 开始网页搜索...", OutputType.INFO)
30
+ results = list(DDGS().text(query, max_results=5))
31
+
32
+ if not results:
33
+ return {
34
+ "stdout": "未找到搜索结果。",
35
+ "stderr": "未找到搜索结果。",
36
+ "success": False,
37
+ }
38
+
39
+ urls = [r["href"] for r in results]
40
+ full_content = ""
41
+ visited_urls = []
42
+
43
+ for url in urls:
44
+ try:
45
+ PrettyOutput.print(f"📄 正在抓取内容: {url}", OutputType.INFO)
46
+ response = http_get(url, timeout=10.0, follow_redirects=True)
47
+ soup = BeautifulSoup(response.text, "lxml")
48
+ body = soup.find("body")
49
+ if body:
50
+ full_content += body.get_text(" ", strip=True) + "\n\n"
51
+ visited_urls.append(url)
52
+ except httpx.HTTPStatusError as e:
53
+ PrettyOutput.print(
54
+ f"⚠️ HTTP错误 {e.response.status_code} 访问 {url}",
55
+ OutputType.WARNING,
56
+ )
57
+ except httpx.RequestError as e:
58
+ PrettyOutput.print(f"⚠️ 请求错误: {e}", OutputType.WARNING)
59
+
60
+ if not full_content.strip():
61
+ return {
62
+ "stdout": "无法从任何URL抓取有效内容。",
63
+ "stderr": "抓取内容失败。",
64
+ "success": False,
65
+ }
66
+
67
+ url_list_str = "\n".join(f" - {u}" for u in visited_urls)
68
+ PrettyOutput.print(
69
+ f"🔍 已成功访问并处理以下URL:\n{url_list_str}", OutputType.INFO
70
+ )
71
+
72
+ PrettyOutput.print("🧠 正在总结内容...", OutputType.INFO)
73
+ summary_prompt = f"请为查询“{query}”总结以下内容:\n\n{full_content}"
74
+
75
+ if not agent.model:
76
+ return {
77
+ "stdout": "",
78
+ "stderr": "用于总结的Agent模型未找到。",
79
+ "success": False,
80
+ }
81
+
82
+ platform_name = agent.model.platform_name()
83
+ model_name = agent.model.name()
84
+
85
+ model = PlatformRegistry().create_platform(platform_name)
86
+ if not model:
87
+ return {
88
+ "stdout": "",
89
+ "stderr": "无法创建用于总结的模型。",
90
+ "success": False,
91
+ }
92
+
93
+ model.set_model_name(model_name)
94
+ model.set_suppress_output(False)
95
+ summary = model.chat_until_success(summary_prompt)
96
+
97
+ return {"stdout": summary, "stderr": "", "success": True}
98
+
99
+ except Exception as e:
100
+ PrettyOutput.print(f"❌ 网页搜索过程中发生错误: {e}", OutputType.ERROR)
101
+ return {
102
+ "stdout": "",
103
+ "stderr": f"网页搜索过程中发生错误: {e}",
104
+ "success": False,
105
+ }
106
+
107
+ def execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
108
+ """
109
+ Executes the web search.
110
+
111
+ If the agent's model supports a native web search, it uses it.
112
+ Otherwise, it falls back to using DuckDuckGo Search and scraping pages.
113
+ """
16
114
  query = args.get("query")
17
- model = PlatformRegistry().get_normal_platform()
18
- model.set_web(True)
19
- model.set_suppress_output(False) # type: ignore
20
- return {
21
- "stdout": model.chat_until_success(query), # type: ignore
22
- "stderr": "",
23
- "success": True,
24
- }
115
+ agent = args.get("agent")
116
+
117
+ if not query:
118
+ return {"stdout": "", "stderr": "缺少查询参数。", "success": False}
119
+
120
+ if not isinstance(agent, Agent) or not agent.model:
121
+ return {
122
+ "stdout": "",
123
+ "stderr": "Agent或Agent模型未找到。",
124
+ "success": False,
125
+ }
126
+
127
+ if agent.model.support_web():
128
+ model = PlatformRegistry().create_platform(agent.model.platform_name())
129
+ if not model:
130
+ return {"stdout": "", "stderr": "无法创建模型。", "success": False}
131
+ model.set_model_name(agent.model.name())
132
+ model.set_web(True)
133
+ model.set_suppress_output(False)
134
+ return {
135
+ "stdout": model.chat_until_success(query),
136
+ "stderr": "",
137
+ "success": True,
138
+ }
139
+
140
+ return self._search_with_ddgs(query, agent)
25
141
 
26
142
  @staticmethod
27
143
  def check() -> bool:
28
- """检查当前平台是否支持web功能"""
29
- return PlatformRegistry().get_normal_platform().support_web()
144
+ """Check if the tool is available."""
145
+ return True
@@ -3,6 +3,7 @@ import os
3
3
  from functools import lru_cache
4
4
  from typing import Any, Dict, List
5
5
 
6
+ import torch
6
7
  import yaml # type: ignore
7
8
 
8
9
  from jarvis.jarvis_utils.builtin_replace_map import BUILTIN_REPLACE_MAP
@@ -248,3 +249,73 @@ def get_mcp_config() -> List[Dict[str, Any]]:
248
249
  List[Dict[str, Any]]: MCP配置项列表,如果未配置则返回空列表
249
250
  """
250
251
  return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
252
+
253
+
254
+ # ==============================================================================
255
+ # RAG Framework Configuration
256
+ # ==============================================================================
257
+
258
+ EMBEDDING_MODELS = {
259
+ "performance": {
260
+ "model_name": "BAAI/bge-base-zh-v1.5",
261
+ "model_kwargs": {"device": "cuda" if torch.cuda.is_available() else "cpu"},
262
+ "encode_kwargs": {"normalize_embeddings": True},
263
+ "show_progress": True,
264
+ },
265
+ "accuracy": {
266
+ "model_name": "BAAI/bge-large-zh-v1.5",
267
+ "model_kwargs": {"device": "cuda" if torch.cuda.is_available() else "cpu"},
268
+ "encode_kwargs": {"normalize_embeddings": True},
269
+ "show_progress": True,
270
+ },
271
+ }
272
+
273
+
274
+ def get_rag_config() -> Dict[str, Any]:
275
+ """
276
+ 获取RAG框架的配置。
277
+
278
+ 返回:
279
+ Dict[str, Any]: RAG配置字典
280
+ """
281
+ return GLOBAL_CONFIG_DATA.get("JARVIS_RAG", {})
282
+
283
+
284
+ def get_rag_embedding_models() -> Dict[str, Any]:
285
+ """
286
+ 获取RAG嵌入模型的定义。
287
+
288
+ 返回:
289
+ Dict[str, Any]: 嵌入模型配置字典
290
+ """
291
+ return EMBEDDING_MODELS
292
+
293
+
294
+ def get_rag_embedding_mode() -> str:
295
+ """
296
+ 获取RAG嵌入模型的模式。
297
+
298
+ 返回:
299
+ str: 'performance' 或 'accuracy'
300
+ """
301
+ return get_rag_config().get("embedding_mode", "performance")
302
+
303
+
304
+ def get_rag_embedding_cache_path() -> str:
305
+ """
306
+ 获取RAG嵌入缓存的路径。
307
+
308
+ 返回:
309
+ str: 缓存路径
310
+ """
311
+ return get_rag_config().get("embedding_cache_path", ".jarvis/rag/embeddings")
312
+
313
+
314
+ def get_rag_vector_db_path() -> str:
315
+ """
316
+ 获取RAG向量数据库的路径。
317
+
318
+ 返回:
319
+ str: 数据库路径
320
+ """
321
+ return get_rag_config().get("vector_db_path", ".jarvis/rag/vectordb")