chainmem 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chainmem/__init__.py ADDED
@@ -0,0 +1,67 @@
1
+ """ChainMem — 链式 + 向量混合记忆系统"""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from chainmem.core.node import ChainNode, Chain
6
+ from chainmem.store.sqlite_store import SQLiteStore
7
+ from chainmem.pipeline.ingester import Ingester
8
+ from chainmem.pipeline.retriever import Retriever
9
+
10
+
11
+ class ChainMemory:
12
+ """ChainMem 主入口类"""
13
+
14
+ def __init__(self, db_path: str = "~/.chainmem/data.db"):
15
+ self.db_path = db_path
16
+ self.store: SQLiteStore | None = None
17
+ self.ingester: Ingester | None = None
18
+ self.retriever: Retriever | None = None
19
+
20
+ def open(self):
21
+ """打开数据库,加载索引"""
22
+ import os
23
+ path = os.path.expanduser(self.db_path)
24
+ os.makedirs(os.path.dirname(path), exist_ok=True)
25
+
26
+ self.store = SQLiteStore(path)
27
+ self.store.initialize()
28
+ self.ingester = Ingester(self.store)
29
+ self.retriever = Retriever(self.store)
30
+ return self
31
+
32
+ def close(self):
33
+ if self.store:
34
+ self.store.close()
35
+
36
+ def ingest(self, text: str, source: str = "", tags: list[str] | None = None) -> Chain:
37
+ """结链:文本 → 切块 → 嵌入 → 存储"""
38
+ if not self.ingester:
39
+ raise RuntimeError("Call .open() first")
40
+ return self.ingester.ingest(text, source=source, tags=tags or [])
41
+
42
+ def set_model(self, model_name: str):
43
+ """切换嵌入模型"""
44
+ from chainmem.pipeline.ingester import set_model as _set
45
+ _set(model_name)
46
+ # 重建索引使新模型生效
47
+ if self.retriever:
48
+ self.retriever.rebuild_index()
49
+ return self
50
+
51
+ def retrieve(self, query: str, max_steps: int = 100,
52
+ tags: list[str] | None = None) -> list[str]:
53
+ """追溯:查询 → 最近邻 → 指针遍历"""
54
+ if not self.retriever:
55
+ raise RuntimeError("Call .open() first")
56
+ return self.retriever.retrieve(query, max_steps=max_steps, tags=tags)
57
+
58
+ def stats(self) -> dict:
59
+ if not self.store:
60
+ raise RuntimeError("Call .open() first")
61
+ return self.store.stats()
62
+
63
+ def __enter__(self):
64
+ return self.open()
65
+
66
+ def __exit__(self, *args):
67
+ self.close()
chainmem/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ """让 python -m chainmem 可直接运行"""
2
+ from chainmem.cli.app import app
3
+
4
+ if __name__ == "__main__":
5
+ app()
chainmem/cli/app.py ADDED
@@ -0,0 +1,398 @@
1
+ """ChainMem CLI"""
2
+
3
+ import json
4
+ import sys
5
+ import traceback
6
+ import asyncio
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import typer
11
+ from rich import print as rprint
12
+ from rich.console import Console
13
+ from rich.table import Table
14
+ from rich.panel import Panel
15
+
16
+ from chainmem import ChainMemory
17
+
18
+ app = typer.Typer(help="ChainMem — 链式 + 向量混合记忆系统")
19
+ console = Console()
20
+ DEFAULT_DB = "~/.chainmem/data.db"
21
+
22
+
23
+ def _get_cm(db: str | None = None) -> ChainMemory:
24
+ cm = ChainMemory(db_path=db or DEFAULT_DB)
25
+ return cm.open()
26
+
27
+
28
+ @app.command()
29
+ def ingest(
30
+ text: str = typer.Argument(..., help="要结链的文本"),
31
+ source: str = typer.Option("", "--source", "-s", help="来源会话"),
32
+ tags: str = typer.Option("", "--tags", "-t", help="标签(逗号分隔)"),
33
+ db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
34
+ ):
35
+ """结链:文本 → 切块 → 嵌入 → 存储"""
36
+ cm = _get_cm(db)
37
+ tag_list = [t.strip() for t in tags.split(",") if t.strip()]
38
+
39
+ chain = cm.ingest(text, source=source, tags=tag_list)
40
+
41
+ panel = Panel(
42
+ f"[bold green]✓ 结链成功[/bold green]\n\n"
43
+ f"链 ID: {chain.id}\n"
44
+ f"节点数: {chain.node_count}\n"
45
+ f"前缀锚点: [bold]{chain.anchor_prefix}[/bold]\n"
46
+ f"来源: {source or '(未指定)'}\n\n"
47
+ f"[dim]完整文本:[/dim]\n{chain.full_text()}",
48
+ title="ChainMem Ingest",
49
+ )
50
+ rprint(panel)
51
+ cm.close()
52
+
53
+
54
+ @app.command()
55
+ def retrieve(
56
+ query: str = typer.Argument(..., help="查询文本(前缀或关键词)"),
57
+ max_steps: int = typer.Option(100, "--max-steps", "-m", help="最大遍历步数"),
58
+ tags: str = typer.Option("", "--tags", "-t", help="标签过滤(逗号分隔,OR 逻辑)"),
59
+ db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
60
+ ):
61
+ """追溯:查询 → 最近邻 → 指针遍历 → 文本复原(支持标签过滤)"""
62
+ cm = _get_cm(db)
63
+ cm.retriever.rebuild_index()
64
+
65
+ tag_list = [t.strip() for t in tags.split(",") if t.strip()]
66
+ results = cm.retrieve(query, max_steps=max_steps, tags=tag_list or None)
67
+ cm.close()
68
+
69
+ if not results:
70
+ print("⚠ 未找到匹配的记忆")
71
+ return
72
+
73
+ print()
74
+ for i, text in enumerate(results):
75
+ marker = "🟢" if i == 0 else ("🔴" if i == len(results) - 1 else "🔵")
76
+ print(f" {marker} {text}")
77
+
78
+ print()
79
+ print("─" * 50)
80
+ print("完整记忆重现:")
81
+ print("".join(results))
82
+ print("─" * 50)
83
+
84
+
85
+ @app.command()
86
+ def stats(
87
+ db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
88
+ ):
89
+ """查看记忆统计"""
90
+ cm = _get_cm(db)
91
+ s = cm.stats()
92
+ chains = cm.store.get_all_chains()
93
+ cm.close()
94
+
95
+ table = Table(title="ChainMem 统计")
96
+ table.add_column("指标", style="cyan")
97
+ table.add_column("值", style="green")
98
+ table.add_row("数据库", s["db_path"])
99
+ table.add_row("链总数", str(s["chains"]))
100
+ table.add_row("节点总数", str(s["nodes"]))
101
+ console.print(table)
102
+
103
+ if chains:
104
+ console.print("\n[bold]已存储的链:[/bold]")
105
+ for c in chains:
106
+ tag_str = ""
107
+ tags_raw = c.get("tags", [])
108
+ if isinstance(tags_raw, str):
109
+ tags_raw = json.loads(tags_raw)
110
+ if tags_raw:
111
+ tag_str = f" [cyan]{' '.join('#' + t for t in tags_raw)}[/cyan]"
112
+ rprint(f" [dim]{c['id'][:8]}...[/dim] 前缀=[bold]{c['anchor_prefix']}[/bold] "
113
+ f"节点={c['node_count']} 强度={c['strength']:.1f}{tag_str} "
114
+ f"[dim]{c['created_at']}[/dim]")
115
+
116
+
117
+ @app.command()
118
+ def demo():
119
+ """运行快速演示"""
120
+ import tempfile
121
+ db = tempfile.mktemp(suffix=".db")
122
+
123
+ cm = ChainMemory(db_path=db).open()
124
+
125
+ texts = [
126
+ "其实我的想法是把每一次的记忆包括一次对话全部变成一个链条,这样只要想起开头几个字就能顺着把后面的内容推导出来。",
127
+ "关于股决项目,我觉得应该先做好最薄弱的一环,然后让朋友内测、反馈、再扩,从不用登录墙开始。",
128
+ "用户对医疗养老行业和全栈项目有广泛兴趣,但当前最关注的是股决A股投资APP项目。",
129
+ ]
130
+
131
+ for i, t in enumerate(texts):
132
+ chain = cm.ingest(t, source=f"demo_session_{i}", tags=["demo"])
133
+ rprint(f"[dim]✓ 已结链:[/dim] [bold]{chain.anchor_prefix}[/bold]... ({chain.node_count} 节点)")
134
+
135
+ cm.retriever.rebuild_index()
136
+
137
+ queries = [
138
+ "其实我的想法",
139
+ "关于股决",
140
+ ]
141
+
142
+ for q in queries:
143
+ console.print(f"\n[bold]🔍 查询:[/bold] \"{q}\"")
144
+ results = cm.retrieve(q)
145
+ if results:
146
+ for i, t in enumerate(results):
147
+ marker = "🟢" if i == 0 else ("🔴" if i == len(results) - 1 else "🔵")
148
+ rprint(f" {marker} {t}")
149
+ else:
150
+ rprint(" [yellow]未找到匹配[/yellow]")
151
+
152
+ cm.close()
153
+ rprint("\n[bold green]✓ 演示完成[/bold green]")
154
+
155
+
156
+ @app.command()
157
+ def mcp(
158
+ db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
159
+ ):
160
+ """启动 MCP 协议服务器(stdio 模式,供 Hermes 按需调用)"""
161
+ _run_mcp_stdio(db)
162
+
163
+
164
+ @app.command()
165
+ def serve(
166
+ socket_path: str = typer.Option("/tmp/chainmem.sock", "--socket", "-s",
167
+ help="Unix socket 路径"),
168
+ db: str = typer.Option(DEFAULT_DB, "--db", "-d", help="数据库路径"),
169
+ ):
170
+ """启动持久化 MCP 服务(Unix socket,供 Hermes 常驻连接)
171
+
172
+ 模型在启动时一次性加载,之后查询毫秒级响应。
173
+ 用 systemd 管理此服务。
174
+ """
175
+ import os
176
+ import asyncio
177
+ import json
178
+
179
+ # 预加载模型和索引(冷启动,仅一次)
180
+ console.print("[bold]🔄 正在加载嵌入模型...[/bold]")
181
+ cm = _get_cm(db)
182
+ cm.retriever.rebuild_index()
183
+ console.print(f"[bold green]✓ 模型就绪![/bold green] {cm.stats()['nodes']} 个节点已索引")
184
+ cm.close()
185
+
186
+ # 确保 socket 目录存在
187
+ os.makedirs(os.path.dirname(socket_path), exist_ok=True)
188
+ if os.path.exists(socket_path):
189
+ os.unlink(socket_path)
190
+
191
+ async def handle_connection(reader: asyncio.StreamReader,
192
+ writer: asyncio.StreamWriter):
193
+ """处理一个连接:读取 JSON-RPC,处理后返回"""
194
+ cm_conn = _get_cm(db) # 轻量连接(不加载模型,复用已缓存的嵌入)
195
+ try:
196
+ while True:
197
+ line = await reader.readline()
198
+ if not line:
199
+ break
200
+ line = line.strip().decode("utf-8")
201
+ if not line:
202
+ continue
203
+ try:
204
+ req = json.loads(line)
205
+ await _handle_mcp_request(req, cm_conn, writer)
206
+ except json.JSONDecodeError:
207
+ pass
208
+ except Exception:
209
+ pass
210
+ finally:
211
+ cm_conn.close()
212
+ writer.close()
213
+
214
+ async def server_main():
215
+ server = await asyncio.start_unix_server(handle_connection, path=socket_path)
216
+ os.chmod(socket_path, 0o666) # 多用户可访问
217
+ addr = server.sockets[0].getsockname()
218
+ console.print(f"[bold green]✅ ChainMem MCP 服务启动![/bold green]")
219
+ console.print(f" socket: [bold]{socket_path}[/bold]")
220
+ console.print(f" 模型: all-MiniLM-L6-v2 (已加载)")
221
+ console.print(f" 数据库: {db}")
222
+ async with server:
223
+ await server.serve_forever()
224
+
225
+ asyncio.run(server_main())
226
+
227
+
228
+ # ── MCP 共享逻辑 ──
229
+
230
+ def _run_mcp_stdio(db: str):
231
+ """stdio MCP 模式:从 stdin 读请求、stdout 写响应(Hermes 按需调用)"""
232
+ import sys
233
+ import json
234
+
235
+ _cm_instance = None
236
+
237
+ def get_cm():
238
+ nonlocal _cm_instance
239
+ if _cm_instance is None:
240
+ _cm_instance = _get_cm(db)
241
+ return _cm_instance
242
+
243
+ def send_response(id, result):
244
+ msg = json.dumps({"jsonrpc": "2.0", "id": id, "result": result})
245
+ sys.stdout.write(msg + "\n")
246
+ sys.stdout.flush()
247
+
248
+ def send_error(id, code, message):
249
+ msg = json.dumps({"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}})
250
+ sys.stdout.write(msg + "\n")
251
+ sys.stdout.flush()
252
+
253
+ for line in sys.stdin:
254
+ line = line.strip()
255
+ if not line:
256
+ continue
257
+ try:
258
+ req = json.loads(line)
259
+ _process_mcp_request(req, get_cm, send_response, send_error)
260
+ except json.JSONDecodeError:
261
+ pass
262
+ except Exception:
263
+ send_error(None, -1, traceback.format_exc())
264
+
265
+
266
+ async def _handle_mcp_request(req: dict, cm, writer: asyncio.StreamWriter,
267
+ rebuild_index: bool = True):
268
+ """异步版 MCP 请求处理(serve 模式用)"""
269
+ import json
270
+
271
+ def send_response(id, result):
272
+ msg = json.dumps({"jsonrpc": "2.0", "id": id, "result": result})
273
+ writer.write((msg + "\n").encode("utf-8"))
274
+
275
+ def send_error(id, code, message):
276
+ msg = json.dumps({"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}})
277
+ writer.write((msg + "\n").encode("utf-8"))
278
+
279
+ _process_mcp_request(req, lambda: cm, send_response, send_error,
280
+ rebuild_index=rebuild_index)
281
+ await writer.drain()
282
+
283
+
284
+ def _process_mcp_request(req: dict, get_cm, send_response, send_error,
285
+ rebuild_index: bool = True):
286
+ """MCP 请求处理核心(stdio 和 serve 模式共用)
287
+
288
+ rebuild_index: True 则在每次 retrieve 前重建索引(stdio 模式),
289
+ False 则仅 ingest 后重建(serve 模式,索引常驻)
290
+ """
291
+ import json
292
+ req_id = req.get("id")
293
+ method = req.get("method")
294
+
295
+ if method == "tools/list":
296
+ send_response(req_id, {
297
+ "tools": [
298
+ {
299
+ "name": "chainmem_ingest",
300
+ "description": "结链:将文本存储为链式记忆",
301
+ "inputSchema": {
302
+ "type": "object",
303
+ "properties": {
304
+ "text": {"type": "string", "description": "要结链的文本"},
305
+ "source": {"type": "string", "description": "来源会话"},
306
+ "tags": {"type": "string", "description": "标签(逗号分隔)"},
307
+ },
308
+ "required": ["text"],
309
+ },
310
+ },
311
+ {
312
+ "name": "chainmem_retrieve",
313
+ "description": "追溯:输入查询,还原完整记忆链(支持可选标签过滤)",
314
+ "inputSchema": {
315
+ "type": "object",
316
+ "properties": {
317
+ "query": {"type": "string", "description": "查询文本"},
318
+ "tags": {"type": "string",
319
+ "description": "可选,标签过滤(逗号分隔,OR 逻辑)"},
320
+ },
321
+ "required": ["query"],
322
+ },
323
+ },
324
+ {
325
+ "name": "chainmem_stats",
326
+ "description": "查看记忆统计",
327
+ "inputSchema": {
328
+ "type": "object",
329
+ "properties": {},
330
+ },
331
+ },
332
+ ]
333
+ })
334
+
335
+ elif method == "tools/call":
336
+ tool_name = req.get("params", {}).get("name")
337
+ arguments = req.get("params", {}).get("arguments", {})
338
+
339
+ if tool_name == "chainmem_ingest":
340
+ text = arguments.get("text", "")
341
+ source = arguments.get("source", "")
342
+ tags = [t.strip() for t in arguments.get("tags", "").split(",") if t.strip()]
343
+ try:
344
+ cm = get_cm()
345
+ chain = cm.ingest(text, source=source, tags=tags)
346
+ cm.retriever.rebuild_index()
347
+ send_response(req_id, {
348
+ "content": [{"type": "text",
349
+ "text": f"结链成功:{chain.node_count} 个节点,前缀「{chain.anchor_prefix}」"}]
350
+ })
351
+ except Exception as e:
352
+ send_error(req_id, -1, str(e))
353
+
354
+ elif tool_name == "chainmem_retrieve":
355
+ query = arguments.get("query", "")
356
+ tags_str = arguments.get("tags", "")
357
+ tag_list = [t.strip() for t in tags_str.split(",") if t.strip()]
358
+ cm = get_cm()
359
+ if rebuild_index:
360
+ cm.retriever.rebuild_index()
361
+ results = cm.retrieve(query, tags=tag_list or None)
362
+ if results:
363
+ full_text = "".join(results)
364
+ send_response(req_id, {
365
+ "content": [{"type": "text", "text": full_text}]
366
+ })
367
+ else:
368
+ send_response(req_id, {
369
+ "content": [{"type": "text", "text": "未找到匹配的记忆"}]
370
+ })
371
+
372
+ elif tool_name == "chainmem_stats":
373
+ cm = get_cm()
374
+ stats = cm.stats()
375
+ text = f"链总数: {stats['chains']}\n节点总数: {stats['nodes']}\n数据库: {stats['db_path']}"
376
+ send_response(req_id, {
377
+ "content": [{"type": "text", "text": text}]
378
+ })
379
+
380
+ else:
381
+ send_error(req_id, -32601, f"未知工具: {tool_name}")
382
+
383
+ elif method == "initialize":
384
+ send_response(req_id, {
385
+ "protocolVersion": "2025-11-25",
386
+ "capabilities": {"tools": {}},
387
+ "serverInfo": {"name": "chainmem", "version": "0.1.0"},
388
+ })
389
+
390
+ elif method == "notifications/initialized":
391
+ pass
392
+
393
+ else:
394
+ send_error(req_id, -32601, f"未知方法: {method}")
395
+
396
+
397
+ if __name__ == "__main__":
398
+ app()
@@ -0,0 +1,67 @@
1
+ """数据模型:ChainNode 和 Chain"""
2
+
3
+ from __future__ import annotations
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional
6
+ import uuid
7
+ import numpy as np
8
+
9
+
10
+ @dataclass
11
+ class ChainNode:
12
+ """链节点——记忆的最小单元"""
13
+ id: str
14
+ chain_id: str
15
+ seq: int
16
+ text: str
17
+ embedding: np.ndarray | None = None # shape=(d,),运行时内存中
18
+ prev_id: str | None = None
19
+ next_id: str | None = None
20
+
21
+ @property
22
+ def text_prefix(self) -> str:
23
+ return self.text[:3] if len(self.text) >= 3 else self.text
24
+
25
+
26
+ @dataclass
27
+ class Chain:
28
+ """链——整段记忆的元信息"""
29
+ id: str
30
+ root_id: str
31
+ leaf_id: str
32
+ anchor_prefix: str
33
+ node_count: int
34
+ nodes: list[ChainNode] = field(default_factory=list)
35
+ summary: str = ""
36
+ source: str = ""
37
+ tags: list[str] = field(default_factory=list)
38
+ strength: float = 1.0
39
+
40
+ @classmethod
41
+ def from_nodes(cls, nodes: list[ChainNode]) -> "Chain":
42
+ if not nodes:
43
+ raise ValueError("Cannot create Chain from empty nodes list")
44
+ chain_id = nodes[0].chain_id
45
+ return cls(
46
+ id=chain_id,
47
+ root_id=nodes[0].id,
48
+ leaf_id=nodes[-1].id,
49
+ anchor_prefix=nodes[0].text_prefix,
50
+ node_count=len(nodes),
51
+ nodes=nodes,
52
+ )
53
+
54
+ def full_text(self) -> str:
55
+ """拼接整条链的完整文本"""
56
+ return "".join(node.text for node in self.nodes)
57
+
58
+ def to_dict(self) -> dict:
59
+ return {
60
+ "id": self.id,
61
+ "node_count": self.node_count,
62
+ "anchor_prefix": self.anchor_prefix,
63
+ "source": self.source,
64
+ "tags": self.tags,
65
+ "strength": self.strength,
66
+ "full_text": self.full_text(),
67
+ }
chainmem/core/node.py ADDED
@@ -0,0 +1,5 @@
1
+ """ChainNode 和 Chain 数据模型"""
2
+
3
+ from chainmem.core import ChainNode, Chain
4
+
5
+ __all__ = ["ChainNode", "Chain"]
@@ -0,0 +1,149 @@
1
+ """结链管道:文本 → 切块 → 嵌入 → 存储"""
2
+
3
+ from __future__ import annotations
4
+ import re
5
+ import uuid
6
+
7
+ import numpy as np
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ from chainmem.core.node import ChainNode, Chain
11
+ from chainmem.store.sqlite_store import SQLiteStore
12
+
13
+
14
+ # 全局复用嵌入模型(加载一次即可)
15
+ _MODEL: SentenceTransformer | None = None
16
+ _MODEL_NAME: str = "all-MiniLM-L6-v2"
17
+
18
+
19
+ def _get_model(model_name: str | None = None) -> SentenceTransformer:
20
+ global _MODEL, _MODEL_NAME
21
+ if model_name is not None and model_name != _MODEL_NAME:
22
+ # 切换模型
23
+ _MODEL = SentenceTransformer(model_name)
24
+ _MODEL_NAME = model_name
25
+ elif _MODEL is None:
26
+ _MODEL = SentenceTransformer(_MODEL_NAME)
27
+ return _MODEL
28
+
29
+
30
+ def set_model(model_name: str):
31
+ """切换嵌入模型(下次调用 _get_model 时生效)"""
32
+ global _MODEL_NAME, _MODEL
33
+ _MODEL_NAME = model_name
34
+ _MODEL = None
35
+
36
+
37
+ def chunk_text(text: str, max_chars: int = 18) -> list[str]:
38
+ """将长文本按自然停顿切分为短语块
39
+
40
+ 切分规则:
41
+ 1. 始终按句号/问号/感叹号等终结标点切分
42
+ 2. 始终按逗号/顿号/冒号切分
43
+ 3. 过长的块(> max_chars)硬截断
44
+ """
45
+ # 1. 按终结标点切分(保留标点)
46
+ parts = re.split(r'(?<=[。!?;…\n])\s*', text)
47
+ parts = [p.strip() for p in parts if p.strip()]
48
+
49
+ # 2. 对每个部分按逗号/顿号/冒号再切
50
+ chunks = []
51
+ for part in parts:
52
+ sub = re.split(r'(?<=[,、:])\s*', part)
53
+ for s in sub:
54
+ s = s.strip()
55
+ if not s:
56
+ continue
57
+ if len(s) <= max_chars:
58
+ chunks.append(s)
59
+ else:
60
+ # 过长的硬截断
61
+ for i in range(0, len(s), max_chars):
62
+ chunks.append(s[i:i + max_chars])
63
+ # 3. 合併過短的塊(避免 sentence-transformers 的退化嵌入)
64
+ chunks = merge_short_chunks(chunks)
65
+ return [c for c in chunks if c]
66
+
67
+
68
+ def merge_short_chunks(chunks: list[str], min_chars: int = 6) -> list[str]:
69
+ """合併過短的塊到前一個相鄰塊
70
+
71
+ sentence-transformers 對 ≤5 字的短文本會產生退化嵌入
72
+ (不同文本得到完全相同向量->cosine=1.0)
73
+ 因此需要將短塊併入相鄰的長塊
74
+ """
75
+ if len(chunks) <= 1:
76
+ return chunks
77
+ merged = []
78
+ for chunk in chunks:
79
+ if merged and len(chunk) <= min_chars:
80
+ # 合併到前一個塊
81
+ merged[-1] = merged[-1] + chunk
82
+ else:
83
+ merged.append(chunk)
84
+ return merged
85
+
86
+
87
+ class Ingester:
88
+ """结链器:文本 → 链"""
89
+
90
+ def __init__(self, store: SQLiteStore):
91
+ self.store = store
92
+ self.embedder = _get_model()
93
+
94
+ def ingest(self, text: str, source: str = "", tags: list[str] | None = None) -> Chain:
95
+ chunks = chunk_text(text)
96
+ if not chunks:
97
+ raise ValueError("Empty text after chunking")
98
+
99
+ chain_id = str(uuid.uuid4())
100
+ nodes: list[ChainNode] = []
101
+
102
+ # 1. 嵌入所有块
103
+ embeddings = self.embedder.encode(chunks, normalize_embeddings=True)
104
+
105
+ # 2. 创建节点,串联
106
+ prev_id: str | None = None
107
+ for i, (phrase_text, emb) in enumerate(zip(chunks, embeddings)):
108
+ node_id = str(uuid.uuid4())
109
+ node = ChainNode(
110
+ id=node_id,
111
+ chain_id=chain_id,
112
+ seq=i + 1,
113
+ text=phrase_text,
114
+ embedding=emb,
115
+ prev_id=prev_id,
116
+ )
117
+ if prev_id:
118
+ # 更新前一个节点的 next_id
119
+ nodes[-1].next_id = node_id
120
+ nodes.append(node)
121
+ prev_id = node_id
122
+
123
+ root_id = nodes[0].id
124
+ leaf_id = nodes[-1].id
125
+
126
+ # 3. 存数据库
127
+ self.store.save_chain(
128
+ chain_id=chain_id,
129
+ anchor_prefix=nodes[0].text_prefix,
130
+ root_id=root_id,
131
+ leaf_id=leaf_id,
132
+ node_count=len(nodes),
133
+ source=source,
134
+ tags=tags or [],
135
+ )
136
+ for n in nodes:
137
+ self.store.save_node(
138
+ node_id=n.id,
139
+ chain_id=n.chain_id,
140
+ seq=n.seq,
141
+ text=n.text,
142
+ prev_id=n.prev_id,
143
+ next_id=n.next_id,
144
+ )
145
+
146
+ chain = Chain.from_nodes(nodes)
147
+ chain.source = source
148
+ chain.tags = tags or []
149
+ return chain