jarvis-ai-assistant 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/prompts.py +26 -4
- jarvis/jarvis_data/config_schema.json +67 -12
- jarvis/jarvis_platform/tongyi.py +9 -9
- jarvis/jarvis_rag/cli.py +79 -23
- jarvis/jarvis_rag/query_rewriter.py +61 -12
- jarvis/jarvis_rag/rag_pipeline.py +143 -34
- jarvis/jarvis_rag/retriever.py +5 -5
- jarvis/jarvis_tools/generate_new_tool.py +22 -1
- jarvis/jarvis_utils/config.py +92 -11
- jarvis/jarvis_utils/globals.py +29 -8
- jarvis/jarvis_utils/input.py +114 -121
- jarvis/jarvis_utils/utils.py +3 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/METADATA +82 -9
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/RECORD +19 -19
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,8 @@ class JarvisRAGPipeline:
|
|
30
30
|
embedding_model: Optional[str] = None,
|
31
31
|
db_path: Optional[str] = None,
|
32
32
|
collection_name: str = "jarvis_rag_collection",
|
33
|
+
use_bm25: bool = True,
|
34
|
+
use_rerank: bool = True,
|
33
35
|
):
|
34
36
|
"""
|
35
37
|
初始化RAG管道。
|
@@ -40,6 +42,8 @@ class JarvisRAGPipeline:
|
|
40
42
|
embedding_model: 嵌入模型的名称。如果为None,则使用配置值。
|
41
43
|
db_path: 持久化向量数据库的路径。如果为None,则使用配置值。
|
42
44
|
collection_name: 向量数据库中集合的名称。
|
45
|
+
use_bm25: 是否在检索中使用BM25。
|
46
|
+
use_rerank: 是否在检索后使用重排器。
|
43
47
|
"""
|
44
48
|
# 确定嵌入模型以隔离数据路径
|
45
49
|
model_name = embedding_model or get_rag_embedding_model()
|
@@ -56,22 +60,87 @@ class JarvisRAGPipeline:
|
|
56
60
|
get_rag_embedding_cache_path(), sanitized_model_name
|
57
61
|
)
|
58
62
|
|
59
|
-
|
60
|
-
|
61
|
-
|
63
|
+
# 存储初始化参数以供延迟加载
|
64
|
+
self.llm = llm if llm is not None else ToolAgent_LLM()
|
65
|
+
self.embedding_model_name = embedding_model or get_rag_embedding_model()
|
66
|
+
self.db_path = db_path
|
67
|
+
self.collection_name = collection_name
|
68
|
+
self.use_bm25 = use_bm25
|
69
|
+
self.use_rerank = use_rerank
|
70
|
+
|
71
|
+
# 延迟加载的组件
|
72
|
+
self._embedding_manager: Optional[EmbeddingManager] = None
|
73
|
+
self._retriever: Optional[ChromaRetriever] = None
|
74
|
+
self._reranker: Optional[Reranker] = None
|
75
|
+
self._query_rewriter: Optional[QueryRewriter] = None
|
76
|
+
|
77
|
+
print("✅ JarvisRAGPipeline 初始化成功 (模型按需加载).")
|
78
|
+
|
79
|
+
def _get_embedding_manager(self) -> EmbeddingManager:
|
80
|
+
if self._embedding_manager is None:
|
81
|
+
sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
|
82
|
+
"\\", "_"
|
83
|
+
)
|
84
|
+
_final_cache_path = os.path.join(
|
85
|
+
get_rag_embedding_cache_path(), sanitized_model_name
|
86
|
+
)
|
87
|
+
self._embedding_manager = EmbeddingManager(
|
88
|
+
model_name=self.embedding_model_name,
|
89
|
+
cache_dir=_final_cache_path,
|
90
|
+
)
|
91
|
+
return self._embedding_manager
|
92
|
+
|
93
|
+
def _get_retriever(self) -> ChromaRetriever:
|
94
|
+
if self._retriever is None:
|
95
|
+
sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
|
96
|
+
"\\", "_"
|
97
|
+
)
|
98
|
+
_final_db_path = (
|
99
|
+
str(self.db_path)
|
100
|
+
if self.db_path
|
101
|
+
else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
|
102
|
+
)
|
103
|
+
self._retriever = ChromaRetriever(
|
104
|
+
embedding_manager=self._get_embedding_manager(),
|
105
|
+
db_path=_final_db_path,
|
106
|
+
collection_name=self.collection_name,
|
107
|
+
)
|
108
|
+
return self._retriever
|
109
|
+
|
110
|
+
def _get_collection(self):
|
111
|
+
"""
|
112
|
+
在不加载嵌入模型的情况下,直接获取并返回Chroma集合对象。
|
113
|
+
这对于仅需要访问集合元数据(如列出文档)而无需嵌入功能的操作非常有用。
|
114
|
+
"""
|
115
|
+
# 为了避免初始化embedding_manager,我们直接构建db_path
|
116
|
+
if self._retriever:
|
117
|
+
return self._retriever.collection
|
118
|
+
|
119
|
+
sanitized_model_name = self.embedding_model_name.replace("/", "_").replace(
|
120
|
+
"\\", "_"
|
62
121
|
)
|
63
|
-
|
64
|
-
|
65
|
-
db_path
|
66
|
-
|
122
|
+
_final_db_path = (
|
123
|
+
str(self.db_path)
|
124
|
+
if self.db_path
|
125
|
+
else os.path.join(get_rag_vector_db_path(), sanitized_model_name)
|
67
126
|
)
|
68
|
-
# 除非提供了特定的LLM,否则默认为ToolAgent_LLM
|
69
|
-
self.llm = llm if llm is not None else ToolAgent_LLM()
|
70
|
-
self.reranker = Reranker(model_name=get_rag_rerank_model())
|
71
|
-
# 使用标准LLM执行查询重写任务,而不是代理
|
72
|
-
self.query_rewriter = QueryRewriter(JarvisPlatform_LLM())
|
73
127
|
|
74
|
-
|
128
|
+
# 直接创建ChromaRetriever所使用的chroma_client,但绕过embedding_manager
|
129
|
+
import chromadb
|
130
|
+
|
131
|
+
chroma_client = chromadb.PersistentClient(path=_final_db_path)
|
132
|
+
return chroma_client.get_collection(name=self.collection_name)
|
133
|
+
|
134
|
+
def _get_reranker(self) -> Reranker:
|
135
|
+
if self._reranker is None:
|
136
|
+
self._reranker = Reranker(model_name=get_rag_rerank_model())
|
137
|
+
return self._reranker
|
138
|
+
|
139
|
+
def _get_query_rewriter(self) -> QueryRewriter:
|
140
|
+
if self._query_rewriter is None:
|
141
|
+
# 使用标准LLM执行查询重写任务,而不是代理
|
142
|
+
self._query_rewriter = QueryRewriter(JarvisPlatform_LLM())
|
143
|
+
return self._query_rewriter
|
75
144
|
|
76
145
|
def add_documents(self, documents: List[Document]):
|
77
146
|
"""
|
@@ -80,24 +149,21 @@ class JarvisRAGPipeline:
|
|
80
149
|
参数:
|
81
150
|
documents: 要添加的LangChain文档对象列表。
|
82
151
|
"""
|
83
|
-
self.
|
152
|
+
self._get_retriever().add_documents(documents)
|
84
153
|
|
85
|
-
def _create_prompt(
|
86
|
-
self, query: str, context_docs: List[Document], source_files: List[str]
|
87
|
-
) -> str:
|
154
|
+
def _create_prompt(self, query: str, context_docs: List[Document]) -> str:
|
88
155
|
"""为LLM或代理创建最终的提示。"""
|
89
|
-
|
90
|
-
|
156
|
+
context_details = []
|
157
|
+
for doc in context_docs:
|
158
|
+
source = doc.metadata.get("source", "未知来源")
|
159
|
+
content = doc.page_content
|
160
|
+
context_details.append(f"来源: {source}\n\n---\n{content}\n---")
|
161
|
+
context = "\n\n".join(context_details)
|
91
162
|
|
92
163
|
prompt_template = f"""
|
93
164
|
你是一个专家助手。请根据用户的问题,结合下面提供的参考信息来回答。
|
94
165
|
|
95
|
-
**重要**:
|
96
|
-
|
97
|
-
参考文件列表:
|
98
|
-
---
|
99
|
-
{sources_text}
|
100
|
-
---
|
166
|
+
**重要**: 提供的上下文**仅供参考**,可能不完整或已过时。在回答前,你应该**优先使用工具(如 read_code)来获取最新、最准确的信息**。
|
101
167
|
|
102
168
|
参考上下文:
|
103
169
|
---
|
@@ -122,13 +188,15 @@ class JarvisRAGPipeline:
|
|
122
188
|
由LLM生成的答案。
|
123
189
|
"""
|
124
190
|
# 1. 将原始查询重写为多个查询
|
125
|
-
rewritten_queries = self.
|
191
|
+
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
126
192
|
|
127
193
|
# 2. 为每个重写的查询检索初始候选文档
|
128
194
|
all_candidate_docs = []
|
129
195
|
for q in rewritten_queries:
|
130
196
|
print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
|
131
|
-
candidates = self.
|
197
|
+
candidates = self._get_retriever().retrieve(
|
198
|
+
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
199
|
+
)
|
132
200
|
all_candidate_docs.extend(candidates)
|
133
201
|
|
134
202
|
# 对候选文档进行去重
|
@@ -139,12 +207,13 @@ class JarvisRAGPipeline:
|
|
139
207
|
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
140
208
|
|
141
209
|
# 3. 根据*原始*查询对统一的候选池进行重排
|
142
|
-
|
143
|
-
f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)..."
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
210
|
+
if self.use_rerank:
|
211
|
+
print(f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排(基于原始问题)...")
|
212
|
+
retrieved_docs = self._get_reranker().rerank(
|
213
|
+
query_text, unique_candidate_docs, top_n=n_results
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
retrieved_docs = unique_candidate_docs[:n_results]
|
148
217
|
|
149
218
|
if not retrieved_docs:
|
150
219
|
return "我在提供的文档中找不到任何相关信息来回答您的问题。"
|
@@ -166,9 +235,49 @@ class JarvisRAGPipeline:
|
|
166
235
|
|
167
236
|
# 4. 创建最终提示并生成答案
|
168
237
|
# 我们使用原始的query_text作为给LLM的最终提示
|
169
|
-
prompt = self._create_prompt(query_text, retrieved_docs
|
238
|
+
prompt = self._create_prompt(query_text, retrieved_docs)
|
170
239
|
|
171
240
|
print("🤖 正在从LLM生成答案...")
|
172
241
|
answer = self.llm.generate(prompt)
|
173
242
|
|
174
243
|
return answer
|
244
|
+
|
245
|
+
def retrieve_only(self, query_text: str, n_results: int = 5) -> List[Document]:
|
246
|
+
"""
|
247
|
+
仅执行检索和重排,不生成答案。
|
248
|
+
|
249
|
+
参数:
|
250
|
+
query_text: 用户的原始问题。
|
251
|
+
n_results: 要检索的最终相关块的数量。
|
252
|
+
|
253
|
+
返回:
|
254
|
+
检索到的文档列表。
|
255
|
+
"""
|
256
|
+
# 1. 重写查询
|
257
|
+
rewritten_queries = self._get_query_rewriter().rewrite(query_text)
|
258
|
+
|
259
|
+
# 2. 检索候选文档
|
260
|
+
all_candidate_docs = []
|
261
|
+
for q in rewritten_queries:
|
262
|
+
print(f"🔍 正在为查询变体 '{q}' 进行混合检索...")
|
263
|
+
candidates = self._get_retriever().retrieve(
|
264
|
+
q, n_results=n_results * 2, use_bm25=self.use_bm25
|
265
|
+
)
|
266
|
+
all_candidate_docs.extend(candidates)
|
267
|
+
|
268
|
+
unique_docs_dict = {doc.page_content: doc for doc in all_candidate_docs}
|
269
|
+
unique_candidate_docs = list(unique_docs_dict.values())
|
270
|
+
|
271
|
+
if not unique_candidate_docs:
|
272
|
+
return []
|
273
|
+
|
274
|
+
# 3. 重排
|
275
|
+
if self.use_rerank:
|
276
|
+
print(f"🔍 正在对 {len(unique_candidate_docs)} 个候选文档进行重排...")
|
277
|
+
retrieved_docs = self._get_reranker().rerank(
|
278
|
+
query_text, unique_candidate_docs, top_n=n_results
|
279
|
+
)
|
280
|
+
else:
|
281
|
+
retrieved_docs = unique_candidate_docs[:n_results]
|
282
|
+
|
283
|
+
return retrieved_docs
|
jarvis/jarvis_rag/retriever.py
CHANGED
@@ -39,9 +39,7 @@ class ChromaRetriever:
|
|
39
39
|
self.collection = self.client.get_or_create_collection(
|
40
40
|
name=self.collection_name
|
41
41
|
)
|
42
|
-
print(
|
43
|
-
f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。"
|
44
|
-
)
|
42
|
+
print(f"✅ ChromaDB 客户端已在 '{db_path}' 初始化,集合为 '{collection_name}'。")
|
45
43
|
|
46
44
|
# BM25索引设置
|
47
45
|
self.bm25_index_path = os.path.join(self.db_path, f"{collection_name}_bm25.pkl")
|
@@ -107,7 +105,9 @@ class ChromaRetriever:
|
|
107
105
|
self.bm25_index = BM25Okapi(self.bm25_corpus)
|
108
106
|
self._save_bm25_index()
|
109
107
|
|
110
|
-
def retrieve(
|
108
|
+
def retrieve(
|
109
|
+
self, query: str, n_results: int = 5, use_bm25: bool = True
|
110
|
+
) -> List[Document]:
|
111
111
|
"""
|
112
112
|
使用向量搜索和BM25执行混合检索,然后使用倒数排序融合(RRF)
|
113
113
|
对结果进行融合。
|
@@ -121,7 +121,7 @@ class ChromaRetriever:
|
|
121
121
|
|
122
122
|
# 2. 关键字搜索 (BM25)
|
123
123
|
bm25_docs = []
|
124
|
-
if self.bm25_index:
|
124
|
+
if self.bm25_index and use_bm25:
|
125
125
|
tokenized_query = query.split()
|
126
126
|
doc_scores = self.bm25_index.get_scores(tokenized_query)
|
127
127
|
|
@@ -12,7 +12,9 @@ class generate_new_tool:
|
|
12
12
|
生成并注册新的Jarvis工具。该工具会在用户数据目录下创建新的工具文件,
|
13
13
|
并自动注册到当前的工具注册表中。适用场景:1. 需要创建新的自定义工具;
|
14
14
|
2. 扩展Jarvis功能;3. 自动化重复性操作;4. 封装特定领域的功能。
|
15
|
-
|
15
|
+
重要提示:
|
16
|
+
1. `tool_name` 参数必须与 `tool_code` 中定义的 `name` 属性完全一致。
|
17
|
+
2. 在编写工具代码时,应尽量将工具执行的过程和结果打印出来,方便追踪工具的执行状态。
|
16
18
|
"""
|
17
19
|
|
18
20
|
parameters = {
|
@@ -75,6 +77,25 @@ class generate_new_tool:
|
|
75
77
|
"stderr": f"工具名称 '{tool_name}' 不是有效的Python标识符",
|
76
78
|
}
|
77
79
|
|
80
|
+
# 验证工具代码中的名称是否与tool_name一致
|
81
|
+
import re
|
82
|
+
|
83
|
+
match = re.search(r"^\s*name\s*=\s*[\"'](.+?)[\"']", tool_code, re.MULTILINE)
|
84
|
+
if not match:
|
85
|
+
return {
|
86
|
+
"success": False,
|
87
|
+
"stdout": "",
|
88
|
+
"stderr": "无法在工具代码中找到 'name' 属性。请确保工具类中包含 'name = \"your_tool_name\"'。",
|
89
|
+
}
|
90
|
+
|
91
|
+
code_name = match.group(1)
|
92
|
+
if tool_name != code_name:
|
93
|
+
return {
|
94
|
+
"success": False,
|
95
|
+
"stdout": "",
|
96
|
+
"stderr": f"工具名称不一致:参数 'tool_name' ('{tool_name}') 与代码中的 'name' 属性 ('{code_name}') 必须相同。",
|
97
|
+
}
|
98
|
+
|
78
99
|
# 准备工具目录
|
79
100
|
tools_dir = Path(get_data_dir()) / "tools"
|
80
101
|
tools_dir.mkdir(parents=True, exist_ok=True)
|
jarvis/jarvis_utils/config.py
CHANGED
@@ -115,22 +115,26 @@ def get_shell_name() -> str:
|
|
115
115
|
return os.path.basename(shell_path).lower()
|
116
116
|
|
117
117
|
|
118
|
-
def _get_resolved_model_config(
|
118
|
+
def _get_resolved_model_config(
|
119
|
+
model_group_override: Optional[str] = None,
|
120
|
+
) -> Dict[str, Any]:
|
119
121
|
"""
|
120
122
|
解析并合并模型配置,处理模型组。
|
121
123
|
|
122
124
|
优先级顺序:
|
123
125
|
1. 单独的环境变量 (JARVIS_PLATFORM, JARVIS_MODEL, etc.)
|
124
|
-
2.
|
126
|
+
2. JARVIS_LLM_GROUP 中定义的组配置
|
125
127
|
3. 代码中的默认值
|
126
128
|
|
127
129
|
返回:
|
128
130
|
Dict[str, Any]: 解析后的模型配置字典
|
129
131
|
"""
|
130
132
|
group_config = {}
|
131
|
-
model_group_name = model_group_override or GLOBAL_CONFIG_DATA.get(
|
133
|
+
model_group_name = model_group_override or GLOBAL_CONFIG_DATA.get(
|
134
|
+
"JARVIS_LLM_GROUP"
|
135
|
+
)
|
132
136
|
# The format is a list of single-key dicts: [{'group_name': {...}}, ...]
|
133
|
-
model_groups = GLOBAL_CONFIG_DATA.get("
|
137
|
+
model_groups = GLOBAL_CONFIG_DATA.get("JARVIS_LLM_GROUPS", [])
|
134
138
|
|
135
139
|
if model_group_name and isinstance(model_groups, list):
|
136
140
|
for group_item in model_groups:
|
@@ -202,7 +206,9 @@ def get_thinking_model_name(model_group_override: Optional[str] = None) -> str:
|
|
202
206
|
"""
|
203
207
|
config = _get_resolved_model_config(model_group_override)
|
204
208
|
# Fallback to normal model if thinking model is not specified
|
205
|
-
return config.get(
|
209
|
+
return config.get(
|
210
|
+
"JARVIS_THINKING_MODEL", get_normal_model_name(model_group_override)
|
211
|
+
)
|
206
212
|
|
207
213
|
|
208
214
|
def is_execute_tool_confirm() -> bool:
|
@@ -334,14 +340,65 @@ def get_mcp_config() -> List[Dict[str, Any]]:
|
|
334
340
|
# ==============================================================================
|
335
341
|
|
336
342
|
|
337
|
-
|
343
|
+
DEFAULT_RAG_GROUPS = [
|
344
|
+
{
|
345
|
+
"text": {
|
346
|
+
"embedding_model": "BAAI/bge-m3",
|
347
|
+
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
348
|
+
"use_bm25": True,
|
349
|
+
"use_rerank": True,
|
350
|
+
}
|
351
|
+
},
|
352
|
+
{
|
353
|
+
"code": {
|
354
|
+
"embedding_model": "Qodo/Qodo-Embed-1-7B",
|
355
|
+
"use_bm25": False,
|
356
|
+
"use_rerank": False,
|
357
|
+
}
|
358
|
+
},
|
359
|
+
]
|
360
|
+
|
361
|
+
|
362
|
+
def _get_resolved_rag_config(
|
363
|
+
rag_group_override: Optional[str] = None,
|
364
|
+
) -> Dict[str, Any]:
|
338
365
|
"""
|
339
|
-
|
366
|
+
解析并合并RAG配置,处理RAG组。
|
367
|
+
|
368
|
+
优先级顺序:
|
369
|
+
1. JARVIS_RAG 中的顶级设置 (embedding_model, etc.)
|
370
|
+
2. JARVIS_RAG_GROUP 中定义的组配置
|
371
|
+
3. 代码中的默认值
|
340
372
|
|
341
373
|
返回:
|
342
|
-
Dict[str, Any]: RAG配置字典
|
374
|
+
Dict[str, Any]: 解析后的RAG配置字典
|
343
375
|
"""
|
344
|
-
|
376
|
+
group_config = {}
|
377
|
+
rag_group_name = rag_group_override or GLOBAL_CONFIG_DATA.get("JARVIS_RAG_GROUP")
|
378
|
+
rag_groups = GLOBAL_CONFIG_DATA.get("JARVIS_RAG_GROUPS", DEFAULT_RAG_GROUPS)
|
379
|
+
|
380
|
+
if rag_group_name and isinstance(rag_groups, list):
|
381
|
+
for group_item in rag_groups:
|
382
|
+
if isinstance(group_item, dict) and rag_group_name in group_item:
|
383
|
+
group_config = group_item[rag_group_name]
|
384
|
+
break
|
385
|
+
|
386
|
+
# Start with group config
|
387
|
+
resolved_config = group_config.copy()
|
388
|
+
|
389
|
+
# Override with specific settings from the top-level JARVIS_RAG dict
|
390
|
+
top_level_rag_config = GLOBAL_CONFIG_DATA.get("JARVIS_RAG", {})
|
391
|
+
if isinstance(top_level_rag_config, dict):
|
392
|
+
for key in [
|
393
|
+
"embedding_model",
|
394
|
+
"rerank_model",
|
395
|
+
"use_bm25",
|
396
|
+
"use_rerank",
|
397
|
+
]:
|
398
|
+
if key in top_level_rag_config:
|
399
|
+
resolved_config[key] = top_level_rag_config[key]
|
400
|
+
|
401
|
+
return resolved_config
|
345
402
|
|
346
403
|
|
347
404
|
def get_rag_embedding_model() -> str:
|
@@ -351,7 +408,8 @@ def get_rag_embedding_model() -> str:
|
|
351
408
|
返回:
|
352
409
|
str: 嵌入模型的名称
|
353
410
|
"""
|
354
|
-
|
411
|
+
config = _get_resolved_rag_config()
|
412
|
+
return config.get("embedding_model", "BAAI/bge-m3")
|
355
413
|
|
356
414
|
|
357
415
|
def get_rag_rerank_model() -> str:
|
@@ -361,7 +419,8 @@ def get_rag_rerank_model() -> str:
|
|
361
419
|
返回:
|
362
420
|
str: rerank模型的名称
|
363
421
|
"""
|
364
|
-
|
422
|
+
config = _get_resolved_rag_config()
|
423
|
+
return config.get("rerank_model", "BAAI/bge-reranker-v2-m3")
|
365
424
|
|
366
425
|
|
367
426
|
def get_rag_embedding_cache_path() -> str:
|
@@ -382,3 +441,25 @@ def get_rag_vector_db_path() -> str:
|
|
382
441
|
str: 数据库路径
|
383
442
|
"""
|
384
443
|
return ".jarvis/rag/vectordb"
|
444
|
+
|
445
|
+
|
446
|
+
def get_rag_use_bm25() -> bool:
|
447
|
+
"""
|
448
|
+
获取RAG是否使用BM25。
|
449
|
+
|
450
|
+
返回:
|
451
|
+
bool: 如果使用BM25则返回True,默认为True
|
452
|
+
"""
|
453
|
+
config = _get_resolved_rag_config()
|
454
|
+
return config.get("use_bm25", True) is True
|
455
|
+
|
456
|
+
|
457
|
+
def get_rag_use_rerank() -> bool:
|
458
|
+
"""
|
459
|
+
获取RAG是否使用rerank。
|
460
|
+
|
461
|
+
返回:
|
462
|
+
bool: 如果使用rerank则返回True,默认为True
|
463
|
+
"""
|
464
|
+
config = _get_resolved_rag_config()
|
465
|
+
return config.get("use_rerank", True) is True
|
jarvis/jarvis_utils/globals.py
CHANGED
@@ -9,9 +9,11 @@
|
|
9
9
|
"""
|
10
10
|
import os
|
11
11
|
|
12
|
-
#
|
13
|
-
|
14
|
-
|
12
|
+
# 全局变量:保存消息历史
|
13
|
+
from typing import Any, Dict, Set, List
|
14
|
+
|
15
|
+
message_history: List[str] = []
|
16
|
+
MAX_HISTORY_SIZE = 50
|
15
17
|
|
16
18
|
import colorama
|
17
19
|
from rich.console import Console
|
@@ -169,13 +171,18 @@ def get_interrupt() -> int:
|
|
169
171
|
|
170
172
|
def set_last_message(message: str) -> None:
|
171
173
|
"""
|
172
|
-
|
174
|
+
将消息添加到历史记录中。
|
173
175
|
|
174
176
|
参数:
|
175
177
|
message: 要保存的消息
|
176
178
|
"""
|
177
|
-
global
|
178
|
-
|
179
|
+
global message_history
|
180
|
+
if message:
|
181
|
+
# 避免重复添加
|
182
|
+
if not message_history or message_history[-1] != message:
|
183
|
+
message_history.append(message)
|
184
|
+
if len(message_history) > MAX_HISTORY_SIZE:
|
185
|
+
message_history.pop(0)
|
179
186
|
|
180
187
|
|
181
188
|
def get_last_message() -> str:
|
@@ -183,6 +190,20 @@ def get_last_message() -> str:
|
|
183
190
|
获取最后一条消息。
|
184
191
|
|
185
192
|
返回:
|
186
|
-
str:
|
193
|
+
str: 最后一条消息,如果历史记录为空则返回空字符串
|
194
|
+
"""
|
195
|
+
global message_history
|
196
|
+
if message_history:
|
197
|
+
return message_history[-1]
|
198
|
+
return ""
|
199
|
+
|
200
|
+
|
201
|
+
def get_message_history() -> List[str]:
|
202
|
+
"""
|
203
|
+
获取完整的消息历史记录。
|
204
|
+
|
205
|
+
返回:
|
206
|
+
List[str]: 消息历史列表
|
187
207
|
"""
|
188
|
-
|
208
|
+
global message_history
|
209
|
+
return message_history
|