iflow-mcp_hanw39_reasoning-bank-mcp 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/METADATA +599 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/RECORD +55 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/WHEEL +4 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/entry_points.txt +2 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/licenses/LICENSE +21 -0
- src/__init__.py +16 -0
- src/__main__.py +6 -0
- src/config.py +266 -0
- src/deduplication/__init__.py +19 -0
- src/deduplication/base.py +88 -0
- src/deduplication/factory.py +60 -0
- src/deduplication/strategies/__init__.py +1 -0
- src/deduplication/strategies/semantic_dedup.py +187 -0
- src/default_config.yaml +121 -0
- src/initializers/__init__.py +50 -0
- src/initializers/base.py +196 -0
- src/initializers/embedding_initializer.py +22 -0
- src/initializers/llm_initializer.py +22 -0
- src/initializers/memory_manager_initializer.py +55 -0
- src/initializers/retrieval_initializer.py +32 -0
- src/initializers/storage_initializer.py +22 -0
- src/initializers/tools_initializer.py +48 -0
- src/llm/__init__.py +10 -0
- src/llm/base.py +61 -0
- src/llm/factory.py +75 -0
- src/llm/providers/__init__.py +12 -0
- src/llm/providers/anthropic.py +62 -0
- src/llm/providers/dashscope.py +76 -0
- src/llm/providers/openai.py +76 -0
- src/merge/__init__.py +22 -0
- src/merge/base.py +89 -0
- src/merge/factory.py +60 -0
- src/merge/strategies/__init__.py +1 -0
- src/merge/strategies/llm_merge.py +170 -0
- src/merge/strategies/voting_merge.py +108 -0
- src/prompts/__init__.py +21 -0
- src/prompts/formatters.py +74 -0
- src/prompts/templates.py +184 -0
- src/retrieval/__init__.py +8 -0
- src/retrieval/base.py +37 -0
- src/retrieval/factory.py +55 -0
- src/retrieval/strategies/__init__.py +8 -0
- src/retrieval/strategies/cosine_retrieval.py +47 -0
- src/retrieval/strategies/hybrid_retrieval.py +155 -0
- src/server.py +306 -0
- src/services/__init__.py +5 -0
- src/services/memory_manager.py +403 -0
- src/storage/__init__.py +45 -0
- src/storage/backends/json_backend.py +290 -0
- src/storage/base.py +150 -0
- src/tools/__init__.py +8 -0
- src/tools/extract_memory.py +285 -0
- src/tools/retrieve_memory.py +139 -0
- src/utils/__init__.py +7 -0
- src/utils/similarity.py +54 -0
src/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""ReasoningBank MCP 服务器包"""
|
|
2
|
+
__version__ = "0.1.0"
|
|
3
|
+
__author__ = "Your Name"
|
|
4
|
+
__description__ = "Memory-augmented reasoning for AI agents via MCP"
|
|
5
|
+
|
|
6
|
+
from .config import load_config, get_config
|
|
7
|
+
# 注意:为了避免循环导入和RuntimeWarning,不从server模块导入ReasoningBankServer和main
|
|
8
|
+
|
|
9
|
+
# 为了脚本入口点,我们需要导入run_server函数
|
|
10
|
+
from .server import run_server
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"load_config",
|
|
14
|
+
"get_config",
|
|
15
|
+
"run_server", # 为脚本入口点导出
|
|
16
|
+
]
|
src/__main__.py
ADDED
src/config.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""配置管理模块"""
|
|
2
|
+
import os
|
|
3
|
+
import yaml
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, Any, Optional
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
|
|
9
|
+
# 配置日志
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# def _find_dotenv_file() -> Optional[Path]:
|
|
14
|
+
# """
|
|
15
|
+
# 智能查找 .env 文件
|
|
16
|
+
#
|
|
17
|
+
# 查找顺序:
|
|
18
|
+
# 1. 当前工作目录
|
|
19
|
+
# 2. 项目根目录(从 src/config.py 向上查找 pyproject.toml)
|
|
20
|
+
#
|
|
21
|
+
# Returns:
|
|
22
|
+
# Path 对象或 None
|
|
23
|
+
# """
|
|
24
|
+
# # 1. 当前工作目录
|
|
25
|
+
# cwd_env = Path.cwd() / ".env"
|
|
26
|
+
# if cwd_env.exists():
|
|
27
|
+
# logger.debug(f"找到 .env 文件: {cwd_env}")
|
|
28
|
+
# return cwd_env
|
|
29
|
+
#
|
|
30
|
+
# # 2. 项目根目录
|
|
31
|
+
# current_file = Path(__file__).resolve() # src/config.py
|
|
32
|
+
# src_dir = current_file.parent # src/
|
|
33
|
+
# project_root = src_dir.parent # 项目根目录
|
|
34
|
+
#
|
|
35
|
+
# if (project_root / "pyproject.toml").exists():
|
|
36
|
+
# project_env = project_root / ".env"
|
|
37
|
+
# if project_env.exists():
|
|
38
|
+
# logger.debug(f"找到 .env 文件: {project_env}")
|
|
39
|
+
# return project_env
|
|
40
|
+
#
|
|
41
|
+
# logger.debug(".env 文件未找到")
|
|
42
|
+
# return None
|
|
43
|
+
#
|
|
44
|
+
#
|
|
45
|
+
# # 加载环境变量(优先使用已存在的环境变量,如 MCP 传递的)
|
|
46
|
+
# dotenv_path = _find_dotenv_file()
|
|
47
|
+
# if dotenv_path:
|
|
48
|
+
# # override=False 确保不覆盖已存在的环境变量(如 MCP 传递的)
|
|
49
|
+
# load_dotenv(dotenv_path, override=False)
|
|
50
|
+
# logger.debug(f"已加载 .env 文件: {dotenv_path}")
|
|
51
|
+
# else:
|
|
52
|
+
# logger.debug("未找到 .env 文件,将仅使用系统环境变量")
|
|
53
|
+
load_dotenv()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Config:
|
|
57
|
+
"""配置管理类"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, config_path: str = "config.yaml"):
|
|
60
|
+
"""
|
|
61
|
+
初始化配置
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
config_path: 配置文件路径,支持:
|
|
65
|
+
- 绝对路径:如 "/path/to/config.yaml"
|
|
66
|
+
- 相对路径:会依次在以下位置查找
|
|
67
|
+
1. 当前工作目录
|
|
68
|
+
2. 项目根目录(pyproject.toml 所在目录)
|
|
69
|
+
3. ~/.reasoningbank/config.yaml
|
|
70
|
+
"""
|
|
71
|
+
self.config_path = self._resolve_config_path(config_path)
|
|
72
|
+
self._config: Dict[str, Any] = {}
|
|
73
|
+
self._load_config()
|
|
74
|
+
|
|
75
|
+
def _resolve_config_path(self, config_path: str) -> Path:
|
|
76
|
+
"""
|
|
77
|
+
智能解析配置文件路径
|
|
78
|
+
|
|
79
|
+
查找顺序:
|
|
80
|
+
1. 如果是绝对路径且存在,直接使用
|
|
81
|
+
2. 当前工作目录
|
|
82
|
+
3. 项目根目录(从 src/config.py 向上查找 pyproject.toml)
|
|
83
|
+
4. 用户主目录 ~/.reasoningbank/
|
|
84
|
+
5. src 目录下的 default_config.yaml(随包安装的默认配置)
|
|
85
|
+
"""
|
|
86
|
+
path = Path(config_path)
|
|
87
|
+
|
|
88
|
+
# 1. 绝对路径直接使用
|
|
89
|
+
if path.is_absolute():
|
|
90
|
+
return path
|
|
91
|
+
|
|
92
|
+
# 2. 当前工作目录
|
|
93
|
+
cwd_path = Path.cwd() / config_path
|
|
94
|
+
if cwd_path.exists():
|
|
95
|
+
return cwd_path
|
|
96
|
+
|
|
97
|
+
# 3. 项目根目录(向上查找 pyproject.toml)
|
|
98
|
+
# 从当前文件所在目录(src/)开始向上查找
|
|
99
|
+
current_file = Path(__file__).resolve() # src/config.py
|
|
100
|
+
src_dir = current_file.parent # src/
|
|
101
|
+
project_root = src_dir.parent # 项目根目录
|
|
102
|
+
|
|
103
|
+
# 验证是否找到项目根目录(检查 pyproject.toml)
|
|
104
|
+
if (project_root / "pyproject.toml").exists():
|
|
105
|
+
project_config = project_root / config_path
|
|
106
|
+
if project_config.exists():
|
|
107
|
+
return project_config
|
|
108
|
+
|
|
109
|
+
# 4. 用户主目录
|
|
110
|
+
home_path = Path.home() / ".reasoningbank" / config_path
|
|
111
|
+
if home_path.exists():
|
|
112
|
+
return home_path
|
|
113
|
+
|
|
114
|
+
# 5. 回退到 src 目录下的默认配置(随包安装)
|
|
115
|
+
default_config = src_dir / "default_config.yaml"
|
|
116
|
+
if default_config.exists():
|
|
117
|
+
logger.info(f"使用默认配置文件: {default_config}")
|
|
118
|
+
logger.info(f"建议在以下位置创建自定义配置: {home_path}")
|
|
119
|
+
return default_config
|
|
120
|
+
|
|
121
|
+
# 如果都没找到,优先使用用户主目录路径(引导用户创建配置)
|
|
122
|
+
return home_path
|
|
123
|
+
|
|
124
|
+
def _load_config(self):
|
|
125
|
+
"""加载配置文件"""
|
|
126
|
+
if not self.config_path.exists():
|
|
127
|
+
raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
|
|
128
|
+
|
|
129
|
+
with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
130
|
+
self._config = yaml.safe_load(f)
|
|
131
|
+
|
|
132
|
+
# 替换环境变量
|
|
133
|
+
self._replace_env_variables(self._config)
|
|
134
|
+
|
|
135
|
+
def _replace_env_variables(self, config: Any) -> Any:
|
|
136
|
+
"""
|
|
137
|
+
递归替换配置中的环境变量
|
|
138
|
+
|
|
139
|
+
支持的格式:
|
|
140
|
+
- ${VAR_NAME} : 必需的环境变量,不存在时抛出异常
|
|
141
|
+
- ${VAR_NAME?} : 可选的环境变量,不存在时返回空字符串
|
|
142
|
+
- ${VAR_NAME:default} : 带默认值的环境变量,不存在时使用默认值
|
|
143
|
+
"""
|
|
144
|
+
if isinstance(config, dict):
|
|
145
|
+
for key, value in config.items():
|
|
146
|
+
config[key] = self._replace_env_variables(value)
|
|
147
|
+
elif isinstance(config, list):
|
|
148
|
+
return [self._replace_env_variables(item) for item in config]
|
|
149
|
+
elif isinstance(config, str):
|
|
150
|
+
# 替换 ${VAR_NAME} 格式的环境变量
|
|
151
|
+
if config.startswith("${") and config.endswith("}"):
|
|
152
|
+
var_spec = config[2:-1]
|
|
153
|
+
|
|
154
|
+
# 支持带默认值: ${VAR_NAME:default_value}
|
|
155
|
+
if ':' in var_spec:
|
|
156
|
+
var_name, default_value = var_spec.split(':', 1)
|
|
157
|
+
return os.getenv(var_name, default_value)
|
|
158
|
+
|
|
159
|
+
# 支持可选环境变量: ${VAR_NAME?}
|
|
160
|
+
if var_spec.endswith('?'):
|
|
161
|
+
var_name = var_spec[:-1]
|
|
162
|
+
return os.getenv(var_name, "")
|
|
163
|
+
|
|
164
|
+
# 必需的环境变量
|
|
165
|
+
env_value = os.getenv(var_spec)
|
|
166
|
+
if env_value is None:
|
|
167
|
+
# 提供详细的错误信息
|
|
168
|
+
error_msg = f"环境变量未设置: {var_spec}\n\n"
|
|
169
|
+
error_msg += "解决方案:\n"
|
|
170
|
+
error_msg += "1. 如果使用 MCP,请在配置中添加:\n"
|
|
171
|
+
error_msg += ' {\n'
|
|
172
|
+
error_msg += ' "env": {\n'
|
|
173
|
+
error_msg += f' "{var_spec}": "your-api-key-here"\n'
|
|
174
|
+
error_msg += ' }\n'
|
|
175
|
+
error_msg += ' }\n\n'
|
|
176
|
+
error_msg += "2. 或在项目根目录创建 .env 文件:\n"
|
|
177
|
+
error_msg += f' {var_spec}=your-api-key-here\n\n'
|
|
178
|
+
|
|
179
|
+
# 显示 .env 文件的查找路径
|
|
180
|
+
# dotenv_file = _find_dotenv_file()
|
|
181
|
+
# if dotenv_file:
|
|
182
|
+
# error_msg += f"3. 当前加载的 .env 文件: {dotenv_file}\n"
|
|
183
|
+
# error_msg += f" 请确认该文件包含 {var_spec} 配置"
|
|
184
|
+
# else:
|
|
185
|
+
# error_msg += f"3. 未找到 .env 文件"
|
|
186
|
+
|
|
187
|
+
raise ValueError(error_msg)
|
|
188
|
+
return env_value
|
|
189
|
+
|
|
190
|
+
return config
|
|
191
|
+
|
|
192
|
+
def get(self, *keys, default=None) -> Any:
|
|
193
|
+
"""
|
|
194
|
+
获取配置值
|
|
195
|
+
|
|
196
|
+
例如: config.get('llm', 'provider') -> 'dashscope'
|
|
197
|
+
"""
|
|
198
|
+
value = self._config
|
|
199
|
+
for key in keys:
|
|
200
|
+
if isinstance(value, dict):
|
|
201
|
+
value = value.get(key, default)
|
|
202
|
+
else:
|
|
203
|
+
return default
|
|
204
|
+
return value
|
|
205
|
+
|
|
206
|
+
def get_llm_config(self) -> Dict[str, Any]:
|
|
207
|
+
"""获取 LLM 配置"""
|
|
208
|
+
provider = self.get('llm', 'provider')
|
|
209
|
+
return {
|
|
210
|
+
'provider': provider,
|
|
211
|
+
**self.get('llm', provider, default={})
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
def get_embedding_config(self) -> Dict[str, Any]:
|
|
215
|
+
"""获取 Embedding 配置"""
|
|
216
|
+
provider = self.get('embedding', 'provider')
|
|
217
|
+
return {
|
|
218
|
+
'provider': provider,
|
|
219
|
+
**self.get('embedding', provider, default={})
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
def get_retrieval_config(self) -> Dict[str, Any]:
|
|
223
|
+
"""获取检索配置"""
|
|
224
|
+
strategy = self.get('retrieval', 'strategy')
|
|
225
|
+
return {
|
|
226
|
+
'strategy': strategy,
|
|
227
|
+
'default_top_k': self.get('retrieval', 'default_top_k', default=1),
|
|
228
|
+
'max_top_k': self.get('retrieval', 'max_top_k', default=10),
|
|
229
|
+
'strategy_config': self.get('retrieval', strategy, default={})
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
def get_storage_config(self) -> Dict[str, Any]:
|
|
233
|
+
"""获取存储配置"""
|
|
234
|
+
backend = self.get('storage', 'backend')
|
|
235
|
+
return {
|
|
236
|
+
'backend': backend,
|
|
237
|
+
**self.get('storage', backend, default={})
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
def get_extraction_config(self) -> Dict[str, Any]:
|
|
241
|
+
"""获取记忆提取配置"""
|
|
242
|
+
return self.get('extraction', default={})
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def all(self) -> Dict[str, Any]:
|
|
246
|
+
"""返回完整配置"""
|
|
247
|
+
return self._config
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# 全局配置实例
|
|
251
|
+
_global_config: Config = None
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def load_config(config_path: str = "config.yaml") -> Config:
|
|
255
|
+
"""加载全局配置"""
|
|
256
|
+
global _global_config
|
|
257
|
+
_global_config = Config(config_path)
|
|
258
|
+
return _global_config
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def get_config() -> Config:
|
|
262
|
+
"""获取全局配置实例"""
|
|
263
|
+
global _global_config
|
|
264
|
+
if _global_config is None:
|
|
265
|
+
_global_config = Config()
|
|
266
|
+
return _global_config
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Deduplication module initialization
|
|
3
|
+
|
|
4
|
+
Registers all built-in deduplication strategies.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .base import DeduplicationStrategy, DeduplicationResult
|
|
8
|
+
from .factory import DeduplicationFactory
|
|
9
|
+
from .strategies.semantic_dedup import SemanticDeduplicationStrategy
|
|
10
|
+
|
|
11
|
+
# Register built-in strategies
|
|
12
|
+
DeduplicationFactory.register("semantic", SemanticDeduplicationStrategy)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DeduplicationStrategy",
|
|
16
|
+
"DeduplicationResult",
|
|
17
|
+
"DeduplicationFactory",
|
|
18
|
+
"SemanticDeduplicationStrategy",
|
|
19
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Deduplication Strategy Base Interface
|
|
3
|
+
|
|
4
|
+
Provides pluggable deduplication strategies for memory management.
|
|
5
|
+
All operations MUST respect agent_id isolation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import List, Optional, Dict, Any
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class DeduplicationResult:
|
|
16
|
+
"""Result of deduplication check"""
|
|
17
|
+
is_duplicate: bool
|
|
18
|
+
duplicate_of: Optional[str] = None # memory_id of existing memory
|
|
19
|
+
similarity_score: float = 0.0
|
|
20
|
+
reason: str = ""
|
|
21
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DeduplicationStrategy(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Abstract base class for deduplication strategies.
|
|
27
|
+
|
|
28
|
+
Key principle: All operations are scoped to a specific agent_id.
|
|
29
|
+
Different agents should have isolated memory spaces.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, config: Dict[str, Any]):
|
|
33
|
+
"""
|
|
34
|
+
Initialize strategy with configuration.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
config: Strategy-specific configuration dictionary
|
|
38
|
+
"""
|
|
39
|
+
self.config = config
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
async def check_duplicate(
|
|
43
|
+
self,
|
|
44
|
+
memory: Dict[str, Any],
|
|
45
|
+
embedding: Optional[np.ndarray] = None,
|
|
46
|
+
storage_backend: Any = None,
|
|
47
|
+
agent_id: Optional[str] = None
|
|
48
|
+
) -> DeduplicationResult:
|
|
49
|
+
"""
|
|
50
|
+
Check if a memory is duplicate of existing memories within the same agent_id.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
memory: Memory object to check (dict with keys: title, content, query, etc.)
|
|
54
|
+
embedding: Optional pre-computed embedding vector
|
|
55
|
+
storage_backend: Storage backend for querying existing memories
|
|
56
|
+
agent_id: Agent ID for isolation (REQUIRED for multi-tenant safety)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
DeduplicationResult with is_duplicate flag and details
|
|
60
|
+
"""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
async def find_duplicate_groups(
|
|
65
|
+
self,
|
|
66
|
+
storage_backend: Any,
|
|
67
|
+
agent_id: Optional[str] = None,
|
|
68
|
+
limit: Optional[int] = None
|
|
69
|
+
) -> List[List[str]]:
|
|
70
|
+
"""
|
|
71
|
+
Find groups of duplicate memories within the same agent_id.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
storage_backend: Storage backend for querying memories
|
|
75
|
+
agent_id: Agent ID for isolation (only search within this agent's memories)
|
|
76
|
+
limit: Maximum number of groups to return
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
List of groups, where each group is a list of memory_ids that are duplicates
|
|
80
|
+
Example: [["mem_1", "mem_2"], ["mem_3", "mem_4", "mem_5"]]
|
|
81
|
+
"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def name(self) -> str:
|
|
87
|
+
"""Return strategy name for logging and config"""
|
|
88
|
+
pass
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Deduplication Strategy Factory
|
|
3
|
+
|
|
4
|
+
Provides plugin mechanism for registering and creating deduplication strategies.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, Type, Any
|
|
8
|
+
from .base import DeduplicationStrategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DeduplicationFactory:
|
|
12
|
+
"""Factory for creating deduplication strategy instances"""
|
|
13
|
+
|
|
14
|
+
_strategies: Dict[str, Type[DeduplicationStrategy]] = {}
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def register(cls, name: str, strategy_class: Type[DeduplicationStrategy]):
|
|
18
|
+
"""
|
|
19
|
+
Register a deduplication strategy.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
name: Strategy name (e.g., "hash", "semantic", "hybrid")
|
|
23
|
+
strategy_class: Strategy class (must inherit from DeduplicationStrategy)
|
|
24
|
+
"""
|
|
25
|
+
cls._strategies[name] = strategy_class
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def create(cls, config: Any) -> DeduplicationStrategy:
|
|
29
|
+
"""
|
|
30
|
+
Create a deduplication strategy instance based on config.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config: Config object with get() method
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
DeduplicationStrategy instance
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If strategy name is not registered
|
|
40
|
+
"""
|
|
41
|
+
# 使用统一的配置访问方式
|
|
42
|
+
strategy_name = config.get('memory_manager', 'deduplication', 'strategy', default='semantic')
|
|
43
|
+
|
|
44
|
+
if strategy_name not in cls._strategies:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Unknown deduplication strategy: {strategy_name}. "
|
|
47
|
+
f"Available: {list(cls._strategies.keys())}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
strategy_class = cls._strategies[strategy_name]
|
|
51
|
+
|
|
52
|
+
# 获取去重配置
|
|
53
|
+
dedup_config = config.get('memory_manager', 'deduplication', default={})
|
|
54
|
+
|
|
55
|
+
return strategy_class(dedup_config)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def list_strategies(cls) -> list:
|
|
59
|
+
"""Return list of registered strategy names"""
|
|
60
|
+
return list(cls._strategies.keys())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Strategies submodule"""
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Semantic Deduplication Strategy
|
|
3
|
+
|
|
4
|
+
Detects semantically similar memories using embedding similarity.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Dict, Any
|
|
8
|
+
import numpy as np
|
|
9
|
+
from ..base import DeduplicationStrategy, DeduplicationResult
|
|
10
|
+
from ...utils.similarity import cosine_similarity
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SemanticDeduplicationStrategy(DeduplicationStrategy):
|
|
17
|
+
"""
|
|
18
|
+
Semantic deduplication using embedding cosine similarity.
|
|
19
|
+
|
|
20
|
+
Use case: Find memories that describe similar experiences,
|
|
21
|
+
even if the exact wording is different.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def name(self) -> str:
|
|
26
|
+
return "semantic"
|
|
27
|
+
|
|
28
|
+
def __init__(self, config: Dict[str, Any]):
|
|
29
|
+
super().__init__(config)
|
|
30
|
+
self.threshold = config.get("semantic", {}).get("threshold", 0.90)
|
|
31
|
+
self.top_k = config.get("semantic", {}).get("top_k_check", 1)
|
|
32
|
+
|
|
33
|
+
async def check_duplicate(
|
|
34
|
+
self,
|
|
35
|
+
memory: Dict[str, Any],
|
|
36
|
+
embedding: Optional[np.ndarray] = None,
|
|
37
|
+
storage_backend: Any = None,
|
|
38
|
+
agent_id: Optional[str] = None
|
|
39
|
+
) -> DeduplicationResult:
|
|
40
|
+
"""
|
|
41
|
+
Check if semantically similar memories exist for this agent.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
memory: Memory dict
|
|
45
|
+
embedding: Pre-computed embedding of memory query (REQUIRED)
|
|
46
|
+
storage_backend: Storage backend
|
|
47
|
+
agent_id: Agent ID for isolation (CRITICAL)
|
|
48
|
+
"""
|
|
49
|
+
if embedding is None:
|
|
50
|
+
logger.warning("No embedding provided for semantic dedup check")
|
|
51
|
+
return DeduplicationResult(
|
|
52
|
+
is_duplicate=False,
|
|
53
|
+
reason="No embedding provided"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if not storage_backend:
|
|
57
|
+
logger.warning("No storage_backend provided")
|
|
58
|
+
return DeduplicationResult(is_duplicate=False)
|
|
59
|
+
|
|
60
|
+
# Retrieve similar memories within agent_id scope
|
|
61
|
+
try:
|
|
62
|
+
# Use existing retrieval mechanism with agent_id filter
|
|
63
|
+
from ...retrieval.factory import RetrievalFactory
|
|
64
|
+
|
|
65
|
+
# Get retrieval strategy from storage
|
|
66
|
+
retrieval_strategy = storage_backend.retrieval_strategy
|
|
67
|
+
if not retrieval_strategy:
|
|
68
|
+
logger.warning("No retrieval strategy available")
|
|
69
|
+
return DeduplicationResult(is_duplicate=False)
|
|
70
|
+
|
|
71
|
+
# Retrieve top-k similar memories
|
|
72
|
+
# Note: query text is not used by retrieval (only embedding),
|
|
73
|
+
# but required by interface
|
|
74
|
+
query_text = memory.get("query", "")
|
|
75
|
+
similar_results = await retrieval_strategy.retrieve(
|
|
76
|
+
query=query_text,
|
|
77
|
+
query_embedding=embedding,
|
|
78
|
+
top_k=self.top_k,
|
|
79
|
+
agent_id=agent_id, # CRITICAL: Only search within this agent
|
|
80
|
+
storage_backend=storage_backend
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Check if any exceed threshold
|
|
84
|
+
for mem_id, score in similar_results:
|
|
85
|
+
if score >= self.threshold:
|
|
86
|
+
existing_mem = await storage_backend.get_memory(mem_id)
|
|
87
|
+
logger.info(
|
|
88
|
+
f"Found semantically similar memory: {mem_id} "
|
|
89
|
+
f"(score={score:.3f}, threshold={self.threshold}) "
|
|
90
|
+
f"for agent_id={agent_id}"
|
|
91
|
+
)
|
|
92
|
+
return DeduplicationResult(
|
|
93
|
+
is_duplicate=True,
|
|
94
|
+
duplicate_of=mem_id,
|
|
95
|
+
similarity_score=score,
|
|
96
|
+
reason=f"Semantically similar to existing memory (score={score:.3f})",
|
|
97
|
+
metadata={
|
|
98
|
+
"existing_title": existing_mem.get("title", ""),
|
|
99
|
+
"threshold": self.threshold
|
|
100
|
+
}
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return DeduplicationResult(
|
|
104
|
+
is_duplicate=False,
|
|
105
|
+
reason=f"No similar memories above threshold {self.threshold}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f"Error in semantic dedup check: {e}", exc_info=True)
|
|
110
|
+
return DeduplicationResult(
|
|
111
|
+
is_duplicate=False,
|
|
112
|
+
reason=f"Error during check: {str(e)}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
async def find_duplicate_groups(
|
|
116
|
+
self,
|
|
117
|
+
storage_backend: Any,
|
|
118
|
+
agent_id: Optional[str] = None,
|
|
119
|
+
limit: Optional[int] = None
|
|
120
|
+
) -> List[List[str]]:
|
|
121
|
+
"""
|
|
122
|
+
Find clusters of semantically similar memories within agent_id scope.
|
|
123
|
+
|
|
124
|
+
Uses a greedy clustering approach:
|
|
125
|
+
1. For each memory, find all memories above similarity threshold
|
|
126
|
+
2. Group connected memories together
|
|
127
|
+
"""
|
|
128
|
+
if not storage_backend:
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
# Get all memories for this agent
|
|
132
|
+
all_memories = await storage_backend.get_all_memories(agent_id=agent_id)
|
|
133
|
+
all_embeddings = await storage_backend.get_embeddings(
|
|
134
|
+
[m["memory_id"] for m in all_memories]
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
if len(all_memories) < 2:
|
|
138
|
+
return []
|
|
139
|
+
|
|
140
|
+
# Build similarity matrix
|
|
141
|
+
n = len(all_memories)
|
|
142
|
+
visited = set()
|
|
143
|
+
groups = []
|
|
144
|
+
|
|
145
|
+
for i in range(n):
|
|
146
|
+
mem_id_i = all_memories[i]["memory_id"]
|
|
147
|
+
|
|
148
|
+
if mem_id_i in visited:
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
# Start a new group
|
|
152
|
+
current_group = [mem_id_i]
|
|
153
|
+
visited.add(mem_id_i)
|
|
154
|
+
|
|
155
|
+
emb_i = all_embeddings.get(mem_id_i)
|
|
156
|
+
if emb_i is None:
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
# Find all similar memories
|
|
160
|
+
for j in range(i + 1, n):
|
|
161
|
+
mem_id_j = all_memories[j]["memory_id"]
|
|
162
|
+
|
|
163
|
+
if mem_id_j in visited:
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
emb_j = all_embeddings.get(mem_id_j)
|
|
167
|
+
if emb_j is None:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
similarity = cosine_similarity(emb_i, emb_j)
|
|
171
|
+
|
|
172
|
+
if similarity >= self.threshold:
|
|
173
|
+
current_group.append(mem_id_j)
|
|
174
|
+
visited.add(mem_id_j)
|
|
175
|
+
|
|
176
|
+
# Only keep groups with 2+ members
|
|
177
|
+
if len(current_group) >= 2:
|
|
178
|
+
groups.append(current_group)
|
|
179
|
+
|
|
180
|
+
if limit:
|
|
181
|
+
groups = groups[:limit]
|
|
182
|
+
|
|
183
|
+
logger.info(
|
|
184
|
+
f"Found {len(groups)} semantic duplicate groups for agent_id={agent_id}"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return groups
|