isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,565 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
from .critic_evaluation import CriticEvaluation
|
|
9
|
+
from .quality_label import QualityLabel
|
|
10
|
+
from .search_query_results import SearchQueryResults
|
|
11
|
+
from .search_result import SearchResult
|
|
12
|
+
from .search_session import SearchSession
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ModelContext:
|
|
17
|
+
# Packet metadata
|
|
18
|
+
sequence: int = 0
|
|
19
|
+
timestamp: int = field(default_factory=lambda: int(time.time() * 1000))
|
|
20
|
+
# Generator content
|
|
21
|
+
raw_question: str | None = None
|
|
22
|
+
# 保留原有的retriver_chunks用于向后兼容,但优先使用search_session
|
|
23
|
+
retriver_chunks: list[str] = field(default_factory=list)
|
|
24
|
+
# 新的分层搜索结果结构
|
|
25
|
+
search_session: SearchSession | None = None
|
|
26
|
+
prompts: list[dict[str, str]] = field(default_factory=list)
|
|
27
|
+
response: str | None = None
|
|
28
|
+
uuid: str = field(default_factory=lambda: str(uuid4()))
|
|
29
|
+
tool_name: str | None = None
|
|
30
|
+
evaluation: CriticEvaluation | None = None
|
|
31
|
+
# Tool configuration - 存储工具相关的配置和中间结果
|
|
32
|
+
tool_config: dict[str, Any] = field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
def __str__(self) -> str:
|
|
35
|
+
"""格式化显示ModelContext内容"""
|
|
36
|
+
# 时间格式化
|
|
37
|
+
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.timestamp / 1000))
|
|
38
|
+
|
|
39
|
+
# 构建输出字符串
|
|
40
|
+
output_lines = []
|
|
41
|
+
output_lines.append("=" * 80)
|
|
42
|
+
|
|
43
|
+
# 标题行
|
|
44
|
+
title_parts = [f"🤖 AI Processing Result [ID: {self.uuid[:8]}]"]
|
|
45
|
+
if self.tool_name:
|
|
46
|
+
tool_emoji = self._get_tool_emoji(self.tool_name)
|
|
47
|
+
title_parts.append(f"{tool_emoji} Tool: {self.tool_name}")
|
|
48
|
+
|
|
49
|
+
output_lines.append(" | ".join(title_parts))
|
|
50
|
+
output_lines.append(f"📅 Time: {timestamp_str} | Sequence: {self.sequence}")
|
|
51
|
+
|
|
52
|
+
# 评估状态行
|
|
53
|
+
if self.evaluation:
|
|
54
|
+
quality_emoji = self._get_quality_emoji(self.evaluation.label)
|
|
55
|
+
status_parts = [
|
|
56
|
+
f"{quality_emoji} Quality: {self.evaluation.label.value}",
|
|
57
|
+
f"Confidence: {self.evaluation.confidence:.2f}",
|
|
58
|
+
f"Output Ready: {'✅' if self.evaluation.ready_for_output else '❌'}",
|
|
59
|
+
]
|
|
60
|
+
output_lines.append("📊 " + " | ".join(status_parts))
|
|
61
|
+
|
|
62
|
+
output_lines.append("=" * 80)
|
|
63
|
+
|
|
64
|
+
# 原始问题
|
|
65
|
+
if self.raw_question:
|
|
66
|
+
output_lines.append("❓ Original Question:")
|
|
67
|
+
output_lines.append(f" {self.raw_question}")
|
|
68
|
+
output_lines.append("")
|
|
69
|
+
|
|
70
|
+
# 工具配置信息
|
|
71
|
+
if self.tool_config:
|
|
72
|
+
output_lines.append("🔧 Tool Configuration:")
|
|
73
|
+
self._format_tool_config(output_lines)
|
|
74
|
+
output_lines.append("")
|
|
75
|
+
|
|
76
|
+
# 搜索结果信息(优先使用新的search_session结构)
|
|
77
|
+
if self.search_session and self.search_session.query_results:
|
|
78
|
+
output_lines.append(
|
|
79
|
+
f"🔍 Search Results ({self.search_session.get_total_results_count()} total):"
|
|
80
|
+
)
|
|
81
|
+
self._format_search_session(output_lines)
|
|
82
|
+
output_lines.append("")
|
|
83
|
+
elif self.retriver_chunks:
|
|
84
|
+
# 向后兼容:显示老格式的检索结果
|
|
85
|
+
output_lines.append(f"📚 Retrieved Information ({len(self.retriver_chunks)} sources):")
|
|
86
|
+
for i, chunk in enumerate(self.retriver_chunks[:3], 1):
|
|
87
|
+
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
|
88
|
+
output_lines.append(f" [{i}] {preview}")
|
|
89
|
+
|
|
90
|
+
if len(self.retriver_chunks) > 3:
|
|
91
|
+
output_lines.append(f" ... and {len(self.retriver_chunks) - 3} more sources")
|
|
92
|
+
output_lines.append("")
|
|
93
|
+
|
|
94
|
+
# 处理步骤信息
|
|
95
|
+
if self.prompts:
|
|
96
|
+
output_lines.append("⚙️ Processing Steps:")
|
|
97
|
+
system_prompts = [p for p in self.prompts if p.get("role") == "system"]
|
|
98
|
+
user_prompts = [p for p in self.prompts if p.get("role") == "user"]
|
|
99
|
+
|
|
100
|
+
if system_prompts:
|
|
101
|
+
output_lines.append(f" • System instructions: {len(system_prompts)} phases")
|
|
102
|
+
if user_prompts:
|
|
103
|
+
last_user_prompt = user_prompts[-1].get("content", "")
|
|
104
|
+
if last_user_prompt and last_user_prompt != self.raw_question:
|
|
105
|
+
preview = (
|
|
106
|
+
last_user_prompt[:100] + "..."
|
|
107
|
+
if len(last_user_prompt) > 100
|
|
108
|
+
else last_user_prompt
|
|
109
|
+
)
|
|
110
|
+
output_lines.append(f" • Specific task: {preview}")
|
|
111
|
+
output_lines.append("")
|
|
112
|
+
|
|
113
|
+
# AI响应
|
|
114
|
+
if self.response:
|
|
115
|
+
output_lines.append("🎯 AI Response:")
|
|
116
|
+
response_lines = self.response.split("\n")
|
|
117
|
+
for line in response_lines:
|
|
118
|
+
output_lines.append(f" {line}")
|
|
119
|
+
output_lines.append("")
|
|
120
|
+
|
|
121
|
+
# 评估详情
|
|
122
|
+
if self.evaluation:
|
|
123
|
+
output_lines.append("🔍 Evaluation Details:")
|
|
124
|
+
output_lines.append(f" • Reasoning: {self.evaluation.reasoning}")
|
|
125
|
+
|
|
126
|
+
if self.evaluation.specific_issues:
|
|
127
|
+
output_lines.append(f" • Issues: {', '.join(self.evaluation.specific_issues)}")
|
|
128
|
+
|
|
129
|
+
if self.evaluation.suggestions:
|
|
130
|
+
output_lines.append(f" • Suggestions: {', '.join(self.evaluation.suggestions)}")
|
|
131
|
+
|
|
132
|
+
if self.evaluation.should_return_to_chief:
|
|
133
|
+
output_lines.append(" • ⚠️ Should return to Chief for reprocessing")
|
|
134
|
+
output_lines.append("")
|
|
135
|
+
|
|
136
|
+
# 状态指示
|
|
137
|
+
status_indicators = []
|
|
138
|
+
if self.tool_name:
|
|
139
|
+
status_indicators.append(f"Tool: {self.tool_name}")
|
|
140
|
+
if self.response:
|
|
141
|
+
status_indicators.append("✅ Response Generated")
|
|
142
|
+
else:
|
|
143
|
+
status_indicators.append("⏳ Processing")
|
|
144
|
+
|
|
145
|
+
# 搜索结果状态
|
|
146
|
+
total_results = 0
|
|
147
|
+
if self.search_session:
|
|
148
|
+
total_results = self.search_session.get_total_results_count()
|
|
149
|
+
status_indicators.append(f"🔍 {total_results} search results")
|
|
150
|
+
elif self.retriver_chunks:
|
|
151
|
+
total_results = len(self.retriver_chunks)
|
|
152
|
+
status_indicators.append(f"📊 {total_results} chunks")
|
|
153
|
+
|
|
154
|
+
if self.evaluation:
|
|
155
|
+
status_indicators.append(f"🔍 Evaluated ({self.evaluation.label.value})")
|
|
156
|
+
if self.tool_config:
|
|
157
|
+
status_indicators.append("🔧 Tool Config")
|
|
158
|
+
|
|
159
|
+
if status_indicators:
|
|
160
|
+
output_lines.append(f"📋 Status: {' | '.join(status_indicators)}")
|
|
161
|
+
output_lines.append("")
|
|
162
|
+
|
|
163
|
+
output_lines.append("=" * 80)
|
|
164
|
+
return "\n".join(output_lines)
|
|
165
|
+
|
|
166
|
+
def _format_search_session(self, output_lines: list[str]) -> None:
|
|
167
|
+
"""格式化搜索会话的显示"""
|
|
168
|
+
if not self.search_session:
|
|
169
|
+
return
|
|
170
|
+
for i, query_result in enumerate(self.search_session.query_results, 1):
|
|
171
|
+
output_lines.append(
|
|
172
|
+
f" Query {i}: '{query_result.query}' ({query_result.get_results_count()} results)"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# 显示前3个结果
|
|
176
|
+
for j, result in enumerate(query_result.get_top_results(3), 1):
|
|
177
|
+
title_preview = (
|
|
178
|
+
result.title[:80] + "..." if len(result.title) > 80 else result.title
|
|
179
|
+
)
|
|
180
|
+
content_preview = (
|
|
181
|
+
result.content[:100] + "..." if len(result.content) > 100 else result.content
|
|
182
|
+
)
|
|
183
|
+
output_lines.append(f" [{j}] {title_preview}")
|
|
184
|
+
output_lines.append(f" {content_preview}")
|
|
185
|
+
output_lines.append(f" Source: {result.source}")
|
|
186
|
+
|
|
187
|
+
if query_result.get_results_count() > 3:
|
|
188
|
+
output_lines.append(
|
|
189
|
+
f" ... and {query_result.get_results_count() - 3} more results"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def _format_tool_config(self, output_lines: list[str]) -> None:
|
|
193
|
+
"""格式化工具配置信息的显示"""
|
|
194
|
+
for key, value in self.tool_config.items():
|
|
195
|
+
if key == "search_queries":
|
|
196
|
+
if isinstance(value, list) and value:
|
|
197
|
+
output_lines.append(f" • Search Queries ({len(value)}):")
|
|
198
|
+
for i, query in enumerate(value[:5], 1):
|
|
199
|
+
preview = query[:80] + "..." if len(query) > 80 else query
|
|
200
|
+
output_lines.append(f" [{i}] {preview}")
|
|
201
|
+
if len(value) > 5:
|
|
202
|
+
output_lines.append(f" ... and {len(value) - 5} more queries")
|
|
203
|
+
else:
|
|
204
|
+
output_lines.append(f" • Search Queries: {value}")
|
|
205
|
+
|
|
206
|
+
elif key == "search_analysis":
|
|
207
|
+
if isinstance(value, dict):
|
|
208
|
+
output_lines.append(" • Search Analysis:")
|
|
209
|
+
if "analysis" in value:
|
|
210
|
+
analysis_text = (
|
|
211
|
+
value["analysis"][:100] + "..."
|
|
212
|
+
if len(str(value["analysis"])) > 100
|
|
213
|
+
else value["analysis"]
|
|
214
|
+
)
|
|
215
|
+
output_lines.append(f" - Analysis: {analysis_text}")
|
|
216
|
+
if "reasoning" in value:
|
|
217
|
+
reasoning_text = (
|
|
218
|
+
value["reasoning"][:100] + "..."
|
|
219
|
+
if len(str(value["reasoning"])) > 100
|
|
220
|
+
else value["reasoning"]
|
|
221
|
+
)
|
|
222
|
+
output_lines.append(f" - Reasoning: {reasoning_text}")
|
|
223
|
+
else:
|
|
224
|
+
output_lines.append(f" • Search Analysis: {value}")
|
|
225
|
+
|
|
226
|
+
elif key == "optimization_metadata":
|
|
227
|
+
if isinstance(value, dict):
|
|
228
|
+
output_lines.append(" • Optimization Metadata:")
|
|
229
|
+
for meta_key, meta_value in value.items():
|
|
230
|
+
if isinstance(meta_value, (str, int, float, bool)):
|
|
231
|
+
output_lines.append(f" - {meta_key}: {meta_value}")
|
|
232
|
+
else:
|
|
233
|
+
output_lines.append(f" - {meta_key}: {type(meta_value).__name__}")
|
|
234
|
+
else:
|
|
235
|
+
output_lines.append(f" • Optimization Metadata: {value}")
|
|
236
|
+
|
|
237
|
+
else:
|
|
238
|
+
if isinstance(value, (list, dict)):
|
|
239
|
+
output_lines.append(
|
|
240
|
+
f" • {key.replace('_', ' ').title()}: {type(value).__name__}({len(value)} items)"
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
value_str = str(value)
|
|
244
|
+
if len(value_str) > 50:
|
|
245
|
+
value_str = value_str[:50] + "..."
|
|
246
|
+
output_lines.append(f" • {key.replace('_', ' ').title()}: {value_str}")
|
|
247
|
+
|
|
248
|
+
def _get_tool_emoji(self, tool_name: str) -> str:
|
|
249
|
+
"""根据工具名称返回对应的emoji"""
|
|
250
|
+
tool_emojis = {
|
|
251
|
+
"web_search": "🔍",
|
|
252
|
+
"knowledge_retrieval": "📖",
|
|
253
|
+
"calculator": "🧮",
|
|
254
|
+
"code_executor": "💻",
|
|
255
|
+
"data_analyzer": "📊",
|
|
256
|
+
"translation": "🌐",
|
|
257
|
+
"summarizer": "📝",
|
|
258
|
+
"fact_checker": "✅",
|
|
259
|
+
"image_analyzer": "🖼️",
|
|
260
|
+
"weather_service": "🌤️",
|
|
261
|
+
"stock_market": "📈",
|
|
262
|
+
"news_aggregator": "📰",
|
|
263
|
+
"direct_response": "💭",
|
|
264
|
+
"error_handler": "⚠️",
|
|
265
|
+
}
|
|
266
|
+
return tool_emojis.get(tool_name, "🔧")
|
|
267
|
+
|
|
268
|
+
def _get_quality_emoji(self, quality_label: QualityLabel) -> str:
|
|
269
|
+
"""根据质量标签返回对应的emoji"""
|
|
270
|
+
quality_emojis = {
|
|
271
|
+
QualityLabel.COMPLETE_EXCELLENT: "🌟",
|
|
272
|
+
QualityLabel.COMPLETE_GOOD: "✅",
|
|
273
|
+
QualityLabel.PARTIAL_NEEDS_IMPROVEMENT: "⚡",
|
|
274
|
+
QualityLabel.INCOMPLETE_MISSING_INFO: "❓",
|
|
275
|
+
QualityLabel.FAILED_POOR_QUALITY: "❌",
|
|
276
|
+
QualityLabel.ERROR_INVALID: "⚠️",
|
|
277
|
+
}
|
|
278
|
+
return quality_emojis.get(quality_label, "❔")
|
|
279
|
+
|
|
280
|
+
def to_dict(self) -> dict[str, Any]:
|
|
281
|
+
"""转换为字典格式"""
|
|
282
|
+
result: dict[str, Any] = {}
|
|
283
|
+
|
|
284
|
+
# 基础字段
|
|
285
|
+
result["sequence"] = self.sequence
|
|
286
|
+
result["timestamp"] = self.timestamp
|
|
287
|
+
result["raw_question"] = self.raw_question
|
|
288
|
+
result["retriver_chunks"] = self.retriver_chunks.copy() if self.retriver_chunks else []
|
|
289
|
+
result["prompts"] = self.prompts.copy() if self.prompts else []
|
|
290
|
+
result["response"] = self.response
|
|
291
|
+
result["uuid"] = self.uuid
|
|
292
|
+
result["tool_name"] = self.tool_name
|
|
293
|
+
result["tool_config"] = (
|
|
294
|
+
self._deep_copy_tool_config(self.tool_config) if self.tool_config else {}
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# 搜索会话
|
|
298
|
+
if self.search_session:
|
|
299
|
+
result["search_session"] = self.search_session.to_dict()
|
|
300
|
+
else:
|
|
301
|
+
result["search_session"] = None
|
|
302
|
+
|
|
303
|
+
# 处理evaluation字段
|
|
304
|
+
if self.evaluation:
|
|
305
|
+
eval_dict = {
|
|
306
|
+
"label": self.evaluation.label.value,
|
|
307
|
+
"confidence": self.evaluation.confidence,
|
|
308
|
+
"reasoning": self.evaluation.reasoning,
|
|
309
|
+
"specific_issues": self.evaluation.specific_issues.copy(),
|
|
310
|
+
"suggestions": self.evaluation.suggestions.copy(),
|
|
311
|
+
"should_return_to_chief": self.evaluation.should_return_to_chief,
|
|
312
|
+
"ready_for_output": self.evaluation.ready_for_output,
|
|
313
|
+
}
|
|
314
|
+
result["evaluation"] = eval_dict
|
|
315
|
+
else:
|
|
316
|
+
result["evaluation"] = None
|
|
317
|
+
|
|
318
|
+
return result
|
|
319
|
+
|
|
320
|
+
def _deep_copy_tool_config(self, config: dict[str, Any]) -> dict[str, Any]:
|
|
321
|
+
"""深拷贝tool_config"""
|
|
322
|
+
import copy
|
|
323
|
+
|
|
324
|
+
return copy.deepcopy(config)
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def from_dict(cls, data: dict[str, Any]) -> "ModelContext":
|
|
328
|
+
"""从字典创建ModelContext实例"""
|
|
329
|
+
data = data.copy()
|
|
330
|
+
|
|
331
|
+
# 处理evaluation字段
|
|
332
|
+
evaluation = None
|
|
333
|
+
if data.get("evaluation"):
|
|
334
|
+
eval_data = data["evaluation"]
|
|
335
|
+
label = QualityLabel(eval_data["label"])
|
|
336
|
+
|
|
337
|
+
evaluation = CriticEvaluation(
|
|
338
|
+
label=label,
|
|
339
|
+
confidence=eval_data.get("confidence", 0.0),
|
|
340
|
+
reasoning=eval_data.get("reasoning", ""),
|
|
341
|
+
specific_issues=eval_data.get("specific_issues", []),
|
|
342
|
+
suggestions=eval_data.get("suggestions", []),
|
|
343
|
+
should_return_to_chief=eval_data.get("should_return_to_chief", False),
|
|
344
|
+
ready_for_output=eval_data.get("ready_for_output", False),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# 处理search_session字段
|
|
348
|
+
search_session = None
|
|
349
|
+
if data.get("search_session"):
|
|
350
|
+
search_session = SearchSession.from_dict(data["search_session"])
|
|
351
|
+
|
|
352
|
+
return cls(
|
|
353
|
+
sequence=data.get("sequence", 0),
|
|
354
|
+
timestamp=data.get("timestamp", int(time.time() * 1000)),
|
|
355
|
+
raw_question=data.get("raw_question"),
|
|
356
|
+
retriver_chunks=data.get("retriver_chunks", []),
|
|
357
|
+
search_session=search_session,
|
|
358
|
+
prompts=data.get("prompts", []),
|
|
359
|
+
response=data.get("response"),
|
|
360
|
+
uuid=data.get("uuid", str(uuid4())),
|
|
361
|
+
tool_name=data.get("tool_name"),
|
|
362
|
+
evaluation=evaluation,
|
|
363
|
+
tool_config=data.get("tool_config", {}),
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# 搜索结果相关方法
|
|
367
|
+
def create_search_session(self, original_question: str | None = None) -> SearchSession:
|
|
368
|
+
"""创建新的搜索会话"""
|
|
369
|
+
if not self.search_session:
|
|
370
|
+
self.search_session = SearchSession(
|
|
371
|
+
original_question=original_question or self.raw_question or ""
|
|
372
|
+
)
|
|
373
|
+
return self.search_session
|
|
374
|
+
|
|
375
|
+
def add_search_results(
|
|
376
|
+
self,
|
|
377
|
+
query: str,
|
|
378
|
+
results: list[SearchResult],
|
|
379
|
+
search_engine: str = "unknown",
|
|
380
|
+
execution_time_ms: int = 0,
|
|
381
|
+
total_results_count: int | None = None,
|
|
382
|
+
) -> None:
|
|
383
|
+
"""添加搜索结果"""
|
|
384
|
+
if not self.search_session:
|
|
385
|
+
self.create_search_session()
|
|
386
|
+
|
|
387
|
+
query_results = SearchQueryResults(
|
|
388
|
+
query=query,
|
|
389
|
+
results=results,
|
|
390
|
+
search_engine=search_engine,
|
|
391
|
+
execution_time_ms=execution_time_ms,
|
|
392
|
+
total_results_count=total_results_count or len(results),
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if self.search_session: # Add None check
|
|
396
|
+
self.search_session.add_query_results(query_results)
|
|
397
|
+
|
|
398
|
+
def get_search_queries(self) -> list[str]:
|
|
399
|
+
"""获取所有搜索查询"""
|
|
400
|
+
if self.search_session:
|
|
401
|
+
return self.search_session.get_all_queries()
|
|
402
|
+
return self.get_tool_config("search_queries", [])
|
|
403
|
+
|
|
404
|
+
def get_all_search_results(self) -> list[SearchResult]:
|
|
405
|
+
"""获取所有搜索结果"""
|
|
406
|
+
if self.search_session:
|
|
407
|
+
return self.search_session.get_all_results()
|
|
408
|
+
return []
|
|
409
|
+
|
|
410
|
+
def get_results_by_query(self, query: str) -> list[SearchResult]:
|
|
411
|
+
"""根据查询获取结果"""
|
|
412
|
+
if self.search_session:
|
|
413
|
+
query_results = self.search_session.get_results_by_query(query)
|
|
414
|
+
return query_results.results if query_results else []
|
|
415
|
+
return []
|
|
416
|
+
|
|
417
|
+
def get_search_results_count(self) -> int:
|
|
418
|
+
"""获取搜索结果总数"""
|
|
419
|
+
if self.search_session:
|
|
420
|
+
return self.search_session.get_total_results_count()
|
|
421
|
+
return len(self.retriver_chunks)
|
|
422
|
+
|
|
423
|
+
def has_search_results(self) -> bool:
|
|
424
|
+
"""检查是否有搜索结果"""
|
|
425
|
+
return bool(
|
|
426
|
+
(self.search_session and self.search_session.get_total_results_count() > 0)
|
|
427
|
+
or (self.retriver_chunks and len(self.retriver_chunks) > 0)
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# 向后兼容的方法
|
|
431
|
+
def set_search_queries(
|
|
432
|
+
self, queries: list[str], analysis: dict[str, Any] | None = None
|
|
433
|
+
) -> None:
|
|
434
|
+
"""设置搜索查询(向后兼容)"""
|
|
435
|
+
self.set_tool_config("search_queries", queries)
|
|
436
|
+
if analysis:
|
|
437
|
+
self.set_tool_config("search_analysis", analysis)
|
|
438
|
+
|
|
439
|
+
def get_search_analysis(self) -> dict[str, Any]:
|
|
440
|
+
"""获取搜索分析结果"""
|
|
441
|
+
return self.get_tool_config("search_analysis", {})
|
|
442
|
+
|
|
443
|
+
def has_search_queries(self) -> bool:
|
|
444
|
+
"""检查是否有搜索查询"""
|
|
445
|
+
queries = self.get_search_queries()
|
|
446
|
+
return bool(queries and len(queries) > 0)
|
|
447
|
+
|
|
448
|
+
# Tool Configuration相关方法保持不变...
|
|
449
|
+
def set_tool_config(self, key: str, value: Any) -> None:
|
|
450
|
+
"""设置工具配置项"""
|
|
451
|
+
if self.tool_config is None:
|
|
452
|
+
self.tool_config = {}
|
|
453
|
+
self.tool_config[key] = value
|
|
454
|
+
|
|
455
|
+
def get_tool_config(self, key: str, default: Any = None) -> Any:
|
|
456
|
+
"""获取工具配置项"""
|
|
457
|
+
if not self.tool_config:
|
|
458
|
+
return default
|
|
459
|
+
return self.tool_config.get(key, default)
|
|
460
|
+
|
|
461
|
+
def update_tool_config(self, config_dict: dict[str, Any]) -> None:
|
|
462
|
+
"""批量更新工具配置"""
|
|
463
|
+
if self.tool_config is None:
|
|
464
|
+
self.tool_config = {}
|
|
465
|
+
self.tool_config.update(config_dict)
|
|
466
|
+
|
|
467
|
+
def remove_tool_config(self, key: str) -> Any:
|
|
468
|
+
"""移除工具配置项"""
|
|
469
|
+
if not self.tool_config:
|
|
470
|
+
return None
|
|
471
|
+
return self.tool_config.pop(key, None)
|
|
472
|
+
|
|
473
|
+
def has_tool_config(self, key: str) -> bool:
|
|
474
|
+
"""检查是否存在指定的工具配置项"""
|
|
475
|
+
return bool(self.tool_config and key in self.tool_config)
|
|
476
|
+
|
|
477
|
+
# JSON序列化方法保持不变...
|
|
478
|
+
def to_json(self) -> str:
|
|
479
|
+
"""转换为JSON字符串"""
|
|
480
|
+
return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)
|
|
481
|
+
|
|
482
|
+
@classmethod
|
|
483
|
+
def from_json(cls, json_str: str) -> "ModelContext":
|
|
484
|
+
"""从JSON字符串创建ModelContext实例"""
|
|
485
|
+
try:
|
|
486
|
+
data = json.loads(json_str)
|
|
487
|
+
return cls.from_dict(data)
|
|
488
|
+
except json.JSONDecodeError as e:
|
|
489
|
+
raise ValueError(f"Invalid JSON format: {e}")
|
|
490
|
+
except Exception as e:
|
|
491
|
+
raise ValueError(f"Failed to create ModelContext from JSON: {e}")
|
|
492
|
+
|
|
493
|
+
def save_to_file(self, file_path: str) -> None:
|
|
494
|
+
"""保存到文件"""
|
|
495
|
+
try:
|
|
496
|
+
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
497
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
498
|
+
f.write(self.to_json())
|
|
499
|
+
except Exception as e:
|
|
500
|
+
raise OSError(f"Failed to save ModelContext to {file_path}: {e}")
|
|
501
|
+
|
|
502
|
+
@classmethod
|
|
503
|
+
def load_from_file(cls, file_path: str) -> "ModelContext":
|
|
504
|
+
"""从文件加载"""
|
|
505
|
+
try:
|
|
506
|
+
with open(file_path, encoding="utf-8") as f:
|
|
507
|
+
return cls.from_json(f.read())
|
|
508
|
+
except FileNotFoundError:
|
|
509
|
+
raise FileNotFoundError(f"ModelContext file not found: {file_path}")
|
|
510
|
+
except Exception as e:
|
|
511
|
+
raise OSError(f"Failed to load ModelContext from {file_path}: {e}")
|
|
512
|
+
|
|
513
|
+
def clone(self) -> "ModelContext":
|
|
514
|
+
"""创建当前模板的深拷贝"""
|
|
515
|
+
return self.from_dict(self.to_dict())
|
|
516
|
+
|
|
517
|
+
def update_evaluation(
|
|
518
|
+
self,
|
|
519
|
+
label: QualityLabel,
|
|
520
|
+
confidence: float,
|
|
521
|
+
reasoning: str,
|
|
522
|
+
issues: list[str] | None = None,
|
|
523
|
+
suggestions: list[str] | None = None,
|
|
524
|
+
) -> None:
|
|
525
|
+
"""更新或创建评估信息"""
|
|
526
|
+
self.evaluation = CriticEvaluation(
|
|
527
|
+
label=label,
|
|
528
|
+
confidence=confidence,
|
|
529
|
+
reasoning=reasoning,
|
|
530
|
+
specific_issues=issues or [],
|
|
531
|
+
suggestions=suggestions or [],
|
|
532
|
+
should_return_to_chief=label
|
|
533
|
+
in [QualityLabel.FAILED_POOR_QUALITY, QualityLabel.INCOMPLETE_MISSING_INFO],
|
|
534
|
+
ready_for_output=label in [QualityLabel.COMPLETE_EXCELLENT, QualityLabel.COMPLETE_GOOD],
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
# 其他方法保持不变...
|
|
538
|
+
def has_complete_response(self) -> bool:
|
|
539
|
+
"""检查是否有完整的响应"""
|
|
540
|
+
return bool(self.response and self.response.strip())
|
|
541
|
+
|
|
542
|
+
def is_ready_for_output(self) -> bool:
|
|
543
|
+
"""检查是否准备好输出"""
|
|
544
|
+
return bool(
|
|
545
|
+
self.evaluation and self.evaluation.ready_for_output and self.has_complete_response()
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
def get_processing_summary(self) -> dict[str, Any]:
|
|
549
|
+
"""获取处理摘要信息"""
|
|
550
|
+
return {
|
|
551
|
+
"uuid": self.uuid,
|
|
552
|
+
"tool_name": self.tool_name,
|
|
553
|
+
"has_response": self.has_complete_response(),
|
|
554
|
+
"has_evaluation": self.evaluation is not None,
|
|
555
|
+
"evaluation_label": (self.evaluation.label.value if self.evaluation else None),
|
|
556
|
+
"confidence": self.evaluation.confidence if self.evaluation else None,
|
|
557
|
+
"ready_for_output": self.is_ready_for_output(),
|
|
558
|
+
"search_results_count": self.get_search_results_count(),
|
|
559
|
+
"prompts_count": len(self.prompts),
|
|
560
|
+
"has_tool_config": bool(self.tool_config),
|
|
561
|
+
"tool_config_keys": (list(self.tool_config.keys()) if self.tool_config else []),
|
|
562
|
+
"has_search_queries": self.has_search_queries(),
|
|
563
|
+
"search_queries_count": len(self.get_search_queries()),
|
|
564
|
+
"timestamp": self.timestamp,
|
|
565
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class QualityLabel(Enum):
|
|
5
|
+
"""质量评估标签"""
|
|
6
|
+
|
|
7
|
+
COMPLETE_EXCELLENT = "complete_excellent"
|
|
8
|
+
COMPLETE_GOOD = "complete_good"
|
|
9
|
+
PARTIAL_NEEDS_IMPROVEMENT = "partial_needs_improvement"
|
|
10
|
+
INCOMPLETE_MISSING_INFO = "incomplete_missing_info"
|
|
11
|
+
FAILED_POOR_QUALITY = "failed_poor_quality"
|
|
12
|
+
ERROR_INVALID = "error_invalid"
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .search_result import SearchResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class SearchQueryResults:
|
|
10
|
+
"""单个搜索查询的结果集"""
|
|
11
|
+
|
|
12
|
+
query: str
|
|
13
|
+
results: list[SearchResult] = field(default_factory=list)
|
|
14
|
+
search_timestamp: int = field(default_factory=lambda: int(time.time() * 1000))
|
|
15
|
+
total_results_count: int = 0 # 搜索引擎返回的总结果数
|
|
16
|
+
execution_time_ms: int = 0 # 搜索执行时间(毫秒)
|
|
17
|
+
search_engine: str = "unknown" # 使用的搜索引擎
|
|
18
|
+
metadata: dict[str, Any] = field(default_factory=dict) # 额外的搜索元数据
|
|
19
|
+
|
|
20
|
+
def add_result(self, result: SearchResult) -> None:
|
|
21
|
+
"""添加搜索结果"""
|
|
22
|
+
self.results.append(result)
|
|
23
|
+
|
|
24
|
+
def get_results_count(self) -> int:
|
|
25
|
+
"""获取实际检索到的结果数量"""
|
|
26
|
+
return len(self.results)
|
|
27
|
+
|
|
28
|
+
def get_all_content(self) -> str:
|
|
29
|
+
"""获取所有结果的内容拼接"""
|
|
30
|
+
return "\n\n".join([f"{result.title}\n{result.content}" for result in self.results])
|
|
31
|
+
|
|
32
|
+
def get_top_results(self, n: int = 3) -> list[SearchResult]:
|
|
33
|
+
"""获取前N个结果"""
|
|
34
|
+
return self.results[:n]
|
|
35
|
+
|
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
|
37
|
+
"""转换为字典"""
|
|
38
|
+
return {
|
|
39
|
+
"query": self.query,
|
|
40
|
+
"results": [result.to_dict() for result in self.results],
|
|
41
|
+
"search_timestamp": self.search_timestamp,
|
|
42
|
+
"total_results_count": self.total_results_count,
|
|
43
|
+
"execution_time_ms": self.execution_time_ms,
|
|
44
|
+
"search_engine": self.search_engine,
|
|
45
|
+
"metadata": self.metadata.copy(),
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_dict(cls, data: dict[str, Any]) -> "SearchQueryResults":
|
|
50
|
+
"""从字典创建SearchQueryResults"""
|
|
51
|
+
results = [SearchResult.from_dict(r) for r in data.get("results", [])]
|
|
52
|
+
|
|
53
|
+
return cls(
|
|
54
|
+
query=data.get("query", ""),
|
|
55
|
+
results=results,
|
|
56
|
+
search_timestamp=data.get("search_timestamp", int(time.time() * 1000)),
|
|
57
|
+
total_results_count=data.get("total_results_count", 0),
|
|
58
|
+
execution_time_ms=data.get("execution_time_ms", 0),
|
|
59
|
+
search_engine=data.get("search_engine", "unknown"),
|
|
60
|
+
metadata=data.get("metadata", {}),
|
|
61
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class SearchResult:
|
|
8
|
+
"""单个搜索结果的数据结构"""
|
|
9
|
+
|
|
10
|
+
title: str
|
|
11
|
+
content: str
|
|
12
|
+
source: str
|
|
13
|
+
rank: int = 1 # 搜索结果的排名
|
|
14
|
+
relevance_score: float = 0.0 # 相关性分数
|
|
15
|
+
timestamp: int = field(default_factory=lambda: int(time.time() * 1000))
|
|
16
|
+
|
|
17
|
+
def __str__(self) -> str:
|
|
18
|
+
"""格式化显示搜索结果"""
|
|
19
|
+
return f"[Rank {self.rank}] {self.title}\nContent: {self.content}\nSource: {self.source}"
|
|
20
|
+
|
|
21
|
+
def to_dict(self) -> dict[str, Any]:
|
|
22
|
+
"""转换为字典"""
|
|
23
|
+
return {
|
|
24
|
+
"title": self.title,
|
|
25
|
+
"content": self.content,
|
|
26
|
+
"source": self.source,
|
|
27
|
+
"rank": self.rank,
|
|
28
|
+
"relevance_score": self.relevance_score,
|
|
29
|
+
"timestamp": self.timestamp,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_dict(cls, data: dict[str, Any]) -> "SearchResult":
|
|
34
|
+
"""从字典创建SearchResult"""
|
|
35
|
+
return cls(
|
|
36
|
+
title=data.get("title", ""),
|
|
37
|
+
content=data.get("content", ""),
|
|
38
|
+
source=data.get("source", ""),
|
|
39
|
+
rank=data.get("rank", 1),
|
|
40
|
+
relevance_score=data.get("relevance_score", 0.0),
|
|
41
|
+
timestamp=data.get("timestamp", int(time.time() * 1000)),
|
|
42
|
+
)
|