odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
benchmarks/metrics.py CHANGED
@@ -1,275 +1,275 @@
1
- """
2
- Standard Knowledge Graph Completion Metrics
3
-
4
- Implements metrics used in academic KG completion papers:
5
- - Mean Reciprocal Rank (MRR)
6
- - Hits@K (K=1, 3, 10)
7
-
8
- These metrics evaluate link prediction quality:
9
- Given (h, r, ?), rank all entities by predicted score.
10
- """
11
-
12
- import numpy as np
13
- from typing import List, Dict, Tuple, Any
14
- from dataclasses import dataclass
15
-
16
-
17
- @dataclass
18
- class RankingResult:
19
- """Result of ranking evaluation for a single triple."""
20
- head: str
21
- relation: str
22
- tail: str
23
- tail_rank: int # Rank of true tail among all candidates
24
- head_rank: int # Rank of true head among all candidates (for inverse)
25
- num_candidates: int
26
-
27
-
28
- def mrr(ranks: List[int]) -> float:
29
- """
30
- Mean Reciprocal Rank
31
-
32
- MRR = (1/|Q|) * Σ (1/rank_i)
33
-
34
- Args:
35
- ranks: List of ranks (1-indexed, where 1 is best)
36
-
37
- Returns:
38
- MRR score in [0, 1]
39
- """
40
- if not ranks:
41
- return 0.0
42
- return float(np.mean([1.0 / r for r in ranks]))
43
-
44
-
45
- def hits_at_k(ranks: List[int], k: int) -> float:
46
- """
47
- Hits@K - proportion of ranks <= K
48
-
49
- Args:
50
- ranks: List of ranks (1-indexed)
51
- k: Cutoff threshold
52
-
53
- Returns:
54
- Hits@K score in [0, 1]
55
- """
56
- if not ranks:
57
- return 0.0
58
- return float(np.mean([1 if r <= k else 0 for r in ranks]))
59
-
60
-
61
- def evaluate_rankings(results: List[RankingResult]) -> Dict[str, float]:
62
- """
63
- Compute all standard metrics from ranking results.
64
-
65
- Args:
66
- results: List of RankingResult objects
67
-
68
- Returns:
69
- Dictionary with MRR, Hits@1, Hits@3, Hits@10
70
- """
71
- if not results:
72
- return {
73
- "mrr": 0.0,
74
- "hits@1": 0.0,
75
- "hits@3": 0.0,
76
- "hits@10": 0.0,
77
- "num_queries": 0
78
- }
79
-
80
- # Use tail ranks (standard protocol: predict tail given head+relation)
81
- tail_ranks = [r.tail_rank for r in results]
82
-
83
- return {
84
- "mrr": mrr(tail_ranks),
85
- "hits@1": hits_at_k(tail_ranks, 1),
86
- "hits@3": hits_at_k(tail_ranks, 3),
87
- "hits@10": hits_at_k(tail_ranks, 10),
88
- "num_queries": len(results),
89
- "mean_rank": float(np.mean(tail_ranks)),
90
- "median_rank": float(np.median(tail_ranks)),
91
- }
92
-
93
-
94
- def filtered_rank(
95
- true_entity: str,
96
- scores: Dict[str, float],
97
- filter_entities: set,
98
- ) -> int:
99
- """
100
- Compute filtered rank (standard protocol).
101
-
102
- Filtered ranking removes other valid answers from consideration,
103
- so we only penalize for ranking random entities above the true one.
104
-
105
- Args:
106
- true_entity: The correct answer
107
- scores: Dict mapping entity -> score
108
- filter_entities: Other valid entities to filter out
109
-
110
- Returns:
111
- Filtered rank (1-indexed)
112
- """
113
- true_score = scores.get(true_entity, float('-inf'))
114
-
115
- rank = 1
116
- for entity, score in scores.items():
117
- if entity == true_entity:
118
- continue
119
- if entity in filter_entities:
120
- continue # Filter out other valid answers
121
- if score > true_score:
122
- rank += 1
123
-
124
- return rank
125
-
126
-
127
- class LinkPredictionEvaluator:
128
- """
129
- Evaluator for link prediction task.
130
-
131
- Standard protocol:
132
- 1. For each test triple (h, r, t):
133
- a. Corrupt tail: score all (h, r, e) for e in entities
134
- b. Corrupt head: score all (e, r, t) for e in entities
135
- 2. Filter out other valid triples (filtered setting)
136
- 3. Report filtered MRR, Hits@1/3/10
137
- """
138
-
139
- def __init__(
140
- self,
141
- all_entities: List[str],
142
- train_triples: set,
143
- valid_triples: set = None,
144
- ):
145
- """
146
- Args:
147
- all_entities: List of all entity IDs
148
- train_triples: Set of (h, r, t) tuples from training
149
- valid_triples: Set of (h, r, t) tuples from validation (optional)
150
- """
151
- self.all_entities = all_entities
152
- self.train_triples = train_triples
153
- self.valid_triples = valid_triples or set()
154
- self.known_triples = train_triples | self.valid_triples
155
-
156
- # Build index for filtering
157
- self._build_filter_index()
158
-
159
- def _build_filter_index(self):
160
- """Build indices for efficient filtering."""
161
- # For (h, r, ?): which tails are valid?
162
- self.hr_to_tails = {}
163
- # For (?, r, t): which heads are valid?
164
- self.rt_to_heads = {}
165
-
166
- for h, r, t in self.known_triples:
167
- key_hr = (h, r)
168
- key_rt = (r, t)
169
-
170
- if key_hr not in self.hr_to_tails:
171
- self.hr_to_tails[key_hr] = set()
172
- self.hr_to_tails[key_hr].add(t)
173
-
174
- if key_rt not in self.rt_to_heads:
175
- self.rt_to_heads[key_rt] = set()
176
- self.rt_to_heads[key_rt].add(h)
177
-
178
- def evaluate_triple(
179
- self,
180
- head: str,
181
- relation: str,
182
- tail: str,
183
- score_fn,
184
- ) -> RankingResult:
185
- """
186
- Evaluate a single test triple.
187
-
188
- Args:
189
- head: Head entity
190
- relation: Relation type
191
- tail: True tail entity
192
- score_fn: Function (h, r, t) -> float
193
-
194
- Returns:
195
- RankingResult with filtered ranks
196
- """
197
- # Score all possible tails
198
- tail_scores = {}
199
- for entity in self.all_entities:
200
- tail_scores[entity] = score_fn(head, relation, entity)
201
-
202
- # Get filter set (other valid tails for this h,r pair)
203
- filter_tails = self.hr_to_tails.get((head, relation), set()) - {tail}
204
-
205
- # Compute filtered rank
206
- tail_rank = filtered_rank(tail, tail_scores, filter_tails)
207
-
208
- # Score all possible heads (for bidirectional evaluation)
209
- head_scores = {}
210
- for entity in self.all_entities:
211
- head_scores[entity] = score_fn(entity, relation, tail)
212
-
213
- # Get filter set for heads
214
- filter_heads = self.rt_to_heads.get((relation, tail), set()) - {head}
215
- head_rank = filtered_rank(head, head_scores, filter_heads)
216
-
217
- return RankingResult(
218
- head=head,
219
- relation=relation,
220
- tail=tail,
221
- tail_rank=tail_rank,
222
- head_rank=head_rank,
223
- num_candidates=len(self.all_entities),
224
- )
225
-
226
- def evaluate_batch(
227
- self,
228
- test_triples: List[Tuple[str, str, str]],
229
- score_fn,
230
- verbose: bool = True,
231
- ) -> Dict[str, float]:
232
- """
233
- Evaluate on a batch of test triples.
234
-
235
- Args:
236
- test_triples: List of (h, r, t) tuples
237
- score_fn: Function (h, r, t) -> float
238
- verbose: Print progress
239
-
240
- Returns:
241
- Dictionary with all metrics
242
- """
243
- results = []
244
-
245
- for i, (h, r, t) in enumerate(test_triples):
246
- if verbose and (i + 1) % 100 == 0:
247
- print(f" Evaluated {i + 1}/{len(test_triples)} triples...")
248
-
249
- result = self.evaluate_triple(h, r, t, score_fn)
250
- results.append(result)
251
-
252
- return evaluate_rankings(results)
253
-
254
-
255
- # Convenience function for quick evaluation
256
- def quick_evaluate(
257
- test_triples: List[Tuple[str, str, str]],
258
- score_fn,
259
- all_entities: List[str],
260
- train_triples: set,
261
- ) -> Dict[str, float]:
262
- """
263
- Quick evaluation helper.
264
-
265
- Args:
266
- test_triples: Test set
267
- score_fn: Scoring function
268
- all_entities: All entities in KG
269
- train_triples: Training triples for filtering
270
-
271
- Returns:
272
- Metrics dictionary
273
- """
274
- evaluator = LinkPredictionEvaluator(all_entities, train_triples)
275
- return evaluator.evaluate_batch(test_triples, score_fn)
1
+ """
2
+ Standard Knowledge Graph Completion Metrics
3
+
4
+ Implements metrics used in academic KG completion papers:
5
+ - Mean Reciprocal Rank (MRR)
6
+ - Hits@K (K=1, 3, 10)
7
+
8
+ These metrics evaluate link prediction quality:
9
+ Given (h, r, ?), rank all entities by predicted score.
10
+ """
11
+
12
+ import numpy as np
13
+ from typing import List, Dict, Tuple, Any
14
+ from dataclasses import dataclass
15
+
16
+
17
+ @dataclass
18
+ class RankingResult:
19
+ """Result of ranking evaluation for a single triple."""
20
+ head: str
21
+ relation: str
22
+ tail: str
23
+ tail_rank: int # Rank of true tail among all candidates
24
+ head_rank: int # Rank of true head among all candidates (for inverse)
25
+ num_candidates: int
26
+
27
+
28
+ def mrr(ranks: List[int]) -> float:
29
+ """
30
+ Mean Reciprocal Rank
31
+
32
+ MRR = (1/|Q|) * Σ (1/rank_i)
33
+
34
+ Args:
35
+ ranks: List of ranks (1-indexed, where 1 is best)
36
+
37
+ Returns:
38
+ MRR score in [0, 1]
39
+ """
40
+ if not ranks:
41
+ return 0.0
42
+ return float(np.mean([1.0 / r for r in ranks]))
43
+
44
+
45
+ def hits_at_k(ranks: List[int], k: int) -> float:
46
+ """
47
+ Hits@K - proportion of ranks <= K
48
+
49
+ Args:
50
+ ranks: List of ranks (1-indexed)
51
+ k: Cutoff threshold
52
+
53
+ Returns:
54
+ Hits@K score in [0, 1]
55
+ """
56
+ if not ranks:
57
+ return 0.0
58
+ return float(np.mean([1 if r <= k else 0 for r in ranks]))
59
+
60
+
61
+ def evaluate_rankings(results: List[RankingResult]) -> Dict[str, float]:
62
+ """
63
+ Compute all standard metrics from ranking results.
64
+
65
+ Args:
66
+ results: List of RankingResult objects
67
+
68
+ Returns:
69
+ Dictionary with MRR, Hits@1, Hits@3, Hits@10
70
+ """
71
+ if not results:
72
+ return {
73
+ "mrr": 0.0,
74
+ "hits@1": 0.0,
75
+ "hits@3": 0.0,
76
+ "hits@10": 0.0,
77
+ "num_queries": 0
78
+ }
79
+
80
+ # Use tail ranks (standard protocol: predict tail given head+relation)
81
+ tail_ranks = [r.tail_rank for r in results]
82
+
83
+ return {
84
+ "mrr": mrr(tail_ranks),
85
+ "hits@1": hits_at_k(tail_ranks, 1),
86
+ "hits@3": hits_at_k(tail_ranks, 3),
87
+ "hits@10": hits_at_k(tail_ranks, 10),
88
+ "num_queries": len(results),
89
+ "mean_rank": float(np.mean(tail_ranks)),
90
+ "median_rank": float(np.median(tail_ranks)),
91
+ }
92
+
93
+
94
+ def filtered_rank(
95
+ true_entity: str,
96
+ scores: Dict[str, float],
97
+ filter_entities: set,
98
+ ) -> int:
99
+ """
100
+ Compute filtered rank (standard protocol).
101
+
102
+ Filtered ranking removes other valid answers from consideration,
103
+ so we only penalize for ranking random entities above the true one.
104
+
105
+ Args:
106
+ true_entity: The correct answer
107
+ scores: Dict mapping entity -> score
108
+ filter_entities: Other valid entities to filter out
109
+
110
+ Returns:
111
+ Filtered rank (1-indexed)
112
+ """
113
+ true_score = scores.get(true_entity, float('-inf'))
114
+
115
+ rank = 1
116
+ for entity, score in scores.items():
117
+ if entity == true_entity:
118
+ continue
119
+ if entity in filter_entities:
120
+ continue # Filter out other valid answers
121
+ if score > true_score:
122
+ rank += 1
123
+
124
+ return rank
125
+
126
+
127
+ class LinkPredictionEvaluator:
128
+ """
129
+ Evaluator for link prediction task.
130
+
131
+ Standard protocol:
132
+ 1. For each test triple (h, r, t):
133
+ a. Corrupt tail: score all (h, r, e) for e in entities
134
+ b. Corrupt head: score all (e, r, t) for e in entities
135
+ 2. Filter out other valid triples (filtered setting)
136
+ 3. Report filtered MRR, Hits@1/3/10
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ all_entities: List[str],
142
+ train_triples: set,
143
+ valid_triples: set = None,
144
+ ):
145
+ """
146
+ Args:
147
+ all_entities: List of all entity IDs
148
+ train_triples: Set of (h, r, t) tuples from training
149
+ valid_triples: Set of (h, r, t) tuples from validation (optional)
150
+ """
151
+ self.all_entities = all_entities
152
+ self.train_triples = train_triples
153
+ self.valid_triples = valid_triples or set()
154
+ self.known_triples = train_triples | self.valid_triples
155
+
156
+ # Build index for filtering
157
+ self._build_filter_index()
158
+
159
+ def _build_filter_index(self):
160
+ """Build indices for efficient filtering."""
161
+ # For (h, r, ?): which tails are valid?
162
+ self.hr_to_tails = {}
163
+ # For (?, r, t): which heads are valid?
164
+ self.rt_to_heads = {}
165
+
166
+ for h, r, t in self.known_triples:
167
+ key_hr = (h, r)
168
+ key_rt = (r, t)
169
+
170
+ if key_hr not in self.hr_to_tails:
171
+ self.hr_to_tails[key_hr] = set()
172
+ self.hr_to_tails[key_hr].add(t)
173
+
174
+ if key_rt not in self.rt_to_heads:
175
+ self.rt_to_heads[key_rt] = set()
176
+ self.rt_to_heads[key_rt].add(h)
177
+
178
+ def evaluate_triple(
179
+ self,
180
+ head: str,
181
+ relation: str,
182
+ tail: str,
183
+ score_fn,
184
+ ) -> RankingResult:
185
+ """
186
+ Evaluate a single test triple.
187
+
188
+ Args:
189
+ head: Head entity
190
+ relation: Relation type
191
+ tail: True tail entity
192
+ score_fn: Function (h, r, t) -> float
193
+
194
+ Returns:
195
+ RankingResult with filtered ranks
196
+ """
197
+ # Score all possible tails
198
+ tail_scores = {}
199
+ for entity in self.all_entities:
200
+ tail_scores[entity] = score_fn(head, relation, entity)
201
+
202
+ # Get filter set (other valid tails for this h,r pair)
203
+ filter_tails = self.hr_to_tails.get((head, relation), set()) - {tail}
204
+
205
+ # Compute filtered rank
206
+ tail_rank = filtered_rank(tail, tail_scores, filter_tails)
207
+
208
+ # Score all possible heads (for bidirectional evaluation)
209
+ head_scores = {}
210
+ for entity in self.all_entities:
211
+ head_scores[entity] = score_fn(entity, relation, tail)
212
+
213
+ # Get filter set for heads
214
+ filter_heads = self.rt_to_heads.get((relation, tail), set()) - {head}
215
+ head_rank = filtered_rank(head, head_scores, filter_heads)
216
+
217
+ return RankingResult(
218
+ head=head,
219
+ relation=relation,
220
+ tail=tail,
221
+ tail_rank=tail_rank,
222
+ head_rank=head_rank,
223
+ num_candidates=len(self.all_entities),
224
+ )
225
+
226
+ def evaluate_batch(
227
+ self,
228
+ test_triples: List[Tuple[str, str, str]],
229
+ score_fn,
230
+ verbose: bool = True,
231
+ ) -> Dict[str, float]:
232
+ """
233
+ Evaluate on a batch of test triples.
234
+
235
+ Args:
236
+ test_triples: List of (h, r, t) tuples
237
+ score_fn: Function (h, r, t) -> float
238
+ verbose: Print progress
239
+
240
+ Returns:
241
+ Dictionary with all metrics
242
+ """
243
+ results = []
244
+
245
+ for i, (h, r, t) in enumerate(test_triples):
246
+ if verbose and (i + 1) % 100 == 0:
247
+ print(f" Evaluated {i + 1}/{len(test_triples)} triples...")
248
+
249
+ result = self.evaluate_triple(h, r, t, score_fn)
250
+ results.append(result)
251
+
252
+ return evaluate_rankings(results)
253
+
254
+
255
+ # Convenience function for quick evaluation
256
+ def quick_evaluate(
257
+ test_triples: List[Tuple[str, str, str]],
258
+ score_fn,
259
+ all_entities: List[str],
260
+ train_triples: set,
261
+ ) -> Dict[str, float]:
262
+ """
263
+ Quick evaluation helper.
264
+
265
+ Args:
266
+ test_triples: Test set
267
+ score_fn: Scoring function
268
+ all_entities: All entities in KG
269
+ train_triples: Training triples for filtering
270
+
271
+ Returns:
272
+ Metrics dictionary
273
+ """
274
+ evaluator = LinkPredictionEvaluator(all_entities, train_triples)
275
+ return evaluator.evaluate_batch(test_triples, score_fn)