odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/core/knowledge_graph.py
CHANGED
|
@@ -1,309 +1,309 @@
|
|
|
1
|
-
|
|
2
|
-
from typing import Dict, List, Set, Tuple, Optional, Iterator, Any
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
from collections import defaultdict
|
|
5
|
-
import logging
|
|
6
|
-
|
|
7
|
-
logger = logging.getLogger(__name__)
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@dataclass(frozen=True, eq=True)
|
|
11
|
-
class Entity:
|
|
12
|
-
"""Entity representation"""
|
|
13
|
-
name: str
|
|
14
|
-
entity_id: Optional[int] = None
|
|
15
|
-
|
|
16
|
-
def __str__(self) -> str:
|
|
17
|
-
return self.name
|
|
18
|
-
|
|
19
|
-
def __hash__(self) -> int:
|
|
20
|
-
return hash(self.name)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@dataclass(frozen=True, eq=True)
|
|
24
|
-
class Relation:
|
|
25
|
-
"""Relation/Predicate representation """
|
|
26
|
-
name: str
|
|
27
|
-
relation_id: Optional[int] = None
|
|
28
|
-
|
|
29
|
-
def __str__(self) -> str:
|
|
30
|
-
return self.name
|
|
31
|
-
|
|
32
|
-
def __hash__(self) -> int:
|
|
33
|
-
return hash(self.name)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@dataclass(frozen=True, eq=True)
|
|
37
|
-
class Triple:
|
|
38
|
-
"""
|
|
39
|
-
Triple representation: (head, relation, tail) or l(eh, et)
|
|
40
|
-
"""
|
|
41
|
-
head: Entity
|
|
42
|
-
relation: Relation
|
|
43
|
-
tail: Entity
|
|
44
|
-
|
|
45
|
-
def __str__(self) -> str:
|
|
46
|
-
return f"{self.relation.name}({self.head.name}, {self.tail.name})"
|
|
47
|
-
|
|
48
|
-
def __hash__(self) -> int:
|
|
49
|
-
return hash((self.head, self.relation, self.tail))
|
|
50
|
-
|
|
51
|
-
def to_predicate_logic(self) -> str:
|
|
52
|
-
"""Convert to predicate logic format: relation(head, tail)"""
|
|
53
|
-
return f"{self.relation.name}({self.head.name}, {self.tail.name})"
|
|
54
|
-
|
|
55
|
-
def is_valid(self) -> bool:
|
|
56
|
-
"""Check if triple has valid components"""
|
|
57
|
-
return all([
|
|
58
|
-
self.head.name.strip(),
|
|
59
|
-
self.relation.name.strip(),
|
|
60
|
-
self.tail.name.strip()
|
|
61
|
-
])
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class KnowledgeGraph:
|
|
65
|
-
|
|
66
|
-
def __init__(self, name: str = "KnowledgeGraph"):
|
|
67
|
-
self.name = name
|
|
68
|
-
|
|
69
|
-
# Core components as per paper definition
|
|
70
|
-
self._entities: Dict[str, Entity] = {} # E: entity set
|
|
71
|
-
self._relations: Dict[str, Relation] = {} # L: relation set
|
|
72
|
-
self._known_facts: Set[Triple] = set() # F: known facts
|
|
73
|
-
self._unknown_facts: Set[Triple] = set() # U: unknown facts
|
|
74
|
-
|
|
75
|
-
# Indexing for efficient access
|
|
76
|
-
self._head_index: Dict[Entity, Set[Triple]] = defaultdict(set)
|
|
77
|
-
self._tail_index: Dict[Entity, Set[Triple]] = defaultdict(set)
|
|
78
|
-
self._relation_index: Dict[Relation, Set[Triple]] = defaultdict(set)
|
|
79
|
-
|
|
80
|
-
# Statistics
|
|
81
|
-
self._stats = {
|
|
82
|
-
'num_entities': 0,
|
|
83
|
-
'num_relations': 0,
|
|
84
|
-
'num_known_facts': 0,
|
|
85
|
-
'num_unknown_facts': 0
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
@property
|
|
89
|
-
def entities(self) -> Set[Entity]:
|
|
90
|
-
"""Get entity set E = {e1, e2, ..., eM}"""
|
|
91
|
-
return set(self._entities.values())
|
|
92
|
-
|
|
93
|
-
@property
|
|
94
|
-
def relations(self) -> Set[Relation]:
|
|
95
|
-
"""Get relation set L = {l1, l2, ..., lN}"""
|
|
96
|
-
return set(self._relations.values())
|
|
97
|
-
|
|
98
|
-
@property
|
|
99
|
-
def known_facts(self) -> Set[Triple]:
|
|
100
|
-
"""Get known facts set F = {f1, f2, ..., fS}"""
|
|
101
|
-
return self._known_facts.copy()
|
|
102
|
-
|
|
103
|
-
@property
|
|
104
|
-
def unknown_facts(self) -> Set[Triple]:
|
|
105
|
-
"""Get unknown facts set U (for inference)"""
|
|
106
|
-
return self._unknown_facts.copy()
|
|
107
|
-
|
|
108
|
-
@property
|
|
109
|
-
def all_facts(self) -> Set[Triple]:
|
|
110
|
-
"""Get all facts F ∪ U"""
|
|
111
|
-
return self._known_facts | self._unknown_facts
|
|
112
|
-
|
|
113
|
-
def get_entity(self, name: str) -> Optional[Entity]:
|
|
114
|
-
"""Get entity by name"""
|
|
115
|
-
return self._entities.get(name)
|
|
116
|
-
|
|
117
|
-
def get_relation(self, name: str) -> Optional[Relation]:
|
|
118
|
-
"""Get relation by name"""
|
|
119
|
-
return self._relations.get(name)
|
|
120
|
-
|
|
121
|
-
def add_entity(self, name: str) -> Entity:
|
|
122
|
-
"""Add entity to E set"""
|
|
123
|
-
if name not in self._entities:
|
|
124
|
-
entity_id = len(self._entities)
|
|
125
|
-
entity = Entity(name=name, entity_id=entity_id)
|
|
126
|
-
self._entities[name] = entity
|
|
127
|
-
self._stats['num_entities'] = len(self._entities)
|
|
128
|
-
logger.debug(f"Added entity: {name}")
|
|
129
|
-
return self._entities[name]
|
|
130
|
-
|
|
131
|
-
def add_relation(self, name: str) -> Relation:
|
|
132
|
-
"""Add relation to L set"""
|
|
133
|
-
if name not in self._relations:
|
|
134
|
-
relation_id = len(self._relations)
|
|
135
|
-
relation = Relation(name=name, relation_id=relation_id)
|
|
136
|
-
self._relations[name] = relation
|
|
137
|
-
self._stats['num_relations'] = len(self._relations)
|
|
138
|
-
logger.debug(f"Added relation: {name}")
|
|
139
|
-
return self._relations[name]
|
|
140
|
-
|
|
141
|
-
def add_known_fact(self, head: str, relation: str, tail: str) -> Triple:
|
|
142
|
-
"""
|
|
143
|
-
Add known fact to F set
|
|
144
|
-
"""
|
|
145
|
-
head_entity = self.add_entity(head)
|
|
146
|
-
relation_obj = self.add_relation(relation)
|
|
147
|
-
tail_entity = self.add_entity(tail)
|
|
148
|
-
|
|
149
|
-
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
150
|
-
|
|
151
|
-
if triple.is_valid() and triple not in self._known_facts:
|
|
152
|
-
self._known_facts.add(triple)
|
|
153
|
-
self._update_indices(triple)
|
|
154
|
-
self._stats['num_known_facts'] = len(self._known_facts)
|
|
155
|
-
logger.debug(f"Added known fact: {triple}")
|
|
156
|
-
|
|
157
|
-
return triple
|
|
158
|
-
|
|
159
|
-
def add_unknown_fact(self, head: str, relation: str, tail: str) -> Triple:
|
|
160
|
-
"""Add unknown fact to U set (for inference)"""
|
|
161
|
-
head_entity = self.add_entity(head)
|
|
162
|
-
relation_obj = self.add_relation(relation)
|
|
163
|
-
tail_entity = self.add_entity(tail)
|
|
164
|
-
|
|
165
|
-
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
166
|
-
|
|
167
|
-
if triple.is_valid() and triple not in self._unknown_facts:
|
|
168
|
-
self._unknown_facts.add(triple)
|
|
169
|
-
self._update_indices(triple)
|
|
170
|
-
self._stats['num_unknown_facts'] = len(self._unknown_facts)
|
|
171
|
-
logger.debug(f"Added unknown fact: {triple}")
|
|
172
|
-
|
|
173
|
-
return triple
|
|
174
|
-
|
|
175
|
-
def _update_indices(self, triple: Triple):
|
|
176
|
-
"""Update internal indices for efficient querying"""
|
|
177
|
-
self._head_index[triple.head].add(triple)
|
|
178
|
-
self._tail_index[triple.tail].add(triple)
|
|
179
|
-
self._relation_index[triple.relation].add(triple)
|
|
180
|
-
|
|
181
|
-
def get_facts_by_head(self, entity: Entity) -> Set[Triple]:
|
|
182
|
-
"""Get all facts with given entity as head"""
|
|
183
|
-
return self._head_index[entity].copy()
|
|
184
|
-
|
|
185
|
-
def get_facts_by_tail(self, entity: Entity) -> Set[Triple]:
|
|
186
|
-
"""Get all facts with given entity as tail"""
|
|
187
|
-
return self._tail_index[entity].copy()
|
|
188
|
-
|
|
189
|
-
def get_facts_by_relation(self, relation: Relation) -> Set[Triple]:
|
|
190
|
-
"""Get all facts with given relation"""
|
|
191
|
-
return self._relation_index[relation].copy()
|
|
192
|
-
|
|
193
|
-
def contains_fact(self, head: str, relation: str, tail: str) -> bool:
|
|
194
|
-
"""Check if fact exists in known facts F"""
|
|
195
|
-
head_entity = self.get_entity(head)
|
|
196
|
-
relation_obj = self.get_relation(relation)
|
|
197
|
-
tail_entity = self.get_entity(tail)
|
|
198
|
-
|
|
199
|
-
if not all([head_entity, relation_obj, tail_entity]):
|
|
200
|
-
return False
|
|
201
|
-
|
|
202
|
-
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
203
|
-
return triple in self._known_facts
|
|
204
|
-
|
|
205
|
-
def get_neighbors(self, entity: Entity, relation: Optional[Relation] = None) -> Set[Entity]:
|
|
206
|
-
"""Get neighboring entities connected by relation"""
|
|
207
|
-
neighbors = set()
|
|
208
|
-
|
|
209
|
-
# Outgoing edges (entity as head)
|
|
210
|
-
for triple in self.get_facts_by_head(entity):
|
|
211
|
-
if relation is None or triple.relation == relation:
|
|
212
|
-
neighbors.add(triple.tail)
|
|
213
|
-
|
|
214
|
-
# Incoming edges (entity as tail)
|
|
215
|
-
for triple in self.get_facts_by_tail(entity):
|
|
216
|
-
if relation is None or triple.relation == relation:
|
|
217
|
-
neighbors.add(triple.head)
|
|
218
|
-
|
|
219
|
-
return neighbors
|
|
220
|
-
|
|
221
|
-
def sample_negative_facts(self, num_samples: int,
|
|
222
|
-
corruption_mode: str = "both") -> List[Triple]:
|
|
223
|
-
"""
|
|
224
|
-
Generate negative facts for training
|
|
225
|
-
Corruption modes: 'head', 'tail', 'both'
|
|
226
|
-
"""
|
|
227
|
-
negative_facts = []
|
|
228
|
-
entities_list = list(self.entities)
|
|
229
|
-
|
|
230
|
-
for _ in range(num_samples):
|
|
231
|
-
# Sample a known fact to corrupt
|
|
232
|
-
known_fact = next(iter(self._known_facts))
|
|
233
|
-
|
|
234
|
-
if corruption_mode == "head" or (corruption_mode == "both" and len(negative_facts) % 2 == 0):
|
|
235
|
-
# Corrupt head entity
|
|
236
|
-
corrupt_head = entities_list[hash(known_fact) % len(entities_list)]
|
|
237
|
-
negative_triple = Triple(corrupt_head, known_fact.relation, known_fact.tail)
|
|
238
|
-
else:
|
|
239
|
-
# Corrupt tail entity
|
|
240
|
-
corrupt_tail = entities_list[hash(known_fact) % len(entities_list)]
|
|
241
|
-
negative_triple = Triple(known_fact.head, known_fact.relation, corrupt_tail)
|
|
242
|
-
|
|
243
|
-
# Ensure it's actually negative
|
|
244
|
-
if negative_triple not in self._known_facts:
|
|
245
|
-
negative_facts.append(negative_triple)
|
|
246
|
-
|
|
247
|
-
return negative_facts
|
|
248
|
-
|
|
249
|
-
def get_statistics(self) -> Dict[str, int]:
|
|
250
|
-
"""Get knowledge graph statistics"""
|
|
251
|
-
return self._stats.copy()
|
|
252
|
-
|
|
253
|
-
def __str__(self) -> str:
|
|
254
|
-
return (f"KnowledgeGraph({self.name}): "
|
|
255
|
-
f"{self._stats['num_entities']} entities, "
|
|
256
|
-
f"{self._stats['num_relations']} relations, "
|
|
257
|
-
f"{self._stats['num_known_facts']} known facts")
|
|
258
|
-
|
|
259
|
-
def __repr__(self) -> str:
|
|
260
|
-
return self.__str__()
|
|
261
|
-
|
|
262
|
-
# --- Serialization helpers for robust save/load ---
|
|
263
|
-
def serialize(self) -> Dict[str, Any]:
|
|
264
|
-
"""
|
|
265
|
-
Serialize the knowledge graph to a portable dict.
|
|
266
|
-
Entities/relations are referenced by name; facts are triplets of names.
|
|
267
|
-
"""
|
|
268
|
-
return {
|
|
269
|
-
'name': self.name,
|
|
270
|
-
'entities': [e.name for e in self.entities],
|
|
271
|
-
'relations': [r.name for r in self.relations],
|
|
272
|
-
'known_facts': [
|
|
273
|
-
(t.head.name, t.relation.name, t.tail.name) for t in self._known_facts
|
|
274
|
-
],
|
|
275
|
-
'unknown_facts': [
|
|
276
|
-
(t.head.name, t.relation.name, t.tail.name) for t in self._unknown_facts
|
|
277
|
-
],
|
|
278
|
-
}
|
|
279
|
-
|
|
280
|
-
@staticmethod
|
|
281
|
-
def deserialize(data: Dict[str, Any]) -> 'KnowledgeGraph':
|
|
282
|
-
"""
|
|
283
|
-
Deserialize a knowledge graph previously produced by serialize().
|
|
284
|
-
"""
|
|
285
|
-
kg = KnowledgeGraph(name=data.get('name', 'KnowledgeGraph'))
|
|
286
|
-
# Pre-create entities and relations
|
|
287
|
-
for e in data.get('entities', []):
|
|
288
|
-
kg.add_entity(e)
|
|
289
|
-
for r in data.get('relations', []):
|
|
290
|
-
kg.add_relation(r)
|
|
291
|
-
for h, r, t in data.get('known_facts', []):
|
|
292
|
-
kg.add_known_fact(h, r, t)
|
|
293
|
-
for h, r, t in data.get('unknown_facts', []):
|
|
294
|
-
kg.add_unknown_fact(h, r, t)
|
|
295
|
-
return kg
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
def load_knowledge_graph_from_triples(triples: List[Tuple[str, str, str]],
|
|
299
|
-
name: str = "LoadedKG") -> KnowledgeGraph:
|
|
300
|
-
"""
|
|
301
|
-
Load knowledge graph from list of (head, relation, tail) tuples
|
|
302
|
-
"""
|
|
303
|
-
kg = KnowledgeGraph(name=name)
|
|
304
|
-
|
|
305
|
-
for head, relation, tail in triples:
|
|
306
|
-
kg.add_known_fact(head, relation, tail)
|
|
307
|
-
|
|
308
|
-
logger.info(f"Loaded knowledge graph: {kg}")
|
|
1
|
+
|
|
2
|
+
from typing import Dict, List, Set, Tuple, Optional, Iterator, Any
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True, eq=True)
|
|
11
|
+
class Entity:
|
|
12
|
+
"""Entity representation"""
|
|
13
|
+
name: str
|
|
14
|
+
entity_id: Optional[int] = None
|
|
15
|
+
|
|
16
|
+
def __str__(self) -> str:
|
|
17
|
+
return self.name
|
|
18
|
+
|
|
19
|
+
def __hash__(self) -> int:
|
|
20
|
+
return hash(self.name)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True, eq=True)
|
|
24
|
+
class Relation:
|
|
25
|
+
"""Relation/Predicate representation """
|
|
26
|
+
name: str
|
|
27
|
+
relation_id: Optional[int] = None
|
|
28
|
+
|
|
29
|
+
def __str__(self) -> str:
|
|
30
|
+
return self.name
|
|
31
|
+
|
|
32
|
+
def __hash__(self) -> int:
|
|
33
|
+
return hash(self.name)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True, eq=True)
|
|
37
|
+
class Triple:
|
|
38
|
+
"""
|
|
39
|
+
Triple representation: (head, relation, tail) or l(eh, et)
|
|
40
|
+
"""
|
|
41
|
+
head: Entity
|
|
42
|
+
relation: Relation
|
|
43
|
+
tail: Entity
|
|
44
|
+
|
|
45
|
+
def __str__(self) -> str:
|
|
46
|
+
return f"{self.relation.name}({self.head.name}, {self.tail.name})"
|
|
47
|
+
|
|
48
|
+
def __hash__(self) -> int:
|
|
49
|
+
return hash((self.head, self.relation, self.tail))
|
|
50
|
+
|
|
51
|
+
def to_predicate_logic(self) -> str:
|
|
52
|
+
"""Convert to predicate logic format: relation(head, tail)"""
|
|
53
|
+
return f"{self.relation.name}({self.head.name}, {self.tail.name})"
|
|
54
|
+
|
|
55
|
+
def is_valid(self) -> bool:
|
|
56
|
+
"""Check if triple has valid components"""
|
|
57
|
+
return all([
|
|
58
|
+
self.head.name.strip(),
|
|
59
|
+
self.relation.name.strip(),
|
|
60
|
+
self.tail.name.strip()
|
|
61
|
+
])
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class KnowledgeGraph:
|
|
65
|
+
|
|
66
|
+
def __init__(self, name: str = "KnowledgeGraph"):
|
|
67
|
+
self.name = name
|
|
68
|
+
|
|
69
|
+
# Core components as per paper definition
|
|
70
|
+
self._entities: Dict[str, Entity] = {} # E: entity set
|
|
71
|
+
self._relations: Dict[str, Relation] = {} # L: relation set
|
|
72
|
+
self._known_facts: Set[Triple] = set() # F: known facts
|
|
73
|
+
self._unknown_facts: Set[Triple] = set() # U: unknown facts
|
|
74
|
+
|
|
75
|
+
# Indexing for efficient access
|
|
76
|
+
self._head_index: Dict[Entity, Set[Triple]] = defaultdict(set)
|
|
77
|
+
self._tail_index: Dict[Entity, Set[Triple]] = defaultdict(set)
|
|
78
|
+
self._relation_index: Dict[Relation, Set[Triple]] = defaultdict(set)
|
|
79
|
+
|
|
80
|
+
# Statistics
|
|
81
|
+
self._stats = {
|
|
82
|
+
'num_entities': 0,
|
|
83
|
+
'num_relations': 0,
|
|
84
|
+
'num_known_facts': 0,
|
|
85
|
+
'num_unknown_facts': 0
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def entities(self) -> Set[Entity]:
|
|
90
|
+
"""Get entity set E = {e1, e2, ..., eM}"""
|
|
91
|
+
return set(self._entities.values())
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def relations(self) -> Set[Relation]:
|
|
95
|
+
"""Get relation set L = {l1, l2, ..., lN}"""
|
|
96
|
+
return set(self._relations.values())
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def known_facts(self) -> Set[Triple]:
|
|
100
|
+
"""Get known facts set F = {f1, f2, ..., fS}"""
|
|
101
|
+
return self._known_facts.copy()
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def unknown_facts(self) -> Set[Triple]:
|
|
105
|
+
"""Get unknown facts set U (for inference)"""
|
|
106
|
+
return self._unknown_facts.copy()
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def all_facts(self) -> Set[Triple]:
|
|
110
|
+
"""Get all facts F ∪ U"""
|
|
111
|
+
return self._known_facts | self._unknown_facts
|
|
112
|
+
|
|
113
|
+
def get_entity(self, name: str) -> Optional[Entity]:
|
|
114
|
+
"""Get entity by name"""
|
|
115
|
+
return self._entities.get(name)
|
|
116
|
+
|
|
117
|
+
def get_relation(self, name: str) -> Optional[Relation]:
|
|
118
|
+
"""Get relation by name"""
|
|
119
|
+
return self._relations.get(name)
|
|
120
|
+
|
|
121
|
+
def add_entity(self, name: str) -> Entity:
|
|
122
|
+
"""Add entity to E set"""
|
|
123
|
+
if name not in self._entities:
|
|
124
|
+
entity_id = len(self._entities)
|
|
125
|
+
entity = Entity(name=name, entity_id=entity_id)
|
|
126
|
+
self._entities[name] = entity
|
|
127
|
+
self._stats['num_entities'] = len(self._entities)
|
|
128
|
+
logger.debug(f"Added entity: {name}")
|
|
129
|
+
return self._entities[name]
|
|
130
|
+
|
|
131
|
+
def add_relation(self, name: str) -> Relation:
|
|
132
|
+
"""Add relation to L set"""
|
|
133
|
+
if name not in self._relations:
|
|
134
|
+
relation_id = len(self._relations)
|
|
135
|
+
relation = Relation(name=name, relation_id=relation_id)
|
|
136
|
+
self._relations[name] = relation
|
|
137
|
+
self._stats['num_relations'] = len(self._relations)
|
|
138
|
+
logger.debug(f"Added relation: {name}")
|
|
139
|
+
return self._relations[name]
|
|
140
|
+
|
|
141
|
+
def add_known_fact(self, head: str, relation: str, tail: str) -> Triple:
|
|
142
|
+
"""
|
|
143
|
+
Add known fact to F set
|
|
144
|
+
"""
|
|
145
|
+
head_entity = self.add_entity(head)
|
|
146
|
+
relation_obj = self.add_relation(relation)
|
|
147
|
+
tail_entity = self.add_entity(tail)
|
|
148
|
+
|
|
149
|
+
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
150
|
+
|
|
151
|
+
if triple.is_valid() and triple not in self._known_facts:
|
|
152
|
+
self._known_facts.add(triple)
|
|
153
|
+
self._update_indices(triple)
|
|
154
|
+
self._stats['num_known_facts'] = len(self._known_facts)
|
|
155
|
+
logger.debug(f"Added known fact: {triple}")
|
|
156
|
+
|
|
157
|
+
return triple
|
|
158
|
+
|
|
159
|
+
def add_unknown_fact(self, head: str, relation: str, tail: str) -> Triple:
|
|
160
|
+
"""Add unknown fact to U set (for inference)"""
|
|
161
|
+
head_entity = self.add_entity(head)
|
|
162
|
+
relation_obj = self.add_relation(relation)
|
|
163
|
+
tail_entity = self.add_entity(tail)
|
|
164
|
+
|
|
165
|
+
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
166
|
+
|
|
167
|
+
if triple.is_valid() and triple not in self._unknown_facts:
|
|
168
|
+
self._unknown_facts.add(triple)
|
|
169
|
+
self._update_indices(triple)
|
|
170
|
+
self._stats['num_unknown_facts'] = len(self._unknown_facts)
|
|
171
|
+
logger.debug(f"Added unknown fact: {triple}")
|
|
172
|
+
|
|
173
|
+
return triple
|
|
174
|
+
|
|
175
|
+
def _update_indices(self, triple: Triple):
|
|
176
|
+
"""Update internal indices for efficient querying"""
|
|
177
|
+
self._head_index[triple.head].add(triple)
|
|
178
|
+
self._tail_index[triple.tail].add(triple)
|
|
179
|
+
self._relation_index[triple.relation].add(triple)
|
|
180
|
+
|
|
181
|
+
def get_facts_by_head(self, entity: Entity) -> Set[Triple]:
|
|
182
|
+
"""Get all facts with given entity as head"""
|
|
183
|
+
return self._head_index[entity].copy()
|
|
184
|
+
|
|
185
|
+
def get_facts_by_tail(self, entity: Entity) -> Set[Triple]:
|
|
186
|
+
"""Get all facts with given entity as tail"""
|
|
187
|
+
return self._tail_index[entity].copy()
|
|
188
|
+
|
|
189
|
+
def get_facts_by_relation(self, relation: Relation) -> Set[Triple]:
|
|
190
|
+
"""Get all facts with given relation"""
|
|
191
|
+
return self._relation_index[relation].copy()
|
|
192
|
+
|
|
193
|
+
def contains_fact(self, head: str, relation: str, tail: str) -> bool:
|
|
194
|
+
"""Check if fact exists in known facts F"""
|
|
195
|
+
head_entity = self.get_entity(head)
|
|
196
|
+
relation_obj = self.get_relation(relation)
|
|
197
|
+
tail_entity = self.get_entity(tail)
|
|
198
|
+
|
|
199
|
+
if not all([head_entity, relation_obj, tail_entity]):
|
|
200
|
+
return False
|
|
201
|
+
|
|
202
|
+
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
203
|
+
return triple in self._known_facts
|
|
204
|
+
|
|
205
|
+
def get_neighbors(self, entity: Entity, relation: Optional[Relation] = None) -> Set[Entity]:
|
|
206
|
+
"""Get neighboring entities connected by relation"""
|
|
207
|
+
neighbors = set()
|
|
208
|
+
|
|
209
|
+
# Outgoing edges (entity as head)
|
|
210
|
+
for triple in self.get_facts_by_head(entity):
|
|
211
|
+
if relation is None or triple.relation == relation:
|
|
212
|
+
neighbors.add(triple.tail)
|
|
213
|
+
|
|
214
|
+
# Incoming edges (entity as tail)
|
|
215
|
+
for triple in self.get_facts_by_tail(entity):
|
|
216
|
+
if relation is None or triple.relation == relation:
|
|
217
|
+
neighbors.add(triple.head)
|
|
218
|
+
|
|
219
|
+
return neighbors
|
|
220
|
+
|
|
221
|
+
def sample_negative_facts(self, num_samples: int,
|
|
222
|
+
corruption_mode: str = "both") -> List[Triple]:
|
|
223
|
+
"""
|
|
224
|
+
Generate negative facts for training
|
|
225
|
+
Corruption modes: 'head', 'tail', 'both'
|
|
226
|
+
"""
|
|
227
|
+
negative_facts = []
|
|
228
|
+
entities_list = list(self.entities)
|
|
229
|
+
|
|
230
|
+
for _ in range(num_samples):
|
|
231
|
+
# Sample a known fact to corrupt
|
|
232
|
+
known_fact = next(iter(self._known_facts))
|
|
233
|
+
|
|
234
|
+
if corruption_mode == "head" or (corruption_mode == "both" and len(negative_facts) % 2 == 0):
|
|
235
|
+
# Corrupt head entity
|
|
236
|
+
corrupt_head = entities_list[hash(known_fact) % len(entities_list)]
|
|
237
|
+
negative_triple = Triple(corrupt_head, known_fact.relation, known_fact.tail)
|
|
238
|
+
else:
|
|
239
|
+
# Corrupt tail entity
|
|
240
|
+
corrupt_tail = entities_list[hash(known_fact) % len(entities_list)]
|
|
241
|
+
negative_triple = Triple(known_fact.head, known_fact.relation, corrupt_tail)
|
|
242
|
+
|
|
243
|
+
# Ensure it's actually negative
|
|
244
|
+
if negative_triple not in self._known_facts:
|
|
245
|
+
negative_facts.append(negative_triple)
|
|
246
|
+
|
|
247
|
+
return negative_facts
|
|
248
|
+
|
|
249
|
+
def get_statistics(self) -> Dict[str, int]:
|
|
250
|
+
"""Get knowledge graph statistics"""
|
|
251
|
+
return self._stats.copy()
|
|
252
|
+
|
|
253
|
+
def __str__(self) -> str:
|
|
254
|
+
return (f"KnowledgeGraph({self.name}): "
|
|
255
|
+
f"{self._stats['num_entities']} entities, "
|
|
256
|
+
f"{self._stats['num_relations']} relations, "
|
|
257
|
+
f"{self._stats['num_known_facts']} known facts")
|
|
258
|
+
|
|
259
|
+
def __repr__(self) -> str:
|
|
260
|
+
return self.__str__()
|
|
261
|
+
|
|
262
|
+
# --- Serialization helpers for robust save/load ---
|
|
263
|
+
def serialize(self) -> Dict[str, Any]:
|
|
264
|
+
"""
|
|
265
|
+
Serialize the knowledge graph to a portable dict.
|
|
266
|
+
Entities/relations are referenced by name; facts are triplets of names.
|
|
267
|
+
"""
|
|
268
|
+
return {
|
|
269
|
+
'name': self.name,
|
|
270
|
+
'entities': [e.name for e in self.entities],
|
|
271
|
+
'relations': [r.name for r in self.relations],
|
|
272
|
+
'known_facts': [
|
|
273
|
+
(t.head.name, t.relation.name, t.tail.name) for t in self._known_facts
|
|
274
|
+
],
|
|
275
|
+
'unknown_facts': [
|
|
276
|
+
(t.head.name, t.relation.name, t.tail.name) for t in self._unknown_facts
|
|
277
|
+
],
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def deserialize(data: Dict[str, Any]) -> 'KnowledgeGraph':
|
|
282
|
+
"""
|
|
283
|
+
Deserialize a knowledge graph previously produced by serialize().
|
|
284
|
+
"""
|
|
285
|
+
kg = KnowledgeGraph(name=data.get('name', 'KnowledgeGraph'))
|
|
286
|
+
# Pre-create entities and relations
|
|
287
|
+
for e in data.get('entities', []):
|
|
288
|
+
kg.add_entity(e)
|
|
289
|
+
for r in data.get('relations', []):
|
|
290
|
+
kg.add_relation(r)
|
|
291
|
+
for h, r, t in data.get('known_facts', []):
|
|
292
|
+
kg.add_known_fact(h, r, t)
|
|
293
|
+
for h, r, t in data.get('unknown_facts', []):
|
|
294
|
+
kg.add_unknown_fact(h, r, t)
|
|
295
|
+
return kg
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def load_knowledge_graph_from_triples(triples: List[Tuple[str, str, str]],
|
|
299
|
+
name: str = "LoadedKG") -> KnowledgeGraph:
|
|
300
|
+
"""
|
|
301
|
+
Load knowledge graph from list of (head, relation, tail) tuples
|
|
302
|
+
"""
|
|
303
|
+
kg = KnowledgeGraph(name=name)
|
|
304
|
+
|
|
305
|
+
for head, relation, tail in triples:
|
|
306
|
+
kg.add_known_fact(head, relation, tail)
|
|
307
|
+
|
|
308
|
+
logger.info(f"Loaded knowledge graph: {kg}")
|
|
309
309
|
return kg
|