knowledge-graph-kit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,419 @@
1
+ """
2
+ extractor.py — Schema 引导的实体与关系抽取
3
+
4
+ 流程:
5
+ 1. 解析 schema.txt 得到本体定义
6
+ 2. 对每个 Chunk,构造含 schema 约束的 prompt
7
+ 3. 调用 LLM(OpenAI function calling)抽取实体和关系
8
+ 4. 输出结构化结果(Pydantic 模型)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import os
15
+ import sys
16
+ from pathlib import Path
17
+ from typing import Any, Optional
18
+
19
+ from openai import OpenAI
20
+ from pydantic import BaseModel, Field
21
+
22
+ # Windows 终端 UTF-8
23
+ if sys.platform == "win32":
24
+ try:
25
+ sys.stdout.reconfigure(encoding="utf-8") # type: ignore[attr-defined]
26
+ except Exception:
27
+ pass
28
+
29
+ # ── 惰性客户端 ────────────────────────────────────────────────
30
+ _client_instance: OpenAI | None = None
31
+ _MODEL: str = "gpt-4o-mini"
32
+
33
+
34
+ def configure(api_key: str, base_url: str | None = None, model: str = "gpt-4o-mini") -> None:
35
+ """显式配置 OpenAI 客户端(在程序化使用时调用)。"""
36
+ global _client_instance, _MODEL
37
+ _client_instance = OpenAI(api_key=api_key, base_url=base_url, timeout=90)
38
+ _MODEL = model
39
+
40
+
41
+ def _get_client() -> OpenAI:
42
+ """惰性初始化 OpenAI 客户端,从环境变量读取配置。"""
43
+ global _client_instance, _MODEL
44
+ if _client_instance is None:
45
+ api_key = os.getenv("OPENAI_API_KEY", "")
46
+ if not api_key:
47
+ raise RuntimeError(
48
+ "OPENAI_API_KEY 未设置。请在 .env 中配置或调用 configure()。"
49
+ )
50
+ base_url = os.getenv("OPENAI_BASE_URL")
51
+ model = os.getenv("LLM_MODEL_NAME", "gpt-4o-mini")
52
+ _client_instance = OpenAI(api_key=api_key, base_url=base_url, timeout=90)
53
+ _MODEL = model
54
+ return _client_instance
55
+
56
+
57
+ # ═══════════════════════════════════════════════════════════
58
+ # 1. 结构化输出 Schema(Pydantic)
59
+ # ═══════════════════════════════════════════════════════════
60
+
61
+ class Entity(BaseModel):
62
+ """抽取出的实体"""
63
+ name: str = Field(description="实体名称,如 '解构赋值'、'let关键字'")
64
+ type: str = Field(description="实体类型: Topic / Concept / Skill / CodeExample / TextSegment / Question")
65
+ properties: dict[str, Any] = Field(default_factory=dict, description="实体属性键值对,如 {'level': '基础', 'difficulty': '容易'}")
66
+
67
+
68
+ class Relation(BaseModel):
69
+ """抽取出的关系"""
70
+ source: str = Field(description="源实体名称(必须与某个 Entity.name 一致)")
71
+ target: str = Field(description="目标实体名称")
72
+ type: str = Field(description="关系类型: hasPrerequisite / isA / commonlyConfusedWith / teaches / hasExample / assessedBy")
73
+ properties: dict[str, Any] = Field(default_factory=dict, description="关系属性,如 {'confidence': 0.9}")
74
+
75
+
76
+ class ExtractionResult(BaseModel):
77
+ """单个 Chunk 的抽取结果"""
78
+ entities: list[Entity] = Field(description="从文本中抽取的实体列表")
79
+ relations: list[Relation] = Field(description="从文本中抽取的关系列表")
80
+
81
+
82
+ # ═══════════════════════════════════════════════════════════
83
+ # 2. Schema 解析
84
+ # ═══════════════════════════════════════════════════════════
85
+
86
+ def parse_schema(schema_path: str | Path) -> dict[str, Any]:
87
+ """解析 schema.txt,返回结构化本体定义"""
88
+ schema_path = Path(schema_path)
89
+ text = schema_path.read_text(encoding="utf-8")
90
+
91
+ schema_info: dict[str, Any] = {
92
+ "node_types": {},
93
+ "relation_types": {},
94
+ }
95
+
96
+ current_section: Optional[str] = None
97
+ current_rel_type: Optional[str] = None
98
+
99
+ for raw_line in text.splitlines():
100
+ line = raw_line.strip()
101
+ if not line or line.startswith("#"):
102
+ continue
103
+
104
+ # ── 切换区块 ──
105
+ if "节点类型" in line and ":" in line:
106
+ current_section = "node"
107
+ continue
108
+ if "关系类型" in line and ":" in line:
109
+ current_section = "rel"
110
+ continue
111
+
112
+ if current_section == "node":
113
+ # 匹配 "- Topic: # 知识主题"
114
+ if line.startswith("- "):
115
+ content = line[2:]
116
+ # 提取类型名(冒号或空格或括号前的内容)
117
+ type_name = content.split(":")[0].split("#")[0].split("(")[0].strip()
118
+ # 提取描述(# 后面的内容,去掉 "Topic的子类" 等后缀信息)
119
+ desc = ""
120
+ if "#" in content:
121
+ desc = content.split("#", 1)[1].strip()
122
+ schema_info["node_types"][type_name] = {
123
+ "description": desc,
124
+ "properties": [],
125
+ }
126
+ continue
127
+
128
+ # 匹配属性行: " 属性: [name, level(认知层次), difficulty, estimatedTime]"
129
+ if "属性:" in line:
130
+ props_str = line.split("属性:", 1)[1].strip()
131
+ if props_str.startswith("[") and "]" in props_str:
132
+ props_str = props_str[1:props_str.index("]")]
133
+ props = [p.strip() for p in props_str.split(",")]
134
+ # 找出当前最近的节点类型来关联属性
135
+ if schema_info["node_types"]:
136
+ last_type = list(schema_info["node_types"].keys())[-1]
137
+ schema_info["node_types"][last_type]["properties"] = props
138
+ continue
139
+
140
+ elif current_section == "rel":
141
+ # 关系定义行的两种格式:
142
+ # 带 "- " 前缀: "- hasPrerequisite: # 知识A是B的前置"
143
+ # 不带 "- " 前缀: "hasPrerequisite: # 知识A是B的前置"
144
+ rel_match = None
145
+ for prefix in ("- ", ""):
146
+ if line.startswith(prefix):
147
+ probe = line[len(prefix):] if prefix else line
148
+ if ":" in probe:
149
+ candidate_name = probe.split(":")[0].strip()
150
+ # 避免误匹配 "domain:" / "range:" 行
151
+ if candidate_name in ("domain", "range"):
152
+ continue
153
+ rel_match = (candidate_name, probe)
154
+ break
155
+ if rel_match:
156
+ rel_name, probe = rel_match
157
+ desc = ""
158
+ if "#" in probe:
159
+ desc = probe.split("#", 1)[1].strip()
160
+ schema_info["relation_types"][rel_name] = {
161
+ "description": desc,
162
+ "domain": "",
163
+ "range": "",
164
+ }
165
+ current_rel_type = rel_name
166
+ continue
167
+
168
+ # 匹配 domain/range 行: " domain: Topic, range: Topic"
169
+ if current_rel_type:
170
+ if "domain:" in line:
171
+ d_val = line.split("domain:", 1)[1].split(",")[0].strip()
172
+ schema_info["relation_types"][current_rel_type]["domain"] = d_val
173
+ if "range:" in line:
174
+ r_val = line.split("range:", 1)[1].split(",")[0].strip()
175
+ schema_info["relation_types"][current_rel_type]["range"] = r_val
176
+
177
+ return schema_info
178
+
179
+
180
+ # ═══════════════════════════════════════════════════════════
181
+ # 3. Prompt 构建
182
+ # ═══════════════════════════════════════════════════════════
183
+
184
+ def build_system_prompt(schema: dict[str, Any]) -> str:
185
+ """根据 schema 构建系统提示词"""
186
+ # 节点类型描述
187
+ node_lines = []
188
+ for name, info in schema["node_types"].items():
189
+ desc = f" ({info['description']})" if info["description"] else ""
190
+ props = f" 属性: {', '.join(info['properties'])}" if info["properties"] else ""
191
+ node_lines.append(f" - {name}{desc}")
192
+ if props:
193
+ node_lines.append(f" {props}")
194
+
195
+ # 关系类型描述
196
+ rel_lines = []
197
+ for name, info in schema["relation_types"].items():
198
+ desc = f" ({info['description']})" if info["description"] else ""
199
+ dom = f" domain: {info['domain']}" if info["domain"] else ""
200
+ ran = f" range: {info['range']}" if info["range"] else ""
201
+ rel_lines.append(f" - {name}{desc}{dom}{ran}")
202
+
203
+ system_prompt = f"""你是一个知识抽取专家。请从教材文本中抽取符合以下本体定义的实体和关系。
204
+
205
+ ## 节点类型(实体)
206
+ {chr(10).join(node_lines)}
207
+
208
+ ## 关系类型
209
+ {chr(10).join(rel_lines)}
210
+
211
+ ## 抽取规则
212
+ 1. 只抽取文本中**明确提到**或**可以明确推导**的实体和关系
213
+ 2. 实体名称要**精确**,使用教材中的术语(如 "let关键字" 而非 "let")
214
+ 3. ⚠️ **实体类型必须严格从上面定义的节点类型中选择**,不允许发明任何新的类型(如 "Tool"、"Variable" 等)——如果拿不准,用 Concept
215
+ 4. ⚠️ **关系类型必须严格从上面定义的关系类型中选择**,不允许发明新的关系名
216
+ 5. ⚠️ **关系方向必须遵守 domain→range 定义**,例如 hasExample 的 domain 是 Topic,range 是 CodeExample,即 (知识点)→[hasExample]→(代码示例),**不能反过来**
217
+ 6. ⚠️ **hasPrerequisite** 的方向是 (前置知识)→[hasPrerequisite]→(后续知识),即 A 是 B 的前置知识 → (A)→[hasPrerequisite]→(B)
218
+ 7. CodeExample 的 codeSnippet 属性保存代码内容,language 属性保存编程语言
219
+ 8. 对于 Concept(原子概念),properties 中可包含 level、difficulty 等
220
+ 9. 注意层级:示例代码用 CodeExample,知识要点用 Concept 或 Topic
221
+ 10. 不要重复抽取同一个实体(同名+同类型视为同一实体)
222
+ 11. 习题中的题目用 Question 类型
223
+ """
224
+ return system_prompt
225
+
226
+
227
+ def build_user_prompt(chunk_id: str, title: str, text: str, examples: list[str]) -> str:
228
+ """构建用户提示词(单块内容)"""
229
+ example_info = ""
230
+ if examples:
231
+ example_info = f"\n本块包含示例: {', '.join(examples)}"
232
+
233
+ user_prompt = f"""请从以下教材段落中抽取实体和关系。
234
+
235
+ ## Chunk 信息
236
+ - ID: {chunk_id}
237
+ - 标题: {title}{example_info}
238
+
239
+ ## 文本内容
240
+ {text}
241
+
242
+ 请严格按照本体定义的类型进行抽取。
243
+ """
244
+ return user_prompt
245
+
246
+
247
+ # ═══════════════════════════════════════════════════════════
248
+ # 4. LLM 抽取
249
+ # ═══════════════════════════════════════════════════════════
250
+
251
+ def extract_from_chunk(
252
+ chunk_id: str,
253
+ title: str,
254
+ text: str,
255
+ examples: list[str],
256
+ system_prompt: str,
257
+ client: OpenAI | None = None,
258
+ ) -> ExtractionResult:
259
+ """对单个 Chunk 调用 LLM 抽取实体和关系(JSON mode)"""
260
+ llm_client = client or _get_client()
261
+ user_prompt = build_user_prompt(chunk_id, title, text, examples)
262
+
263
+ # JSON mode 要求 system prompt 中出现 "JSON" 关键词
264
+ json_system = system_prompt + "\n\n请以 JSON 格式输出,严格按照以下结构:\n{\n \"entities\": [{\"name\": \"...\", \"type\": \"Topic|Concept|Skill|CodeExample|TextSegment|Question\", \"properties\": {}}],\n \"relations\": [{\"source\": \"实体名称\", \"target\": \"实体名称\", \"type\": \"hasPrerequisite|isA|commonlyConfusedWith|teaches|hasExample|assessedBy\", \"properties\": {}}]\n}\n\n⚠️ type 字段只允许以上列表中的值,不能自己发明。"
265
+
266
+ for attempt in range(3):
267
+ try:
268
+ response = llm_client.chat.completions.create(
269
+ model=_MODEL,
270
+ messages=[
271
+ {"role": "system", "content": json_system},
272
+ {"role": "user", "content": user_prompt},
273
+ ],
274
+ response_format={"type": "json_object"},
275
+ temperature=0.1,
276
+ max_tokens=4096,
277
+ )
278
+ raw = response.choices[0].message.content
279
+ if not raw:
280
+ raise ValueError("LLM 返回空内容")
281
+ data = json.loads(raw)
282
+ # 用 Pydantic 校验
283
+ result = ExtractionResult.model_validate(data)
284
+ return result
285
+ except Exception as e:
286
+ if attempt < 2:
287
+ wait = (attempt + 1) * 3
288
+ print(f" ⏳ 重试 {attempt + 1}/3 ({wait}s)... 错误: {e}")
289
+ import time
290
+ time.sleep(wait)
291
+ else:
292
+ print(f" ❌ [{chunk_id}] 抽取失败 (3次重试): {e}")
293
+ return ExtractionResult(entities=[], relations=[])
294
+
295
+
296
+ # ═══════════════════════════════════════════════════════════
297
+ # 5. 批量抽取 + 合并
298
+ # ═══════════════════════════════════════════════════════════
299
+
300
+ def merge_results(results: list[ExtractionResult]) -> ExtractionResult:
301
+ """合并多个 Chunk 的抽取结果(去重)"""
302
+ seen_entities: set[tuple[str, str]] = set() # (name, type)
303
+ seen_relations: set[tuple[str, str, str]] = set() # (source, target, type)
304
+
305
+ all_entities: list[Entity] = []
306
+ all_relations: list[Relation] = []
307
+
308
+ for r in results:
309
+ for e in r.entities:
310
+ key = (e.name.strip(), e.type.strip())
311
+ if key not in seen_entities:
312
+ seen_entities.add(key)
313
+ all_entities.append(e)
314
+ for rel in r.relations:
315
+ key = (rel.source.strip(), rel.target.strip(), rel.type.strip())
316
+ if key not in seen_relations:
317
+ seen_relations.add(key)
318
+ all_relations.append(rel)
319
+
320
+ return ExtractionResult(entities=all_entities, relations=all_relations)
321
+
322
+
323
+ # ═══════════════════════════════════════════════════════════
324
+ # 6. CLI 入口
325
+ # ═══════════════════════════════════════════════════════════
326
+
327
+ def main(argv: list[str] | None = None) -> None:
328
+ """CLI 入口:解析 schema → 加载 chunk → LLM 抽取 → 输出结果。
329
+
330
+ 用法: kg-extractor <txt_path>
331
+ kg-extractor (使用 KG_EXTRACTOR_INPUT / KG_SCHEMA_PATH 环境变量)
332
+ """
333
+ if argv is None:
334
+ argv = sys.argv[1:]
335
+
336
+ from dotenv import load_dotenv
337
+ load_dotenv()
338
+
339
+ # 1. Schema 路径:包内数据或环境变量覆盖
340
+ schema_path_str = os.environ.get("KG_SCHEMA_PATH")
341
+ if schema_path_str:
342
+ schema_path = Path(schema_path_str)
343
+ else:
344
+ from importlib.resources import files
345
+ schema_path = files("knowledge_graph_kit").joinpath("schema.txt")
346
+
347
+ print(f"📋 加载本体: {schema_path}")
348
+ schema = parse_schema(schema_path)
349
+ node_types = list(schema["node_types"].keys())
350
+ rel_types = list(schema["relation_types"].keys())
351
+ print(f" 节点类型: {', '.join(node_types)}")
352
+ print(f" 关系类型: {', '.join(rel_types)}")
353
+
354
+ # 2. 构建 system prompt
355
+ system_prompt = build_system_prompt(schema)
356
+
357
+ # 3. 输入文本路径:CLI 参数或环境变量
358
+ if len(argv) >= 1:
359
+ txt_path = argv[0]
360
+ elif os.environ.get("KG_EXTRACTOR_INPUT"):
361
+ txt_path = os.environ["KG_EXTRACTOR_INPUT"]
362
+ else:
363
+ print("用法: kg-extractor <txt_path>", file=sys.stderr)
364
+ print(" 或设置 KG_EXTRACTOR_INPUT 环境变量", file=sys.stderr)
365
+ sys.exit(1)
366
+
367
+ from .chunker import chunk_file # type: ignore[import-untyped]
368
+
369
+ print(f"\n📄 加载文本: {txt_path}")
370
+ chunks = chunk_file(txt_path)
371
+ print(f" 共 {len(chunks)} 个 Chunk")
372
+
373
+ # 只处理有实际内容的 chunk(跳过纯标题)
374
+ non_empty = [c for c in chunks if len(c.text.strip()) > 20]
375
+ print(f" 有效 Chunk: {len(non_empty)} 个")
376
+
377
+ # 4. 逐个抽取(测试阶段限制数量)
378
+ max_chunks = int(os.environ.get("EXTRACT_CHUNKS", "3"))
379
+ test_chunks = non_empty[:max_chunks]
380
+
381
+ all_results: list[ExtractionResult] = []
382
+ for i, c in enumerate(test_chunks):
383
+ text_len = len(c.text)
384
+ print(f"\n🔍 [{i+1}/{len(test_chunks)}] 抽取: [{c.chunk_id}] {c.title} ({text_len} 字符)")
385
+ chunk_text = c.text[:6000] if text_len > 6000 else c.text
386
+ result = extract_from_chunk(
387
+ chunk_id=c.chunk_id,
388
+ title=c.title,
389
+ text=chunk_text,
390
+ examples=[f"{a}-{b}" for a, b in c.examples],
391
+ system_prompt=system_prompt,
392
+ )
393
+ print(f" → 实体: {len(result.entities)} 个, 关系: {len(result.relations)} 条")
394
+ for e in result.entities:
395
+ props_str = f" {dict(list(e.properties.items())[:3])}" if e.properties else ""
396
+ print(f" 实体: [{e.type}] {e.name}{props_str}")
397
+ for rel in result.relations:
398
+ print(f" 关系: ({rel.source})-[{rel.type}]->({rel.target})")
399
+ all_results.append(result)
400
+
401
+ # 5. 合并去重
402
+ merged = merge_results(all_results)
403
+ print(f"\n{'='*50}")
404
+ print(f"📊 合并结果:")
405
+ print(f" 实体: {len(merged.entities)} 个")
406
+ print(f" 关系: {len(merged.relations)} 条")
407
+
408
+ # 6. 保存结果
409
+ out_path_str = os.environ.get("KG_EXTRACTOR_OUTPUT", "extraction_result.json")
410
+ out_path = Path(out_path_str)
411
+ out_path.write_text(
412
+ merged.model_dump_json(indent=2, ensure_ascii=False),
413
+ encoding="utf-8",
414
+ )
415
+ print(f"✅ 已保存到: {out_path}")
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()
@@ -0,0 +1,273 @@
1
+ """
2
+ neo4j_writer.py — 将抽取结果写入 Neo4j
3
+
4
+ 流程:
5
+ 1. 读取清洗后的 extraction_result_clean.json
6
+ 2. 连接 Neo4j,创建约束
7
+ 3. 创建节点(实体类型 → Neo4j Label)
8
+ 4. 创建关系
9
+ 5. 添加 __Entity__ 标签(LlamaIndex 兼容)
10
+ 6. 打印统计信息验证
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import sys
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ from neo4j import GraphDatabase
22
+
23
+ if sys.platform == "win32":
24
+ try:
25
+ sys.stdout.reconfigure(encoding="utf-8") # type: ignore[attr-defined]
26
+ except Exception:
27
+ pass
28
+
29
+ # ── Neo4j 配置 ──────────────────────────────────────────────
30
+ NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
31
+ NEO4J_USER = os.getenv("NEO4J_USERNAME", "neo4j")
32
+ NEO4J_PASS = os.getenv("NEO4J_PASSWORD", "12345678")
33
+
34
+ # ── 类型 → Label 映射 ──────────────────────────────────────
35
+ # 实体类型名 → Neo4j 主标签 + 是否添加 Topic 父标签
36
+ LABEL_MAP: dict[str, tuple[str, bool]] = {
37
+ "Topic": ("Topic", False),
38
+ "Concept": ("Concept", True), # 同时打 :Topic 标签(继承关系)
39
+ "Skill": ("Skill", True), # 同时打 :Topic 标签
40
+ "CodeExample": ("CodeExample", False),
41
+ "TextSegment": ("TextSegment", False),
42
+ "Question": ("Question", False),
43
+ "Video": ("Video", False),
44
+ "Student": ("Student", False),
45
+ "StudyRecord": ("StudyRecord", False),
46
+ }
47
+
48
+
49
+ # ═══════════════════════════════════════════════════════════
50
+ # 1. Neo4j 操作
51
+ # ═══════════════════════════════════════════════════════════
52
+
53
+ class Neo4jWriter:
54
+ """封装 Neo4j 写入操作"""
55
+
56
+ def __init__(self, uri: str, user: str, password: str):
57
+ self.driver = GraphDatabase.driver(uri, auth=(user, password))
58
+
59
+ def close(self):
60
+ self.driver.close()
61
+
62
+ # ── 约束 ──────────────────────────────────────────────
63
+
64
+ def create_constraints(self):
65
+ with self.driver.session() as sess:
66
+ # 所有实体节点共用的唯一约束
67
+ sess.run("CREATE CONSTRAINT IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE")
68
+ # 节点类型上的索引(可选)
69
+ sess.run("CREATE INDEX IF NOT EXISTS FOR (e:Entity) ON (e.entity_type)")
70
+ print(" ✅ 约束已创建")
71
+
72
+ # ── 清空 ──────────────────────────────────────────────
73
+
74
+ def clear_all(self):
75
+ with self.driver.session() as sess:
76
+ sess.run("MATCH (n) DETACH DELETE n")
77
+ print(" 🧹 已清空所有数据")
78
+
79
+ # ── 写入节点 ──────────────────────────────────────────
80
+
81
+ def write_entities(self, entities: list[dict]) -> int:
82
+ count = 0
83
+ with self.driver.session() as sess:
84
+ for e in entities:
85
+ name = e["name"]
86
+ etype = e["type"]
87
+ props = {k: v for k, v in e.get("properties", {}).items() if v is not None}
88
+
89
+ main_label, add_topic_label = LABEL_MAP.get(etype, (etype, False))
90
+
91
+ # 构建标签字符串
92
+ labels = [main_label]
93
+ if add_topic_label:
94
+ labels.append("Topic")
95
+
96
+ # 动态构建 MERGE 语句(标签不能作为参数传入)
97
+ label_str = ":" + ":".join(labels)
98
+ cypher = f"""
99
+ MERGE (n:Entity {{name: $name}})
100
+ ON CREATE SET
101
+ n:{main_label},
102
+ n.entity_type = $etype,
103
+ n.id = $name
104
+ SET n += $props
105
+ """
106
+ # 如果是 Topic 子类,额外增加标签
107
+ if add_topic_label:
108
+ cypher = f"""
109
+ MERGE (n:Entity {{name: $name}})
110
+ ON CREATE SET
111
+ n:{main_label},
112
+ n:Topic,
113
+ n.entity_type = $etype,
114
+ n.id = $name
115
+ SET n += $props
116
+ """
117
+
118
+ sess.run(cypher, name=name, etype=etype, props=props)
119
+ count += 1
120
+
121
+ return count
122
+
123
+ # ── 写入关系 ──────────────────────────────────────────
124
+
125
+ def write_relations(self, relations: list[dict]) -> int:
126
+ count = 0
127
+ with self.driver.session() as sess:
128
+ for rel in relations:
129
+ src = rel["source"]
130
+ tgt = rel["target"]
131
+ rtype = rel["type"]
132
+ props = {k: v for k, v in rel.get("properties", {}).items() if v is not None}
133
+
134
+ cypher = f"""
135
+ MATCH (a:Entity {{name: $src}})
136
+ MATCH (b:Entity {{name: $tgt}})
137
+ MERGE (a)-[r:{rtype}]->(b)
138
+ SET r += $props
139
+ """
140
+ try:
141
+ sess.run(cypher, src=src, tgt=tgt, props=props)
142
+ count += 1
143
+ except Exception as ex:
144
+ print(f" ⚠️ 关系写入失败: ({src})-[{rtype}]->({tgt}) → {ex}")
145
+
146
+ return count
147
+
148
+ # ── 兼容 LlamaIndex ──────────────────────────────────
149
+
150
+ def ensure_llamaindex_compat(self):
151
+ """添加 __Entity__ 标签以兼容 LlamaIndex PropertyGraphIndex"""
152
+ with self.driver.session() as sess:
153
+ sess.run("MATCH (e:Entity) SET e:`__Entity__`")
154
+ sess.run("MATCH (e:Entity) SET e.id = e.name")
155
+ print(" 🔄 LlamaIndex 兼容标签已添加")
156
+
157
+ # ── 统计 ──────────────────────────────────────────────
158
+
159
+ def print_stats(self):
160
+ with self.driver.session() as sess:
161
+ # 节点数
162
+ total = sess.run("MATCH (n:Entity) RETURN count(n) AS c").single()["c"]
163
+ # 标签分布
164
+ labels = sess.run(
165
+ "MATCH (n:Entity) RETURN n.entity_type AS type, count(*) AS cnt ORDER BY cnt DESC"
166
+ ).data()
167
+ # 关系分布
168
+ rels = sess.run(
169
+ "CALL db.relationshipTypes() YIELD relationshipType AS type "
170
+ "RETURN type ORDER BY type"
171
+ ).values()
172
+ rel_counts = []
173
+ for rt in rels:
174
+ cnt = sess.run(
175
+ f"MATCH ()-[r:{rt[0]}]->() RETURN count(r) AS c"
176
+ ).single()["c"]
177
+ rel_counts.append((rt[0], cnt))
178
+
179
+ print(f"\n{'='*50}")
180
+ print(f"📊 Neo4j 统计:")
181
+ print(f" 节点总数: {total}")
182
+ print(f" 类型分布:")
183
+ for row in labels:
184
+ print(f" - {row['type']:12s}: {row['cnt']}")
185
+ print(f" 关系:")
186
+ for name, cnt in rel_counts:
187
+ print(f" - {name}: {cnt}")
188
+
189
+
190
+ # ═══════════════════════════════════════════════════════════
191
+ # 2. 主流程
192
+ # ═══════════════════════════════════════════════════════════
193
+
194
+ def main(argv: list[str] | None = None) -> None:
195
+ """CLI 入口:将抽取结果写入 Neo4j。
196
+
197
+ 用法: kg-neo4j-writer <input.json>
198
+ kg-neo4j-writer (使用 KG_NEO4J_INPUT 环境变量)
199
+ NEO4J_URI / NEO4J_USERNAME / NEO4J_PASSWORD 环境变量用于连接配置
200
+ """
201
+ if argv is None:
202
+ argv = sys.argv[1:]
203
+
204
+ from dotenv import load_dotenv
205
+ load_dotenv()
206
+
207
+ # 1. 加载数据
208
+ if len(argv) >= 1:
209
+ src_path = Path(argv[0])
210
+ elif os.environ.get("KG_NEO4J_INPUT"):
211
+ src_path = Path(os.environ["KG_NEO4J_INPUT"])
212
+ else:
213
+ for candidate in ["extraction_result_clean.json", "extraction_result.json"]:
214
+ p = Path(candidate)
215
+ if p.exists():
216
+ src_path = p
217
+ break
218
+ else:
219
+ print("用法: kg-neo4j-writer <input.json>", file=sys.stderr)
220
+ print(" 或设置 KG_NEO4J_INPUT 环境变量", file=sys.stderr)
221
+ sys.exit(1)
222
+
223
+ data = json.loads(src_path.read_text(encoding="utf-8"))
224
+ entities: list[dict] = data["entities"]
225
+ relations: list[dict] = data["relations"]
226
+
227
+ print(f"📥 加载数据: {src_path.name}")
228
+ print(f" 实体: {len(entities)} 个")
229
+ print(f" 关系: {len(relations)} 条")
230
+
231
+ # 2. 连接 Neo4j
232
+ print(f"\n🔗 连接 Neo4j: {NEO4J_URI}")
233
+ writer = Neo4jWriter(NEO4J_URI, NEO4J_USER, NEO4J_PASS)
234
+ try:
235
+ # 测试连接
236
+ writer.driver.verify_connectivity()
237
+ print(" ✅ 连接成功")
238
+ except Exception as e:
239
+ print(f" ❌ 连接失败: {e}")
240
+ writer.close()
241
+ return
242
+
243
+ # 3. 创建约束
244
+ print("\n📌 创建约束...")
245
+ writer.create_constraints()
246
+
247
+ # 4. 清空旧数据(可选,通过环境变量控制)
248
+ if os.environ.get("NEO4J_CLEAR", "").lower() in ("1", "true", "yes"):
249
+ writer.clear_all()
250
+
251
+ # 5. 写入节点
252
+ print("\n📝 写入节点...")
253
+ n_count = writer.write_entities(entities)
254
+ print(f" ✅ 写入 {n_count} 个节点")
255
+
256
+ # 6. 写入关系
257
+ print("\n🔗 写入关系...")
258
+ r_count = writer.write_relations(relations)
259
+ print(f" ✅ 写入 {r_count} 条关系")
260
+
261
+ # 7. LlamaIndex 兼容
262
+ print("\n🔄 添加 LlamaIndex 兼容标签...")
263
+ writer.ensure_llamaindex_compat()
264
+
265
+ # 8. 统计
266
+ writer.print_stats()
267
+
268
+ writer.close()
269
+ print("\n🎉 写入完成!")
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()