knowledge-graph-kit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- knowledge_graph_kit/__init__.py +42 -0
- knowledge_graph_kit/chunker.py +252 -0
- knowledge_graph_kit/entity_resolver.py +393 -0
- knowledge_graph_kit/extractor.py +419 -0
- knowledge_graph_kit/neo4j_writer.py +273 -0
- knowledge_graph_kit/schema.txt +38 -0
- knowledge_graph_kit-0.1.0.dist-info/METADATA +120 -0
- knowledge_graph_kit-0.1.0.dist-info/RECORD +12 -0
- knowledge_graph_kit-0.1.0.dist-info/WHEEL +5 -0
- knowledge_graph_kit-0.1.0.dist-info/entry_points.txt +5 -0
- knowledge_graph_kit-0.1.0.dist-info/licenses/LICENSE +21 -0
- knowledge_graph_kit-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
"""
|
|
2
|
+
extractor.py — Schema 引导的实体与关系抽取
|
|
3
|
+
|
|
4
|
+
流程:
|
|
5
|
+
1. 解析 schema.txt 得到本体定义
|
|
6
|
+
2. 对每个 Chunk,构造含 schema 约束的 prompt
|
|
7
|
+
3. 调用 LLM(OpenAI function calling)抽取实体和关系
|
|
8
|
+
4. 输出结构化结果(Pydantic 模型)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import os
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
|
|
19
|
+
from openai import OpenAI
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
22
|
+
# Windows 终端 UTF-8
|
|
23
|
+
if sys.platform == "win32":
|
|
24
|
+
try:
|
|
25
|
+
sys.stdout.reconfigure(encoding="utf-8") # type: ignore[attr-defined]
|
|
26
|
+
except Exception:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
# ── 惰性客户端 ────────────────────────────────────────────────
|
|
30
|
+
_client_instance: OpenAI | None = None
|
|
31
|
+
_MODEL: str = "gpt-4o-mini"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def configure(api_key: str, base_url: str | None = None, model: str = "gpt-4o-mini") -> None:
|
|
35
|
+
"""显式配置 OpenAI 客户端(在程序化使用时调用)。"""
|
|
36
|
+
global _client_instance, _MODEL
|
|
37
|
+
_client_instance = OpenAI(api_key=api_key, base_url=base_url, timeout=90)
|
|
38
|
+
_MODEL = model
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_client() -> OpenAI:
|
|
42
|
+
"""惰性初始化 OpenAI 客户端,从环境变量读取配置。"""
|
|
43
|
+
global _client_instance, _MODEL
|
|
44
|
+
if _client_instance is None:
|
|
45
|
+
api_key = os.getenv("OPENAI_API_KEY", "")
|
|
46
|
+
if not api_key:
|
|
47
|
+
raise RuntimeError(
|
|
48
|
+
"OPENAI_API_KEY 未设置。请在 .env 中配置或调用 configure()。"
|
|
49
|
+
)
|
|
50
|
+
base_url = os.getenv("OPENAI_BASE_URL")
|
|
51
|
+
model = os.getenv("LLM_MODEL_NAME", "gpt-4o-mini")
|
|
52
|
+
_client_instance = OpenAI(api_key=api_key, base_url=base_url, timeout=90)
|
|
53
|
+
_MODEL = model
|
|
54
|
+
return _client_instance
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ═══════════════════════════════════════════════════════════
|
|
58
|
+
# 1. 结构化输出 Schema(Pydantic)
|
|
59
|
+
# ═══════════════════════════════════════════════════════════
|
|
60
|
+
|
|
61
|
+
class Entity(BaseModel):
|
|
62
|
+
"""抽取出的实体"""
|
|
63
|
+
name: str = Field(description="实体名称,如 '解构赋值'、'let关键字'")
|
|
64
|
+
type: str = Field(description="实体类型: Topic / Concept / Skill / CodeExample / TextSegment / Question")
|
|
65
|
+
properties: dict[str, Any] = Field(default_factory=dict, description="实体属性键值对,如 {'level': '基础', 'difficulty': '容易'}")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Relation(BaseModel):
|
|
69
|
+
"""抽取出的关系"""
|
|
70
|
+
source: str = Field(description="源实体名称(必须与某个 Entity.name 一致)")
|
|
71
|
+
target: str = Field(description="目标实体名称")
|
|
72
|
+
type: str = Field(description="关系类型: hasPrerequisite / isA / commonlyConfusedWith / teaches / hasExample / assessedBy")
|
|
73
|
+
properties: dict[str, Any] = Field(default_factory=dict, description="关系属性,如 {'confidence': 0.9}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ExtractionResult(BaseModel):
|
|
77
|
+
"""单个 Chunk 的抽取结果"""
|
|
78
|
+
entities: list[Entity] = Field(description="从文本中抽取的实体列表")
|
|
79
|
+
relations: list[Relation] = Field(description="从文本中抽取的关系列表")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# ═══════════════════════════════════════════════════════════
|
|
83
|
+
# 2. Schema 解析
|
|
84
|
+
# ═══════════════════════════════════════════════════════════
|
|
85
|
+
|
|
86
|
+
def parse_schema(schema_path: str | Path) -> dict[str, Any]:
|
|
87
|
+
"""解析 schema.txt,返回结构化本体定义"""
|
|
88
|
+
schema_path = Path(schema_path)
|
|
89
|
+
text = schema_path.read_text(encoding="utf-8")
|
|
90
|
+
|
|
91
|
+
schema_info: dict[str, Any] = {
|
|
92
|
+
"node_types": {},
|
|
93
|
+
"relation_types": {},
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
current_section: Optional[str] = None
|
|
97
|
+
current_rel_type: Optional[str] = None
|
|
98
|
+
|
|
99
|
+
for raw_line in text.splitlines():
|
|
100
|
+
line = raw_line.strip()
|
|
101
|
+
if not line or line.startswith("#"):
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
# ── 切换区块 ──
|
|
105
|
+
if "节点类型" in line and ":" in line:
|
|
106
|
+
current_section = "node"
|
|
107
|
+
continue
|
|
108
|
+
if "关系类型" in line and ":" in line:
|
|
109
|
+
current_section = "rel"
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
if current_section == "node":
|
|
113
|
+
# 匹配 "- Topic: # 知识主题"
|
|
114
|
+
if line.startswith("- "):
|
|
115
|
+
content = line[2:]
|
|
116
|
+
# 提取类型名(冒号或空格或括号前的内容)
|
|
117
|
+
type_name = content.split(":")[0].split("#")[0].split("(")[0].strip()
|
|
118
|
+
# 提取描述(# 后面的内容,去掉 "Topic的子类" 等后缀信息)
|
|
119
|
+
desc = ""
|
|
120
|
+
if "#" in content:
|
|
121
|
+
desc = content.split("#", 1)[1].strip()
|
|
122
|
+
schema_info["node_types"][type_name] = {
|
|
123
|
+
"description": desc,
|
|
124
|
+
"properties": [],
|
|
125
|
+
}
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
# 匹配属性行: " 属性: [name, level(认知层次), difficulty, estimatedTime]"
|
|
129
|
+
if "属性:" in line:
|
|
130
|
+
props_str = line.split("属性:", 1)[1].strip()
|
|
131
|
+
if props_str.startswith("[") and "]" in props_str:
|
|
132
|
+
props_str = props_str[1:props_str.index("]")]
|
|
133
|
+
props = [p.strip() for p in props_str.split(",")]
|
|
134
|
+
# 找出当前最近的节点类型来关联属性
|
|
135
|
+
if schema_info["node_types"]:
|
|
136
|
+
last_type = list(schema_info["node_types"].keys())[-1]
|
|
137
|
+
schema_info["node_types"][last_type]["properties"] = props
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
elif current_section == "rel":
|
|
141
|
+
# 关系定义行的两种格式:
|
|
142
|
+
# 带 "- " 前缀: "- hasPrerequisite: # 知识A是B的前置"
|
|
143
|
+
# 不带 "- " 前缀: "hasPrerequisite: # 知识A是B的前置"
|
|
144
|
+
rel_match = None
|
|
145
|
+
for prefix in ("- ", ""):
|
|
146
|
+
if line.startswith(prefix):
|
|
147
|
+
probe = line[len(prefix):] if prefix else line
|
|
148
|
+
if ":" in probe:
|
|
149
|
+
candidate_name = probe.split(":")[0].strip()
|
|
150
|
+
# 避免误匹配 "domain:" / "range:" 行
|
|
151
|
+
if candidate_name in ("domain", "range"):
|
|
152
|
+
continue
|
|
153
|
+
rel_match = (candidate_name, probe)
|
|
154
|
+
break
|
|
155
|
+
if rel_match:
|
|
156
|
+
rel_name, probe = rel_match
|
|
157
|
+
desc = ""
|
|
158
|
+
if "#" in probe:
|
|
159
|
+
desc = probe.split("#", 1)[1].strip()
|
|
160
|
+
schema_info["relation_types"][rel_name] = {
|
|
161
|
+
"description": desc,
|
|
162
|
+
"domain": "",
|
|
163
|
+
"range": "",
|
|
164
|
+
}
|
|
165
|
+
current_rel_type = rel_name
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
# 匹配 domain/range 行: " domain: Topic, range: Topic"
|
|
169
|
+
if current_rel_type:
|
|
170
|
+
if "domain:" in line:
|
|
171
|
+
d_val = line.split("domain:", 1)[1].split(",")[0].strip()
|
|
172
|
+
schema_info["relation_types"][current_rel_type]["domain"] = d_val
|
|
173
|
+
if "range:" in line:
|
|
174
|
+
r_val = line.split("range:", 1)[1].split(",")[0].strip()
|
|
175
|
+
schema_info["relation_types"][current_rel_type]["range"] = r_val
|
|
176
|
+
|
|
177
|
+
return schema_info
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# ═══════════════════════════════════════════════════════════
|
|
181
|
+
# 3. Prompt 构建
|
|
182
|
+
# ═══════════════════════════════════════════════════════════
|
|
183
|
+
|
|
184
|
+
def build_system_prompt(schema: dict[str, Any]) -> str:
|
|
185
|
+
"""根据 schema 构建系统提示词"""
|
|
186
|
+
# 节点类型描述
|
|
187
|
+
node_lines = []
|
|
188
|
+
for name, info in schema["node_types"].items():
|
|
189
|
+
desc = f" ({info['description']})" if info["description"] else ""
|
|
190
|
+
props = f" 属性: {', '.join(info['properties'])}" if info["properties"] else ""
|
|
191
|
+
node_lines.append(f" - {name}{desc}")
|
|
192
|
+
if props:
|
|
193
|
+
node_lines.append(f" {props}")
|
|
194
|
+
|
|
195
|
+
# 关系类型描述
|
|
196
|
+
rel_lines = []
|
|
197
|
+
for name, info in schema["relation_types"].items():
|
|
198
|
+
desc = f" ({info['description']})" if info["description"] else ""
|
|
199
|
+
dom = f" domain: {info['domain']}" if info["domain"] else ""
|
|
200
|
+
ran = f" range: {info['range']}" if info["range"] else ""
|
|
201
|
+
rel_lines.append(f" - {name}{desc}{dom}{ran}")
|
|
202
|
+
|
|
203
|
+
system_prompt = f"""你是一个知识抽取专家。请从教材文本中抽取符合以下本体定义的实体和关系。
|
|
204
|
+
|
|
205
|
+
## 节点类型(实体)
|
|
206
|
+
{chr(10).join(node_lines)}
|
|
207
|
+
|
|
208
|
+
## 关系类型
|
|
209
|
+
{chr(10).join(rel_lines)}
|
|
210
|
+
|
|
211
|
+
## 抽取规则
|
|
212
|
+
1. 只抽取文本中**明确提到**或**可以明确推导**的实体和关系
|
|
213
|
+
2. 实体名称要**精确**,使用教材中的术语(如 "let关键字" 而非 "let")
|
|
214
|
+
3. ⚠️ **实体类型必须严格从上面定义的节点类型中选择**,不允许发明任何新的类型(如 "Tool"、"Variable" 等)——如果拿不准,用 Concept
|
|
215
|
+
4. ⚠️ **关系类型必须严格从上面定义的关系类型中选择**,不允许发明新的关系名
|
|
216
|
+
5. ⚠️ **关系方向必须遵守 domain→range 定义**,例如 hasExample 的 domain 是 Topic,range 是 CodeExample,即 (知识点)→[hasExample]→(代码示例),**不能反过来**
|
|
217
|
+
6. ⚠️ **hasPrerequisite** 的方向是 (前置知识)→[hasPrerequisite]→(后续知识),即 A 是 B 的前置知识 → (A)→[hasPrerequisite]→(B)
|
|
218
|
+
7. CodeExample 的 codeSnippet 属性保存代码内容,language 属性保存编程语言
|
|
219
|
+
8. 对于 Concept(原子概念),properties 中可包含 level、difficulty 等
|
|
220
|
+
9. 注意层级:示例代码用 CodeExample,知识要点用 Concept 或 Topic
|
|
221
|
+
10. 不要重复抽取同一个实体(同名+同类型视为同一实体)
|
|
222
|
+
11. 习题中的题目用 Question 类型
|
|
223
|
+
"""
|
|
224
|
+
return system_prompt
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def build_user_prompt(chunk_id: str, title: str, text: str, examples: list[str]) -> str:
|
|
228
|
+
"""构建用户提示词(单块内容)"""
|
|
229
|
+
example_info = ""
|
|
230
|
+
if examples:
|
|
231
|
+
example_info = f"\n本块包含示例: {', '.join(examples)}"
|
|
232
|
+
|
|
233
|
+
user_prompt = f"""请从以下教材段落中抽取实体和关系。
|
|
234
|
+
|
|
235
|
+
## Chunk 信息
|
|
236
|
+
- ID: {chunk_id}
|
|
237
|
+
- 标题: {title}{example_info}
|
|
238
|
+
|
|
239
|
+
## 文本内容
|
|
240
|
+
{text}
|
|
241
|
+
|
|
242
|
+
请严格按照本体定义的类型进行抽取。
|
|
243
|
+
"""
|
|
244
|
+
return user_prompt
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# ═══════════════════════════════════════════════════════════
|
|
248
|
+
# 4. LLM 抽取
|
|
249
|
+
# ═══════════════════════════════════════════════════════════
|
|
250
|
+
|
|
251
|
+
def extract_from_chunk(
|
|
252
|
+
chunk_id: str,
|
|
253
|
+
title: str,
|
|
254
|
+
text: str,
|
|
255
|
+
examples: list[str],
|
|
256
|
+
system_prompt: str,
|
|
257
|
+
client: OpenAI | None = None,
|
|
258
|
+
) -> ExtractionResult:
|
|
259
|
+
"""对单个 Chunk 调用 LLM 抽取实体和关系(JSON mode)"""
|
|
260
|
+
llm_client = client or _get_client()
|
|
261
|
+
user_prompt = build_user_prompt(chunk_id, title, text, examples)
|
|
262
|
+
|
|
263
|
+
# JSON mode 要求 system prompt 中出现 "JSON" 关键词
|
|
264
|
+
json_system = system_prompt + "\n\n请以 JSON 格式输出,严格按照以下结构:\n{\n \"entities\": [{\"name\": \"...\", \"type\": \"Topic|Concept|Skill|CodeExample|TextSegment|Question\", \"properties\": {}}],\n \"relations\": [{\"source\": \"实体名称\", \"target\": \"实体名称\", \"type\": \"hasPrerequisite|isA|commonlyConfusedWith|teaches|hasExample|assessedBy\", \"properties\": {}}]\n}\n\n⚠️ type 字段只允许以上列表中的值,不能自己发明。"
|
|
265
|
+
|
|
266
|
+
for attempt in range(3):
|
|
267
|
+
try:
|
|
268
|
+
response = llm_client.chat.completions.create(
|
|
269
|
+
model=_MODEL,
|
|
270
|
+
messages=[
|
|
271
|
+
{"role": "system", "content": json_system},
|
|
272
|
+
{"role": "user", "content": user_prompt},
|
|
273
|
+
],
|
|
274
|
+
response_format={"type": "json_object"},
|
|
275
|
+
temperature=0.1,
|
|
276
|
+
max_tokens=4096,
|
|
277
|
+
)
|
|
278
|
+
raw = response.choices[0].message.content
|
|
279
|
+
if not raw:
|
|
280
|
+
raise ValueError("LLM 返回空内容")
|
|
281
|
+
data = json.loads(raw)
|
|
282
|
+
# 用 Pydantic 校验
|
|
283
|
+
result = ExtractionResult.model_validate(data)
|
|
284
|
+
return result
|
|
285
|
+
except Exception as e:
|
|
286
|
+
if attempt < 2:
|
|
287
|
+
wait = (attempt + 1) * 3
|
|
288
|
+
print(f" ⏳ 重试 {attempt + 1}/3 ({wait}s)... 错误: {e}")
|
|
289
|
+
import time
|
|
290
|
+
time.sleep(wait)
|
|
291
|
+
else:
|
|
292
|
+
print(f" ❌ [{chunk_id}] 抽取失败 (3次重试): {e}")
|
|
293
|
+
return ExtractionResult(entities=[], relations=[])
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# ═══════════════════════════════════════════════════════════
|
|
297
|
+
# 5. 批量抽取 + 合并
|
|
298
|
+
# ═══════════════════════════════════════════════════════════
|
|
299
|
+
|
|
300
|
+
def merge_results(results: list[ExtractionResult]) -> ExtractionResult:
|
|
301
|
+
"""合并多个 Chunk 的抽取结果(去重)"""
|
|
302
|
+
seen_entities: set[tuple[str, str]] = set() # (name, type)
|
|
303
|
+
seen_relations: set[tuple[str, str, str]] = set() # (source, target, type)
|
|
304
|
+
|
|
305
|
+
all_entities: list[Entity] = []
|
|
306
|
+
all_relations: list[Relation] = []
|
|
307
|
+
|
|
308
|
+
for r in results:
|
|
309
|
+
for e in r.entities:
|
|
310
|
+
key = (e.name.strip(), e.type.strip())
|
|
311
|
+
if key not in seen_entities:
|
|
312
|
+
seen_entities.add(key)
|
|
313
|
+
all_entities.append(e)
|
|
314
|
+
for rel in r.relations:
|
|
315
|
+
key = (rel.source.strip(), rel.target.strip(), rel.type.strip())
|
|
316
|
+
if key not in seen_relations:
|
|
317
|
+
seen_relations.add(key)
|
|
318
|
+
all_relations.append(rel)
|
|
319
|
+
|
|
320
|
+
return ExtractionResult(entities=all_entities, relations=all_relations)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# ═══════════════════════════════════════════════════════════
|
|
324
|
+
# 6. CLI 入口
|
|
325
|
+
# ═══════════════════════════════════════════════════════════
|
|
326
|
+
|
|
327
|
+
def main(argv: list[str] | None = None) -> None:
|
|
328
|
+
"""CLI 入口:解析 schema → 加载 chunk → LLM 抽取 → 输出结果。
|
|
329
|
+
|
|
330
|
+
用法: kg-extractor <txt_path>
|
|
331
|
+
kg-extractor (使用 KG_EXTRACTOR_INPUT / KG_SCHEMA_PATH 环境变量)
|
|
332
|
+
"""
|
|
333
|
+
if argv is None:
|
|
334
|
+
argv = sys.argv[1:]
|
|
335
|
+
|
|
336
|
+
from dotenv import load_dotenv
|
|
337
|
+
load_dotenv()
|
|
338
|
+
|
|
339
|
+
# 1. Schema 路径:包内数据或环境变量覆盖
|
|
340
|
+
schema_path_str = os.environ.get("KG_SCHEMA_PATH")
|
|
341
|
+
if schema_path_str:
|
|
342
|
+
schema_path = Path(schema_path_str)
|
|
343
|
+
else:
|
|
344
|
+
from importlib.resources import files
|
|
345
|
+
schema_path = files("knowledge_graph_kit").joinpath("schema.txt")
|
|
346
|
+
|
|
347
|
+
print(f"📋 加载本体: {schema_path}")
|
|
348
|
+
schema = parse_schema(schema_path)
|
|
349
|
+
node_types = list(schema["node_types"].keys())
|
|
350
|
+
rel_types = list(schema["relation_types"].keys())
|
|
351
|
+
print(f" 节点类型: {', '.join(node_types)}")
|
|
352
|
+
print(f" 关系类型: {', '.join(rel_types)}")
|
|
353
|
+
|
|
354
|
+
# 2. 构建 system prompt
|
|
355
|
+
system_prompt = build_system_prompt(schema)
|
|
356
|
+
|
|
357
|
+
# 3. 输入文本路径:CLI 参数或环境变量
|
|
358
|
+
if len(argv) >= 1:
|
|
359
|
+
txt_path = argv[0]
|
|
360
|
+
elif os.environ.get("KG_EXTRACTOR_INPUT"):
|
|
361
|
+
txt_path = os.environ["KG_EXTRACTOR_INPUT"]
|
|
362
|
+
else:
|
|
363
|
+
print("用法: kg-extractor <txt_path>", file=sys.stderr)
|
|
364
|
+
print(" 或设置 KG_EXTRACTOR_INPUT 环境变量", file=sys.stderr)
|
|
365
|
+
sys.exit(1)
|
|
366
|
+
|
|
367
|
+
from .chunker import chunk_file # type: ignore[import-untyped]
|
|
368
|
+
|
|
369
|
+
print(f"\n📄 加载文本: {txt_path}")
|
|
370
|
+
chunks = chunk_file(txt_path)
|
|
371
|
+
print(f" 共 {len(chunks)} 个 Chunk")
|
|
372
|
+
|
|
373
|
+
# 只处理有实际内容的 chunk(跳过纯标题)
|
|
374
|
+
non_empty = [c for c in chunks if len(c.text.strip()) > 20]
|
|
375
|
+
print(f" 有效 Chunk: {len(non_empty)} 个")
|
|
376
|
+
|
|
377
|
+
# 4. 逐个抽取(测试阶段限制数量)
|
|
378
|
+
max_chunks = int(os.environ.get("EXTRACT_CHUNKS", "3"))
|
|
379
|
+
test_chunks = non_empty[:max_chunks]
|
|
380
|
+
|
|
381
|
+
all_results: list[ExtractionResult] = []
|
|
382
|
+
for i, c in enumerate(test_chunks):
|
|
383
|
+
text_len = len(c.text)
|
|
384
|
+
print(f"\n🔍 [{i+1}/{len(test_chunks)}] 抽取: [{c.chunk_id}] {c.title} ({text_len} 字符)")
|
|
385
|
+
chunk_text = c.text[:6000] if text_len > 6000 else c.text
|
|
386
|
+
result = extract_from_chunk(
|
|
387
|
+
chunk_id=c.chunk_id,
|
|
388
|
+
title=c.title,
|
|
389
|
+
text=chunk_text,
|
|
390
|
+
examples=[f"{a}-{b}" for a, b in c.examples],
|
|
391
|
+
system_prompt=system_prompt,
|
|
392
|
+
)
|
|
393
|
+
print(f" → 实体: {len(result.entities)} 个, 关系: {len(result.relations)} 条")
|
|
394
|
+
for e in result.entities:
|
|
395
|
+
props_str = f" {dict(list(e.properties.items())[:3])}" if e.properties else ""
|
|
396
|
+
print(f" 实体: [{e.type}] {e.name}{props_str}")
|
|
397
|
+
for rel in result.relations:
|
|
398
|
+
print(f" 关系: ({rel.source})-[{rel.type}]->({rel.target})")
|
|
399
|
+
all_results.append(result)
|
|
400
|
+
|
|
401
|
+
# 5. 合并去重
|
|
402
|
+
merged = merge_results(all_results)
|
|
403
|
+
print(f"\n{'='*50}")
|
|
404
|
+
print(f"📊 合并结果:")
|
|
405
|
+
print(f" 实体: {len(merged.entities)} 个")
|
|
406
|
+
print(f" 关系: {len(merged.relations)} 条")
|
|
407
|
+
|
|
408
|
+
# 6. 保存结果
|
|
409
|
+
out_path_str = os.environ.get("KG_EXTRACTOR_OUTPUT", "extraction_result.json")
|
|
410
|
+
out_path = Path(out_path_str)
|
|
411
|
+
out_path.write_text(
|
|
412
|
+
merged.model_dump_json(indent=2, ensure_ascii=False),
|
|
413
|
+
encoding="utf-8",
|
|
414
|
+
)
|
|
415
|
+
print(f"✅ 已保存到: {out_path}")
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
if __name__ == "__main__":
|
|
419
|
+
main()
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""
|
|
2
|
+
neo4j_writer.py — 将抽取结果写入 Neo4j
|
|
3
|
+
|
|
4
|
+
流程:
|
|
5
|
+
1. 读取清洗后的 extraction_result_clean.json
|
|
6
|
+
2. 连接 Neo4j,创建约束
|
|
7
|
+
3. 创建节点(实体类型 → Neo4j Label)
|
|
8
|
+
4. 创建关系
|
|
9
|
+
5. 添加 __Entity__ 标签(LlamaIndex 兼容)
|
|
10
|
+
6. 打印统计信息验证
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from neo4j import GraphDatabase
|
|
22
|
+
|
|
23
|
+
if sys.platform == "win32":
|
|
24
|
+
try:
|
|
25
|
+
sys.stdout.reconfigure(encoding="utf-8") # type: ignore[attr-defined]
|
|
26
|
+
except Exception:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
# ── Neo4j 配置 ──────────────────────────────────────────────
|
|
30
|
+
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
|
|
31
|
+
NEO4J_USER = os.getenv("NEO4J_USERNAME", "neo4j")
|
|
32
|
+
NEO4J_PASS = os.getenv("NEO4J_PASSWORD", "12345678")
|
|
33
|
+
|
|
34
|
+
# ── 类型 → Label 映射 ──────────────────────────────────────
|
|
35
|
+
# 实体类型名 → Neo4j 主标签 + 是否添加 Topic 父标签
|
|
36
|
+
LABEL_MAP: dict[str, tuple[str, bool]] = {
|
|
37
|
+
"Topic": ("Topic", False),
|
|
38
|
+
"Concept": ("Concept", True), # 同时打 :Topic 标签(继承关系)
|
|
39
|
+
"Skill": ("Skill", True), # 同时打 :Topic 标签
|
|
40
|
+
"CodeExample": ("CodeExample", False),
|
|
41
|
+
"TextSegment": ("TextSegment", False),
|
|
42
|
+
"Question": ("Question", False),
|
|
43
|
+
"Video": ("Video", False),
|
|
44
|
+
"Student": ("Student", False),
|
|
45
|
+
"StudyRecord": ("StudyRecord", False),
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ═══════════════════════════════════════════════════════════
|
|
50
|
+
# 1. Neo4j 操作
|
|
51
|
+
# ═══════════════════════════════════════════════════════════
|
|
52
|
+
|
|
53
|
+
class Neo4jWriter:
|
|
54
|
+
"""封装 Neo4j 写入操作"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, uri: str, user: str, password: str):
|
|
57
|
+
self.driver = GraphDatabase.driver(uri, auth=(user, password))
|
|
58
|
+
|
|
59
|
+
def close(self):
|
|
60
|
+
self.driver.close()
|
|
61
|
+
|
|
62
|
+
# ── 约束 ──────────────────────────────────────────────
|
|
63
|
+
|
|
64
|
+
def create_constraints(self):
|
|
65
|
+
with self.driver.session() as sess:
|
|
66
|
+
# 所有实体节点共用的唯一约束
|
|
67
|
+
sess.run("CREATE CONSTRAINT IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE")
|
|
68
|
+
# 节点类型上的索引(可选)
|
|
69
|
+
sess.run("CREATE INDEX IF NOT EXISTS FOR (e:Entity) ON (e.entity_type)")
|
|
70
|
+
print(" ✅ 约束已创建")
|
|
71
|
+
|
|
72
|
+
# ── 清空 ──────────────────────────────────────────────
|
|
73
|
+
|
|
74
|
+
def clear_all(self):
|
|
75
|
+
with self.driver.session() as sess:
|
|
76
|
+
sess.run("MATCH (n) DETACH DELETE n")
|
|
77
|
+
print(" 🧹 已清空所有数据")
|
|
78
|
+
|
|
79
|
+
# ── 写入节点 ──────────────────────────────────────────
|
|
80
|
+
|
|
81
|
+
def write_entities(self, entities: list[dict]) -> int:
|
|
82
|
+
count = 0
|
|
83
|
+
with self.driver.session() as sess:
|
|
84
|
+
for e in entities:
|
|
85
|
+
name = e["name"]
|
|
86
|
+
etype = e["type"]
|
|
87
|
+
props = {k: v for k, v in e.get("properties", {}).items() if v is not None}
|
|
88
|
+
|
|
89
|
+
main_label, add_topic_label = LABEL_MAP.get(etype, (etype, False))
|
|
90
|
+
|
|
91
|
+
# 构建标签字符串
|
|
92
|
+
labels = [main_label]
|
|
93
|
+
if add_topic_label:
|
|
94
|
+
labels.append("Topic")
|
|
95
|
+
|
|
96
|
+
# 动态构建 MERGE 语句(标签不能作为参数传入)
|
|
97
|
+
label_str = ":" + ":".join(labels)
|
|
98
|
+
cypher = f"""
|
|
99
|
+
MERGE (n:Entity {{name: $name}})
|
|
100
|
+
ON CREATE SET
|
|
101
|
+
n:{main_label},
|
|
102
|
+
n.entity_type = $etype,
|
|
103
|
+
n.id = $name
|
|
104
|
+
SET n += $props
|
|
105
|
+
"""
|
|
106
|
+
# 如果是 Topic 子类,额外增加标签
|
|
107
|
+
if add_topic_label:
|
|
108
|
+
cypher = f"""
|
|
109
|
+
MERGE (n:Entity {{name: $name}})
|
|
110
|
+
ON CREATE SET
|
|
111
|
+
n:{main_label},
|
|
112
|
+
n:Topic,
|
|
113
|
+
n.entity_type = $etype,
|
|
114
|
+
n.id = $name
|
|
115
|
+
SET n += $props
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
sess.run(cypher, name=name, etype=etype, props=props)
|
|
119
|
+
count += 1
|
|
120
|
+
|
|
121
|
+
return count
|
|
122
|
+
|
|
123
|
+
# ── 写入关系 ──────────────────────────────────────────
|
|
124
|
+
|
|
125
|
+
def write_relations(self, relations: list[dict]) -> int:
|
|
126
|
+
count = 0
|
|
127
|
+
with self.driver.session() as sess:
|
|
128
|
+
for rel in relations:
|
|
129
|
+
src = rel["source"]
|
|
130
|
+
tgt = rel["target"]
|
|
131
|
+
rtype = rel["type"]
|
|
132
|
+
props = {k: v for k, v in rel.get("properties", {}).items() if v is not None}
|
|
133
|
+
|
|
134
|
+
cypher = f"""
|
|
135
|
+
MATCH (a:Entity {{name: $src}})
|
|
136
|
+
MATCH (b:Entity {{name: $tgt}})
|
|
137
|
+
MERGE (a)-[r:{rtype}]->(b)
|
|
138
|
+
SET r += $props
|
|
139
|
+
"""
|
|
140
|
+
try:
|
|
141
|
+
sess.run(cypher, src=src, tgt=tgt, props=props)
|
|
142
|
+
count += 1
|
|
143
|
+
except Exception as ex:
|
|
144
|
+
print(f" ⚠️ 关系写入失败: ({src})-[{rtype}]->({tgt}) → {ex}")
|
|
145
|
+
|
|
146
|
+
return count
|
|
147
|
+
|
|
148
|
+
# ── 兼容 LlamaIndex ──────────────────────────────────
|
|
149
|
+
|
|
150
|
+
def ensure_llamaindex_compat(self):
|
|
151
|
+
"""添加 __Entity__ 标签以兼容 LlamaIndex PropertyGraphIndex"""
|
|
152
|
+
with self.driver.session() as sess:
|
|
153
|
+
sess.run("MATCH (e:Entity) SET e:`__Entity__`")
|
|
154
|
+
sess.run("MATCH (e:Entity) SET e.id = e.name")
|
|
155
|
+
print(" 🔄 LlamaIndex 兼容标签已添加")
|
|
156
|
+
|
|
157
|
+
# ── 统计 ──────────────────────────────────────────────
|
|
158
|
+
|
|
159
|
+
def print_stats(self):
|
|
160
|
+
with self.driver.session() as sess:
|
|
161
|
+
# 节点数
|
|
162
|
+
total = sess.run("MATCH (n:Entity) RETURN count(n) AS c").single()["c"]
|
|
163
|
+
# 标签分布
|
|
164
|
+
labels = sess.run(
|
|
165
|
+
"MATCH (n:Entity) RETURN n.entity_type AS type, count(*) AS cnt ORDER BY cnt DESC"
|
|
166
|
+
).data()
|
|
167
|
+
# 关系分布
|
|
168
|
+
rels = sess.run(
|
|
169
|
+
"CALL db.relationshipTypes() YIELD relationshipType AS type "
|
|
170
|
+
"RETURN type ORDER BY type"
|
|
171
|
+
).values()
|
|
172
|
+
rel_counts = []
|
|
173
|
+
for rt in rels:
|
|
174
|
+
cnt = sess.run(
|
|
175
|
+
f"MATCH ()-[r:{rt[0]}]->() RETURN count(r) AS c"
|
|
176
|
+
).single()["c"]
|
|
177
|
+
rel_counts.append((rt[0], cnt))
|
|
178
|
+
|
|
179
|
+
print(f"\n{'='*50}")
|
|
180
|
+
print(f"📊 Neo4j 统计:")
|
|
181
|
+
print(f" 节点总数: {total}")
|
|
182
|
+
print(f" 类型分布:")
|
|
183
|
+
for row in labels:
|
|
184
|
+
print(f" - {row['type']:12s}: {row['cnt']}")
|
|
185
|
+
print(f" 关系:")
|
|
186
|
+
for name, cnt in rel_counts:
|
|
187
|
+
print(f" - {name}: {cnt}")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# ═══════════════════════════════════════════════════════════
|
|
191
|
+
# 2. 主流程
|
|
192
|
+
# ═══════════════════════════════════════════════════════════
|
|
193
|
+
|
|
194
|
+
def main(argv: list[str] | None = None) -> None:
|
|
195
|
+
"""CLI 入口:将抽取结果写入 Neo4j。
|
|
196
|
+
|
|
197
|
+
用法: kg-neo4j-writer <input.json>
|
|
198
|
+
kg-neo4j-writer (使用 KG_NEO4J_INPUT 环境变量)
|
|
199
|
+
NEO4J_URI / NEO4J_USERNAME / NEO4J_PASSWORD 环境变量用于连接配置
|
|
200
|
+
"""
|
|
201
|
+
if argv is None:
|
|
202
|
+
argv = sys.argv[1:]
|
|
203
|
+
|
|
204
|
+
from dotenv import load_dotenv
|
|
205
|
+
load_dotenv()
|
|
206
|
+
|
|
207
|
+
# 1. 加载数据
|
|
208
|
+
if len(argv) >= 1:
|
|
209
|
+
src_path = Path(argv[0])
|
|
210
|
+
elif os.environ.get("KG_NEO4J_INPUT"):
|
|
211
|
+
src_path = Path(os.environ["KG_NEO4J_INPUT"])
|
|
212
|
+
else:
|
|
213
|
+
for candidate in ["extraction_result_clean.json", "extraction_result.json"]:
|
|
214
|
+
p = Path(candidate)
|
|
215
|
+
if p.exists():
|
|
216
|
+
src_path = p
|
|
217
|
+
break
|
|
218
|
+
else:
|
|
219
|
+
print("用法: kg-neo4j-writer <input.json>", file=sys.stderr)
|
|
220
|
+
print(" 或设置 KG_NEO4J_INPUT 环境变量", file=sys.stderr)
|
|
221
|
+
sys.exit(1)
|
|
222
|
+
|
|
223
|
+
data = json.loads(src_path.read_text(encoding="utf-8"))
|
|
224
|
+
entities: list[dict] = data["entities"]
|
|
225
|
+
relations: list[dict] = data["relations"]
|
|
226
|
+
|
|
227
|
+
print(f"📥 加载数据: {src_path.name}")
|
|
228
|
+
print(f" 实体: {len(entities)} 个")
|
|
229
|
+
print(f" 关系: {len(relations)} 条")
|
|
230
|
+
|
|
231
|
+
# 2. 连接 Neo4j
|
|
232
|
+
print(f"\n🔗 连接 Neo4j: {NEO4J_URI}")
|
|
233
|
+
writer = Neo4jWriter(NEO4J_URI, NEO4J_USER, NEO4J_PASS)
|
|
234
|
+
try:
|
|
235
|
+
# 测试连接
|
|
236
|
+
writer.driver.verify_connectivity()
|
|
237
|
+
print(" ✅ 连接成功")
|
|
238
|
+
except Exception as e:
|
|
239
|
+
print(f" ❌ 连接失败: {e}")
|
|
240
|
+
writer.close()
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
# 3. 创建约束
|
|
244
|
+
print("\n📌 创建约束...")
|
|
245
|
+
writer.create_constraints()
|
|
246
|
+
|
|
247
|
+
# 4. 清空旧数据(可选,通过环境变量控制)
|
|
248
|
+
if os.environ.get("NEO4J_CLEAR", "").lower() in ("1", "true", "yes"):
|
|
249
|
+
writer.clear_all()
|
|
250
|
+
|
|
251
|
+
# 5. 写入节点
|
|
252
|
+
print("\n📝 写入节点...")
|
|
253
|
+
n_count = writer.write_entities(entities)
|
|
254
|
+
print(f" ✅ 写入 {n_count} 个节点")
|
|
255
|
+
|
|
256
|
+
# 6. 写入关系
|
|
257
|
+
print("\n🔗 写入关系...")
|
|
258
|
+
r_count = writer.write_relations(relations)
|
|
259
|
+
print(f" ✅ 写入 {r_count} 条关系")
|
|
260
|
+
|
|
261
|
+
# 7. LlamaIndex 兼容
|
|
262
|
+
print("\n🔄 添加 LlamaIndex 兼容标签...")
|
|
263
|
+
writer.ensure_llamaindex_compat()
|
|
264
|
+
|
|
265
|
+
# 8. 统计
|
|
266
|
+
writer.print_stats()
|
|
267
|
+
|
|
268
|
+
writer.close()
|
|
269
|
+
print("\n🎉 写入完成!")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
if __name__ == "__main__":
|
|
273
|
+
main()
|