odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
benchmarks/run_npll_benchmark.py
CHANGED
|
@@ -1,270 +1,270 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
NPLL Benchmark Runner
|
|
4
|
-
|
|
5
|
-
Evaluates NPLL on standard KG completion benchmarks (FB15k-237, WN18RR)
|
|
6
|
-
and reports standard metrics (MRR, Hits@K).
|
|
7
|
-
|
|
8
|
-
Usage:
|
|
9
|
-
python -m benchmarks.run_npll_benchmark --dataset fb15k237
|
|
10
|
-
python -m benchmarks.run_npll_benchmark --dataset wn18rr
|
|
11
|
-
python -m benchmarks.run_npll_benchmark --dataset fb15k237 --test-subset 1000
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
import argparse
|
|
15
|
-
import json
|
|
16
|
-
import logging
|
|
17
|
-
import time
|
|
18
|
-
from datetime import datetime
|
|
19
|
-
from pathlib import Path
|
|
20
|
-
from typing import Dict, Any
|
|
21
|
-
|
|
22
|
-
import torch
|
|
23
|
-
|
|
24
|
-
from benchmarks.datasets import load_fb15k237, load_wn18rr, dataset_to_kg, BenchmarkDataset
|
|
25
|
-
from benchmarks.metrics import LinkPredictionEvaluator, evaluate_rankings
|
|
26
|
-
from npll import NPLLModel
|
|
27
|
-
from npll.core import KnowledgeGraph, RuleGenerator
|
|
28
|
-
from npll.utils import NPLLConfig, get_config
|
|
29
|
-
|
|
30
|
-
logging.basicConfig(
|
|
31
|
-
level=logging.INFO,
|
|
32
|
-
format="%(asctime)s [%(levelname)s] %(message)s"
|
|
33
|
-
)
|
|
34
|
-
logger = logging.getLogger(__name__)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def load_dataset(name: str) -> BenchmarkDataset:
|
|
38
|
-
"""Load dataset by name."""
|
|
39
|
-
if name.lower() == "fb15k237":
|
|
40
|
-
return load_fb15k237()
|
|
41
|
-
elif name.lower() == "wn18rr":
|
|
42
|
-
return load_wn18rr()
|
|
43
|
-
else:
|
|
44
|
-
raise ValueError(f"Unknown dataset: {name}")
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def create_npll_model(kg: KnowledgeGraph, config: NPLLConfig) -> NPLLModel:
|
|
48
|
-
"""Create and initialize NPLL model."""
|
|
49
|
-
model = NPLLModel(config)
|
|
50
|
-
|
|
51
|
-
# Generate rules from the knowledge graph
|
|
52
|
-
logger.info("Generating logical rules...")
|
|
53
|
-
rule_gen = RuleGenerator()
|
|
54
|
-
rules = rule_gen.generate_rules(kg, max_chain_length=2)
|
|
55
|
-
|
|
56
|
-
# If no rules generated, add universal rules
|
|
57
|
-
if not rules:
|
|
58
|
-
logger.warning("No rules generated, adding universal fallback rules")
|
|
59
|
-
rules = rule_gen.generate_universal_rules(kg)
|
|
60
|
-
|
|
61
|
-
logger.info(f"Generated {len(rules)} rules")
|
|
62
|
-
|
|
63
|
-
# Initialize with KG and rules
|
|
64
|
-
model.initialize(kg, rules)
|
|
65
|
-
|
|
66
|
-
return model
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def train_npll(model: NPLLModel, epochs: int = 10) -> Dict[str, Any]:
|
|
70
|
-
"""Train NPLL model."""
|
|
71
|
-
logger.info(f"Training NPLL for {epochs} epochs...")
|
|
72
|
-
start_time = time.time()
|
|
73
|
-
|
|
74
|
-
training_state = model.train_model(
|
|
75
|
-
num_epochs=epochs,
|
|
76
|
-
em_iterations=5,
|
|
77
|
-
verbose=True,
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
training_time = time.time() - start_time
|
|
81
|
-
|
|
82
|
-
return {
|
|
83
|
-
"epochs": epochs,
|
|
84
|
-
"training_time_seconds": training_time,
|
|
85
|
-
"final_elbo": training_state.best_elbo if training_state else None,
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def evaluate_npll(
|
|
90
|
-
model: NPLLModel,
|
|
91
|
-
dataset: BenchmarkDataset,
|
|
92
|
-
test_subset: int = None,
|
|
93
|
-
) -> Dict[str, float]:
|
|
94
|
-
"""
|
|
95
|
-
Evaluate NPLL on link prediction task.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
model: Trained NPLL model
|
|
99
|
-
dataset: Benchmark dataset
|
|
100
|
-
test_subset: Limit test to first N triples (for faster evaluation)
|
|
101
|
-
|
|
102
|
-
Returns:
|
|
103
|
-
Metrics dictionary
|
|
104
|
-
"""
|
|
105
|
-
logger.info("Evaluating on link prediction task...")
|
|
106
|
-
|
|
107
|
-
# Prepare test triples
|
|
108
|
-
test_triples = dataset.test_triples
|
|
109
|
-
if test_subset:
|
|
110
|
-
test_triples = test_triples[:test_subset]
|
|
111
|
-
logger.info(f"Using subset of {len(test_triples)} test triples")
|
|
112
|
-
|
|
113
|
-
# Create evaluator
|
|
114
|
-
evaluator = LinkPredictionEvaluator(
|
|
115
|
-
all_entities=dataset.entities,
|
|
116
|
-
train_triples=dataset.get_train_set(),
|
|
117
|
-
valid_triples=set(dataset.valid_triples),
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
# Define scoring function using NPLL
|
|
121
|
-
def score_fn(h: str, r: str, t: str) -> float:
|
|
122
|
-
try:
|
|
123
|
-
# Use NPLL model to score the triple
|
|
124
|
-
scores = model.score_triples([(h, r, t)])
|
|
125
|
-
return float(scores[0]) if scores else 0.0
|
|
126
|
-
except Exception:
|
|
127
|
-
return 0.0
|
|
128
|
-
|
|
129
|
-
# Run evaluation
|
|
130
|
-
start_time = time.time()
|
|
131
|
-
metrics = evaluator.evaluate_batch(test_triples, score_fn, verbose=True)
|
|
132
|
-
eval_time = time.time() - start_time
|
|
133
|
-
|
|
134
|
-
metrics["evaluation_time_seconds"] = eval_time
|
|
135
|
-
metrics["triples_evaluated"] = len(test_triples)
|
|
136
|
-
|
|
137
|
-
return metrics
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def run_benchmark(
|
|
141
|
-
dataset_name: str,
|
|
142
|
-
epochs: int = 10,
|
|
143
|
-
test_subset: int = None,
|
|
144
|
-
output_dir: Path = None,
|
|
145
|
-
) -> Dict[str, Any]:
|
|
146
|
-
"""
|
|
147
|
-
Run full NPLL benchmark.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
dataset_name: Name of dataset (fb15k237 or wn18rr)
|
|
151
|
-
epochs: Training epochs
|
|
152
|
-
test_subset: Limit test evaluation (for speed)
|
|
153
|
-
output_dir: Directory to save results
|
|
154
|
-
|
|
155
|
-
Returns:
|
|
156
|
-
Full results dictionary
|
|
157
|
-
"""
|
|
158
|
-
results = {
|
|
159
|
-
"dataset": dataset_name,
|
|
160
|
-
"timestamp": datetime.now().isoformat(),
|
|
161
|
-
"config": {},
|
|
162
|
-
"dataset_stats": {},
|
|
163
|
-
"training": {},
|
|
164
|
-
"evaluation": {},
|
|
165
|
-
}
|
|
166
|
-
|
|
167
|
-
# Load dataset
|
|
168
|
-
logger.info(f"Loading {dataset_name}...")
|
|
169
|
-
dataset = load_dataset(dataset_name)
|
|
170
|
-
logger.info(f"\n{dataset}")
|
|
171
|
-
|
|
172
|
-
results["dataset_stats"] = {
|
|
173
|
-
"num_entities": dataset.num_entities,
|
|
174
|
-
"num_relations": dataset.num_relations,
|
|
175
|
-
"num_train": dataset.num_train,
|
|
176
|
-
"num_valid": dataset.num_valid,
|
|
177
|
-
"num_test": dataset.num_test,
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
# Convert to KnowledgeGraph
|
|
181
|
-
logger.info("Converting to KnowledgeGraph...")
|
|
182
|
-
kg = dataset_to_kg(dataset)
|
|
183
|
-
|
|
184
|
-
# Create NPLL config
|
|
185
|
-
config = get_config()
|
|
186
|
-
config.embedding_dim = 100
|
|
187
|
-
config.hidden_dim = 200
|
|
188
|
-
results["config"] = {
|
|
189
|
-
"embedding_dim": config.embedding_dim,
|
|
190
|
-
"hidden_dim": config.hidden_dim,
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
# Create and train model
|
|
194
|
-
logger.info("Creating NPLL model...")
|
|
195
|
-
model = create_npll_model(kg, config)
|
|
196
|
-
|
|
197
|
-
training_results = train_npll(model, epochs=epochs)
|
|
198
|
-
results["training"] = training_results
|
|
199
|
-
|
|
200
|
-
# Evaluate
|
|
201
|
-
eval_results = evaluate_npll(model, dataset, test_subset=test_subset)
|
|
202
|
-
results["evaluation"] = eval_results
|
|
203
|
-
|
|
204
|
-
# Print summary
|
|
205
|
-
print("\n" + "=" * 60)
|
|
206
|
-
print(f"NPLL BENCHMARK RESULTS: {dataset_name}")
|
|
207
|
-
print("=" * 60)
|
|
208
|
-
print(f"Dataset: {dataset.num_entities:,} entities, {dataset.num_relations} relations")
|
|
209
|
-
print(f"Training: {training_results['training_time_seconds']:.1f}s")
|
|
210
|
-
print("-" * 60)
|
|
211
|
-
print("METRICS (Filtered Setting):")
|
|
212
|
-
print(f" MRR: {eval_results['mrr']:.4f}")
|
|
213
|
-
print(f" Hits@1: {eval_results['hits@1']:.4f}")
|
|
214
|
-
print(f" Hits@3: {eval_results['hits@3']:.4f}")
|
|
215
|
-
print(f" Hits@10: {eval_results['hits@10']:.4f}")
|
|
216
|
-
print(f" Mean Rank: {eval_results['mean_rank']:.1f}")
|
|
217
|
-
print("=" * 60)
|
|
218
|
-
|
|
219
|
-
# Save results
|
|
220
|
-
if output_dir:
|
|
221
|
-
output_dir = Path(output_dir)
|
|
222
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
223
|
-
output_file = output_dir / f"npll_{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
224
|
-
with open(output_file, "w") as f:
|
|
225
|
-
json.dump(results, f, indent=2)
|
|
226
|
-
logger.info(f"Results saved to {output_file}")
|
|
227
|
-
|
|
228
|
-
return results
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def main():
|
|
232
|
-
parser = argparse.ArgumentParser(description="NPLL Benchmark Runner")
|
|
233
|
-
parser.add_argument(
|
|
234
|
-
"--dataset",
|
|
235
|
-
type=str,
|
|
236
|
-
default="fb15k237",
|
|
237
|
-
choices=["fb15k237", "wn18rr"],
|
|
238
|
-
help="Dataset to evaluate on",
|
|
239
|
-
)
|
|
240
|
-
parser.add_argument(
|
|
241
|
-
"--epochs",
|
|
242
|
-
type=int,
|
|
243
|
-
default=10,
|
|
244
|
-
help="Number of training epochs",
|
|
245
|
-
)
|
|
246
|
-
parser.add_argument(
|
|
247
|
-
"--test-subset",
|
|
248
|
-
type=int,
|
|
249
|
-
default=None,
|
|
250
|
-
help="Limit test evaluation to first N triples (for speed)",
|
|
251
|
-
)
|
|
252
|
-
parser.add_argument(
|
|
253
|
-
"--output-dir",
|
|
254
|
-
type=str,
|
|
255
|
-
default="benchmark_results",
|
|
256
|
-
help="Directory to save results",
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
args = parser.parse_args()
|
|
260
|
-
|
|
261
|
-
run_benchmark(
|
|
262
|
-
dataset_name=args.dataset,
|
|
263
|
-
epochs=args.epochs,
|
|
264
|
-
test_subset=args.test_subset,
|
|
265
|
-
output_dir=Path(args.output_dir),
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
if __name__ == "__main__":
|
|
270
|
-
main()
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
NPLL Benchmark Runner
|
|
4
|
+
|
|
5
|
+
Evaluates NPLL on standard KG completion benchmarks (FB15k-237, WN18RR)
|
|
6
|
+
and reports standard metrics (MRR, Hits@K).
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python -m benchmarks.run_npll_benchmark --dataset fb15k237
|
|
10
|
+
python -m benchmarks.run_npll_benchmark --dataset wn18rr
|
|
11
|
+
python -m benchmarks.run_npll_benchmark --dataset fb15k237 --test-subset 1000
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
from datetime import datetime
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Dict, Any
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from benchmarks.datasets import load_fb15k237, load_wn18rr, dataset_to_kg, BenchmarkDataset
|
|
25
|
+
from benchmarks.metrics import LinkPredictionEvaluator, evaluate_rankings
|
|
26
|
+
from npll import NPLLModel
|
|
27
|
+
from npll.core import KnowledgeGraph, RuleGenerator
|
|
28
|
+
from npll.utils import NPLLConfig, get_config
|
|
29
|
+
|
|
30
|
+
logging.basicConfig(
|
|
31
|
+
level=logging.INFO,
|
|
32
|
+
format="%(asctime)s [%(levelname)s] %(message)s"
|
|
33
|
+
)
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_dataset(name: str) -> BenchmarkDataset:
|
|
38
|
+
"""Load dataset by name."""
|
|
39
|
+
if name.lower() == "fb15k237":
|
|
40
|
+
return load_fb15k237()
|
|
41
|
+
elif name.lower() == "wn18rr":
|
|
42
|
+
return load_wn18rr()
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(f"Unknown dataset: {name}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def create_npll_model(kg: KnowledgeGraph, config: NPLLConfig) -> NPLLModel:
|
|
48
|
+
"""Create and initialize NPLL model."""
|
|
49
|
+
model = NPLLModel(config)
|
|
50
|
+
|
|
51
|
+
# Generate rules from the knowledge graph
|
|
52
|
+
logger.info("Generating logical rules...")
|
|
53
|
+
rule_gen = RuleGenerator()
|
|
54
|
+
rules = rule_gen.generate_rules(kg, max_chain_length=2)
|
|
55
|
+
|
|
56
|
+
# If no rules generated, add universal rules
|
|
57
|
+
if not rules:
|
|
58
|
+
logger.warning("No rules generated, adding universal fallback rules")
|
|
59
|
+
rules = rule_gen.generate_universal_rules(kg)
|
|
60
|
+
|
|
61
|
+
logger.info(f"Generated {len(rules)} rules")
|
|
62
|
+
|
|
63
|
+
# Initialize with KG and rules
|
|
64
|
+
model.initialize(kg, rules)
|
|
65
|
+
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def train_npll(model: NPLLModel, epochs: int = 10) -> Dict[str, Any]:
|
|
70
|
+
"""Train NPLL model."""
|
|
71
|
+
logger.info(f"Training NPLL for {epochs} epochs...")
|
|
72
|
+
start_time = time.time()
|
|
73
|
+
|
|
74
|
+
training_state = model.train_model(
|
|
75
|
+
num_epochs=epochs,
|
|
76
|
+
em_iterations=5,
|
|
77
|
+
verbose=True,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
training_time = time.time() - start_time
|
|
81
|
+
|
|
82
|
+
return {
|
|
83
|
+
"epochs": epochs,
|
|
84
|
+
"training_time_seconds": training_time,
|
|
85
|
+
"final_elbo": training_state.best_elbo if training_state else None,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def evaluate_npll(
|
|
90
|
+
model: NPLLModel,
|
|
91
|
+
dataset: BenchmarkDataset,
|
|
92
|
+
test_subset: int = None,
|
|
93
|
+
) -> Dict[str, float]:
|
|
94
|
+
"""
|
|
95
|
+
Evaluate NPLL on link prediction task.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
model: Trained NPLL model
|
|
99
|
+
dataset: Benchmark dataset
|
|
100
|
+
test_subset: Limit test to first N triples (for faster evaluation)
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Metrics dictionary
|
|
104
|
+
"""
|
|
105
|
+
logger.info("Evaluating on link prediction task...")
|
|
106
|
+
|
|
107
|
+
# Prepare test triples
|
|
108
|
+
test_triples = dataset.test_triples
|
|
109
|
+
if test_subset:
|
|
110
|
+
test_triples = test_triples[:test_subset]
|
|
111
|
+
logger.info(f"Using subset of {len(test_triples)} test triples")
|
|
112
|
+
|
|
113
|
+
# Create evaluator
|
|
114
|
+
evaluator = LinkPredictionEvaluator(
|
|
115
|
+
all_entities=dataset.entities,
|
|
116
|
+
train_triples=dataset.get_train_set(),
|
|
117
|
+
valid_triples=set(dataset.valid_triples),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Define scoring function using NPLL
|
|
121
|
+
def score_fn(h: str, r: str, t: str) -> float:
|
|
122
|
+
try:
|
|
123
|
+
# Use NPLL model to score the triple
|
|
124
|
+
scores = model.score_triples([(h, r, t)])
|
|
125
|
+
return float(scores[0]) if scores else 0.0
|
|
126
|
+
except Exception:
|
|
127
|
+
return 0.0
|
|
128
|
+
|
|
129
|
+
# Run evaluation
|
|
130
|
+
start_time = time.time()
|
|
131
|
+
metrics = evaluator.evaluate_batch(test_triples, score_fn, verbose=True)
|
|
132
|
+
eval_time = time.time() - start_time
|
|
133
|
+
|
|
134
|
+
metrics["evaluation_time_seconds"] = eval_time
|
|
135
|
+
metrics["triples_evaluated"] = len(test_triples)
|
|
136
|
+
|
|
137
|
+
return metrics
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def run_benchmark(
|
|
141
|
+
dataset_name: str,
|
|
142
|
+
epochs: int = 10,
|
|
143
|
+
test_subset: int = None,
|
|
144
|
+
output_dir: Path = None,
|
|
145
|
+
) -> Dict[str, Any]:
|
|
146
|
+
"""
|
|
147
|
+
Run full NPLL benchmark.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
dataset_name: Name of dataset (fb15k237 or wn18rr)
|
|
151
|
+
epochs: Training epochs
|
|
152
|
+
test_subset: Limit test evaluation (for speed)
|
|
153
|
+
output_dir: Directory to save results
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Full results dictionary
|
|
157
|
+
"""
|
|
158
|
+
results = {
|
|
159
|
+
"dataset": dataset_name,
|
|
160
|
+
"timestamp": datetime.now().isoformat(),
|
|
161
|
+
"config": {},
|
|
162
|
+
"dataset_stats": {},
|
|
163
|
+
"training": {},
|
|
164
|
+
"evaluation": {},
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
# Load dataset
|
|
168
|
+
logger.info(f"Loading {dataset_name}...")
|
|
169
|
+
dataset = load_dataset(dataset_name)
|
|
170
|
+
logger.info(f"\n{dataset}")
|
|
171
|
+
|
|
172
|
+
results["dataset_stats"] = {
|
|
173
|
+
"num_entities": dataset.num_entities,
|
|
174
|
+
"num_relations": dataset.num_relations,
|
|
175
|
+
"num_train": dataset.num_train,
|
|
176
|
+
"num_valid": dataset.num_valid,
|
|
177
|
+
"num_test": dataset.num_test,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Convert to KnowledgeGraph
|
|
181
|
+
logger.info("Converting to KnowledgeGraph...")
|
|
182
|
+
kg = dataset_to_kg(dataset)
|
|
183
|
+
|
|
184
|
+
# Create NPLL config
|
|
185
|
+
config = get_config()
|
|
186
|
+
config.embedding_dim = 100
|
|
187
|
+
config.hidden_dim = 200
|
|
188
|
+
results["config"] = {
|
|
189
|
+
"embedding_dim": config.embedding_dim,
|
|
190
|
+
"hidden_dim": config.hidden_dim,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
# Create and train model
|
|
194
|
+
logger.info("Creating NPLL model...")
|
|
195
|
+
model = create_npll_model(kg, config)
|
|
196
|
+
|
|
197
|
+
training_results = train_npll(model, epochs=epochs)
|
|
198
|
+
results["training"] = training_results
|
|
199
|
+
|
|
200
|
+
# Evaluate
|
|
201
|
+
eval_results = evaluate_npll(model, dataset, test_subset=test_subset)
|
|
202
|
+
results["evaluation"] = eval_results
|
|
203
|
+
|
|
204
|
+
# Print summary
|
|
205
|
+
print("\n" + "=" * 60)
|
|
206
|
+
print(f"NPLL BENCHMARK RESULTS: {dataset_name}")
|
|
207
|
+
print("=" * 60)
|
|
208
|
+
print(f"Dataset: {dataset.num_entities:,} entities, {dataset.num_relations} relations")
|
|
209
|
+
print(f"Training: {training_results['training_time_seconds']:.1f}s")
|
|
210
|
+
print("-" * 60)
|
|
211
|
+
print("METRICS (Filtered Setting):")
|
|
212
|
+
print(f" MRR: {eval_results['mrr']:.4f}")
|
|
213
|
+
print(f" Hits@1: {eval_results['hits@1']:.4f}")
|
|
214
|
+
print(f" Hits@3: {eval_results['hits@3']:.4f}")
|
|
215
|
+
print(f" Hits@10: {eval_results['hits@10']:.4f}")
|
|
216
|
+
print(f" Mean Rank: {eval_results['mean_rank']:.1f}")
|
|
217
|
+
print("=" * 60)
|
|
218
|
+
|
|
219
|
+
# Save results
|
|
220
|
+
if output_dir:
|
|
221
|
+
output_dir = Path(output_dir)
|
|
222
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
223
|
+
output_file = output_dir / f"npll_{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
224
|
+
with open(output_file, "w") as f:
|
|
225
|
+
json.dump(results, f, indent=2)
|
|
226
|
+
logger.info(f"Results saved to {output_file}")
|
|
227
|
+
|
|
228
|
+
return results
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def main():
|
|
232
|
+
parser = argparse.ArgumentParser(description="NPLL Benchmark Runner")
|
|
233
|
+
parser.add_argument(
|
|
234
|
+
"--dataset",
|
|
235
|
+
type=str,
|
|
236
|
+
default="fb15k237",
|
|
237
|
+
choices=["fb15k237", "wn18rr"],
|
|
238
|
+
help="Dataset to evaluate on",
|
|
239
|
+
)
|
|
240
|
+
parser.add_argument(
|
|
241
|
+
"--epochs",
|
|
242
|
+
type=int,
|
|
243
|
+
default=10,
|
|
244
|
+
help="Number of training epochs",
|
|
245
|
+
)
|
|
246
|
+
parser.add_argument(
|
|
247
|
+
"--test-subset",
|
|
248
|
+
type=int,
|
|
249
|
+
default=None,
|
|
250
|
+
help="Limit test evaluation to first N triples (for speed)",
|
|
251
|
+
)
|
|
252
|
+
parser.add_argument(
|
|
253
|
+
"--output-dir",
|
|
254
|
+
type=str,
|
|
255
|
+
default="benchmark_results",
|
|
256
|
+
help="Directory to save results",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
args = parser.parse_args()
|
|
260
|
+
|
|
261
|
+
run_benchmark(
|
|
262
|
+
dataset_name=args.dataset,
|
|
263
|
+
epochs=args.epochs,
|
|
264
|
+
test_subset=args.test_subset,
|
|
265
|
+
output_dir=Path(args.output_dir),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
if __name__ == "__main__":
|
|
270
|
+
main()
|
npll/__init__.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
"""
|
|
2
|
-
NPLL: Neural Probabilistic Logic Learning
|
|
3
|
-
|
|
4
|
-
The semantic intelligence layer for Odin.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from npll.npll_model import NPLLModel
|
|
8
|
-
from npll.bootstrap import KnowledgeBootstrapper
|
|
9
|
-
|
|
10
|
-
__all__ = ["NPLLModel", "KnowledgeBootstrapper"]
|
|
1
|
+
"""
|
|
2
|
+
NPLL: Neural Probabilistic Logic Learning
|
|
3
|
+
|
|
4
|
+
The semantic intelligence layer for Odin.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from npll.npll_model import NPLLModel
|
|
8
|
+
from npll.bootstrap import KnowledgeBootstrapper
|
|
9
|
+
|
|
10
|
+
__all__ = ["NPLLModel", "KnowledgeBootstrapper"]
|