odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
benchmarks/metrics.py
CHANGED
|
@@ -1,275 +1,275 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Standard Knowledge Graph Completion Metrics
|
|
3
|
-
|
|
4
|
-
Implements metrics used in academic KG completion papers:
|
|
5
|
-
- Mean Reciprocal Rank (MRR)
|
|
6
|
-
- Hits@K (K=1, 3, 10)
|
|
7
|
-
|
|
8
|
-
These metrics evaluate link prediction quality:
|
|
9
|
-
Given (h, r, ?), rank all entities by predicted score.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
import numpy as np
|
|
13
|
-
from typing import List, Dict, Tuple, Any
|
|
14
|
-
from dataclasses import dataclass
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
@dataclass
|
|
18
|
-
class RankingResult:
|
|
19
|
-
"""Result of ranking evaluation for a single triple."""
|
|
20
|
-
head: str
|
|
21
|
-
relation: str
|
|
22
|
-
tail: str
|
|
23
|
-
tail_rank: int # Rank of true tail among all candidates
|
|
24
|
-
head_rank: int # Rank of true head among all candidates (for inverse)
|
|
25
|
-
num_candidates: int
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def mrr(ranks: List[int]) -> float:
|
|
29
|
-
"""
|
|
30
|
-
Mean Reciprocal Rank
|
|
31
|
-
|
|
32
|
-
MRR = (1/|Q|) * Σ (1/rank_i)
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
ranks: List of ranks (1-indexed, where 1 is best)
|
|
36
|
-
|
|
37
|
-
Returns:
|
|
38
|
-
MRR score in [0, 1]
|
|
39
|
-
"""
|
|
40
|
-
if not ranks:
|
|
41
|
-
return 0.0
|
|
42
|
-
return float(np.mean([1.0 / r for r in ranks]))
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def hits_at_k(ranks: List[int], k: int) -> float:
|
|
46
|
-
"""
|
|
47
|
-
Hits@K - proportion of ranks <= K
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
ranks: List of ranks (1-indexed)
|
|
51
|
-
k: Cutoff threshold
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
Hits@K score in [0, 1]
|
|
55
|
-
"""
|
|
56
|
-
if not ranks:
|
|
57
|
-
return 0.0
|
|
58
|
-
return float(np.mean([1 if r <= k else 0 for r in ranks]))
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def evaluate_rankings(results: List[RankingResult]) -> Dict[str, float]:
|
|
62
|
-
"""
|
|
63
|
-
Compute all standard metrics from ranking results.
|
|
64
|
-
|
|
65
|
-
Args:
|
|
66
|
-
results: List of RankingResult objects
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
Dictionary with MRR, Hits@1, Hits@3, Hits@10
|
|
70
|
-
"""
|
|
71
|
-
if not results:
|
|
72
|
-
return {
|
|
73
|
-
"mrr": 0.0,
|
|
74
|
-
"hits@1": 0.0,
|
|
75
|
-
"hits@3": 0.0,
|
|
76
|
-
"hits@10": 0.0,
|
|
77
|
-
"num_queries": 0
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
# Use tail ranks (standard protocol: predict tail given head+relation)
|
|
81
|
-
tail_ranks = [r.tail_rank for r in results]
|
|
82
|
-
|
|
83
|
-
return {
|
|
84
|
-
"mrr": mrr(tail_ranks),
|
|
85
|
-
"hits@1": hits_at_k(tail_ranks, 1),
|
|
86
|
-
"hits@3": hits_at_k(tail_ranks, 3),
|
|
87
|
-
"hits@10": hits_at_k(tail_ranks, 10),
|
|
88
|
-
"num_queries": len(results),
|
|
89
|
-
"mean_rank": float(np.mean(tail_ranks)),
|
|
90
|
-
"median_rank": float(np.median(tail_ranks)),
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def filtered_rank(
|
|
95
|
-
true_entity: str,
|
|
96
|
-
scores: Dict[str, float],
|
|
97
|
-
filter_entities: set,
|
|
98
|
-
) -> int:
|
|
99
|
-
"""
|
|
100
|
-
Compute filtered rank (standard protocol).
|
|
101
|
-
|
|
102
|
-
Filtered ranking removes other valid answers from consideration,
|
|
103
|
-
so we only penalize for ranking random entities above the true one.
|
|
104
|
-
|
|
105
|
-
Args:
|
|
106
|
-
true_entity: The correct answer
|
|
107
|
-
scores: Dict mapping entity -> score
|
|
108
|
-
filter_entities: Other valid entities to filter out
|
|
109
|
-
|
|
110
|
-
Returns:
|
|
111
|
-
Filtered rank (1-indexed)
|
|
112
|
-
"""
|
|
113
|
-
true_score = scores.get(true_entity, float('-inf'))
|
|
114
|
-
|
|
115
|
-
rank = 1
|
|
116
|
-
for entity, score in scores.items():
|
|
117
|
-
if entity == true_entity:
|
|
118
|
-
continue
|
|
119
|
-
if entity in filter_entities:
|
|
120
|
-
continue # Filter out other valid answers
|
|
121
|
-
if score > true_score:
|
|
122
|
-
rank += 1
|
|
123
|
-
|
|
124
|
-
return rank
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
class LinkPredictionEvaluator:
|
|
128
|
-
"""
|
|
129
|
-
Evaluator for link prediction task.
|
|
130
|
-
|
|
131
|
-
Standard protocol:
|
|
132
|
-
1. For each test triple (h, r, t):
|
|
133
|
-
a. Corrupt tail: score all (h, r, e) for e in entities
|
|
134
|
-
b. Corrupt head: score all (e, r, t) for e in entities
|
|
135
|
-
2. Filter out other valid triples (filtered setting)
|
|
136
|
-
3. Report filtered MRR, Hits@1/3/10
|
|
137
|
-
"""
|
|
138
|
-
|
|
139
|
-
def __init__(
|
|
140
|
-
self,
|
|
141
|
-
all_entities: List[str],
|
|
142
|
-
train_triples: set,
|
|
143
|
-
valid_triples: set = None,
|
|
144
|
-
):
|
|
145
|
-
"""
|
|
146
|
-
Args:
|
|
147
|
-
all_entities: List of all entity IDs
|
|
148
|
-
train_triples: Set of (h, r, t) tuples from training
|
|
149
|
-
valid_triples: Set of (h, r, t) tuples from validation (optional)
|
|
150
|
-
"""
|
|
151
|
-
self.all_entities = all_entities
|
|
152
|
-
self.train_triples = train_triples
|
|
153
|
-
self.valid_triples = valid_triples or set()
|
|
154
|
-
self.known_triples = train_triples | self.valid_triples
|
|
155
|
-
|
|
156
|
-
# Build index for filtering
|
|
157
|
-
self._build_filter_index()
|
|
158
|
-
|
|
159
|
-
def _build_filter_index(self):
|
|
160
|
-
"""Build indices for efficient filtering."""
|
|
161
|
-
# For (h, r, ?): which tails are valid?
|
|
162
|
-
self.hr_to_tails = {}
|
|
163
|
-
# For (?, r, t): which heads are valid?
|
|
164
|
-
self.rt_to_heads = {}
|
|
165
|
-
|
|
166
|
-
for h, r, t in self.known_triples:
|
|
167
|
-
key_hr = (h, r)
|
|
168
|
-
key_rt = (r, t)
|
|
169
|
-
|
|
170
|
-
if key_hr not in self.hr_to_tails:
|
|
171
|
-
self.hr_to_tails[key_hr] = set()
|
|
172
|
-
self.hr_to_tails[key_hr].add(t)
|
|
173
|
-
|
|
174
|
-
if key_rt not in self.rt_to_heads:
|
|
175
|
-
self.rt_to_heads[key_rt] = set()
|
|
176
|
-
self.rt_to_heads[key_rt].add(h)
|
|
177
|
-
|
|
178
|
-
def evaluate_triple(
|
|
179
|
-
self,
|
|
180
|
-
head: str,
|
|
181
|
-
relation: str,
|
|
182
|
-
tail: str,
|
|
183
|
-
score_fn,
|
|
184
|
-
) -> RankingResult:
|
|
185
|
-
"""
|
|
186
|
-
Evaluate a single test triple.
|
|
187
|
-
|
|
188
|
-
Args:
|
|
189
|
-
head: Head entity
|
|
190
|
-
relation: Relation type
|
|
191
|
-
tail: True tail entity
|
|
192
|
-
score_fn: Function (h, r, t) -> float
|
|
193
|
-
|
|
194
|
-
Returns:
|
|
195
|
-
RankingResult with filtered ranks
|
|
196
|
-
"""
|
|
197
|
-
# Score all possible tails
|
|
198
|
-
tail_scores = {}
|
|
199
|
-
for entity in self.all_entities:
|
|
200
|
-
tail_scores[entity] = score_fn(head, relation, entity)
|
|
201
|
-
|
|
202
|
-
# Get filter set (other valid tails for this h,r pair)
|
|
203
|
-
filter_tails = self.hr_to_tails.get((head, relation), set()) - {tail}
|
|
204
|
-
|
|
205
|
-
# Compute filtered rank
|
|
206
|
-
tail_rank = filtered_rank(tail, tail_scores, filter_tails)
|
|
207
|
-
|
|
208
|
-
# Score all possible heads (for bidirectional evaluation)
|
|
209
|
-
head_scores = {}
|
|
210
|
-
for entity in self.all_entities:
|
|
211
|
-
head_scores[entity] = score_fn(entity, relation, tail)
|
|
212
|
-
|
|
213
|
-
# Get filter set for heads
|
|
214
|
-
filter_heads = self.rt_to_heads.get((relation, tail), set()) - {head}
|
|
215
|
-
head_rank = filtered_rank(head, head_scores, filter_heads)
|
|
216
|
-
|
|
217
|
-
return RankingResult(
|
|
218
|
-
head=head,
|
|
219
|
-
relation=relation,
|
|
220
|
-
tail=tail,
|
|
221
|
-
tail_rank=tail_rank,
|
|
222
|
-
head_rank=head_rank,
|
|
223
|
-
num_candidates=len(self.all_entities),
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
def evaluate_batch(
|
|
227
|
-
self,
|
|
228
|
-
test_triples: List[Tuple[str, str, str]],
|
|
229
|
-
score_fn,
|
|
230
|
-
verbose: bool = True,
|
|
231
|
-
) -> Dict[str, float]:
|
|
232
|
-
"""
|
|
233
|
-
Evaluate on a batch of test triples.
|
|
234
|
-
|
|
235
|
-
Args:
|
|
236
|
-
test_triples: List of (h, r, t) tuples
|
|
237
|
-
score_fn: Function (h, r, t) -> float
|
|
238
|
-
verbose: Print progress
|
|
239
|
-
|
|
240
|
-
Returns:
|
|
241
|
-
Dictionary with all metrics
|
|
242
|
-
"""
|
|
243
|
-
results = []
|
|
244
|
-
|
|
245
|
-
for i, (h, r, t) in enumerate(test_triples):
|
|
246
|
-
if verbose and (i + 1) % 100 == 0:
|
|
247
|
-
print(f" Evaluated {i + 1}/{len(test_triples)} triples...")
|
|
248
|
-
|
|
249
|
-
result = self.evaluate_triple(h, r, t, score_fn)
|
|
250
|
-
results.append(result)
|
|
251
|
-
|
|
252
|
-
return evaluate_rankings(results)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
# Convenience function for quick evaluation
|
|
256
|
-
def quick_evaluate(
|
|
257
|
-
test_triples: List[Tuple[str, str, str]],
|
|
258
|
-
score_fn,
|
|
259
|
-
all_entities: List[str],
|
|
260
|
-
train_triples: set,
|
|
261
|
-
) -> Dict[str, float]:
|
|
262
|
-
"""
|
|
263
|
-
Quick evaluation helper.
|
|
264
|
-
|
|
265
|
-
Args:
|
|
266
|
-
test_triples: Test set
|
|
267
|
-
score_fn: Scoring function
|
|
268
|
-
all_entities: All entities in KG
|
|
269
|
-
train_triples: Training triples for filtering
|
|
270
|
-
|
|
271
|
-
Returns:
|
|
272
|
-
Metrics dictionary
|
|
273
|
-
"""
|
|
274
|
-
evaluator = LinkPredictionEvaluator(all_entities, train_triples)
|
|
275
|
-
return evaluator.evaluate_batch(test_triples, score_fn)
|
|
1
|
+
"""
|
|
2
|
+
Standard Knowledge Graph Completion Metrics
|
|
3
|
+
|
|
4
|
+
Implements metrics used in academic KG completion papers:
|
|
5
|
+
- Mean Reciprocal Rank (MRR)
|
|
6
|
+
- Hits@K (K=1, 3, 10)
|
|
7
|
+
|
|
8
|
+
These metrics evaluate link prediction quality:
|
|
9
|
+
Given (h, r, ?), rank all entities by predicted score.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from typing import List, Dict, Tuple, Any
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class RankingResult:
|
|
19
|
+
"""Result of ranking evaluation for a single triple."""
|
|
20
|
+
head: str
|
|
21
|
+
relation: str
|
|
22
|
+
tail: str
|
|
23
|
+
tail_rank: int # Rank of true tail among all candidates
|
|
24
|
+
head_rank: int # Rank of true head among all candidates (for inverse)
|
|
25
|
+
num_candidates: int
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def mrr(ranks: List[int]) -> float:
|
|
29
|
+
"""
|
|
30
|
+
Mean Reciprocal Rank
|
|
31
|
+
|
|
32
|
+
MRR = (1/|Q|) * Σ (1/rank_i)
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
ranks: List of ranks (1-indexed, where 1 is best)
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
MRR score in [0, 1]
|
|
39
|
+
"""
|
|
40
|
+
if not ranks:
|
|
41
|
+
return 0.0
|
|
42
|
+
return float(np.mean([1.0 / r for r in ranks]))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def hits_at_k(ranks: List[int], k: int) -> float:
|
|
46
|
+
"""
|
|
47
|
+
Hits@K - proportion of ranks <= K
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
ranks: List of ranks (1-indexed)
|
|
51
|
+
k: Cutoff threshold
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Hits@K score in [0, 1]
|
|
55
|
+
"""
|
|
56
|
+
if not ranks:
|
|
57
|
+
return 0.0
|
|
58
|
+
return float(np.mean([1 if r <= k else 0 for r in ranks]))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def evaluate_rankings(results: List[RankingResult]) -> Dict[str, float]:
|
|
62
|
+
"""
|
|
63
|
+
Compute all standard metrics from ranking results.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
results: List of RankingResult objects
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary with MRR, Hits@1, Hits@3, Hits@10
|
|
70
|
+
"""
|
|
71
|
+
if not results:
|
|
72
|
+
return {
|
|
73
|
+
"mrr": 0.0,
|
|
74
|
+
"hits@1": 0.0,
|
|
75
|
+
"hits@3": 0.0,
|
|
76
|
+
"hits@10": 0.0,
|
|
77
|
+
"num_queries": 0
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# Use tail ranks (standard protocol: predict tail given head+relation)
|
|
81
|
+
tail_ranks = [r.tail_rank for r in results]
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"mrr": mrr(tail_ranks),
|
|
85
|
+
"hits@1": hits_at_k(tail_ranks, 1),
|
|
86
|
+
"hits@3": hits_at_k(tail_ranks, 3),
|
|
87
|
+
"hits@10": hits_at_k(tail_ranks, 10),
|
|
88
|
+
"num_queries": len(results),
|
|
89
|
+
"mean_rank": float(np.mean(tail_ranks)),
|
|
90
|
+
"median_rank": float(np.median(tail_ranks)),
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def filtered_rank(
|
|
95
|
+
true_entity: str,
|
|
96
|
+
scores: Dict[str, float],
|
|
97
|
+
filter_entities: set,
|
|
98
|
+
) -> int:
|
|
99
|
+
"""
|
|
100
|
+
Compute filtered rank (standard protocol).
|
|
101
|
+
|
|
102
|
+
Filtered ranking removes other valid answers from consideration,
|
|
103
|
+
so we only penalize for ranking random entities above the true one.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
true_entity: The correct answer
|
|
107
|
+
scores: Dict mapping entity -> score
|
|
108
|
+
filter_entities: Other valid entities to filter out
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Filtered rank (1-indexed)
|
|
112
|
+
"""
|
|
113
|
+
true_score = scores.get(true_entity, float('-inf'))
|
|
114
|
+
|
|
115
|
+
rank = 1
|
|
116
|
+
for entity, score in scores.items():
|
|
117
|
+
if entity == true_entity:
|
|
118
|
+
continue
|
|
119
|
+
if entity in filter_entities:
|
|
120
|
+
continue # Filter out other valid answers
|
|
121
|
+
if score > true_score:
|
|
122
|
+
rank += 1
|
|
123
|
+
|
|
124
|
+
return rank
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class LinkPredictionEvaluator:
|
|
128
|
+
"""
|
|
129
|
+
Evaluator for link prediction task.
|
|
130
|
+
|
|
131
|
+
Standard protocol:
|
|
132
|
+
1. For each test triple (h, r, t):
|
|
133
|
+
a. Corrupt tail: score all (h, r, e) for e in entities
|
|
134
|
+
b. Corrupt head: score all (e, r, t) for e in entities
|
|
135
|
+
2. Filter out other valid triples (filtered setting)
|
|
136
|
+
3. Report filtered MRR, Hits@1/3/10
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
all_entities: List[str],
|
|
142
|
+
train_triples: set,
|
|
143
|
+
valid_triples: set = None,
|
|
144
|
+
):
|
|
145
|
+
"""
|
|
146
|
+
Args:
|
|
147
|
+
all_entities: List of all entity IDs
|
|
148
|
+
train_triples: Set of (h, r, t) tuples from training
|
|
149
|
+
valid_triples: Set of (h, r, t) tuples from validation (optional)
|
|
150
|
+
"""
|
|
151
|
+
self.all_entities = all_entities
|
|
152
|
+
self.train_triples = train_triples
|
|
153
|
+
self.valid_triples = valid_triples or set()
|
|
154
|
+
self.known_triples = train_triples | self.valid_triples
|
|
155
|
+
|
|
156
|
+
# Build index for filtering
|
|
157
|
+
self._build_filter_index()
|
|
158
|
+
|
|
159
|
+
def _build_filter_index(self):
|
|
160
|
+
"""Build indices for efficient filtering."""
|
|
161
|
+
# For (h, r, ?): which tails are valid?
|
|
162
|
+
self.hr_to_tails = {}
|
|
163
|
+
# For (?, r, t): which heads are valid?
|
|
164
|
+
self.rt_to_heads = {}
|
|
165
|
+
|
|
166
|
+
for h, r, t in self.known_triples:
|
|
167
|
+
key_hr = (h, r)
|
|
168
|
+
key_rt = (r, t)
|
|
169
|
+
|
|
170
|
+
if key_hr not in self.hr_to_tails:
|
|
171
|
+
self.hr_to_tails[key_hr] = set()
|
|
172
|
+
self.hr_to_tails[key_hr].add(t)
|
|
173
|
+
|
|
174
|
+
if key_rt not in self.rt_to_heads:
|
|
175
|
+
self.rt_to_heads[key_rt] = set()
|
|
176
|
+
self.rt_to_heads[key_rt].add(h)
|
|
177
|
+
|
|
178
|
+
def evaluate_triple(
|
|
179
|
+
self,
|
|
180
|
+
head: str,
|
|
181
|
+
relation: str,
|
|
182
|
+
tail: str,
|
|
183
|
+
score_fn,
|
|
184
|
+
) -> RankingResult:
|
|
185
|
+
"""
|
|
186
|
+
Evaluate a single test triple.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
head: Head entity
|
|
190
|
+
relation: Relation type
|
|
191
|
+
tail: True tail entity
|
|
192
|
+
score_fn: Function (h, r, t) -> float
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
RankingResult with filtered ranks
|
|
196
|
+
"""
|
|
197
|
+
# Score all possible tails
|
|
198
|
+
tail_scores = {}
|
|
199
|
+
for entity in self.all_entities:
|
|
200
|
+
tail_scores[entity] = score_fn(head, relation, entity)
|
|
201
|
+
|
|
202
|
+
# Get filter set (other valid tails for this h,r pair)
|
|
203
|
+
filter_tails = self.hr_to_tails.get((head, relation), set()) - {tail}
|
|
204
|
+
|
|
205
|
+
# Compute filtered rank
|
|
206
|
+
tail_rank = filtered_rank(tail, tail_scores, filter_tails)
|
|
207
|
+
|
|
208
|
+
# Score all possible heads (for bidirectional evaluation)
|
|
209
|
+
head_scores = {}
|
|
210
|
+
for entity in self.all_entities:
|
|
211
|
+
head_scores[entity] = score_fn(entity, relation, tail)
|
|
212
|
+
|
|
213
|
+
# Get filter set for heads
|
|
214
|
+
filter_heads = self.rt_to_heads.get((relation, tail), set()) - {head}
|
|
215
|
+
head_rank = filtered_rank(head, head_scores, filter_heads)
|
|
216
|
+
|
|
217
|
+
return RankingResult(
|
|
218
|
+
head=head,
|
|
219
|
+
relation=relation,
|
|
220
|
+
tail=tail,
|
|
221
|
+
tail_rank=tail_rank,
|
|
222
|
+
head_rank=head_rank,
|
|
223
|
+
num_candidates=len(self.all_entities),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def evaluate_batch(
|
|
227
|
+
self,
|
|
228
|
+
test_triples: List[Tuple[str, str, str]],
|
|
229
|
+
score_fn,
|
|
230
|
+
verbose: bool = True,
|
|
231
|
+
) -> Dict[str, float]:
|
|
232
|
+
"""
|
|
233
|
+
Evaluate on a batch of test triples.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
test_triples: List of (h, r, t) tuples
|
|
237
|
+
score_fn: Function (h, r, t) -> float
|
|
238
|
+
verbose: Print progress
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dictionary with all metrics
|
|
242
|
+
"""
|
|
243
|
+
results = []
|
|
244
|
+
|
|
245
|
+
for i, (h, r, t) in enumerate(test_triples):
|
|
246
|
+
if verbose and (i + 1) % 100 == 0:
|
|
247
|
+
print(f" Evaluated {i + 1}/{len(test_triples)} triples...")
|
|
248
|
+
|
|
249
|
+
result = self.evaluate_triple(h, r, t, score_fn)
|
|
250
|
+
results.append(result)
|
|
251
|
+
|
|
252
|
+
return evaluate_rankings(results)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# Convenience function for quick evaluation
|
|
256
|
+
def quick_evaluate(
|
|
257
|
+
test_triples: List[Tuple[str, str, str]],
|
|
258
|
+
score_fn,
|
|
259
|
+
all_entities: List[str],
|
|
260
|
+
train_triples: set,
|
|
261
|
+
) -> Dict[str, float]:
|
|
262
|
+
"""
|
|
263
|
+
Quick evaluation helper.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
test_triples: Test set
|
|
267
|
+
score_fn: Scoring function
|
|
268
|
+
all_entities: All entities in KG
|
|
269
|
+
train_triples: Training triples for filtering
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Metrics dictionary
|
|
273
|
+
"""
|
|
274
|
+
evaluator = LinkPredictionEvaluator(all_entities, train_triples)
|
|
275
|
+
return evaluator.evaluate_batch(test_triples, score_fn)
|