rnsr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnsr/__init__.py +118 -0
- rnsr/__main__.py +242 -0
- rnsr/agent/__init__.py +218 -0
- rnsr/agent/cross_doc_navigator.py +767 -0
- rnsr/agent/graph.py +1557 -0
- rnsr/agent/llm_cache.py +575 -0
- rnsr/agent/navigator_api.py +497 -0
- rnsr/agent/provenance.py +772 -0
- rnsr/agent/query_clarifier.py +617 -0
- rnsr/agent/reasoning_memory.py +736 -0
- rnsr/agent/repl_env.py +709 -0
- rnsr/agent/rlm_navigator.py +2108 -0
- rnsr/agent/self_reflection.py +602 -0
- rnsr/agent/variable_store.py +308 -0
- rnsr/benchmarks/__init__.py +118 -0
- rnsr/benchmarks/comprehensive_benchmark.py +733 -0
- rnsr/benchmarks/evaluation_suite.py +1210 -0
- rnsr/benchmarks/finance_bench.py +147 -0
- rnsr/benchmarks/pdf_merger.py +178 -0
- rnsr/benchmarks/performance.py +321 -0
- rnsr/benchmarks/quality.py +321 -0
- rnsr/benchmarks/runner.py +298 -0
- rnsr/benchmarks/standard_benchmarks.py +995 -0
- rnsr/client.py +560 -0
- rnsr/document_store.py +394 -0
- rnsr/exceptions.py +74 -0
- rnsr/extraction/__init__.py +172 -0
- rnsr/extraction/candidate_extractor.py +357 -0
- rnsr/extraction/entity_extractor.py +581 -0
- rnsr/extraction/entity_linker.py +825 -0
- rnsr/extraction/grounded_extractor.py +722 -0
- rnsr/extraction/learned_types.py +599 -0
- rnsr/extraction/models.py +232 -0
- rnsr/extraction/relationship_extractor.py +600 -0
- rnsr/extraction/relationship_patterns.py +511 -0
- rnsr/extraction/relationship_validator.py +392 -0
- rnsr/extraction/rlm_extractor.py +589 -0
- rnsr/extraction/rlm_unified_extractor.py +990 -0
- rnsr/extraction/tot_validator.py +610 -0
- rnsr/extraction/unified_extractor.py +342 -0
- rnsr/indexing/__init__.py +60 -0
- rnsr/indexing/knowledge_graph.py +1128 -0
- rnsr/indexing/kv_store.py +313 -0
- rnsr/indexing/persistence.py +323 -0
- rnsr/indexing/semantic_retriever.py +237 -0
- rnsr/indexing/semantic_search.py +320 -0
- rnsr/indexing/skeleton_index.py +395 -0
- rnsr/ingestion/__init__.py +161 -0
- rnsr/ingestion/chart_parser.py +569 -0
- rnsr/ingestion/document_boundary.py +662 -0
- rnsr/ingestion/font_histogram.py +334 -0
- rnsr/ingestion/header_classifier.py +595 -0
- rnsr/ingestion/hierarchical_cluster.py +515 -0
- rnsr/ingestion/layout_detector.py +356 -0
- rnsr/ingestion/layout_model.py +379 -0
- rnsr/ingestion/ocr_fallback.py +177 -0
- rnsr/ingestion/pipeline.py +936 -0
- rnsr/ingestion/semantic_fallback.py +417 -0
- rnsr/ingestion/table_parser.py +799 -0
- rnsr/ingestion/text_builder.py +460 -0
- rnsr/ingestion/tree_builder.py +402 -0
- rnsr/ingestion/vision_retrieval.py +965 -0
- rnsr/ingestion/xy_cut.py +555 -0
- rnsr/llm.py +733 -0
- rnsr/models.py +167 -0
- rnsr/py.typed +2 -0
- rnsr-0.1.0.dist-info/METADATA +592 -0
- rnsr-0.1.0.dist-info/RECORD +72 -0
- rnsr-0.1.0.dist-info/WHEEL +5 -0
- rnsr-0.1.0.dist-info/entry_points.txt +2 -0
- rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
- rnsr-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RNSR Knowledge Graph - SQLite-Backed Graph Storage
|
|
3
|
+
|
|
4
|
+
Stores entities, relationships, and entity links for ontological
|
|
5
|
+
document understanding and cross-document queries.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
kg = KnowledgeGraph("./data/knowledge_graph.db")
|
|
9
|
+
kg.add_entity(entity)
|
|
10
|
+
kg.add_relationship(relationship)
|
|
11
|
+
|
|
12
|
+
# Query entities
|
|
13
|
+
entities = kg.find_entities_by_name("John Smith")
|
|
14
|
+
|
|
15
|
+
# Get entity relationships
|
|
16
|
+
relationships = kg.get_entity_relationships("ent_abc123")
|
|
17
|
+
|
|
18
|
+
# Cross-document entity linking
|
|
19
|
+
kg.link_entities("ent_doc1_john", "ent_doc2_john", confidence=0.95)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import json
|
|
25
|
+
import sqlite3
|
|
26
|
+
from contextlib import contextmanager
|
|
27
|
+
from datetime import datetime
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Any, Iterator
|
|
30
|
+
|
|
31
|
+
import structlog
|
|
32
|
+
|
|
33
|
+
from rnsr.extraction.models import (
|
|
34
|
+
Entity,
|
|
35
|
+
EntityLink,
|
|
36
|
+
EntityType,
|
|
37
|
+
Mention,
|
|
38
|
+
Relationship,
|
|
39
|
+
RelationType,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
logger = structlog.get_logger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class KnowledgeGraph:
|
|
46
|
+
"""
|
|
47
|
+
SQLite-backed knowledge graph for entity and relationship storage.
|
|
48
|
+
|
|
49
|
+
Supports:
|
|
50
|
+
- Entity storage with mentions and aliases
|
|
51
|
+
- Relationship storage between entities and nodes
|
|
52
|
+
- Cross-document entity linking
|
|
53
|
+
- Efficient querying by name, type, document, and node
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, db_path: Path | str):
|
|
57
|
+
"""
|
|
58
|
+
Initialize the knowledge graph.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
db_path: Path to the SQLite database file.
|
|
62
|
+
Will be created if it doesn't exist.
|
|
63
|
+
"""
|
|
64
|
+
self.db_path = Path(db_path)
|
|
65
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
66
|
+
|
|
67
|
+
self._init_db()
|
|
68
|
+
|
|
69
|
+
logger.info("knowledge_graph_initialized", db_path=str(self.db_path))
|
|
70
|
+
|
|
71
|
+
def _init_db(self) -> None:
|
|
72
|
+
"""Create the database schema if it doesn't exist."""
|
|
73
|
+
with self._connect() as conn:
|
|
74
|
+
# Entities table
|
|
75
|
+
conn.execute("""
|
|
76
|
+
CREATE TABLE IF NOT EXISTS entities (
|
|
77
|
+
id TEXT PRIMARY KEY,
|
|
78
|
+
type TEXT NOT NULL,
|
|
79
|
+
canonical_name TEXT NOT NULL,
|
|
80
|
+
aliases TEXT,
|
|
81
|
+
metadata TEXT,
|
|
82
|
+
source_doc_id TEXT,
|
|
83
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
84
|
+
)
|
|
85
|
+
""")
|
|
86
|
+
|
|
87
|
+
# Mentions table (entity-to-node links)
|
|
88
|
+
conn.execute("""
|
|
89
|
+
CREATE TABLE IF NOT EXISTS mentions (
|
|
90
|
+
id TEXT PRIMARY KEY,
|
|
91
|
+
entity_id TEXT NOT NULL REFERENCES entities(id) ON DELETE CASCADE,
|
|
92
|
+
node_id TEXT NOT NULL,
|
|
93
|
+
doc_id TEXT NOT NULL,
|
|
94
|
+
context TEXT,
|
|
95
|
+
span_start INTEGER,
|
|
96
|
+
span_end INTEGER,
|
|
97
|
+
page_num INTEGER,
|
|
98
|
+
confidence REAL DEFAULT 1.0
|
|
99
|
+
)
|
|
100
|
+
""")
|
|
101
|
+
|
|
102
|
+
# Relationships table
|
|
103
|
+
conn.execute("""
|
|
104
|
+
CREATE TABLE IF NOT EXISTS relationships (
|
|
105
|
+
id TEXT PRIMARY KEY,
|
|
106
|
+
type TEXT NOT NULL,
|
|
107
|
+
source_id TEXT NOT NULL,
|
|
108
|
+
target_id TEXT NOT NULL,
|
|
109
|
+
source_type TEXT DEFAULT 'entity',
|
|
110
|
+
target_type TEXT DEFAULT 'entity',
|
|
111
|
+
doc_id TEXT,
|
|
112
|
+
confidence REAL DEFAULT 1.0,
|
|
113
|
+
evidence TEXT,
|
|
114
|
+
metadata TEXT,
|
|
115
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
116
|
+
)
|
|
117
|
+
""")
|
|
118
|
+
|
|
119
|
+
# Entity links table (cross-document entity resolution)
|
|
120
|
+
conn.execute("""
|
|
121
|
+
CREATE TABLE IF NOT EXISTS entity_links (
|
|
122
|
+
entity_id_1 TEXT NOT NULL,
|
|
123
|
+
entity_id_2 TEXT NOT NULL,
|
|
124
|
+
confidence REAL DEFAULT 1.0,
|
|
125
|
+
link_method TEXT DEFAULT 'exact',
|
|
126
|
+
evidence TEXT,
|
|
127
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
128
|
+
PRIMARY KEY (entity_id_1, entity_id_2)
|
|
129
|
+
)
|
|
130
|
+
""")
|
|
131
|
+
|
|
132
|
+
# Indexes for efficient querying
|
|
133
|
+
conn.execute("""
|
|
134
|
+
CREATE INDEX IF NOT EXISTS idx_entities_type
|
|
135
|
+
ON entities(type)
|
|
136
|
+
""")
|
|
137
|
+
conn.execute("""
|
|
138
|
+
CREATE INDEX IF NOT EXISTS idx_entities_name
|
|
139
|
+
ON entities(canonical_name)
|
|
140
|
+
""")
|
|
141
|
+
conn.execute("""
|
|
142
|
+
CREATE INDEX IF NOT EXISTS idx_entities_doc
|
|
143
|
+
ON entities(source_doc_id)
|
|
144
|
+
""")
|
|
145
|
+
conn.execute("""
|
|
146
|
+
CREATE INDEX IF NOT EXISTS idx_mentions_entity
|
|
147
|
+
ON mentions(entity_id)
|
|
148
|
+
""")
|
|
149
|
+
conn.execute("""
|
|
150
|
+
CREATE INDEX IF NOT EXISTS idx_mentions_node
|
|
151
|
+
ON mentions(node_id)
|
|
152
|
+
""")
|
|
153
|
+
conn.execute("""
|
|
154
|
+
CREATE INDEX IF NOT EXISTS idx_mentions_doc
|
|
155
|
+
ON mentions(doc_id)
|
|
156
|
+
""")
|
|
157
|
+
conn.execute("""
|
|
158
|
+
CREATE INDEX IF NOT EXISTS idx_relationships_source
|
|
159
|
+
ON relationships(source_id)
|
|
160
|
+
""")
|
|
161
|
+
conn.execute("""
|
|
162
|
+
CREATE INDEX IF NOT EXISTS idx_relationships_target
|
|
163
|
+
ON relationships(target_id)
|
|
164
|
+
""")
|
|
165
|
+
conn.execute("""
|
|
166
|
+
CREATE INDEX IF NOT EXISTS idx_relationships_type
|
|
167
|
+
ON relationships(type)
|
|
168
|
+
""")
|
|
169
|
+
|
|
170
|
+
conn.commit()
|
|
171
|
+
|
|
172
|
+
@contextmanager
|
|
173
|
+
def _connect(self) -> Iterator[sqlite3.Connection]:
|
|
174
|
+
"""Context manager for database connections."""
|
|
175
|
+
conn = sqlite3.connect(self.db_path)
|
|
176
|
+
conn.row_factory = sqlite3.Row
|
|
177
|
+
# Enable foreign keys
|
|
178
|
+
conn.execute("PRAGMA foreign_keys = ON")
|
|
179
|
+
try:
|
|
180
|
+
yield conn
|
|
181
|
+
finally:
|
|
182
|
+
conn.close()
|
|
183
|
+
|
|
184
|
+
# =========================================================================
|
|
185
|
+
# Entity Operations
|
|
186
|
+
# =========================================================================
|
|
187
|
+
|
|
188
|
+
def add_entity(self, entity: Entity) -> str:
|
|
189
|
+
"""
|
|
190
|
+
Add an entity to the knowledge graph.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
entity: Entity to add.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Entity ID.
|
|
197
|
+
"""
|
|
198
|
+
with self._connect() as conn:
|
|
199
|
+
# Insert entity
|
|
200
|
+
conn.execute(
|
|
201
|
+
"""
|
|
202
|
+
INSERT INTO entities (id, type, canonical_name, aliases, metadata, source_doc_id, created_at)
|
|
203
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
204
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
205
|
+
canonical_name = excluded.canonical_name,
|
|
206
|
+
aliases = excluded.aliases,
|
|
207
|
+
metadata = excluded.metadata
|
|
208
|
+
""",
|
|
209
|
+
(
|
|
210
|
+
entity.id,
|
|
211
|
+
entity.type.value,
|
|
212
|
+
entity.canonical_name,
|
|
213
|
+
json.dumps(entity.aliases),
|
|
214
|
+
json.dumps(entity.metadata),
|
|
215
|
+
entity.source_doc_id,
|
|
216
|
+
entity.created_at.isoformat(),
|
|
217
|
+
),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Insert mentions
|
|
221
|
+
for mention in entity.mentions:
|
|
222
|
+
conn.execute(
|
|
223
|
+
"""
|
|
224
|
+
INSERT INTO mentions (id, entity_id, node_id, doc_id, context, span_start, span_end, page_num, confidence)
|
|
225
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
226
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
227
|
+
context = excluded.context,
|
|
228
|
+
confidence = excluded.confidence
|
|
229
|
+
""",
|
|
230
|
+
(
|
|
231
|
+
mention.id,
|
|
232
|
+
entity.id,
|
|
233
|
+
mention.node_id,
|
|
234
|
+
mention.doc_id,
|
|
235
|
+
mention.context,
|
|
236
|
+
mention.span_start,
|
|
237
|
+
mention.span_end,
|
|
238
|
+
mention.page_num,
|
|
239
|
+
mention.confidence,
|
|
240
|
+
),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
conn.commit()
|
|
244
|
+
|
|
245
|
+
logger.debug(
|
|
246
|
+
"entity_added",
|
|
247
|
+
entity_id=entity.id,
|
|
248
|
+
type=entity.type.value,
|
|
249
|
+
name=entity.canonical_name,
|
|
250
|
+
mentions=len(entity.mentions),
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
return entity.id
|
|
254
|
+
|
|
255
|
+
def get_entity(self, entity_id: str) -> Entity | None:
|
|
256
|
+
"""
|
|
257
|
+
Get an entity by ID.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
entity_id: Entity ID.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Entity or None if not found.
|
|
264
|
+
"""
|
|
265
|
+
with self._connect() as conn:
|
|
266
|
+
cursor = conn.execute(
|
|
267
|
+
"SELECT * FROM entities WHERE id = ?",
|
|
268
|
+
(entity_id,),
|
|
269
|
+
)
|
|
270
|
+
row = cursor.fetchone()
|
|
271
|
+
|
|
272
|
+
if row is None:
|
|
273
|
+
return None
|
|
274
|
+
|
|
275
|
+
# Get mentions
|
|
276
|
+
mentions_cursor = conn.execute(
|
|
277
|
+
"SELECT * FROM mentions WHERE entity_id = ?",
|
|
278
|
+
(entity_id,),
|
|
279
|
+
)
|
|
280
|
+
mentions = [self._row_to_mention(m) for m in mentions_cursor]
|
|
281
|
+
|
|
282
|
+
return self._row_to_entity(row, mentions)
|
|
283
|
+
|
|
284
|
+
def find_entities_by_name(
|
|
285
|
+
self,
|
|
286
|
+
name: str,
|
|
287
|
+
entity_type: EntityType | None = None,
|
|
288
|
+
fuzzy: bool = False,
|
|
289
|
+
) -> list[Entity]:
|
|
290
|
+
"""
|
|
291
|
+
Find entities by name.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
name: Entity name to search for.
|
|
295
|
+
entity_type: Optional type filter.
|
|
296
|
+
fuzzy: If True, use LIKE matching.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
List of matching entities.
|
|
300
|
+
"""
|
|
301
|
+
with self._connect() as conn:
|
|
302
|
+
if fuzzy:
|
|
303
|
+
name_pattern = f"%{name}%"
|
|
304
|
+
if entity_type:
|
|
305
|
+
cursor = conn.execute(
|
|
306
|
+
"""
|
|
307
|
+
SELECT * FROM entities
|
|
308
|
+
WHERE (canonical_name LIKE ? OR aliases LIKE ?)
|
|
309
|
+
AND type = ?
|
|
310
|
+
""",
|
|
311
|
+
(name_pattern, name_pattern, entity_type.value),
|
|
312
|
+
)
|
|
313
|
+
else:
|
|
314
|
+
cursor = conn.execute(
|
|
315
|
+
"""
|
|
316
|
+
SELECT * FROM entities
|
|
317
|
+
WHERE canonical_name LIKE ? OR aliases LIKE ?
|
|
318
|
+
""",
|
|
319
|
+
(name_pattern, name_pattern),
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
if entity_type:
|
|
323
|
+
cursor = conn.execute(
|
|
324
|
+
"""
|
|
325
|
+
SELECT * FROM entities
|
|
326
|
+
WHERE canonical_name = ? AND type = ?
|
|
327
|
+
""",
|
|
328
|
+
(name, entity_type.value),
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
cursor = conn.execute(
|
|
332
|
+
"SELECT * FROM entities WHERE canonical_name = ?",
|
|
333
|
+
(name,),
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
entities = []
|
|
337
|
+
for row in cursor:
|
|
338
|
+
mentions_cursor = conn.execute(
|
|
339
|
+
"SELECT * FROM mentions WHERE entity_id = ?",
|
|
340
|
+
(row["id"],),
|
|
341
|
+
)
|
|
342
|
+
mentions = [self._row_to_mention(m) for m in mentions_cursor]
|
|
343
|
+
entities.append(self._row_to_entity(row, mentions))
|
|
344
|
+
|
|
345
|
+
return entities
|
|
346
|
+
|
|
347
|
+
def find_entities_by_type(
|
|
348
|
+
self,
|
|
349
|
+
entity_type: EntityType,
|
|
350
|
+
doc_id: str | None = None,
|
|
351
|
+
) -> list[Entity]:
|
|
352
|
+
"""
|
|
353
|
+
Find entities by type.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
entity_type: Entity type to filter by.
|
|
357
|
+
doc_id: Optional document ID filter.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
List of matching entities.
|
|
361
|
+
"""
|
|
362
|
+
with self._connect() as conn:
|
|
363
|
+
if doc_id:
|
|
364
|
+
cursor = conn.execute(
|
|
365
|
+
"""
|
|
366
|
+
SELECT * FROM entities
|
|
367
|
+
WHERE type = ? AND source_doc_id = ?
|
|
368
|
+
""",
|
|
369
|
+
(entity_type.value, doc_id),
|
|
370
|
+
)
|
|
371
|
+
else:
|
|
372
|
+
cursor = conn.execute(
|
|
373
|
+
"SELECT * FROM entities WHERE type = ?",
|
|
374
|
+
(entity_type.value,),
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
entities = []
|
|
378
|
+
for row in cursor:
|
|
379
|
+
mentions_cursor = conn.execute(
|
|
380
|
+
"SELECT * FROM mentions WHERE entity_id = ?",
|
|
381
|
+
(row["id"],),
|
|
382
|
+
)
|
|
383
|
+
mentions = [self._row_to_mention(m) for m in mentions_cursor]
|
|
384
|
+
entities.append(self._row_to_entity(row, mentions))
|
|
385
|
+
|
|
386
|
+
return entities
|
|
387
|
+
|
|
388
|
+
def find_entities_in_node(self, node_id: str) -> list[Entity]:
|
|
389
|
+
"""
|
|
390
|
+
Find all entities mentioned in a specific node.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
node_id: Node ID.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
List of entities with mentions in the node.
|
|
397
|
+
"""
|
|
398
|
+
with self._connect() as conn:
|
|
399
|
+
cursor = conn.execute(
|
|
400
|
+
"""
|
|
401
|
+
SELECT DISTINCT e.* FROM entities e
|
|
402
|
+
JOIN mentions m ON e.id = m.entity_id
|
|
403
|
+
WHERE m.node_id = ?
|
|
404
|
+
""",
|
|
405
|
+
(node_id,),
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
entities = []
|
|
409
|
+
for row in cursor:
|
|
410
|
+
mentions_cursor = conn.execute(
|
|
411
|
+
"SELECT * FROM mentions WHERE entity_id = ?",
|
|
412
|
+
(row["id"],),
|
|
413
|
+
)
|
|
414
|
+
mentions = [self._row_to_mention(m) for m in mentions_cursor]
|
|
415
|
+
entities.append(self._row_to_entity(row, mentions))
|
|
416
|
+
|
|
417
|
+
return entities
|
|
418
|
+
|
|
419
|
+
def find_entities_in_document(self, doc_id: str) -> list[Entity]:
|
|
420
|
+
"""
|
|
421
|
+
Find all entities in a document.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
doc_id: Document ID.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
List of entities with mentions in the document.
|
|
428
|
+
"""
|
|
429
|
+
with self._connect() as conn:
|
|
430
|
+
cursor = conn.execute(
|
|
431
|
+
"""
|
|
432
|
+
SELECT DISTINCT e.* FROM entities e
|
|
433
|
+
JOIN mentions m ON e.id = m.entity_id
|
|
434
|
+
WHERE m.doc_id = ?
|
|
435
|
+
""",
|
|
436
|
+
(doc_id,),
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
entities = []
|
|
440
|
+
for row in cursor:
|
|
441
|
+
mentions_cursor = conn.execute(
|
|
442
|
+
"SELECT * FROM mentions WHERE entity_id = ? AND doc_id = ?",
|
|
443
|
+
(row["id"], doc_id),
|
|
444
|
+
)
|
|
445
|
+
mentions = [self._row_to_mention(m) for m in mentions_cursor]
|
|
446
|
+
entities.append(self._row_to_entity(row, mentions))
|
|
447
|
+
|
|
448
|
+
return entities
|
|
449
|
+
|
|
450
|
+
def delete_entity(self, entity_id: str) -> bool:
|
|
451
|
+
"""
|
|
452
|
+
Delete an entity and its mentions.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
entity_id: Entity ID.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
True if deleted.
|
|
459
|
+
"""
|
|
460
|
+
with self._connect() as conn:
|
|
461
|
+
cursor = conn.execute(
|
|
462
|
+
"DELETE FROM entities WHERE id = ?",
|
|
463
|
+
(entity_id,),
|
|
464
|
+
)
|
|
465
|
+
conn.commit()
|
|
466
|
+
return cursor.rowcount > 0
|
|
467
|
+
|
|
468
|
+
# =========================================================================
|
|
469
|
+
# Relationship Operations
|
|
470
|
+
# =========================================================================
|
|
471
|
+
|
|
472
|
+
def add_relationship(self, relationship: Relationship) -> str:
|
|
473
|
+
"""
|
|
474
|
+
Add a relationship to the knowledge graph.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
relationship: Relationship to add.
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
Relationship ID.
|
|
481
|
+
"""
|
|
482
|
+
with self._connect() as conn:
|
|
483
|
+
conn.execute(
|
|
484
|
+
"""
|
|
485
|
+
INSERT INTO relationships (id, type, source_id, target_id, source_type, target_type, doc_id, confidence, evidence, metadata, created_at)
|
|
486
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
487
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
488
|
+
confidence = excluded.confidence,
|
|
489
|
+
evidence = excluded.evidence,
|
|
490
|
+
metadata = excluded.metadata
|
|
491
|
+
""",
|
|
492
|
+
(
|
|
493
|
+
relationship.id,
|
|
494
|
+
relationship.type.value,
|
|
495
|
+
relationship.source_id,
|
|
496
|
+
relationship.target_id,
|
|
497
|
+
relationship.source_type,
|
|
498
|
+
relationship.target_type,
|
|
499
|
+
relationship.doc_id,
|
|
500
|
+
relationship.confidence,
|
|
501
|
+
relationship.evidence,
|
|
502
|
+
json.dumps(relationship.metadata),
|
|
503
|
+
relationship.created_at.isoformat(),
|
|
504
|
+
),
|
|
505
|
+
)
|
|
506
|
+
conn.commit()
|
|
507
|
+
|
|
508
|
+
logger.debug(
|
|
509
|
+
"relationship_added",
|
|
510
|
+
relationship_id=relationship.id,
|
|
511
|
+
type=relationship.type.value,
|
|
512
|
+
source=relationship.source_id,
|
|
513
|
+
target=relationship.target_id,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
return relationship.id
|
|
517
|
+
|
|
518
|
+
def get_relationship(self, relationship_id: str) -> Relationship | None:
|
|
519
|
+
"""
|
|
520
|
+
Get a relationship by ID.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
relationship_id: Relationship ID.
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
Relationship or None if not found.
|
|
527
|
+
"""
|
|
528
|
+
with self._connect() as conn:
|
|
529
|
+
cursor = conn.execute(
|
|
530
|
+
"SELECT * FROM relationships WHERE id = ?",
|
|
531
|
+
(relationship_id,),
|
|
532
|
+
)
|
|
533
|
+
row = cursor.fetchone()
|
|
534
|
+
|
|
535
|
+
if row is None:
|
|
536
|
+
return None
|
|
537
|
+
|
|
538
|
+
return self._row_to_relationship(row)
|
|
539
|
+
|
|
540
|
+
def get_entity_relationships(
|
|
541
|
+
self,
|
|
542
|
+
entity_id: str,
|
|
543
|
+
relationship_type: RelationType | None = None,
|
|
544
|
+
direction: str = "both",
|
|
545
|
+
) -> list[Relationship]:
|
|
546
|
+
"""
|
|
547
|
+
Get relationships involving an entity.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
entity_id: Entity ID.
|
|
551
|
+
relationship_type: Optional type filter.
|
|
552
|
+
direction: "outgoing", "incoming", or "both".
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
List of relationships.
|
|
556
|
+
"""
|
|
557
|
+
with self._connect() as conn:
|
|
558
|
+
conditions = []
|
|
559
|
+
params = []
|
|
560
|
+
|
|
561
|
+
if direction in ("outgoing", "both"):
|
|
562
|
+
conditions.append("source_id = ?")
|
|
563
|
+
params.append(entity_id)
|
|
564
|
+
if direction in ("incoming", "both"):
|
|
565
|
+
conditions.append("target_id = ?")
|
|
566
|
+
params.append(entity_id)
|
|
567
|
+
|
|
568
|
+
where_clause = " OR ".join(conditions)
|
|
569
|
+
|
|
570
|
+
if relationship_type:
|
|
571
|
+
query = f"SELECT * FROM relationships WHERE ({where_clause}) AND type = ?"
|
|
572
|
+
params.append(relationship_type.value)
|
|
573
|
+
else:
|
|
574
|
+
query = f"SELECT * FROM relationships WHERE {where_clause}"
|
|
575
|
+
|
|
576
|
+
cursor = conn.execute(query, params)
|
|
577
|
+
relationships = [self._row_to_relationship(row) for row in cursor]
|
|
578
|
+
|
|
579
|
+
return relationships
|
|
580
|
+
|
|
581
|
+
def get_node_relationships(
|
|
582
|
+
self,
|
|
583
|
+
node_id: str,
|
|
584
|
+
relationship_type: RelationType | None = None,
|
|
585
|
+
) -> list[Relationship]:
|
|
586
|
+
"""
|
|
587
|
+
Get relationships involving a node (as source or target).
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
node_id: Node ID.
|
|
591
|
+
relationship_type: Optional type filter.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
List of relationships.
|
|
595
|
+
"""
|
|
596
|
+
with self._connect() as conn:
|
|
597
|
+
if relationship_type:
|
|
598
|
+
cursor = conn.execute(
|
|
599
|
+
"""
|
|
600
|
+
SELECT * FROM relationships
|
|
601
|
+
WHERE (source_id = ? OR target_id = ?) AND type = ?
|
|
602
|
+
""",
|
|
603
|
+
(node_id, node_id, relationship_type.value),
|
|
604
|
+
)
|
|
605
|
+
else:
|
|
606
|
+
cursor = conn.execute(
|
|
607
|
+
"""
|
|
608
|
+
SELECT * FROM relationships
|
|
609
|
+
WHERE source_id = ? OR target_id = ?
|
|
610
|
+
""",
|
|
611
|
+
(node_id, node_id),
|
|
612
|
+
)
|
|
613
|
+
relationships = [self._row_to_relationship(row) for row in cursor]
|
|
614
|
+
|
|
615
|
+
return relationships
|
|
616
|
+
|
|
617
|
+
def delete_relationship(self, relationship_id: str) -> bool:
|
|
618
|
+
"""
|
|
619
|
+
Delete a relationship.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
relationship_id: Relationship ID.
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
True if deleted.
|
|
626
|
+
"""
|
|
627
|
+
with self._connect() as conn:
|
|
628
|
+
cursor = conn.execute(
|
|
629
|
+
"DELETE FROM relationships WHERE id = ?",
|
|
630
|
+
(relationship_id,),
|
|
631
|
+
)
|
|
632
|
+
conn.commit()
|
|
633
|
+
return cursor.rowcount > 0
|
|
634
|
+
|
|
635
|
+
# =========================================================================
|
|
636
|
+
# Entity Linking Operations
|
|
637
|
+
# =========================================================================
|
|
638
|
+
|
|
639
|
+
def link_entities(
|
|
640
|
+
self,
|
|
641
|
+
entity_id_1: str,
|
|
642
|
+
entity_id_2: str,
|
|
643
|
+
confidence: float = 1.0,
|
|
644
|
+
link_method: str = "exact",
|
|
645
|
+
evidence: str = "",
|
|
646
|
+
) -> None:
|
|
647
|
+
"""
|
|
648
|
+
Create a link between two entities (same real-world entity).
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
entity_id_1: First entity ID.
|
|
652
|
+
entity_id_2: Second entity ID.
|
|
653
|
+
confidence: Link confidence (0.0-1.0).
|
|
654
|
+
link_method: How the link was established.
|
|
655
|
+
evidence: Justification for the link.
|
|
656
|
+
"""
|
|
657
|
+
# Ensure consistent ordering
|
|
658
|
+
if entity_id_1 > entity_id_2:
|
|
659
|
+
entity_id_1, entity_id_2 = entity_id_2, entity_id_1
|
|
660
|
+
|
|
661
|
+
with self._connect() as conn:
|
|
662
|
+
conn.execute(
|
|
663
|
+
"""
|
|
664
|
+
INSERT INTO entity_links (entity_id_1, entity_id_2, confidence, link_method, evidence)
|
|
665
|
+
VALUES (?, ?, ?, ?, ?)
|
|
666
|
+
ON CONFLICT(entity_id_1, entity_id_2) DO UPDATE SET
|
|
667
|
+
confidence = MAX(excluded.confidence, entity_links.confidence),
|
|
668
|
+
link_method = excluded.link_method,
|
|
669
|
+
evidence = excluded.evidence
|
|
670
|
+
""",
|
|
671
|
+
(entity_id_1, entity_id_2, confidence, link_method, evidence),
|
|
672
|
+
)
|
|
673
|
+
conn.commit()
|
|
674
|
+
|
|
675
|
+
logger.debug(
|
|
676
|
+
"entities_linked",
|
|
677
|
+
entity_1=entity_id_1,
|
|
678
|
+
entity_2=entity_id_2,
|
|
679
|
+
confidence=confidence,
|
|
680
|
+
method=link_method,
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
def get_linked_entities(
|
|
684
|
+
self,
|
|
685
|
+
entity_id: str,
|
|
686
|
+
min_confidence: float = 0.0,
|
|
687
|
+
) -> list[EntityLink]:
|
|
688
|
+
"""
|
|
689
|
+
Get all entities linked to a given entity.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
entity_id: Entity ID.
|
|
693
|
+
min_confidence: Minimum link confidence.
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
List of EntityLink objects.
|
|
697
|
+
"""
|
|
698
|
+
with self._connect() as conn:
|
|
699
|
+
cursor = conn.execute(
|
|
700
|
+
"""
|
|
701
|
+
SELECT * FROM entity_links
|
|
702
|
+
WHERE (entity_id_1 = ? OR entity_id_2 = ?) AND confidence >= ?
|
|
703
|
+
""",
|
|
704
|
+
(entity_id, entity_id, min_confidence),
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
links = []
|
|
708
|
+
for row in cursor:
|
|
709
|
+
links.append(EntityLink(
|
|
710
|
+
entity_id_1=row["entity_id_1"],
|
|
711
|
+
entity_id_2=row["entity_id_2"],
|
|
712
|
+
confidence=row["confidence"],
|
|
713
|
+
link_method=row["link_method"],
|
|
714
|
+
evidence=row["evidence"] or "",
|
|
715
|
+
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else datetime.utcnow(),
|
|
716
|
+
))
|
|
717
|
+
|
|
718
|
+
return links
|
|
719
|
+
|
|
720
|
+
def find_entity_across_documents(
|
|
721
|
+
self,
|
|
722
|
+
entity_id: str,
|
|
723
|
+
min_confidence: float = 0.5,
|
|
724
|
+
) -> list[Entity]:
|
|
725
|
+
"""
|
|
726
|
+
Find the same entity across multiple documents.
|
|
727
|
+
|
|
728
|
+
Args:
|
|
729
|
+
entity_id: Starting entity ID.
|
|
730
|
+
min_confidence: Minimum link confidence.
|
|
731
|
+
|
|
732
|
+
Returns:
|
|
733
|
+
List of linked entities (including the original).
|
|
734
|
+
"""
|
|
735
|
+
# Get the original entity
|
|
736
|
+
original = self.get_entity(entity_id)
|
|
737
|
+
if not original:
|
|
738
|
+
return []
|
|
739
|
+
|
|
740
|
+
result = [original]
|
|
741
|
+
|
|
742
|
+
# Get linked entities
|
|
743
|
+
links = self.get_linked_entities(entity_id, min_confidence)
|
|
744
|
+
|
|
745
|
+
for link in links:
|
|
746
|
+
linked_id = link.entity_id_2 if link.entity_id_1 == entity_id else link.entity_id_1
|
|
747
|
+
linked_entity = self.get_entity(linked_id)
|
|
748
|
+
if linked_entity:
|
|
749
|
+
result.append(linked_entity)
|
|
750
|
+
|
|
751
|
+
return result
|
|
752
|
+
|
|
753
|
+
# =========================================================================
|
|
754
|
+
# Query Operations
|
|
755
|
+
# =========================================================================
|
|
756
|
+
|
|
757
|
+
def get_entities_mentioned_together(
|
|
758
|
+
self,
|
|
759
|
+
entity_id: str,
|
|
760
|
+
) -> list[tuple[Entity, int]]:
|
|
761
|
+
"""
|
|
762
|
+
Find entities that appear in the same nodes as a given entity.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
entity_id: Entity ID to find co-occurrences for.
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
List of (entity, co-occurrence count) tuples.
|
|
769
|
+
"""
|
|
770
|
+
with self._connect() as conn:
|
|
771
|
+
# Get nodes where the entity is mentioned
|
|
772
|
+
cursor = conn.execute(
|
|
773
|
+
"SELECT DISTINCT node_id FROM mentions WHERE entity_id = ?",
|
|
774
|
+
(entity_id,),
|
|
775
|
+
)
|
|
776
|
+
node_ids = [row["node_id"] for row in cursor]
|
|
777
|
+
|
|
778
|
+
if not node_ids:
|
|
779
|
+
return []
|
|
780
|
+
|
|
781
|
+
# Find other entities in those nodes
|
|
782
|
+
placeholders = ",".join("?" * len(node_ids))
|
|
783
|
+
cursor = conn.execute(
|
|
784
|
+
f"""
|
|
785
|
+
SELECT e.*, COUNT(DISTINCT m.node_id) as co_count
|
|
786
|
+
FROM entities e
|
|
787
|
+
JOIN mentions m ON e.id = m.entity_id
|
|
788
|
+
WHERE m.node_id IN ({placeholders}) AND e.id != ?
|
|
789
|
+
GROUP BY e.id
|
|
790
|
+
ORDER BY co_count DESC
|
|
791
|
+
""",
|
|
792
|
+
node_ids + [entity_id],
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
results = []
|
|
796
|
+
for row in cursor:
|
|
797
|
+
mentions_cursor = conn.execute(
|
|
798
|
+
"SELECT * FROM mentions WHERE entity_id = ?",
|
|
799
|
+
(row["id"],),
|
|
800
|
+
)
|
|
801
|
+
mentions = [self._row_to_mention(m) for m in mentions_cursor]
|
|
802
|
+
entity = self._row_to_entity(row, mentions)
|
|
803
|
+
results.append((entity, row["co_count"]))
|
|
804
|
+
|
|
805
|
+
return results
|
|
806
|
+
|
|
807
|
+
# =========================================================================
|
|
808
|
+
# Statistics
|
|
809
|
+
# =========================================================================
|
|
810
|
+
|
|
811
|
+
def get_stats(self) -> dict[str, Any]:
|
|
812
|
+
"""Get statistics about the knowledge graph."""
|
|
813
|
+
with self._connect() as conn:
|
|
814
|
+
entity_count = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0]
|
|
815
|
+
mention_count = conn.execute("SELECT COUNT(*) FROM mentions").fetchone()[0]
|
|
816
|
+
relationship_count = conn.execute("SELECT COUNT(*) FROM relationships").fetchone()[0]
|
|
817
|
+
link_count = conn.execute("SELECT COUNT(*) FROM entity_links").fetchone()[0]
|
|
818
|
+
|
|
819
|
+
# Type distribution
|
|
820
|
+
type_cursor = conn.execute(
|
|
821
|
+
"SELECT type, COUNT(*) as count FROM entities GROUP BY type"
|
|
822
|
+
)
|
|
823
|
+
type_distribution = {row["type"]: row["count"] for row in type_cursor}
|
|
824
|
+
|
|
825
|
+
return {
|
|
826
|
+
"entity_count": entity_count,
|
|
827
|
+
"mention_count": mention_count,
|
|
828
|
+
"relationship_count": relationship_count,
|
|
829
|
+
"entity_link_count": link_count,
|
|
830
|
+
"entity_type_distribution": type_distribution,
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
def clear(self) -> dict[str, int]:
|
|
834
|
+
"""
|
|
835
|
+
Clear all data from the knowledge graph.
|
|
836
|
+
|
|
837
|
+
Returns:
|
|
838
|
+
Count of deleted items by type.
|
|
839
|
+
"""
|
|
840
|
+
with self._connect() as conn:
|
|
841
|
+
entity_count = conn.execute("DELETE FROM entities").rowcount
|
|
842
|
+
mention_count = conn.execute("DELETE FROM mentions").rowcount
|
|
843
|
+
relationship_count = conn.execute("DELETE FROM relationships").rowcount
|
|
844
|
+
link_count = conn.execute("DELETE FROM entity_links").rowcount
|
|
845
|
+
conn.commit()
|
|
846
|
+
|
|
847
|
+
logger.warning(
|
|
848
|
+
"knowledge_graph_cleared",
|
|
849
|
+
entities=entity_count,
|
|
850
|
+
mentions=mention_count,
|
|
851
|
+
relationships=relationship_count,
|
|
852
|
+
links=link_count,
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
return {
|
|
856
|
+
"entities": entity_count,
|
|
857
|
+
"mentions": mention_count,
|
|
858
|
+
"relationships": relationship_count,
|
|
859
|
+
"entity_links": link_count,
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
# =========================================================================
|
|
863
|
+
# Helper Methods
|
|
864
|
+
# =========================================================================
|
|
865
|
+
|
|
866
|
+
def _row_to_entity(self, row: sqlite3.Row, mentions: list[Mention]) -> Entity:
|
|
867
|
+
"""Convert a database row to an Entity object."""
|
|
868
|
+
return Entity(
|
|
869
|
+
id=row["id"],
|
|
870
|
+
type=EntityType(row["type"]),
|
|
871
|
+
canonical_name=row["canonical_name"],
|
|
872
|
+
aliases=json.loads(row["aliases"]) if row["aliases"] else [],
|
|
873
|
+
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
|
|
874
|
+
source_doc_id=row["source_doc_id"],
|
|
875
|
+
mentions=mentions,
|
|
876
|
+
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else datetime.utcnow(),
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
def _row_to_mention(self, row: sqlite3.Row) -> Mention:
|
|
880
|
+
"""Convert a database row to a Mention object."""
|
|
881
|
+
return Mention(
|
|
882
|
+
id=row["id"],
|
|
883
|
+
node_id=row["node_id"],
|
|
884
|
+
doc_id=row["doc_id"],
|
|
885
|
+
context=row["context"] or "",
|
|
886
|
+
span_start=row["span_start"],
|
|
887
|
+
span_end=row["span_end"],
|
|
888
|
+
page_num=row["page_num"],
|
|
889
|
+
confidence=row["confidence"],
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
def _row_to_relationship(self, row: sqlite3.Row) -> Relationship:
|
|
893
|
+
"""Convert a database row to a Relationship object."""
|
|
894
|
+
return Relationship(
|
|
895
|
+
id=row["id"],
|
|
896
|
+
type=RelationType(row["type"]),
|
|
897
|
+
source_id=row["source_id"],
|
|
898
|
+
target_id=row["target_id"],
|
|
899
|
+
source_type=row["source_type"],
|
|
900
|
+
target_type=row["target_type"],
|
|
901
|
+
doc_id=row["doc_id"],
|
|
902
|
+
confidence=row["confidence"],
|
|
903
|
+
evidence=row["evidence"] or "",
|
|
904
|
+
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
|
|
905
|
+
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else datetime.utcnow(),
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
class InMemoryKnowledgeGraph:
|
|
910
|
+
"""
|
|
911
|
+
In-memory knowledge graph for testing and ephemeral usage.
|
|
912
|
+
|
|
913
|
+
API-compatible with KnowledgeGraph.
|
|
914
|
+
"""
|
|
915
|
+
|
|
916
|
+
def __init__(self):
|
|
917
|
+
self._entities: dict[str, Entity] = {}
|
|
918
|
+
self._relationships: dict[str, Relationship] = {}
|
|
919
|
+
self._entity_links: dict[tuple[str, str], EntityLink] = {}
|
|
920
|
+
|
|
921
|
+
def add_entity(self, entity: Entity) -> str:
|
|
922
|
+
self._entities[entity.id] = entity
|
|
923
|
+
return entity.id
|
|
924
|
+
|
|
925
|
+
def get_entity(self, entity_id: str) -> Entity | None:
|
|
926
|
+
return self._entities.get(entity_id)
|
|
927
|
+
|
|
928
|
+
def find_entities_by_name(
|
|
929
|
+
self,
|
|
930
|
+
name: str,
|
|
931
|
+
entity_type: EntityType | None = None,
|
|
932
|
+
fuzzy: bool = False,
|
|
933
|
+
) -> list[Entity]:
|
|
934
|
+
results = []
|
|
935
|
+
name_lower = name.lower()
|
|
936
|
+
|
|
937
|
+
for entity in self._entities.values():
|
|
938
|
+
if entity_type and entity.type != entity_type:
|
|
939
|
+
continue
|
|
940
|
+
|
|
941
|
+
if fuzzy:
|
|
942
|
+
if name_lower in entity.canonical_name.lower() or any(
|
|
943
|
+
name_lower in alias.lower() for alias in entity.aliases
|
|
944
|
+
):
|
|
945
|
+
results.append(entity)
|
|
946
|
+
else:
|
|
947
|
+
if entity.canonical_name == name:
|
|
948
|
+
results.append(entity)
|
|
949
|
+
|
|
950
|
+
return results
|
|
951
|
+
|
|
952
|
+
def find_entities_by_type(
|
|
953
|
+
self,
|
|
954
|
+
entity_type: EntityType,
|
|
955
|
+
doc_id: str | None = None,
|
|
956
|
+
) -> list[Entity]:
|
|
957
|
+
results = []
|
|
958
|
+
for entity in self._entities.values():
|
|
959
|
+
if entity.type == entity_type:
|
|
960
|
+
if doc_id is None or entity.source_doc_id == doc_id:
|
|
961
|
+
results.append(entity)
|
|
962
|
+
return results
|
|
963
|
+
|
|
964
|
+
def find_entities_in_node(self, node_id: str) -> list[Entity]:
|
|
965
|
+
results = []
|
|
966
|
+
for entity in self._entities.values():
|
|
967
|
+
if any(m.node_id == node_id for m in entity.mentions):
|
|
968
|
+
results.append(entity)
|
|
969
|
+
return results
|
|
970
|
+
|
|
971
|
+
def find_entities_in_document(self, doc_id: str) -> list[Entity]:
|
|
972
|
+
results = []
|
|
973
|
+
for entity in self._entities.values():
|
|
974
|
+
if any(m.doc_id == doc_id for m in entity.mentions):
|
|
975
|
+
results.append(entity)
|
|
976
|
+
return results
|
|
977
|
+
|
|
978
|
+
def delete_entity(self, entity_id: str) -> bool:
|
|
979
|
+
if entity_id in self._entities:
|
|
980
|
+
del self._entities[entity_id]
|
|
981
|
+
return True
|
|
982
|
+
return False
|
|
983
|
+
|
|
984
|
+
def add_relationship(self, relationship: Relationship) -> str:
|
|
985
|
+
self._relationships[relationship.id] = relationship
|
|
986
|
+
return relationship.id
|
|
987
|
+
|
|
988
|
+
def get_relationship(self, relationship_id: str) -> Relationship | None:
|
|
989
|
+
return self._relationships.get(relationship_id)
|
|
990
|
+
|
|
991
|
+
def get_entity_relationships(
|
|
992
|
+
self,
|
|
993
|
+
entity_id: str,
|
|
994
|
+
relationship_type: RelationType | None = None,
|
|
995
|
+
direction: str = "both",
|
|
996
|
+
) -> list[Relationship]:
|
|
997
|
+
results = []
|
|
998
|
+
for rel in self._relationships.values():
|
|
999
|
+
if relationship_type and rel.type != relationship_type:
|
|
1000
|
+
continue
|
|
1001
|
+
|
|
1002
|
+
if direction == "outgoing" and rel.source_id == entity_id:
|
|
1003
|
+
results.append(rel)
|
|
1004
|
+
elif direction == "incoming" and rel.target_id == entity_id:
|
|
1005
|
+
results.append(rel)
|
|
1006
|
+
elif direction == "both" and (rel.source_id == entity_id or rel.target_id == entity_id):
|
|
1007
|
+
results.append(rel)
|
|
1008
|
+
|
|
1009
|
+
return results
|
|
1010
|
+
|
|
1011
|
+
def get_node_relationships(
|
|
1012
|
+
self,
|
|
1013
|
+
node_id: str,
|
|
1014
|
+
relationship_type: RelationType | None = None,
|
|
1015
|
+
) -> list[Relationship]:
|
|
1016
|
+
results = []
|
|
1017
|
+
for rel in self._relationships.values():
|
|
1018
|
+
if relationship_type and rel.type != relationship_type:
|
|
1019
|
+
continue
|
|
1020
|
+
if rel.source_id == node_id or rel.target_id == node_id:
|
|
1021
|
+
results.append(rel)
|
|
1022
|
+
return results
|
|
1023
|
+
|
|
1024
|
+
def delete_relationship(self, relationship_id: str) -> bool:
|
|
1025
|
+
if relationship_id in self._relationships:
|
|
1026
|
+
del self._relationships[relationship_id]
|
|
1027
|
+
return True
|
|
1028
|
+
return False
|
|
1029
|
+
|
|
1030
|
+
def link_entities(
|
|
1031
|
+
self,
|
|
1032
|
+
entity_id_1: str,
|
|
1033
|
+
entity_id_2: str,
|
|
1034
|
+
confidence: float = 1.0,
|
|
1035
|
+
link_method: str = "exact",
|
|
1036
|
+
evidence: str = "",
|
|
1037
|
+
) -> None:
|
|
1038
|
+
if entity_id_1 > entity_id_2:
|
|
1039
|
+
entity_id_1, entity_id_2 = entity_id_2, entity_id_1
|
|
1040
|
+
|
|
1041
|
+
self._entity_links[(entity_id_1, entity_id_2)] = EntityLink(
|
|
1042
|
+
entity_id_1=entity_id_1,
|
|
1043
|
+
entity_id_2=entity_id_2,
|
|
1044
|
+
confidence=confidence,
|
|
1045
|
+
link_method=link_method,
|
|
1046
|
+
evidence=evidence,
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
def get_linked_entities(
|
|
1050
|
+
self,
|
|
1051
|
+
entity_id: str,
|
|
1052
|
+
min_confidence: float = 0.0,
|
|
1053
|
+
) -> list[EntityLink]:
|
|
1054
|
+
results = []
|
|
1055
|
+
for link in self._entity_links.values():
|
|
1056
|
+
if (link.entity_id_1 == entity_id or link.entity_id_2 == entity_id) and link.confidence >= min_confidence:
|
|
1057
|
+
results.append(link)
|
|
1058
|
+
return results
|
|
1059
|
+
|
|
1060
|
+
def find_entity_across_documents(
|
|
1061
|
+
self,
|
|
1062
|
+
entity_id: str,
|
|
1063
|
+
min_confidence: float = 0.5,
|
|
1064
|
+
) -> list[Entity]:
|
|
1065
|
+
original = self.get_entity(entity_id)
|
|
1066
|
+
if not original:
|
|
1067
|
+
return []
|
|
1068
|
+
|
|
1069
|
+
result = [original]
|
|
1070
|
+
links = self.get_linked_entities(entity_id, min_confidence)
|
|
1071
|
+
|
|
1072
|
+
for link in links:
|
|
1073
|
+
linked_id = link.entity_id_2 if link.entity_id_1 == entity_id else link.entity_id_1
|
|
1074
|
+
linked_entity = self.get_entity(linked_id)
|
|
1075
|
+
if linked_entity:
|
|
1076
|
+
result.append(linked_entity)
|
|
1077
|
+
|
|
1078
|
+
return result
|
|
1079
|
+
|
|
1080
|
+
def get_entities_mentioned_together(
|
|
1081
|
+
self,
|
|
1082
|
+
entity_id: str,
|
|
1083
|
+
) -> list[tuple[Entity, int]]:
|
|
1084
|
+
entity = self.get_entity(entity_id)
|
|
1085
|
+
if not entity:
|
|
1086
|
+
return []
|
|
1087
|
+
|
|
1088
|
+
node_ids = {m.node_id for m in entity.mentions}
|
|
1089
|
+
co_occurrences: dict[str, int] = {}
|
|
1090
|
+
|
|
1091
|
+
for other in self._entities.values():
|
|
1092
|
+
if other.id == entity_id:
|
|
1093
|
+
continue
|
|
1094
|
+
count = sum(1 for m in other.mentions if m.node_id in node_ids)
|
|
1095
|
+
if count > 0:
|
|
1096
|
+
co_occurrences[other.id] = count
|
|
1097
|
+
|
|
1098
|
+
results = [
|
|
1099
|
+
(self._entities[eid], count)
|
|
1100
|
+
for eid, count in sorted(co_occurrences.items(), key=lambda x: -x[1])
|
|
1101
|
+
]
|
|
1102
|
+
|
|
1103
|
+
return results
|
|
1104
|
+
|
|
1105
|
+
def get_stats(self) -> dict[str, Any]:
|
|
1106
|
+
type_distribution: dict[str, int] = {}
|
|
1107
|
+
for entity in self._entities.values():
|
|
1108
|
+
type_distribution[entity.type.value] = type_distribution.get(entity.type.value, 0) + 1
|
|
1109
|
+
|
|
1110
|
+
return {
|
|
1111
|
+
"entity_count": len(self._entities),
|
|
1112
|
+
"mention_count": sum(len(e.mentions) for e in self._entities.values()),
|
|
1113
|
+
"relationship_count": len(self._relationships),
|
|
1114
|
+
"entity_link_count": len(self._entity_links),
|
|
1115
|
+
"entity_type_distribution": type_distribution,
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
def clear(self) -> dict[str, int]:
|
|
1119
|
+
counts = {
|
|
1120
|
+
"entities": len(self._entities),
|
|
1121
|
+
"mentions": sum(len(e.mentions) for e in self._entities.values()),
|
|
1122
|
+
"relationships": len(self._relationships),
|
|
1123
|
+
"entity_links": len(self._entity_links),
|
|
1124
|
+
}
|
|
1125
|
+
self._entities.clear()
|
|
1126
|
+
self._relationships.clear()
|
|
1127
|
+
self._entity_links.clear()
|
|
1128
|
+
return counts
|