keepsake-memory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keepsake/__init__.py +558 -0
- keepsake/attention.py +146 -0
- keepsake/consolidator.py +395 -0
- keepsake/embedder.py +155 -0
- keepsake/emotion.py +136 -0
- keepsake/forgetter.py +262 -0
- keepsake/py.typed +0 -0
- keepsake/splitter.py +436 -0
- keepsake/storage.py +1360 -0
- keepsake_memory-1.0.0.dist-info/METADATA +424 -0
- keepsake_memory-1.0.0.dist-info/RECORD +14 -0
- keepsake_memory-1.0.0.dist-info/WHEEL +5 -0
- keepsake_memory-1.0.0.dist-info/licenses/LICENSE +21 -0
- keepsake_memory-1.0.0.dist-info/top_level.txt +1 -0
keepsake/attention.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""注意力追踪 — 统计用户对各个话题的关注频率。
|
|
2
|
+
|
|
3
|
+
用户反复提起某个话题 -> 该话题关注度上升 -> 相关碎片在搜索中权重更高。
|
|
4
|
+
|
|
5
|
+
存储: Redis Sorted Set `keepsake:attention`
|
|
6
|
+
- member: 话题词(由 jieba/关键词提取来)
|
|
7
|
+
- score: 关注度累计值(每次提及 +2,情绪烈度加权)
|
|
8
|
+
|
|
9
|
+
三套时间窗口(同 hot_topics 模式):
|
|
10
|
+
- 全局: fractured:attention (7天)
|
|
11
|
+
- 日榜: fractured:attention:daily (2天)
|
|
12
|
+
- 周榜: fractured:attention:weekly (14天)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any, Dict, List, Optional
|
|
19
|
+
|
|
20
|
+
import redis
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Redis key 前缀
|
|
25
|
+
ATTENTION_SET = "keepsake:attention"
|
|
26
|
+
ATTENTION_DAILY = "keepsake:attention:daily"
|
|
27
|
+
ATTENTION_WEEKLY = "keepsake:attention:weekly"
|
|
28
|
+
|
|
29
|
+
# 过期时间
|
|
30
|
+
_ATTENTION_TTL = {
|
|
31
|
+
ATTENTION_SET: 86400 * 7, # 全局:7天
|
|
32
|
+
ATTENTION_DAILY: 86400 * 2, # 日榜:2天
|
|
33
|
+
ATTENTION_WEEKLY: 86400 * 14, # 周榜:14天
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def record_attention(
|
|
38
|
+
client: redis.Redis,
|
|
39
|
+
text: str,
|
|
40
|
+
emotion_intensity: float = 0.0,
|
|
41
|
+
keywords: Optional[List[str]] = None,
|
|
42
|
+
base_increment: float = 2.0,
|
|
43
|
+
emotion_factor: float = 1.5,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""记录用户对一段文本中话题的关注。"""
|
|
46
|
+
if not client or not text:
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
# 没有关键词时跳过(正常路径由 store() 传入,不会走到这里)
|
|
51
|
+
if not keywords:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
increment = base_increment + emotion_intensity * emotion_factor
|
|
55
|
+
|
|
56
|
+
for kw in keywords:
|
|
57
|
+
kw_lower = kw.lower().strip()
|
|
58
|
+
if len(kw_lower) < 2:
|
|
59
|
+
continue
|
|
60
|
+
for topic_set in (ATTENTION_SET, ATTENTION_DAILY, ATTENTION_WEEKLY):
|
|
61
|
+
client.zincrby(topic_set, increment, kw_lower)
|
|
62
|
+
client.expire(topic_set, _ATTENTION_TTL.get(topic_set, 86400))
|
|
63
|
+
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.debug("attention: record_attention error: %s", e)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_attention_score(client: redis.Redis, keyword: str) -> float:
|
|
69
|
+
"""查某个词在当前注意力分数中的排名分。"""
|
|
70
|
+
if not client or not keyword:
|
|
71
|
+
return 0.0
|
|
72
|
+
try:
|
|
73
|
+
score = client.zscore(ATTENTION_SET, keyword.lower().strip())
|
|
74
|
+
return score if score is not None else 0.0
|
|
75
|
+
except Exception:
|
|
76
|
+
return 0.0
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_top_attention(client: redis.Redis, limit: int = 10,
|
|
80
|
+
period: str = "all") -> List[Dict[str, Any]]:
|
|
81
|
+
"""获取关注度最高的词。"""
|
|
82
|
+
key = {
|
|
83
|
+
"all": ATTENTION_SET,
|
|
84
|
+
"daily": ATTENTION_DAILY,
|
|
85
|
+
"weekly": ATTENTION_WEEKLY,
|
|
86
|
+
}.get(period, ATTENTION_SET)
|
|
87
|
+
|
|
88
|
+
if not client:
|
|
89
|
+
return []
|
|
90
|
+
try:
|
|
91
|
+
raw = client.zrevrange(key, 0, limit - 1, withscores=True)
|
|
92
|
+
results = []
|
|
93
|
+
for t, s_raw in raw:
|
|
94
|
+
topic = t.decode("utf-8") if isinstance(t, bytes) else t
|
|
95
|
+
if isinstance(s_raw, bytes):
|
|
96
|
+
s_raw = s_raw.decode("utf-8")
|
|
97
|
+
results.append({"topic": topic, "score": round(float(s_raw), 1)})
|
|
98
|
+
return results
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.debug("attention: get_top_attention error: %s", e)
|
|
101
|
+
return []
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def match_attention_boost(
|
|
105
|
+
client: redis.Redis,
|
|
106
|
+
content: str,
|
|
107
|
+
top_n: int = 10,
|
|
108
|
+
boost_max: float = 1.5,
|
|
109
|
+
) -> float:
|
|
110
|
+
"""检查碎片内容的注意力关注度加权值。
|
|
111
|
+
|
|
112
|
+
从全局注意力取 top N 话题,看碎片内容命中几个。
|
|
113
|
+
命中越多权重越高,最高 boost_max。
|
|
114
|
+
"""
|
|
115
|
+
if not client or not content:
|
|
116
|
+
return 1.0
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
raw = client.zrevrange(ATTENTION_SET, 0, top_n - 1, withscores=True)
|
|
120
|
+
if not raw:
|
|
121
|
+
return 1.0
|
|
122
|
+
|
|
123
|
+
content_lower = content.lower()
|
|
124
|
+
total_score = 0.0
|
|
125
|
+
max_score = 0.0
|
|
126
|
+
|
|
127
|
+
for topic_b, score_raw in raw:
|
|
128
|
+
topic = topic_b.decode("utf-8") if isinstance(topic_b, bytes) else topic_b
|
|
129
|
+
if isinstance(score_raw, bytes):
|
|
130
|
+
score_raw = score_raw.decode("utf-8")
|
|
131
|
+
sc = float(score_raw) # noqa: F841
|
|
132
|
+
if isinstance(sc, (int, float)):
|
|
133
|
+
if len(topic) >= 2 and topic in content_lower:
|
|
134
|
+
total_score += sc
|
|
135
|
+
max_score += sc
|
|
136
|
+
|
|
137
|
+
if max_score <= 0:
|
|
138
|
+
return 1.0
|
|
139
|
+
|
|
140
|
+
# 归一化到 1.0~boost_max,命中越高越接近 boost_max
|
|
141
|
+
ratio = min(total_score / max_score, 1.0)
|
|
142
|
+
return 1.0 + (boost_max - 1.0) * ratio
|
|
143
|
+
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.debug("attention: match_attention_boost error: %s", e)
|
|
146
|
+
return 1.0
|
keepsake/consolidator.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""Consolidation 引擎 — 将同主题碎片分层提炼为更高层记忆。
|
|
2
|
+
|
|
3
|
+
工作流程:
|
|
4
|
+
1. 扫描所有未合并的碎片(fragment_type != "consolidated")
|
|
5
|
+
2. 用 jieba 关键词提取做主题分组
|
|
6
|
+
3. 每组超过 min_group_size 条时,调 LLM 合并为一条高层次摘要
|
|
7
|
+
4. 存为新碎片(fragment_type="consolidated", level=N+1)
|
|
8
|
+
5. 删除原始碎片(或标记已合并)
|
|
9
|
+
|
|
10
|
+
配置参数:
|
|
11
|
+
- min_group_size: 最少多少条碎片才触发合并(默认 3)
|
|
12
|
+
- max_age_hours: 只合并超过此年龄的碎片(给新碎片时间积累,默认 72h)
|
|
13
|
+
- llm_model: DashScope 模型名(默认 qwen-turbo,便宜够用)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import time
|
|
22
|
+
import urllib.error
|
|
23
|
+
import urllib.request
|
|
24
|
+
from collections import defaultdict
|
|
25
|
+
from datetime import datetime, timezone
|
|
26
|
+
from typing import Any, Dict, List, Optional
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# DashScope API 端点
|
|
31
|
+
DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
32
|
+
|
|
33
|
+
# 默认参数
|
|
34
|
+
DEFAULT_MIN_GROUP_SIZE = 2 # 有重复内容就合
|
|
35
|
+
DEFAULT_MAX_AGE_HOURS = 72
|
|
36
|
+
DEFAULT_LLM_MODEL = "qwen-plus"
|
|
37
|
+
DEFAULT_BATCH_SIZE = 200 # 每次 consolidate 扫描的碎片数
|
|
38
|
+
|
|
39
|
+
# LLM 超时
|
|
40
|
+
LLM_TIMEOUT = 30
|
|
41
|
+
|
|
42
|
+
# 合并提示词
|
|
43
|
+
CONSOLIDATE_PROMPT = """你是一位知识提炼专家。以下是一组关于同一话题的对话片段。
|
|
44
|
+
|
|
45
|
+
请将它们合并成一条简洁、信息完整的高层记忆条目,要求:
|
|
46
|
+
1. 保留所有关键事实和结论,不丢信息
|
|
47
|
+
2. 去掉重复内容
|
|
48
|
+
3. 用陈述句表达,像一条知识条目
|
|
49
|
+
4. 如果片段之间存在矛盾,指出矛盾但不选边
|
|
50
|
+
5. 控制在 200 字以内
|
|
51
|
+
|
|
52
|
+
对话片段:
|
|
53
|
+
{segments}
|
|
54
|
+
|
|
55
|
+
合并后的知识条目:"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_api_key() -> str:
|
|
59
|
+
"""获取 DashScope API key。"""
|
|
60
|
+
key = os.environ.get("OPENAI_API_KEY", "") or os.environ.get("DASHSCOPE_API_KEY", "")
|
|
61
|
+
# fallback: 从 config.yaml 用 yaml 解析,优先取 providers.dashscope.api_key
|
|
62
|
+
if not key:
|
|
63
|
+
try:
|
|
64
|
+
import yaml
|
|
65
|
+
config_path = os.path.expanduser("~/.hermes/config.yaml")
|
|
66
|
+
if os.path.isfile(config_path):
|
|
67
|
+
with open(config_path) as f:
|
|
68
|
+
cfg = yaml.safe_load(f)
|
|
69
|
+
if cfg:
|
|
70
|
+
key = (
|
|
71
|
+
cfg.get("providers", {}).get("dashscope", {}).get("api_key")
|
|
72
|
+
or cfg.get("model", {}).get("api_key")
|
|
73
|
+
or ""
|
|
74
|
+
)
|
|
75
|
+
key = key.strip().strip("'\"")
|
|
76
|
+
except Exception:
|
|
77
|
+
pass
|
|
78
|
+
return key
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _call_llm(messages: List[Dict[str, str]], model: str = DEFAULT_LLM_MODEL,
|
|
82
|
+
max_retries: int = 2) -> Optional[str]:
|
|
83
|
+
"""调用 DashScope chat API 获取 LLM 回复。带重试。"""
|
|
84
|
+
api_key = _get_api_key()
|
|
85
|
+
if not api_key:
|
|
86
|
+
logger.warning("consolidator: no API key for LLM calls")
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
url = f"{DASHSCOPE_BASE}/chat/completions"
|
|
90
|
+
payload = json.dumps({
|
|
91
|
+
"model": model,
|
|
92
|
+
"messages": messages,
|
|
93
|
+
"max_tokens": 512,
|
|
94
|
+
"temperature": 0.3,
|
|
95
|
+
}).encode("utf-8")
|
|
96
|
+
|
|
97
|
+
for attempt in range(1 + max_retries):
|
|
98
|
+
if attempt > 0:
|
|
99
|
+
wait = 2.0 * (2 ** (attempt - 1)) # 2s, 4s
|
|
100
|
+
logger.debug("consolidator: retry %d/%d after %.0fs", attempt, max_retries, wait)
|
|
101
|
+
time.sleep(wait)
|
|
102
|
+
|
|
103
|
+
req = urllib.request.Request(
|
|
104
|
+
url, data=payload,
|
|
105
|
+
headers={
|
|
106
|
+
"Authorization": f"Bearer {api_key}",
|
|
107
|
+
"Content-Type": "application/json",
|
|
108
|
+
},
|
|
109
|
+
method="POST",
|
|
110
|
+
)
|
|
111
|
+
try:
|
|
112
|
+
with urllib.request.urlopen(req, timeout=LLM_TIMEOUT) as resp:
|
|
113
|
+
data = json.loads(resp.read())
|
|
114
|
+
choices = data.get("choices", [])
|
|
115
|
+
if choices:
|
|
116
|
+
return choices[0].get("message", {}).get("content", "").strip()
|
|
117
|
+
# Got response but no choices — don't retry
|
|
118
|
+
logger.warning("consolidator: LLM returned no choices (attempt %d)", attempt + 1)
|
|
119
|
+
return None
|
|
120
|
+
except (urllib.error.URLError, json.JSONDecodeError, OSError) as e:
|
|
121
|
+
if attempt < max_retries:
|
|
122
|
+
logger.debug("consolidator: LLM attempt %d failed: %s", attempt + 1, e)
|
|
123
|
+
else:
|
|
124
|
+
logger.warning("consolidator: LLM call failed after %d attempts: %s",
|
|
125
|
+
attempt + 1, e)
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class Consolidator:
|
|
130
|
+
"""碎片分层提炼引擎。"""
|
|
131
|
+
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
storage: Any, # RedisStorage instance (avoid circular import)
|
|
135
|
+
min_group_size: int = DEFAULT_MIN_GROUP_SIZE,
|
|
136
|
+
max_age_hours: int = DEFAULT_MAX_AGE_HOURS,
|
|
137
|
+
llm_model: str = DEFAULT_LLM_MODEL,
|
|
138
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
139
|
+
):
|
|
140
|
+
self._storage = storage
|
|
141
|
+
self._min_group_size = min_group_size
|
|
142
|
+
self._max_age_hours = max_age_hours
|
|
143
|
+
self._llm_model = llm_model
|
|
144
|
+
self._batch_size = batch_size
|
|
145
|
+
|
|
146
|
+
def consolidate(self) -> Dict[str, Any]:
|
|
147
|
+
"""执行一轮碎片合并。返回操作统计。"""
|
|
148
|
+
client = self._storage._get_client()
|
|
149
|
+
if not client:
|
|
150
|
+
return {"status": "error", "reason": "Redis not available"}
|
|
151
|
+
|
|
152
|
+
stats = {"scanned": 0, "groups_found": 0, "merged": 0, "skipped": 0, "errors": 0}
|
|
153
|
+
|
|
154
|
+
# 1. 扫描未合并的碎片
|
|
155
|
+
fragments = self._scan_unconsolidated(client)
|
|
156
|
+
stats["scanned"] = len(fragments)
|
|
157
|
+
if not fragments:
|
|
158
|
+
return stats
|
|
159
|
+
|
|
160
|
+
# 2. 按主题聚类
|
|
161
|
+
groups = self._cluster_by_topic(fragments)
|
|
162
|
+
stats["groups_found"] = len(groups)
|
|
163
|
+
|
|
164
|
+
# 3. 对每个符合条件的组执行合并
|
|
165
|
+
for group in groups:
|
|
166
|
+
if len(group) < self._min_group_size:
|
|
167
|
+
stats["skipped"] += len(group)
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
result = self._merge_group(client, group)
|
|
171
|
+
if result:
|
|
172
|
+
stats["merged"] += len(group)
|
|
173
|
+
else:
|
|
174
|
+
stats["errors"] += len(group)
|
|
175
|
+
|
|
176
|
+
return stats
|
|
177
|
+
|
|
178
|
+
def _scan_unconsolidated(self, client) -> List[Dict[str, Any]]:
|
|
179
|
+
"""扫描符合合并条件的碎片。
|
|
180
|
+
|
|
181
|
+
条件:
|
|
182
|
+
- fragment_type != "consumed"(未被更高层合并吞掉的)
|
|
183
|
+
- 创建时间 > max_age_hours(给新碎片时间积累)
|
|
184
|
+
- 已合并的(consolidated)也参与扫描,实现多级提炼
|
|
185
|
+
"""
|
|
186
|
+
try:
|
|
187
|
+
cutoff = (datetime.now(timezone.utc).timestamp() - self._max_age_hours * 3600)
|
|
188
|
+
cursor = 0
|
|
189
|
+
fragments = []
|
|
190
|
+
|
|
191
|
+
while True:
|
|
192
|
+
cursor, keys = client.scan(
|
|
193
|
+
cursor=cursor,
|
|
194
|
+
match="memory:frag:*",
|
|
195
|
+
count=self._batch_size,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if not keys:
|
|
199
|
+
if cursor == 0:
|
|
200
|
+
break
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
# 用 pipeline 批量 HMGETALL,减少网络往返
|
|
204
|
+
pipe = client.pipeline()
|
|
205
|
+
for key_b in keys:
|
|
206
|
+
pipe.hgetall(key_b)
|
|
207
|
+
pipe_results = pipe.execute()
|
|
208
|
+
|
|
209
|
+
for key_b, data in zip(keys, pipe_results):
|
|
210
|
+
key = key_b.decode("utf-8") if isinstance(key_b, bytes) else key_b
|
|
211
|
+
if not data:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
# 解码
|
|
215
|
+
doc = {}
|
|
216
|
+
for k_b, v_b in data.items():
|
|
217
|
+
k = k_b.decode("utf-8") if isinstance(k_b, bytes) else k_b
|
|
218
|
+
v = v_b.decode("utf-8") if isinstance(v_b, bytes) else v_b
|
|
219
|
+
doc[k] = v
|
|
220
|
+
|
|
221
|
+
# 跳过已被更高层合并吞掉的
|
|
222
|
+
if doc.get("fragment_type", "") == "consumed":
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# 检查年龄
|
|
226
|
+
created_str = doc.get("created", "")
|
|
227
|
+
if created_str:
|
|
228
|
+
try:
|
|
229
|
+
created_ts = datetime.fromisoformat(created_str).timestamp()
|
|
230
|
+
if created_ts > cutoff:
|
|
231
|
+
continue # 太新,等下次
|
|
232
|
+
except (ValueError, TypeError):
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
doc["_key"] = key
|
|
236
|
+
fragments.append(doc)
|
|
237
|
+
|
|
238
|
+
if cursor == 0:
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
return fragments
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.warning("consolidator: scan error: %s", e)
|
|
245
|
+
return []
|
|
246
|
+
|
|
247
|
+
def _cluster_by_topic(self, fragments: List[Dict]) -> List[List[Dict]]:
|
|
248
|
+
"""按关键词重叠做简单聚类。
|
|
249
|
+
|
|
250
|
+
策略:
|
|
251
|
+
- 对每个碎片提取关键词(用 jieba)
|
|
252
|
+
- 关键词重叠 >= 2 的归为一组
|
|
253
|
+
- 贪心算法,不追求最优聚类
|
|
254
|
+
"""
|
|
255
|
+
from .splitter import extract_keywords
|
|
256
|
+
|
|
257
|
+
# 提取每个碎片的关键词
|
|
258
|
+
frag_data = []
|
|
259
|
+
for f in fragments:
|
|
260
|
+
content = f.get("content", "")
|
|
261
|
+
if not content:
|
|
262
|
+
continue
|
|
263
|
+
kws = set(extract_keywords(content, max_keywords=5))
|
|
264
|
+
frag_data.append({"frag": f, "keywords": kws})
|
|
265
|
+
|
|
266
|
+
if not frag_data:
|
|
267
|
+
return []
|
|
268
|
+
|
|
269
|
+
# 贪心聚类
|
|
270
|
+
groups: List[List[Dict]] = []
|
|
271
|
+
assigned = set()
|
|
272
|
+
|
|
273
|
+
for i, data in enumerate(frag_data):
|
|
274
|
+
if i in assigned:
|
|
275
|
+
continue
|
|
276
|
+
group = [data["frag"]]
|
|
277
|
+
assigned.add(i)
|
|
278
|
+
|
|
279
|
+
for j, other in enumerate(frag_data):
|
|
280
|
+
if j in assigned:
|
|
281
|
+
continue
|
|
282
|
+
# 重叠 >= 2 个关键词
|
|
283
|
+
overlap = len(data["keywords"] & other["keywords"])
|
|
284
|
+
if overlap >= 2:
|
|
285
|
+
group.append(other["frag"])
|
|
286
|
+
assigned.add(j)
|
|
287
|
+
|
|
288
|
+
groups.append(group)
|
|
289
|
+
|
|
290
|
+
return groups
|
|
291
|
+
|
|
292
|
+
def _merge_group(self, client, group: List[Dict]) -> bool:
|
|
293
|
+
"""用 LLM 合并一组碎片。"""
|
|
294
|
+
# 计算新层级:取组内最高 level + 1
|
|
295
|
+
max_level = 1
|
|
296
|
+
for f in group:
|
|
297
|
+
try:
|
|
298
|
+
lv = int(f.get("level", "1"))
|
|
299
|
+
if lv > max_level:
|
|
300
|
+
max_level = lv
|
|
301
|
+
except (ValueError, TypeError):
|
|
302
|
+
pass
|
|
303
|
+
new_level = max_level + 1
|
|
304
|
+
|
|
305
|
+
# 判断是否已有合并过的碎片
|
|
306
|
+
has_consolidated = any(
|
|
307
|
+
f.get("fragment_type") == "consolidated" or int(f.get("level", "1")) > 1
|
|
308
|
+
for f in group
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# 准备片段文本
|
|
312
|
+
segments = []
|
|
313
|
+
tags_set = set()
|
|
314
|
+
for f in group:
|
|
315
|
+
content = f.get("content", "")
|
|
316
|
+
if content:
|
|
317
|
+
segments.append(f"• {content[:300]}")
|
|
318
|
+
tags = f.get("tags", "")
|
|
319
|
+
if tags:
|
|
320
|
+
for t in tags.split(","):
|
|
321
|
+
t = t.strip()
|
|
322
|
+
if t and not t.startswith("session:"):
|
|
323
|
+
tags_set.add(t)
|
|
324
|
+
|
|
325
|
+
if len(segments) < self._min_group_size:
|
|
326
|
+
return False
|
|
327
|
+
|
|
328
|
+
# 根据是否已有提炼过的内容选择不同 prompt
|
|
329
|
+
if has_consolidated:
|
|
330
|
+
prompt = (
|
|
331
|
+
"以下是一组已经提炼过的记忆条目和相关的原始对话片段。"
|
|
332
|
+
"请将它们进一步提炼合并成一条更精炼的高层知识条目。\n\n"
|
|
333
|
+
+ "\n".join(segments)
|
|
334
|
+
+ "\n\n提炼后的高层知识条目:"
|
|
335
|
+
)
|
|
336
|
+
else:
|
|
337
|
+
prompt = CONSOLIDATE_PROMPT.format(segments="\n".join(segments))
|
|
338
|
+
result = _call_llm([
|
|
339
|
+
{"role": "system", "content": "你是一位知识提炼专家,擅长从对话中提取核心信息。"},
|
|
340
|
+
{"role": "user", "content": prompt},
|
|
341
|
+
], model=self._llm_model)
|
|
342
|
+
|
|
343
|
+
if not result:
|
|
344
|
+
return False
|
|
345
|
+
|
|
346
|
+
# 分析情绪
|
|
347
|
+
from .splitter import analyze_sentiment
|
|
348
|
+
sent_score, sent_label = analyze_sentiment(result)
|
|
349
|
+
now_str = datetime.now(timezone.utc).isoformat()
|
|
350
|
+
|
|
351
|
+
mapping = {
|
|
352
|
+
"content": result,
|
|
353
|
+
"tags": ",".join(sorted(tags_set)) if tags_set else "",
|
|
354
|
+
"category": "consolidated",
|
|
355
|
+
"source": "consolidator",
|
|
356
|
+
"created": now_str,
|
|
357
|
+
"fragment_type": "consolidated",
|
|
358
|
+
"level": str(new_level), # 多级:原始=1,首次合并=2,二次合并=3...
|
|
359
|
+
"sentiment_score": str(sent_score),
|
|
360
|
+
"sentiment_label": sent_label,
|
|
361
|
+
"feedback_score": "0",
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
# 存 consolidated 碎片
|
|
365
|
+
import hashlib
|
|
366
|
+
content_hash = hashlib.sha256(result.encode()).hexdigest()[:12]
|
|
367
|
+
consolidated_key = f"memory:frag:{content_hash}"
|
|
368
|
+
|
|
369
|
+
# 如果没有相同 key(去重检查),就存
|
|
370
|
+
existing = client.exists(consolidated_key)
|
|
371
|
+
if existing:
|
|
372
|
+
logger.debug("consolidator: duplicate consolidated result, skipping")
|
|
373
|
+
else:
|
|
374
|
+
client.hset(consolidated_key, mapping=mapping)
|
|
375
|
+
|
|
376
|
+
# 软删除原始碎片(标记为已消费,不硬删)
|
|
377
|
+
from datetime import datetime as _dt
|
|
378
|
+
now_iso = _dt.now(timezone.utc).isoformat()
|
|
379
|
+
consumed_count = 0
|
|
380
|
+
for f in group:
|
|
381
|
+
key = f.get("_key")
|
|
382
|
+
if key:
|
|
383
|
+
try:
|
|
384
|
+
client.hset(key, "consumed_by", consolidated_key)
|
|
385
|
+
client.hset(key, "consumed_at", now_iso)
|
|
386
|
+
client.hset(key, "fragment_type", "consumed")
|
|
387
|
+
consumed_count += 1
|
|
388
|
+
except Exception:
|
|
389
|
+
pass
|
|
390
|
+
|
|
391
|
+
logger.info(
|
|
392
|
+
"consolidator: merged %d fragments → '%s...' (marked %d as consumed)",
|
|
393
|
+
len(group), result[:60], consumed_count,
|
|
394
|
+
)
|
|
395
|
+
return True
|
keepsake/embedder.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Embedding 客户端 — 支持 OpenAI / DashScope / 自定义兼容端点。"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Optional
|
|
10
|
+
from urllib.request import Request, urlopen
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
# 模型 → 维度映射
|
|
16
|
+
# ---------------------------------------------------------------------------
|
|
17
|
+
|
|
18
|
+
_MODEL_DIMENSIONS: dict[str, int] = {
|
|
19
|
+
# OpenAI text-embedding-3 系列
|
|
20
|
+
"text-embedding-3-small": 1536,
|
|
21
|
+
"text-embedding-3-large": 3072,
|
|
22
|
+
"text-embedding-ada-002": 1536,
|
|
23
|
+
# DashScope
|
|
24
|
+
"text-embedding-v2": 1536,
|
|
25
|
+
"text-embedding-v3": 1024,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
_DEFAULT_DIM = 1536
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def resolve_dimension(model: str) -> int:
|
|
32
|
+
"""根据模型名返回向量维度,未知模型返回默认值 1536。"""
|
|
33
|
+
model_key = model.strip().lower()
|
|
34
|
+
return _MODEL_DIMENSIONS.get(model_key, _DEFAULT_DIM)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
# 抽象基类
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Embedder(ABC):
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def get_embedding(self, text: str) -> Optional[list[float]]:
|
|
45
|
+
"""输入文本,返回 float 向量。"""
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def dimension(self) -> int:
|
|
51
|
+
"""返回当前模型输出的向量维度。"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# ---------------------------------------------------------------------------
|
|
56
|
+
# OpenAI 兼容
|
|
57
|
+
# ---------------------------------------------------------------------------
|
|
58
|
+
|
|
59
|
+
_DEFAULT_OPENAI_URL = "https://api.openai.com/v1/embeddings"
|
|
60
|
+
_DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class OpenAIEmbedder(Embedder):
|
|
64
|
+
"""兼容 OpenAI Embedding API 的客户端。
|
|
65
|
+
|
|
66
|
+
也兼容 DashScope 等提供 /v1/embeddings 端点的服务。
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
api_key: str = "",
|
|
72
|
+
base_url: str = _DEFAULT_OPENAI_URL,
|
|
73
|
+
model: str = _DEFAULT_OPENAI_MODEL,
|
|
74
|
+
):
|
|
75
|
+
self._api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
|
76
|
+
self._base_url = base_url.rstrip("/")
|
|
77
|
+
self._model = model
|
|
78
|
+
self._dim = resolve_dimension(model)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def dimension(self) -> int:
|
|
82
|
+
return self._dim
|
|
83
|
+
|
|
84
|
+
def get_embedding(self, text: str) -> Optional[list[float]]:
|
|
85
|
+
if not self._api_key:
|
|
86
|
+
logger.warning("embedder: no API key configured")
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
payload = json.dumps({
|
|
90
|
+
"model": self._model,
|
|
91
|
+
"input": text,
|
|
92
|
+
}).encode("utf-8")
|
|
93
|
+
|
|
94
|
+
req = Request(
|
|
95
|
+
f"{self._base_url}/embeddings",
|
|
96
|
+
data=payload,
|
|
97
|
+
headers={
|
|
98
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
99
|
+
"Content-Type": "application/json",
|
|
100
|
+
},
|
|
101
|
+
method="POST",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
with urlopen(req, timeout=30) as resp:
|
|
106
|
+
data = json.loads(resp.read())
|
|
107
|
+
emb = data["data"][0]["embedding"]
|
|
108
|
+
# 如果服务返回的维度与预期不符,更新 self._dim
|
|
109
|
+
if len(emb) != self._dim:
|
|
110
|
+
logger.info(
|
|
111
|
+
"embedder: model %s returned %d dims (expected %d), updating",
|
|
112
|
+
self._model, len(emb), self._dim,
|
|
113
|
+
)
|
|
114
|
+
self._dim = len(emb)
|
|
115
|
+
return emb
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.debug("embedder: request failed: %s", e)
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# ---------------------------------------------------------------------------
|
|
122
|
+
# 工厂
|
|
123
|
+
# ---------------------------------------------------------------------------
|
|
124
|
+
|
|
125
|
+
_EMBEDDER_PROVIDERS: dict[str, type[Embedder]] = {
|
|
126
|
+
"openai": OpenAIEmbedder,
|
|
127
|
+
"dashscope": OpenAIEmbedder, # DashScope 也走 /v1/embeddings
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def create_embedder(
|
|
132
|
+
provider: str = "",
|
|
133
|
+
api_key: str = "",
|
|
134
|
+
base_url: str = "",
|
|
135
|
+
model: str = "",
|
|
136
|
+
) -> Embedder:
|
|
137
|
+
"""根据配置创建 Embedder 实例。
|
|
138
|
+
|
|
139
|
+
参数:
|
|
140
|
+
provider: "openai" | "dashscope" | 自定义
|
|
141
|
+
api_key: API 密钥
|
|
142
|
+
base_url: API 端点
|
|
143
|
+
model: 模型名
|
|
144
|
+
"""
|
|
145
|
+
provider = provider or os.environ.get("FRAGMENTED_EMBEDDER", "openai").lower()
|
|
146
|
+
api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
|
147
|
+
base_url = base_url or os.environ.get("FRAGMENTED_EMBEDDER_URL", _DEFAULT_OPENAI_URL)
|
|
148
|
+
model = model or os.environ.get("FRAGMENTED_EMBEDDER_MODEL", _DEFAULT_OPENAI_MODEL)
|
|
149
|
+
|
|
150
|
+
if provider == "dashscope":
|
|
151
|
+
base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
152
|
+
model = model or "text-embedding-v2"
|
|
153
|
+
|
|
154
|
+
cls = _EMBEDDER_PROVIDERS.get(provider, OpenAIEmbedder)
|
|
155
|
+
return cls(api_key=api_key, base_url=base_url, model=model)
|