odin-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -0
- benchmarks/datasets.py +284 -0
- benchmarks/metrics.py +275 -0
- benchmarks/run_ablation.py +279 -0
- benchmarks/run_npll_benchmark.py +270 -0
- npll/__init__.py +10 -0
- npll/bootstrap.py +474 -0
- npll/core/__init__.py +34 -0
- npll/core/knowledge_graph.py +309 -0
- npll/core/logical_rules.py +497 -0
- npll/core/mln.py +475 -0
- npll/inference/__init__.py +41 -0
- npll/inference/e_step.py +420 -0
- npll/inference/elbo.py +435 -0
- npll/inference/m_step.py +577 -0
- npll/npll_model.py +632 -0
- npll/scoring/__init__.py +43 -0
- npll/scoring/embeddings.py +442 -0
- npll/scoring/probability.py +403 -0
- npll/scoring/scoring_module.py +370 -0
- npll/training/__init__.py +25 -0
- npll/training/evaluation.py +497 -0
- npll/training/npll_trainer.py +521 -0
- npll/utils/__init__.py +48 -0
- npll/utils/batch_utils.py +493 -0
- npll/utils/config.py +145 -0
- npll/utils/math_utils.py +339 -0
- odin/__init__.py +20 -0
- odin/engine.py +264 -0
- odin_engine-0.1.0.dist-info/METADATA +456 -0
- odin_engine-0.1.0.dist-info/RECORD +62 -0
- odin_engine-0.1.0.dist-info/WHEEL +5 -0
- odin_engine-0.1.0.dist-info/licenses/LICENSE +21 -0
- odin_engine-0.1.0.dist-info/top_level.txt +4 -0
- retrieval/__init__.py +50 -0
- retrieval/adapters.py +140 -0
- retrieval/adapters_arango.py +1418 -0
- retrieval/aggregators.py +707 -0
- retrieval/beam.py +127 -0
- retrieval/budget.py +60 -0
- retrieval/cache.py +159 -0
- retrieval/confidence.py +88 -0
- retrieval/eval.py +49 -0
- retrieval/linker.py +87 -0
- retrieval/metrics.py +105 -0
- retrieval/metrics_motifs.py +36 -0
- retrieval/orchestrator.py +571 -0
- retrieval/ppr/__init__.py +12 -0
- retrieval/ppr/anchors.py +41 -0
- retrieval/ppr/bippr.py +61 -0
- retrieval/ppr/engines.py +257 -0
- retrieval/ppr/global_pr.py +76 -0
- retrieval/ppr/indexes.py +78 -0
- retrieval/ppr.py +156 -0
- retrieval/ppr_cache.py +25 -0
- retrieval/scoring.py +294 -0
- retrieval/utils/__init__.py +0 -0
- retrieval/utils/pii_redaction.py +36 -0
- retrieval/writers/__init__.py +9 -0
- retrieval/writers/arango_writer.py +28 -0
- retrieval/writers/base.py +21 -0
- retrieval/writers/janus_writer.py +36 -0
benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Odin Benchmarks: Academic validation against standard KG datasets.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- Standard dataset loaders (FB15k-237, WN18RR)
|
|
6
|
+
- KG completion metrics (MRR, Hits@K)
|
|
7
|
+
- Benchmark runners for NPLL evaluation
|
|
8
|
+
- Ablation study tools
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .metrics import mrr, hits_at_k, evaluate_rankings
|
|
12
|
+
from .datasets import load_fb15k237, load_wn18rr, BenchmarkDataset
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"mrr", "hits_at_k", "evaluate_rankings",
|
|
16
|
+
"load_fb15k237", "load_wn18rr", "BenchmarkDataset"
|
|
17
|
+
]
|
benchmarks/datasets.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Standard Knowledge Graph Benchmark Datasets
|
|
3
|
+
|
|
4
|
+
Provides loaders for:
|
|
5
|
+
- FB15k-237: Freebase subset (14,541 entities, 237 relations)
|
|
6
|
+
- WN18RR: WordNet subset (40,943 entities, 11 relations)
|
|
7
|
+
|
|
8
|
+
Datasets are downloaded from standard sources and cached locally.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import urllib.request
|
|
13
|
+
import tarfile
|
|
14
|
+
import zipfile
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import List, Tuple, Set, Dict, Optional
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
import logging
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# Dataset URLs - using villmow/datasets_knowledge_embedding (reliable raw files)
|
|
23
|
+
DATASET_BASE_URLS = {
|
|
24
|
+
"fb15k237": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/FB15k-237",
|
|
25
|
+
"wn18rr": "https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Default cache directory
|
|
29
|
+
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "odin_benchmarks"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class BenchmarkDataset:
|
|
34
|
+
"""Container for a benchmark dataset."""
|
|
35
|
+
name: str
|
|
36
|
+
train_triples: List[Tuple[str, str, str]]
|
|
37
|
+
valid_triples: List[Tuple[str, str, str]]
|
|
38
|
+
test_triples: List[Tuple[str, str, str]]
|
|
39
|
+
entities: List[str]
|
|
40
|
+
relations: List[str]
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def num_entities(self) -> int:
|
|
44
|
+
return len(self.entities)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def num_relations(self) -> int:
|
|
48
|
+
return len(self.relations)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def num_train(self) -> int:
|
|
52
|
+
return len(self.train_triples)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def num_valid(self) -> int:
|
|
56
|
+
return len(self.valid_triples)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def num_test(self) -> int:
|
|
60
|
+
return len(self.test_triples)
|
|
61
|
+
|
|
62
|
+
def get_train_set(self) -> Set[Tuple[str, str, str]]:
|
|
63
|
+
return set(self.train_triples)
|
|
64
|
+
|
|
65
|
+
def get_all_triples(self) -> Set[Tuple[str, str, str]]:
|
|
66
|
+
return set(self.train_triples + self.valid_triples + self.test_triples)
|
|
67
|
+
|
|
68
|
+
def __repr__(self) -> str:
|
|
69
|
+
return (
|
|
70
|
+
f"BenchmarkDataset({self.name})\n"
|
|
71
|
+
f" Entities: {self.num_entities:,}\n"
|
|
72
|
+
f" Relations: {self.num_relations}\n"
|
|
73
|
+
f" Train: {self.num_train:,}\n"
|
|
74
|
+
f" Valid: {self.num_valid:,}\n"
|
|
75
|
+
f" Test: {self.num_test:,}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _ensure_dir(path: Path):
|
|
80
|
+
"""Ensure directory exists."""
|
|
81
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _download_file(url: str, dest: Path):
|
|
85
|
+
"""Download a file with progress."""
|
|
86
|
+
import ssl
|
|
87
|
+
import certifi
|
|
88
|
+
|
|
89
|
+
logger.info(f"Downloading {url}...")
|
|
90
|
+
|
|
91
|
+
# Try with certifi SSL context first, fall back to unverified
|
|
92
|
+
try:
|
|
93
|
+
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
|
94
|
+
with urllib.request.urlopen(url, context=ssl_context) as response:
|
|
95
|
+
with open(dest, 'wb') as out_file:
|
|
96
|
+
out_file.write(response.read())
|
|
97
|
+
except (ImportError, ssl.SSLError):
|
|
98
|
+
# Fallback: disable SSL verification (for development only)
|
|
99
|
+
logger.warning("SSL verification disabled - using unverified context")
|
|
100
|
+
ssl_context = ssl.create_default_context()
|
|
101
|
+
ssl_context.check_hostname = False
|
|
102
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
|
103
|
+
with urllib.request.urlopen(url, context=ssl_context) as response:
|
|
104
|
+
with open(dest, 'wb') as out_file:
|
|
105
|
+
out_file.write(response.read())
|
|
106
|
+
|
|
107
|
+
logger.info(f"Downloaded to {dest}")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _extract_tar_gz(archive: Path, dest_dir: Path):
|
|
111
|
+
"""Extract a tar.gz archive."""
|
|
112
|
+
logger.info(f"Extracting {archive}...")
|
|
113
|
+
with tarfile.open(archive, "r:gz") as tar:
|
|
114
|
+
tar.extractall(dest_dir)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _extract_zip(archive: Path, dest_dir: Path):
|
|
118
|
+
"""Extract a zip archive."""
|
|
119
|
+
logger.info(f"Extracting {archive}...")
|
|
120
|
+
with zipfile.ZipFile(archive, 'r') as zip_ref:
|
|
121
|
+
zip_ref.extractall(dest_dir)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _load_triples(filepath: Path) -> List[Tuple[str, str, str]]:
|
|
125
|
+
"""Load triples from a TSV file (head, relation, tail)."""
|
|
126
|
+
triples = []
|
|
127
|
+
with open(filepath, "r", encoding="utf-8") as f:
|
|
128
|
+
for line in f:
|
|
129
|
+
parts = line.strip().split("\t")
|
|
130
|
+
if len(parts) == 3:
|
|
131
|
+
h, r, t = parts
|
|
132
|
+
triples.append((h, r, t))
|
|
133
|
+
return triples
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _extract_entities_and_relations(
|
|
137
|
+
triples: List[Tuple[str, str, str]]
|
|
138
|
+
) -> Tuple[List[str], List[str]]:
|
|
139
|
+
"""Extract unique entities and relations from triples."""
|
|
140
|
+
entities = set()
|
|
141
|
+
relations = set()
|
|
142
|
+
for h, r, t in triples:
|
|
143
|
+
entities.add(h)
|
|
144
|
+
entities.add(t)
|
|
145
|
+
relations.add(r)
|
|
146
|
+
return sorted(entities), sorted(relations)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _download_dataset_files(base_url: str, dataset_dir: Path):
|
|
150
|
+
"""Download train/valid/test files for a dataset."""
|
|
151
|
+
_ensure_dir(dataset_dir)
|
|
152
|
+
|
|
153
|
+
for split in ["train", "valid", "test"]:
|
|
154
|
+
file_path = dataset_dir / f"{split}.txt"
|
|
155
|
+
if not file_path.exists():
|
|
156
|
+
url = f"{base_url}/{split}.txt"
|
|
157
|
+
_download_file(url, file_path)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def load_fb15k237(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
|
|
161
|
+
"""
|
|
162
|
+
Load FB15k-237 dataset.
|
|
163
|
+
|
|
164
|
+
FB15k-237 is a subset of Freebase with:
|
|
165
|
+
- 14,541 entities
|
|
166
|
+
- 237 relations
|
|
167
|
+
- 310,116 triples
|
|
168
|
+
|
|
169
|
+
This version removes inverse relations from FB15k to prevent
|
|
170
|
+
data leakage during evaluation.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
cache_dir: Directory to cache downloaded data
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
BenchmarkDataset object
|
|
177
|
+
"""
|
|
178
|
+
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
|
|
179
|
+
_ensure_dir(cache_dir)
|
|
180
|
+
|
|
181
|
+
dataset_dir = cache_dir / "FB15k-237"
|
|
182
|
+
|
|
183
|
+
# Download if not cached
|
|
184
|
+
if not (dataset_dir / "train.txt").exists():
|
|
185
|
+
_download_dataset_files(DATASET_BASE_URLS["fb15k237"], dataset_dir)
|
|
186
|
+
|
|
187
|
+
# Load splits
|
|
188
|
+
train = _load_triples(dataset_dir / "train.txt")
|
|
189
|
+
valid = _load_triples(dataset_dir / "valid.txt")
|
|
190
|
+
test = _load_triples(dataset_dir / "test.txt")
|
|
191
|
+
|
|
192
|
+
# Extract vocab
|
|
193
|
+
all_triples = train + valid + test
|
|
194
|
+
entities, relations = _extract_entities_and_relations(all_triples)
|
|
195
|
+
|
|
196
|
+
return BenchmarkDataset(
|
|
197
|
+
name="FB15k-237",
|
|
198
|
+
train_triples=train,
|
|
199
|
+
valid_triples=valid,
|
|
200
|
+
test_triples=test,
|
|
201
|
+
entities=entities,
|
|
202
|
+
relations=relations,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def load_wn18rr(cache_dir: Optional[Path] = None) -> BenchmarkDataset:
|
|
207
|
+
"""
|
|
208
|
+
Load WN18RR dataset.
|
|
209
|
+
|
|
210
|
+
WN18RR is a subset of WordNet with:
|
|
211
|
+
- 40,943 entities
|
|
212
|
+
- 11 relations
|
|
213
|
+
- 93,003 triples
|
|
214
|
+
|
|
215
|
+
This version removes inverse relations from WN18 to prevent
|
|
216
|
+
data leakage during evaluation.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
cache_dir: Directory to cache downloaded data
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
BenchmarkDataset object
|
|
223
|
+
"""
|
|
224
|
+
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR
|
|
225
|
+
_ensure_dir(cache_dir)
|
|
226
|
+
|
|
227
|
+
dataset_dir = cache_dir / "WN18RR"
|
|
228
|
+
|
|
229
|
+
# Download if not cached
|
|
230
|
+
if not (dataset_dir / "train.txt").exists():
|
|
231
|
+
_download_dataset_files(DATASET_BASE_URLS["wn18rr"], dataset_dir)
|
|
232
|
+
|
|
233
|
+
# Load splits
|
|
234
|
+
train = _load_triples(dataset_dir / "train.txt")
|
|
235
|
+
valid = _load_triples(dataset_dir / "valid.txt")
|
|
236
|
+
test = _load_triples(dataset_dir / "test.txt")
|
|
237
|
+
|
|
238
|
+
# Extract vocab
|
|
239
|
+
all_triples = train + valid + test
|
|
240
|
+
entities, relations = _extract_entities_and_relations(all_triples)
|
|
241
|
+
|
|
242
|
+
return BenchmarkDataset(
|
|
243
|
+
name="WN18RR",
|
|
244
|
+
train_triples=train,
|
|
245
|
+
valid_triples=valid,
|
|
246
|
+
test_triples=test,
|
|
247
|
+
entities=entities,
|
|
248
|
+
relations=relations,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def dataset_to_kg(dataset: BenchmarkDataset):
|
|
253
|
+
"""
|
|
254
|
+
Convert BenchmarkDataset to Odin KnowledgeGraph.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
dataset: BenchmarkDataset object
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
KnowledgeGraph object suitable for NPLL training
|
|
261
|
+
"""
|
|
262
|
+
from npll.core import KnowledgeGraph
|
|
263
|
+
|
|
264
|
+
kg = KnowledgeGraph()
|
|
265
|
+
|
|
266
|
+
# Add all training triples as known facts
|
|
267
|
+
for h, r, t in dataset.train_triples:
|
|
268
|
+
kg.add_known_fact(h, r, t)
|
|
269
|
+
|
|
270
|
+
return kg
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# Quick test
|
|
274
|
+
if __name__ == "__main__":
|
|
275
|
+
logging.basicConfig(level=logging.INFO)
|
|
276
|
+
|
|
277
|
+
print("Loading FB15k-237...")
|
|
278
|
+
fb = load_fb15k237()
|
|
279
|
+
print(fb)
|
|
280
|
+
print()
|
|
281
|
+
|
|
282
|
+
print("Loading WN18RR...")
|
|
283
|
+
wn = load_wn18rr()
|
|
284
|
+
print(wn)
|
benchmarks/metrics.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Standard Knowledge Graph Completion Metrics
|
|
3
|
+
|
|
4
|
+
Implements metrics used in academic KG completion papers:
|
|
5
|
+
- Mean Reciprocal Rank (MRR)
|
|
6
|
+
- Hits@K (K=1, 3, 10)
|
|
7
|
+
|
|
8
|
+
These metrics evaluate link prediction quality:
|
|
9
|
+
Given (h, r, ?), rank all entities by predicted score.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from typing import List, Dict, Tuple, Any
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class RankingResult:
|
|
19
|
+
"""Result of ranking evaluation for a single triple."""
|
|
20
|
+
head: str
|
|
21
|
+
relation: str
|
|
22
|
+
tail: str
|
|
23
|
+
tail_rank: int # Rank of true tail among all candidates
|
|
24
|
+
head_rank: int # Rank of true head among all candidates (for inverse)
|
|
25
|
+
num_candidates: int
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def mrr(ranks: List[int]) -> float:
|
|
29
|
+
"""
|
|
30
|
+
Mean Reciprocal Rank
|
|
31
|
+
|
|
32
|
+
MRR = (1/|Q|) * Σ (1/rank_i)
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
ranks: List of ranks (1-indexed, where 1 is best)
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
MRR score in [0, 1]
|
|
39
|
+
"""
|
|
40
|
+
if not ranks:
|
|
41
|
+
return 0.0
|
|
42
|
+
return float(np.mean([1.0 / r for r in ranks]))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def hits_at_k(ranks: List[int], k: int) -> float:
|
|
46
|
+
"""
|
|
47
|
+
Hits@K - proportion of ranks <= K
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
ranks: List of ranks (1-indexed)
|
|
51
|
+
k: Cutoff threshold
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Hits@K score in [0, 1]
|
|
55
|
+
"""
|
|
56
|
+
if not ranks:
|
|
57
|
+
return 0.0
|
|
58
|
+
return float(np.mean([1 if r <= k else 0 for r in ranks]))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def evaluate_rankings(results: List[RankingResult]) -> Dict[str, float]:
|
|
62
|
+
"""
|
|
63
|
+
Compute all standard metrics from ranking results.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
results: List of RankingResult objects
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary with MRR, Hits@1, Hits@3, Hits@10
|
|
70
|
+
"""
|
|
71
|
+
if not results:
|
|
72
|
+
return {
|
|
73
|
+
"mrr": 0.0,
|
|
74
|
+
"hits@1": 0.0,
|
|
75
|
+
"hits@3": 0.0,
|
|
76
|
+
"hits@10": 0.0,
|
|
77
|
+
"num_queries": 0
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# Use tail ranks (standard protocol: predict tail given head+relation)
|
|
81
|
+
tail_ranks = [r.tail_rank for r in results]
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"mrr": mrr(tail_ranks),
|
|
85
|
+
"hits@1": hits_at_k(tail_ranks, 1),
|
|
86
|
+
"hits@3": hits_at_k(tail_ranks, 3),
|
|
87
|
+
"hits@10": hits_at_k(tail_ranks, 10),
|
|
88
|
+
"num_queries": len(results),
|
|
89
|
+
"mean_rank": float(np.mean(tail_ranks)),
|
|
90
|
+
"median_rank": float(np.median(tail_ranks)),
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def filtered_rank(
|
|
95
|
+
true_entity: str,
|
|
96
|
+
scores: Dict[str, float],
|
|
97
|
+
filter_entities: set,
|
|
98
|
+
) -> int:
|
|
99
|
+
"""
|
|
100
|
+
Compute filtered rank (standard protocol).
|
|
101
|
+
|
|
102
|
+
Filtered ranking removes other valid answers from consideration,
|
|
103
|
+
so we only penalize for ranking random entities above the true one.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
true_entity: The correct answer
|
|
107
|
+
scores: Dict mapping entity -> score
|
|
108
|
+
filter_entities: Other valid entities to filter out
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Filtered rank (1-indexed)
|
|
112
|
+
"""
|
|
113
|
+
true_score = scores.get(true_entity, float('-inf'))
|
|
114
|
+
|
|
115
|
+
rank = 1
|
|
116
|
+
for entity, score in scores.items():
|
|
117
|
+
if entity == true_entity:
|
|
118
|
+
continue
|
|
119
|
+
if entity in filter_entities:
|
|
120
|
+
continue # Filter out other valid answers
|
|
121
|
+
if score > true_score:
|
|
122
|
+
rank += 1
|
|
123
|
+
|
|
124
|
+
return rank
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class LinkPredictionEvaluator:
|
|
128
|
+
"""
|
|
129
|
+
Evaluator for link prediction task.
|
|
130
|
+
|
|
131
|
+
Standard protocol:
|
|
132
|
+
1. For each test triple (h, r, t):
|
|
133
|
+
a. Corrupt tail: score all (h, r, e) for e in entities
|
|
134
|
+
b. Corrupt head: score all (e, r, t) for e in entities
|
|
135
|
+
2. Filter out other valid triples (filtered setting)
|
|
136
|
+
3. Report filtered MRR, Hits@1/3/10
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
all_entities: List[str],
|
|
142
|
+
train_triples: set,
|
|
143
|
+
valid_triples: set = None,
|
|
144
|
+
):
|
|
145
|
+
"""
|
|
146
|
+
Args:
|
|
147
|
+
all_entities: List of all entity IDs
|
|
148
|
+
train_triples: Set of (h, r, t) tuples from training
|
|
149
|
+
valid_triples: Set of (h, r, t) tuples from validation (optional)
|
|
150
|
+
"""
|
|
151
|
+
self.all_entities = all_entities
|
|
152
|
+
self.train_triples = train_triples
|
|
153
|
+
self.valid_triples = valid_triples or set()
|
|
154
|
+
self.known_triples = train_triples | self.valid_triples
|
|
155
|
+
|
|
156
|
+
# Build index for filtering
|
|
157
|
+
self._build_filter_index()
|
|
158
|
+
|
|
159
|
+
def _build_filter_index(self):
|
|
160
|
+
"""Build indices for efficient filtering."""
|
|
161
|
+
# For (h, r, ?): which tails are valid?
|
|
162
|
+
self.hr_to_tails = {}
|
|
163
|
+
# For (?, r, t): which heads are valid?
|
|
164
|
+
self.rt_to_heads = {}
|
|
165
|
+
|
|
166
|
+
for h, r, t in self.known_triples:
|
|
167
|
+
key_hr = (h, r)
|
|
168
|
+
key_rt = (r, t)
|
|
169
|
+
|
|
170
|
+
if key_hr not in self.hr_to_tails:
|
|
171
|
+
self.hr_to_tails[key_hr] = set()
|
|
172
|
+
self.hr_to_tails[key_hr].add(t)
|
|
173
|
+
|
|
174
|
+
if key_rt not in self.rt_to_heads:
|
|
175
|
+
self.rt_to_heads[key_rt] = set()
|
|
176
|
+
self.rt_to_heads[key_rt].add(h)
|
|
177
|
+
|
|
178
|
+
def evaluate_triple(
|
|
179
|
+
self,
|
|
180
|
+
head: str,
|
|
181
|
+
relation: str,
|
|
182
|
+
tail: str,
|
|
183
|
+
score_fn,
|
|
184
|
+
) -> RankingResult:
|
|
185
|
+
"""
|
|
186
|
+
Evaluate a single test triple.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
head: Head entity
|
|
190
|
+
relation: Relation type
|
|
191
|
+
tail: True tail entity
|
|
192
|
+
score_fn: Function (h, r, t) -> float
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
RankingResult with filtered ranks
|
|
196
|
+
"""
|
|
197
|
+
# Score all possible tails
|
|
198
|
+
tail_scores = {}
|
|
199
|
+
for entity in self.all_entities:
|
|
200
|
+
tail_scores[entity] = score_fn(head, relation, entity)
|
|
201
|
+
|
|
202
|
+
# Get filter set (other valid tails for this h,r pair)
|
|
203
|
+
filter_tails = self.hr_to_tails.get((head, relation), set()) - {tail}
|
|
204
|
+
|
|
205
|
+
# Compute filtered rank
|
|
206
|
+
tail_rank = filtered_rank(tail, tail_scores, filter_tails)
|
|
207
|
+
|
|
208
|
+
# Score all possible heads (for bidirectional evaluation)
|
|
209
|
+
head_scores = {}
|
|
210
|
+
for entity in self.all_entities:
|
|
211
|
+
head_scores[entity] = score_fn(entity, relation, tail)
|
|
212
|
+
|
|
213
|
+
# Get filter set for heads
|
|
214
|
+
filter_heads = self.rt_to_heads.get((relation, tail), set()) - {head}
|
|
215
|
+
head_rank = filtered_rank(head, head_scores, filter_heads)
|
|
216
|
+
|
|
217
|
+
return RankingResult(
|
|
218
|
+
head=head,
|
|
219
|
+
relation=relation,
|
|
220
|
+
tail=tail,
|
|
221
|
+
tail_rank=tail_rank,
|
|
222
|
+
head_rank=head_rank,
|
|
223
|
+
num_candidates=len(self.all_entities),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def evaluate_batch(
|
|
227
|
+
self,
|
|
228
|
+
test_triples: List[Tuple[str, str, str]],
|
|
229
|
+
score_fn,
|
|
230
|
+
verbose: bool = True,
|
|
231
|
+
) -> Dict[str, float]:
|
|
232
|
+
"""
|
|
233
|
+
Evaluate on a batch of test triples.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
test_triples: List of (h, r, t) tuples
|
|
237
|
+
score_fn: Function (h, r, t) -> float
|
|
238
|
+
verbose: Print progress
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dictionary with all metrics
|
|
242
|
+
"""
|
|
243
|
+
results = []
|
|
244
|
+
|
|
245
|
+
for i, (h, r, t) in enumerate(test_triples):
|
|
246
|
+
if verbose and (i + 1) % 100 == 0:
|
|
247
|
+
print(f" Evaluated {i + 1}/{len(test_triples)} triples...")
|
|
248
|
+
|
|
249
|
+
result = self.evaluate_triple(h, r, t, score_fn)
|
|
250
|
+
results.append(result)
|
|
251
|
+
|
|
252
|
+
return evaluate_rankings(results)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# Convenience function for quick evaluation
|
|
256
|
+
def quick_evaluate(
|
|
257
|
+
test_triples: List[Tuple[str, str, str]],
|
|
258
|
+
score_fn,
|
|
259
|
+
all_entities: List[str],
|
|
260
|
+
train_triples: set,
|
|
261
|
+
) -> Dict[str, float]:
|
|
262
|
+
"""
|
|
263
|
+
Quick evaluation helper.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
test_triples: Test set
|
|
267
|
+
score_fn: Scoring function
|
|
268
|
+
all_entities: All entities in KG
|
|
269
|
+
train_triples: Training triples for filtering
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Metrics dictionary
|
|
273
|
+
"""
|
|
274
|
+
evaluator = LinkPredictionEvaluator(all_entities, train_triples)
|
|
275
|
+
return evaluator.evaluate_batch(test_triples, score_fn)
|