isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from jinja2 import Template
|
|
6
|
+
|
|
7
|
+
from sage.common.core.functions import MapFunction as MapOperator
|
|
8
|
+
|
|
9
|
+
QA_prompt_template_str = """Instruction:
|
|
10
|
+
You are an intelligent assistant with access to a knowledge base. Answer the question below with reference to the provided context.
|
|
11
|
+
Only give me the answer and do not output any other words.
|
|
12
|
+
{%- if external_corpus %}
|
|
13
|
+
Relevant corpus for the current question:
|
|
14
|
+
{{ external_corpus }}
|
|
15
|
+
{%- endif %}
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
QA_short_answer_template_str = """Instruction:
|
|
19
|
+
You are an intelligent assistant with access to a knowledge base. Answer the question below with reference to the provided context.
|
|
20
|
+
Please provide a concise answer and conclude with 'So the final answer is: [your answer]'.
|
|
21
|
+
{%- if external_corpus %}
|
|
22
|
+
Relevant corpus for the current question:
|
|
23
|
+
{{ external_corpus }}
|
|
24
|
+
{%- endif %}
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
summarization_prompt_template_str = """Instruction:
|
|
28
|
+
You are an intelligent assistant. Summarize the content provided below in a concise and clear manner.
|
|
29
|
+
Only provide the summary and do not include any additional information.
|
|
30
|
+
{%- if external_corpus %}
|
|
31
|
+
Content to summarize:
|
|
32
|
+
{{ external_corpus }}
|
|
33
|
+
{%- endif %}
|
|
34
|
+
"""
|
|
35
|
+
QA_prompt_template = Template(QA_prompt_template_str)
|
|
36
|
+
QA_short_answer_template = Template(QA_short_answer_template_str)
|
|
37
|
+
summarization_prompt_template = Template(summarization_prompt_template_str)
|
|
38
|
+
|
|
39
|
+
query_profiler_prompt_template_str = """
|
|
40
|
+
For the given query = how Trump earn his first 1 million dollars?: Analyze the language and internal structure of the query and provide the following information:
|
|
41
|
+
|
|
42
|
+
1. Does it need joint reasoning across multiple documents?
|
|
43
|
+
2. Provide a complexity profile for the query:
|
|
44
|
+
- Complexity: High / Low
|
|
45
|
+
- Joint Reasoning needed: Yes / No
|
|
46
|
+
3. Does this query need input chunks to be summarized? If yes, provide a range in words for the summarized chunks.
|
|
47
|
+
4. How many distinct pieces of information are needed to answer the query?
|
|
48
|
+
|
|
49
|
+
database_metadata = The dataset consists of multiple chunks of information from Fortune 500 companies on financial reports from every quarter of 2023.
|
|
50
|
+
chunk_size = 1024
|
|
51
|
+
|
|
52
|
+
Estimate the query profile along with the database_metadata and chunk_size.
|
|
53
|
+
|
|
54
|
+
Your output must be:
|
|
55
|
+
- **Only a valid JSON object**
|
|
56
|
+
- **No explanations, no formatting, no comments**
|
|
57
|
+
- **No markdown code blocks or prose**
|
|
58
|
+
- **Strictly conform to this schema:**
|
|
59
|
+
|
|
60
|
+
{
|
|
61
|
+
"need_joint_reasoning": <true|false>,
|
|
62
|
+
"complexity": "High" or "Low",
|
|
63
|
+
"need_summarization": <true|false>,
|
|
64
|
+
"summarization_length": integer (30-200),
|
|
65
|
+
"n_info_items": integer (1-6)
|
|
66
|
+
}
|
|
67
|
+
"""
|
|
68
|
+
query_profiler_prompt_template = Template(query_profiler_prompt_template_str)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class QAPromptor(MapOperator):
|
|
72
|
+
"""
|
|
73
|
+
QAPromptor is a prompt rag that generates a QA-style prompt using
|
|
74
|
+
an external corpus and a user query. This class is designed to prepare
|
|
75
|
+
the necessary prompt structure for a question-answering model.
|
|
76
|
+
|
|
77
|
+
Attributes:
|
|
78
|
+
config: Configuration data for initializing the prompt rag (e.g., model details, etc.).
|
|
79
|
+
prompt_template: A template used for generating the system prompt, typically includes context or instructions.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
prompt_template: Template
|
|
83
|
+
|
|
84
|
+
def __init__(self, config, enable_profile=False, **kwargs):
|
|
85
|
+
super().__init__(**kwargs)
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
Initializes the QAPromptor instance with configuration and prompt template.
|
|
89
|
+
|
|
90
|
+
:param config: Dictionary containing configuration for the prompt rag.
|
|
91
|
+
"""
|
|
92
|
+
self.config = config # Store the configuration for later use
|
|
93
|
+
self.enable_profile = enable_profile
|
|
94
|
+
|
|
95
|
+
# 使用配置文件中的模板,如果没有则使用默认模板
|
|
96
|
+
self.use_short_answer = config.get("use_short_answer", False) # 是否使用短答案模式
|
|
97
|
+
|
|
98
|
+
if "template" in config:
|
|
99
|
+
from jinja2 import Template
|
|
100
|
+
|
|
101
|
+
self.prompt_template = Template(config["template"])
|
|
102
|
+
else:
|
|
103
|
+
# 根据配置选择模板
|
|
104
|
+
if self.use_short_answer:
|
|
105
|
+
self.prompt_template = QA_short_answer_template
|
|
106
|
+
else:
|
|
107
|
+
self.prompt_template = QA_prompt_template # Load the QA prompt template
|
|
108
|
+
|
|
109
|
+
# 只有启用profile时才设置数据存储路径
|
|
110
|
+
if self.enable_profile:
|
|
111
|
+
from sage.common.config.output_paths import get_sage_paths
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
sage_paths = get_sage_paths()
|
|
115
|
+
self.data_base_path = str(sage_paths.states_dir / "promptor_data")
|
|
116
|
+
except Exception:
|
|
117
|
+
# Fallback to current working directory
|
|
118
|
+
if (
|
|
119
|
+
self.ctx is not None
|
|
120
|
+
and hasattr(self.ctx, "env_base_dir")
|
|
121
|
+
and self.ctx.env_base_dir
|
|
122
|
+
):
|
|
123
|
+
self.data_base_path = os.path.join(
|
|
124
|
+
self.ctx.env_base_dir, ".sage_states", "promptor_data"
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
# 使用默认路径
|
|
128
|
+
self.data_base_path = os.path.join(os.getcwd(), ".sage_states", "promptor_data")
|
|
129
|
+
|
|
130
|
+
os.makedirs(self.data_base_path, exist_ok=True)
|
|
131
|
+
self.data_records = []
|
|
132
|
+
|
|
133
|
+
def _save_data_record(self, query, external_corpus, prompt):
|
|
134
|
+
"""保存提示词数据记录"""
|
|
135
|
+
if not self.enable_profile:
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
record = {
|
|
139
|
+
"timestamp": time.time(),
|
|
140
|
+
"query": query,
|
|
141
|
+
"external_corpus": external_corpus,
|
|
142
|
+
"prompt": prompt,
|
|
143
|
+
}
|
|
144
|
+
self.data_records.append(record)
|
|
145
|
+
self._persist_data_records()
|
|
146
|
+
|
|
147
|
+
def _persist_data_records(self):
|
|
148
|
+
"""将数据记录持久化到文件"""
|
|
149
|
+
if not self.enable_profile or not self.data_records:
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
timestamp = int(time.time())
|
|
153
|
+
filename = f"promptor_data_{timestamp}.json"
|
|
154
|
+
path = os.path.join(self.data_base_path, filename)
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
158
|
+
json.dump(self.data_records, f, ensure_ascii=False, indent=2)
|
|
159
|
+
self.data_records = []
|
|
160
|
+
except Exception as e:
|
|
161
|
+
self.logger.error(f"Failed to persist data records: {e}")
|
|
162
|
+
|
|
163
|
+
# sage_lib/functions/rag/qapromptor.py
|
|
164
|
+
def execute(self, data) -> list:
|
|
165
|
+
"""
|
|
166
|
+
生成 ChatGPT 风格的 prompt(system+user 两条消息)。
|
|
167
|
+
|
|
168
|
+
支持多种输入格式:
|
|
169
|
+
1. (query, external_corpus_list_or_str) # 元组格式
|
|
170
|
+
2. query_str # 纯字符串
|
|
171
|
+
3. {"query": ..., "results": [...]} # 字典格式(来自检索器)
|
|
172
|
+
4. {"question": ..., "context": [...]} # 字典格式(来自测试)
|
|
173
|
+
"""
|
|
174
|
+
self.logger.info(f"QAPromptor received data: {data}")
|
|
175
|
+
try:
|
|
176
|
+
# -------- 解析输入 --------
|
|
177
|
+
raw = data
|
|
178
|
+
original_data = data # 保存原始数据以便返回
|
|
179
|
+
|
|
180
|
+
if isinstance(raw, dict):
|
|
181
|
+
# 字典格式输入 - 支持多种字段名
|
|
182
|
+
query = raw.get("query", raw.get("question", ""))
|
|
183
|
+
|
|
184
|
+
# 处理不同的上下文字段名
|
|
185
|
+
external_corpus_list = []
|
|
186
|
+
|
|
187
|
+
# 处理 refining_results 字段(来自 refiner - 压缩后的文档)
|
|
188
|
+
if "refining_results" in raw:
|
|
189
|
+
results = raw.get("refining_results", [])
|
|
190
|
+
for result in results:
|
|
191
|
+
if isinstance(result, str):
|
|
192
|
+
external_corpus_list.append(result)
|
|
193
|
+
else:
|
|
194
|
+
external_corpus_list.append(str(result))
|
|
195
|
+
|
|
196
|
+
# 处理 retrieval_results 字段(来自 retriever - 原始检索结果)
|
|
197
|
+
elif "retrieval_results" in raw:
|
|
198
|
+
results = raw.get("retrieval_results", [])
|
|
199
|
+
for result in results:
|
|
200
|
+
if isinstance(result, dict) and "text" in result:
|
|
201
|
+
external_corpus_list.append(result["text"])
|
|
202
|
+
elif isinstance(result, str):
|
|
203
|
+
external_corpus_list.append(result)
|
|
204
|
+
else:
|
|
205
|
+
external_corpus_list.append(str(result))
|
|
206
|
+
|
|
207
|
+
# 处理 context 字段(来自测试)
|
|
208
|
+
elif "context" in raw:
|
|
209
|
+
context = raw.get("context", [])
|
|
210
|
+
if isinstance(context, list):
|
|
211
|
+
external_corpus_list.extend([str(c) for c in context])
|
|
212
|
+
else:
|
|
213
|
+
external_corpus_list.append(str(context))
|
|
214
|
+
|
|
215
|
+
# 处理 external_corpus 字段
|
|
216
|
+
elif "external_corpus" in raw:
|
|
217
|
+
external_corpus = raw.get("external_corpus", "")
|
|
218
|
+
if isinstance(external_corpus, list):
|
|
219
|
+
external_corpus_list.extend([str(c) for c in external_corpus])
|
|
220
|
+
else:
|
|
221
|
+
external_corpus_list.append(str(external_corpus))
|
|
222
|
+
|
|
223
|
+
external_corpus = "\n".join(external_corpus_list)
|
|
224
|
+
|
|
225
|
+
elif isinstance(raw, tuple) and len(raw) == 2:
|
|
226
|
+
# 元组格式输入
|
|
227
|
+
query, external_corpus = raw
|
|
228
|
+
if isinstance(external_corpus, list):
|
|
229
|
+
external_corpus = "\n".join(external_corpus)
|
|
230
|
+
# 对于元组输入,保持原有行为,返回query而不是原始数据
|
|
231
|
+
original_data = query
|
|
232
|
+
else:
|
|
233
|
+
# 字符串格式输入
|
|
234
|
+
query = str(raw)
|
|
235
|
+
external_corpus = ""
|
|
236
|
+
# 对于字符串输入,保持原有行为,返回query而不是原始数据
|
|
237
|
+
original_data = query
|
|
238
|
+
|
|
239
|
+
external_corpus = external_corpus or ""
|
|
240
|
+
|
|
241
|
+
# -------- system prompt --------
|
|
242
|
+
if external_corpus:
|
|
243
|
+
system_prompt = {
|
|
244
|
+
"role": "system",
|
|
245
|
+
"content": self.prompt_template.render(external_corpus=external_corpus),
|
|
246
|
+
}
|
|
247
|
+
else:
|
|
248
|
+
system_prompt = {
|
|
249
|
+
"role": "system",
|
|
250
|
+
"content": (
|
|
251
|
+
"You are a helpful AI assistant. Answer the user's questions accurately."
|
|
252
|
+
),
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
# -------- user prompt --------
|
|
256
|
+
user_prompt = {
|
|
257
|
+
"role": "user",
|
|
258
|
+
"content": f"Question: {query}",
|
|
259
|
+
}
|
|
260
|
+
self.logger.info(
|
|
261
|
+
f"QAPromptor generated prompt: {system_prompt['content']} | {user_prompt['content']}"
|
|
262
|
+
)
|
|
263
|
+
prompt = [system_prompt, user_prompt]
|
|
264
|
+
|
|
265
|
+
# 保存数据记录(只有enable_profile=True时才保存)
|
|
266
|
+
if self.enable_profile:
|
|
267
|
+
self._save_data_record(query, external_corpus, prompt)
|
|
268
|
+
|
|
269
|
+
return [original_data, prompt]
|
|
270
|
+
|
|
271
|
+
except Exception as e:
|
|
272
|
+
self.logger.error("QAPromptor error: %s | input=%s", e, getattr(data, "data", ""))
|
|
273
|
+
fallback = [
|
|
274
|
+
{"role": "system", "content": "System encountered an error."},
|
|
275
|
+
{
|
|
276
|
+
"role": "user",
|
|
277
|
+
"content": (
|
|
278
|
+
"Question: Error occurred. Please try again."
|
|
279
|
+
f" (Original: {getattr(data, 'data', '')})"
|
|
280
|
+
),
|
|
281
|
+
},
|
|
282
|
+
]
|
|
283
|
+
return fallback
|
|
284
|
+
|
|
285
|
+
def __del__(self):
|
|
286
|
+
"""确保在对象销毁时保存所有未保存的记录"""
|
|
287
|
+
if hasattr(self, "enable_profile") and self.enable_profile:
|
|
288
|
+
try:
|
|
289
|
+
self._persist_data_records()
|
|
290
|
+
except Exception:
|
|
291
|
+
pass
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class SummarizationPromptor(MapOperator):
|
|
295
|
+
"""
|
|
296
|
+
QAPromptor is a prompt rag that generates a QA-style prompt using
|
|
297
|
+
an external corpus and a user query. This class is designed to prepare
|
|
298
|
+
the necessary prompt structure for a question-answering model.
|
|
299
|
+
|
|
300
|
+
Attributes:
|
|
301
|
+
config: Configuration data for initializing the prompt rag (e.g., model details, etc.).
|
|
302
|
+
prompt_template: A template used for generating the system prompt, typically includes context or instructions.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
prompt_template: Template
|
|
306
|
+
|
|
307
|
+
def __init__(self, config):
|
|
308
|
+
"""
|
|
309
|
+
Initializes the QAPromptor instance with configuration and prompt template.
|
|
310
|
+
|
|
311
|
+
:param config: Dictionary containing configuration for the prompt rag.
|
|
312
|
+
"""
|
|
313
|
+
super().__init__()
|
|
314
|
+
self.config = config # Store the configuration for later use
|
|
315
|
+
self.prompt_template = (
|
|
316
|
+
summarization_prompt_template # Load the summarization prompt template
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def execute(self, data) -> list:
|
|
320
|
+
"""
|
|
321
|
+
Generates a QA-style prompt for the input question and external corpus.
|
|
322
|
+
|
|
323
|
+
This method takes the query and external corpus, processes the corpus
|
|
324
|
+
into a single string, and creates a system prompt and user prompt based
|
|
325
|
+
on a predefined template.
|
|
326
|
+
|
|
327
|
+
:param data: A Data object containing a tuple. The first element is the query (a string),
|
|
328
|
+
and the second is a list of external corpus (contextual information for the model).
|
|
329
|
+
|
|
330
|
+
:return: A Data object containing a list with two prompts:
|
|
331
|
+
1. system_prompt: A system prompt based on the template with external corpus data.
|
|
332
|
+
2. user_prompt: A user prompt containing the question to be answered.
|
|
333
|
+
"""
|
|
334
|
+
# Unpack the input data into query and external_corpus
|
|
335
|
+
query, external_corpus = data
|
|
336
|
+
|
|
337
|
+
# Combine the external corpus list into a single string (in case it's split into multiple parts)
|
|
338
|
+
external_corpus = "".join(external_corpus)
|
|
339
|
+
|
|
340
|
+
# Prepare the base data for the system prompt, which includes the external corpus
|
|
341
|
+
base_system_prompt_data = {"external_corpus": external_corpus}
|
|
342
|
+
|
|
343
|
+
# query = data
|
|
344
|
+
# Create the system prompt using the template and the external corpus data
|
|
345
|
+
system_prompt = {
|
|
346
|
+
"role": "system",
|
|
347
|
+
"content": self.prompt_template.render(**base_system_prompt_data),
|
|
348
|
+
}
|
|
349
|
+
# system_prompt = {
|
|
350
|
+
# "role": "system",
|
|
351
|
+
# "content": ""
|
|
352
|
+
# }
|
|
353
|
+
# Create the user prompt using the query
|
|
354
|
+
user_prompt = {"role": "user", "content": f"Question: {query}"}
|
|
355
|
+
|
|
356
|
+
# Combine the system and user prompts into one list
|
|
357
|
+
prompt = [system_prompt, user_prompt]
|
|
358
|
+
|
|
359
|
+
# Return the prompt list wrapped in a Data object
|
|
360
|
+
return prompt
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class QueryProfilerPromptor(MapOperator):
|
|
364
|
+
"""
|
|
365
|
+
QueryProfilerPromptor provides a prompt for profiling queries.
|
|
366
|
+
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
prompt_template: Template
|
|
370
|
+
|
|
371
|
+
def __init__(self, config):
|
|
372
|
+
"""
|
|
373
|
+
Initializes the QueryProfilerPromptor instance with configuration and prompt template.
|
|
374
|
+
|
|
375
|
+
:param config: Dictionary containing configuration for the prompt rag.
|
|
376
|
+
"""
|
|
377
|
+
super().__init__()
|
|
378
|
+
self.config = config # Store the configuration for later use
|
|
379
|
+
self.prompt_template = (
|
|
380
|
+
query_profiler_prompt_template # Load the query profiler prompt template
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
def execute(self, data) -> list:
|
|
384
|
+
"""
|
|
385
|
+
Generates a profiling prompt for the input query.
|
|
386
|
+
|
|
387
|
+
:param data: A string representing the query to be profiled.
|
|
388
|
+
|
|
389
|
+
:return: A list containing the profiling prompt.
|
|
390
|
+
"""
|
|
391
|
+
query = data
|
|
392
|
+
prompt = {
|
|
393
|
+
"role": "user",
|
|
394
|
+
"content": self.prompt_template.render(
|
|
395
|
+
query=query,
|
|
396
|
+
metadata=self.config.get("metadata", {}),
|
|
397
|
+
chunk_size=self.config.get("chunk_size", 1024),
|
|
398
|
+
),
|
|
399
|
+
}
|
|
400
|
+
return [prompt]
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Refiner Operator - SAGE RAG Pipeline Operator
|
|
3
|
+
==============================================
|
|
4
|
+
|
|
5
|
+
Uses isage-refiner (sage_refiner) for context compression in RAG pipelines.
|
|
6
|
+
|
|
7
|
+
Installation:
|
|
8
|
+
pip install isage-refiner
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from sage.middleware.operators.rag import RefinerOperator
|
|
12
|
+
|
|
13
|
+
config = {
|
|
14
|
+
"algorithm": "long_refiner", # or "reform", "provence", etc.
|
|
15
|
+
"budget": 2048,
|
|
16
|
+
# LongRefiner specific
|
|
17
|
+
"base_model_path": "Qwen/Qwen2.5-3B-Instruct",
|
|
18
|
+
...
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
env.map(RefinerOperator, config)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import json
|
|
25
|
+
import os
|
|
26
|
+
import time
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
from sage.common.config.output_paths import get_states_file
|
|
30
|
+
from sage.common.core.functions import MapFunction as MapOperator
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RefinerOperator(MapOperator):
|
|
34
|
+
"""
|
|
35
|
+
Refiner Operator for SAGE RAG pipelines.
|
|
36
|
+
|
|
37
|
+
Wraps isage-refiner compressors (LongRefiner, REFORM, Provence, etc.)
|
|
38
|
+
for use in SAGE dataflow pipelines.
|
|
39
|
+
|
|
40
|
+
Config:
|
|
41
|
+
algorithm: str - "long_refiner", "reform", "provence", "llmlingua2", etc.
|
|
42
|
+
budget: int - Token budget for compression
|
|
43
|
+
enable_profile: bool - Enable data recording for debugging
|
|
44
|
+
|
|
45
|
+
# Algorithm-specific config passed through to compressor
|
|
46
|
+
base_model_path: str - For LongRefiner
|
|
47
|
+
score_model_path: str - For LongRefiner
|
|
48
|
+
...
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, config: dict, ctx=None):
|
|
52
|
+
super().__init__(config=config, ctx=ctx)
|
|
53
|
+
self.cfg = config
|
|
54
|
+
self.enable_profile = config.get("enable_profile", False)
|
|
55
|
+
self.compressor = None
|
|
56
|
+
|
|
57
|
+
# Data recording (only when enable_profile=True)
|
|
58
|
+
if self.enable_profile:
|
|
59
|
+
self.data_base_path = str(get_states_file("dummy", "refiner_data").parent)
|
|
60
|
+
os.makedirs(self.data_base_path, exist_ok=True)
|
|
61
|
+
self.data_records: list[dict] = []
|
|
62
|
+
|
|
63
|
+
self._init_compressor()
|
|
64
|
+
|
|
65
|
+
def _init_compressor(self):
|
|
66
|
+
"""Initialize the compressor from isage-refiner."""
|
|
67
|
+
algorithm = self.cfg.get("algorithm", "long_refiner").lower()
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
if algorithm == "long_refiner":
|
|
71
|
+
from sage_refiner import LongRefinerCompressor
|
|
72
|
+
|
|
73
|
+
self.compressor = LongRefinerCompressor(
|
|
74
|
+
base_model_path=self.cfg.get("base_model_path", "Qwen/Qwen2.5-3B-Instruct"),
|
|
75
|
+
query_analysis_module_lora_path=self.cfg.get(
|
|
76
|
+
"query_analysis_module_lora_path", ""
|
|
77
|
+
),
|
|
78
|
+
doc_structuring_module_lora_path=self.cfg.get(
|
|
79
|
+
"doc_structuring_module_lora_path", ""
|
|
80
|
+
),
|
|
81
|
+
global_selection_module_lora_path=self.cfg.get(
|
|
82
|
+
"global_selection_module_lora_path", ""
|
|
83
|
+
),
|
|
84
|
+
score_model_path=self.cfg.get("score_model_path", "BAAI/bge-reranker-v2-m3"),
|
|
85
|
+
max_model_len=self.cfg.get("max_model_len", 25000),
|
|
86
|
+
gpu_memory_utilization=self.cfg.get("gpu_memory_utilization", 0.5),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
elif algorithm == "reform":
|
|
90
|
+
from sage_refiner import REFORMCompressor
|
|
91
|
+
|
|
92
|
+
self.compressor = REFORMCompressor(**self.cfg.get("reform_config", {}))
|
|
93
|
+
|
|
94
|
+
elif algorithm == "provence":
|
|
95
|
+
from sage_refiner import ProvenceCompressor
|
|
96
|
+
|
|
97
|
+
self.compressor = ProvenceCompressor(**self.cfg.get("provence_config", {}))
|
|
98
|
+
|
|
99
|
+
elif algorithm in ("simple", "none"):
|
|
100
|
+
# Simple truncation - no compression
|
|
101
|
+
self.compressor = None
|
|
102
|
+
self.logger.info("Using simple/none mode - no compression")
|
|
103
|
+
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
|
106
|
+
|
|
107
|
+
self.logger.info(f"RefinerOperator initialized with algorithm: {algorithm}")
|
|
108
|
+
|
|
109
|
+
except ImportError as e:
|
|
110
|
+
raise ImportError(
|
|
111
|
+
f"Failed to import {algorithm} compressor. "
|
|
112
|
+
f"Install with: pip install isage-refiner\n"
|
|
113
|
+
f"Error: {e}"
|
|
114
|
+
) from e
|
|
115
|
+
|
|
116
|
+
def execute(self, data: dict):
|
|
117
|
+
"""Execute document compression.
|
|
118
|
+
|
|
119
|
+
Input format:
|
|
120
|
+
{
|
|
121
|
+
"query": str,
|
|
122
|
+
"retrieval_results": List[Dict], # Retrieved documents
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
Output format:
|
|
126
|
+
{
|
|
127
|
+
"query": str,
|
|
128
|
+
"retrieval_results": List[Dict], # Original (preserved)
|
|
129
|
+
"refining_results": List[str], # Compressed document texts
|
|
130
|
+
}
|
|
131
|
+
"""
|
|
132
|
+
if not isinstance(data, dict):
|
|
133
|
+
self.logger.error(f"Unexpected input format: {type(data)}")
|
|
134
|
+
return data
|
|
135
|
+
|
|
136
|
+
query = data.get("query", "")
|
|
137
|
+
docs = data.get("retrieval_results", [])
|
|
138
|
+
|
|
139
|
+
# Normalize documents to isage-refiner format
|
|
140
|
+
documents = self._normalize_documents(docs)
|
|
141
|
+
|
|
142
|
+
# Compress
|
|
143
|
+
try:
|
|
144
|
+
if self.compressor is None:
|
|
145
|
+
# Simple mode: just extract text
|
|
146
|
+
refined_texts = [
|
|
147
|
+
doc.get("contents", doc.get("text", str(doc))) for doc in documents
|
|
148
|
+
]
|
|
149
|
+
else:
|
|
150
|
+
budget = self.cfg.get("budget", 2048)
|
|
151
|
+
result = self.compressor.compress(
|
|
152
|
+
question=query,
|
|
153
|
+
document_list=documents,
|
|
154
|
+
budget=budget,
|
|
155
|
+
)
|
|
156
|
+
# isage-refiner returns dict with various fields
|
|
157
|
+
refined_texts = result.get("compressed_context", "")
|
|
158
|
+
if isinstance(refined_texts, str):
|
|
159
|
+
refined_texts = [refined_texts]
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
self.logger.error(f"Refiner execution failed: {e}")
|
|
163
|
+
refined_texts = [doc.get("contents", str(doc)) for doc in documents]
|
|
164
|
+
|
|
165
|
+
# Save data record if profiling
|
|
166
|
+
if self.enable_profile:
|
|
167
|
+
self._save_data_record(query, documents, refined_texts)
|
|
168
|
+
|
|
169
|
+
# Build output
|
|
170
|
+
result_data = data.copy()
|
|
171
|
+
result_data["refining_results"] = refined_texts
|
|
172
|
+
|
|
173
|
+
return result_data
|
|
174
|
+
|
|
175
|
+
def _normalize_documents(self, docs: list[str | dict]) -> list[dict[str, Any]]:
|
|
176
|
+
"""Normalize documents to isage-refiner format (with 'contents' key)."""
|
|
177
|
+
normalized: list[dict[str, Any]] = []
|
|
178
|
+
for doc in docs:
|
|
179
|
+
if isinstance(doc, dict):
|
|
180
|
+
# isage-refiner expects 'contents' key
|
|
181
|
+
text = doc.get("contents") or doc.get("text") or str(doc)
|
|
182
|
+
normalized.append({"contents": text, **doc})
|
|
183
|
+
elif isinstance(doc, str):
|
|
184
|
+
normalized.append({"contents": doc})
|
|
185
|
+
else:
|
|
186
|
+
normalized.append({"contents": str(doc)})
|
|
187
|
+
|
|
188
|
+
return normalized
|
|
189
|
+
|
|
190
|
+
def _save_data_record(self, query: str, input_docs: list[dict], refined_docs: list[str]):
|
|
191
|
+
"""Save data record (only when enable_profile=True)."""
|
|
192
|
+
if not self.enable_profile:
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
record = {
|
|
196
|
+
"timestamp": time.time(),
|
|
197
|
+
"query": query,
|
|
198
|
+
"input_docs": input_docs,
|
|
199
|
+
"refined_docs": refined_docs,
|
|
200
|
+
"budget": self.cfg.get("budget"),
|
|
201
|
+
}
|
|
202
|
+
self.data_records.append(record)
|
|
203
|
+
|
|
204
|
+
# Persist every 10 records
|
|
205
|
+
if len(self.data_records) >= 10:
|
|
206
|
+
self._persist_data_records()
|
|
207
|
+
|
|
208
|
+
def _persist_data_records(self):
|
|
209
|
+
"""Persist data records to disk."""
|
|
210
|
+
if not self.enable_profile or not self.data_records:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
timestamp = int(time.time())
|
|
214
|
+
filename = f"refiner_data_{timestamp}.json"
|
|
215
|
+
path = os.path.join(self.data_base_path, filename)
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
219
|
+
json.dump(self.data_records, f, ensure_ascii=False, indent=2)
|
|
220
|
+
self.logger.info(f"Saved {len(self.data_records)} records to {path}")
|
|
221
|
+
self.data_records = []
|
|
222
|
+
except Exception as e:
|
|
223
|
+
self.logger.error(f"Failed to persist data records: {e}")
|
|
224
|
+
|
|
225
|
+
def __del__(self):
|
|
226
|
+
"""Ensure data is saved on cleanup."""
|
|
227
|
+
if hasattr(self, "enable_profile") and self.enable_profile:
|
|
228
|
+
try:
|
|
229
|
+
self._persist_data_records()
|
|
230
|
+
except Exception:
|
|
231
|
+
pass
|