iflow-mcp_hanw39_reasoning-bank-mcp 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/METADATA +599 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/RECORD +55 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/WHEEL +4 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/entry_points.txt +2 -0
- iflow_mcp_hanw39_reasoning_bank_mcp-0.2.0.dist-info/licenses/LICENSE +21 -0
- src/__init__.py +16 -0
- src/__main__.py +6 -0
- src/config.py +266 -0
- src/deduplication/__init__.py +19 -0
- src/deduplication/base.py +88 -0
- src/deduplication/factory.py +60 -0
- src/deduplication/strategies/__init__.py +1 -0
- src/deduplication/strategies/semantic_dedup.py +187 -0
- src/default_config.yaml +121 -0
- src/initializers/__init__.py +50 -0
- src/initializers/base.py +196 -0
- src/initializers/embedding_initializer.py +22 -0
- src/initializers/llm_initializer.py +22 -0
- src/initializers/memory_manager_initializer.py +55 -0
- src/initializers/retrieval_initializer.py +32 -0
- src/initializers/storage_initializer.py +22 -0
- src/initializers/tools_initializer.py +48 -0
- src/llm/__init__.py +10 -0
- src/llm/base.py +61 -0
- src/llm/factory.py +75 -0
- src/llm/providers/__init__.py +12 -0
- src/llm/providers/anthropic.py +62 -0
- src/llm/providers/dashscope.py +76 -0
- src/llm/providers/openai.py +76 -0
- src/merge/__init__.py +22 -0
- src/merge/base.py +89 -0
- src/merge/factory.py +60 -0
- src/merge/strategies/__init__.py +1 -0
- src/merge/strategies/llm_merge.py +170 -0
- src/merge/strategies/voting_merge.py +108 -0
- src/prompts/__init__.py +21 -0
- src/prompts/formatters.py +74 -0
- src/prompts/templates.py +184 -0
- src/retrieval/__init__.py +8 -0
- src/retrieval/base.py +37 -0
- src/retrieval/factory.py +55 -0
- src/retrieval/strategies/__init__.py +8 -0
- src/retrieval/strategies/cosine_retrieval.py +47 -0
- src/retrieval/strategies/hybrid_retrieval.py +155 -0
- src/server.py +306 -0
- src/services/__init__.py +5 -0
- src/services/memory_manager.py +403 -0
- src/storage/__init__.py +45 -0
- src/storage/backends/json_backend.py +290 -0
- src/storage/base.py +150 -0
- src/tools/__init__.py +8 -0
- src/tools/extract_memory.py +285 -0
- src/tools/retrieve_memory.py +139 -0
- src/utils/__init__.py +7 -0
- src/utils/similarity.py +54 -0
src/llm/factory.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""LLM 和 Embedding Provider 工厂"""
|
|
2
|
+
from typing import Dict
|
|
3
|
+
from .base import LLMProvider, EmbeddingProvider
|
|
4
|
+
from .providers import (
|
|
5
|
+
DashScopeLLMProvider,
|
|
6
|
+
DashScopeEmbeddingProvider,
|
|
7
|
+
OpenAILLMProvider,
|
|
8
|
+
OpenAIEmbeddingProvider,
|
|
9
|
+
AnthropicLLMProvider,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LLMFactory:
|
|
14
|
+
"""LLM Provider 工厂"""
|
|
15
|
+
|
|
16
|
+
_providers = {
|
|
17
|
+
"dashscope": DashScopeLLMProvider,
|
|
18
|
+
"openai": OpenAILLMProvider,
|
|
19
|
+
"anthropic": AnthropicLLMProvider,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def create(cls, config: Dict) -> LLMProvider:
|
|
24
|
+
"""
|
|
25
|
+
创建 LLM Provider 实例
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config: 配置字典,包含 'provider' 键和对应配置
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
LLMProvider 实例
|
|
32
|
+
"""
|
|
33
|
+
provider_name = config.get("provider")
|
|
34
|
+
if provider_name not in cls._providers:
|
|
35
|
+
raise ValueError(f"Unknown LLM provider: {provider_name}")
|
|
36
|
+
|
|
37
|
+
provider_class = cls._providers[provider_name]
|
|
38
|
+
return provider_class(config)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def register_provider(cls, name: str, provider_class: type):
|
|
42
|
+
"""注册新的 LLM Provider(插件机制)"""
|
|
43
|
+
cls._providers[name] = provider_class
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class EmbeddingFactory:
|
|
47
|
+
"""Embedding Provider 工厂"""
|
|
48
|
+
|
|
49
|
+
_providers = {
|
|
50
|
+
"dashscope": DashScopeEmbeddingProvider,
|
|
51
|
+
"openai": OpenAIEmbeddingProvider,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def create(cls, config: Dict) -> EmbeddingProvider:
|
|
56
|
+
"""
|
|
57
|
+
创建 Embedding Provider 实例
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
config: 配置字典,包含 'provider' 键和对应配置
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
EmbeddingProvider 实例
|
|
64
|
+
"""
|
|
65
|
+
provider_name = config.get("provider")
|
|
66
|
+
if provider_name not in cls._providers:
|
|
67
|
+
raise ValueError(f"Unknown embedding provider: {provider_name}")
|
|
68
|
+
|
|
69
|
+
provider_class = cls._providers[provider_name]
|
|
70
|
+
return provider_class(config)
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def register_provider(cls, name: str, provider_class: type):
|
|
74
|
+
"""注册新的 Embedding Provider(插件机制)"""
|
|
75
|
+
cls._providers[name] = provider_class
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Provider 包"""
|
|
2
|
+
from .dashscope import DashScopeLLMProvider, DashScopeEmbeddingProvider
|
|
3
|
+
from .openai import OpenAILLMProvider, OpenAIEmbeddingProvider
|
|
4
|
+
from .anthropic import AnthropicLLMProvider
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"DashScopeLLMProvider",
|
|
8
|
+
"DashScopeEmbeddingProvider",
|
|
9
|
+
"OpenAILLMProvider",
|
|
10
|
+
"OpenAIEmbeddingProvider",
|
|
11
|
+
"AnthropicLLMProvider",
|
|
12
|
+
]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Anthropic (Claude) Provider 实现"""
|
|
2
|
+
from typing import List, Dict
|
|
3
|
+
try:
|
|
4
|
+
from anthropic import AsyncAnthropic
|
|
5
|
+
ANTHROPIC_AVAILABLE = True
|
|
6
|
+
except ImportError:
|
|
7
|
+
ANTHROPIC_AVAILABLE = False
|
|
8
|
+
|
|
9
|
+
from ..base import LLMProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AnthropicLLMProvider(LLMProvider):
|
|
13
|
+
"""Anthropic Claude LLM Provider"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, config: Dict):
|
|
16
|
+
if not ANTHROPIC_AVAILABLE:
|
|
17
|
+
raise ImportError(
|
|
18
|
+
"Anthropic provider requires 'anthropic' package. "
|
|
19
|
+
"Install it with: pip install anthropic"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
self.api_key = config.get("api_key")
|
|
23
|
+
self.model = config.get("chat_model", "claude-3-5-sonnet-20241022")
|
|
24
|
+
self.default_temperature = config.get("temperature", 0.7)
|
|
25
|
+
self.default_max_tokens = config.get("max_tokens", 4096)
|
|
26
|
+
|
|
27
|
+
self.client = AsyncAnthropic(api_key=self.api_key)
|
|
28
|
+
|
|
29
|
+
async def chat(
|
|
30
|
+
self,
|
|
31
|
+
messages: List[Dict[str, str]],
|
|
32
|
+
temperature: float = None,
|
|
33
|
+
max_tokens: int = None,
|
|
34
|
+
**kwargs
|
|
35
|
+
) -> str:
|
|
36
|
+
"""调用 Anthropic Chat API"""
|
|
37
|
+
temperature = temperature if temperature is not None else self.default_temperature
|
|
38
|
+
max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
|
39
|
+
|
|
40
|
+
# Anthropic API 需要分离 system 和 messages
|
|
41
|
+
system_message = None
|
|
42
|
+
user_messages = []
|
|
43
|
+
|
|
44
|
+
for msg in messages:
|
|
45
|
+
if msg["role"] == "system":
|
|
46
|
+
system_message = msg["content"]
|
|
47
|
+
else:
|
|
48
|
+
user_messages.append(msg)
|
|
49
|
+
|
|
50
|
+
response = await self.client.messages.create(
|
|
51
|
+
model=self.model,
|
|
52
|
+
max_tokens=max_tokens,
|
|
53
|
+
temperature=temperature,
|
|
54
|
+
system=system_message,
|
|
55
|
+
messages=user_messages,
|
|
56
|
+
**kwargs
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return response.content[0].text
|
|
60
|
+
|
|
61
|
+
def get_provider_name(self) -> str:
|
|
62
|
+
return f"anthropic:{self.model}"
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""DashScope (通义千问) Provider 实现"""
|
|
2
|
+
from typing import List, Dict
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
from ..base import LLMProvider, EmbeddingProvider
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DashScopeLLMProvider(LLMProvider):
|
|
8
|
+
"""DashScope LLM Provider(使用 OpenAI 兼容接口)"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, config: Dict):
|
|
11
|
+
self.api_key = config.get("api_key")
|
|
12
|
+
self.base_url = config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
13
|
+
self.model = config.get("chat_model", "qwen-plus")
|
|
14
|
+
self.default_temperature = config.get("temperature", 0.7)
|
|
15
|
+
self.default_max_tokens = config.get("max_tokens", 4096)
|
|
16
|
+
|
|
17
|
+
self.client = AsyncOpenAI(
|
|
18
|
+
api_key=self.api_key,
|
|
19
|
+
base_url=self.base_url
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
async def chat(
|
|
23
|
+
self,
|
|
24
|
+
messages: List[Dict[str, str]],
|
|
25
|
+
temperature: float = None,
|
|
26
|
+
max_tokens: int = None,
|
|
27
|
+
**kwargs
|
|
28
|
+
) -> str:
|
|
29
|
+
"""调用 DashScope Chat API"""
|
|
30
|
+
temperature = temperature if temperature is not None else self.default_temperature
|
|
31
|
+
max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
|
32
|
+
|
|
33
|
+
response = await self.client.chat.completions.create(
|
|
34
|
+
model=self.model,
|
|
35
|
+
messages=messages,
|
|
36
|
+
temperature=temperature,
|
|
37
|
+
max_tokens=max_tokens,
|
|
38
|
+
**kwargs
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return response.choices[0].message.content
|
|
42
|
+
|
|
43
|
+
def get_provider_name(self) -> str:
|
|
44
|
+
return f"dashscope:{self.model}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DashScopeEmbeddingProvider(EmbeddingProvider):
|
|
48
|
+
"""DashScope Embedding Provider"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, config: Dict):
|
|
51
|
+
self.api_key = config.get("api_key")
|
|
52
|
+
self.base_url = config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
53
|
+
self.model = config.get("model", "text-embedding-v3")
|
|
54
|
+
|
|
55
|
+
self.client = AsyncOpenAI(
|
|
56
|
+
api_key=self.api_key,
|
|
57
|
+
base_url=self.base_url
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# text-embedding-v3 是 1024 维
|
|
61
|
+
self._embedding_dim = 1024
|
|
62
|
+
|
|
63
|
+
async def embed(self, text: str) -> List[float]:
|
|
64
|
+
"""调用 DashScope Embedding API"""
|
|
65
|
+
response = await self.client.embeddings.create(
|
|
66
|
+
model=self.model,
|
|
67
|
+
input=text
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return response.data[0].embedding
|
|
71
|
+
|
|
72
|
+
def get_provider_name(self) -> str:
|
|
73
|
+
return f"dashscope:{self.model}"
|
|
74
|
+
|
|
75
|
+
def get_embedding_dim(self) -> int:
|
|
76
|
+
return self._embedding_dim
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""OpenAI Provider 实现"""
|
|
2
|
+
from typing import List, Dict
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
from ..base import LLMProvider, EmbeddingProvider
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OpenAILLMProvider(LLMProvider):
|
|
8
|
+
"""OpenAI LLM Provider"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, config: Dict):
|
|
11
|
+
self.api_key = config.get("api_key")
|
|
12
|
+
self.base_url = config.get("base_url", "https://api.openai.com/v1")
|
|
13
|
+
self.model = config.get("chat_model", "gpt-4o-mini")
|
|
14
|
+
self.default_temperature = config.get("temperature", 0.7)
|
|
15
|
+
self.default_max_tokens = config.get("max_tokens", 4096)
|
|
16
|
+
|
|
17
|
+
self.client = AsyncOpenAI(
|
|
18
|
+
api_key=self.api_key,
|
|
19
|
+
base_url=self.base_url
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
async def chat(
|
|
23
|
+
self,
|
|
24
|
+
messages: List[Dict[str, str]],
|
|
25
|
+
temperature: float = None,
|
|
26
|
+
max_tokens: int = None,
|
|
27
|
+
**kwargs
|
|
28
|
+
) -> str:
|
|
29
|
+
"""调用 OpenAI Chat API"""
|
|
30
|
+
temperature = temperature if temperature is not None else self.default_temperature
|
|
31
|
+
max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
|
32
|
+
|
|
33
|
+
response = await self.client.chat.completions.create(
|
|
34
|
+
model=self.model,
|
|
35
|
+
messages=messages,
|
|
36
|
+
temperature=temperature,
|
|
37
|
+
max_tokens=max_tokens,
|
|
38
|
+
**kwargs
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return response.choices[0].message.content
|
|
42
|
+
|
|
43
|
+
def get_provider_name(self) -> str:
|
|
44
|
+
return f"openai:{self.model}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
48
|
+
"""OpenAI Embedding Provider"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, config: Dict):
|
|
51
|
+
self.api_key = config.get("api_key")
|
|
52
|
+
self.base_url = config.get("base_url", "https://api.openai.com/v1")
|
|
53
|
+
self.model = config.get("model", "text-embedding-3-small")
|
|
54
|
+
|
|
55
|
+
self.client = AsyncOpenAI(
|
|
56
|
+
api_key=self.api_key,
|
|
57
|
+
base_url=self.base_url
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# text-embedding-3-small 是 1536 维
|
|
61
|
+
self._embedding_dim = 1536 if "small" in self.model else 3072
|
|
62
|
+
|
|
63
|
+
async def embed(self, text: str) -> List[float]:
|
|
64
|
+
"""调用 OpenAI Embedding API"""
|
|
65
|
+
response = await self.client.embeddings.create(
|
|
66
|
+
model=self.model,
|
|
67
|
+
input=text
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return response.data[0].embedding
|
|
71
|
+
|
|
72
|
+
def get_provider_name(self) -> str:
|
|
73
|
+
return f"openai:{self.model}"
|
|
74
|
+
|
|
75
|
+
def get_embedding_dim(self) -> int:
|
|
76
|
+
return self._embedding_dim
|
src/merge/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Merge module initialization
|
|
3
|
+
|
|
4
|
+
Registers all built-in merge strategies.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .base import MergeStrategy, MergeResult
|
|
8
|
+
from .factory import MergeFactory
|
|
9
|
+
from .strategies.voting_merge import VotingMergeStrategy
|
|
10
|
+
from .strategies.llm_merge import LLMMergeStrategy
|
|
11
|
+
|
|
12
|
+
# Register built-in strategies
|
|
13
|
+
MergeFactory.register("voting", VotingMergeStrategy)
|
|
14
|
+
MergeFactory.register("llm", LLMMergeStrategy)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"MergeStrategy",
|
|
18
|
+
"MergeResult",
|
|
19
|
+
"MergeFactory",
|
|
20
|
+
"VotingMergeStrategy",
|
|
21
|
+
"LLMMergeStrategy",
|
|
22
|
+
]
|
src/merge/base.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Merge Strategy Base Interface
|
|
3
|
+
|
|
4
|
+
Provides pluggable merge strategies for combining similar memories.
|
|
5
|
+
All operations MUST respect agent_id isolation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import List, Dict, Any, Optional
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class MergeResult:
|
|
15
|
+
"""Result of a merge operation"""
|
|
16
|
+
success: bool
|
|
17
|
+
merged_memory_id: Optional[str] = None
|
|
18
|
+
merged_from: List[str] = field(default_factory=list)
|
|
19
|
+
abstraction_level: int = 1 # 0=specific case, 1=pattern, 2=principle
|
|
20
|
+
reason: str = ""
|
|
21
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MergeStrategy(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Abstract base class for merge strategies.
|
|
27
|
+
|
|
28
|
+
Different strategies can implement different merge algorithms:
|
|
29
|
+
- LLM-based (extract common patterns)
|
|
30
|
+
- Template-based (rule-based merging)
|
|
31
|
+
- Voting-based (keep best, remove rest)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: Dict[str, Any]):
|
|
35
|
+
"""
|
|
36
|
+
Initialize strategy with configuration.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
config: Strategy-specific configuration
|
|
40
|
+
"""
|
|
41
|
+
self.config = config
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
async def should_merge(
|
|
45
|
+
self,
|
|
46
|
+
memories: List[Dict[str, Any]],
|
|
47
|
+
agent_id: Optional[str] = None
|
|
48
|
+
) -> bool:
|
|
49
|
+
"""
|
|
50
|
+
Determine if a group of memories should be merged.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
memories: List of memory dicts to evaluate
|
|
54
|
+
agent_id: Agent ID (all memories must belong to same agent)
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
True if memories should be merged, False otherwise
|
|
58
|
+
"""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
async def merge(
|
|
63
|
+
self,
|
|
64
|
+
memories: List[Dict[str, Any]],
|
|
65
|
+
agent_id: Optional[str] = None
|
|
66
|
+
) -> Dict[str, Any]:
|
|
67
|
+
"""
|
|
68
|
+
Merge multiple memories into a single, more general memory.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
memories: List of memory dicts to merge (all from same agent_id)
|
|
72
|
+
agent_id: Agent ID for validation
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
New merged memory dict with keys:
|
|
76
|
+
- title: str
|
|
77
|
+
- content: str
|
|
78
|
+
- description: str
|
|
79
|
+
- query: str (generalized)
|
|
80
|
+
- abstraction_level: int
|
|
81
|
+
- merged_from: List[str] (memory_ids)
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def name(self) -> str:
|
|
88
|
+
"""Return strategy name for logging and config"""
|
|
89
|
+
pass
|
src/merge/factory.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Merge Strategy Factory
|
|
3
|
+
|
|
4
|
+
Provides plugin mechanism for registering and creating merge strategies.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, Type, Any
|
|
8
|
+
from .base import MergeStrategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MergeFactory:
|
|
12
|
+
"""Factory for creating merge strategy instances"""
|
|
13
|
+
|
|
14
|
+
_strategies: Dict[str, Type[MergeStrategy]] = {}
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def register(cls, name: str, strategy_class: Type[MergeStrategy]):
|
|
18
|
+
"""
|
|
19
|
+
Register a merge strategy.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
name: Strategy name (e.g., "llm", "voting", "template")
|
|
23
|
+
strategy_class: Strategy class (must inherit from MergeStrategy)
|
|
24
|
+
"""
|
|
25
|
+
cls._strategies[name] = strategy_class
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def create(cls, config: Any) -> MergeStrategy:
|
|
29
|
+
"""
|
|
30
|
+
Create a merge strategy instance based on config.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config: Config object with get() method
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
MergeStrategy instance
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If strategy name is not registered
|
|
40
|
+
"""
|
|
41
|
+
# 使用统一的配置访问方式
|
|
42
|
+
strategy_name = config.get('memory_manager', 'merge', 'strategy', default='voting')
|
|
43
|
+
|
|
44
|
+
if strategy_name not in cls._strategies:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Unknown merge strategy: {strategy_name}. "
|
|
47
|
+
f"Available: {list(cls._strategies.keys())}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
strategy_class = cls._strategies[strategy_name]
|
|
51
|
+
|
|
52
|
+
# 获取合并配置
|
|
53
|
+
merge_config = config.get('memory_manager', 'merge', default={})
|
|
54
|
+
|
|
55
|
+
return strategy_class(merge_config)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def list_strategies(cls) -> list:
|
|
59
|
+
"""Return list of registered strategy names"""
|
|
60
|
+
return list(cls._strategies.keys())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Strategies submodule"""
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM-based Merge Strategy
|
|
3
|
+
|
|
4
|
+
Uses LLM to extract common patterns and create a generalized memory.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Dict, Any, Optional
|
|
8
|
+
from ..base import MergeStrategy
|
|
9
|
+
import logging
|
|
10
|
+
import json
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LLMMergeStrategy(MergeStrategy):
|
|
16
|
+
"""
|
|
17
|
+
LLM-driven merge: Extract common patterns from similar memories.
|
|
18
|
+
|
|
19
|
+
Creates a new, more abstract memory that captures the essence
|
|
20
|
+
of multiple specific experiences.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def name(self) -> str:
|
|
25
|
+
return "llm"
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: Dict[str, Any]):
|
|
28
|
+
super().__init__(config)
|
|
29
|
+
self.min_group_size = config.get("trigger", {}).get("min_similar_count", 3)
|
|
30
|
+
self.llm_config = config.get("llm", {})
|
|
31
|
+
self.temperature = self.llm_config.get("temperature", 0.7)
|
|
32
|
+
self.llm_provider = None # Will be injected by MemoryManager
|
|
33
|
+
|
|
34
|
+
def set_llm_provider(self, llm_provider):
|
|
35
|
+
"""Inject LLM provider dependency"""
|
|
36
|
+
self.llm_provider = llm_provider
|
|
37
|
+
|
|
38
|
+
async def should_merge(
|
|
39
|
+
self,
|
|
40
|
+
memories: List[Dict[str, Any]],
|
|
41
|
+
agent_id: Optional[str] = None
|
|
42
|
+
) -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Check if group meets criteria for LLM merge.
|
|
45
|
+
|
|
46
|
+
Requirements:
|
|
47
|
+
1. At least min_group_size memories
|
|
48
|
+
2. All from same agent_id
|
|
49
|
+
3. Majority should be successful experiences
|
|
50
|
+
"""
|
|
51
|
+
if len(memories) < self.min_group_size:
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
# Validate agent_id consistency
|
|
55
|
+
if agent_id:
|
|
56
|
+
for mem in memories:
|
|
57
|
+
if mem.get("agent_id") != agent_id:
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Memory {mem.get('memory_id')} has different agent_id"
|
|
60
|
+
)
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
# Check success rate
|
|
64
|
+
success_count = sum(1 for m in memories if m.get("success", False))
|
|
65
|
+
success_rate = success_count / len(memories)
|
|
66
|
+
|
|
67
|
+
# Only merge if majority are successful (avoid mixing success/failure)
|
|
68
|
+
if success_rate < 0.6:
|
|
69
|
+
logger.info(
|
|
70
|
+
f"Success rate too low ({success_rate:.2f}) for merge, "
|
|
71
|
+
f"agent_id={agent_id}"
|
|
72
|
+
)
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
async def merge(
|
|
78
|
+
self,
|
|
79
|
+
memories: List[Dict[str, Any]],
|
|
80
|
+
agent_id: Optional[str] = None
|
|
81
|
+
) -> Dict[str, Any]:
|
|
82
|
+
"""
|
|
83
|
+
Use LLM to extract common patterns and create generalized memory.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
memories: List of similar memories (all from same agent_id)
|
|
87
|
+
agent_id: Agent ID for validation
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
New merged memory dict
|
|
91
|
+
"""
|
|
92
|
+
if not self.llm_provider:
|
|
93
|
+
raise RuntimeError("LLM provider not set. Call set_llm_provider() first.")
|
|
94
|
+
|
|
95
|
+
if not memories:
|
|
96
|
+
raise ValueError("Cannot merge empty memory list")
|
|
97
|
+
|
|
98
|
+
# Validate agent_id
|
|
99
|
+
if agent_id:
|
|
100
|
+
for mem in memories:
|
|
101
|
+
if mem.get("agent_id") != agent_id:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Memory {mem.get('memory_id')} belongs to different agent"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Build prompt for LLM
|
|
107
|
+
prompt = self._build_merge_prompt(memories)
|
|
108
|
+
|
|
109
|
+
# Call LLM
|
|
110
|
+
try:
|
|
111
|
+
response = await self.llm_provider.chat(
|
|
112
|
+
messages=[{"role": "user", "content": prompt}],
|
|
113
|
+
temperature=self.temperature
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Parse JSON response
|
|
117
|
+
merged_data = json.loads(response)
|
|
118
|
+
|
|
119
|
+
# Validate response
|
|
120
|
+
required_fields = ["title", "content", "description", "abstraction_level"]
|
|
121
|
+
for field in required_fields:
|
|
122
|
+
if field not in merged_data:
|
|
123
|
+
raise ValueError(f"LLM response missing required field: {field}")
|
|
124
|
+
|
|
125
|
+
# Build merged memory
|
|
126
|
+
merged_memory = {
|
|
127
|
+
"title": merged_data["title"],
|
|
128
|
+
"content": merged_data["content"],
|
|
129
|
+
"description": merged_data["description"],
|
|
130
|
+
"query": merged_data.get("query", "<通用场景>"),
|
|
131
|
+
"success": True, # Merged memories are positive learnings
|
|
132
|
+
"agent_id": agent_id,
|
|
133
|
+
"is_merged": True,
|
|
134
|
+
"merged_from": [m["memory_id"] for m in memories],
|
|
135
|
+
"merge_metadata": {
|
|
136
|
+
"merge_strategy": self.name,
|
|
137
|
+
"original_count": len(memories),
|
|
138
|
+
"abstraction_level": merged_data.get("abstraction_level", 1),
|
|
139
|
+
"llm_model": self.llm_config.get("model", "unknown")
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
logger.info(
|
|
144
|
+
f"LLM merge successful: {len(memories)} memories -> "
|
|
145
|
+
f"'{merged_memory['title']}' (agent_id={agent_id})"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
return merged_memory
|
|
149
|
+
|
|
150
|
+
except json.JSONDecodeError as e:
|
|
151
|
+
logger.error(f"Failed to parse LLM response as JSON: {e}")
|
|
152
|
+
raise
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.error(f"Error during LLM merge: {e}", exc_info=True)
|
|
155
|
+
raise
|
|
156
|
+
|
|
157
|
+
def _build_merge_prompt(self, memories: List[Dict[str, Any]]) -> str:
|
|
158
|
+
"""Build prompt for LLM to merge memories"""
|
|
159
|
+
|
|
160
|
+
# Format memories for LLM
|
|
161
|
+
memories_text = ""
|
|
162
|
+
for i, mem in enumerate(memories, 1):
|
|
163
|
+
memories_text += f"\n### 经验 {i}\n"
|
|
164
|
+
memories_text += f"**标题**: {mem.get('title', 'N/A')}\n"
|
|
165
|
+
memories_text += f"**描述**: {mem.get('description', 'N/A')}\n"
|
|
166
|
+
memories_text += f"**内容**: {mem.get('content', 'N/A')}\n"
|
|
167
|
+
memories_text += f"**原始场景**: {mem.get('query', 'N/A')}\n"
|
|
168
|
+
from ...prompts.templates import get_merge_prompt
|
|
169
|
+
|
|
170
|
+
return get_merge_prompt(memories_text)
|