jarvis-ai-assistant 0.3.17__py3-none-any.whl → 0.3.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +23 -10
- jarvis/jarvis_agent/edit_file_handler.py +8 -13
- jarvis/jarvis_agent/jarvis.py +13 -3
- jarvis/jarvis_agent/memory_manager.py +4 -4
- jarvis/jarvis_agent/methodology_share_manager.py +2 -2
- jarvis/jarvis_agent/task_analyzer.py +4 -3
- jarvis/jarvis_agent/task_manager.py +6 -6
- jarvis/jarvis_agent/tool_executor.py +2 -2
- jarvis/jarvis_agent/tool_share_manager.py +2 -2
- jarvis/jarvis_code_agent/code_agent.py +21 -29
- jarvis/jarvis_code_analysis/code_review.py +2 -4
- jarvis/jarvis_data/config_schema.json +5 -0
- jarvis/jarvis_git_utils/git_commiter.py +17 -18
- jarvis/jarvis_methodology/main.py +12 -12
- jarvis/jarvis_platform/base.py +21 -13
- jarvis/jarvis_platform/kimi.py +13 -13
- jarvis/jarvis_platform/tongyi.py +17 -15
- jarvis/jarvis_platform/yuanbao.py +11 -11
- jarvis/jarvis_platform_manager/main.py +12 -22
- jarvis/jarvis_rag/cli.py +36 -32
- jarvis/jarvis_rag/embedding_manager.py +11 -6
- jarvis/jarvis_rag/llm_interface.py +6 -5
- jarvis/jarvis_rag/rag_pipeline.py +9 -8
- jarvis/jarvis_rag/reranker.py +3 -2
- jarvis/jarvis_rag/retriever.py +18 -8
- jarvis/jarvis_smart_shell/main.py +306 -46
- jarvis/jarvis_stats/stats.py +40 -0
- jarvis/jarvis_stats/storage.py +220 -9
- jarvis/jarvis_tools/clear_memory.py +0 -11
- jarvis/jarvis_tools/cli/main.py +18 -17
- jarvis/jarvis_tools/edit_file.py +4 -4
- jarvis/jarvis_tools/execute_script.py +5 -1
- jarvis/jarvis_tools/file_analyzer.py +6 -6
- jarvis/jarvis_tools/generate_new_tool.py +6 -17
- jarvis/jarvis_tools/read_code.py +3 -6
- jarvis/jarvis_tools/read_webpage.py +74 -13
- jarvis/jarvis_tools/registry.py +8 -28
- jarvis/jarvis_tools/retrieve_memory.py +5 -16
- jarvis/jarvis_tools/rewrite_file.py +0 -4
- jarvis/jarvis_tools/save_memory.py +2 -10
- jarvis/jarvis_tools/search_web.py +5 -8
- jarvis/jarvis_tools/virtual_tty.py +22 -40
- jarvis/jarvis_utils/clipboard.py +3 -3
- jarvis/jarvis_utils/config.py +8 -0
- jarvis/jarvis_utils/input.py +67 -27
- jarvis/jarvis_utils/methodology.py +3 -3
- jarvis/jarvis_utils/output.py +1 -7
- jarvis/jarvis_utils/utils.py +44 -58
- {jarvis_ai_assistant-0.3.17.dist-info → jarvis_ai_assistant-0.3.19.dist-info}/METADATA +1 -1
- {jarvis_ai_assistant-0.3.17.dist-info → jarvis_ai_assistant-0.3.19.dist-info}/RECORD +55 -55
- {jarvis_ai_assistant-0.3.17.dist-info → jarvis_ai_assistant-0.3.19.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.3.17.dist-info → jarvis_ai_assistant-0.3.19.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.3.17.dist-info → jarvis_ai_assistant-0.3.19.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.3.17.dist-info → jarvis_ai_assistant-0.3.19.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/cli.py
CHANGED
@@ -20,6 +20,7 @@ from jarvis.jarvis_utils.config import (
|
|
20
20
|
get_rag_use_bm25,
|
21
21
|
get_rag_use_rerank,
|
22
22
|
)
|
23
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
23
24
|
|
24
25
|
|
25
26
|
def is_likely_text_file(file_path: Path) -> bool:
|
@@ -70,7 +71,10 @@ class _CustomPlatformLLM(LLMInterface):
|
|
70
71
|
|
71
72
|
def __init__(self, platform: BasePlatform):
|
72
73
|
self.platform = platform
|
73
|
-
print(
|
74
|
+
PrettyOutput.print(
|
75
|
+
f"使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'",
|
76
|
+
OutputType.INFO,
|
77
|
+
)
|
74
78
|
|
75
79
|
def generate(self, prompt: str, **kwargs) -> str:
|
76
80
|
return self.platform.chat_until_success(prompt)
|
@@ -84,13 +88,13 @@ def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInter
|
|
84
88
|
registry = PlatformRegistry.get_global_platform_registry()
|
85
89
|
platform_instance = registry.create_platform(platform_name)
|
86
90
|
if not platform_instance:
|
87
|
-
print(f"
|
91
|
+
PrettyOutput.print(f"错误: 平台 '{platform_name}' 未找到。", OutputType.ERROR)
|
88
92
|
return None
|
89
93
|
platform_instance.set_model_name(model_name)
|
90
94
|
platform_instance.set_suppress_output(True)
|
91
95
|
return _CustomPlatformLLM(platform_instance)
|
92
96
|
except Exception as e:
|
93
|
-
print(f"
|
97
|
+
PrettyOutput.print(f"创建自定义LLM时出错: {e}", OutputType.ERROR)
|
94
98
|
return None
|
95
99
|
|
96
100
|
|
@@ -114,10 +118,10 @@ def _load_ragignore_spec() -> Tuple[Optional[pathspec.PathSpec], Optional[Path]]
|
|
114
118
|
with open(ignore_file_to_use, "r", encoding="utf-8") as f:
|
115
119
|
patterns = f.read().splitlines()
|
116
120
|
spec = pathspec.PathSpec.from_lines("gitwildmatch", patterns)
|
117
|
-
print(f"
|
121
|
+
PrettyOutput.print(f"加载忽略规则: {ignore_file_to_use}", OutputType.SUCCESS)
|
118
122
|
return spec, project_root_path
|
119
123
|
except Exception as e:
|
120
|
-
print(f"
|
124
|
+
PrettyOutput.print(f"加载 {ignore_file_to_use.name} 文件失败: {e}", OutputType.WARNING)
|
121
125
|
|
122
126
|
return None, None
|
123
127
|
|
@@ -166,7 +170,7 @@ def add_documents(
|
|
166
170
|
continue
|
167
171
|
|
168
172
|
if path.is_dir():
|
169
|
-
print(f"
|
173
|
+
PrettyOutput.print(f"正在扫描目录: {path}", OutputType.INFO)
|
170
174
|
for item in path.rglob("*"):
|
171
175
|
if item.is_file() and is_likely_text_file(item):
|
172
176
|
files_to_process.add(item)
|
@@ -174,10 +178,10 @@ def add_documents(
|
|
174
178
|
if is_likely_text_file(path):
|
175
179
|
files_to_process.add(path)
|
176
180
|
else:
|
177
|
-
print(f"
|
181
|
+
PrettyOutput.print(f"跳过可能的二进制文件: {path}", OutputType.WARNING)
|
178
182
|
|
179
183
|
if not files_to_process:
|
180
|
-
print("
|
184
|
+
PrettyOutput.print("在指定路径中未找到任何文本文件。", OutputType.WARNING)
|
181
185
|
return
|
182
186
|
|
183
187
|
# 使用 .ragignore 过滤文件
|
@@ -198,14 +202,14 @@ def add_documents(
|
|
198
202
|
|
199
203
|
ignored_count = initial_count - len(retained_files)
|
200
204
|
if ignored_count > 0:
|
201
|
-
print(f"
|
205
|
+
PrettyOutput.print(f"根据 .ragignore 规则过滤掉 {ignored_count} 个文件。", OutputType.INFO)
|
202
206
|
files_to_process = retained_files
|
203
207
|
|
204
208
|
if not files_to_process:
|
205
|
-
print("
|
209
|
+
PrettyOutput.print("所有找到的文本文件都被忽略规则过滤掉了。", OutputType.WARNING)
|
206
210
|
return
|
207
211
|
|
208
|
-
print(f"
|
212
|
+
PrettyOutput.print(f"发现 {len(files_to_process)} 个独立文件待处理。", OutputType.INFO)
|
209
213
|
|
210
214
|
try:
|
211
215
|
pipeline = JarvisRAGPipeline(
|
@@ -229,26 +233,26 @@ def add_documents(
|
|
229
233
|
loader = TextLoader(str(file_path), encoding="utf-8")
|
230
234
|
|
231
235
|
docs_batch.extend(loader.load())
|
232
|
-
print(f"
|
236
|
+
PrettyOutput.print(f"已加载: {file_path} (文件 {i + 1}/{total_files})", OutputType.INFO)
|
233
237
|
except Exception as e:
|
234
|
-
print(f"
|
238
|
+
PrettyOutput.print(f"加载失败 {file_path}: {e}", OutputType.WARNING)
|
235
239
|
|
236
240
|
# 当批处理已满或是最后一个文件时处理批处理
|
237
241
|
if docs_batch and (len(docs_batch) >= batch_size or (i + 1) == total_files):
|
238
|
-
print(f"
|
242
|
+
PrettyOutput.print(f"正在处理批次,包含 {len(docs_batch)} 个文档...", OutputType.INFO)
|
239
243
|
pipeline.add_documents(docs_batch)
|
240
244
|
total_docs_added += len(docs_batch)
|
241
|
-
print(f"
|
245
|
+
PrettyOutput.print(f"成功添加 {len(docs_batch)} 个文档。", OutputType.SUCCESS)
|
242
246
|
docs_batch = [] # 清空批处理
|
243
247
|
|
244
248
|
if total_docs_added == 0:
|
245
|
-
print("
|
249
|
+
PrettyOutput.print("未能成功加载任何文档。", OutputType.ERROR)
|
246
250
|
raise typer.Exit(code=1)
|
247
251
|
|
248
|
-
print(f"
|
252
|
+
PrettyOutput.print(f"成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。", OutputType.SUCCESS)
|
249
253
|
|
250
254
|
except Exception as e:
|
251
|
-
print(f"
|
255
|
+
PrettyOutput.print(f"发生严重错误: {e}", OutputType.ERROR)
|
252
256
|
raise typer.Exit(code=1)
|
253
257
|
|
254
258
|
|
@@ -273,7 +277,7 @@ def list_documents(
|
|
273
277
|
results = collection.get() # 获取集合中的所有项目
|
274
278
|
|
275
279
|
if not results or not results["metadatas"]:
|
276
|
-
print("
|
280
|
+
PrettyOutput.print("知识库中没有找到任何文档。", OutputType.INFO)
|
277
281
|
return
|
278
282
|
|
279
283
|
# 从元数据中提取唯一的源文件路径
|
@@ -285,15 +289,15 @@ def list_documents(
|
|
285
289
|
sources.add(source)
|
286
290
|
|
287
291
|
if not sources:
|
288
|
-
print("
|
292
|
+
PrettyOutput.print("知识库中没有找到任何带有源信息的文档。", OutputType.INFO)
|
289
293
|
return
|
290
294
|
|
291
|
-
print(f"
|
295
|
+
PrettyOutput.print(f"知识库 '{collection_name}' 中共有 {len(sources)} 个独立文档:", OutputType.INFO)
|
292
296
|
for i, source in enumerate(sorted(list(sources)), 1):
|
293
|
-
print(f" {i}. {source}")
|
297
|
+
PrettyOutput.print(f" {i}. {source}", OutputType.INFO)
|
294
298
|
|
295
299
|
except Exception as e:
|
296
|
-
print(f"
|
300
|
+
PrettyOutput.print(f"发生错误: {e}", OutputType.ERROR)
|
297
301
|
raise typer.Exit(code=1)
|
298
302
|
|
299
303
|
|
@@ -330,14 +334,14 @@ def retrieve(
|
|
330
334
|
use_rerank=use_rerank,
|
331
335
|
)
|
332
336
|
|
333
|
-
print(f"
|
337
|
+
PrettyOutput.print(f"正在为问题检索文档: '{question}'", OutputType.INFO)
|
334
338
|
retrieved_docs = pipeline.retrieve_only(question, n_results=n_results)
|
335
339
|
|
336
340
|
if not retrieved_docs:
|
337
|
-
print("
|
341
|
+
PrettyOutput.print("未找到相关文档。", OutputType.INFO)
|
338
342
|
return
|
339
343
|
|
340
|
-
print(f"
|
344
|
+
PrettyOutput.print(f"成功检索到 {len(retrieved_docs)} 个文档:", OutputType.SUCCESS)
|
341
345
|
from jarvis.jarvis_utils.globals import console
|
342
346
|
|
343
347
|
for i, doc in enumerate(retrieved_docs, 1):
|
@@ -350,7 +354,7 @@ def retrieve(
|
|
350
354
|
console.print(Markdown(f"```\n{content}\n```"))
|
351
355
|
|
352
356
|
except Exception as e:
|
353
|
-
print(f"
|
357
|
+
PrettyOutput.print(f"发生错误: {e}", OutputType.ERROR)
|
354
358
|
raise typer.Exit(code=1)
|
355
359
|
|
356
360
|
|
@@ -385,7 +389,7 @@ def query(
|
|
385
389
|
):
|
386
390
|
"""查询RAG知识库并打印答案。"""
|
387
391
|
if model and not platform:
|
388
|
-
print("
|
392
|
+
PrettyOutput.print("错误: --model 需要指定 --platform。", OutputType.ERROR)
|
389
393
|
raise typer.Exit(code=1)
|
390
394
|
|
391
395
|
try:
|
@@ -407,17 +411,17 @@ def query(
|
|
407
411
|
use_rerank=use_rerank,
|
408
412
|
)
|
409
413
|
|
410
|
-
print(f"
|
414
|
+
PrettyOutput.print(f"正在查询: '{question}'", OutputType.INFO)
|
411
415
|
answer = pipeline.query(question)
|
412
416
|
|
413
|
-
print("
|
417
|
+
PrettyOutput.print("答案:", OutputType.INFO)
|
414
418
|
# 我们仍然可以使用 rich.markdown.Markdown,因为 PrettyOutput 底层使用了 rich
|
415
419
|
from jarvis.jarvis_utils.globals import console
|
416
420
|
|
417
421
|
console.print(Markdown(answer))
|
418
422
|
|
419
423
|
except Exception as e:
|
420
|
-
print(f"
|
424
|
+
PrettyOutput.print(f"发生错误: {e}", OutputType.ERROR)
|
421
425
|
raise typer.Exit(code=1)
|
422
426
|
|
423
427
|
|
@@ -432,7 +436,7 @@ except ImportError:
|
|
432
436
|
|
433
437
|
def _check_rag_dependencies():
|
434
438
|
if not _RAG_INSTALLED:
|
435
|
-
print("
|
439
|
+
PrettyOutput.print("RAG依赖项未安装。请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。", OutputType.ERROR)
|
436
440
|
raise typer.Exit(code=1)
|
437
441
|
|
438
442
|
|
@@ -3,6 +3,7 @@ from typing import List, cast
|
|
3
3
|
from langchain_huggingface import HuggingFaceEmbeddings
|
4
4
|
|
5
5
|
from .cache import EmbeddingCache
|
6
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
6
7
|
|
7
8
|
|
8
9
|
class EmbeddingManager:
|
@@ -23,7 +24,7 @@ class EmbeddingManager:
|
|
23
24
|
"""
|
24
25
|
self.model_name = model_name
|
25
26
|
|
26
|
-
print(f"
|
27
|
+
PrettyOutput.print(f"初始化嵌入管理器, 模型: '{self.model_name}'...", OutputType.INFO)
|
27
28
|
|
28
29
|
# 缓存的salt是模型名称,以防止冲突
|
29
30
|
self.cache = EmbeddingCache(cache_dir=cache_dir, salt=self.model_name)
|
@@ -42,8 +43,8 @@ class EmbeddingManager:
|
|
42
43
|
show_progress=True,
|
43
44
|
)
|
44
45
|
except Exception as e:
|
45
|
-
print(f"
|
46
|
-
print("请确保您已安装 'sentence_transformers' 和 'torch'。")
|
46
|
+
PrettyOutput.print(f"加载嵌入模型 '{self.model_name}' 时出错: {e}", OutputType.ERROR)
|
47
|
+
PrettyOutput.print("请确保您已安装 'sentence_transformers' 和 'torch'。", OutputType.WARNING)
|
47
48
|
raise
|
48
49
|
|
49
50
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
@@ -71,8 +72,9 @@ class EmbeddingManager:
|
|
71
72
|
|
72
73
|
# 为不在缓存中的文本计算嵌入
|
73
74
|
if texts_to_embed:
|
74
|
-
print(
|
75
|
-
f"
|
75
|
+
PrettyOutput.print(
|
76
|
+
f"缓存未命中。正在为 {len(texts_to_embed)}/{len(texts)} 个文档计算嵌入。",
|
77
|
+
OutputType.INFO,
|
76
78
|
)
|
77
79
|
new_embeddings = self.model.embed_documents(texts_to_embed)
|
78
80
|
|
@@ -83,7 +85,10 @@ class EmbeddingManager:
|
|
83
85
|
for i, embedding in zip(indices_to_embed, new_embeddings):
|
84
86
|
cached_embeddings[i] = embedding
|
85
87
|
else:
|
86
|
-
print(
|
88
|
+
PrettyOutput.print(
|
89
|
+
f"缓存命中。所有 {len(texts)} 个文档的嵌入均从缓存中检索。",
|
90
|
+
OutputType.SUCCESS,
|
91
|
+
)
|
87
92
|
|
88
93
|
return cast(List[List[float]], cached_embeddings)
|
89
94
|
|
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
|
6
6
|
from jarvis.jarvis_agent import Agent as JarvisAgent
|
7
7
|
from jarvis.jarvis_platform.base import BasePlatform
|
8
8
|
from jarvis.jarvis_platform.registry import PlatformRegistry
|
9
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
9
10
|
|
10
11
|
|
11
12
|
class LLMInterface(ABC):
|
@@ -41,7 +42,7 @@ class ToolAgent_LLM(LLMInterface):
|
|
41
42
|
"""
|
42
43
|
初始化工具-代理 LLM 包装器。
|
43
44
|
"""
|
44
|
-
print("
|
45
|
+
PrettyOutput.print("已初始化工具 Agent 作为最终应答者。", OutputType.INFO)
|
45
46
|
self.allowed_tools = ["read_code", "execute_script"]
|
46
47
|
# 为代理提供一个通用的系统提示
|
47
48
|
self.system_prompt = "You are a helpful assistant. Please answer the user's question based on the provided context. You can use tools to find more information if needed."
|
@@ -83,7 +84,7 @@ class ToolAgent_LLM(LLMInterface):
|
|
83
84
|
return str(final_answer)
|
84
85
|
|
85
86
|
except Exception as e:
|
86
|
-
print(f"
|
87
|
+
PrettyOutput.print(f"Agent 在执行过程中发生错误: {e}", OutputType.ERROR)
|
87
88
|
return "错误: Agent 未能成功生成回答。"
|
88
89
|
|
89
90
|
|
@@ -102,9 +103,9 @@ class JarvisPlatform_LLM(LLMInterface):
|
|
102
103
|
self.registry = PlatformRegistry.get_global_platform_registry()
|
103
104
|
self.platform: BasePlatform = self.registry.get_normal_platform()
|
104
105
|
self.platform.set_suppress_output(False) # 确保模型没有控制台输出
|
105
|
-
print(f"
|
106
|
+
PrettyOutput.print(f"已初始化 Jarvis 平台 LLM,模型: {self.platform.name()}", OutputType.INFO)
|
106
107
|
except Exception as e:
|
107
|
-
print(f"
|
108
|
+
PrettyOutput.print(f"初始化 Jarvis 平台 LLM 失败: {e}", OutputType.ERROR)
|
108
109
|
raise
|
109
110
|
|
110
111
|
def generate(self, prompt: str, **kwargs) -> str:
|
@@ -122,5 +123,5 @@ class JarvisPlatform_LLM(LLMInterface):
|
|
122
123
|
# 使用健壮的chat_until_success方法
|
123
124
|
return self.platform.chat_until_success(prompt)
|
124
125
|
except Exception as e:
|
125
|
-
print(f"
|
126
|
+
PrettyOutput.print(f"调用 Jarvis 平台模型时发生错误: {e}", OutputType.ERROR)
|
126
127
|
return "错误: 无法从本地LLM获取响应。"
|
@@ -8,6 +8,7 @@ from .llm_interface import JarvisPlatform_LLM, LLMInterface, ToolAgent_LLM
|
|
8
8
|
from .query_rewriter import QueryRewriter
|
9
9
|
from .reranker import Reranker
|
10
10
|
from .retriever import ChromaRetriever
|
11
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
11
12
|
from jarvis.jarvis_utils.config import (
|
12
13
|
get_rag_embedding_model,
|
13
14
|
get_rag_rerank_model,
|
@@ -74,7 +75,7 @@ class JarvisRAGPipeline:
|
|
74
75
|
self._reranker: Optional[Reranker] = None
|
75
76
|
self._query_rewriter: Optional[QueryRewriter] = None
|
76
77
|
|
77
|
-
print("
|
78
|
+
PrettyOutput.print("JarvisRAGPipeline 初始化成功 (模型按需加载).", OutputType.SUCCESS)
|
78
79
|
|
79
80
|
def _get_embedding_manager(self) -> EmbeddingManager:
|
80
81
|
if self._embedding_manager is None:
|
@@ -193,7 +194,7 @@ class JarvisRAGPipeline:
|
|
193
194
|
# 2. 为每个重写的查询检索初始候选文档
|
194
195
|
all_candidate_docs = []
|
195
196
|
for q in rewritten_queries:
|
196
|
-
print(f"
|
197
|
+
PrettyOutput.print(f"正在为查询变体 '{q}' 进行混合检索...", OutputType.INFO)
|
197
198
|
candidates = self._get_retriever().retrieve(
|
198
199
|
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
199
200
|
)
|
@@ -208,7 +209,7 @@ class JarvisRAGPipeline:
|
|
208
209
|
|
209
210
|
# 3. 根据*原始*查询对统一的候选池进行重排
|
210
211
|
if self.use_rerank:
|
211
|
-
print(f"
|
212
|
+
PrettyOutput.print(f"正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)...", OutputType.INFO)
|
212
213
|
retrieved_docs = self._get_reranker().rerank(
|
213
214
|
query_text, unique_candidate_docs, top_n=n_results
|
214
215
|
)
|
@@ -229,15 +230,15 @@ class JarvisRAGPipeline:
|
|
229
230
|
)
|
230
231
|
)
|
231
232
|
if sources:
|
232
|
-
print(
|
233
|
+
PrettyOutput.print("根据以下文档回答:", OutputType.INFO)
|
233
234
|
for source in sources:
|
234
|
-
print(f" - {source}")
|
235
|
+
PrettyOutput.print(f" - {source}", OutputType.INFO)
|
235
236
|
|
236
237
|
# 4. 创建最终提示并生成答案
|
237
238
|
# 我们使用原始的query_text作为给LLM的最终提示
|
238
239
|
prompt = self._create_prompt(query_text, retrieved_docs)
|
239
240
|
|
240
|
-
print("
|
241
|
+
PrettyOutput.print("正在从LLM生成答案...", OutputType.INFO)
|
241
242
|
answer = self.llm.generate(prompt)
|
242
243
|
|
243
244
|
return answer
|
@@ -259,7 +260,7 @@ class JarvisRAGPipeline:
|
|
259
260
|
# 2. 检索候选文档
|
260
261
|
all_candidate_docs = []
|
261
262
|
for q in rewritten_queries:
|
262
|
-
print(f"
|
263
|
+
PrettyOutput.print(f"正在为查询变体 '{q}' 进行混合检索...", OutputType.INFO)
|
263
264
|
candidates = self._get_retriever().retrieve(
|
264
265
|
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
265
266
|
)
|
@@ -273,7 +274,7 @@ class JarvisRAGPipeline:
|
|
273
274
|
|
274
275
|
# 3. 重排
|
275
276
|
if self.use_rerank:
|
276
|
-
print(f"
|
277
|
+
PrettyOutput.print(f"正在对 {len(unique_candidate_docs)} 个候选文档进行重排...", OutputType.INFO)
|
277
278
|
retrieved_docs = self._get_reranker().rerank(
|
278
279
|
query_text, unique_candidate_docs, top_n=n_results
|
279
280
|
)
|
jarvis/jarvis_rag/reranker.py
CHANGED
@@ -4,6 +4,7 @@ from langchain.docstore.document import Document
|
|
4
4
|
from sentence_transformers.cross_encoder import ( # type: ignore
|
5
5
|
CrossEncoder,
|
6
6
|
)
|
7
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
7
8
|
|
8
9
|
|
9
10
|
class Reranker:
|
@@ -19,9 +20,9 @@ class Reranker:
|
|
19
20
|
参数:
|
20
21
|
model_name (str): 要使用的Cross-Encoder模型的名称。
|
21
22
|
"""
|
22
|
-
print(f"
|
23
|
+
PrettyOutput.print(f"正在初始化重排模型: {model_name}...", OutputType.INFO)
|
23
24
|
self.model = CrossEncoder(model_name)
|
24
|
-
print("
|
25
|
+
PrettyOutput.print("重排模型初始化成功。", OutputType.SUCCESS)
|
25
26
|
|
26
27
|
def rerank(
|
27
28
|
self, query: str, documents: List[Document], top_n: int = 5
|
jarvis/jarvis_rag/retriever.py
CHANGED
@@ -8,6 +8,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
8
8
|
from rank_bm25 import BM25Okapi # type: ignore
|
9
9
|
|
10
10
|
from .embedding_manager import EmbeddingManager
|
11
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
11
12
|
|
12
13
|
|
13
14
|
class ChromaRetriever:
|
@@ -39,7 +40,10 @@ class ChromaRetriever:
|
|
39
40
|
self.collection = self.client.get_or_create_collection(
|
40
41
|
name=self.collection_name
|
41
42
|
)
|
42
|
-
print(
|
43
|
+
PrettyOutput.print(
|
44
|
+
f"ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。",
|
45
|
+
OutputType.SUCCESS,
|
46
|
+
)
|
43
47
|
|
44
48
|
# BM25索引设置
|
45
49
|
self.bm25_index_path = os.path.join(self.db_path, f"{collection_name}_bm25.pkl")
|
@@ -48,24 +52,24 @@ class ChromaRetriever:
|
|
48
52
|
def _load_or_initialize_bm25(self):
|
49
53
|
"""从磁盘加载BM25索引或初始化一个新索引。"""
|
50
54
|
if os.path.exists(self.bm25_index_path):
|
51
|
-
print("
|
55
|
+
PrettyOutput.print("正在加载现有的 BM25 索引...", OutputType.INFO)
|
52
56
|
with open(self.bm25_index_path, "rb") as f:
|
53
57
|
data = pickle.load(f)
|
54
58
|
self.bm25_corpus = data["corpus"]
|
55
59
|
self.bm25_index = BM25Okapi(self.bm25_corpus)
|
56
|
-
print("
|
60
|
+
PrettyOutput.print("BM25 索引加载成功。", OutputType.SUCCESS)
|
57
61
|
else:
|
58
|
-
print("
|
62
|
+
PrettyOutput.print("未找到 BM25 索引,将初始化一个新的。", OutputType.WARNING)
|
59
63
|
self.bm25_corpus = []
|
60
64
|
self.bm25_index = None
|
61
65
|
|
62
66
|
def _save_bm25_index(self):
|
63
67
|
"""将BM25索引保存到磁盘。"""
|
64
68
|
if self.bm25_index:
|
65
|
-
print("
|
69
|
+
PrettyOutput.print("正在保存 BM25 索引...", OutputType.INFO)
|
66
70
|
with open(self.bm25_index_path, "wb") as f:
|
67
71
|
pickle.dump({"corpus": self.bm25_corpus, "index": self.bm25_index}, f)
|
68
|
-
print("
|
72
|
+
PrettyOutput.print("BM25 索引保存成功。", OutputType.SUCCESS)
|
69
73
|
|
70
74
|
def add_documents(
|
71
75
|
self, documents: List[Document], chunk_size=1000, chunk_overlap=100
|
@@ -78,7 +82,10 @@ class ChromaRetriever:
|
|
78
82
|
)
|
79
83
|
chunks = text_splitter.split_documents(documents)
|
80
84
|
|
81
|
-
print(
|
85
|
+
PrettyOutput.print(
|
86
|
+
f"已将 {len(documents)} 个文档拆分为 {len(chunks)} 个块。",
|
87
|
+
OutputType.INFO,
|
88
|
+
)
|
82
89
|
|
83
90
|
if not chunks:
|
84
91
|
return
|
@@ -97,7 +104,10 @@ class ChromaRetriever:
|
|
97
104
|
documents=chunk_texts,
|
98
105
|
metadatas=cast(Any, metadatas),
|
99
106
|
)
|
100
|
-
print(
|
107
|
+
PrettyOutput.print(
|
108
|
+
f"成功将 {len(chunks)} 个块添加到 ChromaDB 集合中。",
|
109
|
+
OutputType.SUCCESS,
|
110
|
+
)
|
101
111
|
|
102
112
|
# 更新并保存BM25索引
|
103
113
|
tokenized_chunks = [doc.split() for doc in chunk_texts]
|