jarvis-ai-assistant 0.1.218__py3-none-any.whl → 0.1.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +37 -92
- jarvis/jarvis_agent/shell_input_handler.py +1 -1
- jarvis/jarvis_code_agent/code_agent.py +5 -3
- jarvis/jarvis_data/config_schema.json +30 -0
- jarvis/jarvis_git_squash/main.py +2 -1
- jarvis/jarvis_platform/human.py +2 -7
- jarvis/jarvis_platform/yuanbao.py +3 -1
- jarvis/jarvis_rag/__init__.py +11 -0
- jarvis/jarvis_rag/cache.py +87 -0
- jarvis/jarvis_rag/cli.py +297 -0
- jarvis/jarvis_rag/embedding_manager.py +109 -0
- jarvis/jarvis_rag/llm_interface.py +130 -0
- jarvis/jarvis_rag/query_rewriter.py +63 -0
- jarvis/jarvis_rag/rag_pipeline.py +177 -0
- jarvis/jarvis_rag/reranker.py +56 -0
- jarvis/jarvis_rag/retriever.py +201 -0
- jarvis/jarvis_tools/search_web.py +127 -11
- jarvis/jarvis_utils/config.py +71 -0
- jarvis/jarvis_utils/git_utils.py +27 -18
- jarvis/jarvis_utils/input.py +21 -10
- jarvis/jarvis_utils/utils.py +43 -20
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/METADATA +87 -5
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/RECORD +28 -19
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/entry_points.txt +1 -0
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.218.dist-info → jarvis_ai_assistant-0.1.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
import os
|
2
|
+
from typing import List, Literal, Optional, cast
|
3
|
+
|
4
|
+
from langchain.docstore.document import Document
|
5
|
+
|
6
|
+
from .embedding_manager import EmbeddingManager
|
7
|
+
from .llm_interface import JarvisPlatform_LLM, LLMInterface, ToolAgent_LLM
|
8
|
+
from .query_rewriter import QueryRewriter
|
9
|
+
from .reranker import Reranker
|
10
|
+
from .retriever import ChromaRetriever
|
11
|
+
from jarvis.jarvis_utils.config import (
|
12
|
+
get_rag_embedding_mode,
|
13
|
+
get_rag_vector_db_path,
|
14
|
+
get_rag_embedding_cache_path,
|
15
|
+
get_rag_embedding_models,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
class JarvisRAGPipeline:
|
20
|
+
"""
|
21
|
+
The main orchestrator for the RAG pipeline.
|
22
|
+
|
23
|
+
This class integrates the embedding manager, retriever, and LLM to provide
|
24
|
+
a complete pipeline for adding documents and querying them.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
llm: Optional[LLMInterface] = None,
|
30
|
+
embedding_mode: Optional[Literal["performance", "accuracy"]] = None,
|
31
|
+
db_path: Optional[str] = None,
|
32
|
+
collection_name: str = "jarvis_rag_collection",
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Initializes the RAG pipeline.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
llm: An instance of a class implementing LLMInterface.
|
39
|
+
If None, defaults to the ToolAgent_LLM.
|
40
|
+
embedding_mode: The mode for the local embedding model. If None, uses config value.
|
41
|
+
db_path: Path to the persistent vector database. If None, uses config value.
|
42
|
+
collection_name: Name of the collection in the vector database.
|
43
|
+
"""
|
44
|
+
# Determine the embedding model to isolate data paths
|
45
|
+
_embedding_mode = embedding_mode or get_rag_embedding_mode()
|
46
|
+
embedding_models = get_rag_embedding_models()
|
47
|
+
model_name = embedding_models[_embedding_mode]["model_name"]
|
48
|
+
sanitized_model_name = model_name.replace("/", "_").replace("\\", "_")
|
49
|
+
|
50
|
+
# If a specific db_path is given, use it. Otherwise, create a model-specific path.
|
51
|
+
_final_db_path = (
|
52
|
+
str(db_path)
|
53
|
+
if db_path
|
54
|
+
else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
|
55
|
+
)
|
56
|
+
# Always create a model-specific cache path.
|
57
|
+
_final_cache_path = os.path.join(
|
58
|
+
get_rag_embedding_cache_path(), sanitized_model_name
|
59
|
+
)
|
60
|
+
|
61
|
+
self.embedding_manager = EmbeddingManager(
|
62
|
+
mode=cast(Literal["performance", "accuracy"], _embedding_mode),
|
63
|
+
cache_dir=_final_cache_path,
|
64
|
+
)
|
65
|
+
self.retriever = ChromaRetriever(
|
66
|
+
embedding_manager=self.embedding_manager,
|
67
|
+
db_path=_final_db_path,
|
68
|
+
collection_name=collection_name,
|
69
|
+
)
|
70
|
+
# Default to the ToolAgent_LLM unless a specific LLM is provided
|
71
|
+
self.llm = llm if llm is not None else ToolAgent_LLM()
|
72
|
+
self.reranker = Reranker()
|
73
|
+
# Use a standard LLM for the query rewriting task, not the agent
|
74
|
+
self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
|
75
|
+
|
76
|
+
print("✅ JarvisRAGPipeline 初始化成功。")
|
77
|
+
|
78
|
+
def add_documents(self, documents: List[Document]):
|
79
|
+
"""
|
80
|
+
Adds documents to the vector knowledge base.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
documents: A list of LangChain Document objects to add.
|
84
|
+
"""
|
85
|
+
self.retriever.add_documents(documents)
|
86
|
+
|
87
|
+
def _create_prompt(
|
88
|
+
self, query: str, context_docs: List[Document], source_files: List[str]
|
89
|
+
) -> str:
|
90
|
+
"""Creates the final prompt for the LLM or Agent."""
|
91
|
+
context = "\n\n".join([doc.page_content for doc in context_docs])
|
92
|
+
sources_text = "\n".join([f"- {source}" for source in source_files])
|
93
|
+
|
94
|
+
prompt_template = f"""
|
95
|
+
你是一个专家助手。请根据用户的问题,结合下面提供的参考信息来回答。
|
96
|
+
|
97
|
+
**重要**: 提供的上下文和文件列表**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
|
98
|
+
|
99
|
+
参考文件列表:
|
100
|
+
---
|
101
|
+
{sources_text}
|
102
|
+
---
|
103
|
+
|
104
|
+
参考上下文:
|
105
|
+
---
|
106
|
+
{context}
|
107
|
+
---
|
108
|
+
|
109
|
+
问题: {query}
|
110
|
+
|
111
|
+
回答:
|
112
|
+
"""
|
113
|
+
return prompt_template.strip()
|
114
|
+
|
115
|
+
def query(self, query_text: str, n_results: int = 5) -> str:
|
116
|
+
"""
|
117
|
+
Performs a query against the knowledge base using a multi-query
|
118
|
+
retrieval and reranking pipeline.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
query_text: The user's original question.
|
122
|
+
n_results: The number of final relevant chunks to retrieve.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
The answer generated by the LLM.
|
126
|
+
"""
|
127
|
+
# 1. Rewrite the original query into multiple queries
|
128
|
+
rewritten_queries = self.query_rewriter.rewrite(query_text)
|
129
|
+
|
130
|
+
# 2. Retrieve initial candidates for each rewritten query
|
131
|
+
all_candidate_docs = []
|
132
|
+
for q in rewritten_queries:
|
133
|
+
print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
|
134
|
+
candidates = self.retriever.retrieve(q, n_results=n_results * 2)
|
135
|
+
all_candidate_docs.extend(candidates)
|
136
|
+
|
137
|
+
# De-duplicate the candidate documents
|
138
|
+
unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
|
139
|
+
unique_candidate_docs = list(unique_docs_dict.values())
|
140
|
+
|
141
|
+
if not unique_candidate_docs:
|
142
|
+
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
143
|
+
|
144
|
+
# 3. Rerank the unified candidate pool against the *original* query
|
145
|
+
print(
|
146
|
+
f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
|
147
|
+
)
|
148
|
+
retrieved_docs = self.reranker.rerank(
|
149
|
+
query_text, unique_candidate_docs, top_n=n_results
|
150
|
+
)
|
151
|
+
|
152
|
+
if not retrieved_docs:
|
153
|
+
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
154
|
+
|
155
|
+
# Print the sources of the final retrieved documents
|
156
|
+
sources = sorted(
|
157
|
+
list(
|
158
|
+
{
|
159
|
+
doc.metadata["source"]
|
160
|
+
for doc in retrieved_docs
|
161
|
+
if "source" in doc.metadata
|
162
|
+
}
|
163
|
+
)
|
164
|
+
)
|
165
|
+
if sources:
|
166
|
+
print(f"📚 根据以下文档回答:")
|
167
|
+
for source in sources:
|
168
|
+
print(f" - {source}")
|
169
|
+
|
170
|
+
# 4. Create the final prompt and generate the answer
|
171
|
+
# We use the original query_text for the final prompt to the LLM
|
172
|
+
prompt = self._create_prompt(query_text, retrieved_docs, sources)
|
173
|
+
|
174
|
+
print("🤖 正在从LLM生成答案...")
|
175
|
+
answer = self.llm.generate(prompt)
|
176
|
+
|
177
|
+
return answer
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from langchain.docstore.document import Document
|
4
|
+
from sentence_transformers.cross_encoder import ( # type: ignore
|
5
|
+
CrossEncoder,
|
6
|
+
)
|
7
|
+
|
8
|
+
|
9
|
+
class Reranker:
|
10
|
+
"""
|
11
|
+
A reranker class that uses a Cross-Encoder model to re-score and sort
|
12
|
+
documents based on their relevance to a given query.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, model_name: str = "BAAI/bge-reranker-base"):
|
16
|
+
"""
|
17
|
+
Initializes the Reranker.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
model_name (str): The name of the Cross-Encoder model to use.
|
21
|
+
"""
|
22
|
+
print(f"🔍 正在初始化重排模型: {model_name}...")
|
23
|
+
self.model = CrossEncoder(model_name)
|
24
|
+
print("✅ 重排模型初始化成功。")
|
25
|
+
|
26
|
+
def rerank(
|
27
|
+
self, query: str, documents: List[Document], top_n: int = 5
|
28
|
+
) -> List[Document]:
|
29
|
+
"""
|
30
|
+
Reranks a list of documents based on their relevance to the query.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
query (str): The user's query.
|
34
|
+
documents (List[Document]): The list of documents retrieved from the initial search.
|
35
|
+
top_n (int): The number of top documents to return after reranking.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
List[Document]: A sorted list of the most relevant documents.
|
39
|
+
"""
|
40
|
+
if not documents:
|
41
|
+
return []
|
42
|
+
|
43
|
+
# Create pairs of [query, document_content] for scoring
|
44
|
+
pairs = [[query, doc.page_content] for doc in documents]
|
45
|
+
|
46
|
+
# Get scores from the Cross-Encoder model
|
47
|
+
scores = self.model.predict(pairs)
|
48
|
+
|
49
|
+
# Combine documents with their scores and sort
|
50
|
+
doc_with_scores = list(zip(documents, scores))
|
51
|
+
doc_with_scores.sort(key=lambda x: x[1], reverse=True)
|
52
|
+
|
53
|
+
# Return the top N documents
|
54
|
+
reranked_docs = [doc for doc, score in doc_with_scores[:top_n]]
|
55
|
+
|
56
|
+
return reranked_docs
|
@@ -0,0 +1,201 @@
|
|
1
|
+
import os
|
2
|
+
import pickle
|
3
|
+
from typing import Any, Dict, List, cast
|
4
|
+
|
5
|
+
import chromadb
|
6
|
+
from langchain.docstore.document import Document
|
7
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8
|
+
from rank_bm25 import BM25Okapi # type: ignore
|
9
|
+
|
10
|
+
from .embedding_manager import EmbeddingManager
|
11
|
+
|
12
|
+
|
13
|
+
class ChromaRetriever:
|
14
|
+
"""
|
15
|
+
A retriever class that combines dense vector search (ChromaDB) and
|
16
|
+
sparse keyword search (BM25) for hybrid retrieval.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
embedding_manager: EmbeddingManager,
|
22
|
+
db_path: str,
|
23
|
+
collection_name: str = "jarvis_rag_collection",
|
24
|
+
):
|
25
|
+
"""
|
26
|
+
Initializes the ChromaRetriever.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
embedding_manager: An instance of EmbeddingManager.
|
30
|
+
db_path: The file path for ChromaDB's persistent storage.
|
31
|
+
collection_name: The name of the collection within ChromaDB.
|
32
|
+
"""
|
33
|
+
self.embedding_manager = embedding_manager
|
34
|
+
self.db_path = db_path
|
35
|
+
self.collection_name = collection_name
|
36
|
+
|
37
|
+
# Initialize ChromaDB client
|
38
|
+
self.client = chromadb.PersistentClient(path=self.db_path)
|
39
|
+
self.collection = self.client.get_or_create_collection(
|
40
|
+
name=self.collection_name
|
41
|
+
)
|
42
|
+
print(
|
43
|
+
f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。"
|
44
|
+
)
|
45
|
+
|
46
|
+
# BM25 Index setup
|
47
|
+
self.bm25_index_path = os.path.join(self.db_path, f"{collection_name}_bm25.pkl")
|
48
|
+
self._load_or_initialize_bm25()
|
49
|
+
|
50
|
+
def _load_or_initialize_bm25(self):
|
51
|
+
"""Loads the BM25 index from disk or initializes a new one."""
|
52
|
+
if os.path.exists(self.bm25_index_path):
|
53
|
+
print("🔍 正在加载现有的 BM25 索引...")
|
54
|
+
with open(self.bm25_index_path, "rb") as f:
|
55
|
+
data = pickle.load(f)
|
56
|
+
self.bm25_corpus = data["corpus"]
|
57
|
+
self.bm25_index = BM25Okapi(self.bm25_corpus)
|
58
|
+
print("✅ BM25 索引加载成功。")
|
59
|
+
else:
|
60
|
+
print("⚠️ 未找到 BM25 索引,将初始化一个新的。")
|
61
|
+
self.bm25_corpus = []
|
62
|
+
self.bm25_index = None
|
63
|
+
|
64
|
+
def _save_bm25_index(self):
|
65
|
+
"""Saves the BM25 index to disk."""
|
66
|
+
if self.bm25_index:
|
67
|
+
print("💾 正在保存 BM25 索引...")
|
68
|
+
with open(self.bm25_index_path, "wb") as f:
|
69
|
+
pickle.dump({"corpus": self.bm25_corpus, "index": self.bm25_index}, f)
|
70
|
+
print("✅ BM25 索引保存成功。")
|
71
|
+
|
72
|
+
def add_documents(
|
73
|
+
self, documents: List[Document], chunk_size=1000, chunk_overlap=100
|
74
|
+
):
|
75
|
+
"""
|
76
|
+
Splits, embeds, and adds documents to both ChromaDB and the BM25 index.
|
77
|
+
"""
|
78
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
79
|
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
80
|
+
)
|
81
|
+
chunks = text_splitter.split_documents(documents)
|
82
|
+
|
83
|
+
print(f"📄 已将 {len(documents)} 个文档拆分为 {len(chunks)} 个块。")
|
84
|
+
|
85
|
+
if not chunks:
|
86
|
+
return
|
87
|
+
|
88
|
+
# Extract content, metadata, and generate IDs
|
89
|
+
chunk_texts = [chunk.page_content for chunk in chunks]
|
90
|
+
metadatas = [chunk.metadata for chunk in chunks]
|
91
|
+
start_id = self.collection.count()
|
92
|
+
ids = [f"doc_{i}" for i in range(start_id, start_id + len(chunks))]
|
93
|
+
|
94
|
+
# Add to ChromaDB
|
95
|
+
embeddings = self.embedding_manager.embed_documents(chunk_texts)
|
96
|
+
self.collection.add(
|
97
|
+
ids=ids,
|
98
|
+
embeddings=cast(Any, embeddings),
|
99
|
+
documents=chunk_texts,
|
100
|
+
metadatas=cast(Any, metadatas),
|
101
|
+
)
|
102
|
+
print(f"✅ 成功将 {len(chunks)} 个块添加到 ChromaDB 集合中。")
|
103
|
+
|
104
|
+
# Update and save BM25 index
|
105
|
+
tokenized_chunks = [doc.split() for doc in chunk_texts]
|
106
|
+
self.bm25_corpus.extend(tokenized_chunks)
|
107
|
+
self.bm25_index = BM25Okapi(self.bm25_corpus)
|
108
|
+
self._save_bm25_index()
|
109
|
+
|
110
|
+
def retrieve(self, query: str, n_results: int = 5) -> List[Document]:
|
111
|
+
"""
|
112
|
+
Performs hybrid retrieval using both vector search and BM25,
|
113
|
+
then fuses the results using Reciprocal Rank Fusion (RRF).
|
114
|
+
"""
|
115
|
+
# 1. Vector Search (ChromaDB)
|
116
|
+
query_embedding = self.embedding_manager.embed_query(query)
|
117
|
+
vector_results = self.collection.query(
|
118
|
+
query_embeddings=cast(Any, [query_embedding]),
|
119
|
+
n_results=n_results * 2, # Retrieve more results for fusion
|
120
|
+
)
|
121
|
+
|
122
|
+
# 2. Keyword Search (BM25)
|
123
|
+
bm25_docs = []
|
124
|
+
if self.bm25_index:
|
125
|
+
tokenized_query = query.split()
|
126
|
+
doc_scores = self.bm25_index.get_scores(tokenized_query)
|
127
|
+
|
128
|
+
# Get all documents from Chroma to match with BM25 scores
|
129
|
+
all_docs_in_collection = self.collection.get()
|
130
|
+
all_documents = all_docs_in_collection.get("documents")
|
131
|
+
all_metadatas = all_docs_in_collection.get("metadatas")
|
132
|
+
|
133
|
+
bm25_results_with_docs = []
|
134
|
+
if all_documents and all_metadatas:
|
135
|
+
# Create a mapping from index to document
|
136
|
+
bm25_results_with_docs = [
|
137
|
+
(
|
138
|
+
all_documents[i],
|
139
|
+
all_metadatas[i],
|
140
|
+
score,
|
141
|
+
)
|
142
|
+
for i, score in enumerate(doc_scores)
|
143
|
+
if score > 0
|
144
|
+
]
|
145
|
+
|
146
|
+
# Sort by score and take top results
|
147
|
+
bm25_results_with_docs.sort(key=lambda x: x[2], reverse=True)
|
148
|
+
|
149
|
+
for doc_text, metadata, _ in bm25_results_with_docs[: n_results * 2]:
|
150
|
+
bm25_docs.append(Document(page_content=doc_text, metadata=metadata))
|
151
|
+
|
152
|
+
# 3. Reciprocal Rank Fusion (RRF)
|
153
|
+
fused_scores: Dict[str, float] = {}
|
154
|
+
k = 60 # RRF ranking constant
|
155
|
+
|
156
|
+
# Process vector results
|
157
|
+
if vector_results and vector_results["ids"] and vector_results["documents"]:
|
158
|
+
vec_ids = vector_results["ids"][0]
|
159
|
+
vec_texts = vector_results["documents"][0]
|
160
|
+
|
161
|
+
for rank, doc_id in enumerate(vec_ids):
|
162
|
+
fused_scores[doc_id] = fused_scores.get(doc_id, 0) + 1 / (k + rank)
|
163
|
+
|
164
|
+
# Create a map from document text to its ID for BM25 fusion
|
165
|
+
doc_text_to_id = {text: doc_id for text, doc_id in zip(vec_texts, vec_ids)}
|
166
|
+
|
167
|
+
for rank, doc in enumerate(bm25_docs):
|
168
|
+
bm25_doc_id = doc_text_to_id.get(doc.page_content)
|
169
|
+
if bm25_doc_id:
|
170
|
+
fused_scores[bm25_doc_id] = fused_scores.get(bm25_doc_id, 0) + 1 / (
|
171
|
+
k + rank
|
172
|
+
)
|
173
|
+
|
174
|
+
# Sort fused results
|
175
|
+
sorted_fused_results = sorted(
|
176
|
+
fused_scores.items(), key=lambda x: x[1], reverse=True
|
177
|
+
)
|
178
|
+
|
179
|
+
# Get the final documents from ChromaDB based on fused ranking
|
180
|
+
final_doc_ids = [item[0] for item in sorted_fused_results[:n_results]]
|
181
|
+
|
182
|
+
if not final_doc_ids:
|
183
|
+
return []
|
184
|
+
|
185
|
+
final_docs_data = self.collection.get(ids=final_doc_ids)
|
186
|
+
|
187
|
+
retrieved_docs = []
|
188
|
+
if final_docs_data:
|
189
|
+
final_documents = final_docs_data.get("documents")
|
190
|
+
final_metadatas = final_docs_data.get("metadatas")
|
191
|
+
|
192
|
+
if final_documents and final_metadatas:
|
193
|
+
for doc_text, metadata in zip(final_documents, final_metadatas):
|
194
|
+
if doc_text is not None and metadata is not None:
|
195
|
+
retrieved_docs.append(
|
196
|
+
Document(
|
197
|
+
page_content=cast(str, doc_text), metadata=metadata
|
198
|
+
)
|
199
|
+
)
|
200
|
+
|
201
|
+
return retrieved_docs
|
@@ -1,10 +1,20 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
|
+
"""A tool for searching the web."""
|
2
3
|
from typing import Any, Dict
|
3
4
|
|
5
|
+
import httpx
|
6
|
+
from bs4 import BeautifulSoup
|
7
|
+
from ddgs import DDGS
|
8
|
+
|
9
|
+
from jarvis.jarvis_agent import Agent
|
4
10
|
from jarvis.jarvis_platform.registry import PlatformRegistry
|
11
|
+
from jarvis.jarvis_utils.http import get as http_get
|
12
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
5
13
|
|
6
14
|
|
7
15
|
class SearchWebTool:
|
16
|
+
"""A class to handle web searches."""
|
17
|
+
|
8
18
|
name = "search_web"
|
9
19
|
description = "搜索互联网上的信息"
|
10
20
|
parameters = {
|
@@ -12,18 +22,124 @@ class SearchWebTool:
|
|
12
22
|
"properties": {"query": {"type": "string", "description": "具体的问题"}},
|
13
23
|
}
|
14
24
|
|
15
|
-
def
|
25
|
+
def _search_with_ddgs(self, query: str, agent: Agent) -> Dict[str, Any]:
|
26
|
+
# pylint: disable=too-many-locals, broad-except
|
27
|
+
"""Performs a web search, scrapes content, and summarizes the results."""
|
28
|
+
try:
|
29
|
+
PrettyOutput.print("▶️ 使用 DuckDuckGo 开始网页搜索...", OutputType.INFO)
|
30
|
+
results = list(DDGS().text(query, max_results=5))
|
31
|
+
|
32
|
+
if not results:
|
33
|
+
return {
|
34
|
+
"stdout": "未找到搜索结果。",
|
35
|
+
"stderr": "未找到搜索结果。",
|
36
|
+
"success": False,
|
37
|
+
}
|
38
|
+
|
39
|
+
urls = [r["href"] for r in results]
|
40
|
+
full_content = ""
|
41
|
+
visited_urls = []
|
42
|
+
|
43
|
+
for url in urls:
|
44
|
+
try:
|
45
|
+
PrettyOutput.print(f"📄 正在抓取内容: {url}", OutputType.INFO)
|
46
|
+
response = http_get(url, timeout=10.0, follow_redirects=True)
|
47
|
+
soup = BeautifulSoup(response.text, "lxml")
|
48
|
+
body = soup.find("body")
|
49
|
+
if body:
|
50
|
+
full_content += body.get_text(" ", strip=True) + "\n\n"
|
51
|
+
visited_urls.append(url)
|
52
|
+
except httpx.HTTPStatusError as e:
|
53
|
+
PrettyOutput.print(
|
54
|
+
f"⚠️ HTTP错误 {e.response.status_code} 访问 {url}",
|
55
|
+
OutputType.WARNING,
|
56
|
+
)
|
57
|
+
except httpx.RequestError as e:
|
58
|
+
PrettyOutput.print(f"⚠️ 请求错误: {e}", OutputType.WARNING)
|
59
|
+
|
60
|
+
if not full_content.strip():
|
61
|
+
return {
|
62
|
+
"stdout": "无法从任何URL抓取有效内容。",
|
63
|
+
"stderr": "抓取内容失败。",
|
64
|
+
"success": False,
|
65
|
+
}
|
66
|
+
|
67
|
+
url_list_str = "\n".join(f" - {u}" for u in visited_urls)
|
68
|
+
PrettyOutput.print(
|
69
|
+
f"🔍 已成功访问并处理以下URL:\n{url_list_str}", OutputType.INFO
|
70
|
+
)
|
71
|
+
|
72
|
+
PrettyOutput.print("🧠 正在总结内容...", OutputType.INFO)
|
73
|
+
summary_prompt = f"请为查询“{query}”总结以下内容:\n\n{full_content}"
|
74
|
+
|
75
|
+
if not agent.model:
|
76
|
+
return {
|
77
|
+
"stdout": "",
|
78
|
+
"stderr": "用于总结的Agent模型未找到。",
|
79
|
+
"success": False,
|
80
|
+
}
|
81
|
+
|
82
|
+
platform_name = agent.model.platform_name()
|
83
|
+
model_name = agent.model.name()
|
84
|
+
|
85
|
+
model = PlatformRegistry().create_platform(platform_name)
|
86
|
+
if not model:
|
87
|
+
return {
|
88
|
+
"stdout": "",
|
89
|
+
"stderr": "无法创建用于总结的模型。",
|
90
|
+
"success": False,
|
91
|
+
}
|
92
|
+
|
93
|
+
model.set_model_name(model_name)
|
94
|
+
model.set_suppress_output(False)
|
95
|
+
summary = model.chat_until_success(summary_prompt)
|
96
|
+
|
97
|
+
return {"stdout": summary, "stderr": "", "success": True}
|
98
|
+
|
99
|
+
except Exception as e:
|
100
|
+
PrettyOutput.print(f"❌ 网页搜索过程中发生错误: {e}", OutputType.ERROR)
|
101
|
+
return {
|
102
|
+
"stdout": "",
|
103
|
+
"stderr": f"网页搜索过程中发生错误: {e}",
|
104
|
+
"success": False,
|
105
|
+
}
|
106
|
+
|
107
|
+
def execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
108
|
+
"""
|
109
|
+
Executes the web search.
|
110
|
+
|
111
|
+
If the agent's model supports a native web search, it uses it.
|
112
|
+
Otherwise, it falls back to using DuckDuckGo Search and scraping pages.
|
113
|
+
"""
|
16
114
|
query = args.get("query")
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
115
|
+
agent = args.get("agent")
|
116
|
+
|
117
|
+
if not query:
|
118
|
+
return {"stdout": "", "stderr": "缺少查询参数。", "success": False}
|
119
|
+
|
120
|
+
if not isinstance(agent, Agent) or not agent.model:
|
121
|
+
return {
|
122
|
+
"stdout": "",
|
123
|
+
"stderr": "Agent或Agent模型未找到。",
|
124
|
+
"success": False,
|
125
|
+
}
|
126
|
+
|
127
|
+
if agent.model.support_web():
|
128
|
+
model = PlatformRegistry().create_platform(agent.model.platform_name())
|
129
|
+
if not model:
|
130
|
+
return {"stdout": "", "stderr": "无法创建模型。", "success": False}
|
131
|
+
model.set_model_name(agent.model.name())
|
132
|
+
model.set_web(True)
|
133
|
+
model.set_suppress_output(False)
|
134
|
+
return {
|
135
|
+
"stdout": model.chat_until_success(query),
|
136
|
+
"stderr": "",
|
137
|
+
"success": True,
|
138
|
+
}
|
139
|
+
|
140
|
+
return self._search_with_ddgs(query, agent)
|
25
141
|
|
26
142
|
@staticmethod
|
27
143
|
def check() -> bool:
|
28
|
-
"""
|
29
|
-
return
|
144
|
+
"""Check if the tool is available."""
|
145
|
+
return True
|
jarvis/jarvis_utils/config.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3
3
|
from functools import lru_cache
|
4
4
|
from typing import Any, Dict, List
|
5
5
|
|
6
|
+
import torch
|
6
7
|
import yaml # type: ignore
|
7
8
|
|
8
9
|
from jarvis.jarvis_utils.builtin_replace_map import BUILTIN_REPLACE_MAP
|
@@ -248,3 +249,73 @@ def get_mcp_config() -> List[Dict[str, Any]]:
|
|
248
249
|
List[Dict[str, Any]]: MCP配置项列表,如果未配置则返回空列表
|
249
250
|
"""
|
250
251
|
return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
|
252
|
+
|
253
|
+
|
254
|
+
# ==============================================================================
|
255
|
+
# RAG Framework Configuration
|
256
|
+
# ==============================================================================
|
257
|
+
|
258
|
+
EMBEDDING_MODELS = {
|
259
|
+
"performance": {
|
260
|
+
"model_name": "BAAI/bge-base-zh-v1.5",
|
261
|
+
"model_kwargs": {"device": "cuda" if torch.cuda.is_available() else "cpu"},
|
262
|
+
"encode_kwargs": {"normalize_embeddings": True},
|
263
|
+
"show_progress": True,
|
264
|
+
},
|
265
|
+
"accuracy": {
|
266
|
+
"model_name": "BAAI/bge-large-zh-v1.5",
|
267
|
+
"model_kwargs": {"device": "cuda" if torch.cuda.is_available() else "cpu"},
|
268
|
+
"encode_kwargs": {"normalize_embeddings": True},
|
269
|
+
"show_progress": True,
|
270
|
+
},
|
271
|
+
}
|
272
|
+
|
273
|
+
|
274
|
+
def get_rag_config() -> Dict[str, Any]:
|
275
|
+
"""
|
276
|
+
获取RAG框架的配置。
|
277
|
+
|
278
|
+
返回:
|
279
|
+
Dict[str, Any]: RAG配置字典
|
280
|
+
"""
|
281
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_RAG", {})
|
282
|
+
|
283
|
+
|
284
|
+
def get_rag_embedding_models() -> Dict[str, Any]:
|
285
|
+
"""
|
286
|
+
获取RAG嵌入模型的定义。
|
287
|
+
|
288
|
+
返回:
|
289
|
+
Dict[str, Any]: 嵌入模型配置字典
|
290
|
+
"""
|
291
|
+
return EMBEDDING_MODELS
|
292
|
+
|
293
|
+
|
294
|
+
def get_rag_embedding_mode() -> str:
|
295
|
+
"""
|
296
|
+
获取RAG嵌入模型的模式。
|
297
|
+
|
298
|
+
返回:
|
299
|
+
str: 'performance' 或 'accuracy'
|
300
|
+
"""
|
301
|
+
return get_rag_config().get("embedding_mode", "performance")
|
302
|
+
|
303
|
+
|
304
|
+
def get_rag_embedding_cache_path() -> str:
|
305
|
+
"""
|
306
|
+
获取RAG嵌入缓存的路径。
|
307
|
+
|
308
|
+
返回:
|
309
|
+
str: 缓存路径
|
310
|
+
"""
|
311
|
+
return get_rag_config().get("embedding_cache_path", ".jarvis/rag/embeddings")
|
312
|
+
|
313
|
+
|
314
|
+
def get_rag_vector_db_path() -> str:
|
315
|
+
"""
|
316
|
+
获取RAG向量数据库的路径。
|
317
|
+
|
318
|
+
返回:
|
319
|
+
str: 数据库路径
|
320
|
+
"""
|
321
|
+
return get_rag_config().get("vector_db_path", ".jarvis/rag/vectordb")
|