jarvis-ai-assistant 0.1.219__py3-none-any.whl → 0.1.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,109 @@
1
+ from typing import List, Literal, cast
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+
4
+ from jarvis.jarvis_utils.config import (
5
+ get_rag_embedding_models,
6
+ get_rag_embedding_cache_path,
7
+ )
8
+ from .cache import EmbeddingCache
9
+
10
+
11
+ class EmbeddingManager:
12
+ """
13
+ Manages the loading and usage of local embedding models with caching.
14
+
15
+ This class handles the selection of embedding models based on a specified
16
+ mode ('performance' or 'accuracy'), loads the model from Hugging Face,
17
+ and uses a disk-based cache to avoid re-computing embeddings for the
18
+ same text.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ mode: Literal["performance", "accuracy"],
24
+ cache_dir: str,
25
+ ):
26
+ """
27
+ Initializes the EmbeddingManager.
28
+
29
+ Args:
30
+ mode: The desired mode, either 'performance' or 'accuracy'.
31
+ cache_dir: The directory to store the embedding cache.
32
+ """
33
+ self.mode = mode
34
+ self.embedding_models = get_rag_embedding_models()
35
+ if mode not in self.embedding_models:
36
+ raise ValueError(
37
+ f"Invalid mode '{mode}'. Must be one of {list(self.embedding_models.keys())}"
38
+ )
39
+
40
+ self.model_config = self.embedding_models[self.mode]
41
+ self.model_name = self.model_config["model_name"]
42
+
43
+ print(f"🚀 初始化嵌入管理器,模式: '{self.mode}', 模型: '{self.model_name}'...")
44
+
45
+ # The salt for the cache is the model name to prevent collisions
46
+ self.cache = EmbeddingCache(cache_dir=cache_dir, salt=str(self.model_name))
47
+ self.model = self._load_model()
48
+
49
+ def _load_model(self) -> HuggingFaceEmbeddings:
50
+ """Loads the Hugging Face embedding model based on the configuration."""
51
+ try:
52
+ return HuggingFaceEmbeddings(
53
+ model_name=self.model_name,
54
+ model_kwargs=self.model_config.get("model_kwargs"),
55
+ encode_kwargs=self.model_config.get("encode_kwargs"),
56
+ show_progress=self.model_config.get("show_progress", False),
57
+ )
58
+ except Exception as e:
59
+ print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
60
+ print("请确保您已安装 'sentence_transformers' 和 'torch'。")
61
+ raise
62
+
63
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
64
+ """
65
+ Computes embeddings for a list of documents, using the cache.
66
+
67
+ Args:
68
+ texts: A list of documents (strings) to embed.
69
+
70
+ Returns:
71
+ A list of embeddings, one for each document.
72
+ """
73
+ if not texts:
74
+ return []
75
+
76
+ # Check cache for existing embeddings
77
+ cached_embeddings = self.cache.get_batch(texts)
78
+
79
+ texts_to_embed = []
80
+ indices_to_embed = []
81
+ for i, (text, cached) in enumerate(zip(texts, cached_embeddings)):
82
+ if cached is None:
83
+ texts_to_embed.append(text)
84
+ indices_to_embed.append(i)
85
+
86
+ # Compute embeddings for texts that were not in the cache
87
+ if texts_to_embed:
88
+ print(
89
+ f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
90
+ )
91
+ new_embeddings = self.model.embed_documents(texts_to_embed)
92
+
93
+ # Store new embeddings in the cache
94
+ self.cache.set_batch(texts_to_embed, new_embeddings)
95
+
96
+ # Place new embeddings back into the results list
97
+ for i, embedding in zip(indices_to_embed, new_embeddings):
98
+ cached_embeddings[i] = embedding
99
+ else:
100
+ print(f"✅ 缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。")
101
+
102
+ return cast(List[List[float]], cached_embeddings)
103
+
104
+ def embed_query(self, text: str) -> List[float]:
105
+ """
106
+ Computes the embedding for a single query.
107
+ Queries are typically not cached, but we can add it if needed.
108
+ """
109
+ return self.model.embed_query(text)
@@ -0,0 +1,130 @@
1
+ from abc import ABC, abstractmethod
2
+ import os
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+
6
+ from jarvis.jarvis_agent import Agent as JarvisAgent
7
+ from jarvis.jarvis_platform.base import BasePlatform
8
+ from jarvis.jarvis_platform.registry import PlatformRegistry
9
+
10
+
11
+ class LLMInterface(ABC):
12
+ """
13
+ Abstract Base Class for Large Language Model interfaces.
14
+
15
+ This class defines the standard interface for interacting with a remote LLM.
16
+ Any LLM provider (OpenAI, Anthropic, etc.) should be implemented as a
17
+ subclass of this interface.
18
+ """
19
+
20
+ @abstractmethod
21
+ def generate(self, prompt: str, **kwargs) -> str:
22
+ """
23
+ Generates a response from the LLM based on a given prompt.
24
+
25
+ Args:
26
+ prompt: The input prompt to send to the LLM.
27
+ **kwargs: Additional keyword arguments for the LLM API call
28
+ (e.g., temperature, max_tokens).
29
+
30
+ Returns:
31
+ The text response generated by the LLM.
32
+ """
33
+ pass
34
+
35
+
36
+ class ToolAgent_LLM(LLMInterface):
37
+ """
38
+ An implementation of the LLMInterface that uses a tool-wielding JarvisAgent
39
+ to generate the final response.
40
+ """
41
+
42
+ def __init__(self):
43
+ """
44
+ Initializes the Tool-Agent LLM wrapper.
45
+ """
46
+ print("🤖 已初始化工具 Agent 作为最终应答者。")
47
+ self.allowed_tools = ["read_code", "execute_script"]
48
+ # A generic system prompt for the agent
49
+ self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
50
+ self.summary_prompt = """
51
+ <report>
52
+ 请为本次问答任务生成一个总结报告,包含以下内容:
53
+
54
+ 1. **原始问题**: 重述用户最开始提出的问题。
55
+ 2. **关键信息来源**: 总结你是基于哪些关键信息或文件得出的结论。
56
+ 3. **最终答案**: 给出最终的、精炼的回答。
57
+ </report>
58
+ """
59
+
60
+ def generate(self, prompt: str, **kwargs) -> str:
61
+ """
62
+ Runs the JarvisAgent with a restricted toolset to generate an answer.
63
+
64
+ Args:
65
+ prompt: The full prompt, including context, to be sent to the agent.
66
+ **kwargs: Ignored, kept for interface compatibility.
67
+
68
+ Returns:
69
+ The final answer generated by the agent.
70
+ """
71
+ try:
72
+ # Initialize the agent with specific settings for RAG context
73
+ agent = JarvisAgent(
74
+ system_prompt=self.system_prompt,
75
+ use_tools=self.allowed_tools,
76
+ auto_complete=True,
77
+ use_methodology=False,
78
+ use_analysis=False,
79
+ need_summary=True,
80
+ summary_prompt=self.summary_prompt,
81
+ )
82
+
83
+ # The agent's run method expects the 'user_input' parameter
84
+ final_answer = agent.run(user_input=prompt)
85
+ return str(final_answer)
86
+
87
+ except Exception as e:
88
+ print(f"❌ Agent 在执行过程中发生错误: {e}")
89
+ return "错误: Agent 未能成功生成回答。"
90
+
91
+
92
+ class JarvisPlatform_LLM(LLMInterface):
93
+ """
94
+ An implementation of the LLMInterface for the project's internal platform.
95
+
96
+ This class uses the PlatformRegistry to get the configured "normal" model.
97
+ """
98
+
99
+ def __init__(self):
100
+ """
101
+ Initializes the Jarvis Platform LLM client.
102
+ """
103
+ try:
104
+ self.registry = PlatformRegistry.get_global_platform_registry()
105
+ self.platform: BasePlatform = self.registry.get_normal_platform()
106
+ self.platform.set_suppress_output(
107
+ False
108
+ ) # Ensure no console output from the model
109
+ print(f"🚀 已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}")
110
+ except Exception as e:
111
+ print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
112
+ raise
113
+
114
+ def generate(self, prompt: str, **kwargs) -> str:
115
+ """
116
+ Sends a prompt to the local platform model and returns the response.
117
+
118
+ Args:
119
+ prompt: The user's prompt.
120
+ **kwargs: Ignored, kept for interface compatibility.
121
+
122
+ Returns:
123
+ The response generated by the platform model.
124
+ """
125
+ try:
126
+ # Use the robust chat_until_success method
127
+ return self.platform.chat_until_success(prompt)
128
+ except Exception as e:
129
+ print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
130
+ return "错误: 无法从本地LLM获取响应。"
@@ -0,0 +1,63 @@
1
+ from typing import List
2
+ from .llm_interface import LLMInterface
3
+
4
+
5
+ class QueryRewriter:
6
+ """
7
+ Uses an LLM to rewrite a user's query into multiple, diverse search
8
+ queries to enhance retrieval recall.
9
+ """
10
+
11
+ def __init__(self, llm: LLMInterface):
12
+ """
13
+ Initializes the QueryRewriter.
14
+
15
+ Args:
16
+ llm: An instance of a class implementing LLMInterface.
17
+ """
18
+ self.llm = llm
19
+ self.rewrite_prompt_template = self._create_prompt_template()
20
+
21
+ def _create_prompt_template(self) -> str:
22
+ """Creates the prompt template for the multi-query rewriting task."""
23
+ return """
24
+ 你是一个精通检索的AI助手。你的任务是将以下这个单一的用户问题,从不同角度改写成 3 个不同的、但语义上相关的搜索查询。这有助于在知识库中进行更全面的搜索。
25
+
26
+ 请遵循以下原则:
27
+ 1. **多样性**:生成的查询应尝试使用不同的关键词和表述方式。
28
+ 2. **保留核心意图**:所有查询都必须围绕原始问题的核心意图。
29
+ 3. **简洁性**:每个查询都应该是独立的、可以直接用于搜索的短语或问题。
30
+ 4. **格式要求**:请直接输出 3 个查询,每个查询占一行,用换行符分隔。不要添加任何编号、前缀或解释。
31
+
32
+ 原始问题:
33
+ ---
34
+ {query}
35
+ ---
36
+
37
+ 3个改写后的查询 (每行一个):
38
+ """
39
+
40
+ def rewrite(self, query: str) -> List[str]:
41
+ """
42
+ Rewrites the user query into multiple queries using the LLM.
43
+
44
+ Args:
45
+ query: The original user query.
46
+
47
+ Returns:
48
+ A list of rewritten, search-optimized queries.
49
+ """
50
+ prompt = self.rewrite_prompt_template.format(query=query)
51
+ print(f"✍️ 正在将原始查询重写为多个搜索查询...")
52
+
53
+ response_text = self.llm.generate(prompt)
54
+ rewritten_queries = [
55
+ line.strip() for line in response_text.strip().split("\n") if line.strip()
56
+ ]
57
+
58
+ # Also include the original query for robustness
59
+ if query not in rewritten_queries:
60
+ rewritten_queries.insert(0, query)
61
+
62
+ print(f"✅ 生成了 {len(rewritten_queries)} 个查询变体。")
63
+ return rewritten_queries
@@ -0,0 +1,177 @@
1
+ import os
2
+ from typing import List, Literal, Optional, cast
3
+
4
+ from langchain.docstore.document import Document
5
+
6
+ from .embedding_manager import EmbeddingManager
7
+ from .llm_interface import JarvisPlatform_LLM, LLMInterface, ToolAgent_LLM
8
+ from .query_rewriter import QueryRewriter
9
+ from .reranker import Reranker
10
+ from .retriever import ChromaRetriever
11
+ from jarvis.jarvis_utils.config import (
12
+ get_rag_embedding_mode,
13
+ get_rag_vector_db_path,
14
+ get_rag_embedding_cache_path,
15
+ get_rag_embedding_models,
16
+ )
17
+
18
+
19
+ class JarvisRAGPipeline:
20
+ """
21
+ The main orchestrator for the RAG pipeline.
22
+
23
+ This class integrates the embedding manager, retriever, and LLM to provide
24
+ a complete pipeline for adding documents and querying them.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ llm: Optional[LLMInterface] = None,
30
+ embedding_mode: Optional[Literal["performance", "accuracy"]] = None,
31
+ db_path: Optional[str] = None,
32
+ collection_name: str = "jarvis_rag_collection",
33
+ ):
34
+ """
35
+ Initializes the RAG pipeline.
36
+
37
+ Args:
38
+ llm: An instance of a class implementing LLMInterface.
39
+ If None, defaults to the ToolAgent_LLM.
40
+ embedding_mode: The mode for the local embedding model. If None, uses config value.
41
+ db_path: Path to the persistent vector database. If None, uses config value.
42
+ collection_name: Name of the collection in the vector database.
43
+ """
44
+ # Determine the embedding model to isolate data paths
45
+ _embedding_mode = embedding_mode or get_rag_embedding_mode()
46
+ embedding_models = get_rag_embedding_models()
47
+ model_name = embedding_models[_embedding_mode]["model_name"]
48
+ sanitized_model_name = model_name.replace("/", "_").replace("\\", "_")
49
+
50
+ # If a specific db_path is given, use it. Otherwise, create a model-specific path.
51
+ _final_db_path = (
52
+ str(db_path)
53
+ if db_path
54
+ else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
55
+ )
56
+ # Always create a model-specific cache path.
57
+ _final_cache_path = os.path.join(
58
+ get_rag_embedding_cache_path(), sanitized_model_name
59
+ )
60
+
61
+ self.embedding_manager = EmbeddingManager(
62
+ mode=cast(Literal["performance", "accuracy"], _embedding_mode),
63
+ cache_dir=_final_cache_path,
64
+ )
65
+ self.retriever = ChromaRetriever(
66
+ embedding_manager=self.embedding_manager,
67
+ db_path=_final_db_path,
68
+ collection_name=collection_name,
69
+ )
70
+ # Default to the ToolAgent_LLM unless a specific LLM is provided
71
+ self.llm = llm if llm is not None else ToolAgent_LLM()
72
+ self.reranker = Reranker()
73
+ # Use a standard LLM for the query rewriting task, not the agent
74
+ self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
75
+
76
+ print("✅ JarvisRAGPipeline 初始化成功。")
77
+
78
+ def add_documents(self, documents: List[Document]):
79
+ """
80
+ Adds documents to the vector knowledge base.
81
+
82
+ Args:
83
+ documents: A list of LangChain Document objects to add.
84
+ """
85
+ self.retriever.add_documents(documents)
86
+
87
+ def _create_prompt(
88
+ self, query: str, context_docs: List[Document], source_files: List[str]
89
+ ) -> str:
90
+ """Creates the final prompt for the LLM or Agent."""
91
+ context = "\n\n".join([doc.page_content for doc in context_docs])
92
+ sources_text = "\n".join([f"- {source}" for source in source_files])
93
+
94
+ prompt_template = f"""
95
+ 你是一个专家助手。请根据用户的问题,结合下面提供的参考信息来回答。
96
+
97
+ **重要**: 提供的上下文和文件列表**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
98
+
99
+ 参考文件列表:
100
+ ---
101
+ {sources_text}
102
+ ---
103
+
104
+ 参考上下文:
105
+ ---
106
+ {context}
107
+ ---
108
+
109
+ 问题: {query}
110
+
111
+ 回答:
112
+ """
113
+ return prompt_template.strip()
114
+
115
+ def query(self, query_text: str, n_results: int = 5) -> str:
116
+ """
117
+ Performs a query against the knowledge base using a multi-query
118
+ retrieval and reranking pipeline.
119
+
120
+ Args:
121
+ query_text: The user's original question.
122
+ n_results: The number of final relevant chunks to retrieve.
123
+
124
+ Returns:
125
+ The answer generated by the LLM.
126
+ """
127
+ # 1. Rewrite the original query into multiple queries
128
+ rewritten_queries = self.query_rewriter.rewrite(query_text)
129
+
130
+ # 2. Retrieve initial candidates for each rewritten query
131
+ all_candidate_docs = []
132
+ for q in rewritten_queries:
133
+ print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
134
+ candidates = self.retriever.retrieve(q, n_results=n_results * 2)
135
+ all_candidate_docs.extend(candidates)
136
+
137
+ # De-duplicate the candidate documents
138
+ unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
139
+ unique_candidate_docs = list(unique_docs_dict.values())
140
+
141
+ if not unique_candidate_docs:
142
+ return "我在提供的文档中找不到任何相关信息来回答您的问题。"
143
+
144
+ # 3. Rerank the unified candidate pool against the *original* query
145
+ print(
146
+ f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
147
+ )
148
+ retrieved_docs = self.reranker.rerank(
149
+ query_text, unique_candidate_docs, top_n=n_results
150
+ )
151
+
152
+ if not retrieved_docs:
153
+ return "我在提供的文档中找不到任何相关信息来回答您的问题。"
154
+
155
+ # Print the sources of the final retrieved documents
156
+ sources = sorted(
157
+ list(
158
+ {
159
+ doc.metadata["source"]
160
+ for doc in retrieved_docs
161
+ if "source" in doc.metadata
162
+ }
163
+ )
164
+ )
165
+ if sources:
166
+ print(f"📚 根据以下文档回答:")
167
+ for source in sources:
168
+ print(f" - {source}")
169
+
170
+ # 4. Create the final prompt and generate the answer
171
+ # We use the original query_text for the final prompt to the LLM
172
+ prompt = self._create_prompt(query_text, retrieved_docs, sources)
173
+
174
+ print("🤖 正在从LLM生成答案...")
175
+ answer = self.llm.generate(prompt)
176
+
177
+ return answer
@@ -0,0 +1,56 @@
1
+ from typing import List
2
+
3
+ from langchain.docstore.document import Document
4
+ from sentence_transformers.cross_encoder import ( # type: ignore
5
+ CrossEncoder,
6
+ )
7
+
8
+
9
+ class Reranker:
10
+ """
11
+ A reranker class that uses a Cross-Encoder model to re-score and sort
12
+ documents based on their relevance to a given query.
13
+ """
14
+
15
+ def __init__(self, model_name: str = "BAAI/bge-reranker-base"):
16
+ """
17
+ Initializes the Reranker.
18
+
19
+ Args:
20
+ model_name (str): The name of the Cross-Encoder model to use.
21
+ """
22
+ print(f"🔍 正在初始化重排模型: {model_name}...")
23
+ self.model = CrossEncoder(model_name)
24
+ print("✅ 重排模型初始化成功。")
25
+
26
+ def rerank(
27
+ self, query: str, documents: List[Document], top_n: int = 5
28
+ ) -> List[Document]:
29
+ """
30
+ Reranks a list of documents based on their relevance to the query.
31
+
32
+ Args:
33
+ query (str): The user's query.
34
+ documents (List[Document]): The list of documents retrieved from the initial search.
35
+ top_n (int): The number of top documents to return after reranking.
36
+
37
+ Returns:
38
+ List[Document]: A sorted list of the most relevant documents.
39
+ """
40
+ if not documents:
41
+ return []
42
+
43
+ # Create pairs of [query, document_content] for scoring
44
+ pairs = [[query, doc.page_content] for doc in documents]
45
+
46
+ # Get scores from the Cross-Encoder model
47
+ scores = self.model.predict(pairs)
48
+
49
+ # Combine documents with their scores and sort
50
+ doc_with_scores = list(zip(documents, scores))
51
+ doc_with_scores.sort(key=lambda x: x[1], reverse=True)
52
+
53
+ # Return the top N documents
54
+ reranked_docs = [doc for doc, score in doc_with_scores[:top_n]]
55
+
56
+ return reranked_docs