isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
|
|
8
|
+
from sage.common.core.functions import MapFunction as MapOperator
|
|
9
|
+
from sage.middleware.operators.context.model_context import ModelContext
|
|
10
|
+
from sage.middleware.operators.context.search_result import SearchResult
|
|
11
|
+
from sage.middleware.operators.context.search_session import SearchSession
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BochaSearchTool(MapOperator):
|
|
15
|
+
"""
|
|
16
|
+
改进的Bocha搜索工具 - 使用新的分层搜索结果结构
|
|
17
|
+
输入: ModelContext (包含搜索查询)
|
|
18
|
+
输出: ModelContext (包含结构化的搜索结果)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: dict, **kwargs):
|
|
22
|
+
super().__init__(**kwargs)
|
|
23
|
+
|
|
24
|
+
self.url = config.get("url", "https://api.bochasearch.com/search")
|
|
25
|
+
self.api_key = config.get("api_key", os.getenv("BOCHA_API_KEY"))
|
|
26
|
+
self.max_results_per_query = config.get("max_results_per_query", 3)
|
|
27
|
+
self.search_engine_name = config.get("search_engine_name", "Bocha")
|
|
28
|
+
|
|
29
|
+
if not self.api_key:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
"BOCHA_API_KEY is required. Set it in environment variables or config."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self.headers = {
|
|
35
|
+
"Authorization": self.api_key,
|
|
36
|
+
"Content-Type": "application/json",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
self.search_count = 0
|
|
40
|
+
|
|
41
|
+
self.logger.info(
|
|
42
|
+
f"BochaSearchTool initialized with max_results_per_query: {self.max_results_per_query}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def _execute_single_search(self, query: str) -> dict[str, Any]:
|
|
46
|
+
"""
|
|
47
|
+
执行单个搜索查询
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
query: 搜索查询字符串
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Dict: 搜索API的原始响应
|
|
54
|
+
"""
|
|
55
|
+
start_time = time.time()
|
|
56
|
+
|
|
57
|
+
payload = json.dumps(
|
|
58
|
+
{
|
|
59
|
+
"query": query,
|
|
60
|
+
"summary": True,
|
|
61
|
+
"count": max(10, self.max_results_per_query * 2), # 请求更多结果以便筛选
|
|
62
|
+
"page": 1,
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
self.logger.debug(f"Executing search for query: '{query}'")
|
|
68
|
+
response = requests.post(self.url, headers=self.headers, data=payload, timeout=30)
|
|
69
|
+
response.raise_for_status()
|
|
70
|
+
|
|
71
|
+
execution_time = int((time.time() - start_time) * 1000)
|
|
72
|
+
result = response.json()
|
|
73
|
+
result["_execution_time_ms"] = execution_time
|
|
74
|
+
|
|
75
|
+
return result
|
|
76
|
+
|
|
77
|
+
except requests.exceptions.RequestException as e:
|
|
78
|
+
execution_time = int((time.time() - start_time) * 1000)
|
|
79
|
+
self.logger.error(f"Search API request failed for query '{query}': {e}")
|
|
80
|
+
return {
|
|
81
|
+
"error": str(e),
|
|
82
|
+
"data": {"webPages": {"value": []}},
|
|
83
|
+
"_execution_time_ms": execution_time,
|
|
84
|
+
}
|
|
85
|
+
except json.JSONDecodeError as e:
|
|
86
|
+
execution_time = int((time.time() - start_time) * 1000)
|
|
87
|
+
self.logger.error(f"Failed to parse search API response for query '{query}': {e}")
|
|
88
|
+
return {
|
|
89
|
+
"error": "JSON decode error",
|
|
90
|
+
"data": {"webPages": {"value": []}},
|
|
91
|
+
"_execution_time_ms": execution_time,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
def _convert_api_response_to_search_results(
|
|
95
|
+
self, api_response: dict[str, Any], query: str
|
|
96
|
+
) -> list[SearchResult]:
|
|
97
|
+
"""
|
|
98
|
+
将搜索API响应转换为SearchResult对象列表
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
api_response: 搜索API的响应
|
|
102
|
+
query: 原始查询
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List[SearchResult]: 搜索结果对象列表
|
|
106
|
+
"""
|
|
107
|
+
search_results = []
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
# 检查是否有错误
|
|
111
|
+
if "error" in api_response:
|
|
112
|
+
error_result = SearchResult(
|
|
113
|
+
title=f"Search Error for '{query}'",
|
|
114
|
+
content=f"Error: {api_response['error']}",
|
|
115
|
+
source="Error",
|
|
116
|
+
rank=1,
|
|
117
|
+
relevance_score=0.0,
|
|
118
|
+
)
|
|
119
|
+
search_results.append(error_result)
|
|
120
|
+
return search_results
|
|
121
|
+
|
|
122
|
+
# 提取网页结果
|
|
123
|
+
web_pages = api_response.get("data", {}).get("webPages", {}).get("value", [])
|
|
124
|
+
|
|
125
|
+
for i, page in enumerate(web_pages[: self.max_results_per_query]):
|
|
126
|
+
title = page.get("name", "No Title").strip()
|
|
127
|
+
content = page.get("snippet", "No content available").strip()
|
|
128
|
+
source = page.get("url", "No URL").strip()
|
|
129
|
+
|
|
130
|
+
# 计算相关性分数(简单的基于排名的分数)
|
|
131
|
+
relevance_score = max(0.1, 1.0 - (i * 0.1))
|
|
132
|
+
|
|
133
|
+
search_result = SearchResult(
|
|
134
|
+
title=title,
|
|
135
|
+
content=content,
|
|
136
|
+
source=source,
|
|
137
|
+
rank=i + 1,
|
|
138
|
+
relevance_score=relevance_score,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
search_results.append(search_result)
|
|
142
|
+
|
|
143
|
+
# 如果没有找到结果
|
|
144
|
+
if not search_results:
|
|
145
|
+
no_results = SearchResult(
|
|
146
|
+
title=f"No Results for '{query}'",
|
|
147
|
+
content=f"No search results found for query: '{query}'",
|
|
148
|
+
source="Search Engine",
|
|
149
|
+
rank=1,
|
|
150
|
+
relevance_score=0.0,
|
|
151
|
+
)
|
|
152
|
+
search_results.append(no_results)
|
|
153
|
+
|
|
154
|
+
except Exception as e:
|
|
155
|
+
self.logger.error(f"Error converting API response for query '{query}': {e}")
|
|
156
|
+
error_result = SearchResult(
|
|
157
|
+
title=f"Conversion Error for '{query}'",
|
|
158
|
+
content=f"Error processing search results: {str(e)}",
|
|
159
|
+
source="Error",
|
|
160
|
+
rank=1,
|
|
161
|
+
relevance_score=0.0,
|
|
162
|
+
)
|
|
163
|
+
search_results.append(error_result)
|
|
164
|
+
|
|
165
|
+
return search_results
|
|
166
|
+
|
|
167
|
+
def _create_legacy_chunks_for_compatibility(self, search_session: SearchSession) -> list[str]:
|
|
168
|
+
"""
|
|
169
|
+
为向后兼容性创建legacy格式的retriver_chunks
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
search_session: 搜索会话对象
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
List[str]: legacy格式的搜索结果字符串列表
|
|
176
|
+
"""
|
|
177
|
+
legacy_chunks = []
|
|
178
|
+
|
|
179
|
+
for query_result in search_session.query_results:
|
|
180
|
+
for result in query_result.results:
|
|
181
|
+
legacy_chunk = f"""[Search Result {result.rank} for '{query_result.query}']
|
|
182
|
+
Title: {result.title}
|
|
183
|
+
Content: {result.content}
|
|
184
|
+
Source: {result.source}"""
|
|
185
|
+
legacy_chunks.append(legacy_chunk)
|
|
186
|
+
|
|
187
|
+
return legacy_chunks
|
|
188
|
+
|
|
189
|
+
def _log_search_summary(self, context: ModelContext, total_queries: int, total_results: int):
|
|
190
|
+
"""记录搜索摘要信息"""
|
|
191
|
+
original_chunks = len(context.retriver_chunks) if context.retriver_chunks else 0
|
|
192
|
+
|
|
193
|
+
self.logger.info(
|
|
194
|
+
f"Search completed: "
|
|
195
|
+
f"Queries={total_queries}, "
|
|
196
|
+
f"Total_results={total_results}, "
|
|
197
|
+
f"Original_chunks={original_chunks}, "
|
|
198
|
+
f"Context_UUID={context.uuid}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def execute(self, context: ModelContext) -> ModelContext:
|
|
202
|
+
"""
|
|
203
|
+
执行搜索并将结果集成到ModelContext中
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
context: ModelContext对象,包含搜索查询
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
ModelContext: 更新了搜索结果的上下文
|
|
210
|
+
"""
|
|
211
|
+
try:
|
|
212
|
+
# 获取搜索查询
|
|
213
|
+
search_queries = context.get_search_queries()
|
|
214
|
+
self.logger.debug(
|
|
215
|
+
f"BochaSearchTool processing {len(search_queries)} queries for context {context.uuid}"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# 如果没有搜索查询,直接返回原上下文
|
|
219
|
+
if not search_queries:
|
|
220
|
+
self.logger.info("No search queries provided, returning original context")
|
|
221
|
+
return context
|
|
222
|
+
|
|
223
|
+
# 创建搜索会话(如果还没有)
|
|
224
|
+
if not context.search_session:
|
|
225
|
+
context.create_search_session(context.raw_question)
|
|
226
|
+
|
|
227
|
+
# 执行所有搜索查询
|
|
228
|
+
total_results = 0
|
|
229
|
+
|
|
230
|
+
for query in search_queries:
|
|
231
|
+
self.logger.debug(f"Executing search for query: '{query}'")
|
|
232
|
+
|
|
233
|
+
# 执行搜索
|
|
234
|
+
api_response = self._execute_single_search(query)
|
|
235
|
+
execution_time = api_response.get("_execution_time_ms", 0)
|
|
236
|
+
|
|
237
|
+
# 转换为SearchResult对象
|
|
238
|
+
search_results = self._convert_api_response_to_search_results(api_response, query)
|
|
239
|
+
|
|
240
|
+
# 计算总结果数(从API响应中获取,如果可用)
|
|
241
|
+
total_count_from_api = len(
|
|
242
|
+
api_response.get("data", {}).get("webPages", {}).get("value", [])
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# 添加搜索结果到上下文
|
|
246
|
+
context.add_search_results(
|
|
247
|
+
query=query,
|
|
248
|
+
results=search_results,
|
|
249
|
+
search_engine=self.search_engine_name,
|
|
250
|
+
execution_time_ms=execution_time,
|
|
251
|
+
total_results_count=total_count_from_api,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
total_results += len(search_results)
|
|
255
|
+
|
|
256
|
+
self.logger.debug(f"Query '{query}' returned {len(search_results)} results")
|
|
257
|
+
|
|
258
|
+
# 为向后兼容性更新retriver_chunks
|
|
259
|
+
if context.search_session:
|
|
260
|
+
legacy_chunks = self._create_legacy_chunks_for_compatibility(context.search_session)
|
|
261
|
+
if context.retriver_chunks is None:
|
|
262
|
+
context.retriver_chunks = []
|
|
263
|
+
context.retriver_chunks.extend(legacy_chunks)
|
|
264
|
+
|
|
265
|
+
# 更新搜索计数
|
|
266
|
+
self.search_count += 1
|
|
267
|
+
|
|
268
|
+
# 记录搜索摘要
|
|
269
|
+
self._log_search_summary(context, len(search_queries), total_results)
|
|
270
|
+
|
|
271
|
+
# 更新工具配置记录搜索执行信息
|
|
272
|
+
search_execution_info = {
|
|
273
|
+
"bocha_search_executed": True,
|
|
274
|
+
"queries_count": len(search_queries),
|
|
275
|
+
"total_results": total_results,
|
|
276
|
+
"search_engine": self.search_engine_name,
|
|
277
|
+
"execution_timestamp": int(time.time() * 1000),
|
|
278
|
+
"session_id": (
|
|
279
|
+
context.search_session.session_id if context.search_session else None
|
|
280
|
+
),
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
context.update_tool_config({"bocha_search_info": search_execution_info})
|
|
284
|
+
|
|
285
|
+
return context
|
|
286
|
+
|
|
287
|
+
except Exception as e:
|
|
288
|
+
self.logger.error(f"BochaSearchTool execution failed: {e}", exc_info=True)
|
|
289
|
+
|
|
290
|
+
# 错误处理:记录错误到工具配置中
|
|
291
|
+
error_info = {
|
|
292
|
+
"bocha_search_error": str(e),
|
|
293
|
+
"error_timestamp": int(time.time() * 1000),
|
|
294
|
+
"attempted_queries": (search_queries if "search_queries" in locals() else []),
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
context.update_tool_config({"bocha_search_error": error_info})
|
|
298
|
+
|
|
299
|
+
return context
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
class EnhancedBochaSearchTool(BochaSearchTool):
|
|
303
|
+
"""
|
|
304
|
+
增强版Bocha搜索工具,支持更多定制化选项和结果优化
|
|
305
|
+
使用新的分层搜索结构和ModelContext
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(self, config: dict, **kwargs):
|
|
309
|
+
super().__init__(config, **kwargs)
|
|
310
|
+
|
|
311
|
+
self.deduplicate_results = config.get("deduplicate_results", True)
|
|
312
|
+
self.max_total_chunks = config.get("max_total_chunks", 20)
|
|
313
|
+
self.preserve_chunk_order = config.get("preserve_chunk_order", True)
|
|
314
|
+
self.min_relevance_score = config.get("min_relevance_score", 0.1)
|
|
315
|
+
self.diversity_threshold = config.get("diversity_threshold", 0.8) # 多样性阈值
|
|
316
|
+
|
|
317
|
+
self.logger.info(
|
|
318
|
+
f"EnhancedBochaSearchTool initialized: "
|
|
319
|
+
f"deduplicate={self.deduplicate_results}, "
|
|
320
|
+
f"max_total={self.max_total_chunks}, "
|
|
321
|
+
f"min_relevance={self.min_relevance_score}"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def _calculate_content_similarity(self, content1: str, content2: str) -> float:
|
|
325
|
+
"""计算两个内容的相似度(简单的词汇重叠)"""
|
|
326
|
+
words1 = set(content1.lower().split())
|
|
327
|
+
words2 = set(content2.lower().split())
|
|
328
|
+
|
|
329
|
+
if not words1 or not words2:
|
|
330
|
+
return 0.0
|
|
331
|
+
|
|
332
|
+
intersection = words1.intersection(words2)
|
|
333
|
+
union = words1.union(words2)
|
|
334
|
+
|
|
335
|
+
return len(intersection) / len(union) if union else 0.0
|
|
336
|
+
|
|
337
|
+
def _deduplicate_search_results(self, search_results: list[SearchResult]) -> list[SearchResult]:
|
|
338
|
+
"""去重和多样性优化搜索结果"""
|
|
339
|
+
if not self.deduplicate_results or not search_results:
|
|
340
|
+
return search_results
|
|
341
|
+
|
|
342
|
+
# 按相关性分数排序
|
|
343
|
+
sorted_results = sorted(search_results, key=lambda x: x.relevance_score, reverse=True)
|
|
344
|
+
|
|
345
|
+
deduplicated = []
|
|
346
|
+
seen_sources = set()
|
|
347
|
+
|
|
348
|
+
for result in sorted_results:
|
|
349
|
+
# 检查是否已有相同源
|
|
350
|
+
if result.source in seen_sources:
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
# 检查与已选结果的相似度
|
|
354
|
+
is_diverse = True
|
|
355
|
+
for existing in deduplicated:
|
|
356
|
+
similarity = self._calculate_content_similarity(result.content, existing.content)
|
|
357
|
+
if similarity > self.diversity_threshold:
|
|
358
|
+
is_diverse = False
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
# 检查相关性分数阈值
|
|
362
|
+
if result.relevance_score >= self.min_relevance_score and is_diverse:
|
|
363
|
+
deduplicated.append(result)
|
|
364
|
+
seen_sources.add(result.source)
|
|
365
|
+
|
|
366
|
+
# 保持原有排名顺序(如果要求保持顺序)
|
|
367
|
+
if self.preserve_chunk_order:
|
|
368
|
+
# 按原来的rank排序
|
|
369
|
+
deduplicated = sorted(deduplicated, key=lambda x: x.rank)
|
|
370
|
+
|
|
371
|
+
return deduplicated
|
|
372
|
+
|
|
373
|
+
def _optimize_search_session(self, context: ModelContext) -> None:
|
|
374
|
+
"""优化搜索会话结果"""
|
|
375
|
+
if not context.search_session or not context.search_session.query_results:
|
|
376
|
+
return
|
|
377
|
+
|
|
378
|
+
total_optimized = 0
|
|
379
|
+
|
|
380
|
+
for query_result in context.search_session.query_results:
|
|
381
|
+
original_count = len(query_result.results)
|
|
382
|
+
|
|
383
|
+
# 应用去重和多样性优化
|
|
384
|
+
query_result.results = self._deduplicate_search_results(query_result.results)
|
|
385
|
+
|
|
386
|
+
optimized_count = len(query_result.results)
|
|
387
|
+
total_optimized += original_count - optimized_count
|
|
388
|
+
|
|
389
|
+
if original_count != optimized_count:
|
|
390
|
+
self.logger.debug(
|
|
391
|
+
f"Query '{query_result.query}': "
|
|
392
|
+
f"optimized from {original_count} to {optimized_count} results"
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if total_optimized > 0:
|
|
396
|
+
self.logger.info(
|
|
397
|
+
f"Search optimization removed {total_optimized} duplicate/low-quality results"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
def _limit_total_results(self, context: ModelContext) -> None:
|
|
401
|
+
"""限制总的搜索结果数量"""
|
|
402
|
+
if not context.search_session:
|
|
403
|
+
return
|
|
404
|
+
|
|
405
|
+
total_results = context.search_session.get_total_results_count()
|
|
406
|
+
|
|
407
|
+
if total_results <= self.max_total_chunks:
|
|
408
|
+
return
|
|
409
|
+
|
|
410
|
+
# 收集所有结果并按相关性排序
|
|
411
|
+
all_results = []
|
|
412
|
+
for query_result in context.search_session.query_results:
|
|
413
|
+
for result in query_result.results:
|
|
414
|
+
all_results.append((query_result, result))
|
|
415
|
+
|
|
416
|
+
# 按相关性分数排序
|
|
417
|
+
all_results.sort(key=lambda x: x[1].relevance_score, reverse=True)
|
|
418
|
+
|
|
419
|
+
# 清空现有结果
|
|
420
|
+
for query_result in context.search_session.query_results:
|
|
421
|
+
query_result.results = []
|
|
422
|
+
|
|
423
|
+
# 重新分配最佳结果,保持每个查询至少有一个结果
|
|
424
|
+
results_per_query = self.max_total_chunks // len(context.search_session.query_results)
|
|
425
|
+
remaining_slots = self.max_total_chunks % len(context.search_session.query_results)
|
|
426
|
+
|
|
427
|
+
query_result_counts = dict.fromkeys(context.search_session.query_results, 0)
|
|
428
|
+
|
|
429
|
+
for query_result, result in all_results:
|
|
430
|
+
current_count = query_result_counts[query_result]
|
|
431
|
+
max_for_this_query = results_per_query + (1 if remaining_slots > 0 else 0)
|
|
432
|
+
|
|
433
|
+
if current_count < max_for_this_query:
|
|
434
|
+
query_result.results.append(result)
|
|
435
|
+
query_result_counts[query_result] += 1
|
|
436
|
+
|
|
437
|
+
if current_count + 1 == max_for_this_query and remaining_slots > 0:
|
|
438
|
+
remaining_slots -= 1
|
|
439
|
+
|
|
440
|
+
if sum(query_result_counts.values()) >= self.max_total_chunks:
|
|
441
|
+
break
|
|
442
|
+
|
|
443
|
+
new_total = context.search_session.get_total_results_count()
|
|
444
|
+
self.logger.info(f"Limited search results from {total_results} to {new_total}")
|
|
445
|
+
|
|
446
|
+
def _update_legacy_chunks(self, context: ModelContext) -> None:
|
|
447
|
+
"""更新legacy格式的retriver_chunks以反映优化后的结果"""
|
|
448
|
+
if not context.search_session:
|
|
449
|
+
return
|
|
450
|
+
|
|
451
|
+
# 重新生成legacy chunks
|
|
452
|
+
optimized_chunks = self._create_legacy_chunks_for_compatibility(context.search_session)
|
|
453
|
+
|
|
454
|
+
# 合并到现有chunks中(保持之前可能存在的非搜索chunks)
|
|
455
|
+
non_search_chunks = []
|
|
456
|
+
if context.retriver_chunks:
|
|
457
|
+
# 尝试识别非搜索chunks(不包含"[Search Result"标记的)
|
|
458
|
+
for chunk in context.retriver_chunks:
|
|
459
|
+
if not chunk.strip().startswith("[Search Result"):
|
|
460
|
+
non_search_chunks.append(chunk)
|
|
461
|
+
|
|
462
|
+
context.retriver_chunks = non_search_chunks + optimized_chunks
|
|
463
|
+
|
|
464
|
+
def execute(self, context: ModelContext) -> ModelContext:
|
|
465
|
+
"""增强版执行逻辑,包含结果优化"""
|
|
466
|
+
try:
|
|
467
|
+
# 先执行基础搜索
|
|
468
|
+
context = super().execute(context)
|
|
469
|
+
|
|
470
|
+
# 应用增强功能
|
|
471
|
+
if context.search_session and context.search_session.query_results:
|
|
472
|
+
# 1. 优化搜索会话结果(去重、多样性)
|
|
473
|
+
self._optimize_search_session(context)
|
|
474
|
+
|
|
475
|
+
# 2. 限制总结果数量
|
|
476
|
+
self._limit_total_results(context)
|
|
477
|
+
|
|
478
|
+
# 3. 更新legacy chunks以反映优化
|
|
479
|
+
self._update_legacy_chunks(context)
|
|
480
|
+
|
|
481
|
+
# 4. 更新工具配置记录优化信息
|
|
482
|
+
optimization_info = {
|
|
483
|
+
"enhanced_search_applied": True,
|
|
484
|
+
"deduplicate_results": self.deduplicate_results,
|
|
485
|
+
"max_total_chunks": self.max_total_chunks,
|
|
486
|
+
"min_relevance_score": self.min_relevance_score,
|
|
487
|
+
"final_results_count": context.search_session.get_total_results_count(),
|
|
488
|
+
"final_chunks_count": (
|
|
489
|
+
len(context.retriver_chunks) if context.retriver_chunks else 0
|
|
490
|
+
),
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
context.update_tool_config({"enhanced_search_info": optimization_info})
|
|
494
|
+
|
|
495
|
+
self.logger.info(
|
|
496
|
+
f"Enhanced search completed: {optimization_info['final_results_count']} results, "
|
|
497
|
+
f"{optimization_info['final_chunks_count']} chunks"
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
return context
|
|
501
|
+
|
|
502
|
+
except Exception as e:
|
|
503
|
+
self.logger.error(f"EnhancedBochaSearchTool execution failed: {e}", exc_info=True)
|
|
504
|
+
|
|
505
|
+
# 错误处理:记录错误并继续基础搜索结果
|
|
506
|
+
error_info = {
|
|
507
|
+
"enhanced_search_error": str(e),
|
|
508
|
+
"error_timestamp": int(time.time() * 1000),
|
|
509
|
+
"fallback_to_basic": True,
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
context.update_tool_config({"enhanced_search_error": error_info})
|
|
513
|
+
|
|
514
|
+
return context
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from sage.libs.foundation.tools.tool import BaseTool
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class text_detector(BaseTool):
|
|
10
|
+
def __init__(self):
|
|
11
|
+
super().__init__(
|
|
12
|
+
tool_name="Text_Detector_Tool",
|
|
13
|
+
tool_description="A tool that detects text in an image using EasyOCR.",
|
|
14
|
+
input_types={
|
|
15
|
+
"image": "str - The path to the image file.",
|
|
16
|
+
"languages": "list - A list of language codes for the OCR model.",
|
|
17
|
+
"detail": "int - The level of detail in the output. Set to 0 for simpler output, 1 for detailed output.",
|
|
18
|
+
},
|
|
19
|
+
output_type="list - A list of detected text blocks.",
|
|
20
|
+
demo_commands=[
|
|
21
|
+
{
|
|
22
|
+
"command": 'execution = tool.execute(image="path/to/image.png", languages=["en"])',
|
|
23
|
+
"description": "Detect text in an image using the default language (English).",
|
|
24
|
+
},
|
|
25
|
+
{
|
|
26
|
+
"command": 'execution = tool.execute(image="path/to/image.png", languages=["en", "de"])',
|
|
27
|
+
"description": "Detect text in an image using multiple languages (English and German).",
|
|
28
|
+
},
|
|
29
|
+
{
|
|
30
|
+
"command": 'execution = tool.execute(image="path/to/image.png", languages=["en"], detail=0)',
|
|
31
|
+
"description": "Detect text in an image with simpler output (text without coordinates and scores).",
|
|
32
|
+
},
|
|
33
|
+
],
|
|
34
|
+
)
|
|
35
|
+
self.tool_version = "1.0.0"
|
|
36
|
+
self.frequently_used_language = {
|
|
37
|
+
"ch_sim": "Simplified Chinese",
|
|
38
|
+
"ch_tra": "Traditional Chinese",
|
|
39
|
+
"de": "German",
|
|
40
|
+
"en": "English",
|
|
41
|
+
"es": "Spanish",
|
|
42
|
+
"fr": "French",
|
|
43
|
+
"hi": "Hindi",
|
|
44
|
+
"ja": "Japanese",
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
def build_tool(self, languages=None):
|
|
48
|
+
"""
|
|
49
|
+
Builds and returns the EasyOCR reader model.
|
|
50
|
+
|
|
51
|
+
Parameters:
|
|
52
|
+
languages (list): A list of language codes for the OCR model.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
easyocr.Reader: An initialized EasyOCR Reader object.
|
|
56
|
+
"""
|
|
57
|
+
languages = languages or ["en"] # Default to English if no languages provided
|
|
58
|
+
try:
|
|
59
|
+
import easyocr
|
|
60
|
+
|
|
61
|
+
reader = easyocr.Reader(languages)
|
|
62
|
+
return reader
|
|
63
|
+
except ImportError:
|
|
64
|
+
raise ImportError("Please install the EasyOCR package using 'pip install easyocr'.")
|
|
65
|
+
except Exception as e:
|
|
66
|
+
print(f"Error building the OCR tool: {e}")
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
def execute(
|
|
70
|
+
self,
|
|
71
|
+
image,
|
|
72
|
+
languages=None,
|
|
73
|
+
max_retries=10,
|
|
74
|
+
retry_delay=5,
|
|
75
|
+
clear_cuda_cache=False,
|
|
76
|
+
**kwargs,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Executes the OCR tool to detect text in the provided image.
|
|
80
|
+
|
|
81
|
+
Parameters:
|
|
82
|
+
image (str): The path to the image file.
|
|
83
|
+
languages (list): A list of language codes for the OCR model.
|
|
84
|
+
max_retries (int): Maximum number of retry attempts.
|
|
85
|
+
retry_delay (int): Delay in seconds between retry attempts.
|
|
86
|
+
clear_cuda_cache (bool): Whether to clear CUDA cache on out-of-memory errors.
|
|
87
|
+
**kwargs: Additional keyword arguments for the OCR reader.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
list: A list of detected text blocks.
|
|
91
|
+
"""
|
|
92
|
+
languages = languages or ["en"]
|
|
93
|
+
|
|
94
|
+
for attempt in range(max_retries):
|
|
95
|
+
try:
|
|
96
|
+
reader = self.build_tool(languages)
|
|
97
|
+
if reader is None:
|
|
98
|
+
raise ValueError("Failed to build the OCR tool.")
|
|
99
|
+
|
|
100
|
+
result = reader.readtext(image, **kwargs)
|
|
101
|
+
try:
|
|
102
|
+
# detail = 1: Convert numpy types to standard Python types
|
|
103
|
+
from typing import Any, cast
|
|
104
|
+
|
|
105
|
+
cleaned_result = [
|
|
106
|
+
(
|
|
107
|
+
[[int(coord[0]), int(coord[1])] for coord in cast(Any, item)[0]],
|
|
108
|
+
cast(Any, item)[1],
|
|
109
|
+
round(float(cast(Any, item)[2]), 2),
|
|
110
|
+
)
|
|
111
|
+
for item in result
|
|
112
|
+
]
|
|
113
|
+
return cleaned_result
|
|
114
|
+
except Exception:
|
|
115
|
+
# detail = 0
|
|
116
|
+
return result
|
|
117
|
+
|
|
118
|
+
except RuntimeError as e:
|
|
119
|
+
if "CUDA out of memory" in str(e):
|
|
120
|
+
print(f"CUDA out of memory error on attempt {attempt + 1}.")
|
|
121
|
+
if clear_cuda_cache:
|
|
122
|
+
print("Clearing CUDA cache and retrying...")
|
|
123
|
+
torch.cuda.empty_cache()
|
|
124
|
+
else:
|
|
125
|
+
print(f"Retrying in {retry_delay} seconds...")
|
|
126
|
+
time.sleep(retry_delay)
|
|
127
|
+
continue
|
|
128
|
+
else:
|
|
129
|
+
print(f"Runtime error: {e}")
|
|
130
|
+
break
|
|
131
|
+
except Exception as e:
|
|
132
|
+
print(f"Error detecting text: {e}")
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
print(f"Failed to detect text after {max_retries} attempts.")
|
|
136
|
+
return []
|
|
137
|
+
|
|
138
|
+
def get_metadata(self):
|
|
139
|
+
"""
|
|
140
|
+
Returns the metadata for the Text_Detector_Tool.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
dict: A dictionary containing the tool's metadata.
|
|
144
|
+
"""
|
|
145
|
+
metadata = super().get_metadata()
|
|
146
|
+
return metadata
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
if __name__ == "__main__":
|
|
150
|
+
import json
|
|
151
|
+
|
|
152
|
+
# Get the directory of the current script
|
|
153
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
154
|
+
|
|
155
|
+
# Example usage of the Text_Detector_Tool
|
|
156
|
+
tool = text_detector()
|
|
157
|
+
|
|
158
|
+
# Get tool metadata
|
|
159
|
+
metadata = tool.get_metadata()
|
|
160
|
+
print(metadata)
|
|
161
|
+
|
|
162
|
+
# Construct the full path to the image using the script's directory
|
|
163
|
+
# relative_image_path = "examples/chinese_tra.jpg"
|
|
164
|
+
# relative_image_path = "examples/chinese.jpg"
|
|
165
|
+
relative_image_path = "examples/english.png"
|
|
166
|
+
image_path = os.path.join(script_dir, relative_image_path)
|
|
167
|
+
|
|
168
|
+
# Check if the image file exists
|
|
169
|
+
if not os.path.exists(image_path):
|
|
170
|
+
print(f"Image file not found: {image_path}")
|
|
171
|
+
print("Please provide a valid image file in the 'examples/' directory.")
|
|
172
|
+
exit(1)
|
|
173
|
+
|
|
174
|
+
# Execute the tool
|
|
175
|
+
try:
|
|
176
|
+
# execution = tool.execute(image=image_path, languages=["en", "ch_sim"])
|
|
177
|
+
# execution = tool.execute(image=image_path, languages=["en", "ch_tra"])
|
|
178
|
+
execution = tool.execute(image=image_path, languages=["en"])
|
|
179
|
+
print(json.dumps(execution))
|
|
180
|
+
|
|
181
|
+
print("Detected Text:", execution)
|
|
182
|
+
except ValueError as e:
|
|
183
|
+
print(f"Execution failed: {e}")
|
|
184
|
+
|
|
185
|
+
print("Done!")
|