iflow-mcp_hanw39_reasoning-bank-mcp 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/METADATA +599 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/RECORD +55 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/WHEEL +4 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/entry_points.txt +2 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/licenses/LICENSE +21 -0
- src/__init__.py +16 -0
- src/__main__.py +6 -0
- src/config.py +266 -0
- src/deduplication/__init__.py +19 -0
- src/deduplication/base.py +88 -0
- src/deduplication/factory.py +60 -0
- src/deduplication/strategies/__init__.py +1 -0
- src/deduplication/strategies/semantic_dedup.py +187 -0
- src/default_config.yaml +121 -0
- src/initializers/__init__.py +50 -0
- src/initializers/base.py +196 -0
- src/initializers/embedding_initializer.py +22 -0
- src/initializers/llm_initializer.py +22 -0
- src/initializers/memory_manager_initializer.py +55 -0
- src/initializers/retrieval_initializer.py +32 -0
- src/initializers/storage_initializer.py +22 -0
- src/initializers/tools_initializer.py +48 -0
- src/llm/__init__.py +10 -0
- src/llm/base.py +61 -0
- src/llm/factory.py +75 -0
- src/llm/providers/__init__.py +12 -0
- src/llm/providers/anthropic.py +62 -0
- src/llm/providers/dashscope.py +76 -0
- src/llm/providers/openai.py +76 -0
- src/merge/__init__.py +22 -0
- src/merge/base.py +89 -0
- src/merge/factory.py +60 -0
- src/merge/strategies/__init__.py +1 -0
- src/merge/strategies/llm_merge.py +170 -0
- src/merge/strategies/voting_merge.py +108 -0
- src/prompts/__init__.py +21 -0
- src/prompts/formatters.py +74 -0
- src/prompts/templates.py +184 -0
- src/retrieval/__init__.py +8 -0
- src/retrieval/base.py +37 -0
- src/retrieval/factory.py +55 -0
- src/retrieval/strategies/__init__.py +8 -0
- src/retrieval/strategies/cosine_retrieval.py +47 -0
- src/retrieval/strategies/hybrid_retrieval.py +155 -0
- src/server.py +306 -0
- src/services/__init__.py +5 -0
- src/services/memory_manager.py +403 -0
- src/storage/__init__.py +45 -0
- src/storage/backends/json_backend.py +290 -0
- src/storage/base.py +150 -0
- src/tools/__init__.py +8 -0
- src/tools/extract_memory.py +285 -0
- src/tools/retrieve_memory.py +139 -0
- src/utils/__init__.py +7 -0
- src/utils/similarity.py +54 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""extract_memory 工具 - 提取记忆(支持异步)"""
|
|
2
|
+
import json
|
|
3
|
+
import uuid
|
|
4
|
+
import asyncio
|
|
5
|
+
from typing import Dict, List, Optional
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ExtractMemoryTool:
|
|
10
|
+
"""记忆提取工具"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, config, storage_backend, llm_provider, embedding_provider, memory_manager=None):
|
|
13
|
+
"""
|
|
14
|
+
初始化提取工具
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
config: 配置对象
|
|
18
|
+
storage_backend: 存储后端实例
|
|
19
|
+
llm_provider: LLM Provider 实例
|
|
20
|
+
embedding_provider: 嵌入 Provider 实例
|
|
21
|
+
memory_manager: 记忆管理器实例(可选,用于去重和合并)
|
|
22
|
+
"""
|
|
23
|
+
self.config = config
|
|
24
|
+
self.storage = storage_backend
|
|
25
|
+
self.llm = llm_provider
|
|
26
|
+
self.embedding = embedding_provider
|
|
27
|
+
self.memory_manager = memory_manager
|
|
28
|
+
|
|
29
|
+
# 提取配置
|
|
30
|
+
extraction_config = config.get("extraction", default={})
|
|
31
|
+
self.max_memories = extraction_config.get("max_memories_per_trajectory", 3)
|
|
32
|
+
self.judge_temp = extraction_config.get("judge_temperature", 0.0)
|
|
33
|
+
self.extract_temp = extraction_config.get("extract_temperature", 1.0)
|
|
34
|
+
self.async_by_default = extraction_config.get("async_by_default", True)
|
|
35
|
+
|
|
36
|
+
async def execute(
|
|
37
|
+
self,
|
|
38
|
+
trajectory: List[Dict],
|
|
39
|
+
query: str,
|
|
40
|
+
success_signal: Optional[bool] = None,
|
|
41
|
+
async_mode: bool = None,
|
|
42
|
+
agent_id: str = None
|
|
43
|
+
) -> Dict:
|
|
44
|
+
"""
|
|
45
|
+
执行记忆提取
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
trajectory: 轨迹步骤列表
|
|
49
|
+
query: 任务查询
|
|
50
|
+
success_signal: 成功/失败标记,None 时自动判断
|
|
51
|
+
async_mode: 是否异步处理,None 时使用配置默认值
|
|
52
|
+
agent_id: Agent ID,用于多租户隔离
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
提取结果字典
|
|
56
|
+
"""
|
|
57
|
+
# 确定是否异步
|
|
58
|
+
if async_mode is None:
|
|
59
|
+
async_mode = self.async_by_default
|
|
60
|
+
|
|
61
|
+
# 生成任务 ID
|
|
62
|
+
task_id = f"extract_{uuid.uuid4().hex[:8]}"
|
|
63
|
+
|
|
64
|
+
if async_mode:
|
|
65
|
+
# 异步模式:立即返回,后台处理
|
|
66
|
+
asyncio.create_task(
|
|
67
|
+
self._extract_async(task_id, trajectory, query, success_signal, agent_id)
|
|
68
|
+
)
|
|
69
|
+
return {
|
|
70
|
+
"status": "processing",
|
|
71
|
+
"message": "记忆提取任务已提交,正在后台处理",
|
|
72
|
+
"task_id": task_id,
|
|
73
|
+
"async_mode": True
|
|
74
|
+
}
|
|
75
|
+
else:
|
|
76
|
+
# 同步模式:等待处理完成
|
|
77
|
+
result = await self._extract_sync(trajectory, query, success_signal, agent_id)
|
|
78
|
+
return {
|
|
79
|
+
**result,
|
|
80
|
+
"task_id": task_id,
|
|
81
|
+
"async_mode": False
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
async def _extract_sync(
|
|
85
|
+
self,
|
|
86
|
+
trajectory: List[Dict],
|
|
87
|
+
query: str,
|
|
88
|
+
success_signal: Optional[bool],
|
|
89
|
+
agent_id: str = None
|
|
90
|
+
) -> Dict:
|
|
91
|
+
"""同步提取记忆"""
|
|
92
|
+
try:
|
|
93
|
+
# 1. 判断成功/失败(如果未提供)
|
|
94
|
+
if success_signal is None:
|
|
95
|
+
success_signal = await self._judge_trajectory(trajectory, query)
|
|
96
|
+
|
|
97
|
+
# 2. 格式化轨迹
|
|
98
|
+
from ..prompts.formatters import format_trajectory
|
|
99
|
+
trajectory_text = format_trajectory(trajectory)
|
|
100
|
+
|
|
101
|
+
# 3. 提取记忆项
|
|
102
|
+
from ..prompts.templates import get_extract_prompt
|
|
103
|
+
extract_prompt = get_extract_prompt(query, trajectory_text, success_signal)
|
|
104
|
+
|
|
105
|
+
response = await self.llm.chat(
|
|
106
|
+
messages=[{"role": "user", "content": extract_prompt}],
|
|
107
|
+
temperature=self.extract_temp
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# 4. 解析 LLM 响应
|
|
111
|
+
memories = self._parse_llm_response(response)
|
|
112
|
+
|
|
113
|
+
# 限制数量
|
|
114
|
+
memories = memories[:self.max_memories]
|
|
115
|
+
|
|
116
|
+
# 5. 构建记忆项和嵌入
|
|
117
|
+
new_memories = []
|
|
118
|
+
embeddings_dict = {}
|
|
119
|
+
current_time = datetime.now(timezone.utc).isoformat()
|
|
120
|
+
|
|
121
|
+
for mem_data in memories:
|
|
122
|
+
memory_id = f"mem_{uuid.uuid4().hex}"
|
|
123
|
+
|
|
124
|
+
# 构建完整记忆项
|
|
125
|
+
memory = {
|
|
126
|
+
"memory_id": memory_id,
|
|
127
|
+
"agent_id": agent_id,
|
|
128
|
+
"timestamp": current_time,
|
|
129
|
+
"success": success_signal,
|
|
130
|
+
"title": mem_data["title"],
|
|
131
|
+
"description": mem_data["description"],
|
|
132
|
+
"content": mem_data["content"],
|
|
133
|
+
"query": query,
|
|
134
|
+
"retrieval_count": 0,
|
|
135
|
+
"last_retrieved": None,
|
|
136
|
+
"tags": self._extract_tags(mem_data, query),
|
|
137
|
+
"metadata": {
|
|
138
|
+
"extraction_model": self.llm.get_provider_name(),
|
|
139
|
+
"embedding_model": self.embedding.get_provider_name()
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
# 计算嵌入:对记忆内容进行向量化,而不是查询
|
|
144
|
+
memory_text = f"{mem_data['title']} {mem_data['description']} {mem_data['content']}"
|
|
145
|
+
embedding = await self.embedding.embed(memory_text)
|
|
146
|
+
|
|
147
|
+
new_memories.append(memory)
|
|
148
|
+
embeddings_dict[memory_id] = embedding
|
|
149
|
+
|
|
150
|
+
# 6. 通过 MemoryManager 处理(去重和合并检测)
|
|
151
|
+
memories_to_save = new_memories
|
|
152
|
+
management_result = None
|
|
153
|
+
|
|
154
|
+
if self.memory_manager:
|
|
155
|
+
try:
|
|
156
|
+
management_result = await self.memory_manager.on_memory_created(
|
|
157
|
+
new_memories=new_memories,
|
|
158
|
+
embeddings=embeddings_dict,
|
|
159
|
+
agent_id=agent_id
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# 使用去重后的记忆列表
|
|
163
|
+
if management_result.success:
|
|
164
|
+
memories_to_save = management_result.metadata.get("unique_memories", new_memories)
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
# MemoryManager 失败不影响主流程
|
|
168
|
+
import logging
|
|
169
|
+
logging.warning(f"MemoryManager 处理失败: {e}", exc_info=True)
|
|
170
|
+
|
|
171
|
+
# 7. 保存记忆到存储
|
|
172
|
+
saved_memories = []
|
|
173
|
+
for memory in memories_to_save:
|
|
174
|
+
memory_id = memory["memory_id"]
|
|
175
|
+
embedding = embeddings_dict[memory_id]
|
|
176
|
+
|
|
177
|
+
await self.storage.add_memory(memory, embedding)
|
|
178
|
+
|
|
179
|
+
saved_memories.append({
|
|
180
|
+
"memory_id": memory_id,
|
|
181
|
+
"title": memory["title"],
|
|
182
|
+
"description": memory["description"]
|
|
183
|
+
})
|
|
184
|
+
|
|
185
|
+
# 8. 构建返回结果
|
|
186
|
+
result = {
|
|
187
|
+
"status": "completed",
|
|
188
|
+
"success": success_signal,
|
|
189
|
+
"extracted_count": len(saved_memories),
|
|
190
|
+
"memories": saved_memories
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
# 添加管理信息
|
|
194
|
+
if management_result:
|
|
195
|
+
result["management"] = {
|
|
196
|
+
"duplicates_skipped": management_result.duplicates_found,
|
|
197
|
+
"merges_triggered": management_result.merged_count,
|
|
198
|
+
"message": management_result.message
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
return result
|
|
202
|
+
|
|
203
|
+
except Exception as e:
|
|
204
|
+
return {
|
|
205
|
+
"status": "error",
|
|
206
|
+
"message": f"提取失败: {str(e)}",
|
|
207
|
+
"success": None,
|
|
208
|
+
"extracted_count": 0,
|
|
209
|
+
"memories": []
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
async def _extract_async(
|
|
213
|
+
self,
|
|
214
|
+
task_id: str,
|
|
215
|
+
trajectory: List[Dict],
|
|
216
|
+
query: str,
|
|
217
|
+
success_signal: Optional[bool],
|
|
218
|
+
agent_id: str = None
|
|
219
|
+
):
|
|
220
|
+
"""异步提取记忆(后台任务)"""
|
|
221
|
+
# 直接调用同步提取逻辑
|
|
222
|
+
await self._extract_sync(trajectory, query, success_signal, agent_id)
|
|
223
|
+
# 注意:异步模式下,结果不返回给调用者,只记录到日志或存储
|
|
224
|
+
|
|
225
|
+
async def _judge_trajectory(self, trajectory: List[Dict], query: str) -> bool:
|
|
226
|
+
"""判断轨迹是否成功"""
|
|
227
|
+
try:
|
|
228
|
+
from ..prompts.formatters import format_trajectory
|
|
229
|
+
from ..prompts.templates import get_judge_prompt
|
|
230
|
+
|
|
231
|
+
trajectory_text = format_trajectory(trajectory)
|
|
232
|
+
# todo 轨迹分段,A task may involve success and failure..
|
|
233
|
+
judge_prompt = get_judge_prompt(query, trajectory_text)
|
|
234
|
+
|
|
235
|
+
response = await self.llm.chat(
|
|
236
|
+
messages=[{"role": "user", "content": judge_prompt}],
|
|
237
|
+
temperature=self.judge_temp
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# 解析判断结果
|
|
241
|
+
result = self._parse_json_response(response)
|
|
242
|
+
return result.get("result", "failure") == "success"
|
|
243
|
+
|
|
244
|
+
except Exception:
|
|
245
|
+
# 判断失败时,默认为失败轨迹
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
def _parse_llm_response(self, response: str) -> List[Dict]:
|
|
249
|
+
"""解析 LLM 返回的记忆项"""
|
|
250
|
+
try:
|
|
251
|
+
# 尝试提取 JSON
|
|
252
|
+
data = self._parse_json_response(response)
|
|
253
|
+
|
|
254
|
+
if "memories" in data:
|
|
255
|
+
return data["memories"]
|
|
256
|
+
else:
|
|
257
|
+
return []
|
|
258
|
+
|
|
259
|
+
except Exception:
|
|
260
|
+
return []
|
|
261
|
+
|
|
262
|
+
def _parse_json_response(self, response: str) -> Dict:
|
|
263
|
+
"""从响应中提取 JSON"""
|
|
264
|
+
# 移除可能的 markdown 代码块标记
|
|
265
|
+
response = response.strip()
|
|
266
|
+
if response.startswith("```json"):
|
|
267
|
+
response = response[7:]
|
|
268
|
+
if response.startswith("```"):
|
|
269
|
+
response = response[3:]
|
|
270
|
+
if response.endswith("```"):
|
|
271
|
+
response = response[:-3]
|
|
272
|
+
|
|
273
|
+
response = response.strip()
|
|
274
|
+
return json.loads(response)
|
|
275
|
+
|
|
276
|
+
def _extract_tags(self, memory_data: Dict, query: str) -> List[str]:
|
|
277
|
+
"""从记忆内容中提取标签"""
|
|
278
|
+
# 简单的标签提取逻辑
|
|
279
|
+
tags = []
|
|
280
|
+
|
|
281
|
+
# 基于成功/失败
|
|
282
|
+
# (在调用处已经有 success 信息)
|
|
283
|
+
|
|
284
|
+
# 可以添加更复杂的标签提取逻辑
|
|
285
|
+
return tags
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""retrieve_memory 工具 - 检索相关记忆"""
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RetrieveMemoryTool:
|
|
8
|
+
"""检索记忆工具"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, config, storage_backend, embedding_provider, retrieval_strategy):
|
|
11
|
+
"""
|
|
12
|
+
初始化检索工具
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
config: 配置对象
|
|
16
|
+
storage_backend: 存储后端实例
|
|
17
|
+
embedding_provider: 嵌入 Provider 实例
|
|
18
|
+
retrieval_strategy: 检索策略实例
|
|
19
|
+
"""
|
|
20
|
+
self.config = config
|
|
21
|
+
self.storage = storage_backend
|
|
22
|
+
self.embedding = embedding_provider
|
|
23
|
+
self.retrieval = retrieval_strategy
|
|
24
|
+
|
|
25
|
+
async def execute(self, query: str, top_k: int = None, agent_id: str = None) -> Dict:
|
|
26
|
+
"""
|
|
27
|
+
执行记忆检索
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
query: 任务查询
|
|
31
|
+
top_k: 检索数量,默认使用配置中的值
|
|
32
|
+
agent_id: Agent ID,用于多租户隔离
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
检索结果字典
|
|
36
|
+
"""
|
|
37
|
+
# 使用配置的默认值
|
|
38
|
+
if top_k is None:
|
|
39
|
+
top_k = self.config.get("retrieval", "default_top_k", default=1)
|
|
40
|
+
|
|
41
|
+
# 限制最大值
|
|
42
|
+
max_top_k = self.config.get("retrieval", "max_top_k", default=10)
|
|
43
|
+
top_k = min(top_k, max_top_k)
|
|
44
|
+
|
|
45
|
+
# 获取最小分数阈值
|
|
46
|
+
min_score_threshold = self.config.get("retrieval", "min_score_threshold", default=0.85)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
# 1. 对查询进行嵌入
|
|
50
|
+
query_embedding = await self.embedding.embed(query)
|
|
51
|
+
query_vec = np.array(query_embedding)
|
|
52
|
+
|
|
53
|
+
# 2. 使用策略检索
|
|
54
|
+
top_k_results = await self.retrieval.retrieve(
|
|
55
|
+
query=query,
|
|
56
|
+
query_embedding=query_vec,
|
|
57
|
+
storage_backend=self.storage,
|
|
58
|
+
top_k=top_k,
|
|
59
|
+
agent_id=agent_id
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not top_k_results:
|
|
63
|
+
return {
|
|
64
|
+
"status": "no_memories",
|
|
65
|
+
"message": "记忆库为空或没有找到相关记忆",
|
|
66
|
+
"query": query,
|
|
67
|
+
"memories": [],
|
|
68
|
+
"formatted_prompt": ""
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# 3. 获取完整记忆内容并更新统计(过滤低分记忆)
|
|
72
|
+
retrieved_memories = []
|
|
73
|
+
current_time = datetime.now(timezone.utc).isoformat()
|
|
74
|
+
filtered_count = 0 # 记录被过滤的数量
|
|
75
|
+
|
|
76
|
+
for memory_id, score in top_k_results:
|
|
77
|
+
# 过滤低于阈值的记忆
|
|
78
|
+
if score < min_score_threshold:
|
|
79
|
+
filtered_count += 1
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
memory = await self.storage.get_memory_by_id(memory_id)
|
|
83
|
+
if memory:
|
|
84
|
+
# 更新检索统计
|
|
85
|
+
await self.storage.update_retrieval_stats(memory_id, current_time)
|
|
86
|
+
|
|
87
|
+
# 添加到结果
|
|
88
|
+
retrieved_memories.append({
|
|
89
|
+
"memory_id": memory_id,
|
|
90
|
+
"score": float(score),
|
|
91
|
+
"title": memory["title"],
|
|
92
|
+
"content": memory["content"],
|
|
93
|
+
"success": memory.get("success", True),
|
|
94
|
+
"tags": memory.get("tags", []),
|
|
95
|
+
"description": memory.get("description", "")
|
|
96
|
+
})
|
|
97
|
+
|
|
98
|
+
# 如果所有记忆都被过滤了
|
|
99
|
+
if not retrieved_memories:
|
|
100
|
+
return {
|
|
101
|
+
"status": "no_memories",
|
|
102
|
+
"message": f"没有找到相关度高于 {min_score_threshold} 的记忆(过滤了 {filtered_count} 条低相关度记忆)",
|
|
103
|
+
"query": query,
|
|
104
|
+
"min_score_threshold": min_score_threshold,
|
|
105
|
+
"filtered_count": filtered_count,
|
|
106
|
+
"memories": [],
|
|
107
|
+
"formatted_prompt": ""
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# 4. 格式化为 LLM 提示
|
|
111
|
+
formatted_prompt = self._format_for_prompt(retrieved_memories)
|
|
112
|
+
|
|
113
|
+
return {
|
|
114
|
+
"status": "success",
|
|
115
|
+
"query": query,
|
|
116
|
+
"retrieval_strategy": self.retrieval.get_name(),
|
|
117
|
+
"top_k": top_k,
|
|
118
|
+
"min_score_threshold": min_score_threshold,
|
|
119
|
+
"filtered_count": filtered_count,
|
|
120
|
+
"memories": retrieved_memories,
|
|
121
|
+
"formatted_prompt": formatted_prompt
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
return {
|
|
126
|
+
"status": "error",
|
|
127
|
+
"message": f"检索失败: {str(e)}",
|
|
128
|
+
"query": query,
|
|
129
|
+
"memories": [],
|
|
130
|
+
"formatted_prompt": ""
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
def _format_for_prompt(self, memories: List[Dict]) -> str:
|
|
134
|
+
"""格式化为可直接用于 LLM 提示的文本"""
|
|
135
|
+
if not memories:
|
|
136
|
+
return ""
|
|
137
|
+
|
|
138
|
+
from ..prompts.formatters import format_memory_for_prompt
|
|
139
|
+
return format_memory_for_prompt(memories)
|
src/utils/__init__.py
ADDED
src/utils/similarity.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""工具函数 - 相似度计算"""
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import List, Tuple, Dict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
|
|
7
|
+
"""
|
|
8
|
+
计算两个向量的余弦相似度
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
vec1: 向量1
|
|
12
|
+
vec2: 向量2
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
余弦相似度值 [-1, 1]
|
|
16
|
+
"""
|
|
17
|
+
dot_product = np.dot(vec1, vec2)
|
|
18
|
+
norm1 = np.linalg.norm(vec1)
|
|
19
|
+
norm2 = np.linalg.norm(vec2)
|
|
20
|
+
|
|
21
|
+
if norm1 == 0 or norm2 == 0:
|
|
22
|
+
return 0.0
|
|
23
|
+
|
|
24
|
+
return float(dot_product / (norm1 * norm2))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def find_top_k_similar(
|
|
28
|
+
query_embedding: List[float],
|
|
29
|
+
memory_embeddings: Dict[str, np.ndarray],
|
|
30
|
+
top_k: int = 1
|
|
31
|
+
) -> List[Tuple[str, float]]:
|
|
32
|
+
"""
|
|
33
|
+
找到与查询最相似的 Top-K 记忆
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
query_embedding: 查询的嵌入向量
|
|
37
|
+
memory_embeddings: 记忆嵌入字典 {memory_id: embedding}
|
|
38
|
+
top_k: 返回的数量
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
[(memory_id, similarity_score), ...] 按相似度降序排列
|
|
42
|
+
"""
|
|
43
|
+
query_vec = np.array(query_embedding)
|
|
44
|
+
|
|
45
|
+
similarities = []
|
|
46
|
+
for memory_id, memory_vec in memory_embeddings.items():
|
|
47
|
+
sim = cosine_similarity(query_vec, memory_vec)
|
|
48
|
+
similarities.append((memory_id, sim))
|
|
49
|
+
|
|
50
|
+
# 按相似度降序排序
|
|
51
|
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
|
52
|
+
|
|
53
|
+
# 返回 Top-K
|
|
54
|
+
return similarities[:top_k]
|