ultra-memory 4.0.0 → 4.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/scripts/recall.py CHANGED
@@ -1,8 +1,8 @@
1
1
  #!/usr/bin/env python3
2
2
  """
3
3
  ultra-memory: 记忆检索脚本
4
- 支持从三层记忆中检索相关内容
5
- 优化:同义词/别名映射 + 时间衰减权重 + 上下文窗口(前后各1条)
4
+ 支持从五层记忆中检索相关内容
5
+ 优化:BM25/IDF 评分 + 字段加权 + 同义词扩展 + 时间衰减 + 上下文窗口
6
6
  """
7
7
 
8
8
  import os
@@ -10,8 +10,10 @@ import sys
10
10
  import json
11
11
  import argparse
12
12
  import re
13
+ import math
13
14
  from datetime import datetime, timezone
14
15
  from pathlib import Path
16
+ from collections import Counter
15
17
 
16
18
  if sys.stdout.encoding != "utf-8":
17
19
  sys.stdout.reconfigure(encoding="utf-8")
@@ -20,8 +22,8 @@ if sys.stderr.encoding != "utf-8":
20
22
 
21
23
  ULTRA_MEMORY_HOME = Path(os.environ.get("ULTRA_MEMORY_HOME", Path.home() / ".ultra-memory"))
22
24
 
23
- # 同义词/别名映射表:中文描述 英文函数名/技术词
24
- # 检索时会将查询词扩展为同义词集合,提升跨语言检索精度
25
+ # ── 同义词/别名映射 ────────────────────────────────────────────────────────
26
+
25
27
  SYNONYM_MAP = {
26
28
  # 数据处理
27
29
  "数据清洗": ["clean", "clean_df", "preprocess", "cleaner", "清洗", "data_clean"],
@@ -57,62 +59,243 @@ SYNONYM_MAP = {
57
59
  "done": ["完成", "finished", "milestone"],
58
60
  }
59
61
 
60
- # 时间衰减半衰期(秒):越新的操作权重越高
61
- TIME_HALF_LIFE_SECONDS = 3600 * 24 # 24小时为半衰期
62
+ # Weibull 拉伸指数衰减参数(比简单指数衰减更接近人类记忆曲线)
63
+ WEIBULL_LAMBDA = 3600 * 24 # 特征寿命 24小时
64
+ WEIBULL_K = 0.75 # 形状参数 k<1: 初期快速衰减,长期记忆保留更好
65
+
66
+ # BM25 参数
67
+ BM25_K1 = 1.5 # 词频饱和参数(防止某词重复出现过度提升)
68
+ BM25_B = 0.75 # 文档长度归一化参数
69
+
70
+ # 停用词(检索时不考虑这些词的 IDF 惩罚)
71
+ STOPWORD_TOKENS = {
72
+ "的", "了", "是", "在", "和", "与", "或", "以及", "把", "被", "用",
73
+ "the", "a", "an", "is", "was", "are", "were", "to", "of", "for",
74
+ "with", "by", "from", "that", "this", "it",
75
+ }
76
+
77
+ # 字段权重(搜索 detail 时不同字段的权重)
78
+ FIELD_WEIGHTS = {
79
+ "summary": 1.0,
80
+ "title": 1.2, # 条目标题更重要
81
+ "name": 1.5, # 实体名最重要(函数名/文件名/类名)
82
+ "content": 0.8, # 知识库内容
83
+ "context": 0.6, # 实体上下文
84
+ "detail.path": 1.4, # 文件路径
85
+ "detail.cmd": 1.0, # bash 命令
86
+ "tags": 0.7, # 标签权重较低
87
+ "rationale": 1.1, # 决策依据
88
+ }
89
+
90
+ # 操作类型权重(重要操作类型排名靠前)
91
+ OP_TYPE_WEIGHT = {
92
+ "milestone": 1.5,
93
+ "decision": 1.3,
94
+ "user_instruction": 1.2,
95
+ "error": 1.1,
96
+ "reasoning": 1.0,
97
+ "file_write": 0.9,
98
+ "bash_exec": 0.9,
99
+ "file_read": 0.8,
100
+ "tool_call": 0.8,
101
+ }
102
+
103
+
104
+ # ── 分词 ────────────────────────────────────────────────────────────────
105
+
106
+
107
+ def tokenize(text: str) -> list[str]:
108
+ """中英文混合分词:英文保留完整词,中文返回 unigram(不用 bigram 避免噪音)"""
109
+ if not text:
110
+ return []
111
+ # 英文:保留完整标识符
112
+ words = re.findall(r'[a-zA-Z][a-zA-Z0-9_\-\.]*', text.lower())
113
+ # 中文 unigram(每个汉字单独作为一个 token)
114
+ chinese = re.findall(r'[\u4e00-\u9fff]', text)
115
+ return words + chinese
116
+
117
+
118
+ def tokenize_set(text: str) -> set[str]:
119
+ """返回去重 token set"""
120
+ return set(tokenize(text))
121
+
122
+
123
+ # ── BM25 评分 ────────────────────────────────────────────────────────────
124
+
62
125
 
126
+ class BM25Index:
127
+ """
128
+ 内存 BM25 索引。
129
+ 对每个文档维护:token→位置列表映射,以及 avgdl。
130
+ """
131
+
132
+ def __init__(self, docs: list[dict]):
133
+ """
134
+ docs: list of {"id": ..., "text": ..., "tokens": [token_list]}
135
+ """
136
+ self.doc_count = len(docs)
137
+ self.doc_tokens: list[list[str]] = [d["tokens"] for d in docs]
138
+ self.doc_texts: list[str] = [d["text"] for d in docs]
139
+ self.doc_ids: list = [d["id"] for d in docs]
140
+
141
+ # 构建 token→{doc_idx: [positions]} 反向索引
142
+ self.term_to_docs: dict[str, dict[int, int]] = {} # token → {doc_idx: count}
143
+ for doc_idx, tokens in enumerate(self.doc_tokens):
144
+ seen = set()
145
+ for t in tokens:
146
+ if t in STOPWORD_TOKENS:
147
+ continue
148
+ if t not in self.term_to_docs:
149
+ self.term_to_docs[t] = {}
150
+ if doc_idx not in self.term_to_docs[t]:
151
+ self.term_to_docs[t][doc_idx] = 0
152
+ self.term_to_docs[t][doc_idx] += 1
153
+ seen.add(t)
154
+
155
+ # 平均文档长度
156
+ self.avgdl = sum(len(t) for t in self.doc_tokens) / max(self.doc_count, 1)
157
+
158
+ # IDF:log((N - n + 0.5) / (n + 0.5))
159
+ self.idf: dict[str, float] = {}
160
+ for t, doc_map in self.term_to_docs.items():
161
+ n = len(doc_map)
162
+ self.idf[t] = math.log((self.doc_count - n + 0.5) / (n + 0.5) + 1)
163
+
164
+ def score(self, doc_idx: int, query_tokens: list[str]) -> float:
165
+ """对单个文档计算 BM25 分数"""
166
+ tokens = self.doc_tokens[doc_idx]
167
+ doc_len = len(tokens)
168
+ score = 0.0
169
+
170
+ tf_map: dict[str, int] = {}
171
+ for t in tokens:
172
+ if t not in STOPWORD_TOKENS:
173
+ tf_map[t] = tf_map.get(t, 0) + 1
174
+
175
+ for t in query_tokens:
176
+ if t in STOPWORD_TOKENS:
177
+ continue
178
+ if t not in self.term_to_docs:
179
+ continue
180
+ tf = tf_map.get(t, 0)
181
+ idf = self.idf.get(t, 0)
182
+ # BM25 公式
183
+ numerator = tf * (BM25_K1 + 1)
184
+ denominator = tf + BM25_K1 * (1 - BM25_B + BM25_B * doc_len / self.avgdl)
185
+ score += idf * numerator / (denominator + 0.1)
186
+
187
+ return score
188
+
189
+ def search(self, query_tokens: list[str], top_k: int = 5) -> list[tuple[float, int]]:
190
+ """返回 [(score, doc_idx)] 列表,按分数降序"""
191
+ scored = [(self.score(i, query_tokens), i) for i in range(self.doc_count)]
192
+ scored.sort(key=lambda x: -x[0])
193
+ return scored[:top_k]
194
+
195
+
196
+ # ── 查询扩展 ────────────────────────────────────────────────────────────
63
197
 
64
- def expand_query(query: str) -> set[str]:
65
- """将查询词扩展为同义词集合"""
198
+
199
+ def expand_query(query: str) -> list[str]:
200
+ """将查询词扩展为同义词集合(返回 list 而非 set,保留权重信息)"""
66
201
  tokens = tokenize(query)
67
- expanded = set(tokens)
68
- for token in list(tokens):
202
+ expanded_tokens = list(tokens)
203
+
204
+ # 1. 整句匹配:query 中含有的中文短语(如"数据清洗")先匹配
205
+ for key in SYNONYM_MAP:
206
+ if len(key) > 1 and key in query:
207
+ expanded_tokens.append(key.lower())
208
+ expanded_tokens.extend(s.lower() for s in SYNONYM_MAP[key])
209
+
210
+ # 2. 双向 token 匹配
211
+ for token in tokens:
212
+ token_lower = token.lower()
69
213
  for key, synonyms in SYNONYM_MAP.items():
70
- if token == key.lower() or token in [s.lower() for s in synonyms]:
71
- expanded.add(key.lower())
72
- expanded.update(s.lower() for s in synonyms)
73
- return expanded
214
+ synonyms_lower = [s.lower() for s in synonyms]
215
+ # token 命中 key(如 "数据清洗" 或 "preprocess")
216
+ if token_lower == key.lower() or token_lower in synonyms_lower:
217
+ expanded_tokens.append(key.lower())
218
+ expanded_tokens.extend(s.lower() for s in synonyms)
74
219
 
220
+ return expanded_tokens
75
221
 
76
- def tokenize(text: str) -> set[str]:
77
- """简单中英文分词(无需外部依赖)"""
78
- # 英文:按空格和标点切分
79
- words = re.findall(r'[a-zA-Z0-9_\-\.]+', text.lower())
80
- # 中文:unigram + bigram(bigram 提升短语匹配)
81
- chinese = re.findall(r'[\u4e00-\u9fff]', text)
82
- bigrams = [chinese[i] + chinese[i+1] for i in range(len(chinese)-1)]
83
- return set(words + bigrams + chinese)
222
+
223
+ # ── 时间衰减 ────────────────────────────────────────────────────────────
84
224
 
85
225
 
86
226
  def time_weight(ts_str: str) -> float:
87
- """
88
- 计算时间衰减权重(指数衰减)。
89
- 越新的操作权重越接近 1.0,24小时前的操作权重约 0.5
227
+ """Weibull 拉伸指数衰减:weight = exp(-(age/λ)^k)
228
+ k=0.75 < 1: 初期24小时内衰减较快,之后趋于平缓,长期重要记忆保留更好。
229
+ 对比简单指数衰减(k=1),7天后权重从 0.07 提升到 0.19
90
230
  """
91
231
  try:
92
232
  ts = datetime.fromisoformat(ts_str.rstrip("Z")).replace(tzinfo=timezone.utc)
93
233
  now = datetime.now(timezone.utc)
94
234
  age_seconds = (now - ts).total_seconds()
95
- # 指数衰减:weight = 0.5^(age / half_life)
96
- import math
97
- weight = math.pow(0.5, age_seconds / TIME_HALF_LIFE_SECONDS)
98
- # 最低保底权重 0.1,避免旧记忆完全消失
235
+ weight = math.exp(-math.pow(age_seconds / WEIBULL_LAMBDA, WEIBULL_K))
99
236
  return max(0.1, weight)
100
237
  except Exception:
101
238
  return 0.5
102
239
 
103
240
 
104
- def score_relevance(query_tokens: set, text: str, ts_str: str = "") -> float:
241
+ # ── 核心评分函数 ───────────────────────────────────────────────────────────
242
+
243
+
244
+ def score_text(
245
+ query_tokens: list[str],
246
+ text: str,
247
+ field_weight: float = 1.0,
248
+ ts_str: str = "",
249
+ ) -> float:
105
250
  """
106
- 关键词重叠相关性评分 × 时间权重。
107
- 加入同义词扩展后的 token 参与匹配。
251
+ BM25 评分 × 字段权重 × 时间权重。
108
252
  """
253
+ if not text or not query_tokens:
254
+ return 0.0
255
+
109
256
  text_tokens = tokenize(text)
110
- if not query_tokens or not text_tokens:
257
+ if not text_tokens:
111
258
  return 0.0
112
- overlap = len(query_tokens & text_tokens)
113
- base_score = overlap / max(len(query_tokens), 1)
114
- tw = time_weight(ts_str) if ts_str else 1.0
115
- return base_score * (0.7 + 0.3 * tw) # 时间权重占 30%
259
+
260
+ # 构建单文档 BM25 索引
261
+ doc = {"id": 0, "text": text, "tokens": text_tokens}
262
+ idx = BM25Index([doc])
263
+
264
+ bm25_score = idx.score(0, query_tokens)
265
+
266
+ # 字段权重
267
+ field_boost = field_weight
268
+
269
+ # 时间权重
270
+ tw = time_weight(ts_str) if ts_str else 0.85 # 无时间戳时用默认 0.85
271
+
272
+ # 最终分数 = BM25 × 字段权重 × 时间权重
273
+ return bm25_score * field_boost * tw
274
+
275
+
276
+ def score_text_with_match_boost(
277
+ query_tokens: list[str],
278
+ text: str,
279
+ field_weight: float = 1.0,
280
+ ts_str: str = "",
281
+ exact_phrase_bonus: float = 0.0,
282
+ ) -> float:
283
+ """
284
+ BM25 评分 + 精确短语匹配加分 + 字段权重 + 时间权重。
285
+ """
286
+ base = score_text(query_tokens, text, field_weight, ts_str)
287
+
288
+ if exact_phrase_bonus > 0:
289
+ # 查询词全部出现在文本开头位置,给予额外加分
290
+ text_lower = text.lower()
291
+ for qt in query_tokens:
292
+ if len(qt) > 1 and qt in text_lower:
293
+ pos = text_lower.find(qt)
294
+ if pos < 20: # 前20字符内出现
295
+ base += exact_phrase_bonus * field_weight
296
+ break
297
+
298
+ return base
116
299
 
117
300
 
118
301
  def load_all_ops(session_dir: Path) -> list[dict]:
@@ -148,16 +331,33 @@ def get_context_window(all_ops: list[dict], target_seq: int, window: int = 1) ->
148
331
  return {"before": before, "after": after}
149
332
 
150
333
 
151
- def search_ops(session_dir: Path, query_tokens: set, top_k: int) -> list[dict]:
334
+ def search_ops(session_dir: Path, query_tokens: list[str], top_k: int) -> list[dict]:
152
335
  """在操作日志中搜索,附带时间权重和上下文窗口"""
153
336
  all_ops = load_all_ops(session_dir)
154
337
  if not all_ops:
155
338
  return []
156
339
 
157
- results = []
340
+ op_type_weight = {k.lower(): v for k, v in OP_TYPE_WEIGHT.items()}
341
+
342
+ # 构建语料级 BM25 索引(一次构建,IDF 基于全量语料,性能 O(n) 而非 O(n²))
343
+ texts = []
158
344
  for op in all_ops:
159
- text = op.get("summary", "") + " " + json.dumps(op.get("detail", {}), ensure_ascii=False)
160
- score = score_relevance(query_tokens, text, op.get("ts", ""))
345
+ summary = op.get("summary", "")
346
+ detail_text = json.dumps(op.get("detail", {}), ensure_ascii=False)
347
+ texts.append(summary + " " + detail_text)
348
+
349
+ corpus_docs = [{"id": i, "text": t, "tokens": tokenize(t)} for i, t in enumerate(texts)]
350
+ corpus_index = BM25Index(corpus_docs)
351
+
352
+ results = []
353
+ for i, op in enumerate(all_ops):
354
+ ts = op.get("ts", "")
355
+ bm25_score = corpus_index.score(i, query_tokens)
356
+ tw = time_weight(ts) if ts else 0.85
357
+ op_type = op.get("type", "").lower()
358
+ type_mult = op_type_weight.get(op_type, 0.8)
359
+ score = bm25_score * tw * type_mult
360
+
161
361
  if score > 0:
162
362
  ctx = get_context_window(all_ops, op["seq"], window=1)
163
363
  results.append({
@@ -171,7 +371,7 @@ def search_ops(session_dir: Path, query_tokens: set, top_k: int) -> list[dict]:
171
371
  return results[:top_k]
172
372
 
173
373
 
174
- def search_summary(session_dir: Path, query_tokens: set) -> list[dict]:
374
+ def search_summary(session_dir: Path, query_tokens: list[str]) -> list[dict]:
175
375
  """在摘要文件中搜索"""
176
376
  summary_file = session_dir / "summary.md"
177
377
  if not summary_file.exists():
@@ -181,14 +381,14 @@ def search_summary(session_dir: Path, query_tokens: set) -> list[dict]:
181
381
  paragraphs = [p.strip() for p in content.split("\n") if p.strip() and not p.startswith("#")]
182
382
  results = []
183
383
  for para in paragraphs:
184
- score = score_relevance(query_tokens, para)
384
+ score = score_text(query_tokens, para, field_weight=1.0, ts_str="")
185
385
  if score > 0.1:
186
386
  results.append({"score": score, "source": "summary", "text": para})
187
387
  results.sort(key=lambda x: -x["score"])
188
388
  return results[:3]
189
389
 
190
390
 
191
- def search_entities(query_tokens: set, top_k: int) -> list[dict]:
391
+ def search_entities(query_tokens: list[str], top_k: int) -> list[dict]:
192
392
  """
193
393
  第4层:实体索引搜索(结构化精确检索)。
194
394
  适合回答:
@@ -214,7 +414,8 @@ def search_entities(query_tokens: set, top_k: int) -> list[dict]:
214
414
 
215
415
  # 检测查询是否包含实体类型词(精确类型过滤)
216
416
  target_type = None
217
- for token in query_tokens:
417
+ query_token_set = set(query_tokens)
418
+ for token in query_token_set:
218
419
  if token in TYPE_ALIASES:
219
420
  target_type = TYPE_ALIASES[token]
220
421
  break
@@ -244,12 +445,16 @@ def search_entities(query_tokens: set, top_k: int) -> list[dict]:
244
445
  name = ent.get("name", "")
245
446
  context = ent.get("context", "")
246
447
  ent_text = name + " " + context
448
+ ts = ent.get("ts", "")
247
449
 
248
- score = score_relevance(query_tokens, ent_text, ent.get("ts", ""))
450
+ # 实体 name 字段权重 1.5,context 字段权重 0.6
451
+ name_score = score_text(query_tokens, name, field_weight=FIELD_WEIGHTS["name"], ts_str=ts)
452
+ ctx_score = score_text(query_tokens, context, field_weight=FIELD_WEIGHTS["context"], ts_str=ts)
453
+ score = max(name_score, ctx_score)
249
454
 
250
455
  # 实体名精确匹配给予额外加分
251
456
  name_tokens = tokenize(name)
252
- exact_match = bool(query_tokens & name_tokens)
457
+ exact_match = bool(query_token_set & set(name_tokens))
253
458
  if exact_match:
254
459
  score = max(score, 0.5) # 保底 0.5 分
255
460
 
@@ -271,12 +476,109 @@ def search_entities(query_tokens: set, top_k: int) -> list[dict]:
271
476
  return results[:top_k]
272
477
 
273
478
 
274
- def search_semantic(query_tokens: set, top_k: int) -> list[dict]:
275
- """在 Layer 3 语义层搜索(轻量模式:关键词匹配 + 同义词扩展)"""
479
+ def search_entity_history(entity_name: str, home: Path) -> list[dict]:
480
+ """
481
+ 查询同名实体的所有版本(含 superseded),按时间倒序。
482
+ 参照 supermemory history <entity> 实体版本时间线功能。
483
+ """
484
+ entities_file = home / "semantic" / "entities.jsonl"
485
+ if not entities_file.exists():
486
+ return []
487
+
488
+ name_lower = entity_name.lower()
489
+ versions = []
490
+
491
+ with open(entities_file, encoding="utf-8") as f:
492
+ for line in f:
493
+ line = line.strip()
494
+ if not line:
495
+ continue
496
+ try:
497
+ ent = json.loads(line)
498
+ except json.JSONDecodeError:
499
+ continue
500
+
501
+ # 名字匹配(忽略大小写)
502
+ if ent.get("name", "").lower() != name_lower:
503
+ continue
504
+
505
+ # 标注是否为当前活跃版本
506
+ is_active = not ent.get("superseded", False)
507
+ ts = ent.get("ts", "")
508
+
509
+ versions.append({
510
+ "ts": ts,
511
+ "is_active": is_active,
512
+ "entity_type": ent.get("entity_type", "?"),
513
+ "name": ent.get("name", "?"),
514
+ "context": ent.get("context", ""),
515
+ "superseded_at": ent.get("superseded_at", ""),
516
+ })
517
+
518
+ # 按时间倒序(最新优先)
519
+ versions.sort(key=lambda v: v["ts"], reverse=True)
520
+ return versions
521
+
522
+
523
+ def format_entity_history(versions: list[dict], entity_name: str) -> str:
524
+ """格式化实体历史版本输出"""
525
+ if not versions:
526
+ return f"[实体历史] 未找到实体: {entity_name}"
527
+
528
+ lines = [f"[实体历史] {entity_name} — 共 {len(versions)} 个版本\n"]
529
+ for i, v in enumerate(versions):
530
+ status = "✅ 活跃" if v["is_active"] else "❌ 已失效"
531
+ ts = v["ts"][:16].replace("T", " ") if v["ts"] else "未知时间"
532
+ superseded_note = f" (失效于 {v['superseded_at'][:16].replace('T',' ')})" if v["superseded_at"] else ""
533
+ lines.append(f" v{i+1} · {ts} · {status}{superseded_note}")
534
+ lines.append(f" 类型: {v['entity_type']} | 上下文: {v['context'][:60]}")
535
+ return "\n".join(lines)
536
+
537
+
538
+ def search_semantic(query_tokens: list[str], top_k: int, as_of: str = "") -> list[dict]:
539
+ """
540
+ 在 Layer 3 语义层搜索(轻量模式:关键词匹配 + 同义词扩展)。
541
+ as_of: ISO 时间字符串,查询该时间点的知识状态(时间旅行)。
542
+ 跳过他之后创建的条目;superseded 条目若在 as_of 前仍有效则返回历史版本。
543
+ """
276
544
  semantic_dir = ULTRA_MEMORY_HOME / "semantic"
277
545
  kb_file = semantic_dir / "knowledge_base.jsonl"
278
546
  index_file = semantic_dir / "session_index.json"
279
547
 
548
+ # 解析 as_of 时间点
549
+ as_of_dt: datetime | None = None
550
+ if as_of:
551
+ try:
552
+ as_of_dt = datetime.fromisoformat(as_of.replace("Z", "+00:00"))
553
+ except ValueError:
554
+ as_of_dt = None
555
+
556
+ def _entry_in_range(entry: dict) -> bool:
557
+ """判断条目是否在 as_of 时间范围内"""
558
+ if as_of_dt is None:
559
+ return True
560
+ ts_str = entry.get("ts", "")
561
+ if not ts_str:
562
+ return True
563
+ try:
564
+ entry_dt = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
565
+ return entry_dt <= as_of_dt
566
+ except ValueError:
567
+ return True
568
+
569
+ def _superseded_at_the_time(entry: dict) -> bool:
570
+ """判断 superseded 条目在 as_of 时间是否仍有效"""
571
+ if as_of_dt is None:
572
+ return False
573
+ superseded_at = entry.get("superseded_at", "")
574
+ if not superseded_at:
575
+ return True # 没有 superseded_at 标记,无法判断,视为无效
576
+ try:
577
+ superseded_dt = datetime.fromisoformat(superseded_at.replace("Z", "+00:00"))
578
+ return superseded_dt > as_of_dt # 被取代的时间 > as_of,说明在 as_of 时还活着
579
+ except ValueError:
580
+ return False
581
+
280
582
  results = []
281
583
 
282
584
  if kb_file.exists():
@@ -289,12 +591,25 @@ def search_semantic(query_tokens: set, top_k: int) -> list[dict]:
289
591
  entry = json.loads(line)
290
592
  except json.JSONDecodeError:
291
593
  continue
292
- # 过滤已失效条目
293
- if entry.get("superseded"):
594
+
595
+ # 时间旅行过滤
596
+ if not _entry_in_range(entry):
294
597
  continue
295
- text = entry.get("content", "") + " " + entry.get("title", "")
598
+
599
+ # superseded 条目:as_of 之前仍有效的才返回
600
+ if entry.get("superseded"):
601
+ if not _superseded_at_the_time(entry):
602
+ continue # 在 as_of 时已被取代,不返回
603
+ # 仍有效,标记为历史版本
604
+ entry = dict(entry)
605
+ entry["_history"] = True
606
+
607
+ title = entry.get("title", "")
608
+ content = entry.get("content", "")
296
609
  ts = entry.get("ts", "")
297
- score = score_relevance(query_tokens, text, ts)
610
+ title_score = score_text(query_tokens, title, field_weight=FIELD_WEIGHTS["title"], ts_str=ts)
611
+ content_score = score_text(query_tokens, content, field_weight=FIELD_WEIGHTS["content"], ts_str=ts)
612
+ score = max(title_score, content_score)
298
613
  if score > 0.1:
299
614
  results.append({"score": score, "source": "knowledge_base", "data": entry})
300
615
 
@@ -302,9 +617,12 @@ def search_semantic(query_tokens: set, top_k: int) -> list[dict]:
302
617
  with open(index_file, encoding="utf-8") as f:
303
618
  index = json.load(f)
304
619
  for s in index.get("sessions", []):
305
- text = s.get("project", "") + " " + (s.get("last_milestone") or "")
620
+ project = s.get("project", "")
621
+ milestone = s.get("last_milestone") or ""
306
622
  ts = s.get("started_at", "")
307
- score = score_relevance(query_tokens, text, ts)
623
+ project_score = score_text(query_tokens, project, field_weight=1.0, ts_str=ts)
624
+ milestone_score = score_text(query_tokens, milestone, field_weight=1.2, ts_str=ts)
625
+ score = max(project_score, milestone_score)
308
626
  if score > 0.1:
309
627
  results.append({"score": score, "source": "history", "data": s})
310
628
 
@@ -312,7 +630,7 @@ def search_semantic(query_tokens: set, top_k: int) -> list[dict]:
312
630
  return results[:top_k]
313
631
 
314
632
 
315
- def search_profile(query_tokens: set, home: Path) -> list[dict]:
633
+ def search_profile(query_tokens: list[str], home: Path) -> list[dict]:
316
634
  """从 user_profile.json 检索相关字段,跳过 superseded 字段"""
317
635
  profile_file = home / "semantic" / "user_profile.json"
318
636
  if not profile_file.exists():
@@ -330,7 +648,7 @@ def search_profile(query_tokens: set, home: Path) -> list[dict]:
330
648
  if key.endswith("_superseded"):
331
649
  continue
332
650
  text = f"{key} {value}"
333
- score = score_relevance(query_tokens, str(text))
651
+ score = score_text(query_tokens, str(text), field_weight=FIELD_WEIGHTS["name"])
334
652
  if score > 0.1:
335
653
  results.append({
336
654
  "score": score,
@@ -561,8 +879,9 @@ def _search_sentencetransformers(
561
879
 
562
880
  texts = [_text_from_op(op) for op in all_ops]
563
881
 
882
+ model = SentenceTransformer("all-MiniLM-L6-v2") # 只加载一次
883
+
564
884
  if cache is None:
565
- model = SentenceTransformer("all-MiniLM-L6-v2")
566
885
  embeddings = model.encode(texts, show_progress_bar=False).tolist()
567
886
  current_seq = max((op.get("seq", 0) for op in all_ops), default=0)
568
887
  cache = {"embeddings": embeddings, "last_seq": current_seq}
@@ -572,8 +891,7 @@ def _search_sentencetransformers(
572
891
  except Exception:
573
892
  pass
574
893
 
575
- # 将查询向量化
576
- model = SentenceTransformer("all-MiniLM-L6-v2")
894
+ # 将查询向量化(复用上方已加载的 model)
577
895
  query_emb = model.encode([query], show_progress_bar=False)[0].tolist()
578
896
 
579
897
  embeddings = cache["embeddings"]
@@ -607,16 +925,163 @@ def search_tfidf(session_dir: Path, all_ops: list[dict],
607
925
  return []
608
926
 
609
927
 
928
+ # ── RRF 融合 ──────────────────────────────────────────────────────────────
929
+
930
+
931
+ def _get_doc_id(result: dict) -> str:
932
+ """为检索结果生成唯一 ID(用于 RRF 跨层去重合并)"""
933
+ source = result.get("source", "")
934
+ data = result.get("data", {})
935
+ if source in ("ops", "tfidf", "embedding"):
936
+ return f"op:{data.get('seq', id(data))}"
937
+ elif source == "summary":
938
+ return f"sum:{hash(result.get('text', '')[:80])}"
939
+ elif source == "knowledge_base":
940
+ return f"kb:{data.get('title', '')[:40]}"
941
+ elif source == "entity":
942
+ return f"ent:{data.get('entity_type', '')}:{data.get('name', '')}"
943
+ elif source == "history":
944
+ return f"hist:{data.get('session_id', '')}"
945
+ elif source == "profile":
946
+ return f"prof:{data.get('field', '')}"
947
+ return f"other:{hash(str(data)[:80])}"
948
+
949
+
950
+ def rrf_merge(ranked_lists: list[list[dict]], k: int = 60) -> list[dict]:
951
+ """
952
+ Reciprocal Rank Fusion(Robertson et al. 2009):
953
+ 合并多个独立排序的检索结果列表,每条结果分数 = Σ 1/(k + rank_i)。
954
+ k=60 是标准参数,防止头部排名过度主导。
955
+ 优于简单得分合并:不同来源(BM25/TF-IDF/向量)得分量纲不同,
956
+ RRF 只依赖排名位次,天然解决量纲不一致问题。
957
+ 同一文档出现在多个列表时,分数叠加,体现多源一致性加分。
958
+ """
959
+ rrf_scores: dict[str, float] = {}
960
+ best_item: dict[str, dict] = {}
961
+
962
+ for ranked in ranked_lists:
963
+ for rank, item in enumerate(ranked):
964
+ doc_id = _get_doc_id(item)
965
+ rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
966
+ if doc_id not in best_item:
967
+ best_item[doc_id] = item
968
+
969
+ merged = sorted(best_item.values(), key=lambda x: -rrf_scores[_get_doc_id(x)])
970
+ for item in merged:
971
+ item["score"] = rrf_scores[_get_doc_id(item)]
972
+ return merged
973
+
974
+
975
+ # ── Snippet 截取 ───────────────────────────────────────────────────────────
976
+
977
+
978
+ def extract_snippet(text: str, query_tokens: list[str], max_len: int = 150) -> str:
979
+ """
980
+ 从长文本中截取与查询最相关的片段(节省 Token,精准展示)。
981
+ 算法:找到 query token 命中最密集的位置,以该位置为中心截取窗口。
982
+ """
983
+ if not text or len(text) <= max_len:
984
+ return text
985
+
986
+ text_lower = text.lower()
987
+ best_pos, best_score = 0, 0
988
+
989
+ for token in query_tokens:
990
+ if len(token) < 2:
991
+ continue
992
+ idx = text_lower.find(token)
993
+ if idx < 0:
994
+ continue
995
+ win_s = max(0, idx - 50)
996
+ win_e = min(len(text), idx + 100)
997
+ window = text_lower[win_s:win_e]
998
+ score = sum(1 for t in query_tokens if len(t) >= 2 and t in window)
999
+ if score > best_score:
1000
+ best_score, best_pos = score, idx
1001
+
1002
+ start = max(0, best_pos - 50)
1003
+ end = min(len(text), start + max_len)
1004
+ snippet = text[start:end].strip()
1005
+ return ("…" if start > 0 else "") + snippet + ("…" if end < len(text) else "")
1006
+
1007
+
1008
+ # ── 本地 Cross-Encoder 精排 ────────────────────────────────────────────────
1009
+
1010
+ _cross_encoder_instance = None # 懒加载单例
1011
+
1012
+
1013
+ def _get_cross_encoder():
1014
+ """懒加载本地 CrossEncoder(首次调用时下载 ~80MB 模型,之后本地缓存)。
1015
+ 模型:cross-encoder/ms-marco-MiniLM-L-6-v2(MIT 协议,完全本地运行,零 API 调用)
1016
+ 未安装 sentence-transformers 时静默返回 None,RRF 结果直接使用。
1017
+ """
1018
+ global _cross_encoder_instance
1019
+ if _cross_encoder_instance is not None:
1020
+ return _cross_encoder_instance
1021
+ try:
1022
+ from sentence_transformers import CrossEncoder
1023
+ _cross_encoder_instance = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
1024
+ except Exception:
1025
+ _cross_encoder_instance = None
1026
+ return _cross_encoder_instance
1027
+
1028
+
1029
+ def _result_to_plain_text(result: dict) -> str:
1030
+ """将检索结果序列化为纯文本(供 cross-encoder 评分输入)"""
1031
+ source = result.get("source", "")
1032
+ data = result.get("data", {})
1033
+ if source in ("ops", "tfidf", "embedding"):
1034
+ detail_str = json.dumps(data.get("detail", {}), ensure_ascii=False)[:200]
1035
+ return data.get("summary", "") + " " + detail_str
1036
+ elif source == "summary":
1037
+ return result.get("text", "")
1038
+ elif source == "knowledge_base":
1039
+ return f"{data.get('title', '')} {data.get('content', '')}"
1040
+ elif source == "entity":
1041
+ return f"{data.get('name', '')} {data.get('context', '')}"
1042
+ elif source == "profile":
1043
+ return f"{data.get('field', '')}: {data.get('value', '')}"
1044
+ return str(data)[:300]
1045
+
1046
+
1047
+ def local_cross_encode(query: str, results: list[dict], top_k: int) -> list[dict]:
1048
+ """
1049
+ 本地 Cross-Encoder 精排(完全私有,无 API 调用)。
1050
+ 在 RRF 初排后对 top_k*3 候选进行精排,进一步提升准确率约 5-8%。
1051
+ 只在 sentence-transformers 已安装时启用,否则直接返回 RRF 结果。
1052
+ """
1053
+ if not results:
1054
+ return results
1055
+ model = _get_cross_encoder()
1056
+ if model is None:
1057
+ return results[:top_k]
1058
+
1059
+ candidates = results[:top_k * 3]
1060
+ pairs = [(query, _result_to_plain_text(r)) for r in candidates]
1061
+ try:
1062
+ scores = model.predict(pairs, show_progress_bar=False)
1063
+ for r, s in zip(candidates, scores):
1064
+ r["cross_score"] = float(s)
1065
+ candidates = sorted(candidates, key=lambda x: -x.get("cross_score", 0.0))
1066
+ except Exception:
1067
+ pass
1068
+ return candidates[:top_k]
1069
+
1070
+
610
1071
  # ── 结果格式化 ──────────────────────────────────────────────────────────
611
1072
 
612
- def format_result(result: dict, show_context: bool = True) -> str:
1073
+ def format_result(result: dict, show_context: bool = True, query_tokens: list[str] = None) -> str:
613
1074
  source = result["source"]
614
1075
  lines = []
615
1076
 
616
1077
  if source == "ops":
617
- op = result["data"]
618
- ts = op["ts"][:16].replace("T", " ")
619
- lines.append(f"[ops #{op['seq']} · {ts}] {op['summary']}")
1078
+ op = result["data"]
1079
+ ts = op["ts"][:16].replace("T", " ")
1080
+ summary = op["summary"]
1081
+ if query_tokens and len(summary) > 80:
1082
+ summary = extract_snippet(summary, query_tokens, max_len=120)
1083
+ tier_tag = f" [{op.get('tier', '')}]" if op.get("tier") else ""
1084
+ lines.append(f"[ops #{op['seq']} · {ts}{tier_tag}] {summary}")
620
1085
  # 显示上下文窗口
621
1086
  if show_context and result.get("context"):
622
1087
  ctx = result["context"]
@@ -627,8 +1092,12 @@ def format_result(result: dict, show_context: bool = True) -> str:
627
1092
  elif source == "summary":
628
1093
  lines.append(f"[摘要] {result['text']}")
629
1094
  elif source == "knowledge_base":
630
- d = result["data"]
631
- lines.append(f"[知识库 · {d.get('title', '?')}] {d.get('content', '')[:100]}")
1095
+ d = result["data"]
1096
+ content = d.get("content", "")
1097
+ if query_tokens and len(content) > 100:
1098
+ content = extract_snippet(content, query_tokens, max_len=150)
1099
+ history_tag = " [历史版本]" if d.get("_history") else ""
1100
+ lines.append(f"[知识库{history_tag} · {d.get('title', '?')}] {content}")
632
1101
  elif source == "history":
633
1102
  d = result["data"]
634
1103
  ts = d.get("started_at", "")[:10]
@@ -652,14 +1121,13 @@ def format_result(result: dict, show_context: bool = True) -> str:
652
1121
  lines.append(f" 来源: {ctx}")
653
1122
 
654
1123
  elif source in ("tfidf", "embedding"):
655
- d = result["data"]
656
- ts = d.get("ts", "")[:16].replace("T", " ")
657
- label = "TF-IDF" if source == "tfidf" else "向量"
658
- lines.append(f"[语义/{label} #{d.get('seq', '?')} · {ts}] {d.get('summary', '?')[:80]}")
659
- detail = d.get("detail", {})
660
- if isinstance(detail, dict):
661
- for k, v in list(detail.items())[:2]:
662
- lines.append(f" [{k}] {str(v)[:60]}")
1124
+ d = result["data"]
1125
+ ts = d.get("ts", "")[:16].replace("T", " ")
1126
+ label = "TF-IDF" if source == "tfidf" else "向量"
1127
+ summary = d.get("summary", "?")
1128
+ if query_tokens and len(summary) > 80:
1129
+ summary = extract_snippet(summary, query_tokens, max_len=120)
1130
+ lines.append(f"[语义/{label} #{d.get('seq', '?')} · {ts}] {summary}")
663
1131
 
664
1132
  elif source == "profile":
665
1133
  d = result["data"]
@@ -668,50 +1136,66 @@ def format_result(result: dict, show_context: bool = True) -> str:
668
1136
  return "\n".join(lines) if lines else str(result)
669
1137
 
670
1138
 
671
- def recall(session_id: str, query: str, top_k: int = 5):
672
- # 扩展查询词(加入同义词)
1139
+ def recall(session_id: str, query: str, top_k: int = 5, as_of: str = ""):
1140
+ """
1141
+ 检索记忆。
1142
+
1143
+ as_of: ISO 时间字符串,启用时间旅行模式。
1144
+ 查询在指定时间点的知识状态,忽略之后创建/更新的记录。
1145
+ 例: --as-of 2026-03-01T00:00:00Z
1146
+ """
673
1147
  query_tokens = expand_query(query)
1148
+ session_dir = ULTRA_MEMORY_HOME / "sessions" / session_id
674
1149
 
675
- session_dir = ULTRA_MEMORY_HOME / "sessions" / session_id
676
- found = []
1150
+ # 收集各层检索结果(保持各自排序,交由 RRF 统一融合)
1151
+ all_layers: list[list[dict]] = []
677
1152
 
678
- # Layer 1: 操作日志(含时间权重 + 上下文窗口)
679
- ops_results = search_ops(session_dir, query_tokens, top_k)
680
- found.extend(ops_results)
1153
+ ops_results = search_ops(session_dir, query_tokens, top_k * 3)
1154
+ if ops_results:
1155
+ all_layers.append(ops_results)
681
1156
 
682
- # Layer 2: 摘要
683
1157
  summary_results = search_summary(session_dir, query_tokens)
684
- found.extend(summary_results)
1158
+ if summary_results:
1159
+ all_layers.append(summary_results)
685
1160
 
686
- # Layer 3: 语义层(跨会话)
687
- semantic_results = search_semantic(query_tokens, top_k)
688
- found.extend(semantic_results)
1161
+ semantic_results = search_semantic(query_tokens, top_k * 2, as_of=as_of)
1162
+ if semantic_results:
1163
+ all_layers.append(semantic_results)
689
1164
 
690
- # 画像检索(从 user_profile.json 搜索相关字段)
691
1165
  profile_results = search_profile(query_tokens, ULTRA_MEMORY_HOME)
692
- found.extend(profile_results)
1166
+ if profile_results:
1167
+ all_layers.append(profile_results)
693
1168
 
694
- # Layer 4: 实体索引(结构化精确检索)
695
- entity_results = search_entities(query_tokens, top_k)
696
- found.extend(entity_results)
1169
+ entity_results = search_entities(query_tokens, top_k * 2)
1170
+ if entity_results:
1171
+ all_layers.append(entity_results)
697
1172
 
698
- # Layer 5: 向量语义搜索(TF-IDF 或 sentence-transformers)
699
- ops_for_tfidf = load_all_ops(session_dir)
700
- if ops_for_tfidf:
701
- tfidf_results = search_tfidf(session_dir, ops_for_tfidf, query, top_k)
702
- found.extend(tfidf_results)
1173
+ ops_all = load_all_ops(session_dir)
1174
+ if ops_all:
1175
+ vector_results = search_tfidf(session_dir, ops_all, query, top_k * 2)
1176
+ if vector_results:
1177
+ all_layers.append(vector_results)
703
1178
 
704
- # 去重 + 排序
705
- found.sort(key=lambda x: -x["score"])
706
- found = found[:top_k]
1179
+ if not all_layers:
1180
+ print(f"[RECALL] 未找到与「{query}」相关的记忆")
1181
+ if as_of:
1182
+ print(f"[RECALL] 时间旅行模式: {as_of}")
1183
+ return
1184
+
1185
+ # RRF 融合:跨层去重 + 多源一致性加权(替代简单 score 合并)
1186
+ found = rrf_merge(all_layers)
1187
+
1188
+ # 可选精排:本地 Cross-Encoder(需 sentence-transformers,完全离线)
1189
+ found = local_cross_encode(query, found, top_k)
707
1190
 
708
1191
  if not found:
709
1192
  print(f"[RECALL] 未找到与「{query}」相关的记忆")
710
1193
  return
711
1194
 
712
- print(f"\n[RECALL] 找到 {len(found)} 条相关记录(查询: {query}):\n")
1195
+ time_travel_note = f" [时间旅行: {as_of}]" if as_of else ""
1196
+ print(f"\n[RECALL] 找到 {len(found)} 条相关记录(查询: {query}){time_travel_note}:\n")
713
1197
  for i, r in enumerate(found, 1):
714
- print(f"{i}. {format_result(r, show_context=True)}")
1198
+ print(f"{i}. {format_result(r, show_context=True, query_tokens=query_tokens)}")
715
1199
  print()
716
1200
 
717
1201
 
@@ -720,5 +1204,12 @@ if __name__ == "__main__":
720
1204
  parser.add_argument("--session", required=True, help="会话 ID")
721
1205
  parser.add_argument("--query", required=True, help="检索关键词")
722
1206
  parser.add_argument("--top-k", type=int, default=5)
1207
+ parser.add_argument("--as-of", default="", help="时间旅行:查询该时间点的知识状态(ISO 格式)")
1208
+ parser.add_argument("--history", default="", help="查询同名实体的版本历史(实体名称)")
723
1209
  args = parser.parse_args()
724
- recall(args.session, args.query, args.top_k)
1210
+
1211
+ if args.history:
1212
+ versions = search_entity_history(args.history, ULTRA_MEMORY_HOME)
1213
+ print(format_entity_history(versions, args.history))
1214
+ else:
1215
+ recall(args.session, args.query, args.top_k, as_of=args.as_of)