jarvis-ai-assistant 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/prompts.py +26 -4
- jarvis/jarvis_data/config_schema.json +67 -12
- jarvis/jarvis_platform/tongyi.py +9 -9
- jarvis/jarvis_rag/cli.py +79 -23
- jarvis/jarvis_rag/query_rewriter.py +61 -12
- jarvis/jarvis_rag/rag_pipeline.py +143 -34
- jarvis/jarvis_rag/retriever.py +5 -5
- jarvis/jarvis_tools/generate_new_tool.py +22 -1
- jarvis/jarvis_utils/config.py +92 -11
- jarvis/jarvis_utils/globals.py +29 -8
- jarvis/jarvis_utils/input.py +114 -121
- jarvis/jarvis_utils/utils.py +3 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/METADATA +82 -9
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/RECORD +19 -19
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.2.2.dist-info → jarvis_ai_assistant-0.2.3.dist-info}/top_level.txt +0 -0
jarvis/__init__.py
CHANGED
jarvis/jarvis_agent/prompts.py
CHANGED
@@ -113,6 +113,31 @@ TASK_ANALYSIS_PROMPT = f"""<task_analysis>
|
|
113
113
|
"stderr": f"操作失败: {{str(e)}}"
|
114
114
|
}}
|
115
115
|
```
|
116
|
+
4. **在工具中调用大模型**:如果工具需要调用大模型来完成子任务(例如,生成代码、分析文本等),为了避免干扰主对话流程,建议创建一个独立的大模型实例。
|
117
|
+
```python
|
118
|
+
# 通过 agent 实例获取模型配置
|
119
|
+
agent = args.get("agent")
|
120
|
+
if not agent:
|
121
|
+
return {{"success": False, "stderr": "Agent not found."}}
|
122
|
+
|
123
|
+
current_model = agent.model
|
124
|
+
platform_name = current_model.platform_name()
|
125
|
+
model_name = current_model.name()
|
126
|
+
|
127
|
+
# 创建独立的模型实例
|
128
|
+
from jarvis.jarvis_platform.registry import PlatformRegistry
|
129
|
+
llm = PlatformRegistry().create_platform(platform_name)
|
130
|
+
if not llm:
|
131
|
+
return {{"success": False, "stderr": f"Platform {{platform_name}} not found."}}
|
132
|
+
|
133
|
+
llm.set_model_name(model_name)
|
134
|
+
llm.set_suppress_output(True) # 工具内的调用通常不需要流式输出
|
135
|
+
|
136
|
+
# 使用新实例调用大模型
|
137
|
+
PrettyOutput.print("正在执行子任务...", OutputType.INFO)
|
138
|
+
response = llm.chat_until_success("你的提示词")
|
139
|
+
PrettyOutput.print("子任务完成", OutputType.SUCCESS)
|
140
|
+
```
|
116
141
|
</tool_requirements>
|
117
142
|
<methodology_requirements>
|
118
143
|
方法论格式要求:
|
@@ -139,10 +164,7 @@ arguments:
|
|
139
164
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
140
165
|
class 工具名称:
|
141
166
|
name = "工具名称"
|
142
|
-
description = "Tool
|
143
|
-
Tool description
|
144
|
-
适用场景:1. 格式化文本; 2. 处理标题; 3. 标准化输出
|
145
|
-
\"\"\"
|
167
|
+
description = "Tool description"
|
146
168
|
parameters = {{
|
147
169
|
"type": "object",
|
148
170
|
"properties": {{
|
@@ -141,38 +141,47 @@
|
|
141
141
|
"description": "思考操作模型名称",
|
142
142
|
"default": "deep_seek"
|
143
143
|
},
|
144
|
-
"
|
144
|
+
"JARVIS_LLM_GROUP": {
|
145
145
|
"type": "string",
|
146
|
-
"description": "选择一个预定义的模型组"
|
146
|
+
"description": "选择一个预定义的模型组",
|
147
|
+
"default": ""
|
147
148
|
},
|
148
|
-
"
|
149
|
+
"JARVIS_LLM_GROUPS": {
|
149
150
|
"type": "array",
|
150
151
|
"description": "预定义的模型配置组",
|
152
|
+
"default": [],
|
151
153
|
"items": {
|
152
154
|
"type": "object",
|
153
155
|
"additionalProperties": {
|
154
156
|
"type": "object",
|
155
157
|
"properties": {
|
156
158
|
"JARVIS_PLATFORM": {
|
157
|
-
"type": "string"
|
159
|
+
"type": "string",
|
160
|
+
"default": "yuanbao"
|
158
161
|
},
|
159
162
|
"JARVIS_MODEL": {
|
160
|
-
"type": "string"
|
163
|
+
"type": "string",
|
164
|
+
"default": "deep_seek_v3"
|
161
165
|
},
|
162
166
|
"JARVIS_THINKING_PLATFORM": {
|
163
|
-
"type": "string"
|
167
|
+
"type": "string",
|
168
|
+
"default": "yuanbao"
|
164
169
|
},
|
165
170
|
"JARVIS_THINKING_MODEL": {
|
166
|
-
"type": "string"
|
171
|
+
"type": "string",
|
172
|
+
"default": "deep_seek"
|
167
173
|
},
|
168
174
|
"JARVIS_MAX_TOKEN_COUNT": {
|
169
|
-
"type": "number"
|
175
|
+
"type": "number",
|
176
|
+
"default": 960000
|
170
177
|
},
|
171
178
|
"JARVIS_MAX_INPUT_TOKEN_COUNT": {
|
172
|
-
"type": "number"
|
179
|
+
"type": "number",
|
180
|
+
"default": 32000
|
173
181
|
},
|
174
182
|
"JARVIS_MAX_BIG_CONTENT_SIZE": {
|
175
|
-
"type": "number"
|
183
|
+
"type": "number",
|
184
|
+
"default": 160000
|
176
185
|
}
|
177
186
|
},
|
178
187
|
"required": [
|
@@ -235,9 +244,43 @@
|
|
235
244
|
"description": "是否启用静态代码分析",
|
236
245
|
"default": true
|
237
246
|
},
|
247
|
+
"JARVIS_RAG_GROUP": {
|
248
|
+
"type": "string",
|
249
|
+
"description": "选择一个预定义的RAG配置组",
|
250
|
+
"default": ""
|
251
|
+
},
|
252
|
+
"JARVIS_RAG_GROUPS": {
|
253
|
+
"type": "array",
|
254
|
+
"description": "预定义的RAG配置组",
|
255
|
+
"default": [],
|
256
|
+
"items": {
|
257
|
+
"type": "object",
|
258
|
+
"additionalProperties": {
|
259
|
+
"type": "object",
|
260
|
+
"properties": {
|
261
|
+
"embedding_model": {
|
262
|
+
"type": "string",
|
263
|
+
"default": "BAAI/bge-base-zh-v1.5"
|
264
|
+
},
|
265
|
+
"rerank_model": {
|
266
|
+
"type": "string",
|
267
|
+
"default": "BAAI/bge-reranker-base"
|
268
|
+
},
|
269
|
+
"use_bm25": {
|
270
|
+
"type": "boolean",
|
271
|
+
"default": true
|
272
|
+
},
|
273
|
+
"use_rerank": {
|
274
|
+
"type": "boolean",
|
275
|
+
"default": true
|
276
|
+
}
|
277
|
+
}
|
278
|
+
}
|
279
|
+
}
|
280
|
+
},
|
238
281
|
"JARVIS_RAG": {
|
239
282
|
"type": "object",
|
240
|
-
"description": "RAG
|
283
|
+
"description": "RAG框架的顶层配置。注意:此处的设置将覆盖任何由JARVIS_RAG_GROUP选择的组配置。",
|
241
284
|
"properties": {
|
242
285
|
"embedding_model": {
|
243
286
|
"type": "string",
|
@@ -248,11 +291,23 @@
|
|
248
291
|
"type": "string",
|
249
292
|
"default": "BAAI/bge-reranker-base",
|
250
293
|
"description": "用于RAG的rerank模型的名称, 默认为 'BAAI/bge-reranker-base'"
|
294
|
+
},
|
295
|
+
"use_bm25": {
|
296
|
+
"type": "boolean",
|
297
|
+
"default": true,
|
298
|
+
"description": "是否在RAG中为检索使用BM25, 默认为 true"
|
299
|
+
},
|
300
|
+
"use_rerank": {
|
301
|
+
"type": "boolean",
|
302
|
+
"default": true,
|
303
|
+
"description": "是否在RAG中为检索使用rerank, 默认为 true"
|
251
304
|
}
|
252
305
|
},
|
253
306
|
"default": {
|
254
307
|
"embedding_model": "BAAI/bge-base-zh-v1.5",
|
255
|
-
"rerank_model": "BAAI/bge-reranker-base"
|
308
|
+
"rerank_model": "BAAI/bge-reranker-base",
|
309
|
+
"use_bm25": true,
|
310
|
+
"use_rerank": true
|
256
311
|
}
|
257
312
|
},
|
258
313
|
"JARVIS_REPLACE_MAP": {
|
jarvis/jarvis_platform/tongyi.py
CHANGED
@@ -81,10 +81,10 @@ class TongyiPlatform(BasePlatform):
|
|
81
81
|
"contentType": "text",
|
82
82
|
"role": "user",
|
83
83
|
"ext": {
|
84
|
-
"searchType": "",
|
84
|
+
"searchType": "depth" if self.web else "",
|
85
85
|
"pptGenerate": False,
|
86
|
-
"deepThink":
|
87
|
-
"deepResearch":
|
86
|
+
"deepThink": self.model_name == "Thinking",
|
87
|
+
"deepResearch": self.model_name == "Deep-Research",
|
88
88
|
},
|
89
89
|
}
|
90
90
|
]
|
@@ -98,10 +98,10 @@ class TongyiPlatform(BasePlatform):
|
|
98
98
|
"contentType": "text",
|
99
99
|
"role": "system",
|
100
100
|
"ext": {
|
101
|
-
"searchType": "",
|
101
|
+
"searchType": "depth" if self.web else "",
|
102
102
|
"pptGenerate": False,
|
103
|
-
"deepThink":
|
104
|
-
"deepResearch":
|
103
|
+
"deepThink": self.model_name == "Thinking",
|
104
|
+
"deepResearch": self.model_name == "Deep-Research",
|
105
105
|
},
|
106
106
|
},
|
107
107
|
)
|
@@ -140,13 +140,13 @@ class TongyiPlatform(BasePlatform):
|
|
140
140
|
"parentMsgId": self.msg_id,
|
141
141
|
"params": {
|
142
142
|
"agentId": "",
|
143
|
-
"searchType": "",
|
143
|
+
"searchType": "depth" if self.web else "",
|
144
144
|
"pptGenerate": False,
|
145
145
|
"bizScene": "code_chat" if self.model_name == "Code-Chat" else "",
|
146
146
|
"bizSceneInfo": {},
|
147
147
|
"specifiedModel": "",
|
148
|
-
"deepThink":
|
149
|
-
"deepResearch":
|
148
|
+
"deepThink": self.model_name == "Thinking",
|
149
|
+
"deepResearch": self.model_name == "Deep-Research",
|
150
150
|
"fileUploadBatchId": (
|
151
151
|
self.uploaded_file_info[0]["batchId"]
|
152
152
|
if self.uploaded_file_info
|
jarvis/jarvis_rag/cli.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import sys
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Optional, List, Literal, cast
|
4
|
+
from typing import Optional, List, Literal, cast, Tuple
|
5
5
|
import mimetypes
|
6
6
|
|
7
7
|
import pathspec # type: ignore
|
@@ -15,6 +15,11 @@ from langchain_core.document_loaders.base import BaseLoader
|
|
15
15
|
from rich.markdown import Markdown
|
16
16
|
|
17
17
|
from jarvis.jarvis_utils.utils import init_env
|
18
|
+
from jarvis.jarvis_utils.config import (
|
19
|
+
get_rag_embedding_model,
|
20
|
+
get_rag_use_bm25,
|
21
|
+
get_rag_use_rerank,
|
22
|
+
)
|
18
23
|
|
19
24
|
|
20
25
|
def is_likely_text_file(file_path: Path) -> bool:
|
@@ -65,9 +70,7 @@ class _CustomPlatformLLM(LLMInterface):
|
|
65
70
|
|
66
71
|
def __init__(self, platform: BasePlatform):
|
67
72
|
self.platform = platform
|
68
|
-
print(
|
69
|
-
f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'"
|
70
|
-
)
|
73
|
+
print(f"✅ 使用自定义LLM: 平台='{platform.platform_name()}', 模型='{platform.name()}'")
|
71
74
|
|
72
75
|
def generate(self, prompt: str, **kwargs) -> str:
|
73
76
|
return self.platform.chat_until_success(prompt)
|
@@ -91,7 +94,7 @@ def _create_custom_llm(platform_name: str, model_name: str) -> Optional[LLMInter
|
|
91
94
|
return None
|
92
95
|
|
93
96
|
|
94
|
-
def _load_ragignore_spec() ->
|
97
|
+
def _load_ragignore_spec() -> Tuple[Optional[pathspec.PathSpec], Optional[Path]]:
|
95
98
|
"""
|
96
99
|
从项目根目录加载忽略模式。
|
97
100
|
首先查找 `.jarvis/rag/.ragignore`,如果未找到,则回退到 `.gitignore`。
|
@@ -140,9 +143,7 @@ def add_documents(
|
|
140
143
|
"-e",
|
141
144
|
help="嵌入模型的名称。覆盖全局配置。",
|
142
145
|
),
|
143
|
-
db_path: Optional[Path] = typer.Option(
|
144
|
-
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
145
|
-
),
|
146
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
146
147
|
batch_size: int = typer.Option(
|
147
148
|
500,
|
148
149
|
"--batch-size",
|
@@ -244,9 +245,7 @@ def add_documents(
|
|
244
245
|
print("❌ 未能成功加载任何文档。")
|
245
246
|
raise typer.Exit(code=1)
|
246
247
|
|
247
|
-
print(
|
248
|
-
f"✅ 成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。"
|
249
|
-
)
|
248
|
+
print(f"✅ 成功将 {total_docs_added} 个文档的内容添加至集合 '{collection_name}'。")
|
250
249
|
|
251
250
|
except Exception as e:
|
252
251
|
print(f"❌ 发生严重错误: {e}")
|
@@ -261,9 +260,7 @@ def list_documents(
|
|
261
260
|
"-c",
|
262
261
|
help="向量数据库中集合的名称。",
|
263
262
|
),
|
264
|
-
db_path: Optional[Path] = typer.Option(
|
265
|
-
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
266
|
-
),
|
263
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
267
264
|
):
|
268
265
|
"""列出指定集合中的所有唯一文档。"""
|
269
266
|
try:
|
@@ -272,7 +269,7 @@ def list_documents(
|
|
272
269
|
collection_name=collection_name,
|
273
270
|
)
|
274
271
|
|
275
|
-
collection = pipeline.
|
272
|
+
collection = pipeline._get_collection()
|
276
273
|
results = collection.get() # 获取集合中的所有项目
|
277
274
|
|
278
275
|
if not results or not results["metadatas"]:
|
@@ -300,6 +297,63 @@ def list_documents(
|
|
300
297
|
raise typer.Exit(code=1)
|
301
298
|
|
302
299
|
|
300
|
+
@app.command("retrieve", help="仅从知识库检索相关文档,不生成答案。")
|
301
|
+
def retrieve(
|
302
|
+
question: str = typer.Argument(..., help="要提出的问题。"),
|
303
|
+
collection_name: str = typer.Option(
|
304
|
+
"jarvis_rag_collection",
|
305
|
+
"--collection",
|
306
|
+
"-c",
|
307
|
+
help="向量数据库中集合的名称。",
|
308
|
+
),
|
309
|
+
embedding_model: Optional[str] = typer.Option(
|
310
|
+
None,
|
311
|
+
"--embedding-model",
|
312
|
+
"-e",
|
313
|
+
help="嵌入模型的名称。覆盖全局配置。",
|
314
|
+
),
|
315
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
316
|
+
n_results: int = typer.Option(5, "--top-n", help="要检索的文档数量。"),
|
317
|
+
):
|
318
|
+
"""仅从RAG知识库检索文档并打印结果。"""
|
319
|
+
try:
|
320
|
+
# 如果未在命令行中指定,则从配置中加载RAG设置
|
321
|
+
final_embedding_model = embedding_model or get_rag_embedding_model()
|
322
|
+
use_bm25 = get_rag_use_bm25()
|
323
|
+
use_rerank = get_rag_use_rerank()
|
324
|
+
|
325
|
+
pipeline = JarvisRAGPipeline(
|
326
|
+
embedding_model=final_embedding_model,
|
327
|
+
db_path=str(db_path) if db_path else None,
|
328
|
+
collection_name=collection_name,
|
329
|
+
use_bm25=use_bm25,
|
330
|
+
use_rerank=use_rerank,
|
331
|
+
)
|
332
|
+
|
333
|
+
print(f"🤔 正在为问题检索文档: '{question}'")
|
334
|
+
retrieved_docs = pipeline.retrieve_only(question, n_results=n_results)
|
335
|
+
|
336
|
+
if not retrieved_docs:
|
337
|
+
print("ℹ️ 未找到相关文档。")
|
338
|
+
return
|
339
|
+
|
340
|
+
print(f"✅ 成功检索到 {len(retrieved_docs)} 个文档:")
|
341
|
+
from jarvis.jarvis_utils.globals import console
|
342
|
+
|
343
|
+
for i, doc in enumerate(retrieved_docs, 1):
|
344
|
+
source = doc.metadata.get("source", "未知来源")
|
345
|
+
content = doc.page_content
|
346
|
+
panel_title = f"文档 {i} | 来源: {source}"
|
347
|
+
console.print(
|
348
|
+
f"\n[bold magenta]{panel_title}[/bold magenta]"
|
349
|
+
)
|
350
|
+
console.print(Markdown(f"```\n{content}\n```"))
|
351
|
+
|
352
|
+
except Exception as e:
|
353
|
+
print(f"❌ 发生错误: {e}")
|
354
|
+
raise typer.Exit(code=1)
|
355
|
+
|
356
|
+
|
303
357
|
@app.command("query", help="向知识库提问。")
|
304
358
|
def query(
|
305
359
|
question: str = typer.Argument(..., help="要提出的问题。"),
|
@@ -315,9 +369,7 @@ def query(
|
|
315
369
|
"-e",
|
316
370
|
help="嵌入模型的名称。覆盖全局配置。",
|
317
371
|
),
|
318
|
-
db_path: Optional[Path] = typer.Option(
|
319
|
-
None, "--db-path", help="向量数据库的路径。覆盖全局配置。"
|
320
|
-
),
|
372
|
+
db_path: Optional[Path] = typer.Option(None, "--db-path", help="向量数据库的路径。覆盖全局配置。"),
|
321
373
|
platform: Optional[str] = typer.Option(
|
322
374
|
None,
|
323
375
|
"--platform",
|
@@ -341,11 +393,18 @@ def query(
|
|
341
393
|
if (platform or model) and not custom_llm:
|
342
394
|
raise typer.Exit(code=1)
|
343
395
|
|
396
|
+
# 如果未在命令行中指定,则从配置中加载RAG设置
|
397
|
+
final_embedding_model = embedding_model or get_rag_embedding_model()
|
398
|
+
use_bm25 = get_rag_use_bm25()
|
399
|
+
use_rerank = get_rag_use_rerank()
|
400
|
+
|
344
401
|
pipeline = JarvisRAGPipeline(
|
345
402
|
llm=custom_llm,
|
346
|
-
embedding_model=
|
403
|
+
embedding_model=final_embedding_model,
|
347
404
|
db_path=str(db_path) if db_path else None,
|
348
405
|
collection_name=collection_name,
|
406
|
+
use_bm25=use_bm25,
|
407
|
+
use_rerank=use_rerank,
|
349
408
|
)
|
350
409
|
|
351
410
|
print(f"🤔 正在查询: '{question}'")
|
@@ -373,10 +432,7 @@ except ImportError:
|
|
373
432
|
|
374
433
|
def _check_rag_dependencies():
|
375
434
|
if not _RAG_INSTALLED:
|
376
|
-
print(
|
377
|
-
"❌ RAG依赖项未安装。"
|
378
|
-
"请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。"
|
379
|
-
)
|
435
|
+
print("❌ RAG依赖项未安装。" "请运行 'pip install \"jarvis-ai-assistant[rag]\"' 来使用此命令。")
|
380
436
|
raise typer.Exit(code=1)
|
381
437
|
|
382
438
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import List
|
2
2
|
from .llm_interface import LLMInterface
|
3
|
+
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
3
4
|
|
4
5
|
|
5
6
|
class QueryRewriter:
|
@@ -20,20 +21,29 @@ class QueryRewriter:
|
|
20
21
|
def _create_prompt_template(self) -> str:
|
21
22
|
"""为多查询重写任务创建提示模板。"""
|
22
23
|
return """
|
23
|
-
|
24
|
+
你是一个精通检索和语言的AI助手。你的任务是将以下这个单一的用户问题,改写为几个语义相关但表达方式不同的搜索查询,并提供英文翻译。这有助于在多语言知识库中进行更全面的搜索。
|
24
25
|
|
25
26
|
请遵循以下原则:
|
26
|
-
1.
|
27
|
-
2.
|
28
|
-
|
29
|
-
|
27
|
+
1. **保留核心意图**: 所有查询都必须围绕原始问题的核心意图。
|
28
|
+
2. **查询类型**:
|
29
|
+
- **同义词/相关术语查询**: 使用原始语言,通过替换同义词或相关术语来生成1-2个新的查询。
|
30
|
+
- **英文翻译查询**: 将原始问题翻译成一个简洁的英文搜索查询。
|
31
|
+
3. **简洁性**: 每个查询都应该是独立的、可以直接用于搜索的短语或问题。
|
32
|
+
4. **严格格式要求**: 你必须将所有重写后的查询放置在 `<REWRITE>` 和 `</REWRITE>` 标签之间。每个查询占一行。不要在标签内外添加任何编号、前缀或解释。
|
33
|
+
|
34
|
+
示例输出格式:
|
35
|
+
<REWRITE>
|
36
|
+
使用不同表述的中文查询
|
37
|
+
另一个中文查询
|
38
|
+
English version of the query
|
39
|
+
</REWRITE>
|
30
40
|
|
31
41
|
原始问题:
|
32
42
|
---
|
33
43
|
{query}
|
34
44
|
---
|
35
45
|
|
36
|
-
|
46
|
+
请将改写后的查询包裹在 `<REWRITE>` 标签内:
|
37
47
|
"""
|
38
48
|
|
39
49
|
def rewrite(self, query: str) -> List[str]:
|
@@ -47,16 +57,55 @@ class QueryRewriter:
|
|
47
57
|
一个经过重写、搜索优化的查询列表。
|
48
58
|
"""
|
49
59
|
prompt = self.rewrite_prompt_template.format(query=query)
|
50
|
-
print(
|
60
|
+
PrettyOutput.print(
|
61
|
+
"正在将原始查询重写为多个搜索查询...", output_type=OutputType.INFO, timestamp=False
|
62
|
+
)
|
63
|
+
|
64
|
+
import re
|
65
|
+
|
66
|
+
max_retries = 3
|
67
|
+
attempts = 0
|
68
|
+
rewritten_queries = []
|
69
|
+
response_text = ""
|
70
|
+
|
71
|
+
while attempts < max_retries:
|
72
|
+
attempts += 1
|
73
|
+
response_text = self.llm.generate(prompt)
|
74
|
+
match = re.search(r"<REWRITE>(.*?)</REWRITE>", response_text, re.DOTALL)
|
75
|
+
|
76
|
+
if match:
|
77
|
+
content = match.group(1).strip()
|
78
|
+
rewritten_queries = [
|
79
|
+
line.strip() for line in content.split("\n") if line.strip()
|
80
|
+
]
|
81
|
+
PrettyOutput.print(
|
82
|
+
f"成功从LLM响应中提取到内容 (尝试 {attempts}/{max_retries})。",
|
83
|
+
output_type=OutputType.SUCCESS,
|
84
|
+
timestamp=False,
|
85
|
+
)
|
86
|
+
break # 提取成功,退出循环
|
87
|
+
else:
|
88
|
+
PrettyOutput.print(
|
89
|
+
f"未能从LLM响应中提取内容。正在重试... ({attempts}/{max_retries})",
|
90
|
+
output_type=OutputType.WARNING,
|
91
|
+
timestamp=False,
|
92
|
+
)
|
51
93
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
94
|
+
# 如果所有重试都失败,则跳过重写步骤
|
95
|
+
if not rewritten_queries:
|
96
|
+
PrettyOutput.print(
|
97
|
+
"所有重试均失败。跳过查询重写,将仅使用原始查询。",
|
98
|
+
output_type=OutputType.ERROR,
|
99
|
+
timestamp=False,
|
100
|
+
)
|
56
101
|
|
57
102
|
# 同时包含原始查询以保证鲁棒性
|
58
103
|
if query not in rewritten_queries:
|
59
104
|
rewritten_queries.insert(0, query)
|
60
105
|
|
61
|
-
print(
|
106
|
+
PrettyOutput.print(
|
107
|
+
f"生成了 {len(rewritten_queries)} 个查询变体。",
|
108
|
+
output_type=OutputType.SUCCESS,
|
109
|
+
timestamp=False,
|
110
|
+
)
|
62
111
|
return rewritten_queries
|