jarvis-ai-assistant 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/edit_file_handler.py +5 -0
- jarvis/jarvis_agent/jarvis.py +22 -25
- jarvis/jarvis_agent/main.py +6 -6
- jarvis/jarvis_agent/prompts.py +26 -4
- jarvis/jarvis_code_agent/code_agent.py +279 -11
- jarvis/jarvis_code_analysis/code_review.py +21 -19
- jarvis/jarvis_data/config_schema.json +86 -18
- jarvis/jarvis_git_squash/main.py +3 -3
- jarvis/jarvis_git_utils/git_commiter.py +32 -11
- jarvis/jarvis_mcp/sse_mcp_client.py +4 -6
- jarvis/jarvis_mcp/streamable_mcp_client.py +5 -9
- jarvis/jarvis_platform/tongyi.py +9 -9
- jarvis/jarvis_rag/cli.py +79 -23
- jarvis/jarvis_rag/query_rewriter.py +61 -12
- jarvis/jarvis_rag/rag_pipeline.py +143 -34
- jarvis/jarvis_rag/retriever.py +6 -6
- jarvis/jarvis_smart_shell/main.py +2 -2
- jarvis/jarvis_stats/__init__.py +13 -0
- jarvis/jarvis_stats/cli.py +337 -0
- jarvis/jarvis_stats/stats.py +433 -0
- jarvis/jarvis_stats/storage.py +329 -0
- jarvis/jarvis_stats/visualizer.py +443 -0
- jarvis/jarvis_tools/cli/main.py +84 -15
- jarvis/jarvis_tools/generate_new_tool.py +22 -1
- jarvis/jarvis_tools/registry.py +35 -16
- jarvis/jarvis_tools/search_web.py +3 -3
- jarvis/jarvis_tools/virtual_tty.py +315 -26
- jarvis/jarvis_utils/config.py +98 -11
- jarvis/jarvis_utils/git_utils.py +8 -16
- jarvis/jarvis_utils/globals.py +29 -8
- jarvis/jarvis_utils/input.py +114 -121
- jarvis/jarvis_utils/utils.py +213 -37
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.4.dist-info}/METADATA +99 -9
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.4.dist-info}/RECORD +39 -34
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.4.dist-info}/entry_points.txt +2 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.4.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.4.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/cli.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import sys
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Optional, List, Literal, cast
|
4
|
+
from typing import Optional, List, Literal, cast, Tuple
|
5
5
|
import mimetypes
|
6
6
|
|
7
7
|
import pathspec # type: ignore
|
@@ -15,6 +15,11 @@ from langchain_core.document_loaders.base import BaseLoader
|
|
15
15
|
from rich.markdown import Markdown
|
16
16
|
|
17
17
|
from jarvis.jarvis_utils.utils import init_env
|
18
|
+
from jarvis.jarvis_utils.config import (
|
19
|
+
get_rag_embedding_model,
|
20
|
+
get_rag_use_bm25,
|
21
|
+
get_rag_use_rerank,
|
22
|
+
)
|
18
23
|
|
19
24
|
|
20
25
|
def is_likely_text_file(file_path: Path) -> bool:
|
@@ -65,9 +70,7 @@ class _CustomPlatformLLM(LLMInterface):
|
|
65
70
|
|
66
71
|
def __init__(self, platform: BasePlatform):
|
67
72
|
self.platform = platform
|
68
|
-
print(
|
69
|
-
f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'"
|
70
|
-
)
|
73
|
+
print(f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'")
|
71
74
|
|
72
75
|
def generate(self, prompt: str, **kwargs) -> str:
|
73
76
|
return self.platform.chat_until_success(prompt)
|
@@ -91,7 +94,7 @@ def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInter
|
|
91
94
|
return None
|
92
95
|
|
93
96
|
|
94
|
-
def _load_ragignore_spec() ->
|
97
|
+
def _load_ragignore_spec() -> Tuple[Optional[pathspec.PathSpec], Optional[Path]]:
|
95
98
|
"""
|
96
99
|
从项目根目录加载忽略模式。
|
97
100
|
首先查找 `.jarvis/rag/.ragignore`,如果未找到,则回退到 `.gitignore`。
|
@@ -140,9 +143,7 @@ def add_documents(
|
|
140
143
|
"-e",
|
141
144
|
help="嵌入模型的名称。覆盖全局配置。",
|
142
145
|
),
|
143
|
-
db_path: Optional[Path] = typer.Option(
|
144
|
-
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
145
|
-
),
|
146
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
146
147
|
batch_size: int = typer.Option(
|
147
148
|
500,
|
148
149
|
"--batch-size",
|
@@ -244,9 +245,7 @@ def add_documents(
|
|
244
245
|
print("❌ 未能成功加载任何文档。")
|
245
246
|
raise typer.Exit(code=1)
|
246
247
|
|
247
|
-
print(
|
248
|
-
f"✅ 成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。"
|
249
|
-
)
|
248
|
+
print(f"✅ 成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。")
|
250
249
|
|
251
250
|
except Exception as e:
|
252
251
|
print(f"❌ 发生严重错误: {e}")
|
@@ -261,9 +260,7 @@ def list_documents(
|
|
261
260
|
"-c",
|
262
261
|
help="向量数据库中集合的名称。",
|
263
262
|
),
|
264
|
-
db_path: Optional[Path] = typer.Option(
|
265
|
-
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
266
|
-
),
|
263
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
267
264
|
):
|
268
265
|
"""列出指定集合中的所有唯一文档。"""
|
269
266
|
try:
|
@@ -272,7 +269,7 @@ def list_documents(
|
|
272
269
|
collection_name=collection_name,
|
273
270
|
)
|
274
271
|
|
275
|
-
collection = pipeline.
|
272
|
+
collection = pipeline._get_collection()
|
276
273
|
results = collection.get() # 获取集合中的所有项目
|
277
274
|
|
278
275
|
if not results or not results["metadatas"]:
|
@@ -300,6 +297,63 @@ def list_documents(
|
|
300
297
|
raise typer.Exit(code=1)
|
301
298
|
|
302
299
|
|
300
|
+
@app.command("retrieve", help="仅从知识库检索相关文档,不生成答案。")
|
301
|
+
def retrieve(
|
302
|
+
question: str = typer.Argument(..., help="要提出的问题。"),
|
303
|
+
collection_name: str = typer.Option(
|
304
|
+
"jarvis_rag_collection",
|
305
|
+
"--collection",
|
306
|
+
"-c",
|
307
|
+
help="向量数据库中集合的名称。",
|
308
|
+
),
|
309
|
+
embedding_model: Optional[str] = typer.Option(
|
310
|
+
None,
|
311
|
+
"--embedding-model",
|
312
|
+
"-e",
|
313
|
+
help="嵌入模型的名称。覆盖全局配置。",
|
314
|
+
),
|
315
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
316
|
+
n_results: int = typer.Option(5, "--top-n", help="要检索的文档数量。"),
|
317
|
+
):
|
318
|
+
"""仅从RAG知识库检索文档并打印结果。"""
|
319
|
+
try:
|
320
|
+
# 如果未在命令行中指定,则从配置中加载RAG设置
|
321
|
+
final_embedding_model = embedding_model or get_rag_embedding_model()
|
322
|
+
use_bm25 = get_rag_use_bm25()
|
323
|
+
use_rerank = get_rag_use_rerank()
|
324
|
+
|
325
|
+
pipeline = JarvisRAGPipeline(
|
326
|
+
embedding_model=final_embedding_model,
|
327
|
+
db_path=str(db_path) if db_path else None,
|
328
|
+
collection_name=collection_name,
|
329
|
+
use_bm25=use_bm25,
|
330
|
+
use_rerank=use_rerank,
|
331
|
+
)
|
332
|
+
|
333
|
+
print(f"🤔 正在为问题检索文档: '{question}'")
|
334
|
+
retrieved_docs = pipeline.retrieve_only(question, n_results=n_results)
|
335
|
+
|
336
|
+
if not retrieved_docs:
|
337
|
+
print("ℹ️ 未找到相关文档。")
|
338
|
+
return
|
339
|
+
|
340
|
+
print(f"✅ 成功检索到 {len(retrieved_docs)} 个文档:")
|
341
|
+
from jarvis.jarvis_utils.globals import console
|
342
|
+
|
343
|
+
for i, doc in enumerate(retrieved_docs, 1):
|
344
|
+
source = doc.metadata.get("source", "未知来源")
|
345
|
+
content = doc.page_content
|
346
|
+
panel_title = f"文档 {i} | 来源: {source}"
|
347
|
+
console.print(
|
348
|
+
f"\n[bold magenta]{panel_title}[/bold magenta]"
|
349
|
+
)
|
350
|
+
console.print(Markdown(f"```\n{content}\n```"))
|
351
|
+
|
352
|
+
except Exception as e:
|
353
|
+
print(f"❌ 发生错误: {e}")
|
354
|
+
raise typer.Exit(code=1)
|
355
|
+
|
356
|
+
|
303
357
|
@app.command("query", help="向知识库提问。")
|
304
358
|
def query(
|
305
359
|
question: str = typer.Argument(..., help="要提出的问题。"),
|
@@ -315,9 +369,7 @@ def query(
|
|
315
369
|
"-e",
|
316
370
|
help="嵌入模型的名称。覆盖全局配置。",
|
317
371
|
),
|
318
|
-
db_path: Optional[Path] = typer.Option(
|
319
|
-
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
320
|
-
),
|
372
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
321
373
|
platform: Optional[str] = typer.Option(
|
322
374
|
None,
|
323
375
|
"--platform",
|
@@ -341,11 +393,18 @@ def query(
|
|
341
393
|
if (platform or model) and not custom_llm:
|
342
394
|
raise typer.Exit(code=1)
|
343
395
|
|
396
|
+
# 如果未在命令行中指定,则从配置中加载RAG设置
|
397
|
+
final_embedding_model = embedding_model or get_rag_embedding_model()
|
398
|
+
use_bm25 = get_rag_use_bm25()
|
399
|
+
use_rerank = get_rag_use_rerank()
|
400
|
+
|
344
401
|
pipeline = JarvisRAGPipeline(
|
345
402
|
llm=custom_llm,
|
346
|
-
embedding_model=
|
403
|
+
embedding_model=final_embedding_model,
|
347
404
|
db_path=str(db_path) if db_path else None,
|
348
405
|
collection_name=collection_name,
|
406
|
+
use_bm25=use_bm25,
|
407
|
+
use_rerank=use_rerank,
|
349
408
|
)
|
350
409
|
|
351
410
|
print(f"🤔 正在查询: '{question}'")
|
@@ -373,10 +432,7 @@ except ImportError:
|
|
373
432
|
|
374
433
|
def _check_rag_dependencies():
|
375
434
|
if not _RAG_INSTALLED:
|
376
|
-
print(
|
377
|
-
"❌ RAG依赖项未安装。"
|
378
|
-
"请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。"
|
379
|
-
)
|
435
|
+
print("❌ RAG依赖项未安装。" "请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。")
|
380
436
|
raise typer.Exit(code=1)
|
381
437
|
|
382
438
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import List
|
2
2
|
from .llm_interface import LLMInterface
|
3
|
+
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
3
4
|
|
4
5
|
|
5
6
|
class QueryRewriter:
|
@@ -20,20 +21,29 @@ class QueryRewriter:
|
|
20
21
|
def _create_prompt_template(self) -> str:
|
21
22
|
"""为多查询重写任务创建提示模板。"""
|
22
23
|
return """
|
23
|
-
|
24
|
+
你是一个精通检索和语言的AI助手。你的任务是将以下这个单一的用户问题,改写为几个语义相关但表达方式不同的搜索查询,并提供英文翻译。这有助于在多语言知识库中进行更全面的搜索。
|
24
25
|
|
25
26
|
请遵循以下原则:
|
26
|
-
1.
|
27
|
-
2.
|
28
|
-
|
29
|
-
|
27
|
+
1. **保留核心意图**: 所有查询都必须围绕原始问题的核心意图。
|
28
|
+
2. **查询类型**:
|
29
|
+
- **同义词/相关术语查询**: 使用原始语言,通过替换同义词或相关术语来生成1-2个新的查询。
|
30
|
+
- **英文翻译查询**: 将原始问题翻译成一个简洁的英文搜索查询。
|
31
|
+
3. **简洁性**: 每个查询都应该是独立的、可以直接用于搜索的短语或问题。
|
32
|
+
4. **严格格式要求**: 你必须将所有重写后的查询放置在 `<REWRITE>` 和 `</REWRITE>` 标签之间。每个查询占一行。不要在标签内外添加任何编号、前缀或解释。
|
33
|
+
|
34
|
+
示例输出格式:
|
35
|
+
<REWRITE>
|
36
|
+
使用不同表述的中文查询
|
37
|
+
另一个中文查询
|
38
|
+
English version of the query
|
39
|
+
</REWRITE>
|
30
40
|
|
31
41
|
原始问题:
|
32
42
|
---
|
33
43
|
{query}
|
34
44
|
---
|
35
45
|
|
36
|
-
|
46
|
+
请将改写后的查询包裹在 `<REWRITE>` 标签内:
|
37
47
|
"""
|
38
48
|
|
39
49
|
def rewrite(self, query: str) -> List[str]:
|
@@ -47,16 +57,55 @@ class QueryRewriter:
|
|
47
57
|
一个经过重写、搜索优化的查询列表。
|
48
58
|
"""
|
49
59
|
prompt = self.rewrite_prompt_template.format(query=query)
|
50
|
-
print(
|
60
|
+
PrettyOutput.print(
|
61
|
+
"正在将原始查询重写为多个搜索查询...", output_type=OutputType.INFO, timestamp=False
|
62
|
+
)
|
63
|
+
|
64
|
+
import re
|
65
|
+
|
66
|
+
max_retries = 3
|
67
|
+
attempts = 0
|
68
|
+
rewritten_queries = []
|
69
|
+
response_text = ""
|
70
|
+
|
71
|
+
while attempts < max_retries:
|
72
|
+
attempts += 1
|
73
|
+
response_text = self.llm.generate(prompt)
|
74
|
+
match = re.search(r"<REWRITE>(.*?)</REWRITE>", response_text, re.DOTALL)
|
75
|
+
|
76
|
+
if match:
|
77
|
+
content = match.group(1).strip()
|
78
|
+
rewritten_queries = [
|
79
|
+
line.strip() for line in content.split("\n") if line.strip()
|
80
|
+
]
|
81
|
+
PrettyOutput.print(
|
82
|
+
f"成功从LLM响应中提取到内容 (尝试 {attempts}/{max_retries})。",
|
83
|
+
output_type=OutputType.SUCCESS,
|
84
|
+
timestamp=False,
|
85
|
+
)
|
86
|
+
break # 提取成功,退出循环
|
87
|
+
else:
|
88
|
+
PrettyOutput.print(
|
89
|
+
f"未能从LLM响应中提取内容。正在重试... ({attempts}/{max_retries})",
|
90
|
+
output_type=OutputType.WARNING,
|
91
|
+
timestamp=False,
|
92
|
+
)
|
51
93
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
94
|
+
# 如果所有重试都失败,则跳过重写步骤
|
95
|
+
if not rewritten_queries:
|
96
|
+
PrettyOutput.print(
|
97
|
+
"所有重试均失败。跳过查询重写,将仅使用原始查询。",
|
98
|
+
output_type=OutputType.ERROR,
|
99
|
+
timestamp=False,
|
100
|
+
)
|
56
101
|
|
57
102
|
# 同时包含原始查询以保证鲁棒性
|
58
103
|
if query not in rewritten_queries:
|
59
104
|
rewritten_queries.insert(0, query)
|
60
105
|
|
61
|
-
print(
|
106
|
+
PrettyOutput.print(
|
107
|
+
f"生成了 {len(rewritten_queries)} 个查询变体。",
|
108
|
+
output_type=OutputType.SUCCESS,
|
109
|
+
timestamp=False,
|
110
|
+
)
|
62
111
|
return rewritten_queries
|
@@ -30,6 +30,8 @@ class JarvisRAGPipeline:
|
|
30
30
|
embedding_model: Optional[str] = None,
|
31
31
|
db_path: Optional[str] = None,
|
32
32
|
collection_name: str = "jarvis_rag_collection",
|
33
|
+
use_bm25: bool = True,
|
34
|
+
use_rerank: bool = True,
|
33
35
|
):
|
34
36
|
"""
|
35
37
|
初始化RAG管道。
|
@@ -40,6 +42,8 @@ class JarvisRAGPipeline:
|
|
40
42
|
embedding_model: 嵌入模型的名称。如果为None,则使用配置值。
|
41
43
|
db_path: 持久化向量数据库的路径。如果为None,则使用配置值。
|
42
44
|
collection_name: 向量数据库中集合的名称。
|
45
|
+
use_bm25: 是否在检索中使用BM25。
|
46
|
+
use_rerank: 是否在检索后使用重排器。
|
43
47
|
"""
|
44
48
|
# 确定嵌入模型以隔离数据路径
|
45
49
|
model_name = embedding_model or get_rag_embedding_model()
|
@@ -56,22 +60,87 @@ class JarvisRAGPipeline:
|
|
56
60
|
get_rag_embedding_cache_path(), sanitized_model_name
|
57
61
|
)
|
58
62
|
|
59
|
-
|
60
|
-
|
61
|
-
|
63
|
+
# 存储初始化参数以供延迟加载
|
64
|
+
self.llm = llm if llm is not None else ToolAgent_LLM()
|
65
|
+
self.embedding_model_name = embedding_model or get_rag_embedding_model()
|
66
|
+
self.db_path = db_path
|
67
|
+
self.collection_name = collection_name
|
68
|
+
self.use_bm25 = use_bm25
|
69
|
+
self.use_rerank = use_rerank
|
70
|
+
|
71
|
+
# 延迟加载的组件
|
72
|
+
self._embedding_manager: Optional[EmbeddingManager] = None
|
73
|
+
self._retriever: Optional[ChromaRetriever] = None
|
74
|
+
self._reranker: Optional[Reranker] = None
|
75
|
+
self._query_rewriter: Optional[QueryRewriter] = None
|
76
|
+
|
77
|
+
print("✅ JarvisRAGPipeline 初始化成功 (模型按需加载).")
|
78
|
+
|
79
|
+
def _get_embedding_manager(self) -> EmbeddingManager:
|
80
|
+
if self._embedding_manager is None:
|
81
|
+
sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
|
82
|
+
"\\", "_"
|
83
|
+
)
|
84
|
+
_final_cache_path = os.path.join(
|
85
|
+
get_rag_embedding_cache_path(), sanitized_model_name
|
86
|
+
)
|
87
|
+
self._embedding_manager = EmbeddingManager(
|
88
|
+
model_name=self.embedding_model_name,
|
89
|
+
cache_dir=_final_cache_path,
|
90
|
+
)
|
91
|
+
return self._embedding_manager
|
92
|
+
|
93
|
+
def _get_retriever(self) -> ChromaRetriever:
|
94
|
+
if self._retriever is None:
|
95
|
+
sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
|
96
|
+
"\\", "_"
|
97
|
+
)
|
98
|
+
_final_db_path = (
|
99
|
+
str(self.db_path)
|
100
|
+
if self.db_path
|
101
|
+
else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
|
102
|
+
)
|
103
|
+
self._retriever = ChromaRetriever(
|
104
|
+
embedding_manager=self._get_embedding_manager(),
|
105
|
+
db_path=_final_db_path,
|
106
|
+
collection_name=self.collection_name,
|
107
|
+
)
|
108
|
+
return self._retriever
|
109
|
+
|
110
|
+
def _get_collection(self):
|
111
|
+
"""
|
112
|
+
在不加载嵌入模型的情况下,直接获取并返回Chroma集合对象。
|
113
|
+
这对于仅需要访问集合元数据(如列出文档)而无需嵌入功能的操作非常有用。
|
114
|
+
"""
|
115
|
+
# 为了避免初始化embedding_manager,我们直接构建db_path
|
116
|
+
if self._retriever:
|
117
|
+
return self._retriever.collection
|
118
|
+
|
119
|
+
sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
|
120
|
+
"\\", "_"
|
62
121
|
)
|
63
|
-
|
64
|
-
|
65
|
-
db_path
|
66
|
-
|
122
|
+
_final_db_path = (
|
123
|
+
str(self.db_path)
|
124
|
+
if self.db_path
|
125
|
+
else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
|
67
126
|
)
|
68
|
-
# 除非提供了特定的LLM,否则默认为ToolAgent_LLM
|
69
|
-
self.llm = llm if llm is not None else ToolAgent_LLM()
|
70
|
-
self.reranker = Reranker(model_name=get_rag_rerank_model())
|
71
|
-
# 使用标准LLM执行查询重写任务,而不是代理
|
72
|
-
self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
|
73
127
|
|
74
|
-
|
128
|
+
# 直接创建ChromaRetriever所使用的chroma_client,但绕过embedding_manager
|
129
|
+
import chromadb
|
130
|
+
|
131
|
+
chroma_client = chromadb.PersistentClient(path=_final_db_path)
|
132
|
+
return chroma_client.get_collection(name=self.collection_name)
|
133
|
+
|
134
|
+
def _get_reranker(self) -> Reranker:
|
135
|
+
if self._reranker is None:
|
136
|
+
self._reranker = Reranker(model_name=get_rag_rerank_model())
|
137
|
+
return self._reranker
|
138
|
+
|
139
|
+
def _get_query_rewriter(self) -> QueryRewriter:
|
140
|
+
if self._query_rewriter is None:
|
141
|
+
# 使用标准LLM执行查询重写任务,而不是代理
|
142
|
+
self._query_rewriter = QueryRewriter(JarvisPlatform_LLM())
|
143
|
+
return self._query_rewriter
|
75
144
|
|
76
145
|
def add_documents(self, documents: List[Document]):
|
77
146
|
"""
|
@@ -80,24 +149,21 @@ class JarvisRAGPipeline:
|
|
80
149
|
参数:
|
81
150
|
documents: 要添加的LangChain文档对象列表。
|
82
151
|
"""
|
83
|
-
self.
|
152
|
+
self._get_retriever().add_documents(documents)
|
84
153
|
|
85
|
-
def _create_prompt(
|
86
|
-
self, query: str, context_docs: List[Document], source_files: List[str]
|
87
|
-
) -> str:
|
154
|
+
def _create_prompt(self, query: str, context_docs: List[Document]) -> str:
|
88
155
|
"""为LLM或代理创建最终的提示。"""
|
89
|
-
|
90
|
-
|
156
|
+
context_details = []
|
157
|
+
for doc in context_docs:
|
158
|
+
source = doc.metadata.get("source", "未知来源")
|
159
|
+
content = doc.page_content
|
160
|
+
context_details.append(f"来源: {source}\n\n---\n{content}\n---")
|
161
|
+
context = "\n\n".join(context_details)
|
91
162
|
|
92
163
|
prompt_template = f"""
|
93
164
|
你是一个专家助手。请根据用户的问题,结合下面提供的参考信息来回答。
|
94
165
|
|
95
|
-
**重要**:
|
96
|
-
|
97
|
-
参考文件列表:
|
98
|
-
---
|
99
|
-
{sources_text}
|
100
|
-
---
|
166
|
+
**重要**: 提供的上下文**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
|
101
167
|
|
102
168
|
参考上下文:
|
103
169
|
---
|
@@ -122,13 +188,15 @@ class JarvisRAGPipeline:
|
|
122
188
|
由LLM生成的答案。
|
123
189
|
"""
|
124
190
|
# 1. 将原始查询重写为多个查询
|
125
|
-
rewritten_queries = self.
|
191
|
+
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
126
192
|
|
127
193
|
# 2. 为每个重写的查询检索初始候选文档
|
128
194
|
all_candidate_docs = []
|
129
195
|
for q in rewritten_queries:
|
130
196
|
print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
|
131
|
-
candidates = self.
|
197
|
+
candidates = self._get_retriever().retrieve(
|
198
|
+
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
199
|
+
)
|
132
200
|
all_candidate_docs.extend(candidates)
|
133
201
|
|
134
202
|
# 对候选文档进行去重
|
@@ -139,12 +207,13 @@ class JarvisRAGPipeline:
|
|
139
207
|
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
140
208
|
|
141
209
|
# 3. 根据*原始*查询对统一的候选池进行重排
|
142
|
-
|
143
|
-
f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
210
|
+
if self.use_rerank:
|
211
|
+
print(f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)...")
|
212
|
+
retrieved_docs = self._get_reranker().rerank(
|
213
|
+
query_text, unique_candidate_docs, top_n=n_results
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
retrieved_docs = unique_candidate_docs[:n_results]
|
148
217
|
|
149
218
|
if not retrieved_docs:
|
150
219
|
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
@@ -166,9 +235,49 @@ class JarvisRAGPipeline:
|
|
166
235
|
|
167
236
|
# 4. 创建最终提示并生成答案
|
168
237
|
# 我们使用原始的query_text作为给LLM的最终提示
|
169
|
-
prompt = self._create_prompt(query_text, retrieved_docs
|
238
|
+
prompt = self._create_prompt(query_text, retrieved_docs)
|
170
239
|
|
171
240
|
print("🤖 正在从LLM生成答案...")
|
172
241
|
answer = self.llm.generate(prompt)
|
173
242
|
|
174
243
|
return answer
|
244
|
+
|
245
|
+
def retrieve_only(self, query_text: str, n_results: int = 5) -> List[Document]:
|
246
|
+
"""
|
247
|
+
仅执行检索和重排,不生成答案。
|
248
|
+
|
249
|
+
参数:
|
250
|
+
query_text: 用户的原始问题。
|
251
|
+
n_results: 要检索的最终相关块的数量。
|
252
|
+
|
253
|
+
返回:
|
254
|
+
检索到的文档列表。
|
255
|
+
"""
|
256
|
+
# 1. 重写查询
|
257
|
+
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
258
|
+
|
259
|
+
# 2. 检索候选文档
|
260
|
+
all_candidate_docs = []
|
261
|
+
for q in rewritten_queries:
|
262
|
+
print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
|
263
|
+
candidates = self._get_retriever().retrieve(
|
264
|
+
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
265
|
+
)
|
266
|
+
all_candidate_docs.extend(candidates)
|
267
|
+
|
268
|
+
unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
|
269
|
+
unique_candidate_docs = list(unique_docs_dict.values())
|
270
|
+
|
271
|
+
if not unique_candidate_docs:
|
272
|
+
return []
|
273
|
+
|
274
|
+
# 3. 重排
|
275
|
+
if self.use_rerank:
|
276
|
+
print(f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排...")
|
277
|
+
retrieved_docs = self._get_reranker().rerank(
|
278
|
+
query_text, unique_candidate_docs, top_n=n_results
|
279
|
+
)
|
280
|
+
else:
|
281
|
+
retrieved_docs = unique_candidate_docs[:n_results]
|
282
|
+
|
283
|
+
return retrieved_docs
|
jarvis/jarvis_rag/retriever.py
CHANGED
@@ -39,9 +39,7 @@ class ChromaRetriever:
|
|
39
39
|
self.collection = self.client.get_or_create_collection(
|
40
40
|
name=self.collection_name
|
41
41
|
)
|
42
|
-
print(
|
43
|
-
f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。"
|
44
|
-
)
|
42
|
+
print(f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。")
|
45
43
|
|
46
44
|
# BM25索引设置
|
47
45
|
self.bm25_index_path = os.path.join(self.db_path, f"{collection_name}_bm25.pkl")
|
@@ -107,7 +105,9 @@ class ChromaRetriever:
|
|
107
105
|
self.bm25_index = BM25Okapi(self.bm25_corpus)
|
108
106
|
self._save_bm25_index()
|
109
107
|
|
110
|
-
def retrieve(
|
108
|
+
def retrieve(
|
109
|
+
self, query: str, n_results: int = 5, use_bm25: bool = True
|
110
|
+
) -> List[Document]:
|
111
111
|
"""
|
112
112
|
使用向量搜索和BM25执行混合检索,然后使用倒数排序融合(RRF)
|
113
113
|
对结果进行融合。
|
@@ -121,7 +121,7 @@ class ChromaRetriever:
|
|
121
121
|
|
122
122
|
# 2. 关键字搜索 (BM25)
|
123
123
|
bm25_docs = []
|
124
|
-
if self.bm25_index:
|
124
|
+
if self.bm25_index and use_bm25:
|
125
125
|
tokenized_query = query.split()
|
126
126
|
doc_scores = self.bm25_index.get_scores(tokenized_query)
|
127
127
|
|
@@ -144,7 +144,7 @@ class ChromaRetriever:
|
|
144
144
|
]
|
145
145
|
|
146
146
|
# 按分数排序并取最高结果
|
147
|
-
bm25_results_with_docs.sort(key=lambda x: x[2], reverse=True)
|
147
|
+
bm25_results_with_docs.sort(key=lambda x: x[2], reverse=True) # type: ignore
|
148
148
|
|
149
149
|
for doc_text, metadata, _ in bm25_results_with_docs[: n_results * 2]:
|
150
150
|
bm25_docs.append(Document(page_content=doc_text, metadata=metadata))
|
@@ -0,0 +1,13 @@
|
|
1
|
+
"""
|
2
|
+
Jarvis统计模块
|
3
|
+
|
4
|
+
提供指标统计、数据持久化、可视化展示等功能
|
5
|
+
"""
|
6
|
+
|
7
|
+
from jarvis.jarvis_stats.stats import StatsManager
|
8
|
+
from jarvis.jarvis_stats.storage import StatsStorage
|
9
|
+
from jarvis.jarvis_stats.visualizer import StatsVisualizer
|
10
|
+
|
11
|
+
__all__ = ["StatsManager", "StatsStorage", "StatsVisualizer"]
|
12
|
+
|
13
|
+
__version__ = "1.0.0"
|