graphrag-core 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphrag_core/__init__.py +138 -0
- graphrag_core/_cypher.py +15 -0
- graphrag_core/agents/__init__.py +6 -0
- graphrag_core/agents/context.py +16 -0
- graphrag_core/agents/orchestrator.py +34 -0
- graphrag_core/curation/__init__.py +6 -0
- graphrag_core/curation/detection.py +158 -0
- graphrag_core/curation/pipeline.py +39 -0
- graphrag_core/extraction/__init__.py +5 -0
- graphrag_core/extraction/engine.py +154 -0
- graphrag_core/graph/__init__.py +11 -0
- graphrag_core/graph/memory.py +118 -0
- graphrag_core/graph/neo4j.py +196 -0
- graphrag_core/ingestion/__init__.py +19 -0
- graphrag_core/ingestion/chunker.py +45 -0
- graphrag_core/ingestion/parsers.py +128 -0
- graphrag_core/ingestion/pipeline.py +36 -0
- graphrag_core/interfaces.py +229 -0
- graphrag_core/llm/__init__.py +9 -0
- graphrag_core/llm/anthropic.py +35 -0
- graphrag_core/models.py +247 -0
- graphrag_core/py.typed +0 -0
- graphrag_core/registry/__init__.py +5 -0
- graphrag_core/registry/matching.py +23 -0
- graphrag_core/registry/memory.py +81 -0
- graphrag_core/search/__init__.py +11 -0
- graphrag_core/search/fusion.py +34 -0
- graphrag_core/search/memory.py +104 -0
- graphrag_core/search/neo4j.py +186 -0
- graphrag_core/tools/__init__.py +6 -0
- graphrag_core/tools/core_tools.py +88 -0
- graphrag_core/tools/library.py +45 -0
- graphrag_core-0.2.0.dist-info/METADATA +182 -0
- graphrag_core-0.2.0.dist-info/RECORD +36 -0
- graphrag_core-0.2.0.dist-info/WHEEL +4 -0
- graphrag_core-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""graphrag-core: Domain-agnostic Graph RAG framework."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.2.0"
|
|
4
|
+
|
|
5
|
+
from graphrag_core.interfaces import (
|
|
6
|
+
Agent,
|
|
7
|
+
ApprovalGateway,
|
|
8
|
+
Chunker,
|
|
9
|
+
DetectionLayer,
|
|
10
|
+
DocumentParser,
|
|
11
|
+
EmbeddingModel,
|
|
12
|
+
EntityRegistry,
|
|
13
|
+
ExtractionEngine,
|
|
14
|
+
GraphStore,
|
|
15
|
+
IngestionPipeline,
|
|
16
|
+
LLMClient,
|
|
17
|
+
LLMCurationLayer,
|
|
18
|
+
Orchestrator,
|
|
19
|
+
ReportRenderer,
|
|
20
|
+
SearchEngine,
|
|
21
|
+
)
|
|
22
|
+
from graphrag_core.ingestion import (
|
|
23
|
+
DocxParser,
|
|
24
|
+
MarkdownParser,
|
|
25
|
+
PdfParser,
|
|
26
|
+
TextParser,
|
|
27
|
+
TokenChunker,
|
|
28
|
+
)
|
|
29
|
+
from graphrag_core.extraction import LLMExtractionEngine
|
|
30
|
+
from graphrag_core.graph import InMemoryGraphStore
|
|
31
|
+
from graphrag_core.search import InMemorySearchEngine
|
|
32
|
+
from graphrag_core.registry import InMemoryEntityRegistry
|
|
33
|
+
from graphrag_core.curation import DeterministicDetectionLayer, CurationPipeline
|
|
34
|
+
from graphrag_core.tools import Tool, ToolLibrary, register_core_tools
|
|
35
|
+
from graphrag_core.agents import AgentContext, SequentialOrchestrator
|
|
36
|
+
from graphrag_core.models import (
|
|
37
|
+
AgentResult,
|
|
38
|
+
CurationIssue,
|
|
39
|
+
CurationReport,
|
|
40
|
+
DocumentChunk,
|
|
41
|
+
ExtractionResult,
|
|
42
|
+
GraphNode,
|
|
43
|
+
ImportRun,
|
|
44
|
+
KnownEntity,
|
|
45
|
+
NodeTypeDefinition,
|
|
46
|
+
OntologySchema,
|
|
47
|
+
PropertyDefinition,
|
|
48
|
+
RegistryMatch,
|
|
49
|
+
RelationshipTypeDefinition,
|
|
50
|
+
RenderConfig,
|
|
51
|
+
ReportData,
|
|
52
|
+
SearchResult,
|
|
53
|
+
ToolParameter,
|
|
54
|
+
ToolResult,
|
|
55
|
+
WorkflowResult,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
# Protocols
|
|
60
|
+
"Agent",
|
|
61
|
+
"ApprovalGateway",
|
|
62
|
+
"Chunker",
|
|
63
|
+
"DetectionLayer",
|
|
64
|
+
"DocumentParser",
|
|
65
|
+
"EmbeddingModel",
|
|
66
|
+
"EntityRegistry",
|
|
67
|
+
"ExtractionEngine",
|
|
68
|
+
"GraphStore",
|
|
69
|
+
"IngestionPipeline",
|
|
70
|
+
"LLMClient",
|
|
71
|
+
"LLMCurationLayer",
|
|
72
|
+
"Orchestrator",
|
|
73
|
+
"ReportRenderer",
|
|
74
|
+
"SearchEngine",
|
|
75
|
+
# BB1 implementations
|
|
76
|
+
"DocxParser",
|
|
77
|
+
"MarkdownParser",
|
|
78
|
+
"PdfParser",
|
|
79
|
+
"TextParser",
|
|
80
|
+
"TokenChunker",
|
|
81
|
+
# BB2 implementations
|
|
82
|
+
"LLMExtractionEngine",
|
|
83
|
+
# BB3 implementations
|
|
84
|
+
"InMemoryGraphStore",
|
|
85
|
+
# BB4 implementations
|
|
86
|
+
"InMemorySearchEngine",
|
|
87
|
+
# BB5 implementations
|
|
88
|
+
"CurationPipeline",
|
|
89
|
+
"DeterministicDetectionLayer",
|
|
90
|
+
# BB6 implementations
|
|
91
|
+
"InMemoryEntityRegistry",
|
|
92
|
+
# BB7 implementations
|
|
93
|
+
"Tool",
|
|
94
|
+
"ToolLibrary",
|
|
95
|
+
"register_core_tools",
|
|
96
|
+
# BB8 implementations
|
|
97
|
+
"AgentContext",
|
|
98
|
+
"SequentialOrchestrator",
|
|
99
|
+
# Models
|
|
100
|
+
"AgentResult",
|
|
101
|
+
"CurationIssue",
|
|
102
|
+
"CurationReport",
|
|
103
|
+
"DocumentChunk",
|
|
104
|
+
"ExtractionResult",
|
|
105
|
+
"GraphNode",
|
|
106
|
+
"ImportRun",
|
|
107
|
+
"KnownEntity",
|
|
108
|
+
"NodeTypeDefinition",
|
|
109
|
+
"OntologySchema",
|
|
110
|
+
"PropertyDefinition",
|
|
111
|
+
"RegistryMatch",
|
|
112
|
+
"RelationshipTypeDefinition",
|
|
113
|
+
"RenderConfig",
|
|
114
|
+
"ReportData",
|
|
115
|
+
"SearchResult",
|
|
116
|
+
"ToolParameter",
|
|
117
|
+
"ToolResult",
|
|
118
|
+
"WorkflowResult",
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Optional: Neo4j and Anthropic (require extras)
|
|
122
|
+
try:
|
|
123
|
+
from graphrag_core.graph import Neo4jGraphStore
|
|
124
|
+
__all__.append("Neo4jGraphStore")
|
|
125
|
+
except ImportError:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
from graphrag_core.search import Neo4jHybridSearch
|
|
130
|
+
__all__.append("Neo4jHybridSearch")
|
|
131
|
+
except ImportError:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
from graphrag_core.llm import AnthropicLLMClient
|
|
136
|
+
__all__.append("AnthropicLLMClient")
|
|
137
|
+
except ImportError:
|
|
138
|
+
pass
|
graphrag_core/_cypher.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Shared Cypher safety utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
SAFE_IDENTIFIER = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
|
8
|
+
MAX_DEPTH = 10
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def validate_identifier(value: str, kind: str) -> str:
|
|
12
|
+
"""Reject identifiers that could cause Cypher injection."""
|
|
13
|
+
if not SAFE_IDENTIFIER.match(value):
|
|
14
|
+
raise ValueError(f"Invalid {kind}: {value!r}")
|
|
15
|
+
return value
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""BB8: Shared context for agent workflows."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class AgentContext:
|
|
11
|
+
"""Runtime context passed between agents in a workflow."""
|
|
12
|
+
|
|
13
|
+
graph_store: Any
|
|
14
|
+
tool_library: Any
|
|
15
|
+
search_engine: Any
|
|
16
|
+
workflow_state: dict[str, Any] = field(default_factory=dict)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""BB8: Sequential agent orchestrator."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from graphrag_core.agents.context import AgentContext
|
|
6
|
+
from graphrag_core.models import WorkflowResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SequentialOrchestrator:
|
|
10
|
+
"""Runs agents sequentially, stopping on first failure."""
|
|
11
|
+
|
|
12
|
+
async def run_workflow(
|
|
13
|
+
self,
|
|
14
|
+
workflow_id: str,
|
|
15
|
+
agents: list,
|
|
16
|
+
context: AgentContext,
|
|
17
|
+
) -> WorkflowResult:
|
|
18
|
+
agent_results = []
|
|
19
|
+
|
|
20
|
+
for agent in agents:
|
|
21
|
+
result = await agent.execute(context)
|
|
22
|
+
agent_results.append(result)
|
|
23
|
+
if not result.success:
|
|
24
|
+
return WorkflowResult(
|
|
25
|
+
workflow_id=workflow_id,
|
|
26
|
+
success=False,
|
|
27
|
+
agent_results=agent_results,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
return WorkflowResult(
|
|
31
|
+
workflow_id=workflow_id,
|
|
32
|
+
success=True,
|
|
33
|
+
agent_results=agent_results,
|
|
34
|
+
)
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""BB5: Deterministic detection layer for graph quality issues."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import uuid
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
8
|
+
from graphrag_core.interfaces import EntityRegistry, GraphStore
|
|
9
|
+
from graphrag_core.models import CurationIssue, GraphNode, OntologySchema
|
|
10
|
+
from graphrag_core.registry.matching import fuzzy_score
|
|
11
|
+
|
|
12
|
+
_PAIRWISE_CAP = 1000
|
|
13
|
+
_FUZZY_THRESHOLD = 0.7
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DeterministicDetectionLayer:
|
|
17
|
+
"""Finds duplicates, orphans, and schema violations without LLM calls."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, entity_registry: EntityRegistry | None = None) -> None:
|
|
20
|
+
self._registry = entity_registry
|
|
21
|
+
|
|
22
|
+
async def detect(
|
|
23
|
+
self, graph_store: GraphStore, schema: OntologySchema
|
|
24
|
+
) -> list[CurationIssue]:
|
|
25
|
+
issues: list[CurationIssue] = []
|
|
26
|
+
|
|
27
|
+
nodes = await graph_store.list_nodes()
|
|
28
|
+
|
|
29
|
+
issues.extend(await self._detect_duplicates(nodes))
|
|
30
|
+
issues.extend(await self._detect_orphans(nodes, graph_store))
|
|
31
|
+
issues.extend(await self._detect_schema_violations(graph_store, schema))
|
|
32
|
+
|
|
33
|
+
return issues
|
|
34
|
+
|
|
35
|
+
async def _detect_duplicates(self, nodes: list[GraphNode]) -> list[CurationIssue]:
|
|
36
|
+
issues: list[CurationIssue] = []
|
|
37
|
+
|
|
38
|
+
groups: dict[str, list[GraphNode]] = defaultdict(list)
|
|
39
|
+
for node in nodes:
|
|
40
|
+
groups[node.label].append(node)
|
|
41
|
+
|
|
42
|
+
for label, group in groups.items():
|
|
43
|
+
if self._registry is not None:
|
|
44
|
+
issues.extend(await self._detect_duplicates_with_registry(group))
|
|
45
|
+
else:
|
|
46
|
+
if len(group) > _PAIRWISE_CAP:
|
|
47
|
+
issues.append(CurationIssue(
|
|
48
|
+
id=str(uuid.uuid4()),
|
|
49
|
+
issue_type="skipped_detection",
|
|
50
|
+
severity="warning",
|
|
51
|
+
affected_nodes=[],
|
|
52
|
+
suggested_action=f"Register entities in EntityRegistry for label '{label}' ({len(group)} nodes exceeds pairwise cap of {_PAIRWISE_CAP})",
|
|
53
|
+
auto_fixable=False,
|
|
54
|
+
source_layer="deterministic",
|
|
55
|
+
))
|
|
56
|
+
else:
|
|
57
|
+
issues.extend(self._detect_duplicates_pairwise(group))
|
|
58
|
+
|
|
59
|
+
return issues
|
|
60
|
+
|
|
61
|
+
async def _detect_duplicates_with_registry(
|
|
62
|
+
self, nodes: list[GraphNode]
|
|
63
|
+
) -> list[CurationIssue]:
|
|
64
|
+
issues: list[CurationIssue] = []
|
|
65
|
+
entity_to_nodes: dict[str, list[str]] = defaultdict(list)
|
|
66
|
+
|
|
67
|
+
for node in nodes:
|
|
68
|
+
name = node.properties.get("name", "")
|
|
69
|
+
if not name:
|
|
70
|
+
continue
|
|
71
|
+
matches = await self._registry.lookup(name, node.label, match_strategy="fuzzy")
|
|
72
|
+
if matches:
|
|
73
|
+
entity_to_nodes[matches[0].entity_id].append(node.id)
|
|
74
|
+
|
|
75
|
+
for entity_id, node_ids in entity_to_nodes.items():
|
|
76
|
+
if len(node_ids) > 1:
|
|
77
|
+
issues.append(CurationIssue(
|
|
78
|
+
id=str(uuid.uuid4()),
|
|
79
|
+
issue_type="duplicate",
|
|
80
|
+
severity="warning",
|
|
81
|
+
affected_nodes=node_ids,
|
|
82
|
+
suggested_action=f"Merge nodes {node_ids} — they match registry entity '{entity_id}'",
|
|
83
|
+
auto_fixable=False,
|
|
84
|
+
source_layer="deterministic",
|
|
85
|
+
))
|
|
86
|
+
|
|
87
|
+
return issues
|
|
88
|
+
|
|
89
|
+
def _detect_duplicates_pairwise(
|
|
90
|
+
self, nodes: list[GraphNode]
|
|
91
|
+
) -> list[CurationIssue]:
|
|
92
|
+
issues: list[CurationIssue] = []
|
|
93
|
+
seen_pairs: set[tuple[str, str]] = set()
|
|
94
|
+
|
|
95
|
+
for i, a in enumerate(nodes):
|
|
96
|
+
name_a = a.properties.get("name", "")
|
|
97
|
+
if not name_a:
|
|
98
|
+
continue
|
|
99
|
+
for b in nodes[i + 1:]:
|
|
100
|
+
name_b = b.properties.get("name", "")
|
|
101
|
+
if not name_b:
|
|
102
|
+
continue
|
|
103
|
+
pair = (min(a.id, b.id), max(a.id, b.id))
|
|
104
|
+
if pair in seen_pairs:
|
|
105
|
+
continue
|
|
106
|
+
score = fuzzy_score(name_a, name_b)
|
|
107
|
+
if score >= _FUZZY_THRESHOLD:
|
|
108
|
+
seen_pairs.add(pair)
|
|
109
|
+
issues.append(CurationIssue(
|
|
110
|
+
id=str(uuid.uuid4()),
|
|
111
|
+
issue_type="duplicate",
|
|
112
|
+
severity="warning",
|
|
113
|
+
affected_nodes=[a.id, b.id],
|
|
114
|
+
suggested_action=f"Merge '{name_a}' and '{name_b}' (similarity: {score:.2f})",
|
|
115
|
+
auto_fixable=False,
|
|
116
|
+
source_layer="deterministic",
|
|
117
|
+
))
|
|
118
|
+
|
|
119
|
+
return issues
|
|
120
|
+
|
|
121
|
+
async def _detect_orphans(
|
|
122
|
+
self, nodes: list[GraphNode], graph_store: GraphStore
|
|
123
|
+
) -> list[CurationIssue]:
|
|
124
|
+
issues: list[CurationIssue] = []
|
|
125
|
+
|
|
126
|
+
for node in nodes:
|
|
127
|
+
related = await graph_store.get_related(node.id)
|
|
128
|
+
if not related:
|
|
129
|
+
issues.append(CurationIssue(
|
|
130
|
+
id=str(uuid.uuid4()),
|
|
131
|
+
issue_type="orphan",
|
|
132
|
+
severity="info",
|
|
133
|
+
affected_nodes=[node.id],
|
|
134
|
+
suggested_action=f"Node '{node.id}' ({node.label}) has no relationships",
|
|
135
|
+
auto_fixable=False,
|
|
136
|
+
source_layer="deterministic",
|
|
137
|
+
))
|
|
138
|
+
|
|
139
|
+
return issues
|
|
140
|
+
|
|
141
|
+
async def _detect_schema_violations(
|
|
142
|
+
self, graph_store: GraphStore, schema: OntologySchema
|
|
143
|
+
) -> list[CurationIssue]:
|
|
144
|
+
await graph_store.apply_schema(schema)
|
|
145
|
+
violations = await graph_store.validate_schema()
|
|
146
|
+
|
|
147
|
+
return [
|
|
148
|
+
CurationIssue(
|
|
149
|
+
id=str(uuid.uuid4()),
|
|
150
|
+
issue_type="schema_violation",
|
|
151
|
+
severity="error",
|
|
152
|
+
affected_nodes=[v.node_id],
|
|
153
|
+
suggested_action=v.message,
|
|
154
|
+
auto_fixable=False,
|
|
155
|
+
source_layer="deterministic",
|
|
156
|
+
)
|
|
157
|
+
for v in violations
|
|
158
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""BB5: Curation pipeline orchestrator."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from graphrag_core.interfaces import ApprovalGateway, DetectionLayer, GraphStore, LLMCurationLayer
|
|
6
|
+
from graphrag_core.models import CurationReport, OntologySchema
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CurationPipeline:
|
|
10
|
+
"""Orchestrates the curation flow: detect -> (curate) -> (approve)."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
detection: DetectionLayer,
|
|
15
|
+
llm_curation: LLMCurationLayer | None = None,
|
|
16
|
+
approval: ApprovalGateway | None = None,
|
|
17
|
+
) -> None:
|
|
18
|
+
self._detection = detection
|
|
19
|
+
self._llm_curation = llm_curation
|
|
20
|
+
self._approval = approval
|
|
21
|
+
|
|
22
|
+
async def run(
|
|
23
|
+
self,
|
|
24
|
+
graph_store: GraphStore,
|
|
25
|
+
schema: OntologySchema,
|
|
26
|
+
) -> CurationReport:
|
|
27
|
+
issues = await self._detection.detect(graph_store, schema)
|
|
28
|
+
|
|
29
|
+
if self._llm_curation is not None:
|
|
30
|
+
issues = await self._llm_curation.curate(issues)
|
|
31
|
+
|
|
32
|
+
nodes = await graph_store.list_nodes()
|
|
33
|
+
rel_count = await graph_store.count_relationships()
|
|
34
|
+
|
|
35
|
+
return CurationReport(
|
|
36
|
+
issues=issues,
|
|
37
|
+
nodes_scanned=len(nodes),
|
|
38
|
+
relationships_scanned=rel_count,
|
|
39
|
+
)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""BB2: LLM-powered schema-guided entity extraction engine."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
from graphrag_core.interfaces import LLMClient
|
|
8
|
+
from graphrag_core.models import (
|
|
9
|
+
DocumentChunk,
|
|
10
|
+
ExtractedNode,
|
|
11
|
+
ExtractedRelationship,
|
|
12
|
+
ExtractionResult,
|
|
13
|
+
ImportRun,
|
|
14
|
+
OntologySchema,
|
|
15
|
+
ProvenanceLink,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LLMExtractionEngine:
|
|
20
|
+
"""Extracts entities and relationships from text using an LLM, guided by an ontology schema."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, llm_client: LLMClient) -> None:
|
|
23
|
+
self._llm = llm_client
|
|
24
|
+
|
|
25
|
+
async def extract(
|
|
26
|
+
self,
|
|
27
|
+
chunks: list[DocumentChunk],
|
|
28
|
+
schema: OntologySchema,
|
|
29
|
+
import_run: ImportRun,
|
|
30
|
+
) -> ExtractionResult:
|
|
31
|
+
all_nodes: list[ExtractedNode] = []
|
|
32
|
+
all_rels: list[ExtractedRelationship] = []
|
|
33
|
+
all_provenance: list[ProvenanceLink] = []
|
|
34
|
+
|
|
35
|
+
system_prompt = self._build_system_prompt(schema)
|
|
36
|
+
|
|
37
|
+
for chunk in chunks:
|
|
38
|
+
nodes, rels = await self._extract_chunk(chunk, system_prompt)
|
|
39
|
+
nodes, rels = self._validate(nodes, rels, schema)
|
|
40
|
+
|
|
41
|
+
for node in nodes:
|
|
42
|
+
all_provenance.append(
|
|
43
|
+
ProvenanceLink(chunk_id=chunk.id, node_id=node.id, confidence=1.0)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
all_nodes.extend(nodes)
|
|
47
|
+
all_rels.extend(rels)
|
|
48
|
+
|
|
49
|
+
return ExtractionResult(
|
|
50
|
+
nodes=all_nodes,
|
|
51
|
+
relationships=all_rels,
|
|
52
|
+
provenance=all_provenance,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
async def _extract_chunk(
|
|
56
|
+
self, chunk: DocumentChunk, system_prompt: str
|
|
57
|
+
) -> tuple[list[ExtractedNode], list[ExtractedRelationship]]:
|
|
58
|
+
response = await self._llm.complete(
|
|
59
|
+
messages=[{"role": "user", "content": chunk.text}],
|
|
60
|
+
system=system_prompt,
|
|
61
|
+
temperature=0.0,
|
|
62
|
+
)
|
|
63
|
+
return self._parse_response(response)
|
|
64
|
+
|
|
65
|
+
def _build_system_prompt(self, schema: OntologySchema) -> str:
|
|
66
|
+
node_descriptions = []
|
|
67
|
+
for nt in schema.node_types:
|
|
68
|
+
props = ", ".join(
|
|
69
|
+
f"{p.name} ({p.type}{', required' if p.required else ''})"
|
|
70
|
+
for p in nt.properties
|
|
71
|
+
)
|
|
72
|
+
node_descriptions.append(f"- {nt.label}: properties=[{props}]")
|
|
73
|
+
|
|
74
|
+
rel_descriptions = []
|
|
75
|
+
for rt in schema.relationship_types:
|
|
76
|
+
rel_descriptions.append(
|
|
77
|
+
f"- {rt.type}: {rt.source_types} -> {rt.target_types}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return (
|
|
81
|
+
"You are an entity extraction engine. Extract entities and relationships "
|
|
82
|
+
"from the provided text according to this schema.\n\n"
|
|
83
|
+
"ALLOWED NODE TYPES:\n"
|
|
84
|
+
+ "\n".join(node_descriptions)
|
|
85
|
+
+ "\n\nALLOWED RELATIONSHIP TYPES:\n"
|
|
86
|
+
+ "\n".join(rel_descriptions)
|
|
87
|
+
+ "\n\nDo not extract entities or relationships not listed above.\n\n"
|
|
88
|
+
"Respond with ONLY a JSON object in this exact format:\n"
|
|
89
|
+
'{"nodes": [{"id": "<unique_id>", "label": "<NodeType>", "properties": {<key>: <value>}}], '
|
|
90
|
+
'"relationships": [{"source_id": "<node_id>", "target_id": "<node_id>", "type": "<RelType>", "properties": {}}]}\n\n'
|
|
91
|
+
"Rules:\n"
|
|
92
|
+
"- Every node id must be unique and descriptive (e.g., 'person-alice', 'company-acme')\n"
|
|
93
|
+
"- Only use node types and relationship types listed above\n"
|
|
94
|
+
"- Include all required properties for each node type\n"
|
|
95
|
+
"- Return empty arrays if no entities are found"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def _parse_response(
|
|
99
|
+
self, response: str
|
|
100
|
+
) -> tuple[list[ExtractedNode], list[ExtractedRelationship]]:
|
|
101
|
+
data = json.loads(response)
|
|
102
|
+
|
|
103
|
+
nodes = [
|
|
104
|
+
ExtractedNode(
|
|
105
|
+
id=n["id"],
|
|
106
|
+
label=n["label"],
|
|
107
|
+
properties=n.get("properties", {}),
|
|
108
|
+
)
|
|
109
|
+
for n in data.get("nodes", [])
|
|
110
|
+
]
|
|
111
|
+
|
|
112
|
+
rels = [
|
|
113
|
+
ExtractedRelationship(
|
|
114
|
+
source_id=r["source_id"],
|
|
115
|
+
target_id=r["target_id"],
|
|
116
|
+
type=r["type"],
|
|
117
|
+
properties=r.get("properties", {}),
|
|
118
|
+
)
|
|
119
|
+
for r in data.get("relationships", [])
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
return nodes, rels
|
|
123
|
+
|
|
124
|
+
def _validate(
|
|
125
|
+
self,
|
|
126
|
+
nodes: list[ExtractedNode],
|
|
127
|
+
rels: list[ExtractedRelationship],
|
|
128
|
+
schema: OntologySchema,
|
|
129
|
+
) -> tuple[list[ExtractedNode], list[ExtractedRelationship]]:
|
|
130
|
+
allowed_labels = {nt.label for nt in schema.node_types}
|
|
131
|
+
allowed_rel_types = {rt.type for rt in schema.relationship_types}
|
|
132
|
+
rel_constraints = {
|
|
133
|
+
rt.type: (set(rt.source_types), set(rt.target_types))
|
|
134
|
+
for rt in schema.relationship_types
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
valid_nodes = [n for n in nodes if n.label in allowed_labels]
|
|
138
|
+
valid_node_ids = {n.id for n in valid_nodes}
|
|
139
|
+
node_labels = {n.id: n.label for n in valid_nodes}
|
|
140
|
+
|
|
141
|
+
valid_rels = []
|
|
142
|
+
for rel in rels:
|
|
143
|
+
if rel.type not in allowed_rel_types:
|
|
144
|
+
continue
|
|
145
|
+
if rel.source_id not in valid_node_ids or rel.target_id not in valid_node_ids:
|
|
146
|
+
continue
|
|
147
|
+
source_types, target_types = rel_constraints[rel.type]
|
|
148
|
+
if node_labels[rel.source_id] not in source_types:
|
|
149
|
+
continue
|
|
150
|
+
if node_labels[rel.target_id] not in target_types:
|
|
151
|
+
continue
|
|
152
|
+
valid_rels.append(rel)
|
|
153
|
+
|
|
154
|
+
return valid_nodes, valid_rels
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""BB3: Graph store implementations."""
|
|
2
|
+
|
|
3
|
+
from graphrag_core.graph.memory import InMemoryGraphStore
|
|
4
|
+
|
|
5
|
+
__all__ = ["InMemoryGraphStore"]
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from graphrag_core.graph.neo4j import Neo4jGraphStore
|
|
9
|
+
__all__.append("Neo4jGraphStore")
|
|
10
|
+
except ImportError:
|
|
11
|
+
pass
|