jarvis-ai-assistant 0.1.219__py3-none-any.whl → 0.1.221__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +98 -440
- jarvis/jarvis_agent/edit_file_handler.py +32 -185
- jarvis/jarvis_agent/prompt_builder.py +57 -0
- jarvis/jarvis_agent/prompts.py +188 -0
- jarvis/jarvis_agent/protocols.py +30 -0
- jarvis/jarvis_agent/session_manager.py +84 -0
- jarvis/jarvis_agent/tool_executor.py +49 -0
- jarvis/jarvis_code_agent/code_agent.py +4 -4
- jarvis/jarvis_data/config_schema.json +20 -0
- jarvis/jarvis_platform/yuanbao.py +3 -1
- jarvis/jarvis_rag/__init__.py +11 -0
- jarvis/jarvis_rag/cache.py +85 -0
- jarvis/jarvis_rag/cli.py +386 -0
- jarvis/jarvis_rag/embedding_manager.py +95 -0
- jarvis/jarvis_rag/llm_interface.py +128 -0
- jarvis/jarvis_rag/query_rewriter.py +62 -0
- jarvis/jarvis_rag/rag_pipeline.py +174 -0
- jarvis/jarvis_rag/reranker.py +56 -0
- jarvis/jarvis_rag/retriever.py +201 -0
- jarvis/jarvis_tools/edit_file.py +11 -36
- jarvis/jarvis_utils/config.py +56 -0
- {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/METADATA +90 -8
- {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/RECORD +28 -14
- {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/entry_points.txt +1 -0
- {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.219.dist-info → jarvis_ai_assistant-0.1.221.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
from typing import Any, Tuple, TYPE_CHECKING
|
3
|
+
|
4
|
+
from jarvis.jarvis_utils.input import user_confirm
|
5
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from jarvis.jarvis_agent import Agent
|
9
|
+
|
10
|
+
|
11
|
+
def execute_tool_call(response: str, agent: "Agent") -> Tuple[bool, Any]:
|
12
|
+
"""
|
13
|
+
Parses the model's response, identifies the appropriate tool, and executes it.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
response: The response string from the model, potentially containing a tool call.
|
17
|
+
agent: The agent instance, providing context like output handlers and settings.
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
A tuple containing:
|
21
|
+
- A boolean indicating if the tool's result should be returned to the user.
|
22
|
+
- The result of the tool execution or an error message.
|
23
|
+
"""
|
24
|
+
tool_list = []
|
25
|
+
for handler in agent.output_handler:
|
26
|
+
if handler.can_handle(response):
|
27
|
+
tool_list.append(handler)
|
28
|
+
|
29
|
+
if len(tool_list) > 1:
|
30
|
+
error_message = (
|
31
|
+
f"操作失败:检测到多个操作。一次只能执行一个操作。"
|
32
|
+
f"尝试执行的操作:{', '.join([handler.name() for handler in tool_list])}"
|
33
|
+
)
|
34
|
+
PrettyOutput.print(error_message, OutputType.WARNING)
|
35
|
+
return False, error_message
|
36
|
+
|
37
|
+
if not tool_list:
|
38
|
+
return False, ""
|
39
|
+
|
40
|
+
tool_to_execute = tool_list[0]
|
41
|
+
if not agent.execute_tool_confirm or user_confirm(
|
42
|
+
f"需要执行{tool_to_execute.name()}确认执行?", True
|
43
|
+
):
|
44
|
+
print(f"🔧 正在执行{tool_to_execute.name()}...")
|
45
|
+
result = tool_to_execute.handle(response, agent)
|
46
|
+
print(f"✅ {tool_to_execute.name()}执行完成")
|
47
|
+
return result
|
48
|
+
|
49
|
+
return False, ""
|
@@ -392,19 +392,19 @@ class CodeAgent:
|
|
392
392
|
return
|
393
393
|
# 用户确认最终结果
|
394
394
|
if commited:
|
395
|
-
agent.prompt += final_ret
|
395
|
+
agent.session.prompt += final_ret
|
396
396
|
return
|
397
397
|
PrettyOutput.print(final_ret, OutputType.USER, lang="markdown")
|
398
398
|
if not is_confirm_before_apply_patch() or user_confirm(
|
399
399
|
"是否使用此回复?", default=True
|
400
400
|
):
|
401
|
-
agent.prompt += final_ret
|
401
|
+
agent.session.prompt += final_ret
|
402
402
|
return
|
403
|
-
agent.prompt += final_ret
|
403
|
+
agent.session.prompt += final_ret
|
404
404
|
custom_reply = get_multiline_input("请输入自定义回复")
|
405
405
|
if custom_reply.strip(): # 如果自定义回复为空,返回空字符串
|
406
406
|
agent.set_addon_prompt(custom_reply)
|
407
|
-
agent.prompt += final_ret
|
407
|
+
agent.session.prompt += final_ret
|
408
408
|
|
409
409
|
|
410
410
|
def main() -> None:
|
@@ -181,6 +181,26 @@
|
|
181
181
|
"description": "是否打印提示",
|
182
182
|
"default": false
|
183
183
|
},
|
184
|
+
"JARVIS_RAG": {
|
185
|
+
"type": "object",
|
186
|
+
"description": "RAG框架的配置",
|
187
|
+
"properties": {
|
188
|
+
"embedding_model": {
|
189
|
+
"type": "string",
|
190
|
+
"default": "BAAI/bge-base-zh-v1.5",
|
191
|
+
"description": "用于RAG的嵌入模型的名称, 默认为 'BAAI/bge-base-zh-v1.5'"
|
192
|
+
},
|
193
|
+
"rerank_model": {
|
194
|
+
"type": "string",
|
195
|
+
"default": "BAAI/bge-reranker-base",
|
196
|
+
"description": "用于RAG的rerank模型的名称, 默认为 'BAAI/bge-reranker-base'"
|
197
|
+
}
|
198
|
+
},
|
199
|
+
"default": {
|
200
|
+
"embedding_model": "BAAI/bge-base-zh-v1.5",
|
201
|
+
"rerank_model": "BAAI/bge-reranker-base"
|
202
|
+
}
|
203
|
+
},
|
184
204
|
"JARVIS_REPLACE_MAP": {
|
185
205
|
"type": "object",
|
186
206
|
"description": "自定义替换映射表配置",
|
@@ -38,7 +38,9 @@ class YuanbaoPlatform(BasePlatform):
|
|
38
38
|
self.agent_id = "naQivTmsDa"
|
39
39
|
|
40
40
|
if not self.cookies:
|
41
|
-
|
41
|
+
raise ValueError(
|
42
|
+
"YUANBAO_COOKIES environment variable not set. Please provide your cookies to use the Yuanbao platform."
|
43
|
+
)
|
42
44
|
|
43
45
|
self.system_message = "" # 系统消息,用于初始化对话
|
44
46
|
self.first_chat = True # 标识是否为第一次对话
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
Jarvis RAG 框架
|
3
|
+
|
4
|
+
一个灵活的RAG管道,具有可插拔的远程LLM和本地带缓存的嵌入模型。
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .rag_pipeline import JarvisRAGPipeline
|
8
|
+
from .llm_interface import LLMInterface
|
9
|
+
from .embedding_manager import EmbeddingManager
|
10
|
+
|
11
|
+
__all__ = ["JarvisRAGPipeline", "LLMInterface", "EmbeddingManager"]
|
@@ -0,0 +1,85 @@
|
|
1
|
+
import hashlib
|
2
|
+
from typing import List, Optional, Any
|
3
|
+
|
4
|
+
from diskcache import Cache
|
5
|
+
|
6
|
+
|
7
|
+
class EmbeddingCache:
|
8
|
+
"""
|
9
|
+
一个用于存储和检索文本嵌入的基于磁盘的缓存。
|
10
|
+
|
11
|
+
该类使用diskcache创建一个持久化的本地缓存。它根据每个文本内容的
|
12
|
+
SHA256哈希值为其生成一个键,使得查找过程具有确定性和高效性。
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, cache_dir: str, salt: str = ""):
|
16
|
+
"""
|
17
|
+
初始化EmbeddingCache。
|
18
|
+
|
19
|
+
参数:
|
20
|
+
cache_dir (str): 缓存将要存储的目录。
|
21
|
+
salt (str): 添加到哈希中的盐值。这对于确保由不同模型生成的
|
22
|
+
嵌入不会发生冲突至关重要。例如,可以使用模型名称作为盐值。
|
23
|
+
"""
|
24
|
+
self.cache = Cache(cache_dir)
|
25
|
+
self.salt = salt
|
26
|
+
|
27
|
+
def _get_key(self, text: str) -> str:
|
28
|
+
"""为一个给定的文本和盐值生成一个唯一的缓存键。"""
|
29
|
+
hash_object = hashlib.sha256((self.salt + text).encode("utf-8"))
|
30
|
+
return hash_object.hexdigest()
|
31
|
+
|
32
|
+
def get(self, text: str) -> Optional[Any]:
|
33
|
+
"""
|
34
|
+
从缓存中检索一个嵌入。
|
35
|
+
|
36
|
+
参数:
|
37
|
+
text (str): 要查找的文本。
|
38
|
+
|
39
|
+
返回:
|
40
|
+
缓存的嵌入,如果不在缓存中则返回None。
|
41
|
+
"""
|
42
|
+
key = self._get_key(text)
|
43
|
+
return self.cache.get(key)
|
44
|
+
|
45
|
+
def set(self, text: str, embedding: Any) -> None:
|
46
|
+
"""
|
47
|
+
在缓存中存储一个嵌入。
|
48
|
+
|
49
|
+
参数:
|
50
|
+
text (str): 与嵌入相对应的文本。
|
51
|
+
embedding (Any): 要存储的嵌入向量。
|
52
|
+
"""
|
53
|
+
key = self._get_key(text)
|
54
|
+
self.cache.set(key, embedding)
|
55
|
+
|
56
|
+
def get_batch(self, texts: List[str]) -> List[Optional[Any]]:
|
57
|
+
"""
|
58
|
+
从缓存中检索一批嵌入。
|
59
|
+
|
60
|
+
参数:
|
61
|
+
texts (List[str]): 要查找的文本列表。
|
62
|
+
|
63
|
+
返回:
|
64
|
+
一个列表,其中包含缓存的嵌入,对于缓存未命中的情况则为None。
|
65
|
+
"""
|
66
|
+
return [self.get(text) for text in texts]
|
67
|
+
|
68
|
+
def set_batch(self, texts: List[str], embeddings: List[Any]) -> None:
|
69
|
+
"""
|
70
|
+
在缓存中存储一批嵌入。
|
71
|
+
|
72
|
+
参数:
|
73
|
+
texts (List[str]): 文本列表。
|
74
|
+
embeddings (List[Any]): 相应的嵌入列表。
|
75
|
+
"""
|
76
|
+
if len(texts) != len(embeddings):
|
77
|
+
raise ValueError("Length of texts and embeddings must be the same.")
|
78
|
+
|
79
|
+
with self.cache.transact():
|
80
|
+
for text, embedding in zip(texts, embeddings):
|
81
|
+
self.set(text, embedding)
|
82
|
+
|
83
|
+
def close(self):
|
84
|
+
"""关闭缓存连接。"""
|
85
|
+
self.cache.close()
|
jarvis/jarvis_rag/cli.py
ADDED
@@ -0,0 +1,386 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Optional, List, Literal, cast
|
5
|
+
import mimetypes
|
6
|
+
|
7
|
+
import pathspec
|
8
|
+
import typer
|
9
|
+
from langchain.docstore.document import Document
|
10
|
+
from langchain_community.document_loaders import (
|
11
|
+
TextLoader,
|
12
|
+
UnstructuredMarkdownLoader,
|
13
|
+
)
|
14
|
+
from langchain_core.document_loaders.base import BaseLoader
|
15
|
+
from rich.markdown import Markdown
|
16
|
+
|
17
|
+
from jarvis.jarvis_utils.utils import init_env
|
18
|
+
|
19
|
+
|
20
|
+
def is_likely_text_file(file_path: Path) -> bool:
|
21
|
+
"""
|
22
|
+
通过读取文件开头部分,检查文件是否可能为文本文件。
|
23
|
+
此方法可以避免将大型二进制文件加载到内存中。
|
24
|
+
"""
|
25
|
+
try:
|
26
|
+
# 启发式方法1:检查MIME类型(如果可用)
|
27
|
+
mime_type, _ = mimetypes.guess_type(file_path)
|
28
|
+
if mime_type and mime_type.startswith("text/"):
|
29
|
+
return True
|
30
|
+
if mime_type and any(x in mime_type for x in ["json", "xml", "javascript"]):
|
31
|
+
return True
|
32
|
+
|
33
|
+
# 启发式方法2:检查文件的前几KB中是否包含空字节
|
34
|
+
with open(file_path, "rb") as f:
|
35
|
+
chunk = f.read(4096) # 读取前4KB
|
36
|
+
if b"\x00" in chunk:
|
37
|
+
return False # 空字节是二进制文件的强指示符
|
38
|
+
return True
|
39
|
+
except Exception:
|
40
|
+
return False
|
41
|
+
|
42
|
+
|
43
|
+
# 确保项目根目录在Python路径中,以允许绝对导入
|
44
|
+
# 这使得脚本可以作为模块运行。
|
45
|
+
_project_root = os.path.abspath(
|
46
|
+
os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
47
|
+
)
|
48
|
+
if _project_root not in sys.path:
|
49
|
+
sys.path.insert(0, _project_root)
|
50
|
+
|
51
|
+
from jarvis.jarvis_platform.base import BasePlatform
|
52
|
+
from jarvis.jarvis_platform.registry import PlatformRegistry
|
53
|
+
from jarvis.jarvis_rag.llm_interface import LLMInterface
|
54
|
+
from jarvis.jarvis_rag.rag_pipeline import JarvisRAGPipeline
|
55
|
+
|
56
|
+
app = typer.Typer(
|
57
|
+
name="jarvis-rag",
|
58
|
+
help="一个与Jarvis RAG框架交互的命令行工具。",
|
59
|
+
add_completion=False,
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
class _CustomPlatformLLM(LLMInterface):
|
64
|
+
"""一个简单的包装器,使BasePlatform实例与LLMInterface兼容。"""
|
65
|
+
|
66
|
+
def __init__(self, platform: BasePlatform):
|
67
|
+
self.platform = platform
|
68
|
+
print(
|
69
|
+
f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'"
|
70
|
+
)
|
71
|
+
|
72
|
+
def generate(self, prompt: str, **kwargs) -> str:
|
73
|
+
return self.platform.chat_until_success(prompt)
|
74
|
+
|
75
|
+
|
76
|
+
def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInterface]:
|
77
|
+
"""从指定的平台和模型创建LLM接口。"""
|
78
|
+
if not platform_name or not model_name:
|
79
|
+
return None
|
80
|
+
try:
|
81
|
+
registry = PlatformRegistry.get_global_platform_registry()
|
82
|
+
platform_instance = registry.create_platform(platform_name)
|
83
|
+
if not platform_instance:
|
84
|
+
print(f"❌ 错误: 平台 '{platform_name}' 未找到。")
|
85
|
+
return None
|
86
|
+
platform_instance.set_model_name(model_name)
|
87
|
+
platform_instance.set_suppress_output(True)
|
88
|
+
return _CustomPlatformLLM(platform_instance)
|
89
|
+
except Exception as e:
|
90
|
+
print(f"❌ 创建自定义LLM时出错: {e}")
|
91
|
+
return None
|
92
|
+
|
93
|
+
|
94
|
+
def _load_ragignore_spec() -> tuple[Optional[pathspec.PathSpec], Optional[Path]]:
|
95
|
+
"""
|
96
|
+
从项目根目录加载忽略模式。
|
97
|
+
首先查找 `.jarvis/rag/.ragignore`,如果未找到,则回退到 `.gitignore`。
|
98
|
+
"""
|
99
|
+
project_root_path = Path(_project_root)
|
100
|
+
ragignore_file = project_root_path / ".jarvis" / "rag" / ".ragignore"
|
101
|
+
gitignore_file = project_root_path / ".gitignore"
|
102
|
+
|
103
|
+
ignore_file_to_use = None
|
104
|
+
if ragignore_file.is_file():
|
105
|
+
ignore_file_to_use = ragignore_file
|
106
|
+
elif gitignore_file.is_file():
|
107
|
+
ignore_file_to_use = gitignore_file
|
108
|
+
|
109
|
+
if ignore_file_to_use:
|
110
|
+
try:
|
111
|
+
with open(ignore_file_to_use, "r", encoding="utf-8") as f:
|
112
|
+
patterns = f.read().splitlines()
|
113
|
+
spec = pathspec.PathSpec.from_lines("gitwildmatch", patterns)
|
114
|
+
print(f"✅ 加载忽略规则: {ignore_file_to_use}")
|
115
|
+
return spec, project_root_path
|
116
|
+
except Exception as e:
|
117
|
+
print(f"⚠️ 加载 {ignore_file_to_use.name} 文件失败: {e}")
|
118
|
+
|
119
|
+
return None, None
|
120
|
+
|
121
|
+
|
122
|
+
@app.command(
|
123
|
+
"add",
|
124
|
+
help="从文件、目录或glob模式(例如 'src/**/*.py')添加文档。",
|
125
|
+
)
|
126
|
+
def add_documents(
|
127
|
+
paths: List[Path] = typer.Argument(
|
128
|
+
...,
|
129
|
+
help="文件/目录路径或glob模式。支持Shell扩展。",
|
130
|
+
),
|
131
|
+
collection_name: str = typer.Option(
|
132
|
+
"jarvis_rag_collection",
|
133
|
+
"--collection",
|
134
|
+
"-c",
|
135
|
+
help="向量数据库中集合的名称。",
|
136
|
+
),
|
137
|
+
embedding_model: Optional[str] = typer.Option(
|
138
|
+
None,
|
139
|
+
"--embedding-model",
|
140
|
+
"-e",
|
141
|
+
help="嵌入模型的名称。覆盖全局配置。",
|
142
|
+
),
|
143
|
+
db_path: Optional[Path] = typer.Option(
|
144
|
+
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
145
|
+
),
|
146
|
+
batch_size: int = typer.Option(
|
147
|
+
500,
|
148
|
+
"--batch-size",
|
149
|
+
"-b",
|
150
|
+
help="单个批次中要处理的文档数。",
|
151
|
+
),
|
152
|
+
):
|
153
|
+
"""从不同来源向RAG知识库添加文档。"""
|
154
|
+
files_to_process = set()
|
155
|
+
|
156
|
+
for path_str in paths:
|
157
|
+
# Typer的List[Path]可能不会扩展glob,所以我们手动处理
|
158
|
+
from glob import glob
|
159
|
+
|
160
|
+
expanded_paths = glob(str(path_str), recursive=True)
|
161
|
+
|
162
|
+
for p_str in expanded_paths:
|
163
|
+
path = Path(p_str)
|
164
|
+
if not path.exists():
|
165
|
+
continue
|
166
|
+
|
167
|
+
if path.is_dir():
|
168
|
+
print(f"🔍 正在扫描目录: {path}")
|
169
|
+
for item in path.rglob("*"):
|
170
|
+
if item.is_file() and is_likely_text_file(item):
|
171
|
+
files_to_process.add(item)
|
172
|
+
elif path.is_file():
|
173
|
+
if is_likely_text_file(path):
|
174
|
+
files_to_process.add(path)
|
175
|
+
else:
|
176
|
+
print(f"⚠️ 跳过可能的二进制文件: {path}")
|
177
|
+
|
178
|
+
if not files_to_process:
|
179
|
+
print("⚠️ 在指定路径中未找到任何文本文件。")
|
180
|
+
return
|
181
|
+
|
182
|
+
# 使用 .ragignore 过滤文件
|
183
|
+
ragignore_spec, ragignore_root = _load_ragignore_spec()
|
184
|
+
if ragignore_spec and ragignore_root:
|
185
|
+
initial_count = len(files_to_process)
|
186
|
+
retained_files = set()
|
187
|
+
for file_path in files_to_process:
|
188
|
+
try:
|
189
|
+
# 将文件路径解析为绝对路径以确保正确比较
|
190
|
+
resolved_path = file_path.resolve()
|
191
|
+
relative_path = str(resolved_path.relative_to(ragignore_root))
|
192
|
+
if not ragignore_spec.match_file(relative_path):
|
193
|
+
retained_files.add(file_path)
|
194
|
+
except ValueError:
|
195
|
+
# 文件不在项目根目录下,保留它
|
196
|
+
retained_files.add(file_path)
|
197
|
+
|
198
|
+
ignored_count = initial_count - len(retained_files)
|
199
|
+
if ignored_count > 0:
|
200
|
+
print(f"ℹ️ 根据 .ragignore 规则过滤掉 {ignored_count} 个文件。")
|
201
|
+
files_to_process = retained_files
|
202
|
+
|
203
|
+
if not files_to_process:
|
204
|
+
print("⚠️ 所有找到的文本文件都被忽略规则过滤掉了。")
|
205
|
+
return
|
206
|
+
|
207
|
+
print(f"✅ 发现 {len(files_to_process)} 个独立文件待处理。")
|
208
|
+
|
209
|
+
try:
|
210
|
+
pipeline = JarvisRAGPipeline(
|
211
|
+
embedding_model=embedding_model,
|
212
|
+
db_path=str(db_path) if db_path else None,
|
213
|
+
collection_name=collection_name,
|
214
|
+
)
|
215
|
+
|
216
|
+
docs_batch: List[Document] = []
|
217
|
+
total_docs_added = 0
|
218
|
+
loader: BaseLoader
|
219
|
+
|
220
|
+
sorted_files = sorted(list(files_to_process))
|
221
|
+
total_files = len(sorted_files)
|
222
|
+
|
223
|
+
for i, file_path in enumerate(sorted_files):
|
224
|
+
try:
|
225
|
+
if file_path.suffix.lower() == ".md":
|
226
|
+
loader = UnstructuredMarkdownLoader(str(file_path))
|
227
|
+
else: # 对.txt和所有代码文件默认使用TextLoader
|
228
|
+
loader = TextLoader(str(file_path), encoding="utf-8")
|
229
|
+
|
230
|
+
docs_batch.extend(loader.load())
|
231
|
+
print(f"✅ 已加载: {file_path} (文件 {i + 1}/{total_files})")
|
232
|
+
except Exception as e:
|
233
|
+
print(f"⚠️ 加载失败 {file_path}: {e}")
|
234
|
+
|
235
|
+
# 当批处理已满或是最后一个文件时处理批处理
|
236
|
+
if docs_batch and (len(docs_batch) >= batch_size or (i + 1) == total_files):
|
237
|
+
print(f"⚙️ 正在处理批次,包含 {len(docs_batch)} 个文档...")
|
238
|
+
pipeline.add_documents(docs_batch)
|
239
|
+
total_docs_added += len(docs_batch)
|
240
|
+
print(f"✅ 成功添加 {len(docs_batch)} 个文档。")
|
241
|
+
docs_batch = [] # 清空批处理
|
242
|
+
|
243
|
+
if total_docs_added == 0:
|
244
|
+
print("❌ 未能成功加载任何文档。")
|
245
|
+
raise typer.Exit(code=1)
|
246
|
+
|
247
|
+
print(
|
248
|
+
f"✅ 成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。"
|
249
|
+
)
|
250
|
+
|
251
|
+
except Exception as e:
|
252
|
+
print(f"❌ 发生严重错误: {e}")
|
253
|
+
raise typer.Exit(code=1)
|
254
|
+
|
255
|
+
|
256
|
+
@app.command("list-docs", help="列出知识库中所有唯一的文档。")
|
257
|
+
def list_documents(
|
258
|
+
collection_name: str = typer.Option(
|
259
|
+
"jarvis_rag_collection",
|
260
|
+
"--collection",
|
261
|
+
"-c",
|
262
|
+
help="向量数据库中集合的名称。",
|
263
|
+
),
|
264
|
+
db_path: Optional[Path] = typer.Option(
|
265
|
+
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
266
|
+
),
|
267
|
+
):
|
268
|
+
"""列出指定集合中的所有唯一文档。"""
|
269
|
+
try:
|
270
|
+
pipeline = JarvisRAGPipeline(
|
271
|
+
db_path=str(db_path) if db_path else None,
|
272
|
+
collection_name=collection_name,
|
273
|
+
)
|
274
|
+
|
275
|
+
collection = pipeline.retriever.collection
|
276
|
+
results = collection.get() # 获取集合中的所有项目
|
277
|
+
|
278
|
+
if not results or not results["metadatas"]:
|
279
|
+
print("ℹ️ 知识库中没有找到任何文档。")
|
280
|
+
return
|
281
|
+
|
282
|
+
# 从元数据中提取唯一的源文件路径
|
283
|
+
sources = set()
|
284
|
+
for metadata in results["metadatas"]:
|
285
|
+
if metadata:
|
286
|
+
source = metadata.get("source")
|
287
|
+
if isinstance(source, str):
|
288
|
+
sources.add(source)
|
289
|
+
|
290
|
+
if not sources:
|
291
|
+
print("ℹ️ 知识库中没有找到任何带有源信息的文档。")
|
292
|
+
return
|
293
|
+
|
294
|
+
print(f"📚 知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:")
|
295
|
+
for i, source in enumerate(sorted(list(sources)), 1):
|
296
|
+
print(f" {i}. {source}")
|
297
|
+
|
298
|
+
except Exception as e:
|
299
|
+
print(f"❌ 发生错误: {e}")
|
300
|
+
raise typer.Exit(code=1)
|
301
|
+
|
302
|
+
|
303
|
+
@app.command("query", help="向知识库提问。")
|
304
|
+
def query(
|
305
|
+
question: str = typer.Argument(..., help="要提出的问题。"),
|
306
|
+
collection_name: str = typer.Option(
|
307
|
+
"jarvis_rag_collection",
|
308
|
+
"--collection",
|
309
|
+
"-c",
|
310
|
+
help="向量数据库中集合的名称。",
|
311
|
+
),
|
312
|
+
embedding_model: Optional[str] = typer.Option(
|
313
|
+
None,
|
314
|
+
"--embedding-model",
|
315
|
+
"-e",
|
316
|
+
help="嵌入模型的名称。覆盖全局配置。",
|
317
|
+
),
|
318
|
+
db_path: Optional[Path] = typer.Option(
|
319
|
+
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
320
|
+
),
|
321
|
+
platform: Optional[str] = typer.Option(
|
322
|
+
None,
|
323
|
+
"--platform",
|
324
|
+
"-p",
|
325
|
+
help="为LLM指定平台名称。覆盖默认的思考模型。",
|
326
|
+
),
|
327
|
+
model: Optional[str] = typer.Option(
|
328
|
+
None,
|
329
|
+
"--model",
|
330
|
+
"-m",
|
331
|
+
help="为LLM指定模型名称。需要 --platform。",
|
332
|
+
),
|
333
|
+
):
|
334
|
+
"""查询RAG知识库并打印答案。"""
|
335
|
+
if model and not platform:
|
336
|
+
print("❌ 错误: --model 需要指定 --platform。")
|
337
|
+
raise typer.Exit(code=1)
|
338
|
+
|
339
|
+
try:
|
340
|
+
custom_llm = _create_custom_llm(platform, model) if platform and model else None
|
341
|
+
if (platform or model) and not custom_llm:
|
342
|
+
raise typer.Exit(code=1)
|
343
|
+
|
344
|
+
pipeline = JarvisRAGPipeline(
|
345
|
+
llm=custom_llm,
|
346
|
+
embedding_model=embedding_model,
|
347
|
+
db_path=str(db_path) if db_path else None,
|
348
|
+
collection_name=collection_name,
|
349
|
+
)
|
350
|
+
|
351
|
+
print(f"🤔 正在查询: '{question}'")
|
352
|
+
answer = pipeline.query(question)
|
353
|
+
|
354
|
+
print("💬 答案:")
|
355
|
+
# 我们仍然可以使用 rich.markdown.Markdown,因为 PrettyOutput 底层使用了 rich
|
356
|
+
from jarvis.jarvis_utils.globals import console
|
357
|
+
|
358
|
+
console.print(Markdown(answer))
|
359
|
+
|
360
|
+
except Exception as e:
|
361
|
+
print(f"❌ 发生错误: {e}")
|
362
|
+
raise typer.Exit(code=1)
|
363
|
+
|
364
|
+
|
365
|
+
_RAG_INSTALLED = False
|
366
|
+
try:
|
367
|
+
import langchain # noqa
|
368
|
+
|
369
|
+
_RAG_INSTALLED = True
|
370
|
+
except ImportError:
|
371
|
+
pass
|
372
|
+
|
373
|
+
|
374
|
+
def _check_rag_dependencies():
|
375
|
+
if not _RAG_INSTALLED:
|
376
|
+
print(
|
377
|
+
"❌ RAG依赖项未安装。"
|
378
|
+
"请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。"
|
379
|
+
)
|
380
|
+
raise typer.Exit(code=1)
|
381
|
+
|
382
|
+
|
383
|
+
def main():
|
384
|
+
_check_rag_dependencies()
|
385
|
+
init_env(welcome_str="Jarvis RAG")
|
386
|
+
app()
|
@@ -0,0 +1,95 @@
|
|
1
|
+
import torch
|
2
|
+
from typing import List, cast
|
3
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
4
|
+
|
5
|
+
from .cache import EmbeddingCache
|
6
|
+
|
7
|
+
|
8
|
+
class EmbeddingManager:
|
9
|
+
"""
|
10
|
+
管理本地嵌入模型的加载和使用,并带有缓存功能。
|
11
|
+
|
12
|
+
该类负责从Hugging Face加载指定的模型,并使用基于磁盘的缓存
|
13
|
+
来避免为相同文本重新计算嵌入。
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, model_name: str, cache_dir: str):
|
17
|
+
"""
|
18
|
+
初始化EmbeddingManager。
|
19
|
+
|
20
|
+
参数:
|
21
|
+
model_name: 要加载的Hugging Face模型的名称。
|
22
|
+
cache_dir: 用于存储嵌入缓存的目录。
|
23
|
+
"""
|
24
|
+
self.model_name = model_name
|
25
|
+
|
26
|
+
print(f"🚀 初始化嵌入管理器, 模型: '{self.model_name}'...")
|
27
|
+
|
28
|
+
# 缓存的salt是模型名称,以防止冲突
|
29
|
+
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
|
30
|
+
self.model = self._load_model()
|
31
|
+
|
32
|
+
def _load_model(self) -> HuggingFaceEmbeddings:
|
33
|
+
"""根据配置加载Hugging Face嵌入模型。"""
|
34
|
+
model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
|
35
|
+
encode_kwargs = {"normalize_embeddings": True}
|
36
|
+
|
37
|
+
try:
|
38
|
+
return HuggingFaceEmbeddings(
|
39
|
+
model_name=self.model_name,
|
40
|
+
model_kwargs=model_kwargs,
|
41
|
+
encode_kwargs=encode_kwargs,
|
42
|
+
show_progress=True,
|
43
|
+
)
|
44
|
+
except Exception as e:
|
45
|
+
print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
|
46
|
+
print("请确保您已安装 'sentence_transformers' 和 'torch'。")
|
47
|
+
raise
|
48
|
+
|
49
|
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
50
|
+
"""
|
51
|
+
使用缓存为文档列表计算嵌入。
|
52
|
+
|
53
|
+
参数:
|
54
|
+
texts: 要嵌入的文档(字符串)列表。
|
55
|
+
|
56
|
+
返回:
|
57
|
+
一个嵌入列表,每个文档对应一个嵌入。
|
58
|
+
"""
|
59
|
+
if not texts:
|
60
|
+
return []
|
61
|
+
|
62
|
+
# 检查缓存中是否已存在嵌入
|
63
|
+
cached_embeddings = self.cache.get_batch(texts)
|
64
|
+
|
65
|
+
texts_to_embed = []
|
66
|
+
indices_to_embed = []
|
67
|
+
for i, (text, cached) in enumerate(zip(texts, cached_embeddings)):
|
68
|
+
if cached is None:
|
69
|
+
texts_to_embed.append(text)
|
70
|
+
indices_to_embed.append(i)
|
71
|
+
|
72
|
+
# 为不在缓存中的文本计算嵌入
|
73
|
+
if texts_to_embed:
|
74
|
+
print(
|
75
|
+
f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
|
76
|
+
)
|
77
|
+
new_embeddings = self.model.embed_documents(texts_to_embed)
|
78
|
+
|
79
|
+
# 将新的嵌入存储在缓存中
|
80
|
+
self.cache.set_batch(texts_to_embed, new_embeddings)
|
81
|
+
|
82
|
+
# 将新的嵌入放回结果列表中
|
83
|
+
for i, embedding in zip(indices_to_embed, new_embeddings):
|
84
|
+
cached_embeddings[i] = embedding
|
85
|
+
else:
|
86
|
+
print(f"✅ 缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。")
|
87
|
+
|
88
|
+
return cast(List[List[float]], cached_embeddings)
|
89
|
+
|
90
|
+
def embed_query(self, text: str) -> List[float]:
|
91
|
+
"""
|
92
|
+
为单个查询计算嵌入。
|
93
|
+
查询通常不被缓存,但如果需要可以添加。
|
94
|
+
"""
|
95
|
+
return self.model.embed_query(text)
|