chainmem 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chainmem/__init__.py +67 -0
- chainmem/__main__.py +5 -0
- chainmem/cli/app.py +398 -0
- chainmem/core/__init__.py +67 -0
- chainmem/core/node.py +5 -0
- chainmem/pipeline/ingester.py +149 -0
- chainmem/pipeline/retriever.py +311 -0
- chainmem/store/sqlite_store.py +145 -0
- chainmem-0.3.0.dist-info/LICENSE +21 -0
- chainmem-0.3.0.dist-info/METADATA +322 -0
- chainmem-0.3.0.dist-info/RECORD +14 -0
- chainmem-0.3.0.dist-info/WHEEL +5 -0
- chainmem-0.3.0.dist-info/entry_points.txt +2 -0
- chainmem-0.3.0.dist-info/top_level.txt +1 -0
chainmem/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""ChainMem — 链式 + 向量混合记忆系统"""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from chainmem.core.node import ChainNode, Chain
|
|
6
|
+
from chainmem.store.sqlite_store import SQLiteStore
|
|
7
|
+
from chainmem.pipeline.ingester import Ingester
|
|
8
|
+
from chainmem.pipeline.retriever import Retriever
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ChainMemory:
|
|
12
|
+
"""ChainMem 主入口类"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, db_path: str = "~/.chainmem/data.db"):
|
|
15
|
+
self.db_path = db_path
|
|
16
|
+
self.store: SQLiteStore | None = None
|
|
17
|
+
self.ingester: Ingester | None = None
|
|
18
|
+
self.retriever: Retriever | None = None
|
|
19
|
+
|
|
20
|
+
def open(self):
|
|
21
|
+
"""打开数据库,加载索引"""
|
|
22
|
+
import os
|
|
23
|
+
path = os.path.expanduser(self.db_path)
|
|
24
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
25
|
+
|
|
26
|
+
self.store = SQLiteStore(path)
|
|
27
|
+
self.store.initialize()
|
|
28
|
+
self.ingester = Ingester(self.store)
|
|
29
|
+
self.retriever = Retriever(self.store)
|
|
30
|
+
return self
|
|
31
|
+
|
|
32
|
+
def close(self):
|
|
33
|
+
if self.store:
|
|
34
|
+
self.store.close()
|
|
35
|
+
|
|
36
|
+
def ingest(self, text: str, source: str = "", tags: list[str] | None = None) -> Chain:
|
|
37
|
+
"""结链:文本 → 切块 → 嵌入 → 存储"""
|
|
38
|
+
if not self.ingester:
|
|
39
|
+
raise RuntimeError("Call .open() first")
|
|
40
|
+
return self.ingester.ingest(text, source=source, tags=tags or [])
|
|
41
|
+
|
|
42
|
+
def set_model(self, model_name: str):
|
|
43
|
+
"""切换嵌入模型"""
|
|
44
|
+
from chainmem.pipeline.ingester import set_model as _set
|
|
45
|
+
_set(model_name)
|
|
46
|
+
# 重建索引使新模型生效
|
|
47
|
+
if self.retriever:
|
|
48
|
+
self.retriever.rebuild_index()
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def retrieve(self, query: str, max_steps: int = 100,
|
|
52
|
+
tags: list[str] | None = None) -> list[str]:
|
|
53
|
+
"""追溯:查询 → 最近邻 → 指针遍历"""
|
|
54
|
+
if not self.retriever:
|
|
55
|
+
raise RuntimeError("Call .open() first")
|
|
56
|
+
return self.retriever.retrieve(query, max_steps=max_steps, tags=tags)
|
|
57
|
+
|
|
58
|
+
def stats(self) -> dict:
|
|
59
|
+
if not self.store:
|
|
60
|
+
raise RuntimeError("Call .open() first")
|
|
61
|
+
return self.store.stats()
|
|
62
|
+
|
|
63
|
+
def __enter__(self):
|
|
64
|
+
return self.open()
|
|
65
|
+
|
|
66
|
+
def __exit__(self, *args):
|
|
67
|
+
self.close()
|
chainmem/__main__.py
ADDED
chainmem/cli/app.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
"""ChainMem CLI"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
import traceback
|
|
6
|
+
import asyncio
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import typer
|
|
11
|
+
from rich import print as rprint
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
from rich.table import Table
|
|
14
|
+
from rich.panel import Panel
|
|
15
|
+
|
|
16
|
+
from chainmem import ChainMemory
|
|
17
|
+
|
|
18
|
+
app = typer.Typer(help="ChainMem — 链式 + 向量混合记忆系统")
|
|
19
|
+
console = Console()
|
|
20
|
+
DEFAULT_DB = "~/.chainmem/data.db"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _get_cm(db: str | None = None) -> ChainMemory:
|
|
24
|
+
cm = ChainMemory(db_path=db or DEFAULT_DB)
|
|
25
|
+
return cm.open()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@app.command()
|
|
29
|
+
def ingest(
|
|
30
|
+
text: str = typer.Argument(..., help="要结链的文本"),
|
|
31
|
+
source: str = typer.Option("", "--source", "-s", help="来源会话"),
|
|
32
|
+
tags: str = typer.Option("", "--tags", "-t", help="标签(逗号分隔)"),
|
|
33
|
+
db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
|
|
34
|
+
):
|
|
35
|
+
"""结链:文本 → 切块 → 嵌入 → 存储"""
|
|
36
|
+
cm = _get_cm(db)
|
|
37
|
+
tag_list = [t.strip() for t in tags.split(",") if t.strip()]
|
|
38
|
+
|
|
39
|
+
chain = cm.ingest(text, source=source, tags=tag_list)
|
|
40
|
+
|
|
41
|
+
panel = Panel(
|
|
42
|
+
f"[bold green]✓ 结链成功[/bold green]\n\n"
|
|
43
|
+
f"链 ID: {chain.id}\n"
|
|
44
|
+
f"节点数: {chain.node_count}\n"
|
|
45
|
+
f"前缀锚点: [bold]{chain.anchor_prefix}[/bold]\n"
|
|
46
|
+
f"来源: {source or '(未指定)'}\n\n"
|
|
47
|
+
f"[dim]完整文本:[/dim]\n{chain.full_text()}",
|
|
48
|
+
title="ChainMem Ingest",
|
|
49
|
+
)
|
|
50
|
+
rprint(panel)
|
|
51
|
+
cm.close()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@app.command()
|
|
55
|
+
def retrieve(
|
|
56
|
+
query: str = typer.Argument(..., help="查询文本(前缀或关键词)"),
|
|
57
|
+
max_steps: int = typer.Option(100, "--max-steps", "-m", help="最大遍历步数"),
|
|
58
|
+
tags: str = typer.Option("", "--tags", "-t", help="标签过滤(逗号分隔,OR 逻辑)"),
|
|
59
|
+
db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
|
|
60
|
+
):
|
|
61
|
+
"""追溯:查询 → 最近邻 → 指针遍历 → 文本复原(支持标签过滤)"""
|
|
62
|
+
cm = _get_cm(db)
|
|
63
|
+
cm.retriever.rebuild_index()
|
|
64
|
+
|
|
65
|
+
tag_list = [t.strip() for t in tags.split(",") if t.strip()]
|
|
66
|
+
results = cm.retrieve(query, max_steps=max_steps, tags=tag_list or None)
|
|
67
|
+
cm.close()
|
|
68
|
+
|
|
69
|
+
if not results:
|
|
70
|
+
print("⚠ 未找到匹配的记忆")
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
print()
|
|
74
|
+
for i, text in enumerate(results):
|
|
75
|
+
marker = "🟢" if i == 0 else ("🔴" if i == len(results) - 1 else "🔵")
|
|
76
|
+
print(f" {marker} {text}")
|
|
77
|
+
|
|
78
|
+
print()
|
|
79
|
+
print("─" * 50)
|
|
80
|
+
print("完整记忆重现:")
|
|
81
|
+
print("".join(results))
|
|
82
|
+
print("─" * 50)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@app.command()
|
|
86
|
+
def stats(
|
|
87
|
+
db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
|
|
88
|
+
):
|
|
89
|
+
"""查看记忆统计"""
|
|
90
|
+
cm = _get_cm(db)
|
|
91
|
+
s = cm.stats()
|
|
92
|
+
chains = cm.store.get_all_chains()
|
|
93
|
+
cm.close()
|
|
94
|
+
|
|
95
|
+
table = Table(title="ChainMem 统计")
|
|
96
|
+
table.add_column("指标", style="cyan")
|
|
97
|
+
table.add_column("值", style="green")
|
|
98
|
+
table.add_row("数据库", s["db_path"])
|
|
99
|
+
table.add_row("链总数", str(s["chains"]))
|
|
100
|
+
table.add_row("节点总数", str(s["nodes"]))
|
|
101
|
+
console.print(table)
|
|
102
|
+
|
|
103
|
+
if chains:
|
|
104
|
+
console.print("\n[bold]已存储的链:[/bold]")
|
|
105
|
+
for c in chains:
|
|
106
|
+
tag_str = ""
|
|
107
|
+
tags_raw = c.get("tags", [])
|
|
108
|
+
if isinstance(tags_raw, str):
|
|
109
|
+
tags_raw = json.loads(tags_raw)
|
|
110
|
+
if tags_raw:
|
|
111
|
+
tag_str = f" [cyan]{' '.join('#' + t for t in tags_raw)}[/cyan]"
|
|
112
|
+
rprint(f" [dim]{c['id'][:8]}...[/dim] 前缀=[bold]{c['anchor_prefix']}[/bold] "
|
|
113
|
+
f"节点={c['node_count']} 强度={c['strength']:.1f}{tag_str} "
|
|
114
|
+
f"[dim]{c['created_at']}[/dim]")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@app.command()
|
|
118
|
+
def demo():
|
|
119
|
+
"""运行快速演示"""
|
|
120
|
+
import tempfile
|
|
121
|
+
db = tempfile.mktemp(suffix=".db")
|
|
122
|
+
|
|
123
|
+
cm = ChainMemory(db_path=db).open()
|
|
124
|
+
|
|
125
|
+
texts = [
|
|
126
|
+
"其实我的想法是把每一次的记忆包括一次对话全部变成一个链条,这样只要想起开头几个字就能顺着把后面的内容推导出来。",
|
|
127
|
+
"关于股决项目,我觉得应该先做好最薄弱的一环,然后让朋友内测、反馈、再扩,从不用登录墙开始。",
|
|
128
|
+
"用户对医疗养老行业和全栈项目有广泛兴趣,但当前最关注的是股决A股投资APP项目。",
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
for i, t in enumerate(texts):
|
|
132
|
+
chain = cm.ingest(t, source=f"demo_session_{i}", tags=["demo"])
|
|
133
|
+
rprint(f"[dim]✓ 已结链:[/dim] [bold]{chain.anchor_prefix}[/bold]... ({chain.node_count} 节点)")
|
|
134
|
+
|
|
135
|
+
cm.retriever.rebuild_index()
|
|
136
|
+
|
|
137
|
+
queries = [
|
|
138
|
+
"其实我的想法",
|
|
139
|
+
"关于股决",
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
for q in queries:
|
|
143
|
+
console.print(f"\n[bold]🔍 查询:[/bold] \"{q}\"")
|
|
144
|
+
results = cm.retrieve(q)
|
|
145
|
+
if results:
|
|
146
|
+
for i, t in enumerate(results):
|
|
147
|
+
marker = "🟢" if i == 0 else ("🔴" if i == len(results) - 1 else "🔵")
|
|
148
|
+
rprint(f" {marker} {t}")
|
|
149
|
+
else:
|
|
150
|
+
rprint(" [yellow]未找到匹配[/yellow]")
|
|
151
|
+
|
|
152
|
+
cm.close()
|
|
153
|
+
rprint("\n[bold green]✓ 演示完成[/bold green]")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@app.command()
|
|
157
|
+
def mcp(
|
|
158
|
+
db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
|
|
159
|
+
):
|
|
160
|
+
"""启动 MCP 协议服务器(stdio 模式,供 Hermes 按需调用)"""
|
|
161
|
+
_run_mcp_stdio(db)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@app.command()
|
|
165
|
+
def serve(
|
|
166
|
+
socket_path: str = typer.Option("/tmp/chainmem.sock", "--socket", "-s",
|
|
167
|
+
help="Unix socket 路径"),
|
|
168
|
+
db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
|
|
169
|
+
):
|
|
170
|
+
"""启动持久化 MCP 服务(Unix socket,供 Hermes 常驻连接)
|
|
171
|
+
|
|
172
|
+
模型在启动时一次性加载,之后查询毫秒级响应。
|
|
173
|
+
用 systemd 管理此服务。
|
|
174
|
+
"""
|
|
175
|
+
import os
|
|
176
|
+
import asyncio
|
|
177
|
+
import json
|
|
178
|
+
|
|
179
|
+
# 预加载模型和索引(冷启动,仅一次)
|
|
180
|
+
console.print("[bold]🔄 正在加载嵌入模型...[/bold]")
|
|
181
|
+
cm = _get_cm(db)
|
|
182
|
+
cm.retriever.rebuild_index()
|
|
183
|
+
console.print(f"[bold green]✓ 模型就绪![/bold green] {cm.stats()['nodes']} 个节点已索引")
|
|
184
|
+
cm.close()
|
|
185
|
+
|
|
186
|
+
# 确保 socket 目录存在
|
|
187
|
+
os.makedirs(os.path.dirname(socket_path), exist_ok=True)
|
|
188
|
+
if os.path.exists(socket_path):
|
|
189
|
+
os.unlink(socket_path)
|
|
190
|
+
|
|
191
|
+
async def handle_connection(reader: asyncio.StreamReader,
|
|
192
|
+
writer: asyncio.StreamWriter):
|
|
193
|
+
"""处理一个连接:读取 JSON-RPC,处理后返回"""
|
|
194
|
+
cm_conn = _get_cm(db) # 轻量连接(不加载模型,复用已缓存的嵌入)
|
|
195
|
+
try:
|
|
196
|
+
while True:
|
|
197
|
+
line = await reader.readline()
|
|
198
|
+
if not line:
|
|
199
|
+
break
|
|
200
|
+
line = line.strip().decode("utf-8")
|
|
201
|
+
if not line:
|
|
202
|
+
continue
|
|
203
|
+
try:
|
|
204
|
+
req = json.loads(line)
|
|
205
|
+
await _handle_mcp_request(req, cm_conn, writer)
|
|
206
|
+
except json.JSONDecodeError:
|
|
207
|
+
pass
|
|
208
|
+
except Exception:
|
|
209
|
+
pass
|
|
210
|
+
finally:
|
|
211
|
+
cm_conn.close()
|
|
212
|
+
writer.close()
|
|
213
|
+
|
|
214
|
+
async def server_main():
|
|
215
|
+
server = await asyncio.start_unix_server(handle_connection, path=socket_path)
|
|
216
|
+
os.chmod(socket_path, 0o666) # 多用户可访问
|
|
217
|
+
addr = server.sockets[0].getsockname()
|
|
218
|
+
console.print(f"[bold green]✅ ChainMem MCP 服务启动![/bold green]")
|
|
219
|
+
console.print(f" socket: [bold]{socket_path}[/bold]")
|
|
220
|
+
console.print(f" 模型: all-MiniLM-L6-v2 (已加载)")
|
|
221
|
+
console.print(f" 数据库: {db}")
|
|
222
|
+
async with server:
|
|
223
|
+
await server.serve_forever()
|
|
224
|
+
|
|
225
|
+
asyncio.run(server_main())
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# ── MCP 共享逻辑 ──
|
|
229
|
+
|
|
230
|
+
def _run_mcp_stdio(db: str):
|
|
231
|
+
"""stdio MCP 模式:从 stdin 读请求、stdout 写响应(Hermes 按需调用)"""
|
|
232
|
+
import sys
|
|
233
|
+
import json
|
|
234
|
+
|
|
235
|
+
_cm_instance = None
|
|
236
|
+
|
|
237
|
+
def get_cm():
|
|
238
|
+
nonlocal _cm_instance
|
|
239
|
+
if _cm_instance is None:
|
|
240
|
+
_cm_instance = _get_cm(db)
|
|
241
|
+
return _cm_instance
|
|
242
|
+
|
|
243
|
+
def send_response(id, result):
|
|
244
|
+
msg = json.dumps({"jsonrpc": "2.0", "id": id, "result": result})
|
|
245
|
+
sys.stdout.write(msg + "\n")
|
|
246
|
+
sys.stdout.flush()
|
|
247
|
+
|
|
248
|
+
def send_error(id, code, message):
|
|
249
|
+
msg = json.dumps({"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}})
|
|
250
|
+
sys.stdout.write(msg + "\n")
|
|
251
|
+
sys.stdout.flush()
|
|
252
|
+
|
|
253
|
+
for line in sys.stdin:
|
|
254
|
+
line = line.strip()
|
|
255
|
+
if not line:
|
|
256
|
+
continue
|
|
257
|
+
try:
|
|
258
|
+
req = json.loads(line)
|
|
259
|
+
_process_mcp_request(req, get_cm, send_response, send_error)
|
|
260
|
+
except json.JSONDecodeError:
|
|
261
|
+
pass
|
|
262
|
+
except Exception:
|
|
263
|
+
send_error(None, -1, traceback.format_exc())
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
async def _handle_mcp_request(req: dict, cm, writer: asyncio.StreamWriter,
|
|
267
|
+
rebuild_index: bool = True):
|
|
268
|
+
"""异步版 MCP 请求处理(serve 模式用)"""
|
|
269
|
+
import json
|
|
270
|
+
|
|
271
|
+
def send_response(id, result):
|
|
272
|
+
msg = json.dumps({"jsonrpc": "2.0", "id": id, "result": result})
|
|
273
|
+
writer.write((msg + "\n").encode("utf-8"))
|
|
274
|
+
|
|
275
|
+
def send_error(id, code, message):
|
|
276
|
+
msg = json.dumps({"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}})
|
|
277
|
+
writer.write((msg + "\n").encode("utf-8"))
|
|
278
|
+
|
|
279
|
+
_process_mcp_request(req, lambda: cm, send_response, send_error,
|
|
280
|
+
rebuild_index=rebuild_index)
|
|
281
|
+
await writer.drain()
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _process_mcp_request(req: dict, get_cm, send_response, send_error,
|
|
285
|
+
rebuild_index: bool = True):
|
|
286
|
+
"""MCP 请求处理核心(stdio 和 serve 模式共用)
|
|
287
|
+
|
|
288
|
+
rebuild_index: True 则在每次 retrieve 前重建索引(stdio 模式),
|
|
289
|
+
False 则仅 ingest 后重建(serve 模式,索引常驻)
|
|
290
|
+
"""
|
|
291
|
+
import json
|
|
292
|
+
req_id = req.get("id")
|
|
293
|
+
method = req.get("method")
|
|
294
|
+
|
|
295
|
+
if method == "tools/list":
|
|
296
|
+
send_response(req_id, {
|
|
297
|
+
"tools": [
|
|
298
|
+
{
|
|
299
|
+
"name": "chainmem_ingest",
|
|
300
|
+
"description": "结链:将文本存储为链式记忆",
|
|
301
|
+
"inputSchema": {
|
|
302
|
+
"type": "object",
|
|
303
|
+
"properties": {
|
|
304
|
+
"text": {"type": "string", "description": "要结链的文本"},
|
|
305
|
+
"source": {"type": "string", "description": "来源会话"},
|
|
306
|
+
"tags": {"type": "string", "description": "标签(逗号分隔)"},
|
|
307
|
+
},
|
|
308
|
+
"required": ["text"],
|
|
309
|
+
},
|
|
310
|
+
},
|
|
311
|
+
{
|
|
312
|
+
"name": "chainmem_retrieve",
|
|
313
|
+
"description": "追溯:输入查询,还原完整记忆链(支持可选标签过滤)",
|
|
314
|
+
"inputSchema": {
|
|
315
|
+
"type": "object",
|
|
316
|
+
"properties": {
|
|
317
|
+
"query": {"type": "string", "description": "查询文本"},
|
|
318
|
+
"tags": {"type": "string",
|
|
319
|
+
"description": "可选,标签过滤(逗号分隔,OR 逻辑)"},
|
|
320
|
+
},
|
|
321
|
+
"required": ["query"],
|
|
322
|
+
},
|
|
323
|
+
},
|
|
324
|
+
{
|
|
325
|
+
"name": "chainmem_stats",
|
|
326
|
+
"description": "查看记忆统计",
|
|
327
|
+
"inputSchema": {
|
|
328
|
+
"type": "object",
|
|
329
|
+
"properties": {},
|
|
330
|
+
},
|
|
331
|
+
},
|
|
332
|
+
]
|
|
333
|
+
})
|
|
334
|
+
|
|
335
|
+
elif method == "tools/call":
|
|
336
|
+
tool_name = req.get("params", {}).get("name")
|
|
337
|
+
arguments = req.get("params", {}).get("arguments", {})
|
|
338
|
+
|
|
339
|
+
if tool_name == "chainmem_ingest":
|
|
340
|
+
text = arguments.get("text", "")
|
|
341
|
+
source = arguments.get("source", "")
|
|
342
|
+
tags = [t.strip() for t in arguments.get("tags", "").split(",") if t.strip()]
|
|
343
|
+
try:
|
|
344
|
+
cm = get_cm()
|
|
345
|
+
chain = cm.ingest(text, source=source, tags=tags)
|
|
346
|
+
cm.retriever.rebuild_index()
|
|
347
|
+
send_response(req_id, {
|
|
348
|
+
"content": [{"type": "text",
|
|
349
|
+
"text": f"结链成功:{chain.node_count} 个节点,前缀「{chain.anchor_prefix}」"}]
|
|
350
|
+
})
|
|
351
|
+
except Exception as e:
|
|
352
|
+
send_error(req_id, -1, str(e))
|
|
353
|
+
|
|
354
|
+
elif tool_name == "chainmem_retrieve":
|
|
355
|
+
query = arguments.get("query", "")
|
|
356
|
+
tags_str = arguments.get("tags", "")
|
|
357
|
+
tag_list = [t.strip() for t in tags_str.split(",") if t.strip()]
|
|
358
|
+
cm = get_cm()
|
|
359
|
+
if rebuild_index:
|
|
360
|
+
cm.retriever.rebuild_index()
|
|
361
|
+
results = cm.retrieve(query, tags=tag_list or None)
|
|
362
|
+
if results:
|
|
363
|
+
full_text = "".join(results)
|
|
364
|
+
send_response(req_id, {
|
|
365
|
+
"content": [{"type": "text", "text": full_text}]
|
|
366
|
+
})
|
|
367
|
+
else:
|
|
368
|
+
send_response(req_id, {
|
|
369
|
+
"content": [{"type": "text", "text": "未找到匹配的记忆"}]
|
|
370
|
+
})
|
|
371
|
+
|
|
372
|
+
elif tool_name == "chainmem_stats":
|
|
373
|
+
cm = get_cm()
|
|
374
|
+
stats = cm.stats()
|
|
375
|
+
text = f"链总数: {stats['chains']}\n节点总数: {stats['nodes']}\n数据库: {stats['db_path']}"
|
|
376
|
+
send_response(req_id, {
|
|
377
|
+
"content": [{"type": "text", "text": text}]
|
|
378
|
+
})
|
|
379
|
+
|
|
380
|
+
else:
|
|
381
|
+
send_error(req_id, -32601, f"未知工具: {tool_name}")
|
|
382
|
+
|
|
383
|
+
elif method == "initialize":
|
|
384
|
+
send_response(req_id, {
|
|
385
|
+
"protocolVersion": "2025-11-25",
|
|
386
|
+
"capabilities": {"tools": {}},
|
|
387
|
+
"serverInfo": {"name": "chainmem", "version": "0.1.0"},
|
|
388
|
+
})
|
|
389
|
+
|
|
390
|
+
elif method == "notifications/initialized":
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
else:
|
|
394
|
+
send_error(req_id, -32601, f"未知方法: {method}")
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
if __name__ == "__main__":
|
|
398
|
+
app()
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""数据模型:ChainNode 和 Chain"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Optional
|
|
6
|
+
import uuid
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ChainNode:
|
|
12
|
+
"""链节点——记忆的最小单元"""
|
|
13
|
+
id: str
|
|
14
|
+
chain_id: str
|
|
15
|
+
seq: int
|
|
16
|
+
text: str
|
|
17
|
+
embedding: np.ndarray | None = None # shape=(d,),运行时内存中
|
|
18
|
+
prev_id: str | None = None
|
|
19
|
+
next_id: str | None = None
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def text_prefix(self) -> str:
|
|
23
|
+
return self.text[:3] if len(self.text) >= 3 else self.text
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class Chain:
|
|
28
|
+
"""链——整段记忆的元信息"""
|
|
29
|
+
id: str
|
|
30
|
+
root_id: str
|
|
31
|
+
leaf_id: str
|
|
32
|
+
anchor_prefix: str
|
|
33
|
+
node_count: int
|
|
34
|
+
nodes: list[ChainNode] = field(default_factory=list)
|
|
35
|
+
summary: str = ""
|
|
36
|
+
source: str = ""
|
|
37
|
+
tags: list[str] = field(default_factory=list)
|
|
38
|
+
strength: float = 1.0
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_nodes(cls, nodes: list[ChainNode]) -> "Chain":
|
|
42
|
+
if not nodes:
|
|
43
|
+
raise ValueError("Cannot create Chain from empty nodes list")
|
|
44
|
+
chain_id = nodes[0].chain_id
|
|
45
|
+
return cls(
|
|
46
|
+
id=chain_id,
|
|
47
|
+
root_id=nodes[0].id,
|
|
48
|
+
leaf_id=nodes[-1].id,
|
|
49
|
+
anchor_prefix=nodes[0].text_prefix,
|
|
50
|
+
node_count=len(nodes),
|
|
51
|
+
nodes=nodes,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def full_text(self) -> str:
|
|
55
|
+
"""拼接整条链的完整文本"""
|
|
56
|
+
return "".join(node.text for node in self.nodes)
|
|
57
|
+
|
|
58
|
+
def to_dict(self) -> dict:
|
|
59
|
+
return {
|
|
60
|
+
"id": self.id,
|
|
61
|
+
"node_count": self.node_count,
|
|
62
|
+
"anchor_prefix": self.anchor_prefix,
|
|
63
|
+
"source": self.source,
|
|
64
|
+
"tags": self.tags,
|
|
65
|
+
"strength": self.strength,
|
|
66
|
+
"full_text": self.full_text(),
|
|
67
|
+
}
|
chainmem/core/node.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""结链管道:文本 → 切块 → 嵌入 → 存储"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import re
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sentence_transformers import SentenceTransformer
|
|
9
|
+
|
|
10
|
+
from chainmem.core.node import ChainNode, Chain
|
|
11
|
+
from chainmem.store.sqlite_store import SQLiteStore
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# 全局复用嵌入模型(加载一次即可)
|
|
15
|
+
_MODEL: SentenceTransformer | None = None
|
|
16
|
+
_MODEL_NAME: str = "all-MiniLM-L6-v2"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_model(model_name: str | None = None) -> SentenceTransformer:
|
|
20
|
+
global _MODEL, _MODEL_NAME
|
|
21
|
+
if model_name is not None and model_name != _MODEL_NAME:
|
|
22
|
+
# 切换模型
|
|
23
|
+
_MODEL = SentenceTransformer(model_name)
|
|
24
|
+
_MODEL_NAME = model_name
|
|
25
|
+
elif _MODEL is None:
|
|
26
|
+
_MODEL = SentenceTransformer(_MODEL_NAME)
|
|
27
|
+
return _MODEL
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def set_model(model_name: str):
|
|
31
|
+
"""切换嵌入模型(下次调用 _get_model 时生效)"""
|
|
32
|
+
global _MODEL_NAME, _MODEL
|
|
33
|
+
_MODEL_NAME = model_name
|
|
34
|
+
_MODEL = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def chunk_text(text: str, max_chars: int = 18) -> list[str]:
|
|
38
|
+
"""将长文本按自然停顿切分为短语块
|
|
39
|
+
|
|
40
|
+
切分规则:
|
|
41
|
+
1. 始终按句号/问号/感叹号等终结标点切分
|
|
42
|
+
2. 始终按逗号/顿号/冒号切分
|
|
43
|
+
3. 过长的块(> max_chars)硬截断
|
|
44
|
+
"""
|
|
45
|
+
# 1. 按终结标点切分(保留标点)
|
|
46
|
+
parts = re.split(r'(?<=[。!?;…\n])\s*', text)
|
|
47
|
+
parts = [p.strip() for p in parts if p.strip()]
|
|
48
|
+
|
|
49
|
+
# 2. 对每个部分按逗号/顿号/冒号再切
|
|
50
|
+
chunks = []
|
|
51
|
+
for part in parts:
|
|
52
|
+
sub = re.split(r'(?<=[,、:])\s*', part)
|
|
53
|
+
for s in sub:
|
|
54
|
+
s = s.strip()
|
|
55
|
+
if not s:
|
|
56
|
+
continue
|
|
57
|
+
if len(s) <= max_chars:
|
|
58
|
+
chunks.append(s)
|
|
59
|
+
else:
|
|
60
|
+
# 过长的硬截断
|
|
61
|
+
for i in range(0, len(s), max_chars):
|
|
62
|
+
chunks.append(s[i:i + max_chars])
|
|
63
|
+
# 3. 合併過短的塊(避免 sentence-transformers 的退化嵌入)
|
|
64
|
+
chunks = merge_short_chunks(chunks)
|
|
65
|
+
return [c for c in chunks if c]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def merge_short_chunks(chunks: list[str], min_chars: int = 6) -> list[str]:
|
|
69
|
+
"""合併過短的塊到前一個相鄰塊
|
|
70
|
+
|
|
71
|
+
sentence-transformers 對 ≤5 字的短文本會產生退化嵌入
|
|
72
|
+
(不同文本得到完全相同向量->cosine=1.0)
|
|
73
|
+
因此需要將短塊併入相鄰的長塊
|
|
74
|
+
"""
|
|
75
|
+
if len(chunks) <= 1:
|
|
76
|
+
return chunks
|
|
77
|
+
merged = []
|
|
78
|
+
for chunk in chunks:
|
|
79
|
+
if merged and len(chunk) <= min_chars:
|
|
80
|
+
# 合併到前一個塊
|
|
81
|
+
merged[-1] = merged[-1] + chunk
|
|
82
|
+
else:
|
|
83
|
+
merged.append(chunk)
|
|
84
|
+
return merged
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Ingester:
|
|
88
|
+
"""结链器:文本 → 链"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, store: SQLiteStore):
|
|
91
|
+
self.store = store
|
|
92
|
+
self.embedder = _get_model()
|
|
93
|
+
|
|
94
|
+
def ingest(self, text: str, source: str = "", tags: list[str] | None = None) -> Chain:
|
|
95
|
+
chunks = chunk_text(text)
|
|
96
|
+
if not chunks:
|
|
97
|
+
raise ValueError("Empty text after chunking")
|
|
98
|
+
|
|
99
|
+
chain_id = str(uuid.uuid4())
|
|
100
|
+
nodes: list[ChainNode] = []
|
|
101
|
+
|
|
102
|
+
# 1. 嵌入所有块
|
|
103
|
+
embeddings = self.embedder.encode(chunks, normalize_embeddings=True)
|
|
104
|
+
|
|
105
|
+
# 2. 创建节点,串联
|
|
106
|
+
prev_id: str | None = None
|
|
107
|
+
for i, (phrase_text, emb) in enumerate(zip(chunks, embeddings)):
|
|
108
|
+
node_id = str(uuid.uuid4())
|
|
109
|
+
node = ChainNode(
|
|
110
|
+
id=node_id,
|
|
111
|
+
chain_id=chain_id,
|
|
112
|
+
seq=i + 1,
|
|
113
|
+
text=phrase_text,
|
|
114
|
+
embedding=emb,
|
|
115
|
+
prev_id=prev_id,
|
|
116
|
+
)
|
|
117
|
+
if prev_id:
|
|
118
|
+
# 更新前一个节点的 next_id
|
|
119
|
+
nodes[-1].next_id = node_id
|
|
120
|
+
nodes.append(node)
|
|
121
|
+
prev_id = node_id
|
|
122
|
+
|
|
123
|
+
root_id = nodes[0].id
|
|
124
|
+
leaf_id = nodes[-1].id
|
|
125
|
+
|
|
126
|
+
# 3. 存数据库
|
|
127
|
+
self.store.save_chain(
|
|
128
|
+
chain_id=chain_id,
|
|
129
|
+
anchor_prefix=nodes[0].text_prefix,
|
|
130
|
+
root_id=root_id,
|
|
131
|
+
leaf_id=leaf_id,
|
|
132
|
+
node_count=len(nodes),
|
|
133
|
+
source=source,
|
|
134
|
+
tags=tags or [],
|
|
135
|
+
)
|
|
136
|
+
for n in nodes:
|
|
137
|
+
self.store.save_node(
|
|
138
|
+
node_id=n.id,
|
|
139
|
+
chain_id=n.chain_id,
|
|
140
|
+
seq=n.seq,
|
|
141
|
+
text=n.text,
|
|
142
|
+
prev_id=n.prev_id,
|
|
143
|
+
next_id=n.next_id,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
chain = Chain.from_nodes(nodes)
|
|
147
|
+
chain.source = source
|
|
148
|
+
chain.tags = tags or []
|
|
149
|
+
return chain
|