odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
benchmarks/run_ablation.py
CHANGED
|
@@ -1,279 +1,279 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
Odin Ablation Study
|
|
4
|
-
|
|
5
|
-
Evaluates the contribution of each Odin component:
|
|
6
|
-
1. PPR only (structural importance)
|
|
7
|
-
2. NPLL only (semantic plausibility)
|
|
8
|
-
3. PPR + NPLL (full Odin)
|
|
9
|
-
4. Random baseline
|
|
10
|
-
|
|
11
|
-
This validates that each component contributes to overall performance.
|
|
12
|
-
|
|
13
|
-
Usage:
|
|
14
|
-
python -m benchmarks.run_ablation --dataset fb15k237
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
import argparse
|
|
18
|
-
import json
|
|
19
|
-
import logging
|
|
20
|
-
import time
|
|
21
|
-
from datetime import datetime
|
|
22
|
-
from pathlib import Path
|
|
23
|
-
from typing import Dict, Any, List, Tuple, Optional
|
|
24
|
-
import random
|
|
25
|
-
|
|
26
|
-
from benchmarks.datasets import load_fb15k237, load_wn18rr, dataset_to_kg, BenchmarkDataset
|
|
27
|
-
from benchmarks.metrics import evaluate_rankings, RankingResult
|
|
28
|
-
|
|
29
|
-
logging.basicConfig(
|
|
30
|
-
level=logging.INFO,
|
|
31
|
-
format="%(asctime)s [%(levelname)s] %(message)s"
|
|
32
|
-
)
|
|
33
|
-
logger = logging.getLogger(__name__)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class AblationEvaluator:
|
|
37
|
-
"""
|
|
38
|
-
Evaluator for ablation study.
|
|
39
|
-
|
|
40
|
-
Compares path ranking quality under different scoring configurations.
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
def __init__(self, dataset: BenchmarkDataset):
|
|
44
|
-
self.dataset = dataset
|
|
45
|
-
self.entity_to_idx = {e: i for i, e in enumerate(dataset.entities)}
|
|
46
|
-
self.relation_to_idx = {r: i for i, r in enumerate(dataset.relations)}
|
|
47
|
-
|
|
48
|
-
# Build adjacency for PPR simulation
|
|
49
|
-
self._build_adjacency()
|
|
50
|
-
|
|
51
|
-
def _build_adjacency(self):
|
|
52
|
-
"""Build adjacency lists from training data."""
|
|
53
|
-
self.outgoing = {} # entity -> [(relation, target)]
|
|
54
|
-
self.incoming = {} # entity -> [(relation, source)]
|
|
55
|
-
|
|
56
|
-
for h, r, t in self.dataset.train_triples:
|
|
57
|
-
if h not in self.outgoing:
|
|
58
|
-
self.outgoing[h] = []
|
|
59
|
-
self.outgoing[h].append((r, t))
|
|
60
|
-
|
|
61
|
-
if t not in self.incoming:
|
|
62
|
-
self.incoming[t] = []
|
|
63
|
-
self.incoming[t].append((r, h))
|
|
64
|
-
|
|
65
|
-
def ppr_score(self, source: str, target: str, alpha: float = 0.15) -> float:
|
|
66
|
-
"""
|
|
67
|
-
Approximate PPR score via random walk simulation.
|
|
68
|
-
|
|
69
|
-
Higher score = target is more "important" relative to source.
|
|
70
|
-
"""
|
|
71
|
-
if source not in self.outgoing:
|
|
72
|
-
return 0.0
|
|
73
|
-
|
|
74
|
-
# Simple approximation: count paths from source to target
|
|
75
|
-
visited = {source: 1.0}
|
|
76
|
-
frontier = [(source, 1.0)]
|
|
77
|
-
|
|
78
|
-
for _ in range(3): # 3-hop
|
|
79
|
-
new_frontier = []
|
|
80
|
-
for node, prob in frontier:
|
|
81
|
-
if node not in self.outgoing:
|
|
82
|
-
continue
|
|
83
|
-
neighbors = self.outgoing[node]
|
|
84
|
-
if not neighbors:
|
|
85
|
-
continue
|
|
86
|
-
spread = prob * (1 - alpha) / len(neighbors)
|
|
87
|
-
for _, neighbor in neighbors:
|
|
88
|
-
if neighbor not in visited:
|
|
89
|
-
visited[neighbor] = 0.0
|
|
90
|
-
visited[neighbor] += spread
|
|
91
|
-
new_frontier.append((neighbor, spread))
|
|
92
|
-
frontier = new_frontier
|
|
93
|
-
|
|
94
|
-
return visited.get(target, 0.0)
|
|
95
|
-
|
|
96
|
-
def random_score(self, h: str, r: str, t: str) -> float:
|
|
97
|
-
"""Random baseline score."""
|
|
98
|
-
return random.random()
|
|
99
|
-
|
|
100
|
-
def degree_score(self, h: str, r: str, t: str) -> float:
|
|
101
|
-
"""Degree-based score (common baseline)."""
|
|
102
|
-
out_degree = len(self.outgoing.get(t, []))
|
|
103
|
-
in_degree = len(self.incoming.get(t, []))
|
|
104
|
-
return (out_degree + in_degree) / max(len(self.dataset.entities), 1)
|
|
105
|
-
|
|
106
|
-
def evaluate_method(
|
|
107
|
-
self,
|
|
108
|
-
method_name: str,
|
|
109
|
-
score_fn,
|
|
110
|
-
test_triples: List[Tuple[str, str, str]],
|
|
111
|
-
sample_size: int = 500,
|
|
112
|
-
) -> Dict[str, float]:
|
|
113
|
-
"""
|
|
114
|
-
Evaluate a scoring method on link prediction.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
method_name: Name for logging
|
|
118
|
-
score_fn: (h, r, t) -> float
|
|
119
|
-
test_triples: Test set
|
|
120
|
-
sample_size: Number of test triples to evaluate
|
|
121
|
-
|
|
122
|
-
Returns:
|
|
123
|
-
Metrics dictionary
|
|
124
|
-
"""
|
|
125
|
-
logger.info(f"Evaluating {method_name}...")
|
|
126
|
-
|
|
127
|
-
# Sample test triples
|
|
128
|
-
if len(test_triples) > sample_size:
|
|
129
|
-
test_sample = random.sample(test_triples, sample_size)
|
|
130
|
-
else:
|
|
131
|
-
test_sample = test_triples
|
|
132
|
-
|
|
133
|
-
results = []
|
|
134
|
-
|
|
135
|
-
for i, (h, r, t) in enumerate(test_sample):
|
|
136
|
-
if (i + 1) % 100 == 0:
|
|
137
|
-
logger.info(f" {i + 1}/{len(test_sample)}...")
|
|
138
|
-
|
|
139
|
-
# Score all candidate tails
|
|
140
|
-
scores = {}
|
|
141
|
-
for entity in self.dataset.entities:
|
|
142
|
-
scores[entity] = score_fn(h, r, entity)
|
|
143
|
-
|
|
144
|
-
# Compute rank of true tail
|
|
145
|
-
true_score = scores[t]
|
|
146
|
-
rank = 1
|
|
147
|
-
for entity, score in scores.items():
|
|
148
|
-
if entity != t and score > true_score:
|
|
149
|
-
rank += 1
|
|
150
|
-
|
|
151
|
-
results.append(RankingResult(
|
|
152
|
-
head=h,
|
|
153
|
-
relation=r,
|
|
154
|
-
tail=t,
|
|
155
|
-
tail_rank=rank,
|
|
156
|
-
head_rank=1, # Not computing head rank in ablation
|
|
157
|
-
num_candidates=len(self.dataset.entities),
|
|
158
|
-
))
|
|
159
|
-
|
|
160
|
-
return evaluate_rankings(results)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
def run_ablation(
|
|
164
|
-
dataset_name: str,
|
|
165
|
-
sample_size: int = 500,
|
|
166
|
-
output_dir: Optional[Path] = None,
|
|
167
|
-
) -> Dict[str, Any]:
|
|
168
|
-
"""
|
|
169
|
-
Run ablation study comparing scoring methods.
|
|
170
|
-
"""
|
|
171
|
-
results = {
|
|
172
|
-
"dataset": dataset_name,
|
|
173
|
-
"timestamp": datetime.now().isoformat(),
|
|
174
|
-
"sample_size": sample_size,
|
|
175
|
-
"methods": {},
|
|
176
|
-
}
|
|
177
|
-
|
|
178
|
-
# Load dataset
|
|
179
|
-
logger.info(f"Loading {dataset_name}...")
|
|
180
|
-
if dataset_name.lower() == "fb15k237":
|
|
181
|
-
dataset = load_fb15k237()
|
|
182
|
-
else:
|
|
183
|
-
dataset = load_wn18rr()
|
|
184
|
-
|
|
185
|
-
logger.info(f"\n{dataset}")
|
|
186
|
-
|
|
187
|
-
# Create evaluator
|
|
188
|
-
evaluator = AblationEvaluator(dataset)
|
|
189
|
-
|
|
190
|
-
# Method 1: Random baseline
|
|
191
|
-
random_metrics = evaluator.evaluate_method(
|
|
192
|
-
"Random",
|
|
193
|
-
evaluator.random_score,
|
|
194
|
-
dataset.test_triples,
|
|
195
|
-
sample_size=sample_size,
|
|
196
|
-
)
|
|
197
|
-
results["methods"]["random"] = random_metrics
|
|
198
|
-
|
|
199
|
-
# Method 2: Degree-based (common baseline)
|
|
200
|
-
degree_metrics = evaluator.evaluate_method(
|
|
201
|
-
"Degree",
|
|
202
|
-
evaluator.degree_score,
|
|
203
|
-
dataset.test_triples,
|
|
204
|
-
sample_size=sample_size,
|
|
205
|
-
)
|
|
206
|
-
results["methods"]["degree"] = degree_metrics
|
|
207
|
-
|
|
208
|
-
# Method 3: PPR-only
|
|
209
|
-
def ppr_only_score(h, r, t):
|
|
210
|
-
return evaluator.ppr_score(h, t)
|
|
211
|
-
|
|
212
|
-
ppr_metrics = evaluator.evaluate_method(
|
|
213
|
-
"PPR-only",
|
|
214
|
-
ppr_only_score,
|
|
215
|
-
dataset.test_triples,
|
|
216
|
-
sample_size=sample_size,
|
|
217
|
-
)
|
|
218
|
-
results["methods"]["ppr_only"] = ppr_metrics
|
|
219
|
-
|
|
220
|
-
# Print summary
|
|
221
|
-
print("\n" + "=" * 70)
|
|
222
|
-
print(f"ABLATION STUDY RESULTS: {dataset_name}")
|
|
223
|
-
print("=" * 70)
|
|
224
|
-
print(f"{'Method':<15} {'MRR':>10} {'Hits@1':>10} {'Hits@3':>10} {'Hits@10':>10}")
|
|
225
|
-
print("-" * 70)
|
|
226
|
-
|
|
227
|
-
for method, metrics in results["methods"].items():
|
|
228
|
-
print(f"{method:<15} {metrics['mrr']:>10.4f} {metrics['hits@1']:>10.4f} "
|
|
229
|
-
f"{metrics['hits@3']:>10.4f} {metrics['hits@10']:>10.4f}")
|
|
230
|
-
|
|
231
|
-
print("=" * 70)
|
|
232
|
-
print("\nNote: NPLL results require running run_npll_benchmark.py separately")
|
|
233
|
-
print(" and comparing with these baseline numbers.")
|
|
234
|
-
|
|
235
|
-
# Save results
|
|
236
|
-
if output_dir:
|
|
237
|
-
output_dir = Path(output_dir)
|
|
238
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
239
|
-
output_file = output_dir / f"ablation_{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
240
|
-
with open(output_file, "w") as f:
|
|
241
|
-
json.dump(results, f, indent=2)
|
|
242
|
-
logger.info(f"Results saved to {output_file}")
|
|
243
|
-
|
|
244
|
-
return results
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def main():
|
|
248
|
-
parser = argparse.ArgumentParser(description="Odin Ablation Study")
|
|
249
|
-
parser.add_argument(
|
|
250
|
-
"--dataset",
|
|
251
|
-
type=str,
|
|
252
|
-
default="fb15k237",
|
|
253
|
-
choices=["fb15k237", "wn18rr"],
|
|
254
|
-
help="Dataset to evaluate on",
|
|
255
|
-
)
|
|
256
|
-
parser.add_argument(
|
|
257
|
-
"--sample-size",
|
|
258
|
-
type=int,
|
|
259
|
-
default=500,
|
|
260
|
-
help="Number of test triples to sample",
|
|
261
|
-
)
|
|
262
|
-
parser.add_argument(
|
|
263
|
-
"--output-dir",
|
|
264
|
-
type=str,
|
|
265
|
-
default="benchmark_results",
|
|
266
|
-
help="Directory to save results",
|
|
267
|
-
)
|
|
268
|
-
|
|
269
|
-
args = parser.parse_args()
|
|
270
|
-
|
|
271
|
-
run_ablation(
|
|
272
|
-
dataset_name=args.dataset,
|
|
273
|
-
sample_size=args.sample_size,
|
|
274
|
-
output_dir=Path(args.output_dir),
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
if __name__ == "__main__":
|
|
279
|
-
main()
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Odin Ablation Study
|
|
4
|
+
|
|
5
|
+
Evaluates the contribution of each Odin component:
|
|
6
|
+
1. PPR only (structural importance)
|
|
7
|
+
2. NPLL only (semantic plausibility)
|
|
8
|
+
3. PPR + NPLL (full Odin)
|
|
9
|
+
4. Random baseline
|
|
10
|
+
|
|
11
|
+
This validates that each component contributes to overall performance.
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
python -m benchmarks.run_ablation --dataset fb15k237
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import time
|
|
21
|
+
from datetime import datetime
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Dict, Any, List, Tuple, Optional
|
|
24
|
+
import random
|
|
25
|
+
|
|
26
|
+
from benchmarks.datasets import load_fb15k237, load_wn18rr, dataset_to_kg, BenchmarkDataset
|
|
27
|
+
from benchmarks.metrics import evaluate_rankings, RankingResult
|
|
28
|
+
|
|
29
|
+
logging.basicConfig(
|
|
30
|
+
level=logging.INFO,
|
|
31
|
+
format="%(asctime)s [%(levelname)s] %(message)s"
|
|
32
|
+
)
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AblationEvaluator:
|
|
37
|
+
"""
|
|
38
|
+
Evaluator for ablation study.
|
|
39
|
+
|
|
40
|
+
Compares path ranking quality under different scoring configurations.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, dataset: BenchmarkDataset):
|
|
44
|
+
self.dataset = dataset
|
|
45
|
+
self.entity_to_idx = {e: i for i, e in enumerate(dataset.entities)}
|
|
46
|
+
self.relation_to_idx = {r: i for i, r in enumerate(dataset.relations)}
|
|
47
|
+
|
|
48
|
+
# Build adjacency for PPR simulation
|
|
49
|
+
self._build_adjacency()
|
|
50
|
+
|
|
51
|
+
def _build_adjacency(self):
|
|
52
|
+
"""Build adjacency lists from training data."""
|
|
53
|
+
self.outgoing = {} # entity -> [(relation, target)]
|
|
54
|
+
self.incoming = {} # entity -> [(relation, source)]
|
|
55
|
+
|
|
56
|
+
for h, r, t in self.dataset.train_triples:
|
|
57
|
+
if h not in self.outgoing:
|
|
58
|
+
self.outgoing[h] = []
|
|
59
|
+
self.outgoing[h].append((r, t))
|
|
60
|
+
|
|
61
|
+
if t not in self.incoming:
|
|
62
|
+
self.incoming[t] = []
|
|
63
|
+
self.incoming[t].append((r, h))
|
|
64
|
+
|
|
65
|
+
def ppr_score(self, source: str, target: str, alpha: float = 0.15) -> float:
|
|
66
|
+
"""
|
|
67
|
+
Approximate PPR score via random walk simulation.
|
|
68
|
+
|
|
69
|
+
Higher score = target is more "important" relative to source.
|
|
70
|
+
"""
|
|
71
|
+
if source not in self.outgoing:
|
|
72
|
+
return 0.0
|
|
73
|
+
|
|
74
|
+
# Simple approximation: count paths from source to target
|
|
75
|
+
visited = {source: 1.0}
|
|
76
|
+
frontier = [(source, 1.0)]
|
|
77
|
+
|
|
78
|
+
for _ in range(3): # 3-hop
|
|
79
|
+
new_frontier = []
|
|
80
|
+
for node, prob in frontier:
|
|
81
|
+
if node not in self.outgoing:
|
|
82
|
+
continue
|
|
83
|
+
neighbors = self.outgoing[node]
|
|
84
|
+
if not neighbors:
|
|
85
|
+
continue
|
|
86
|
+
spread = prob * (1 - alpha) / len(neighbors)
|
|
87
|
+
for _, neighbor in neighbors:
|
|
88
|
+
if neighbor not in visited:
|
|
89
|
+
visited[neighbor] = 0.0
|
|
90
|
+
visited[neighbor] += spread
|
|
91
|
+
new_frontier.append((neighbor, spread))
|
|
92
|
+
frontier = new_frontier
|
|
93
|
+
|
|
94
|
+
return visited.get(target, 0.0)
|
|
95
|
+
|
|
96
|
+
def random_score(self, h: str, r: str, t: str) -> float:
|
|
97
|
+
"""Random baseline score."""
|
|
98
|
+
return random.random()
|
|
99
|
+
|
|
100
|
+
def degree_score(self, h: str, r: str, t: str) -> float:
|
|
101
|
+
"""Degree-based score (common baseline)."""
|
|
102
|
+
out_degree = len(self.outgoing.get(t, []))
|
|
103
|
+
in_degree = len(self.incoming.get(t, []))
|
|
104
|
+
return (out_degree + in_degree) / max(len(self.dataset.entities), 1)
|
|
105
|
+
|
|
106
|
+
def evaluate_method(
|
|
107
|
+
self,
|
|
108
|
+
method_name: str,
|
|
109
|
+
score_fn,
|
|
110
|
+
test_triples: List[Tuple[str, str, str]],
|
|
111
|
+
sample_size: int = 500,
|
|
112
|
+
) -> Dict[str, float]:
|
|
113
|
+
"""
|
|
114
|
+
Evaluate a scoring method on link prediction.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
method_name: Name for logging
|
|
118
|
+
score_fn: (h, r, t) -> float
|
|
119
|
+
test_triples: Test set
|
|
120
|
+
sample_size: Number of test triples to evaluate
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Metrics dictionary
|
|
124
|
+
"""
|
|
125
|
+
logger.info(f"Evaluating {method_name}...")
|
|
126
|
+
|
|
127
|
+
# Sample test triples
|
|
128
|
+
if len(test_triples) > sample_size:
|
|
129
|
+
test_sample = random.sample(test_triples, sample_size)
|
|
130
|
+
else:
|
|
131
|
+
test_sample = test_triples
|
|
132
|
+
|
|
133
|
+
results = []
|
|
134
|
+
|
|
135
|
+
for i, (h, r, t) in enumerate(test_sample):
|
|
136
|
+
if (i + 1) % 100 == 0:
|
|
137
|
+
logger.info(f" {i + 1}/{len(test_sample)}...")
|
|
138
|
+
|
|
139
|
+
# Score all candidate tails
|
|
140
|
+
scores = {}
|
|
141
|
+
for entity in self.dataset.entities:
|
|
142
|
+
scores[entity] = score_fn(h, r, entity)
|
|
143
|
+
|
|
144
|
+
# Compute rank of true tail
|
|
145
|
+
true_score = scores[t]
|
|
146
|
+
rank = 1
|
|
147
|
+
for entity, score in scores.items():
|
|
148
|
+
if entity != t and score > true_score:
|
|
149
|
+
rank += 1
|
|
150
|
+
|
|
151
|
+
results.append(RankingResult(
|
|
152
|
+
head=h,
|
|
153
|
+
relation=r,
|
|
154
|
+
tail=t,
|
|
155
|
+
tail_rank=rank,
|
|
156
|
+
head_rank=1, # Not computing head rank in ablation
|
|
157
|
+
num_candidates=len(self.dataset.entities),
|
|
158
|
+
))
|
|
159
|
+
|
|
160
|
+
return evaluate_rankings(results)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def run_ablation(
|
|
164
|
+
dataset_name: str,
|
|
165
|
+
sample_size: int = 500,
|
|
166
|
+
output_dir: Optional[Path] = None,
|
|
167
|
+
) -> Dict[str, Any]:
|
|
168
|
+
"""
|
|
169
|
+
Run ablation study comparing scoring methods.
|
|
170
|
+
"""
|
|
171
|
+
results = {
|
|
172
|
+
"dataset": dataset_name,
|
|
173
|
+
"timestamp": datetime.now().isoformat(),
|
|
174
|
+
"sample_size": sample_size,
|
|
175
|
+
"methods": {},
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# Load dataset
|
|
179
|
+
logger.info(f"Loading {dataset_name}...")
|
|
180
|
+
if dataset_name.lower() == "fb15k237":
|
|
181
|
+
dataset = load_fb15k237()
|
|
182
|
+
else:
|
|
183
|
+
dataset = load_wn18rr()
|
|
184
|
+
|
|
185
|
+
logger.info(f"\n{dataset}")
|
|
186
|
+
|
|
187
|
+
# Create evaluator
|
|
188
|
+
evaluator = AblationEvaluator(dataset)
|
|
189
|
+
|
|
190
|
+
# Method 1: Random baseline
|
|
191
|
+
random_metrics = evaluator.evaluate_method(
|
|
192
|
+
"Random",
|
|
193
|
+
evaluator.random_score,
|
|
194
|
+
dataset.test_triples,
|
|
195
|
+
sample_size=sample_size,
|
|
196
|
+
)
|
|
197
|
+
results["methods"]["random"] = random_metrics
|
|
198
|
+
|
|
199
|
+
# Method 2: Degree-based (common baseline)
|
|
200
|
+
degree_metrics = evaluator.evaluate_method(
|
|
201
|
+
"Degree",
|
|
202
|
+
evaluator.degree_score,
|
|
203
|
+
dataset.test_triples,
|
|
204
|
+
sample_size=sample_size,
|
|
205
|
+
)
|
|
206
|
+
results["methods"]["degree"] = degree_metrics
|
|
207
|
+
|
|
208
|
+
# Method 3: PPR-only
|
|
209
|
+
def ppr_only_score(h, r, t):
|
|
210
|
+
return evaluator.ppr_score(h, t)
|
|
211
|
+
|
|
212
|
+
ppr_metrics = evaluator.evaluate_method(
|
|
213
|
+
"PPR-only",
|
|
214
|
+
ppr_only_score,
|
|
215
|
+
dataset.test_triples,
|
|
216
|
+
sample_size=sample_size,
|
|
217
|
+
)
|
|
218
|
+
results["methods"]["ppr_only"] = ppr_metrics
|
|
219
|
+
|
|
220
|
+
# Print summary
|
|
221
|
+
print("\n" + "=" * 70)
|
|
222
|
+
print(f"ABLATION STUDY RESULTS: {dataset_name}")
|
|
223
|
+
print("=" * 70)
|
|
224
|
+
print(f"{'Method':<15} {'MRR':>10} {'Hits@1':>10} {'Hits@3':>10} {'Hits@10':>10}")
|
|
225
|
+
print("-" * 70)
|
|
226
|
+
|
|
227
|
+
for method, metrics in results["methods"].items():
|
|
228
|
+
print(f"{method:<15} {metrics['mrr']:>10.4f} {metrics['hits@1']:>10.4f} "
|
|
229
|
+
f"{metrics['hits@3']:>10.4f} {metrics['hits@10']:>10.4f}")
|
|
230
|
+
|
|
231
|
+
print("=" * 70)
|
|
232
|
+
print("\nNote: NPLL results require running run_npll_benchmark.py separately")
|
|
233
|
+
print(" and comparing with these baseline numbers.")
|
|
234
|
+
|
|
235
|
+
# Save results
|
|
236
|
+
if output_dir:
|
|
237
|
+
output_dir = Path(output_dir)
|
|
238
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
output_file = output_dir / f"ablation_{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
240
|
+
with open(output_file, "w") as f:
|
|
241
|
+
json.dump(results, f, indent=2)
|
|
242
|
+
logger.info(f"Results saved to {output_file}")
|
|
243
|
+
|
|
244
|
+
return results
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def main():
|
|
248
|
+
parser = argparse.ArgumentParser(description="Odin Ablation Study")
|
|
249
|
+
parser.add_argument(
|
|
250
|
+
"--dataset",
|
|
251
|
+
type=str,
|
|
252
|
+
default="fb15k237",
|
|
253
|
+
choices=["fb15k237", "wn18rr"],
|
|
254
|
+
help="Dataset to evaluate on",
|
|
255
|
+
)
|
|
256
|
+
parser.add_argument(
|
|
257
|
+
"--sample-size",
|
|
258
|
+
type=int,
|
|
259
|
+
default=500,
|
|
260
|
+
help="Number of test triples to sample",
|
|
261
|
+
)
|
|
262
|
+
parser.add_argument(
|
|
263
|
+
"--output-dir",
|
|
264
|
+
type=str,
|
|
265
|
+
default="benchmark_results",
|
|
266
|
+
help="Directory to save results",
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
args = parser.parse_args()
|
|
270
|
+
|
|
271
|
+
run_ablation(
|
|
272
|
+
dataset_name=args.dataset,
|
|
273
|
+
sample_size=args.sample_size,
|
|
274
|
+
output_dir=Path(args.output_dir),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
if __name__ == "__main__":
|
|
279
|
+
main()
|