jarvis-ai-assistant 0.1.219__py3-none-any.whl → 0.1.221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +98 -440
  3. jarvis/jarvis_agent/edit_file_handler.py +32 -185
  4. jarvis/jarvis_agent/prompt_builder.py +57 -0
  5. jarvis/jarvis_agent/prompts.py +188 -0
  6. jarvis/jarvis_agent/protocols.py +30 -0
  7. jarvis/jarvis_agent/session_manager.py +84 -0
  8. jarvis/jarvis_agent/tool_executor.py +49 -0
  9. jarvis/jarvis_code_agent/code_agent.py +4 -4
  10. jarvis/jarvis_data/config_schema.json +20 -0
  11. jarvis/jarvis_platform/yuanbao.py +3 -1
  12. jarvis/jarvis_rag/__init__.py +11 -0
  13. jarvis/jarvis_rag/cache.py +85 -0
  14. jarvis/jarvis_rag/cli.py +386 -0
  15. jarvis/jarvis_rag/embedding_manager.py +95 -0
  16. jarvis/jarvis_rag/llm_interface.py +128 -0
  17. jarvis/jarvis_rag/query_rewriter.py +62 -0
  18. jarvis/jarvis_rag/rag_pipeline.py +174 -0
  19. jarvis/jarvis_rag/reranker.py +56 -0
  20. jarvis/jarvis_rag/retriever.py +201 -0
  21. jarvis/jarvis_tools/edit_file.py +11 -36
  22. jarvis/jarvis_utils/config.py +56 -0
  23. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/METADATA +90 -8
  24. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/RECORD +28 -14
  25. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/entry_points.txt +1 -0
  26. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/WHEEL +0 -0
  27. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/licenses/LICENSE +0 -0
  28. {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
1
+ # -*- coding: utf-8 -*-
2
+ from typing import Any, Tuple, TYPE_CHECKING
3
+
4
+ from jarvis.jarvis_utils.input import user_confirm
5
+ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
6
+
7
+ if TYPE_CHECKING:
8
+ from jarvis.jarvis_agent import Agent
9
+
10
+
11
+ def execute_tool_call(response: str, agent: "Agent") -> Tuple[bool, Any]:
12
+ """
13
+ Parses the model's response, identifies the appropriate tool, and executes it.
14
+
15
+ Args:
16
+ response: The response string from the model, potentially containing a tool call.
17
+ agent: The agent instance, providing context like output handlers and settings.
18
+
19
+ Returns:
20
+ A tuple containing:
21
+ - A boolean indicating if the tool's result should be returned to the user.
22
+ - The result of the tool execution or an error message.
23
+ """
24
+ tool_list = []
25
+ for handler in agent.output_handler:
26
+ if handler.can_handle(response):
27
+ tool_list.append(handler)
28
+
29
+ if len(tool_list) > 1:
30
+ error_message = (
31
+ f"操作失败:检测到多个操作。一次只能执行一个操作。"
32
+ f"尝试执行的操作:{', '.join([handler.name() for handler in tool_list])}"
33
+ )
34
+ PrettyOutput.print(error_message, OutputType.WARNING)
35
+ return False, error_message
36
+
37
+ if not tool_list:
38
+ return False, ""
39
+
40
+ tool_to_execute = tool_list[0]
41
+ if not agent.execute_tool_confirm or user_confirm(
42
+ f"需要执行{tool_to_execute.name()}确认执行?", True
43
+ ):
44
+ print(f"🔧 正在执行{tool_to_execute.name()}...")
45
+ result = tool_to_execute.handle(response, agent)
46
+ print(f"✅ {tool_to_execute.name()}执行完成")
47
+ return result
48
+
49
+ return False, ""
@@ -392,19 +392,19 @@ class CodeAgent:
392
392
  return
393
393
  # 用户确认最终结果
394
394
  if commited:
395
- agent.prompt += final_ret
395
+ agent.session.prompt += final_ret
396
396
  return
397
397
  PrettyOutput.print(final_ret, OutputType.USER, lang="markdown")
398
398
  if not is_confirm_before_apply_patch() or user_confirm(
399
399
  "是否使用此回复?", default=True
400
400
  ):
401
- agent.prompt += final_ret
401
+ agent.session.prompt += final_ret
402
402
  return
403
- agent.prompt += final_ret
403
+ agent.session.prompt += final_ret
404
404
  custom_reply = get_multiline_input("请输入自定义回复")
405
405
  if custom_reply.strip(): # 如果自定义回复为空,返回空字符串
406
406
  agent.set_addon_prompt(custom_reply)
407
- agent.prompt += final_ret
407
+ agent.session.prompt += final_ret
408
408
 
409
409
 
410
410
  def main() -> None:
@@ -181,6 +181,26 @@
181
181
  "description": "是否打印提示",
182
182
  "default": false
183
183
  },
184
+ "JARVIS_RAG": {
185
+ "type": "object",
186
+ "description": "RAG框架的配置",
187
+ "properties": {
188
+ "embedding_model": {
189
+ "type": "string",
190
+ "default": "BAAI/bge-base-zh-v1.5",
191
+ "description": "用于RAG的嵌入模型的名称, 默认为 'BAAI/bge-base-zh-v1.5'"
192
+ },
193
+ "rerank_model": {
194
+ "type": "string",
195
+ "default": "BAAI/bge-reranker-base",
196
+ "description": "用于RAG的rerank模型的名称, 默认为 'BAAI/bge-reranker-base'"
197
+ }
198
+ },
199
+ "default": {
200
+ "embedding_model": "BAAI/bge-base-zh-v1.5",
201
+ "rerank_model": "BAAI/bge-reranker-base"
202
+ }
203
+ },
184
204
  "JARVIS_REPLACE_MAP": {
185
205
  "type": "object",
186
206
  "description": "自定义替换映射表配置",
@@ -38,7 +38,9 @@ class YuanbaoPlatform(BasePlatform):
38
38
  self.agent_id = "naQivTmsDa"
39
39
 
40
40
  if not self.cookies:
41
- PrettyOutput.print("YUANBAO_COOKIES 未设置", OutputType.WARNING)
41
+ raise ValueError(
42
+ "YUANBAO_COOKIES environment variable not set. Please provide your cookies to use the Yuanbao platform."
43
+ )
42
44
 
43
45
  self.system_message = "" # 系统消息,用于初始化对话
44
46
  self.first_chat = True # 标识是否为第一次对话
@@ -0,0 +1,11 @@
1
+ """
2
+ Jarvis RAG 框架
3
+
4
+ 一个灵活的RAG管道,具有可插拔的远程LLM和本地带缓存的嵌入模型。
5
+ """
6
+
7
+ from .rag_pipeline import JarvisRAGPipeline
8
+ from .llm_interface import LLMInterface
9
+ from .embedding_manager import EmbeddingManager
10
+
11
+ __all__ = ["JarvisRAGPipeline", "LLMInterface", "EmbeddingManager"]
@@ -0,0 +1,85 @@
1
+ import hashlib
2
+ from typing import List, Optional, Any
3
+
4
+ from diskcache import Cache
5
+
6
+
7
+ class EmbeddingCache:
8
+ """
9
+ 一个用于存储和检索文本嵌入的基于磁盘的缓存。
10
+
11
+ 该类使用diskcache创建一个持久化的本地缓存。它根据每个文本内容的
12
+ SHA256哈希值为其生成一个键,使得查找过程具有确定性和高效性。
13
+ """
14
+
15
+ def __init__(self, cache_dir: str, salt: str = ""):
16
+ """
17
+ 初始化EmbeddingCache。
18
+
19
+ 参数:
20
+ cache_dir (str): 缓存将要存储的目录。
21
+ salt (str): 添加到哈希中的盐值。这对于确保由不同模型生成的
22
+ 嵌入不会发生冲突至关重要。例如,可以使用模型名称作为盐值。
23
+ """
24
+ self.cache = Cache(cache_dir)
25
+ self.salt = salt
26
+
27
+ def _get_key(self, text: str) -> str:
28
+ """为一个给定的文本和盐值生成一个唯一的缓存键。"""
29
+ hash_object = hashlib.sha256((self.salt + text).encode("utf-8"))
30
+ return hash_object.hexdigest()
31
+
32
+ def get(self, text: str) -> Optional[Any]:
33
+ """
34
+ 从缓存中检索一个嵌入。
35
+
36
+ 参数:
37
+ text (str): 要查找的文本。
38
+
39
+ 返回:
40
+ 缓存的嵌入,如果不在缓存中则返回None。
41
+ """
42
+ key = self._get_key(text)
43
+ return self.cache.get(key)
44
+
45
+ def set(self, text: str, embedding: Any) -> None:
46
+ """
47
+ 在缓存中存储一个嵌入。
48
+
49
+ 参数:
50
+ text (str): 与嵌入相对应的文本。
51
+ embedding (Any): 要存储的嵌入向量。
52
+ """
53
+ key = self._get_key(text)
54
+ self.cache.set(key, embedding)
55
+
56
+ def get_batch(self, texts: List[str]) -> List[Optional[Any]]:
57
+ """
58
+ 从缓存中检索一批嵌入。
59
+
60
+ 参数:
61
+ texts (List[str]): 要查找的文本列表。
62
+
63
+ 返回:
64
+ 一个列表,其中包含缓存的嵌入,对于缓存未命中的情况则为None。
65
+ """
66
+ return [self.get(text) for text in texts]
67
+
68
+ def set_batch(self, texts: List[str], embeddings: List[Any]) -> None:
69
+ """
70
+ 在缓存中存储一批嵌入。
71
+
72
+ 参数:
73
+ texts (List[str]): 文本列表。
74
+ embeddings (List[Any]): 相应的嵌入列表。
75
+ """
76
+ if len(texts) != len(embeddings):
77
+ raise ValueError("Length of texts and embeddings must be the same.")
78
+
79
+ with self.cache.transact():
80
+ for text, embedding in zip(texts, embeddings):
81
+ self.set(text, embedding)
82
+
83
+ def close(self):
84
+ """关闭缓存连接。"""
85
+ self.cache.close()
@@ -0,0 +1,386 @@
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Optional, List, Literal, cast
5
+ import mimetypes
6
+
7
+ import pathspec
8
+ import typer
9
+ from langchain.docstore.document import Document
10
+ from langchain_community.document_loaders import (
11
+ TextLoader,
12
+ UnstructuredMarkdownLoader,
13
+ )
14
+ from langchain_core.document_loaders.base import BaseLoader
15
+ from rich.markdown import Markdown
16
+
17
+ from jarvis.jarvis_utils.utils import init_env
18
+
19
+
20
+ def is_likely_text_file(file_path: Path) -> bool:
21
+ """
22
+ 通过读取文件开头部分,检查文件是否可能为文本文件。
23
+ 此方法可以避免将大型二进制文件加载到内存中。
24
+ """
25
+ try:
26
+ # 启发式方法1:检查MIME类型(如果可用)
27
+ mime_type, _ = mimetypes.guess_type(file_path)
28
+ if mime_type and mime_type.startswith("text/"):
29
+ return True
30
+ if mime_type and any(x in mime_type for x in ["json", "xml", "javascript"]):
31
+ return True
32
+
33
+ # 启发式方法2:检查文件的前几KB中是否包含空字节
34
+ with open(file_path, "rb") as f:
35
+ chunk = f.read(4096) # 读取前4KB
36
+ if b"\x00" in chunk:
37
+ return False # 空字节是二进制文件的强指示符
38
+ return True
39
+ except Exception:
40
+ return False
41
+
42
+
43
+ # 确保项目根目录在Python路径中,以允许绝对导入
44
+ # 这使得脚本可以作为模块运行。
45
+ _project_root = os.path.abspath(
46
+ os.path.join(os.path.dirname(__file__), "..", "..", "..")
47
+ )
48
+ if _project_root not in sys.path:
49
+ sys.path.insert(0, _project_root)
50
+
51
+ from jarvis.jarvis_platform.base import BasePlatform
52
+ from jarvis.jarvis_platform.registry import PlatformRegistry
53
+ from jarvis.jarvis_rag.llm_interface import LLMInterface
54
+ from jarvis.jarvis_rag.rag_pipeline import JarvisRAGPipeline
55
+
56
+ app = typer.Typer(
57
+ name="jarvis-rag",
58
+ help="一个与Jarvis RAG框架交互的命令行工具。",
59
+ add_completion=False,
60
+ )
61
+
62
+
63
+ class _CustomPlatformLLM(LLMInterface):
64
+ """一个简单的包装器,使BasePlatform实例与LLMInterface兼容。"""
65
+
66
+ def __init__(self, platform: BasePlatform):
67
+ self.platform = platform
68
+ print(
69
+ f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'"
70
+ )
71
+
72
+ def generate(self, prompt: str, **kwargs) -> str:
73
+ return self.platform.chat_until_success(prompt)
74
+
75
+
76
+ def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInterface]:
77
+ """从指定的平台和模型创建LLM接口。"""
78
+ if not platform_name or not model_name:
79
+ return None
80
+ try:
81
+ registry = PlatformRegistry.get_global_platform_registry()
82
+ platform_instance = registry.create_platform(platform_name)
83
+ if not platform_instance:
84
+ print(f"❌ 错误: 平台 '{platform_name}' 未找到。")
85
+ return None
86
+ platform_instance.set_model_name(model_name)
87
+ platform_instance.set_suppress_output(True)
88
+ return _CustomPlatformLLM(platform_instance)
89
+ except Exception as e:
90
+ print(f"❌ 创建自定义LLM时出错: {e}")
91
+ return None
92
+
93
+
94
+ def _load_ragignore_spec() -> tuple[Optional[pathspec.PathSpec], Optional[Path]]:
95
+ """
96
+ 从项目根目录加载忽略模式。
97
+ 首先查找 `.jarvis/rag/.ragignore`,如果未找到,则回退到 `.gitignore`。
98
+ """
99
+ project_root_path = Path(_project_root)
100
+ ragignore_file = project_root_path / ".jarvis" / "rag" / ".ragignore"
101
+ gitignore_file = project_root_path / ".gitignore"
102
+
103
+ ignore_file_to_use = None
104
+ if ragignore_file.is_file():
105
+ ignore_file_to_use = ragignore_file
106
+ elif gitignore_file.is_file():
107
+ ignore_file_to_use = gitignore_file
108
+
109
+ if ignore_file_to_use:
110
+ try:
111
+ with open(ignore_file_to_use, "r", encoding="utf-8") as f:
112
+ patterns = f.read().splitlines()
113
+ spec = pathspec.PathSpec.from_lines("gitwildmatch", patterns)
114
+ print(f"✅ 加载忽略规则: {ignore_file_to_use}")
115
+ return spec, project_root_path
116
+ except Exception as e:
117
+ print(f"⚠️ 加载 {ignore_file_to_use.name} 文件失败: {e}")
118
+
119
+ return None, None
120
+
121
+
122
+ @app.command(
123
+ "add",
124
+ help="从文件、目录或glob模式(例如 'src/**/*.py')添加文档。",
125
+ )
126
+ def add_documents(
127
+ paths: List[Path] = typer.Argument(
128
+ ...,
129
+ help="文件/目录路径或glob模式。支持Shell扩展。",
130
+ ),
131
+ collection_name: str = typer.Option(
132
+ "jarvis_rag_collection",
133
+ "--collection",
134
+ "-c",
135
+ help="向量数据库中集合的名称。",
136
+ ),
137
+ embedding_model: Optional[str] = typer.Option(
138
+ None,
139
+ "--embedding-model",
140
+ "-e",
141
+ help="嵌入模型的名称。覆盖全局配置。",
142
+ ),
143
+ db_path: Optional[Path] = typer.Option(
144
+ None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
145
+ ),
146
+ batch_size: int = typer.Option(
147
+ 500,
148
+ "--batch-size",
149
+ "-b",
150
+ help="单个批次中要处理的文档数。",
151
+ ),
152
+ ):
153
+ """从不同来源向RAG知识库添加文档。"""
154
+ files_to_process = set()
155
+
156
+ for path_str in paths:
157
+ # Typer的List[Path]可能不会扩展glob,所以我们手动处理
158
+ from glob import glob
159
+
160
+ expanded_paths = glob(str(path_str), recursive=True)
161
+
162
+ for p_str in expanded_paths:
163
+ path = Path(p_str)
164
+ if not path.exists():
165
+ continue
166
+
167
+ if path.is_dir():
168
+ print(f"🔍 正在扫描目录: {path}")
169
+ for item in path.rglob("*"):
170
+ if item.is_file() and is_likely_text_file(item):
171
+ files_to_process.add(item)
172
+ elif path.is_file():
173
+ if is_likely_text_file(path):
174
+ files_to_process.add(path)
175
+ else:
176
+ print(f"⚠️ 跳过可能的二进制文件: {path}")
177
+
178
+ if not files_to_process:
179
+ print("⚠️ 在指定路径中未找到任何文本文件。")
180
+ return
181
+
182
+ # 使用 .ragignore 过滤文件
183
+ ragignore_spec, ragignore_root = _load_ragignore_spec()
184
+ if ragignore_spec and ragignore_root:
185
+ initial_count = len(files_to_process)
186
+ retained_files = set()
187
+ for file_path in files_to_process:
188
+ try:
189
+ # 将文件路径解析为绝对路径以确保正确比较
190
+ resolved_path = file_path.resolve()
191
+ relative_path = str(resolved_path.relative_to(ragignore_root))
192
+ if not ragignore_spec.match_file(relative_path):
193
+ retained_files.add(file_path)
194
+ except ValueError:
195
+ # 文件不在项目根目录下,保留它
196
+ retained_files.add(file_path)
197
+
198
+ ignored_count = initial_count - len(retained_files)
199
+ if ignored_count > 0:
200
+ print(f"ℹ️ 根据 .ragignore 规则过滤掉 {ignored_count} 个文件。")
201
+ files_to_process = retained_files
202
+
203
+ if not files_to_process:
204
+ print("⚠️ 所有找到的文本文件都被忽略规则过滤掉了。")
205
+ return
206
+
207
+ print(f"✅ 发现 {len(files_to_process)} 个独立文件待处理。")
208
+
209
+ try:
210
+ pipeline = JarvisRAGPipeline(
211
+ embedding_model=embedding_model,
212
+ db_path=str(db_path) if db_path else None,
213
+ collection_name=collection_name,
214
+ )
215
+
216
+ docs_batch: List[Document] = []
217
+ total_docs_added = 0
218
+ loader: BaseLoader
219
+
220
+ sorted_files = sorted(list(files_to_process))
221
+ total_files = len(sorted_files)
222
+
223
+ for i, file_path in enumerate(sorted_files):
224
+ try:
225
+ if file_path.suffix.lower() == ".md":
226
+ loader = UnstructuredMarkdownLoader(str(file_path))
227
+ else: # 对.txt和所有代码文件默认使用TextLoader
228
+ loader = TextLoader(str(file_path), encoding="utf-8")
229
+
230
+ docs_batch.extend(loader.load())
231
+ print(f"✅ 已加载: {file_path} (文件 {i + 1}/{total_files})")
232
+ except Exception as e:
233
+ print(f"⚠️ 加载失败 {file_path}: {e}")
234
+
235
+ # 当批处理已满或是最后一个文件时处理批处理
236
+ if docs_batch and (len(docs_batch) >= batch_size or (i + 1) == total_files):
237
+ print(f"⚙️ 正在处理批次,包含 {len(docs_batch)} 个文档...")
238
+ pipeline.add_documents(docs_batch)
239
+ total_docs_added += len(docs_batch)
240
+ print(f"✅ 成功添加 {len(docs_batch)} 个文档。")
241
+ docs_batch = [] # 清空批处理
242
+
243
+ if total_docs_added == 0:
244
+ print("❌ 未能成功加载任何文档。")
245
+ raise typer.Exit(code=1)
246
+
247
+ print(
248
+ f"✅ 成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。"
249
+ )
250
+
251
+ except Exception as e:
252
+ print(f"❌ 发生严重错误: {e}")
253
+ raise typer.Exit(code=1)
254
+
255
+
256
+ @app.command("list-docs", help="列出知识库中所有唯一的文档。")
257
+ def list_documents(
258
+ collection_name: str = typer.Option(
259
+ "jarvis_rag_collection",
260
+ "--collection",
261
+ "-c",
262
+ help="向量数据库中集合的名称。",
263
+ ),
264
+ db_path: Optional[Path] = typer.Option(
265
+ None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
266
+ ),
267
+ ):
268
+ """列出指定集合中的所有唯一文档。"""
269
+ try:
270
+ pipeline = JarvisRAGPipeline(
271
+ db_path=str(db_path) if db_path else None,
272
+ collection_name=collection_name,
273
+ )
274
+
275
+ collection = pipeline.retriever.collection
276
+ results = collection.get() # 获取集合中的所有项目
277
+
278
+ if not results or not results["metadatas"]:
279
+ print("ℹ️ 知识库中没有找到任何文档。")
280
+ return
281
+
282
+ # 从元数据中提取唯一的源文件路径
283
+ sources = set()
284
+ for metadata in results["metadatas"]:
285
+ if metadata:
286
+ source = metadata.get("source")
287
+ if isinstance(source, str):
288
+ sources.add(source)
289
+
290
+ if not sources:
291
+ print("ℹ️ 知识库中没有找到任何带有源信息的文档。")
292
+ return
293
+
294
+ print(f"📚 知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:")
295
+ for i, source in enumerate(sorted(list(sources)), 1):
296
+ print(f" {i}. {source}")
297
+
298
+ except Exception as e:
299
+ print(f"❌ 发生错误: {e}")
300
+ raise typer.Exit(code=1)
301
+
302
+
303
+ @app.command("query", help="向知识库提问。")
304
+ def query(
305
+ question: str = typer.Argument(..., help="要提出的问题。"),
306
+ collection_name: str = typer.Option(
307
+ "jarvis_rag_collection",
308
+ "--collection",
309
+ "-c",
310
+ help="向量数据库中集合的名称。",
311
+ ),
312
+ embedding_model: Optional[str] = typer.Option(
313
+ None,
314
+ "--embedding-model",
315
+ "-e",
316
+ help="嵌入模型的名称。覆盖全局配置。",
317
+ ),
318
+ db_path: Optional[Path] = typer.Option(
319
+ None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
320
+ ),
321
+ platform: Optional[str] = typer.Option(
322
+ None,
323
+ "--platform",
324
+ "-p",
325
+ help="为LLM指定平台名称。覆盖默认的思考模型。",
326
+ ),
327
+ model: Optional[str] = typer.Option(
328
+ None,
329
+ "--model",
330
+ "-m",
331
+ help="为LLM指定模型名称。需要 --platform。",
332
+ ),
333
+ ):
334
+ """查询RAG知识库并打印答案。"""
335
+ if model and not platform:
336
+ print("❌ 错误: --model 需要指定 --platform。")
337
+ raise typer.Exit(code=1)
338
+
339
+ try:
340
+ custom_llm = _create_custom_llm(platform, model) if platform and model else None
341
+ if (platform or model) and not custom_llm:
342
+ raise typer.Exit(code=1)
343
+
344
+ pipeline = JarvisRAGPipeline(
345
+ llm=custom_llm,
346
+ embedding_model=embedding_model,
347
+ db_path=str(db_path) if db_path else None,
348
+ collection_name=collection_name,
349
+ )
350
+
351
+ print(f"🤔 正在查询: '{question}'")
352
+ answer = pipeline.query(question)
353
+
354
+ print("💬 答案:")
355
+ # 我们仍然可以使用 rich.markdown.Markdown,因为 PrettyOutput 底层使用了 rich
356
+ from jarvis.jarvis_utils.globals import console
357
+
358
+ console.print(Markdown(answer))
359
+
360
+ except Exception as e:
361
+ print(f"❌ 发生错误: {e}")
362
+ raise typer.Exit(code=1)
363
+
364
+
365
+ _RAG_INSTALLED = False
366
+ try:
367
+ import langchain # noqa
368
+
369
+ _RAG_INSTALLED = True
370
+ except ImportError:
371
+ pass
372
+
373
+
374
+ def _check_rag_dependencies():
375
+ if not _RAG_INSTALLED:
376
+ print(
377
+ "❌ RAG依赖项未安装。"
378
+ "请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。"
379
+ )
380
+ raise typer.Exit(code=1)
381
+
382
+
383
+ def main():
384
+ _check_rag_dependencies()
385
+ init_env(welcome_str="Jarvis RAG")
386
+ app()
@@ -0,0 +1,95 @@
1
+ import torch
2
+ from typing import List, cast
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+
5
+ from .cache import EmbeddingCache
6
+
7
+
8
+ class EmbeddingManager:
9
+ """
10
+ 管理本地嵌入模型的加载和使用,并带有缓存功能。
11
+
12
+ 该类负责从Hugging Face加载指定的模型,并使用基于磁盘的缓存
13
+ 来避免为相同文本重新计算嵌入。
14
+ """
15
+
16
+ def __init__(self, model_name: str, cache_dir: str):
17
+ """
18
+ 初始化EmbeddingManager。
19
+
20
+ 参数:
21
+ model_name: 要加载的Hugging Face模型的名称。
22
+ cache_dir: 用于存储嵌入缓存的目录。
23
+ """
24
+ self.model_name = model_name
25
+
26
+ print(f"🚀 初始化嵌入管理器, 模型: '{self.model_name}'...")
27
+
28
+ # 缓存的salt是模型名称,以防止冲突
29
+ self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
30
+ self.model = self._load_model()
31
+
32
+ def _load_model(self) -> HuggingFaceEmbeddings:
33
+ """根据配置加载Hugging Face嵌入模型。"""
34
+ model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
35
+ encode_kwargs = {"normalize_embeddings": True}
36
+
37
+ try:
38
+ return HuggingFaceEmbeddings(
39
+ model_name=self.model_name,
40
+ model_kwargs=model_kwargs,
41
+ encode_kwargs=encode_kwargs,
42
+ show_progress=True,
43
+ )
44
+ except Exception as e:
45
+ print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
46
+ print("请确保您已安装 'sentence_transformers' 和 'torch'。")
47
+ raise
48
+
49
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
50
+ """
51
+ 使用缓存为文档列表计算嵌入。
52
+
53
+ 参数:
54
+ texts: 要嵌入的文档(字符串)列表。
55
+
56
+ 返回:
57
+ 一个嵌入列表,每个文档对应一个嵌入。
58
+ """
59
+ if not texts:
60
+ return []
61
+
62
+ # 检查缓存中是否已存在嵌入
63
+ cached_embeddings = self.cache.get_batch(texts)
64
+
65
+ texts_to_embed = []
66
+ indices_to_embed = []
67
+ for i, (text, cached) in enumerate(zip(texts, cached_embeddings)):
68
+ if cached is None:
69
+ texts_to_embed.append(text)
70
+ indices_to_embed.append(i)
71
+
72
+ # 为不在缓存中的文本计算嵌入
73
+ if texts_to_embed:
74
+ print(
75
+ f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
76
+ )
77
+ new_embeddings = self.model.embed_documents(texts_to_embed)
78
+
79
+ # 将新的嵌入存储在缓存中
80
+ self.cache.set_batch(texts_to_embed, new_embeddings)
81
+
82
+ # 将新的嵌入放回结果列表中
83
+ for i, embedding in zip(indices_to_embed, new_embeddings):
84
+ cached_embeddings[i] = embedding
85
+ else:
86
+ print(f"✅ 缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。")
87
+
88
+ return cast(List[List[float]], cached_embeddings)
89
+
90
+ def embed_query(self, text: str) -> List[float]:
91
+ """
92
+ 为单个查询计算嵌入。
93
+ 查询通常不被缓存,但如果需要可以添加。
94
+ """
95
+ return self.model.embed_query(text)