odin-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. benchmarks/__init__.py +17 -0
  2. benchmarks/datasets.py +284 -0
  3. benchmarks/metrics.py +275 -0
  4. benchmarks/run_ablation.py +279 -0
  5. benchmarks/run_npll_benchmark.py +270 -0
  6. npll/__init__.py +10 -0
  7. npll/bootstrap.py +474 -0
  8. npll/core/__init__.py +34 -0
  9. npll/core/knowledge_graph.py +309 -0
  10. npll/core/logical_rules.py +497 -0
  11. npll/core/mln.py +475 -0
  12. npll/inference/__init__.py +41 -0
  13. npll/inference/e_step.py +420 -0
  14. npll/inference/elbo.py +435 -0
  15. npll/inference/m_step.py +577 -0
  16. npll/npll_model.py +632 -0
  17. npll/scoring/__init__.py +43 -0
  18. npll/scoring/embeddings.py +442 -0
  19. npll/scoring/probability.py +403 -0
  20. npll/scoring/scoring_module.py +370 -0
  21. npll/training/__init__.py +25 -0
  22. npll/training/evaluation.py +497 -0
  23. npll/training/npll_trainer.py +521 -0
  24. npll/utils/__init__.py +48 -0
  25. npll/utils/batch_utils.py +493 -0
  26. npll/utils/config.py +145 -0
  27. npll/utils/math_utils.py +339 -0
  28. odin/__init__.py +20 -0
  29. odin/engine.py +264 -0
  30. odin_engine-0.1.0.dist-info/METADATA +456 -0
  31. odin_engine-0.1.0.dist-info/RECORD +62 -0
  32. odin_engine-0.1.0.dist-info/WHEEL +5 -0
  33. odin_engine-0.1.0.dist-info/licenses/LICENSE +21 -0
  34. odin_engine-0.1.0.dist-info/top_level.txt +4 -0
  35. retrieval/__init__.py +50 -0
  36. retrieval/adapters.py +140 -0
  37. retrieval/adapters_arango.py +1418 -0
  38. retrieval/aggregators.py +707 -0
  39. retrieval/beam.py +127 -0
  40. retrieval/budget.py +60 -0
  41. retrieval/cache.py +159 -0
  42. retrieval/confidence.py +88 -0
  43. retrieval/eval.py +49 -0
  44. retrieval/linker.py +87 -0
  45. retrieval/metrics.py +105 -0
  46. retrieval/metrics_motifs.py +36 -0
  47. retrieval/orchestrator.py +571 -0
  48. retrieval/ppr/__init__.py +12 -0
  49. retrieval/ppr/anchors.py +41 -0
  50. retrieval/ppr/bippr.py +61 -0
  51. retrieval/ppr/engines.py +257 -0
  52. retrieval/ppr/global_pr.py +76 -0
  53. retrieval/ppr/indexes.py +78 -0
  54. retrieval/ppr.py +156 -0
  55. retrieval/ppr_cache.py +25 -0
  56. retrieval/scoring.py +294 -0
  57. retrieval/utils/__init__.py +0 -0
  58. retrieval/utils/pii_redaction.py +36 -0
  59. retrieval/writers/__init__.py +9 -0
  60. retrieval/writers/arango_writer.py +28 -0
  61. retrieval/writers/base.py +21 -0
  62. retrieval/writers/janus_writer.py +36 -0
benchmarks/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ """
2
+ Odin Benchmarks: Academic validation against standard KG datasets.
3
+
4
+ This module provides:
5
+ - Standard dataset loaders (FB15k-237, WN18RR)
6
+ - KG completion metrics (MRR, Hits@K)
7
+ - Benchmark runners for NPLL evaluation
8
+ - Ablation study tools
9
+ """
10
+
11
+ from .metrics import mrr, hits_at_k, evaluate_rankings
12
+ from .datasets import load_fb15k237, load_wn18rr, BenchmarkDataset
13
+
14
+ __all__ = [
15
+ "mrr", "hits_at_k", "evaluate_rankings",
16
+ "load_fb15k237", "load_wn18rr", "BenchmarkDataset"
17
+ ]
benchmarks/datasets.py ADDED
@@ -0,0 +1,284 @@
1
+ """
2
+ Standard Knowledge Graph Benchmark Datasets
3
+
4
+ Provides loaders for:
5
+ - FB15k-237: Freebase subset (14,541 entities, 237 relations)
6
+ - WN18RR: WordNet subset (40,943 entities, 11 relations)
7
+
8
+ Datasets are downloaded from standard sources and cached locally.
9
+ """
10
+
11
+ import os
12
+ import urllib.request
13
+ import tarfile
14
+ import zipfile
15
+ from pathlib import Path
16
+ from typing import List, Tuple, Set, Dict, Optional
17
+ from dataclasses import dataclass
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Dataset URLs - using villmow/datasets_knowledge_embedding (reliable raw files)
23
+ DATASET_BASE_URLS = {
24
+ "fb15k237": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237",
25
+ "wn18rr": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR",
26
+ }
27
+
28
+ # Default cache directory
29
+ DEFAULT_CACHE_DIR = Path.home() / ".cache" / "odin_benchmarks"
30
+
31
+
32
+ @dataclass
33
+ class BenchmarkDataset:
34
+ """Container for a benchmark dataset."""
35
+ name: str
36
+ train_triples: List[Tuple[str, str, str]]
37
+ valid_triples: List[Tuple[str, str, str]]
38
+ test_triples: List[Tuple[str, str, str]]
39
+ entities: List[str]
40
+ relations: List[str]
41
+
42
+ @property
43
+ def num_entities(self) -> int:
44
+ return len(self.entities)
45
+
46
+ @property
47
+ def num_relations(self) -> int:
48
+ return len(self.relations)
49
+
50
+ @property
51
+ def num_train(self) -> int:
52
+ return len(self.train_triples)
53
+
54
+ @property
55
+ def num_valid(self) -> int:
56
+ return len(self.valid_triples)
57
+
58
+ @property
59
+ def num_test(self) -> int:
60
+ return len(self.test_triples)
61
+
62
+ def get_train_set(self) -> Set[Tuple[str, str, str]]:
63
+ return set(self.train_triples)
64
+
65
+ def get_all_triples(self) -> Set[Tuple[str, str, str]]:
66
+ return set(self.train_triples + self.valid_triples + self.test_triples)
67
+
68
+ def __repr__(self) -> str:
69
+ return (
70
+ f"BenchmarkDataset({self.name})\n"
71
+ f" Entities: {self.num_entities:,}\n"
72
+ f" Relations: {self.num_relations}\n"
73
+ f" Train: {self.num_train:,}\n"
74
+ f" Valid: {self.num_valid:,}\n"
75
+ f" Test: {self.num_test:,}"
76
+ )
77
+
78
+
79
+ def _ensure_dir(path: Path):
80
+ """Ensure directory exists."""
81
+ path.mkdir(parents=True, exist_ok=True)
82
+
83
+
84
+ def _download_file(url: str, dest: Path):
85
+ """Download a file with progress."""
86
+ import ssl
87
+ import certifi
88
+
89
+ logger.info(f"Downloading {url}...")
90
+
91
+ # Try with certifi SSL context first, fall back to unverified
92
+ try:
93
+ ssl_context = ssl.create_default_context(cafile=certifi.where())
94
+ with urllib.request.urlopen(url, context=ssl_context) as response:
95
+ with open(dest, 'wb') as out_file:
96
+ out_file.write(response.read())
97
+ except (ImportError, ssl.SSLError):
98
+ # Fallback: disable SSL verification (for development only)
99
+ logger.warning("SSL verification disabled - using unverified context")
100
+ ssl_context = ssl.create_default_context()
101
+ ssl_context.check_hostname = False
102
+ ssl_context.verify_mode = ssl.CERT_NONE
103
+ with urllib.request.urlopen(url, context=ssl_context) as response:
104
+ with open(dest, 'wb') as out_file:
105
+ out_file.write(response.read())
106
+
107
+ logger.info(f"Downloaded to {dest}")
108
+
109
+
110
+ def _extract_tar_gz(archive: Path, dest_dir: Path):
111
+ """Extract a tar.gz archive."""
112
+ logger.info(f"Extracting {archive}...")
113
+ with tarfile.open(archive, "r:gz") as tar:
114
+ tar.extractall(dest_dir)
115
+
116
+
117
+ def _extract_zip(archive: Path, dest_dir: Path):
118
+ """Extract a zip archive."""
119
+ logger.info(f"Extracting {archive}...")
120
+ with zipfile.ZipFile(archive, 'r') as zip_ref:
121
+ zip_ref.extractall(dest_dir)
122
+
123
+
124
+ def _load_triples(filepath: Path) -> List[Tuple[str, str, str]]:
125
+ """Load triples from a TSV file (head, relation, tail)."""
126
+ triples = []
127
+ with open(filepath, "r", encoding="utf-8") as f:
128
+ for line in f:
129
+ parts = line.strip().split("\t")
130
+ if len(parts) == 3:
131
+ h, r, t = parts
132
+ triples.append((h, r, t))
133
+ return triples
134
+
135
+
136
+ def _extract_entities_and_relations(
137
+ triples: List[Tuple[str, str, str]]
138
+ ) -> Tuple[List[str], List[str]]:
139
+ """Extract unique entities and relations from triples."""
140
+ entities = set()
141
+ relations = set()
142
+ for h, r, t in triples:
143
+ entities.add(h)
144
+ entities.add(t)
145
+ relations.add(r)
146
+ return sorted(entities), sorted(relations)
147
+
148
+
149
+ def _download_dataset_files(base_url: str, dataset_dir: Path):
150
+ """Download train/valid/test files for a dataset."""
151
+ _ensure_dir(dataset_dir)
152
+
153
+ for split in ["train", "valid", "test"]:
154
+ file_path = dataset_dir / f"{split}.txt"
155
+ if not file_path.exists():
156
+ url = f"{base_url}/{split}.txt"
157
+ _download_file(url, file_path)
158
+
159
+
160
+ def load_fb15k237(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
161
+ """
162
+ Load FB15k-237 dataset.
163
+
164
+ FB15k-237 is a subset of Freebase with:
165
+ - 14,541 entities
166
+ - 237 relations
167
+ - 310,116 triples
168
+
169
+ This version removes inverse relations from FB15k to prevent
170
+ data leakage during evaluation.
171
+
172
+ Args:
173
+ cache_dir: Directory to cache downloaded data
174
+
175
+ Returns:
176
+ BenchmarkDataset object
177
+ """
178
+ cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
179
+ _ensure_dir(cache_dir)
180
+
181
+ dataset_dir = cache_dir / "FB15k-237"
182
+
183
+ # Download if not cached
184
+ if not (dataset_dir / "train.txt").exists():
185
+ _download_dataset_files(DATASET_BASE_URLS["fb15k237"], dataset_dir)
186
+
187
+ # Load splits
188
+ train = _load_triples(dataset_dir / "train.txt")
189
+ valid = _load_triples(dataset_dir / "valid.txt")
190
+ test = _load_triples(dataset_dir / "test.txt")
191
+
192
+ # Extract vocab
193
+ all_triples = train + valid + test
194
+ entities, relations = _extract_entities_and_relations(all_triples)
195
+
196
+ return BenchmarkDataset(
197
+ name="FB15k-237",
198
+ train_triples=train,
199
+ valid_triples=valid,
200
+ test_triples=test,
201
+ entities=entities,
202
+ relations=relations,
203
+ )
204
+
205
+
206
+ def load_wn18rr(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
207
+ """
208
+ Load WN18RR dataset.
209
+
210
+ WN18RR is a subset of WordNet with:
211
+ - 40,943 entities
212
+ - 11 relations
213
+ - 93,003 triples
214
+
215
+ This version removes inverse relations from WN18 to prevent
216
+ data leakage during evaluation.
217
+
218
+ Args:
219
+ cache_dir: Directory to cache downloaded data
220
+
221
+ Returns:
222
+ BenchmarkDataset object
223
+ """
224
+ cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
225
+ _ensure_dir(cache_dir)
226
+
227
+ dataset_dir = cache_dir / "WN18RR"
228
+
229
+ # Download if not cached
230
+ if not (dataset_dir / "train.txt").exists():
231
+ _download_dataset_files(DATASET_BASE_URLS["wn18rr"], dataset_dir)
232
+
233
+ # Load splits
234
+ train = _load_triples(dataset_dir / "train.txt")
235
+ valid = _load_triples(dataset_dir / "valid.txt")
236
+ test = _load_triples(dataset_dir / "test.txt")
237
+
238
+ # Extract vocab
239
+ all_triples = train + valid + test
240
+ entities, relations = _extract_entities_and_relations(all_triples)
241
+
242
+ return BenchmarkDataset(
243
+ name="WN18RR",
244
+ train_triples=train,
245
+ valid_triples=valid,
246
+ test_triples=test,
247
+ entities=entities,
248
+ relations=relations,
249
+ )
250
+
251
+
252
+ def dataset_to_kg(dataset: BenchmarkDataset):
253
+ """
254
+ Convert BenchmarkDataset to Odin KnowledgeGraph.
255
+
256
+ Args:
257
+ dataset: BenchmarkDataset object
258
+
259
+ Returns:
260
+ KnowledgeGraph object suitable for NPLL training
261
+ """
262
+ from npll.core import KnowledgeGraph
263
+
264
+ kg = KnowledgeGraph()
265
+
266
+ # Add all training triples as known facts
267
+ for h, r, t in dataset.train_triples:
268
+ kg.add_known_fact(h, r, t)
269
+
270
+ return kg
271
+
272
+
273
+ # Quick test
274
+ if __name__ == "__main__":
275
+ logging.basicConfig(level=logging.INFO)
276
+
277
+ print("Loading FB15k-237...")
278
+ fb = load_fb15k237()
279
+ print(fb)
280
+ print()
281
+
282
+ print("Loading WN18RR...")
283
+ wn = load_wn18rr()
284
+ print(wn)
benchmarks/metrics.py ADDED
@@ -0,0 +1,275 @@
1
+ """
2
+ Standard Knowledge Graph Completion Metrics
3
+
4
+ Implements metrics used in academic KG completion papers:
5
+ - Mean Reciprocal Rank (MRR)
6
+ - Hits@K (K=1, 3, 10)
7
+
8
+ These metrics evaluate link prediction quality:
9
+ Given (h, r, ?), rank all entities by predicted score.
10
+ """
11
+
12
+ import numpy as np
13
+ from typing import List, Dict, Tuple, Any
14
+ from dataclasses import dataclass
15
+
16
+
17
+ @dataclass
18
+ class RankingResult:
19
+ """Result of ranking evaluation for a single triple."""
20
+ head: str
21
+ relation: str
22
+ tail: str
23
+ tail_rank: int # Rank of true tail among all candidates
24
+ head_rank: int # Rank of true head among all candidates (for inverse)
25
+ num_candidates: int
26
+
27
+
28
+ def mrr(ranks: List[int]) -> float:
29
+ """
30
+ Mean Reciprocal Rank
31
+
32
+ MRR = (1/|Q|) * Σ (1/rank_i)
33
+
34
+ Args:
35
+ ranks: List of ranks (1-indexed, where 1 is best)
36
+
37
+ Returns:
38
+ MRR score in [0, 1]
39
+ """
40
+ if not ranks:
41
+ return 0.0
42
+ return float(np.mean([1.0 / r for r in ranks]))
43
+
44
+
45
+ def hits_at_k(ranks: List[int], k: int) -> float:
46
+ """
47
+ Hits@K - proportion of ranks <= K
48
+
49
+ Args:
50
+ ranks: List of ranks (1-indexed)
51
+ k: Cutoff threshold
52
+
53
+ Returns:
54
+ Hits@K score in [0, 1]
55
+ """
56
+ if not ranks:
57
+ return 0.0
58
+ return float(np.mean([1 if r <= k else 0 for r in ranks]))
59
+
60
+
61
+ def evaluate_rankings(results: List[RankingResult]) -> Dict[str, float]:
62
+ """
63
+ Compute all standard metrics from ranking results.
64
+
65
+ Args:
66
+ results: List of RankingResult objects
67
+
68
+ Returns:
69
+ Dictionary with MRR, Hits@1, Hits@3, Hits@10
70
+ """
71
+ if not results:
72
+ return {
73
+ "mrr": 0.0,
74
+ "hits@1": 0.0,
75
+ "hits@3": 0.0,
76
+ "hits@10": 0.0,
77
+ "num_queries": 0
78
+ }
79
+
80
+ # Use tail ranks (standard protocol: predict tail given head+relation)
81
+ tail_ranks = [r.tail_rank for r in results]
82
+
83
+ return {
84
+ "mrr": mrr(tail_ranks),
85
+ "hits@1": hits_at_k(tail_ranks, 1),
86
+ "hits@3": hits_at_k(tail_ranks, 3),
87
+ "hits@10": hits_at_k(tail_ranks, 10),
88
+ "num_queries": len(results),
89
+ "mean_rank": float(np.mean(tail_ranks)),
90
+ "median_rank": float(np.median(tail_ranks)),
91
+ }
92
+
93
+
94
+ def filtered_rank(
95
+ true_entity: str,
96
+ scores: Dict[str, float],
97
+ filter_entities: set,
98
+ ) -> int:
99
+ """
100
+ Compute filtered rank (standard protocol).
101
+
102
+ Filtered ranking removes other valid answers from consideration,
103
+ so we only penalize for ranking random entities above the true one.
104
+
105
+ Args:
106
+ true_entity: The correct answer
107
+ scores: Dict mapping entity -> score
108
+ filter_entities: Other valid entities to filter out
109
+
110
+ Returns:
111
+ Filtered rank (1-indexed)
112
+ """
113
+ true_score = scores.get(true_entity, float('-inf'))
114
+
115
+ rank = 1
116
+ for entity, score in scores.items():
117
+ if entity == true_entity:
118
+ continue
119
+ if entity in filter_entities:
120
+ continue # Filter out other valid answers
121
+ if score > true_score:
122
+ rank += 1
123
+
124
+ return rank
125
+
126
+
127
+ class LinkPredictionEvaluator:
128
+ """
129
+ Evaluator for link prediction task.
130
+
131
+ Standard protocol:
132
+ 1. For each test triple (h, r, t):
133
+ a. Corrupt tail: score all (h, r, e) for e in entities
134
+ b. Corrupt head: score all (e, r, t) for e in entities
135
+ 2. Filter out other valid triples (filtered setting)
136
+ 3. Report filtered MRR, Hits@1/3/10
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ all_entities: List[str],
142
+ train_triples: set,
143
+ valid_triples: set = None,
144
+ ):
145
+ """
146
+ Args:
147
+ all_entities: List of all entity IDs
148
+ train_triples: Set of (h, r, t) tuples from training
149
+ valid_triples: Set of (h, r, t) tuples from validation (optional)
150
+ """
151
+ self.all_entities = all_entities
152
+ self.train_triples = train_triples
153
+ self.valid_triples = valid_triples or set()
154
+ self.known_triples = train_triples | self.valid_triples
155
+
156
+ # Build index for filtering
157
+ self._build_filter_index()
158
+
159
+ def _build_filter_index(self):
160
+ """Build indices for efficient filtering."""
161
+ # For (h, r, ?): which tails are valid?
162
+ self.hr_to_tails = {}
163
+ # For (?, r, t): which heads are valid?
164
+ self.rt_to_heads = {}
165
+
166
+ for h, r, t in self.known_triples:
167
+ key_hr = (h, r)
168
+ key_rt = (r, t)
169
+
170
+ if key_hr not in self.hr_to_tails:
171
+ self.hr_to_tails[key_hr] = set()
172
+ self.hr_to_tails[key_hr].add(t)
173
+
174
+ if key_rt not in self.rt_to_heads:
175
+ self.rt_to_heads[key_rt] = set()
176
+ self.rt_to_heads[key_rt].add(h)
177
+
178
+ def evaluate_triple(
179
+ self,
180
+ head: str,
181
+ relation: str,
182
+ tail: str,
183
+ score_fn,
184
+ ) -> RankingResult:
185
+ """
186
+ Evaluate a single test triple.
187
+
188
+ Args:
189
+ head: Head entity
190
+ relation: Relation type
191
+ tail: True tail entity
192
+ score_fn: Function (h, r, t) -> float
193
+
194
+ Returns:
195
+ RankingResult with filtered ranks
196
+ """
197
+ # Score all possible tails
198
+ tail_scores = {}
199
+ for entity in self.all_entities:
200
+ tail_scores[entity] = score_fn(head, relation, entity)
201
+
202
+ # Get filter set (other valid tails for this h,r pair)
203
+ filter_tails = self.hr_to_tails.get((head, relation), set()) - {tail}
204
+
205
+ # Compute filtered rank
206
+ tail_rank = filtered_rank(tail, tail_scores, filter_tails)
207
+
208
+ # Score all possible heads (for bidirectional evaluation)
209
+ head_scores = {}
210
+ for entity in self.all_entities:
211
+ head_scores[entity] = score_fn(entity, relation, tail)
212
+
213
+ # Get filter set for heads
214
+ filter_heads = self.rt_to_heads.get((relation, tail), set()) - {head}
215
+ head_rank = filtered_rank(head, head_scores, filter_heads)
216
+
217
+ return RankingResult(
218
+ head=head,
219
+ relation=relation,
220
+ tail=tail,
221
+ tail_rank=tail_rank,
222
+ head_rank=head_rank,
223
+ num_candidates=len(self.all_entities),
224
+ )
225
+
226
+ def evaluate_batch(
227
+ self,
228
+ test_triples: List[Tuple[str, str, str]],
229
+ score_fn,
230
+ verbose: bool = True,
231
+ ) -> Dict[str, float]:
232
+ """
233
+ Evaluate on a batch of test triples.
234
+
235
+ Args:
236
+ test_triples: List of (h, r, t) tuples
237
+ score_fn: Function (h, r, t) -> float
238
+ verbose: Print progress
239
+
240
+ Returns:
241
+ Dictionary with all metrics
242
+ """
243
+ results = []
244
+
245
+ for i, (h, r, t) in enumerate(test_triples):
246
+ if verbose and (i + 1) % 100 == 0:
247
+ print(f" Evaluated {i + 1}/{len(test_triples)} triples...")
248
+
249
+ result = self.evaluate_triple(h, r, t, score_fn)
250
+ results.append(result)
251
+
252
+ return evaluate_rankings(results)
253
+
254
+
255
+ # Convenience function for quick evaluation
256
+ def quick_evaluate(
257
+ test_triples: List[Tuple[str, str, str]],
258
+ score_fn,
259
+ all_entities: List[str],
260
+ train_triples: set,
261
+ ) -> Dict[str, float]:
262
+ """
263
+ Quick evaluation helper.
264
+
265
+ Args:
266
+ test_triples: Test set
267
+ score_fn: Scoring function
268
+ all_entities: All entities in KG
269
+ train_triples: Training triples for filtering
270
+
271
+ Returns:
272
+ Metrics dictionary
273
+ """
274
+ evaluator = LinkPredictionEvaluator(all_entities, train_triples)
275
+ return evaluator.evaluate_batch(test_triples, score_fn)