jarvis-ai-assistant 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,6 +30,8 @@ class JarvisRAGPipeline:
30
30
  embedding_model: Optional[str] = None,
31
31
  db_path: Optional[str] = None,
32
32
  collection_name: str = "jarvis_rag_collection",
33
+ use_bm25: bool = True,
34
+ use_rerank: bool = True,
33
35
  ):
34
36
  """
35
37
  初始化RAG管道。
@@ -40,6 +42,8 @@ class JarvisRAGPipeline:
40
42
  embedding_model: 嵌入模型的名称。如果为None,则使用配置值。
41
43
  db_path: 持久化向量数据库的路径。如果为None,则使用配置值。
42
44
  collection_name: 向量数据库中集合的名称。
45
+ use_bm25: 是否在检索中使用BM25。
46
+ use_rerank: 是否在检索后使用重排器。
43
47
  """
44
48
  # 确定嵌入模型以隔离数据路径
45
49
  model_name = embedding_model or get_rag_embedding_model()
@@ -56,22 +60,87 @@ class JarvisRAGPipeline:
56
60
  get_rag_embedding_cache_path(), sanitized_model_name
57
61
  )
58
62
 
59
- self.embedding_manager = EmbeddingManager(
60
- model_name=model_name,
61
- cache_dir=_final_cache_path,
63
+ # 存储初始化参数以供延迟加载
64
+ self.llm = llm if llm is not None else ToolAgent_LLM()
65
+ self.embedding_model_name = embedding_model or get_rag_embedding_model()
66
+ self.db_path = db_path
67
+ self.collection_name = collection_name
68
+ self.use_bm25 = use_bm25
69
+ self.use_rerank = use_rerank
70
+
71
+ # 延迟加载的组件
72
+ self._embedding_manager: Optional[EmbeddingManager] = None
73
+ self._retriever: Optional[ChromaRetriever] = None
74
+ self._reranker: Optional[Reranker] = None
75
+ self._query_rewriter: Optional[QueryRewriter] = None
76
+
77
+ print("✅ JarvisRAGPipeline 初始化成功 (模型按需加载).")
78
+
79
+ def _get_embedding_manager(self) -> EmbeddingManager:
80
+ if self._embedding_manager is None:
81
+ sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
82
+ "\\", "_"
83
+ )
84
+ _final_cache_path = os.path.join(
85
+ get_rag_embedding_cache_path(), sanitized_model_name
86
+ )
87
+ self._embedding_manager = EmbeddingManager(
88
+ model_name=self.embedding_model_name,
89
+ cache_dir=_final_cache_path,
90
+ )
91
+ return self._embedding_manager
92
+
93
+ def _get_retriever(self) -> ChromaRetriever:
94
+ if self._retriever is None:
95
+ sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
96
+ "\\", "_"
97
+ )
98
+ _final_db_path = (
99
+ str(self.db_path)
100
+ if self.db_path
101
+ else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
102
+ )
103
+ self._retriever = ChromaRetriever(
104
+ embedding_manager=self._get_embedding_manager(),
105
+ db_path=_final_db_path,
106
+ collection_name=self.collection_name,
107
+ )
108
+ return self._retriever
109
+
110
+ def _get_collection(self):
111
+ """
112
+ 在不加载嵌入模型的情况下,直接获取并返回Chroma集合对象。
113
+ 这对于仅需要访问集合元数据(如列出文档)而无需嵌入功能的操作非常有用。
114
+ """
115
+ # 为了避免初始化embedding_manager,我们直接构建db_path
116
+ if self._retriever:
117
+ return self._retriever.collection
118
+
119
+ sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
120
+ "\\", "_"
62
121
  )
63
- self.retriever = ChromaRetriever(
64
- embedding_manager=self.embedding_manager,
65
- db_path=_final_db_path,
66
- collection_name=collection_name,
122
+ _final_db_path = (
123
+ str(self.db_path)
124
+ if self.db_path
125
+ else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
67
126
  )
68
- # 除非提供了特定的LLM,否则默认为ToolAgent_LLM
69
- self.llm = llm if llm is not None else ToolAgent_LLM()
70
- self.reranker = Reranker(model_name=get_rag_rerank_model())
71
- # 使用标准LLM执行查询重写任务,而不是代理
72
- self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
73
127
 
74
- print("✅ JarvisRAGPipeline 初始化成功。")
128
+ # 直接创建ChromaRetriever所使用的chroma_client,但绕过embedding_manager
129
+ import chromadb
130
+
131
+ chroma_client = chromadb.PersistentClient(path=_final_db_path)
132
+ return chroma_client.get_collection(name=self.collection_name)
133
+
134
+ def _get_reranker(self) -> Reranker:
135
+ if self._reranker is None:
136
+ self._reranker = Reranker(model_name=get_rag_rerank_model())
137
+ return self._reranker
138
+
139
+ def _get_query_rewriter(self) -> QueryRewriter:
140
+ if self._query_rewriter is None:
141
+ # 使用标准LLM执行查询重写任务,而不是代理
142
+ self._query_rewriter = QueryRewriter(JarvisPlatform_LLM())
143
+ return self._query_rewriter
75
144
 
76
145
  def add_documents(self, documents: List[Document]):
77
146
  """
@@ -80,24 +149,21 @@ class JarvisRAGPipeline:
80
149
  参数:
81
150
  documents: 要添加的LangChain文档对象列表。
82
151
  """
83
- self.retriever.add_documents(documents)
152
+ self._get_retriever().add_documents(documents)
84
153
 
85
- def _create_prompt(
86
- self, query: str, context_docs: List[Document], source_files: List[str]
87
- ) -> str:
154
+ def _create_prompt(self, query: str, context_docs: List[Document]) -> str:
88
155
  """为LLM或代理创建最终的提示。"""
89
- context = "\n\n".join([doc.page_content for doc in context_docs])
90
- sources_text = "\n".join([f"- {source}" for source in source_files])
156
+ context_details = []
157
+ for doc in context_docs:
158
+ source = doc.metadata.get("source", "未知来源")
159
+ content = doc.page_content
160
+ context_details.append(f"来源: {source}\n\n---\n{content}\n---")
161
+ context = "\n\n".join(context_details)
91
162
 
92
163
  prompt_template = f"""
93
164
  你是一个专家助手。请根据用户的问题,结合下面提供的参考信息来回答。
94
165
 
95
- **重要**: 提供的上下文和文件列表**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
96
-
97
- 参考文件列表:
98
- ---
99
- {sources_text}
100
- ---
166
+ **重要**: 提供的上下文**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
101
167
 
102
168
  参考上下文:
103
169
  ---
@@ -122,13 +188,15 @@ class JarvisRAGPipeline:
122
188
  由LLM生成的答案。
123
189
  """
124
190
  # 1. 将原始查询重写为多个查询
125
- rewritten_queries = self.query_rewriter.rewrite(query_text)
191
+ rewritten_queries = self._get_query_rewriter().rewrite(query_text)
126
192
 
127
193
  # 2. 为每个重写的查询检索初始候选文档
128
194
  all_candidate_docs = []
129
195
  for q in rewritten_queries:
130
196
  print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
131
- candidates = self.retriever.retrieve(q, n_results=n_results * 2)
197
+ candidates = self._get_retriever().retrieve(
198
+ q, n_results=n_results * 2, use_bm25=self.use_bm25
199
+ )
132
200
  all_candidate_docs.extend(candidates)
133
201
 
134
202
  # 对候选文档进行去重
@@ -139,12 +207,13 @@ class JarvisRAGPipeline:
139
207
  return "我在提供的文档中找不到任何相关信息来回答您的问题。"
140
208
 
141
209
  # 3. 根据*原始*查询对统一的候选池进行重排
142
- print(
143
- f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
144
- )
145
- retrieved_docs = self.reranker.rerank(
146
- query_text, unique_candidate_docs, top_n=n_results
147
- )
210
+ if self.use_rerank:
211
+ print(f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)...")
212
+ retrieved_docs = self._get_reranker().rerank(
213
+ query_text, unique_candidate_docs, top_n=n_results
214
+ )
215
+ else:
216
+ retrieved_docs = unique_candidate_docs[:n_results]
148
217
 
149
218
  if not retrieved_docs:
150
219
  return "我在提供的文档中找不到任何相关信息来回答您的问题。"
@@ -166,9 +235,49 @@ class JarvisRAGPipeline:
166
235
 
167
236
  # 4. 创建最终提示并生成答案
168
237
  # 我们使用原始的query_text作为给LLM的最终提示
169
- prompt = self._create_prompt(query_text, retrieved_docs, sources)
238
+ prompt = self._create_prompt(query_text, retrieved_docs)
170
239
 
171
240
  print("🤖 正在从LLM生成答案...")
172
241
  answer = self.llm.generate(prompt)
173
242
 
174
243
  return answer
244
+
245
+ def retrieve_only(self, query_text: str, n_results: int = 5) -> List[Document]:
246
+ """
247
+ 仅执行检索和重排,不生成答案。
248
+
249
+ 参数:
250
+ query_text: 用户的原始问题。
251
+ n_results: 要检索的最终相关块的数量。
252
+
253
+ 返回:
254
+ 检索到的文档列表。
255
+ """
256
+ # 1. 重写查询
257
+ rewritten_queries = self._get_query_rewriter().rewrite(query_text)
258
+
259
+ # 2. 检索候选文档
260
+ all_candidate_docs = []
261
+ for q in rewritten_queries:
262
+ print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
263
+ candidates = self._get_retriever().retrieve(
264
+ q, n_results=n_results * 2, use_bm25=self.use_bm25
265
+ )
266
+ all_candidate_docs.extend(candidates)
267
+
268
+ unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
269
+ unique_candidate_docs = list(unique_docs_dict.values())
270
+
271
+ if not unique_candidate_docs:
272
+ return []
273
+
274
+ # 3. 重排
275
+ if self.use_rerank:
276
+ print(f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排...")
277
+ retrieved_docs = self._get_reranker().rerank(
278
+ query_text, unique_candidate_docs, top_n=n_results
279
+ )
280
+ else:
281
+ retrieved_docs = unique_candidate_docs[:n_results]
282
+
283
+ return retrieved_docs
@@ -39,9 +39,7 @@ class ChromaRetriever:
39
39
  self.collection = self.client.get_or_create_collection(
40
40
  name=self.collection_name
41
41
  )
42
- print(
43
- f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。"
44
- )
42
+ print(f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。")
45
43
 
46
44
  # BM25索引设置
47
45
  self.bm25_index_path = os.path.join(self.db_path, f"{collection_name}_bm25.pkl")
@@ -107,7 +105,9 @@ class ChromaRetriever:
107
105
  self.bm25_index = BM25Okapi(self.bm25_corpus)
108
106
  self._save_bm25_index()
109
107
 
110
- def retrieve(self, query: str, n_results: int = 5) -> List[Document]:
108
+ def retrieve(
109
+ self, query: str, n_results: int = 5, use_bm25: bool = True
110
+ ) -> List[Document]:
111
111
  """
112
112
  使用向量搜索和BM25执行混合检索,然后使用倒数排序融合(RRF)
113
113
  对结果进行融合。
@@ -121,7 +121,7 @@ class ChromaRetriever:
121
121
 
122
122
  # 2. 关键字搜索 (BM25)
123
123
  bm25_docs = []
124
- if self.bm25_index:
124
+ if self.bm25_index and use_bm25:
125
125
  tokenized_query = query.split()
126
126
  doc_scores = self.bm25_index.get_scores(tokenized_query)
127
127
 
@@ -12,7 +12,9 @@ class generate_new_tool:
12
12
  生成并注册新的Jarvis工具。该工具会在用户数据目录下创建新的工具文件,
13
13
  并自动注册到当前的工具注册表中。适用场景:1. 需要创建新的自定义工具;
14
14
  2. 扩展Jarvis功能;3. 自动化重复性操作;4. 封装特定领域的功能。
15
- 重要提示:在编写工具代码时,应尽量将工具执行的过程和结果打印出来,方便追踪工具的执行状态。
15
+ 重要提示:
16
+ 1. `tool_name` 参数必须与 `tool_code` 中定义的 `name` 属性完全一致。
17
+ 2. 在编写工具代码时,应尽量将工具执行的过程和结果打印出来,方便追踪工具的执行状态。
16
18
  """
17
19
 
18
20
  parameters = {
@@ -75,6 +77,25 @@ class generate_new_tool:
75
77
  "stderr": f"工具名称 '{tool_name}' 不是有效的Python标识符",
76
78
  }
77
79
 
80
+ # 验证工具代码中的名称是否与tool_name一致
81
+ import re
82
+
83
+ match = re.search(r"^\s*name\s*=\s*[\"'](.+?)[\"']", tool_code, re.MULTILINE)
84
+ if not match:
85
+ return {
86
+ "success": False,
87
+ "stdout": "",
88
+ "stderr": "无法在工具代码中找到 'name' 属性。请确保工具类中包含 'name = \"your_tool_name\"'。",
89
+ }
90
+
91
+ code_name = match.group(1)
92
+ if tool_name != code_name:
93
+ return {
94
+ "success": False,
95
+ "stdout": "",
96
+ "stderr": f"工具名称不一致:参数 'tool_name' ('{tool_name}') 与代码中的 'name' 属性 ('{code_name}') 必须相同。",
97
+ }
98
+
78
99
  # 准备工具目录
79
100
  tools_dir = Path(get_data_dir()) / "tools"
80
101
  tools_dir.mkdir(parents=True, exist_ok=True)
@@ -115,22 +115,26 @@ def get_shell_name() -> str:
115
115
  return os.path.basename(shell_path).lower()
116
116
 
117
117
 
118
- def _get_resolved_model_config(model_group_override: Optional[str] = None) -> Dict[str, Any]:
118
+ def _get_resolved_model_config(
119
+ model_group_override: Optional[str] = None,
120
+ ) -> Dict[str, Any]:
119
121
  """
120
122
  解析并合并模型配置,处理模型组。
121
123
 
122
124
  优先级顺序:
123
125
  1. 单独的环境变量 (JARVIS_PLATFORM, JARVIS_MODEL, etc.)
124
- 2. JARVIS_MODEL_GROUP 中定义的组配置
126
+ 2. JARVIS_LLM_GROUP 中定义的组配置
125
127
  3. 代码中的默认值
126
128
 
127
129
  返回:
128
130
  Dict[str, Any]: 解析后的模型配置字典
129
131
  """
130
132
  group_config = {}
131
- model_group_name = model_group_override or GLOBAL_CONFIG_DATA.get("JARVIS_MODEL_GROUP")
133
+ model_group_name = model_group_override or GLOBAL_CONFIG_DATA.get(
134
+ "JARVIS_LLM_GROUP"
135
+ )
132
136
  # The format is a list of single-key dicts: [{'group_name': {...}}, ...]
133
- model_groups = GLOBAL_CONFIG_DATA.get("JARVIS_MODEL_GROUPS", [])
137
+ model_groups = GLOBAL_CONFIG_DATA.get("JARVIS_LLM_GROUPS", [])
134
138
 
135
139
  if model_group_name and isinstance(model_groups, list):
136
140
  for group_item in model_groups:
@@ -202,7 +206,9 @@ def get_thinking_model_name(model_group_override: Optional[str] = None) -> str:
202
206
  """
203
207
  config = _get_resolved_model_config(model_group_override)
204
208
  # Fallback to normal model if thinking model is not specified
205
- return config.get("JARVIS_THINKING_MODEL", get_normal_model_name(model_group_override))
209
+ return config.get(
210
+ "JARVIS_THINKING_MODEL", get_normal_model_name(model_group_override)
211
+ )
206
212
 
207
213
 
208
214
  def is_execute_tool_confirm() -> bool:
@@ -334,14 +340,65 @@ def get_mcp_config() -> List[Dict[str, Any]]:
334
340
  # ==============================================================================
335
341
 
336
342
 
337
- def get_rag_config() -> Dict[str, Any]:
343
+ DEFAULT_RAG_GROUPS = [
344
+ {
345
+ "text": {
346
+ "embedding_model": "BAAI/bge-m3",
347
+ "rerank_model": "BAAI/bge-reranker-v2-m3",
348
+ "use_bm25": True,
349
+ "use_rerank": True,
350
+ }
351
+ },
352
+ {
353
+ "code": {
354
+ "embedding_model": "Qodo/Qodo-Embed-1-7B",
355
+ "use_bm25": False,
356
+ "use_rerank": False,
357
+ }
358
+ },
359
+ ]
360
+
361
+
362
+ def _get_resolved_rag_config(
363
+ rag_group_override: Optional[str] = None,
364
+ ) -> Dict[str, Any]:
338
365
  """
339
- 获取RAG框架的配置。
366
+ 解析并合并RAG配置,处理RAG组。
367
+
368
+ 优先级顺序:
369
+ 1. JARVIS_RAG 中的顶级设置 (embedding_model, etc.)
370
+ 2. JARVIS_RAG_GROUP 中定义的组配置
371
+ 3. 代码中的默认值
340
372
 
341
373
  返回:
342
- Dict[str, Any]: RAG配置字典
374
+ Dict[str, Any]: 解析后的RAG配置字典
343
375
  """
344
- return GLOBAL_CONFIG_DATA.get("JARVIS_RAG", {})
376
+ group_config = {}
377
+ rag_group_name = rag_group_override or GLOBAL_CONFIG_DATA.get("JARVIS_RAG_GROUP")
378
+ rag_groups = GLOBAL_CONFIG_DATA.get("JARVIS_RAG_GROUPS", DEFAULT_RAG_GROUPS)
379
+
380
+ if rag_group_name and isinstance(rag_groups, list):
381
+ for group_item in rag_groups:
382
+ if isinstance(group_item, dict) and rag_group_name in group_item:
383
+ group_config = group_item[rag_group_name]
384
+ break
385
+
386
+ # Start with group config
387
+ resolved_config = group_config.copy()
388
+
389
+ # Override with specific settings from the top-level JARVIS_RAG dict
390
+ top_level_rag_config = GLOBAL_CONFIG_DATA.get("JARVIS_RAG", {})
391
+ if isinstance(top_level_rag_config, dict):
392
+ for key in [
393
+ "embedding_model",
394
+ "rerank_model",
395
+ "use_bm25",
396
+ "use_rerank",
397
+ ]:
398
+ if key in top_level_rag_config:
399
+ resolved_config[key] = top_level_rag_config[key]
400
+
401
+ return resolved_config
345
402
 
346
403
 
347
404
  def get_rag_embedding_model() -> str:
@@ -351,7 +408,8 @@ def get_rag_embedding_model() -> str:
351
408
  返回:
352
409
  str: 嵌入模型的名称
353
410
  """
354
- return get_rag_config().get("embedding_model", "BAAI/bge-base-zh-v1.5")
411
+ config = _get_resolved_rag_config()
412
+ return config.get("embedding_model", "BAAI/bge-m3")
355
413
 
356
414
 
357
415
  def get_rag_rerank_model() -> str:
@@ -361,7 +419,8 @@ def get_rag_rerank_model() -> str:
361
419
  返回:
362
420
  str: rerank模型的名称
363
421
  """
364
- return get_rag_config().get("rerank_model", "BAAI/bge-reranker-base")
422
+ config = _get_resolved_rag_config()
423
+ return config.get("rerank_model", "BAAI/bge-reranker-v2-m3")
365
424
 
366
425
 
367
426
  def get_rag_embedding_cache_path() -> str:
@@ -382,3 +441,25 @@ def get_rag_vector_db_path() -> str:
382
441
  str: 数据库路径
383
442
  """
384
443
  return ".jarvis/rag/vectordb"
444
+
445
+
446
+ def get_rag_use_bm25() -> bool:
447
+ """
448
+ 获取RAG是否使用BM25。
449
+
450
+ 返回:
451
+ bool: 如果使用BM25则返回True,默认为True
452
+ """
453
+ config = _get_resolved_rag_config()
454
+ return config.get("use_bm25", True) is True
455
+
456
+
457
+ def get_rag_use_rerank() -> bool:
458
+ """
459
+ 获取RAG是否使用rerank。
460
+
461
+ 返回:
462
+ bool: 如果使用rerank则返回True,默认为True
463
+ """
464
+ config = _get_resolved_rag_config()
465
+ return config.get("use_rerank", True) is True
@@ -9,9 +9,11 @@
9
9
  """
10
10
  import os
11
11
 
12
- # 全局变量:保存最后一条消息
13
- last_message: str = ""
14
- from typing import Any, Dict, Set
12
+ # 全局变量:保存消息历史
13
+ from typing import Any, Dict, Set, List
14
+
15
+ message_history: List[str] = []
16
+ MAX_HISTORY_SIZE = 50
15
17
 
16
18
  import colorama
17
19
  from rich.console import Console
@@ -169,13 +171,18 @@ def get_interrupt() -> int:
169
171
 
170
172
  def set_last_message(message: str) -> None:
171
173
  """
172
- 设置最后一条消息。
174
+ 将消息添加到历史记录中。
173
175
 
174
176
  参数:
175
177
  message: 要保存的消息
176
178
  """
177
- global last_message
178
- last_message = message
179
+ global message_history
180
+ if message:
181
+ # 避免重复添加
182
+ if not message_history or message_history[-1] != message:
183
+ message_history.append(message)
184
+ if len(message_history) > MAX_HISTORY_SIZE:
185
+ message_history.pop(0)
179
186
 
180
187
 
181
188
  def get_last_message() -> str:
@@ -183,6 +190,20 @@ def get_last_message() -> str:
183
190
  获取最后一条消息。
184
191
 
185
192
  返回:
186
- str: 最后一条消息
193
+ str: 最后一条消息,如果历史记录为空则返回空字符串
194
+ """
195
+ global message_history
196
+ if message_history:
197
+ return message_history[-1]
198
+ return ""
199
+
200
+
201
+ def get_message_history() -> List[str]:
202
+ """
203
+ 获取完整的消息历史记录。
204
+
205
+ 返回:
206
+ List[str]: 消息历史列表
187
207
  """
188
- return last_message
208
+ global message_history
209
+ return message_history