jarvis-ai-assistant 0.1.220__py3-none-any.whl → 0.1.222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +110 -395
- jarvis/jarvis_agent/edit_file_handler.py +32 -185
- jarvis/jarvis_agent/jarvis.py +14 -9
- jarvis/jarvis_agent/main.py +13 -6
- jarvis/jarvis_agent/prompt_builder.py +57 -0
- jarvis/jarvis_agent/prompts.py +188 -0
- jarvis/jarvis_agent/protocols.py +30 -0
- jarvis/jarvis_agent/session_manager.py +84 -0
- jarvis/jarvis_agent/tool_executor.py +49 -0
- jarvis/jarvis_code_agent/code_agent.py +14 -23
- jarvis/jarvis_code_analysis/code_review.py +1 -1
- jarvis/jarvis_data/config_schema.json +13 -18
- jarvis/jarvis_git_details/main.py +1 -1
- jarvis/jarvis_platform/kimi.py +4 -2
- jarvis/jarvis_rag/__init__.py +2 -2
- jarvis/jarvis_rag/cache.py +28 -30
- jarvis/jarvis_rag/cli.py +141 -52
- jarvis/jarvis_rag/embedding_manager.py +32 -46
- jarvis/jarvis_rag/llm_interface.py +32 -34
- jarvis/jarvis_rag/query_rewriter.py +11 -12
- jarvis/jarvis_rag/rag_pipeline.py +40 -43
- jarvis/jarvis_rag/reranker.py +18 -18
- jarvis/jarvis_rag/retriever.py +29 -29
- jarvis/jarvis_tools/edit_file.py +11 -36
- jarvis/jarvis_utils/config.py +20 -25
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.222.dist-info}/METADATA +25 -20
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.222.dist-info}/RECORD +32 -27
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.222.dist-info}/entry_points.txt +9 -0
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.222.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.222.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.220.dist-info → jarvis_ai_assistant-0.1.222.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/cli.py
CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
4
4
|
from typing import Optional, List, Literal, cast
|
5
5
|
import mimetypes
|
6
6
|
|
7
|
+
import pathspec
|
7
8
|
import typer
|
8
9
|
from langchain.docstore.document import Document
|
9
10
|
from langchain_community.document_loaders import (
|
@@ -18,29 +19,29 @@ from jarvis.jarvis_utils.utils import init_env
|
|
18
19
|
|
19
20
|
def is_likely_text_file(file_path: Path) -> bool:
|
20
21
|
"""
|
21
|
-
|
22
|
-
|
22
|
+
通过读取文件开头部分,检查文件是否可能为文本文件。
|
23
|
+
此方法可以避免将大型二进制文件加载到内存中。
|
23
24
|
"""
|
24
25
|
try:
|
25
|
-
#
|
26
|
+
# 启发式方法1:检查MIME类型(如果可用)
|
26
27
|
mime_type, _ = mimetypes.guess_type(file_path)
|
27
28
|
if mime_type and mime_type.startswith("text/"):
|
28
29
|
return True
|
29
30
|
if mime_type and any(x in mime_type for x in ["json", "xml", "javascript"]):
|
30
31
|
return True
|
31
32
|
|
32
|
-
#
|
33
|
+
# 启发式方法2:检查文件的前几KB中是否包含空字节
|
33
34
|
with open(file_path, "rb") as f:
|
34
|
-
chunk = f.read(4096) #
|
35
|
+
chunk = f.read(4096) # 读取前4KB
|
35
36
|
if b"\x00" in chunk:
|
36
|
-
return False #
|
37
|
+
return False # 空字节是二进制文件的强指示符
|
37
38
|
return True
|
38
39
|
except Exception:
|
39
40
|
return False
|
40
41
|
|
41
42
|
|
42
|
-
#
|
43
|
-
#
|
43
|
+
# 确保项目根目录在Python路径中,以允许绝对导入
|
44
|
+
# 这使得脚本可以作为模块运行。
|
44
45
|
_project_root = os.path.abspath(
|
45
46
|
os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
46
47
|
)
|
@@ -54,13 +55,13 @@ from jarvis.jarvis_rag.rag_pipeline import JarvisRAGPipeline
|
|
54
55
|
|
55
56
|
app = typer.Typer(
|
56
57
|
name="jarvis-rag",
|
57
|
-
help="
|
58
|
+
help="一个与Jarvis RAG框架交互的命令行工具。",
|
58
59
|
add_completion=False,
|
59
60
|
)
|
60
61
|
|
61
62
|
|
62
63
|
class _CustomPlatformLLM(LLMInterface):
|
63
|
-
"""
|
64
|
+
"""一个简单的包装器,使BasePlatform实例与LLMInterface兼容。"""
|
64
65
|
|
65
66
|
def __init__(self, platform: BasePlatform):
|
66
67
|
self.platform = platform
|
@@ -73,7 +74,7 @@ class _CustomPlatformLLM(LLMInterface):
|
|
73
74
|
|
74
75
|
|
75
76
|
def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInterface]:
|
76
|
-
"""
|
77
|
+
"""从指定的平台和模型创建LLM接口。"""
|
77
78
|
if not platform_name or not model_name:
|
78
79
|
return None
|
79
80
|
try:
|
@@ -90,36 +91,70 @@ def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInter
|
|
90
91
|
return None
|
91
92
|
|
92
93
|
|
94
|
+
def _load_ragignore_spec() -> tuple[Optional[pathspec.PathSpec], Optional[Path]]:
|
95
|
+
"""
|
96
|
+
从项目根目录加载忽略模式。
|
97
|
+
首先查找 `.jarvis/rag/.ragignore`,如果未找到,则回退到 `.gitignore`。
|
98
|
+
"""
|
99
|
+
project_root_path = Path(_project_root)
|
100
|
+
ragignore_file = project_root_path / ".jarvis" / "rag" / ".ragignore"
|
101
|
+
gitignore_file = project_root_path / ".gitignore"
|
102
|
+
|
103
|
+
ignore_file_to_use = None
|
104
|
+
if ragignore_file.is_file():
|
105
|
+
ignore_file_to_use = ragignore_file
|
106
|
+
elif gitignore_file.is_file():
|
107
|
+
ignore_file_to_use = gitignore_file
|
108
|
+
|
109
|
+
if ignore_file_to_use:
|
110
|
+
try:
|
111
|
+
with open(ignore_file_to_use, "r", encoding="utf-8") as f:
|
112
|
+
patterns = f.read().splitlines()
|
113
|
+
spec = pathspec.PathSpec.from_lines("gitwildmatch", patterns)
|
114
|
+
print(f"✅ 加载忽略规则: {ignore_file_to_use}")
|
115
|
+
return spec, project_root_path
|
116
|
+
except Exception as e:
|
117
|
+
print(f"⚠️ 加载 {ignore_file_to_use.name} 文件失败: {e}")
|
118
|
+
|
119
|
+
return None, None
|
120
|
+
|
121
|
+
|
93
122
|
@app.command(
|
94
123
|
"add",
|
95
|
-
help="
|
124
|
+
help="从文件、目录或glob模式(例如 'src/**/*.py')添加文档。",
|
96
125
|
)
|
97
126
|
def add_documents(
|
98
127
|
paths: List[Path] = typer.Argument(
|
99
128
|
...,
|
100
|
-
help="
|
129
|
+
help="文件/目录路径或glob模式。支持Shell扩展。",
|
101
130
|
),
|
102
131
|
collection_name: str = typer.Option(
|
103
132
|
"jarvis_rag_collection",
|
104
133
|
"--collection",
|
105
134
|
"-c",
|
106
|
-
help="
|
135
|
+
help="向量数据库中集合的名称。",
|
107
136
|
),
|
108
|
-
|
137
|
+
embedding_model: Optional[str] = typer.Option(
|
109
138
|
None,
|
110
|
-
"--embedding-
|
139
|
+
"--embedding-model",
|
111
140
|
"-e",
|
112
|
-
help="
|
141
|
+
help="嵌入模型的名称。覆盖全局配置。",
|
113
142
|
),
|
114
143
|
db_path: Optional[Path] = typer.Option(
|
115
|
-
None, "--db-path", help="
|
144
|
+
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
145
|
+
),
|
146
|
+
batch_size: int = typer.Option(
|
147
|
+
500,
|
148
|
+
"--batch-size",
|
149
|
+
"-b",
|
150
|
+
help="单个批次中要处理的文档数。",
|
116
151
|
),
|
117
152
|
):
|
118
|
-
"""
|
153
|
+
"""从不同来源向RAG知识库添加文档。"""
|
119
154
|
files_to_process = set()
|
120
155
|
|
121
156
|
for path_str in paths:
|
122
|
-
# Typer
|
157
|
+
# Typer的List[Path]可能不会扩展glob,所以我们手动处理
|
123
158
|
from glob import glob
|
124
159
|
|
125
160
|
expanded_paths = glob(str(path_str), recursive=True)
|
@@ -141,59 +176,96 @@ def add_documents(
|
|
141
176
|
print(f"⚠️ 跳过可能的二进制文件: {path}")
|
142
177
|
|
143
178
|
if not files_to_process:
|
144
|
-
print(
|
179
|
+
print("⚠️ 在指定路径中未找到任何文本文件。")
|
180
|
+
return
|
181
|
+
|
182
|
+
# 使用 .ragignore 过滤文件
|
183
|
+
ragignore_spec, ragignore_root = _load_ragignore_spec()
|
184
|
+
if ragignore_spec and ragignore_root:
|
185
|
+
initial_count = len(files_to_process)
|
186
|
+
retained_files = set()
|
187
|
+
for file_path in files_to_process:
|
188
|
+
try:
|
189
|
+
# 将文件路径解析为绝对路径以确保正确比较
|
190
|
+
resolved_path = file_path.resolve()
|
191
|
+
relative_path = str(resolved_path.relative_to(ragignore_root))
|
192
|
+
if not ragignore_spec.match_file(relative_path):
|
193
|
+
retained_files.add(file_path)
|
194
|
+
except ValueError:
|
195
|
+
# 文件不在项目根目录下,保留它
|
196
|
+
retained_files.add(file_path)
|
197
|
+
|
198
|
+
ignored_count = initial_count - len(retained_files)
|
199
|
+
if ignored_count > 0:
|
200
|
+
print(f"ℹ️ 根据 .ragignore 规则过滤掉 {ignored_count} 个文件。")
|
201
|
+
files_to_process = retained_files
|
202
|
+
|
203
|
+
if not files_to_process:
|
204
|
+
print("⚠️ 所有找到的文本文件都被忽略规则过滤掉了。")
|
145
205
|
return
|
146
206
|
|
147
207
|
print(f"✅ 发现 {len(files_to_process)} 个独立文件待处理。")
|
148
208
|
|
149
209
|
try:
|
150
210
|
pipeline = JarvisRAGPipeline(
|
151
|
-
|
152
|
-
Optional[Literal["performance", "accuracy"]], embedding_mode
|
153
|
-
),
|
211
|
+
embedding_model=embedding_model,
|
154
212
|
db_path=str(db_path) if db_path else None,
|
155
213
|
collection_name=collection_name,
|
156
214
|
)
|
157
215
|
|
158
|
-
|
216
|
+
docs_batch: List[Document] = []
|
217
|
+
total_docs_added = 0
|
159
218
|
loader: BaseLoader
|
160
|
-
|
219
|
+
|
220
|
+
sorted_files = sorted(list(files_to_process))
|
221
|
+
total_files = len(sorted_files)
|
222
|
+
|
223
|
+
for i, file_path in enumerate(sorted_files):
|
161
224
|
try:
|
162
225
|
if file_path.suffix.lower() == ".md":
|
163
226
|
loader = UnstructuredMarkdownLoader(str(file_path))
|
164
|
-
else: #
|
227
|
+
else: # 对.txt和所有代码文件默认使用TextLoader
|
165
228
|
loader = TextLoader(str(file_path), encoding="utf-8")
|
166
229
|
|
167
|
-
|
168
|
-
print(f"✅ 已加载: {file_path}")
|
230
|
+
docs_batch.extend(loader.load())
|
231
|
+
print(f"✅ 已加载: {file_path} (文件 {i + 1}/{total_files})")
|
169
232
|
except Exception as e:
|
170
233
|
print(f"⚠️ 加载失败 {file_path}: {e}")
|
171
234
|
|
172
|
-
|
235
|
+
# 当批处理已满或是最后一个文件时处理批处理
|
236
|
+
if docs_batch and (len(docs_batch) >= batch_size or (i + 1) == total_files):
|
237
|
+
print(f"⚙️ 正在处理批次,包含 {len(docs_batch)} 个文档...")
|
238
|
+
pipeline.add_documents(docs_batch)
|
239
|
+
total_docs_added += len(docs_batch)
|
240
|
+
print(f"✅ 成功添加 {len(docs_batch)} 个文档。")
|
241
|
+
docs_batch = [] # 清空批处理
|
242
|
+
|
243
|
+
if total_docs_added == 0:
|
173
244
|
print("❌ 未能成功加载任何文档。")
|
174
245
|
raise typer.Exit(code=1)
|
175
246
|
|
176
|
-
|
177
|
-
|
247
|
+
print(
|
248
|
+
f"✅ 成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。"
|
249
|
+
)
|
178
250
|
|
179
251
|
except Exception as e:
|
180
252
|
print(f"❌ 发生严重错误: {e}")
|
181
253
|
raise typer.Exit(code=1)
|
182
254
|
|
183
255
|
|
184
|
-
@app.command("list-docs", help="
|
256
|
+
@app.command("list-docs", help="列出知识库中所有唯一的文档。")
|
185
257
|
def list_documents(
|
186
258
|
collection_name: str = typer.Option(
|
187
259
|
"jarvis_rag_collection",
|
188
260
|
"--collection",
|
189
261
|
"-c",
|
190
|
-
help="
|
262
|
+
help="向量数据库中集合的名称。",
|
191
263
|
),
|
192
264
|
db_path: Optional[Path] = typer.Option(
|
193
|
-
None, "--db-path", help="
|
265
|
+
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
194
266
|
),
|
195
267
|
):
|
196
|
-
"""
|
268
|
+
"""列出指定集合中的所有唯一文档。"""
|
197
269
|
try:
|
198
270
|
pipeline = JarvisRAGPipeline(
|
199
271
|
db_path=str(db_path) if db_path else None,
|
@@ -201,13 +273,13 @@ def list_documents(
|
|
201
273
|
)
|
202
274
|
|
203
275
|
collection = pipeline.retriever.collection
|
204
|
-
results = collection.get() #
|
276
|
+
results = collection.get() # 获取集合中的所有项目
|
205
277
|
|
206
278
|
if not results or not results["metadatas"]:
|
207
279
|
print("ℹ️ 知识库中没有找到任何文档。")
|
208
280
|
return
|
209
281
|
|
210
|
-
#
|
282
|
+
# 从元数据中提取唯一的源文件路径
|
211
283
|
sources = set()
|
212
284
|
for metadata in results["metadatas"]:
|
213
285
|
if metadata:
|
@@ -228,38 +300,38 @@ def list_documents(
|
|
228
300
|
raise typer.Exit(code=1)
|
229
301
|
|
230
302
|
|
231
|
-
@app.command("query", help="
|
303
|
+
@app.command("query", help="向知识库提问。")
|
232
304
|
def query(
|
233
|
-
question: str = typer.Argument(..., help="
|
305
|
+
question: str = typer.Argument(..., help="要提出的问题。"),
|
234
306
|
collection_name: str = typer.Option(
|
235
307
|
"jarvis_rag_collection",
|
236
308
|
"--collection",
|
237
309
|
"-c",
|
238
|
-
help="
|
310
|
+
help="向量数据库中集合的名称。",
|
239
311
|
),
|
240
|
-
|
312
|
+
embedding_model: Optional[str] = typer.Option(
|
241
313
|
None,
|
242
|
-
"--embedding-
|
314
|
+
"--embedding-model",
|
243
315
|
"-e",
|
244
|
-
help="
|
316
|
+
help="嵌入模型的名称。覆盖全局配置。",
|
245
317
|
),
|
246
318
|
db_path: Optional[Path] = typer.Option(
|
247
|
-
None, "--db-path", help="
|
319
|
+
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
248
320
|
),
|
249
321
|
platform: Optional[str] = typer.Option(
|
250
322
|
None,
|
251
323
|
"--platform",
|
252
324
|
"-p",
|
253
|
-
help="
|
325
|
+
help="为LLM指定平台名称。覆盖默认的思考模型。",
|
254
326
|
),
|
255
327
|
model: Optional[str] = typer.Option(
|
256
328
|
None,
|
257
329
|
"--model",
|
258
330
|
"-m",
|
259
|
-
help="
|
331
|
+
help="为LLM指定模型名称。需要 --platform。",
|
260
332
|
),
|
261
333
|
):
|
262
|
-
"""
|
334
|
+
"""查询RAG知识库并打印答案。"""
|
263
335
|
if model and not platform:
|
264
336
|
print("❌ 错误: --model 需要指定 --platform。")
|
265
337
|
raise typer.Exit(code=1)
|
@@ -271,9 +343,7 @@ def query(
|
|
271
343
|
|
272
344
|
pipeline = JarvisRAGPipeline(
|
273
345
|
llm=custom_llm,
|
274
|
-
|
275
|
-
Optional[Literal["performance", "accuracy"]], embedding_mode
|
276
|
-
),
|
346
|
+
embedding_model=embedding_model,
|
277
347
|
db_path=str(db_path) if db_path else None,
|
278
348
|
collection_name=collection_name,
|
279
349
|
)
|
@@ -282,7 +352,7 @@ def query(
|
|
282
352
|
answer = pipeline.query(question)
|
283
353
|
|
284
354
|
print("💬 答案:")
|
285
|
-
#
|
355
|
+
# 我们仍然可以使用 rich.markdown.Markdown,因为 PrettyOutput 底层使用了 rich
|
286
356
|
from jarvis.jarvis_utils.globals import console
|
287
357
|
|
288
358
|
console.print(Markdown(answer))
|
@@ -292,6 +362,25 @@ def query(
|
|
292
362
|
raise typer.Exit(code=1)
|
293
363
|
|
294
364
|
|
365
|
+
_RAG_INSTALLED = False
|
366
|
+
try:
|
367
|
+
import langchain # noqa
|
368
|
+
|
369
|
+
_RAG_INSTALLED = True
|
370
|
+
except ImportError:
|
371
|
+
pass
|
372
|
+
|
373
|
+
|
374
|
+
def _check_rag_dependencies():
|
375
|
+
if not _RAG_INSTALLED:
|
376
|
+
print(
|
377
|
+
"❌ RAG依赖项未安装。"
|
378
|
+
"请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。"
|
379
|
+
)
|
380
|
+
raise typer.Exit(code=1)
|
381
|
+
|
382
|
+
|
295
383
|
def main():
|
384
|
+
_check_rag_dependencies()
|
296
385
|
init_env(welcome_str="Jarvis RAG")
|
297
386
|
app()
|
@@ -1,59 +1,45 @@
|
|
1
|
-
|
1
|
+
import torch
|
2
|
+
from typing import List, cast
|
2
3
|
from langchain_huggingface import HuggingFaceEmbeddings
|
3
4
|
|
4
|
-
from jarvis.jarvis_utils.config import (
|
5
|
-
get_rag_embedding_models,
|
6
|
-
get_rag_embedding_cache_path,
|
7
|
-
)
|
8
5
|
from .cache import EmbeddingCache
|
9
6
|
|
10
7
|
|
11
8
|
class EmbeddingManager:
|
12
9
|
"""
|
13
|
-
|
10
|
+
管理本地嵌入模型的加载和使用,并带有缓存功能。
|
14
11
|
|
15
|
-
|
16
|
-
|
17
|
-
and uses a disk-based cache to avoid re-computing embeddings for the
|
18
|
-
same text.
|
12
|
+
该类负责从Hugging Face加载指定的模型,并使用基于磁盘的缓存
|
13
|
+
来避免为相同文本重新计算嵌入。
|
19
14
|
"""
|
20
15
|
|
21
|
-
def __init__(
|
22
|
-
self,
|
23
|
-
mode: Literal["performance", "accuracy"],
|
24
|
-
cache_dir: str,
|
25
|
-
):
|
16
|
+
def __init__(self, model_name: str, cache_dir: str):
|
26
17
|
"""
|
27
|
-
|
18
|
+
初始化EmbeddingManager。
|
28
19
|
|
29
|
-
|
30
|
-
|
31
|
-
cache_dir:
|
20
|
+
参数:
|
21
|
+
model_name: 要加载的Hugging Face模型的名称。
|
22
|
+
cache_dir: 用于存储嵌入缓存的目录。
|
32
23
|
"""
|
33
|
-
self.
|
34
|
-
self.embedding_models = get_rag_embedding_models()
|
35
|
-
if mode not in self.embedding_models:
|
36
|
-
raise ValueError(
|
37
|
-
f"Invalid mode '{mode}'. Must be one of {list(self.embedding_models.keys())}"
|
38
|
-
)
|
39
|
-
|
40
|
-
self.model_config = self.embedding_models[self.mode]
|
41
|
-
self.model_name = self.model_config["model_name"]
|
24
|
+
self.model_name = model_name
|
42
25
|
|
43
|
-
print(f"🚀
|
26
|
+
print(f"🚀 初始化嵌入管理器, 模型: '{self.model_name}'...")
|
44
27
|
|
45
|
-
#
|
46
|
-
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=
|
28
|
+
# 缓存的salt是模型名称,以防止冲突
|
29
|
+
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
|
47
30
|
self.model = self._load_model()
|
48
31
|
|
49
32
|
def _load_model(self) -> HuggingFaceEmbeddings:
|
50
|
-
"""
|
33
|
+
"""根据配置加载Hugging Face嵌入模型。"""
|
34
|
+
model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
|
35
|
+
encode_kwargs = {"normalize_embeddings": True}
|
36
|
+
|
51
37
|
try:
|
52
38
|
return HuggingFaceEmbeddings(
|
53
39
|
model_name=self.model_name,
|
54
|
-
model_kwargs=
|
55
|
-
encode_kwargs=
|
56
|
-
show_progress=
|
40
|
+
model_kwargs=model_kwargs,
|
41
|
+
encode_kwargs=encode_kwargs,
|
42
|
+
show_progress=True,
|
57
43
|
)
|
58
44
|
except Exception as e:
|
59
45
|
print(f"❌ 加载嵌入模型 '{self.model_name}' 时出错: {e}")
|
@@ -62,18 +48,18 @@ class EmbeddingManager:
|
|
62
48
|
|
63
49
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
64
50
|
"""
|
65
|
-
|
51
|
+
使用缓存为文档列表计算嵌入。
|
66
52
|
|
67
|
-
|
68
|
-
texts:
|
53
|
+
参数:
|
54
|
+
texts: 要嵌入的文档(字符串)列表。
|
69
55
|
|
70
|
-
|
71
|
-
|
56
|
+
返回:
|
57
|
+
一个嵌入列表,每个文档对应一个嵌入。
|
72
58
|
"""
|
73
59
|
if not texts:
|
74
60
|
return []
|
75
61
|
|
76
|
-
#
|
62
|
+
# 检查缓存中是否已存在嵌入
|
77
63
|
cached_embeddings = self.cache.get_batch(texts)
|
78
64
|
|
79
65
|
texts_to_embed = []
|
@@ -83,17 +69,17 @@ class EmbeddingManager:
|
|
83
69
|
texts_to_embed.append(text)
|
84
70
|
indices_to_embed.append(i)
|
85
71
|
|
86
|
-
#
|
72
|
+
# 为不在缓存中的文本计算嵌入
|
87
73
|
if texts_to_embed:
|
88
74
|
print(
|
89
75
|
f"🔎 缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。"
|
90
76
|
)
|
91
77
|
new_embeddings = self.model.embed_documents(texts_to_embed)
|
92
78
|
|
93
|
-
#
|
79
|
+
# 将新的嵌入存储在缓存中
|
94
80
|
self.cache.set_batch(texts_to_embed, new_embeddings)
|
95
81
|
|
96
|
-
#
|
82
|
+
# 将新的嵌入放回结果列表中
|
97
83
|
for i, embedding in zip(indices_to_embed, new_embeddings):
|
98
84
|
cached_embeddings[i] = embedding
|
99
85
|
else:
|
@@ -103,7 +89,7 @@ class EmbeddingManager:
|
|
103
89
|
|
104
90
|
def embed_query(self, text: str) -> List[float]:
|
105
91
|
"""
|
106
|
-
|
107
|
-
|
92
|
+
为单个查询计算嵌入。
|
93
|
+
查询通常不被缓存,但如果需要可以添加。
|
108
94
|
"""
|
109
95
|
return self.model.embed_query(text)
|
@@ -10,42 +10,40 @@ from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
10
10
|
|
11
11
|
class LLMInterface(ABC):
|
12
12
|
"""
|
13
|
-
|
13
|
+
大型语言模型接口的抽象基类。
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
subclass of this interface.
|
15
|
+
该类定义了与远程LLM交互的标准接口。
|
16
|
+
任何LLM提供商(如OpenAI、Anthropic等)都应作为该接口的子类来实现。
|
18
17
|
"""
|
19
18
|
|
20
19
|
@abstractmethod
|
21
20
|
def generate(self, prompt: str, **kwargs) -> str:
|
22
21
|
"""
|
23
|
-
|
22
|
+
根据给定的提示从LLM生成响应。
|
24
23
|
|
25
|
-
|
26
|
-
prompt:
|
27
|
-
**kwargs:
|
28
|
-
|
24
|
+
参数:
|
25
|
+
prompt: 发送给LLM的输入提示。
|
26
|
+
**kwargs: LLM API调用的其他关键字参数
|
27
|
+
(例如,temperature, max_tokens)。
|
29
28
|
|
30
|
-
|
31
|
-
|
29
|
+
返回:
|
30
|
+
由LLM生成的文本响应。
|
32
31
|
"""
|
33
32
|
pass
|
34
33
|
|
35
34
|
|
36
35
|
class ToolAgent_LLM(LLMInterface):
|
37
36
|
"""
|
38
|
-
|
39
|
-
to generate the final response.
|
37
|
+
LLMInterface的一个实现,它使用一个能操作工具的JarvisAgent来生成最终响应。
|
40
38
|
"""
|
41
39
|
|
42
40
|
def __init__(self):
|
43
41
|
"""
|
44
|
-
|
42
|
+
初始化工具-代理 LLM 包装器。
|
45
43
|
"""
|
46
44
|
print("🤖 已初始化工具 Agent 作为最终应答者。")
|
47
45
|
self.allowed_tools = ["read_code", "execute_script"]
|
48
|
-
#
|
46
|
+
# 为代理提供一个通用的系统提示
|
49
47
|
self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
|
50
48
|
self.summary_prompt = """
|
51
49
|
<report>
|
@@ -59,17 +57,17 @@ class ToolAgent_LLM(LLMInterface):
|
|
59
57
|
|
60
58
|
def generate(self, prompt: str, **kwargs) -> str:
|
61
59
|
"""
|
62
|
-
|
60
|
+
使用受限的工具集运行JarvisAgent以生成答案。
|
63
61
|
|
64
|
-
|
65
|
-
prompt:
|
66
|
-
**kwargs:
|
62
|
+
参数:
|
63
|
+
prompt: 要发送给代理的完整提示,包括上下文。
|
64
|
+
**kwargs: 已忽略,为保持接口兼容性而保留。
|
67
65
|
|
68
|
-
|
69
|
-
|
66
|
+
返回:
|
67
|
+
由代理生成的最终答案。
|
70
68
|
"""
|
71
69
|
try:
|
72
|
-
#
|
70
|
+
# 使用RAG上下文的特定设置初始化代理
|
73
71
|
agent = JarvisAgent(
|
74
72
|
system_prompt=self.system_prompt,
|
75
73
|
use_tools=self.allowed_tools,
|
@@ -80,7 +78,7 @@ class ToolAgent_LLM(LLMInterface):
|
|
80
78
|
summary_prompt=self.summary_prompt,
|
81
79
|
)
|
82
80
|
|
83
|
-
#
|
81
|
+
# 代理的run方法需要'user_input'参数
|
84
82
|
final_answer = agent.run(user_input=prompt)
|
85
83
|
return str(final_answer)
|
86
84
|
|
@@ -91,21 +89,21 @@ class ToolAgent_LLM(LLMInterface):
|
|
91
89
|
|
92
90
|
class JarvisPlatform_LLM(LLMInterface):
|
93
91
|
"""
|
94
|
-
|
92
|
+
项目内部平台的LLMInterface实现。
|
95
93
|
|
96
|
-
|
94
|
+
该类使用PlatformRegistry来获取配置的“普通”模型。
|
97
95
|
"""
|
98
96
|
|
99
97
|
def __init__(self):
|
100
98
|
"""
|
101
|
-
|
99
|
+
初始化Jarvis平台LLM客户端。
|
102
100
|
"""
|
103
101
|
try:
|
104
102
|
self.registry = PlatformRegistry.get_global_platform_registry()
|
105
103
|
self.platform: BasePlatform = self.registry.get_normal_platform()
|
106
104
|
self.platform.set_suppress_output(
|
107
105
|
False
|
108
|
-
) #
|
106
|
+
) # 确保模型没有控制台输出
|
109
107
|
print(f"🚀 已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}")
|
110
108
|
except Exception as e:
|
111
109
|
print(f"❌ 初始化 Jarvis 平台 LLM 失败: {e}")
|
@@ -113,17 +111,17 @@ class JarvisPlatform_LLM(LLMInterface):
|
|
113
111
|
|
114
112
|
def generate(self, prompt: str, **kwargs) -> str:
|
115
113
|
"""
|
116
|
-
|
114
|
+
向本地平台模型发送提示并返回响应。
|
117
115
|
|
118
|
-
|
119
|
-
prompt:
|
120
|
-
**kwargs:
|
116
|
+
参数:
|
117
|
+
prompt: 用户的提示。
|
118
|
+
**kwargs: 已忽略,为保持接口兼容性而保留。
|
121
119
|
|
122
|
-
|
123
|
-
|
120
|
+
返回:
|
121
|
+
由平台模型生成的响应。
|
124
122
|
"""
|
125
123
|
try:
|
126
|
-
#
|
124
|
+
# 使用健壮的chat_until_success方法
|
127
125
|
return self.platform.chat_until_success(prompt)
|
128
126
|
except Exception as e:
|
129
127
|
print(f"❌ 调用 Jarvis 平台模型时发生错误: {e}")
|