isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
"""SageLLM Generator - 统一 LLM 推理算子
|
|
2
|
+
|
|
3
|
+
通过 EngineFactory 统一创建引擎,不硬编码任何具体引擎实现。
|
|
4
|
+
支持 auto/mock/cuda/ascend 等多种后端类型。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from collections.abc import AsyncGenerator, Sequence
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from sage.common.core.functions import MapFunction as MapOperator
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _normalize_input(data: Any) -> tuple[dict[str, Any], str, dict[str, Any]]:
|
|
20
|
+
"""
|
|
21
|
+
规范化输入数据为 (context, prompt, options) 三元组。
|
|
22
|
+
|
|
23
|
+
支持多种输入格式:
|
|
24
|
+
- str: 直接作为 prompt
|
|
25
|
+
- dict: 包含 prompt 和可选的 options
|
|
26
|
+
- Sequence: [context, prompt] 或 [context, prompt, options]
|
|
27
|
+
"""
|
|
28
|
+
context: dict[str, Any] = {}
|
|
29
|
+
prompt: str = ""
|
|
30
|
+
options: dict[str, Any] = {}
|
|
31
|
+
|
|
32
|
+
if isinstance(data, str):
|
|
33
|
+
prompt = data
|
|
34
|
+
elif isinstance(data, dict):
|
|
35
|
+
prompt = data.get("prompt", "")
|
|
36
|
+
options = dict(data.get("options", {}))
|
|
37
|
+
# 保留其他上下文字段
|
|
38
|
+
context = {k: v for k, v in data.items() if k not in ("prompt", "options")}
|
|
39
|
+
elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
|
|
40
|
+
if len(data) >= 1:
|
|
41
|
+
first = data[0]
|
|
42
|
+
if isinstance(first, str):
|
|
43
|
+
prompt = first
|
|
44
|
+
elif isinstance(first, dict):
|
|
45
|
+
context = dict(first)
|
|
46
|
+
if len(data) >= 2:
|
|
47
|
+
second = data[1]
|
|
48
|
+
if isinstance(second, str):
|
|
49
|
+
prompt = second
|
|
50
|
+
elif isinstance(second, dict) and "prompt" in second:
|
|
51
|
+
prompt = second["prompt"]
|
|
52
|
+
options.update(second.get("options", {}))
|
|
53
|
+
if len(data) >= 3 and isinstance(data[2], dict):
|
|
54
|
+
options.update(data[2])
|
|
55
|
+
else:
|
|
56
|
+
prompt = str(data)
|
|
57
|
+
|
|
58
|
+
return context, prompt, options
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class SageLLMGenerator(MapOperator):
|
|
63
|
+
"""
|
|
64
|
+
SageLLM 统一生成算子 - 通过 EngineFactory 创建引擎进行文本生成
|
|
65
|
+
|
|
66
|
+
不直接导入或硬编码任何具体引擎实现(如 HFCudaEngine, MockEngine),
|
|
67
|
+
而是通过工厂模式动态创建引擎,实现后端解耦。
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
```python
|
|
71
|
+
# 自动选择后端
|
|
72
|
+
generator = SageLLMGenerator(
|
|
73
|
+
model_path="Qwen/Qwen2.5-7B-Instruct",
|
|
74
|
+
backend_type="auto",
|
|
75
|
+
)
|
|
76
|
+
result = generator.execute("写一首诗")
|
|
77
|
+
|
|
78
|
+
# 指定 mock 后端用于测试
|
|
79
|
+
generator = SageLLMGenerator(
|
|
80
|
+
backend_type="mock",
|
|
81
|
+
model_path="mock-model",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# 流式生成
|
|
85
|
+
async for chunk in generator.stream_async("讲个故事"):
|
|
86
|
+
print(chunk["text"], end="", flush=True)
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Attributes:
|
|
90
|
+
backend_type: 引擎后端类型,支持 "auto"/"mock"/"cuda"/"ascend" 等
|
|
91
|
+
model_path: 模型路径或 HuggingFace 模型 ID
|
|
92
|
+
device_map: 设备映射策略,如 "auto"/"cuda:0"/"cpu"
|
|
93
|
+
dtype: 数据类型,如 "auto"/"float16"/"bfloat16"
|
|
94
|
+
max_tokens: 最大生成 token 数
|
|
95
|
+
temperature: 采样温度
|
|
96
|
+
top_p: nucleus 采样参数
|
|
97
|
+
top_k: top-k 采样参数
|
|
98
|
+
default_options: 默认生成选项
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# 核心配置
|
|
102
|
+
backend_type: str = "auto"
|
|
103
|
+
model_path: str = ""
|
|
104
|
+
device_map: str = "auto"
|
|
105
|
+
dtype: str = "float16"
|
|
106
|
+
device: str = "cuda"
|
|
107
|
+
|
|
108
|
+
# HFCudaEngine 必需的配置(fail-fast 设计)
|
|
109
|
+
load_in_8bit: bool = False
|
|
110
|
+
load_in_4bit: bool = False
|
|
111
|
+
trust_remote_code: bool = False
|
|
112
|
+
|
|
113
|
+
# 生成参数默认值
|
|
114
|
+
max_tokens: int = 2048
|
|
115
|
+
max_new_tokens: int = 128 # HFCudaEngine 使用此字段
|
|
116
|
+
temperature: float = 0.7
|
|
117
|
+
top_p: float = 0.95
|
|
118
|
+
top_k: int = 50
|
|
119
|
+
|
|
120
|
+
# 引擎配置
|
|
121
|
+
engine_id: str = ""
|
|
122
|
+
timeout: float = 120.0
|
|
123
|
+
default_options: dict[str, Any] = field(default_factory=dict)
|
|
124
|
+
|
|
125
|
+
# 内部状态
|
|
126
|
+
_engine: Any = field(default=None, init=False, repr=False)
|
|
127
|
+
_initialized: bool = field(default=False, init=False, repr=False)
|
|
128
|
+
|
|
129
|
+
def __post_init__(self) -> None:
|
|
130
|
+
super().__init__()
|
|
131
|
+
if not self.engine_id:
|
|
132
|
+
self.engine_id = f"sage-llm-{id(self)}"
|
|
133
|
+
|
|
134
|
+
def _ensure_engine(self) -> None:
|
|
135
|
+
"""
|
|
136
|
+
确保引擎已初始化。
|
|
137
|
+
|
|
138
|
+
延迟初始化策略:只在首次使用时创建引擎。
|
|
139
|
+
通过 EngineFactory 统一创建,不直接导入具体引擎类。
|
|
140
|
+
"""
|
|
141
|
+
if self._initialized and self._engine is not None:
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
# 统一通过工厂创建,不直接 import 具体引擎
|
|
146
|
+
from sagellm_backend.engine.factory import EngineFactory
|
|
147
|
+
|
|
148
|
+
config = {
|
|
149
|
+
"engine_id": self.engine_id,
|
|
150
|
+
"model_path": self.model_path,
|
|
151
|
+
"device": self.device,
|
|
152
|
+
"device_map": self.device_map,
|
|
153
|
+
"dtype": self.dtype,
|
|
154
|
+
"load_in_8bit": self.load_in_8bit,
|
|
155
|
+
"load_in_4bit": self.load_in_4bit,
|
|
156
|
+
"trust_remote_code": self.trust_remote_code,
|
|
157
|
+
"max_new_tokens": self.max_new_tokens,
|
|
158
|
+
"max_tokens": self.max_tokens,
|
|
159
|
+
"temperature": self.temperature,
|
|
160
|
+
"top_p": self.top_p,
|
|
161
|
+
"top_k": self.top_k,
|
|
162
|
+
"mock_mode": self.backend_type == "mock",
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
logger.info(
|
|
166
|
+
f"Creating engine: backend_type={self.backend_type}, "
|
|
167
|
+
f"model_path={self.model_path}, engine_id={self.engine_id}"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self._engine = EngineFactory.create(
|
|
171
|
+
backend_type=self.backend_type,
|
|
172
|
+
config=config,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# 启动引擎(如果需要)
|
|
176
|
+
if hasattr(self._engine, "start") and not self._engine.is_running:
|
|
177
|
+
import asyncio
|
|
178
|
+
|
|
179
|
+
start_coro = self._engine.start()
|
|
180
|
+
if asyncio.iscoroutine(start_coro):
|
|
181
|
+
try:
|
|
182
|
+
loop = asyncio.get_running_loop()
|
|
183
|
+
except RuntimeError:
|
|
184
|
+
loop = None
|
|
185
|
+
|
|
186
|
+
if loop is not None:
|
|
187
|
+
import concurrent.futures
|
|
188
|
+
|
|
189
|
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
190
|
+
pool.submit(asyncio.run, start_coro).result()
|
|
191
|
+
else:
|
|
192
|
+
asyncio.run(start_coro)
|
|
193
|
+
|
|
194
|
+
self._initialized = True
|
|
195
|
+
|
|
196
|
+
logger.info(f"Engine created successfully: {self.engine_id}")
|
|
197
|
+
|
|
198
|
+
except ImportError as e:
|
|
199
|
+
raise ImportError(
|
|
200
|
+
f"Failed to import sagellm_backend. "
|
|
201
|
+
f"Please install it with: pip install sagellm-backend\n"
|
|
202
|
+
f"Original error: {e}"
|
|
203
|
+
) from e
|
|
204
|
+
except Exception as e:
|
|
205
|
+
logger.error(f"Failed to create engine: {e}")
|
|
206
|
+
raise RuntimeError(
|
|
207
|
+
f"Failed to create SageLLM engine with backend_type={self.backend_type}: {e}"
|
|
208
|
+
) from e
|
|
209
|
+
|
|
210
|
+
def _build_generation_params(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
211
|
+
"""
|
|
212
|
+
构建生成参数,合并默认值和用户指定的选项。
|
|
213
|
+
"""
|
|
214
|
+
params = {
|
|
215
|
+
"prompt": prompt,
|
|
216
|
+
"max_tokens": self.max_tokens,
|
|
217
|
+
"temperature": self.temperature,
|
|
218
|
+
"top_p": self.top_p,
|
|
219
|
+
"top_k": self.top_k,
|
|
220
|
+
}
|
|
221
|
+
# 应用默认选项
|
|
222
|
+
params.update(self.default_options)
|
|
223
|
+
# 应用用户传入的选项(优先级最高)
|
|
224
|
+
params.update({k: v for k, v in options.items() if v is not None})
|
|
225
|
+
return params
|
|
226
|
+
|
|
227
|
+
def execute(self, data: Any) -> dict[str, Any]:
|
|
228
|
+
"""
|
|
229
|
+
同步执行文本生成。
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
data: 输入数据,支持 str/dict/Sequence 格式
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
包含生成结果的字典:
|
|
236
|
+
- text: 生成的文本
|
|
237
|
+
- usage: token 使用统计
|
|
238
|
+
- context: 原始上下文(如果有)
|
|
239
|
+
"""
|
|
240
|
+
self._ensure_engine()
|
|
241
|
+
|
|
242
|
+
context, prompt, options = _normalize_input(data)
|
|
243
|
+
params = self._build_generation_params(prompt, options)
|
|
244
|
+
|
|
245
|
+
if not prompt:
|
|
246
|
+
logger.warning("Empty prompt received, returning empty result")
|
|
247
|
+
return {"text": "", "usage": {}, "context": context}
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
logger.debug(f"Generating with params: {params}")
|
|
251
|
+
|
|
252
|
+
# 调用引擎生成(兼容 execute/generate 两种接口,支持 async)
|
|
253
|
+
import asyncio
|
|
254
|
+
import uuid
|
|
255
|
+
|
|
256
|
+
if hasattr(self._engine, "execute"):
|
|
257
|
+
# 需要将 params 转换为 Request 对象
|
|
258
|
+
try:
|
|
259
|
+
from sagellm_protocol.types import Request as SageLLMRequest
|
|
260
|
+
|
|
261
|
+
request = SageLLMRequest(
|
|
262
|
+
request_id=str(uuid.uuid4()),
|
|
263
|
+
trace_id=str(uuid.uuid4()),
|
|
264
|
+
model=self.model_path or "default",
|
|
265
|
+
prompt=params.get("prompt", ""),
|
|
266
|
+
max_tokens=params.get("max_tokens", self.max_tokens),
|
|
267
|
+
stream=False,
|
|
268
|
+
temperature=params.get("temperature", self.temperature),
|
|
269
|
+
top_p=params.get("top_p", self.top_p),
|
|
270
|
+
)
|
|
271
|
+
coro_or_result = self._engine.execute(request)
|
|
272
|
+
except ImportError:
|
|
273
|
+
# 如果 sagellm_protocol 不可用,直接传 dict
|
|
274
|
+
coro_or_result = self._engine.execute(params)
|
|
275
|
+
|
|
276
|
+
# 检查是否是协程
|
|
277
|
+
if asyncio.iscoroutine(coro_or_result):
|
|
278
|
+
# 在同步上下文中运行异步方法
|
|
279
|
+
try:
|
|
280
|
+
loop = asyncio.get_running_loop()
|
|
281
|
+
except RuntimeError:
|
|
282
|
+
loop = None
|
|
283
|
+
|
|
284
|
+
if loop is not None:
|
|
285
|
+
# 已有事件循环,创建新任务
|
|
286
|
+
import concurrent.futures
|
|
287
|
+
|
|
288
|
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
289
|
+
result = pool.submit(asyncio.run, coro_or_result).result()
|
|
290
|
+
else:
|
|
291
|
+
result = asyncio.run(coro_or_result)
|
|
292
|
+
else:
|
|
293
|
+
result = coro_or_result
|
|
294
|
+
elif hasattr(self._engine, "generate"):
|
|
295
|
+
result = self._engine.generate(**params)
|
|
296
|
+
else:
|
|
297
|
+
raise RuntimeError(
|
|
298
|
+
f"Engine {type(self._engine).__name__} does not support "
|
|
299
|
+
"execute() or generate() method"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# 规范化输出格式
|
|
303
|
+
if isinstance(result, str):
|
|
304
|
+
output = {"text": result, "usage": {}}
|
|
305
|
+
elif isinstance(result, dict):
|
|
306
|
+
output = {
|
|
307
|
+
"text": result.get(
|
|
308
|
+
"text", result.get("generated", result.get("output_text", ""))
|
|
309
|
+
),
|
|
310
|
+
"usage": result.get("usage", {}),
|
|
311
|
+
}
|
|
312
|
+
elif hasattr(result, "output_text"):
|
|
313
|
+
# sagellm_protocol.types.Response 对象
|
|
314
|
+
output_tokens = getattr(result, "output_tokens", [])
|
|
315
|
+
num_output_tokens = len(output_tokens) if isinstance(output_tokens, list) else 0
|
|
316
|
+
output = {
|
|
317
|
+
"text": result.output_text,
|
|
318
|
+
"usage": {
|
|
319
|
+
"completion_tokens": num_output_tokens,
|
|
320
|
+
},
|
|
321
|
+
}
|
|
322
|
+
else:
|
|
323
|
+
output = {"text": str(result), "usage": {}}
|
|
324
|
+
|
|
325
|
+
# 附加上下文
|
|
326
|
+
if context:
|
|
327
|
+
output["context"] = context
|
|
328
|
+
|
|
329
|
+
return output
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
logger.error(f"Generation failed: {e}")
|
|
333
|
+
raise RuntimeError(f"SageLLM generation failed: {e}") from e
|
|
334
|
+
|
|
335
|
+
async def stream_async(self, data: Any) -> AsyncGenerator[dict[str, Any], None]:
|
|
336
|
+
"""
|
|
337
|
+
异步流式生成文本。
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
data: 输入数据,支持 str/dict/Sequence 格式
|
|
341
|
+
|
|
342
|
+
Yields:
|
|
343
|
+
流式输出的字典:
|
|
344
|
+
- text: 当前生成的文本片段
|
|
345
|
+
- done: 是否完成
|
|
346
|
+
- usage: token 使用统计(仅在完成时)
|
|
347
|
+
"""
|
|
348
|
+
self._ensure_engine()
|
|
349
|
+
|
|
350
|
+
context, prompt, options = _normalize_input(data)
|
|
351
|
+
params = self._build_generation_params(prompt, options)
|
|
352
|
+
|
|
353
|
+
if not prompt:
|
|
354
|
+
logger.warning("Empty prompt received for streaming")
|
|
355
|
+
yield {"text": "", "done": True, "usage": {}}
|
|
356
|
+
return
|
|
357
|
+
|
|
358
|
+
try:
|
|
359
|
+
logger.debug(f"Streaming generation with params: {params}")
|
|
360
|
+
|
|
361
|
+
# 检查引擎是否支持流式生成
|
|
362
|
+
if hasattr(self._engine, "generate_stream"):
|
|
363
|
+
async for chunk in self._engine.generate_stream(**params):
|
|
364
|
+
if isinstance(chunk, str):
|
|
365
|
+
yield {"text": chunk, "done": False}
|
|
366
|
+
elif isinstance(chunk, dict):
|
|
367
|
+
yield {
|
|
368
|
+
"text": chunk.get("text", ""),
|
|
369
|
+
"done": chunk.get("done", False),
|
|
370
|
+
"usage": chunk.get("usage", {}),
|
|
371
|
+
}
|
|
372
|
+
else:
|
|
373
|
+
yield {"text": str(chunk), "done": False}
|
|
374
|
+
|
|
375
|
+
# 发送完成信号
|
|
376
|
+
yield {"text": "", "done": True, "usage": {}}
|
|
377
|
+
|
|
378
|
+
elif hasattr(self._engine, "stream"):
|
|
379
|
+
# 兼容同步流式接口
|
|
380
|
+
for chunk in self._engine.stream(**params):
|
|
381
|
+
if isinstance(chunk, str):
|
|
382
|
+
yield {"text": chunk, "done": False}
|
|
383
|
+
elif isinstance(chunk, dict):
|
|
384
|
+
yield {
|
|
385
|
+
"text": chunk.get("text", ""),
|
|
386
|
+
"done": chunk.get("done", False),
|
|
387
|
+
}
|
|
388
|
+
else:
|
|
389
|
+
yield {"text": str(chunk), "done": False}
|
|
390
|
+
|
|
391
|
+
yield {"text": "", "done": True, "usage": {}}
|
|
392
|
+
|
|
393
|
+
else:
|
|
394
|
+
# 引擎不支持流式,降级为一次性返回
|
|
395
|
+
logger.warning(
|
|
396
|
+
f"Engine {self.backend_type} does not support streaming, "
|
|
397
|
+
"falling back to non-streaming generation"
|
|
398
|
+
)
|
|
399
|
+
result = self._engine.generate(**params)
|
|
400
|
+
text = result if isinstance(result, str) else result.get("text", "")
|
|
401
|
+
yield {"text": text, "done": True, "usage": result.get("usage", {})}
|
|
402
|
+
|
|
403
|
+
except Exception as e:
|
|
404
|
+
logger.error(f"Streaming generation failed: {e}")
|
|
405
|
+
yield {"text": "", "done": True, "error": str(e)}
|
|
406
|
+
|
|
407
|
+
def shutdown(self) -> None:
|
|
408
|
+
"""
|
|
409
|
+
关闭引擎并释放资源。
|
|
410
|
+
"""
|
|
411
|
+
if self._engine is not None:
|
|
412
|
+
try:
|
|
413
|
+
if hasattr(self._engine, "shutdown"):
|
|
414
|
+
self._engine.shutdown()
|
|
415
|
+
elif hasattr(self._engine, "close"):
|
|
416
|
+
self._engine.close()
|
|
417
|
+
logger.info(f"Engine {self.engine_id} shut down")
|
|
418
|
+
except Exception as e:
|
|
419
|
+
logger.warning(f"Error shutting down engine: {e}")
|
|
420
|
+
finally:
|
|
421
|
+
self._engine = None
|
|
422
|
+
self._initialized = False
|
|
423
|
+
|
|
424
|
+
def __del__(self) -> None:
|
|
425
|
+
"""析构时尝试清理资源。"""
|
|
426
|
+
try:
|
|
427
|
+
self.shutdown()
|
|
428
|
+
except Exception:
|
|
429
|
+
pass
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
__all__ = ["SageLLMGenerator"]
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAG (Retrieval-Augmented Generation) Operators
|
|
3
|
+
|
|
4
|
+
This module contains domain-specific operators for RAG applications:
|
|
5
|
+
- Pipeline (RAG orchestration and workflow)
|
|
6
|
+
- Profiler (Query profiling and analysis)
|
|
7
|
+
- Document Loaders (Document loading utilities)
|
|
8
|
+
- Generator operators (LLM response generation)
|
|
9
|
+
- Retriever operators (document/passage retrieval)
|
|
10
|
+
- Reranker operators (result reranking)
|
|
11
|
+
- Promptor operators (prompt construction)
|
|
12
|
+
- Evaluation operators (quality metrics)
|
|
13
|
+
- Document processing operators (chunking, refining, writing)
|
|
14
|
+
- External data source operators (ArXiv)
|
|
15
|
+
|
|
16
|
+
These operators inherit from base operator classes in sage.kernel.operators
|
|
17
|
+
and implement RAG-specific business logic.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
# Export types for easier access
|
|
21
|
+
from sage.libs.rag.types import (
|
|
22
|
+
RAGDocument,
|
|
23
|
+
RAGInput,
|
|
24
|
+
RAGOutput,
|
|
25
|
+
RAGQuery,
|
|
26
|
+
RAGResponse,
|
|
27
|
+
create_rag_response,
|
|
28
|
+
ensure_rag_response,
|
|
29
|
+
extract_query,
|
|
30
|
+
extract_results,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Lazy imports to avoid optional dependency issues
|
|
34
|
+
_IMPORTS = {
|
|
35
|
+
# Pipeline and Profiler
|
|
36
|
+
# RAGPipeline lives in the middleware layer (L4) as orchestration/pipeline code.
|
|
37
|
+
# It previously pointed to sage.libs.rag.pipeline (L3) which was deleted during
|
|
38
|
+
# the libs -> middleware refactor. Update to the new location.
|
|
39
|
+
"RAGPipeline": ("sage.middleware.operators.rag.pipeline", "RAGPipeline"),
|
|
40
|
+
"Query_Profiler": ("sage.middleware.operators.rag.profiler", "Query_Profiler"),
|
|
41
|
+
"QueryProfilerResult": ("sage.middleware.operators.rag.profiler", "QueryProfilerResult"),
|
|
42
|
+
# Document Loaders
|
|
43
|
+
"TextLoader": ("sage.libs.rag.document_loaders", "TextLoader"),
|
|
44
|
+
"PDFLoader": ("sage.libs.rag.document_loaders", "PDFLoader"),
|
|
45
|
+
"DocxLoader": ("sage.libs.rag.document_loaders", "DocxLoader"),
|
|
46
|
+
"DocLoader": ("sage.libs.rag.document_loaders", "DocLoader"),
|
|
47
|
+
"MarkdownLoader": ("sage.libs.rag.document_loaders", "MarkdownLoader"),
|
|
48
|
+
"LoaderFactory": ("sage.libs.rag.document_loaders", "LoaderFactory"),
|
|
49
|
+
# Generators
|
|
50
|
+
"OpenAIGenerator": ("sage.middleware.operators.rag.generator", "OpenAIGenerator"),
|
|
51
|
+
"HFGenerator": ("sage.middleware.operators.rag.generator", "HFGenerator"),
|
|
52
|
+
"SageLLMRAGGenerator": ("sage.middleware.operators.rag.generator", "SageLLMRAGGenerator"),
|
|
53
|
+
# Retrievers
|
|
54
|
+
"ChromaRetriever": ("sage.middleware.operators.rag.retriever", "ChromaRetriever"),
|
|
55
|
+
"MilvusDenseRetriever": (
|
|
56
|
+
"sage.middleware.operators.rag.retriever",
|
|
57
|
+
"MilvusDenseRetriever",
|
|
58
|
+
),
|
|
59
|
+
"MilvusSparseRetriever": (
|
|
60
|
+
"sage.middleware.operators.rag.retriever",
|
|
61
|
+
"MilvusSparseRetriever",
|
|
62
|
+
),
|
|
63
|
+
"Wiki18FAISSRetriever": (
|
|
64
|
+
"sage.middleware.operators.rag.retriever",
|
|
65
|
+
"Wiki18FAISSRetriever",
|
|
66
|
+
),
|
|
67
|
+
# Rerankers
|
|
68
|
+
"BGEReranker": ("sage.middleware.operators.rag.reranker", "BGEReranker"),
|
|
69
|
+
"LLMbased_Reranker": (
|
|
70
|
+
"sage.middleware.operators.rag.reranker",
|
|
71
|
+
"LLMbased_Reranker",
|
|
72
|
+
),
|
|
73
|
+
# Promptors
|
|
74
|
+
"QAPromptor": ("sage.middleware.operators.rag.promptor", "QAPromptor"),
|
|
75
|
+
"SummarizationPromptor": (
|
|
76
|
+
"sage.middleware.operators.rag.promptor",
|
|
77
|
+
"SummarizationPromptor",
|
|
78
|
+
),
|
|
79
|
+
"QueryProfilerPromptor": (
|
|
80
|
+
"sage.middleware.operators.rag.promptor",
|
|
81
|
+
"QueryProfilerPromptor",
|
|
82
|
+
),
|
|
83
|
+
# Evaluation
|
|
84
|
+
"F1Evaluate": ("sage.middleware.operators.rag.evaluate", "F1Evaluate"),
|
|
85
|
+
"EMEvaluate": ("sage.middleware.operators.rag.evaluate", "EMEvaluate"),
|
|
86
|
+
"RecallEvaluate": ("sage.middleware.operators.rag.evaluate", "RecallEvaluate"),
|
|
87
|
+
"BertRecallEvaluate": (
|
|
88
|
+
"sage.middleware.operators.rag.evaluate",
|
|
89
|
+
"BertRecallEvaluate",
|
|
90
|
+
),
|
|
91
|
+
"RougeLEvaluate": ("sage.middleware.operators.rag.evaluate", "RougeLEvaluate"),
|
|
92
|
+
"BRSEvaluate": ("sage.middleware.operators.rag.evaluate", "BRSEvaluate"),
|
|
93
|
+
"AccuracyEvaluate": ("sage.middleware.operators.rag.evaluate", "AccuracyEvaluate"),
|
|
94
|
+
"TokenCountEvaluate": (
|
|
95
|
+
"sage.middleware.operators.rag.evaluate",
|
|
96
|
+
"TokenCountEvaluate",
|
|
97
|
+
),
|
|
98
|
+
"LatencyEvaluate": ("sage.middleware.operators.rag.evaluate", "LatencyEvaluate"),
|
|
99
|
+
"ContextRecallEvaluate": (
|
|
100
|
+
"sage.middleware.operators.rag.evaluate",
|
|
101
|
+
"ContextRecallEvaluate",
|
|
102
|
+
),
|
|
103
|
+
"CompressionRateEvaluate": (
|
|
104
|
+
"sage.middleware.operators.rag.evaluate",
|
|
105
|
+
"CompressionRateEvaluate",
|
|
106
|
+
),
|
|
107
|
+
# Document Processing
|
|
108
|
+
"CharacterSplitter": ("sage.libs.rag.chunk", "CharacterSplitter"),
|
|
109
|
+
"SentenceTransformersTokenTextSplitter": (
|
|
110
|
+
"sage.libs.rag.chunk",
|
|
111
|
+
"SentenceTransformersTokenTextSplitter",
|
|
112
|
+
),
|
|
113
|
+
"RefinerOperator": ("sage.middleware.operators.rag.refiner", "RefinerOperator"),
|
|
114
|
+
"MemoryWriter": ("sage.middleware.operators.rag.writer", "MemoryWriter"),
|
|
115
|
+
# External Data Sources (may require optional dependencies)
|
|
116
|
+
"ArxivPDFDownloader": ("sage.middleware.operators.rag.arxiv", "ArxivPDFDownloader"),
|
|
117
|
+
"ArxivPDFParser": ("sage.middleware.operators.rag.arxiv", "ArxivPDFParser"),
|
|
118
|
+
# Web Search
|
|
119
|
+
"BochaWebSearch": ("sage.middleware.operators.rag.searcher", "BochaWebSearch"),
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
# Export all operator names and type utilities
|
|
123
|
+
__all__ = [ # type: ignore[misc]
|
|
124
|
+
# Types
|
|
125
|
+
"RAGDocument",
|
|
126
|
+
"RAGQuery",
|
|
127
|
+
"RAGResponse",
|
|
128
|
+
"RAGInput",
|
|
129
|
+
"RAGOutput",
|
|
130
|
+
"ensure_rag_response",
|
|
131
|
+
"extract_query",
|
|
132
|
+
"extract_results",
|
|
133
|
+
"create_rag_response",
|
|
134
|
+
# Operators (lazy loaded)
|
|
135
|
+
*list(_IMPORTS.keys()),
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def __getattr__(name: str):
|
|
140
|
+
"""Lazy import to avoid optional dependency issues at import time."""
|
|
141
|
+
if name in _IMPORTS:
|
|
142
|
+
module_name, attr_name = _IMPORTS[name]
|
|
143
|
+
import importlib
|
|
144
|
+
|
|
145
|
+
module = importlib.import_module(module_name)
|
|
146
|
+
return getattr(module, attr_name)
|
|
147
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|