odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/training/evaluation.py
CHANGED
|
@@ -1,497 +1,497 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Evaluation Metrics for NPLL Implementation
|
|
3
|
-
Implements MRR, Hit@K, and other knowledge graph evaluation metrics from the paper
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import numpy as np
|
|
8
|
-
from typing import List, Dict, Set, Tuple, Optional, Any, Union
|
|
9
|
-
import logging
|
|
10
|
-
from dataclasses import dataclass
|
|
11
|
-
from collections import defaultdict
|
|
12
|
-
import time
|
|
13
|
-
|
|
14
|
-
from ..core import KnowledgeGraph, Triple, Entity, Relation
|
|
15
|
-
from ..npll_model import NPLLModel
|
|
16
|
-
|
|
17
|
-
logger = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class EvaluationMetrics:
|
|
22
|
-
"""
|
|
23
|
-
Comprehensive evaluation metrics for knowledge graph completion
|
|
24
|
-
"""
|
|
25
|
-
# Link prediction metrics (standard)
|
|
26
|
-
mrr: float # Mean Reciprocal Rank
|
|
27
|
-
hit_at_1: float # Hit@1
|
|
28
|
-
hit_at_3: float # Hit@3
|
|
29
|
-
hit_at_10: float # Hit@10
|
|
30
|
-
|
|
31
|
-
# Extended metrics
|
|
32
|
-
mean_rank: float # Mean rank
|
|
33
|
-
median_rank: float # Median rank
|
|
34
|
-
|
|
35
|
-
# Rule quality metrics
|
|
36
|
-
rule_precision: Optional[float] = None
|
|
37
|
-
rule_recall: Optional[float] = None
|
|
38
|
-
rule_f1: Optional[float] = None
|
|
39
|
-
|
|
40
|
-
# Confidence calibration
|
|
41
|
-
calibration_error: Optional[float] = None
|
|
42
|
-
confidence_correlation: Optional[float] = None
|
|
43
|
-
|
|
44
|
-
# Efficiency metrics
|
|
45
|
-
evaluation_time: float = 0.0
|
|
46
|
-
predictions_per_second: float = 0.0
|
|
47
|
-
|
|
48
|
-
def __str__(self) -> str:
|
|
49
|
-
return (f"Evaluation Metrics:\\n"
|
|
50
|
-
f" MRR: {self.mrr:.4f}\\n"
|
|
51
|
-
f" Hit@1: {self.hit_at_1:.4f}\\n"
|
|
52
|
-
f" Hit@3: {self.hit_at_3:.4f}\\n"
|
|
53
|
-
f" Hit@10: {self.hit_at_10:.4f}\\n"
|
|
54
|
-
f" Mean Rank: {self.mean_rank:.2f}\\n"
|
|
55
|
-
f" Median Rank: {self.median_rank:.2f}")
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class KnowledgeGraphEvaluator:
|
|
59
|
-
"""
|
|
60
|
-
Base evaluator for knowledge graph tasks
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
def __init__(self, knowledge_graph: KnowledgeGraph, filter_known: bool = True):
|
|
64
|
-
"""
|
|
65
|
-
Initialize evaluator
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
knowledge_graph: Knowledge graph for evaluation
|
|
69
|
-
filter_known: Whether to filter out known facts during ranking
|
|
70
|
-
"""
|
|
71
|
-
self.kg = knowledge_graph
|
|
72
|
-
self.filter_known = filter_known
|
|
73
|
-
|
|
74
|
-
# Create sets for efficient lookup
|
|
75
|
-
self.known_fact_set = set()
|
|
76
|
-
for fact in self.kg.known_facts:
|
|
77
|
-
self.known_fact_set.add((fact.head.name, fact.relation.name, fact.tail.name))
|
|
78
|
-
|
|
79
|
-
def create_negative_samples(self,
|
|
80
|
-
test_triple: Triple,
|
|
81
|
-
corruption_mode: str = "both") -> List[Triple]:
|
|
82
|
-
"""
|
|
83
|
-
Create negative samples by corrupting head or tail entities
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
test_triple: Triple to create negatives for
|
|
87
|
-
corruption_mode: "head", "tail", or "both"
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
List of negative triples
|
|
91
|
-
"""
|
|
92
|
-
negatives = []
|
|
93
|
-
|
|
94
|
-
if corruption_mode in ["head", "both"]:
|
|
95
|
-
# Corrupt head entity
|
|
96
|
-
for entity in self.kg.entities:
|
|
97
|
-
if entity != test_triple.head:
|
|
98
|
-
negative = Triple(
|
|
99
|
-
head=entity,
|
|
100
|
-
relation=test_triple.relation,
|
|
101
|
-
tail=test_triple.tail
|
|
102
|
-
)
|
|
103
|
-
negatives.append(negative)
|
|
104
|
-
|
|
105
|
-
if corruption_mode in ["tail", "both"]:
|
|
106
|
-
# Corrupt tail entity
|
|
107
|
-
for entity in self.kg.entities:
|
|
108
|
-
if entity != test_triple.tail:
|
|
109
|
-
negative = Triple(
|
|
110
|
-
head=test_triple.head,
|
|
111
|
-
relation=test_triple.relation,
|
|
112
|
-
tail=entity
|
|
113
|
-
)
|
|
114
|
-
negatives.append(negative)
|
|
115
|
-
|
|
116
|
-
return negatives
|
|
117
|
-
|
|
118
|
-
def filter_candidates(self, candidates: List[Triple]) -> List[Triple]:
|
|
119
|
-
"""Filter out known facts from candidates if filtering is enabled"""
|
|
120
|
-
if not self.filter_known:
|
|
121
|
-
return candidates
|
|
122
|
-
|
|
123
|
-
filtered = []
|
|
124
|
-
for candidate in candidates:
|
|
125
|
-
triple_tuple = (candidate.head.name, candidate.relation.name, candidate.tail.name)
|
|
126
|
-
if triple_tuple not in self.known_fact_set:
|
|
127
|
-
filtered.append(candidate)
|
|
128
|
-
|
|
129
|
-
return filtered
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
class LinkPredictionEvaluator(KnowledgeGraphEvaluator):
|
|
133
|
-
"""
|
|
134
|
-
Evaluator for link prediction task
|
|
135
|
-
Implements standard knowledge graph completion metrics
|
|
136
|
-
"""
|
|
137
|
-
|
|
138
|
-
def evaluate_link_prediction(self,
|
|
139
|
-
model: NPLLModel,
|
|
140
|
-
test_triples: Optional[List[Triple]] = None,
|
|
141
|
-
top_k: List[int] = [1, 3, 10],
|
|
142
|
-
corruption_mode: str = "both") -> Dict[str, float]:
|
|
143
|
-
"""
|
|
144
|
-
Evaluate link prediction performance
|
|
145
|
-
|
|
146
|
-
Args:
|
|
147
|
-
model: Trained NPLL model
|
|
148
|
-
test_triples: Test triples (uses unknown facts if None)
|
|
149
|
-
top_k: List of K values for Hit@K computation
|
|
150
|
-
corruption_mode: "head", "tail", or "both"
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
Dictionary with evaluation metrics
|
|
154
|
-
"""
|
|
155
|
-
if test_triples is None:
|
|
156
|
-
test_triples = list(self.kg.unknown_facts)
|
|
157
|
-
|
|
158
|
-
if not test_triples:
|
|
159
|
-
logger.warning("No test triples available for evaluation")
|
|
160
|
-
return {}
|
|
161
|
-
|
|
162
|
-
logger.info(f"Evaluating link prediction on {len(test_triples)} test triples")
|
|
163
|
-
|
|
164
|
-
start_time = time.time()
|
|
165
|
-
ranks = []
|
|
166
|
-
|
|
167
|
-
for i, test_triple in enumerate(test_triples):
|
|
168
|
-
if i % 100 == 0:
|
|
169
|
-
logger.debug(f"Evaluating triple {i}/{len(test_triples)}")
|
|
170
|
-
|
|
171
|
-
# Create candidates (test triple + negatives)
|
|
172
|
-
candidates = [test_triple] + self.create_negative_samples(test_triple, corruption_mode)
|
|
173
|
-
|
|
174
|
-
# Filter known facts
|
|
175
|
-
candidates = self.filter_candidates(candidates)
|
|
176
|
-
|
|
177
|
-
if not candidates:
|
|
178
|
-
continue
|
|
179
|
-
|
|
180
|
-
# Get model predictions
|
|
181
|
-
try:
|
|
182
|
-
predictions = model.forward(candidates)
|
|
183
|
-
scores = predictions['probabilities'].cpu().numpy()
|
|
184
|
-
|
|
185
|
-
# Find rank of test triple (first candidate)
|
|
186
|
-
test_score = scores[0]
|
|
187
|
-
rank = 1 + np.sum(scores > test_score) # Rank starts from 1
|
|
188
|
-
ranks.append(rank)
|
|
189
|
-
|
|
190
|
-
except Exception as e:
|
|
191
|
-
logger.warning(f"Error evaluating triple {test_triple}: {e}")
|
|
192
|
-
continue
|
|
193
|
-
|
|
194
|
-
if not ranks:
|
|
195
|
-
logger.error("No valid ranks computed")
|
|
196
|
-
return {}
|
|
197
|
-
|
|
198
|
-
evaluation_time = time.time() - start_time
|
|
199
|
-
|
|
200
|
-
# Compute metrics
|
|
201
|
-
ranks = np.array(ranks, dtype=float)
|
|
202
|
-
|
|
203
|
-
metrics = {
|
|
204
|
-
'mrr': float(np.mean(1.0 / ranks)),
|
|
205
|
-
'mean_rank': float(np.mean(ranks)),
|
|
206
|
-
'median_rank': float(np.median(ranks)),
|
|
207
|
-
'evaluation_time': evaluation_time,
|
|
208
|
-
'predictions_per_second': len(test_triples) / evaluation_time if evaluation_time > 0 else 0.0
|
|
209
|
-
}
|
|
210
|
-
|
|
211
|
-
# Compute Hit@K metrics
|
|
212
|
-
for k in top_k:
|
|
213
|
-
hit_at_k = np.mean(ranks <= k)
|
|
214
|
-
metrics[f'hit@{k}'] = float(hit_at_k)
|
|
215
|
-
|
|
216
|
-
logger.info(f"Link prediction evaluation completed: MRR={metrics['mrr']:.4f}")
|
|
217
|
-
|
|
218
|
-
return metrics
|
|
219
|
-
|
|
220
|
-
def evaluate_entity_ranking(self,
|
|
221
|
-
model: NPLLModel,
|
|
222
|
-
query_relations: List[str],
|
|
223
|
-
top_k: int = 10) -> Dict[str, List[Tuple[str, float]]]:
|
|
224
|
-
"""
|
|
225
|
-
Evaluate entity ranking for specific relations
|
|
226
|
-
|
|
227
|
-
Args:
|
|
228
|
-
model: Trained NPLL model
|
|
229
|
-
query_relations: Relations to evaluate
|
|
230
|
-
top_k: Number of top entities to return
|
|
231
|
-
|
|
232
|
-
Returns:
|
|
233
|
-
Dictionary mapping relations to ranked entity lists
|
|
234
|
-
"""
|
|
235
|
-
results = {}
|
|
236
|
-
|
|
237
|
-
for relation_name in query_relations:
|
|
238
|
-
relation = self.kg.get_relation(relation_name)
|
|
239
|
-
if relation is None:
|
|
240
|
-
continue
|
|
241
|
-
|
|
242
|
-
# For each head entity, rank all possible tail entities
|
|
243
|
-
entity_scores = defaultdict(list)
|
|
244
|
-
|
|
245
|
-
for head_entity in self.kg.entities:
|
|
246
|
-
candidates = []
|
|
247
|
-
for tail_entity in self.kg.entities:
|
|
248
|
-
if head_entity != tail_entity:
|
|
249
|
-
candidate = Triple(head=head_entity, relation=relation, tail=tail_entity)
|
|
250
|
-
candidates.append(candidate)
|
|
251
|
-
|
|
252
|
-
# Get predictions
|
|
253
|
-
if candidates:
|
|
254
|
-
predictions = model.forward(candidates)
|
|
255
|
-
scores = predictions['probabilities'].cpu().numpy()
|
|
256
|
-
|
|
257
|
-
for candidate, score in zip(candidates, scores):
|
|
258
|
-
entity_scores[candidate.tail.name].append(score)
|
|
259
|
-
|
|
260
|
-
# Average scores and rank
|
|
261
|
-
avg_scores = [(entity, np.mean(scores))
|
|
262
|
-
for entity, scores in entity_scores.items()]
|
|
263
|
-
avg_scores.sort(key=lambda x: x[1], reverse=True)
|
|
264
|
-
|
|
265
|
-
results[relation_name] = avg_scores[:top_k]
|
|
266
|
-
|
|
267
|
-
return results
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
class RuleQualityEvaluator(KnowledgeGraphEvaluator):
|
|
271
|
-
"""
|
|
272
|
-
Evaluator for logical rule quality
|
|
273
|
-
"""
|
|
274
|
-
|
|
275
|
-
def evaluate_rule_quality(self, model: NPLLModel) -> Dict[str, float]:
|
|
276
|
-
"""
|
|
277
|
-
Evaluate quality of learned logical rules
|
|
278
|
-
|
|
279
|
-
Args:
|
|
280
|
-
model: Trained NPLL model
|
|
281
|
-
|
|
282
|
-
Returns:
|
|
283
|
-
Dictionary with rule quality metrics
|
|
284
|
-
"""
|
|
285
|
-
if not model.is_initialized or model.mln is None:
|
|
286
|
-
return {}
|
|
287
|
-
|
|
288
|
-
rule_confidences = model.get_rule_confidences()
|
|
289
|
-
if not rule_confidences:
|
|
290
|
-
return {}
|
|
291
|
-
|
|
292
|
-
# Compute rule statistics
|
|
293
|
-
confidences = list(rule_confidences.values())
|
|
294
|
-
|
|
295
|
-
metrics = {
|
|
296
|
-
'avg_rule_confidence': float(np.mean(confidences)),
|
|
297
|
-
'std_rule_confidence': float(np.std(confidences)),
|
|
298
|
-
'min_rule_confidence': float(np.min(confidences)),
|
|
299
|
-
'max_rule_confidence': float(np.max(confidences)),
|
|
300
|
-
'num_high_confidence_rules': int(np.sum(np.array(confidences) > 0.8)),
|
|
301
|
-
'num_low_confidence_rules': int(np.sum(np.array(confidences) < 0.2))
|
|
302
|
-
}
|
|
303
|
-
|
|
304
|
-
return metrics
|
|
305
|
-
|
|
306
|
-
def evaluate_rule_coverage(self, model: NPLLModel) -> Dict[str, float]:
|
|
307
|
-
"""
|
|
308
|
-
Evaluate how well rules cover the known facts
|
|
309
|
-
|
|
310
|
-
Args:
|
|
311
|
-
model: Trained NPLL model
|
|
312
|
-
|
|
313
|
-
Returns:
|
|
314
|
-
Dictionary with coverage metrics
|
|
315
|
-
"""
|
|
316
|
-
if not model.is_initialized or model.mln is None:
|
|
317
|
-
return {}
|
|
318
|
-
|
|
319
|
-
# This would require more complex analysis of ground rules
|
|
320
|
-
# For now, return basic statistics
|
|
321
|
-
|
|
322
|
-
total_ground_rules = len(model.mln.ground_rules) if model.mln.ground_rules else 0
|
|
323
|
-
total_facts = len(self.kg.known_facts) + len(self.kg.unknown_facts)
|
|
324
|
-
|
|
325
|
-
coverage_ratio = total_ground_rules / total_facts if total_facts > 0 else 0.0
|
|
326
|
-
|
|
327
|
-
return {
|
|
328
|
-
'total_ground_rules': total_ground_rules,
|
|
329
|
-
'total_facts': total_facts,
|
|
330
|
-
'coverage_ratio': coverage_ratio
|
|
331
|
-
}
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
class ConfidenceCalibrationEvaluator(KnowledgeGraphEvaluator):
|
|
335
|
-
"""
|
|
336
|
-
Evaluator for prediction confidence calibration
|
|
337
|
-
"""
|
|
338
|
-
|
|
339
|
-
def evaluate_calibration(self,
|
|
340
|
-
model: NPLLModel,
|
|
341
|
-
test_triples: List[Triple],
|
|
342
|
-
num_bins: int = 10) -> Dict[str, float]:
|
|
343
|
-
"""
|
|
344
|
-
Evaluate confidence calibration using reliability diagrams
|
|
345
|
-
|
|
346
|
-
Args:
|
|
347
|
-
model: Trained NPLL model
|
|
348
|
-
test_triples: Test triples with ground truth
|
|
349
|
-
num_bins: Number of confidence bins
|
|
350
|
-
|
|
351
|
-
Returns:
|
|
352
|
-
Dictionary with calibration metrics
|
|
353
|
-
"""
|
|
354
|
-
# Get model predictions
|
|
355
|
-
predictions = model.forward(test_triples)
|
|
356
|
-
confidences = predictions['probabilities'].cpu().numpy()
|
|
357
|
-
|
|
358
|
-
# For this example, assume all test triples are positive
|
|
359
|
-
# In practice, you'd need ground truth labels
|
|
360
|
-
ground_truth = np.ones(len(test_triples)) # Placeholder
|
|
361
|
-
|
|
362
|
-
# Compute Expected Calibration Error (ECE)
|
|
363
|
-
bin_boundaries = np.linspace(0, 1, num_bins + 1)
|
|
364
|
-
bin_lowers = bin_boundaries[:-1]
|
|
365
|
-
bin_uppers = bin_boundaries[1:]
|
|
366
|
-
|
|
367
|
-
ece = 0.0
|
|
368
|
-
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
|
369
|
-
in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
|
|
370
|
-
prop_in_bin = in_bin.mean()
|
|
371
|
-
|
|
372
|
-
if prop_in_bin > 0:
|
|
373
|
-
accuracy_in_bin = ground_truth[in_bin].mean()
|
|
374
|
-
avg_confidence_in_bin = confidences[in_bin].mean()
|
|
375
|
-
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
|
376
|
-
|
|
377
|
-
return {
|
|
378
|
-
'expected_calibration_error': float(ece),
|
|
379
|
-
'avg_confidence': float(np.mean(confidences)),
|
|
380
|
-
'confidence_std': float(np.std(confidences))
|
|
381
|
-
}
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
def create_evaluator(knowledge_graph: KnowledgeGraph,
|
|
385
|
-
evaluation_type: str = "link_prediction") -> KnowledgeGraphEvaluator:
|
|
386
|
-
"""
|
|
387
|
-
Factory function to create appropriate evaluator
|
|
388
|
-
|
|
389
|
-
Args:
|
|
390
|
-
knowledge_graph: Knowledge graph for evaluation
|
|
391
|
-
evaluation_type: Type of evaluator to create
|
|
392
|
-
|
|
393
|
-
Returns:
|
|
394
|
-
Configured evaluator
|
|
395
|
-
"""
|
|
396
|
-
if evaluation_type == "link_prediction":
|
|
397
|
-
return LinkPredictionEvaluator(knowledge_graph)
|
|
398
|
-
elif evaluation_type == "rule_quality":
|
|
399
|
-
return RuleQualityEvaluator(knowledge_graph)
|
|
400
|
-
elif evaluation_type == "confidence_calibration":
|
|
401
|
-
return ConfidenceCalibrationEvaluator(knowledge_graph)
|
|
402
|
-
else:
|
|
403
|
-
return LinkPredictionEvaluator(knowledge_graph) # Default
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
def comprehensive_evaluation(model: NPLLModel,
|
|
407
|
-
knowledge_graph: KnowledgeGraph,
|
|
408
|
-
test_triples: Optional[List[Triple]] = None) -> EvaluationMetrics:
|
|
409
|
-
"""
|
|
410
|
-
Run comprehensive evaluation of NPLL model
|
|
411
|
-
|
|
412
|
-
Args:
|
|
413
|
-
model: Trained NPLL model
|
|
414
|
-
knowledge_graph: Knowledge graph for evaluation
|
|
415
|
-
test_triples: Optional test triples
|
|
416
|
-
|
|
417
|
-
Returns:
|
|
418
|
-
Comprehensive evaluation metrics
|
|
419
|
-
"""
|
|
420
|
-
start_time = time.time()
|
|
421
|
-
|
|
422
|
-
# Link prediction evaluation
|
|
423
|
-
link_evaluator = LinkPredictionEvaluator(knowledge_graph)
|
|
424
|
-
link_metrics = link_evaluator.evaluate_link_prediction(model, test_triples)
|
|
425
|
-
|
|
426
|
-
# Rule quality evaluation
|
|
427
|
-
rule_evaluator = RuleQualityEvaluator(knowledge_graph)
|
|
428
|
-
rule_metrics = rule_evaluator.evaluate_rule_quality(model)
|
|
429
|
-
|
|
430
|
-
evaluation_time = time.time() - start_time
|
|
431
|
-
|
|
432
|
-
# Create comprehensive metrics object
|
|
433
|
-
metrics = EvaluationMetrics(
|
|
434
|
-
mrr=link_metrics.get('mrr', 0.0),
|
|
435
|
-
hit_at_1=link_metrics.get('hit@1', 0.0),
|
|
436
|
-
hit_at_3=link_metrics.get('hit@3', 0.0),
|
|
437
|
-
hit_at_10=link_metrics.get('hit@10', 0.0),
|
|
438
|
-
mean_rank=link_metrics.get('mean_rank', 0.0),
|
|
439
|
-
median_rank=link_metrics.get('median_rank', 0.0),
|
|
440
|
-
rule_precision=rule_metrics.get('avg_rule_confidence'),
|
|
441
|
-
evaluation_time=evaluation_time,
|
|
442
|
-
predictions_per_second=link_metrics.get('predictions_per_second', 0.0)
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
return metrics
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
# Example usage function
|
|
449
|
-
def example_evaluation():
|
|
450
|
-
"""
|
|
451
|
-
Example showing comprehensive evaluation with sample data
|
|
452
|
-
"""
|
|
453
|
-
from ..core import load_knowledge_graph_from_triples
|
|
454
|
-
from ..core.logical_rules import RuleGenerator
|
|
455
|
-
from ..npll_model import create_npll_model
|
|
456
|
-
from ..utils import get_config
|
|
457
|
-
|
|
458
|
-
# Create sample data
|
|
459
|
-
sample_triples = [
|
|
460
|
-
('Alice', 'friendOf', 'Bob'),
|
|
461
|
-
('Bob', 'worksAt', 'Company'),
|
|
462
|
-
('Charlie', 'friendOf', 'Alice'),
|
|
463
|
-
('Alice', 'livesIn', 'NYC')
|
|
464
|
-
]
|
|
465
|
-
|
|
466
|
-
kg = load_knowledge_graph_from_triples(sample_triples, "Eval Test")
|
|
467
|
-
|
|
468
|
-
# Generate rules
|
|
469
|
-
rule_gen = RuleGenerator(kg)
|
|
470
|
-
rules = rule_gen.generate_simple_rules(min_support=1)
|
|
471
|
-
|
|
472
|
-
# Create and train model (simplified)
|
|
473
|
-
config = get_config("ArangoDB_Triples")
|
|
474
|
-
model = create_npll_model(config)
|
|
475
|
-
model.initialize(kg, rules)
|
|
476
|
-
|
|
477
|
-
# Add unknown facts for evaluation
|
|
478
|
-
kg.add_unknown_fact('Charlie', 'worksAt', 'Company')
|
|
479
|
-
kg.add_unknown_fact('Bob', 'livesIn', 'NYC')
|
|
480
|
-
|
|
481
|
-
# Quick training
|
|
482
|
-
model.train_epoch()
|
|
483
|
-
|
|
484
|
-
# Comprehensive evaluation
|
|
485
|
-
metrics = comprehensive_evaluation(model, kg)
|
|
486
|
-
print(f"Evaluation Results: {metrics}")
|
|
487
|
-
|
|
488
|
-
# Specific evaluations
|
|
489
|
-
evaluator = create_evaluator(kg)
|
|
490
|
-
link_metrics = evaluator.evaluate_link_prediction(model)
|
|
491
|
-
print(f"Link Prediction: {link_metrics}")
|
|
492
|
-
|
|
493
|
-
return metrics
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
if __name__ == "__main__":
|
|
1
|
+
"""
|
|
2
|
+
Evaluation Metrics for NPLL Implementation
|
|
3
|
+
Implements MRR, Hit@K, and other knowledge graph evaluation metrics from the paper
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import numpy as np
|
|
8
|
+
from typing import List, Dict, Set, Tuple, Optional, Any, Union
|
|
9
|
+
import logging
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
from ..core import KnowledgeGraph, Triple, Entity, Relation
|
|
15
|
+
from ..npll_model import NPLLModel
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class EvaluationMetrics:
|
|
22
|
+
"""
|
|
23
|
+
Comprehensive evaluation metrics for knowledge graph completion
|
|
24
|
+
"""
|
|
25
|
+
# Link prediction metrics (standard)
|
|
26
|
+
mrr: float # Mean Reciprocal Rank
|
|
27
|
+
hit_at_1: float # Hit@1
|
|
28
|
+
hit_at_3: float # Hit@3
|
|
29
|
+
hit_at_10: float # Hit@10
|
|
30
|
+
|
|
31
|
+
# Extended metrics
|
|
32
|
+
mean_rank: float # Mean rank
|
|
33
|
+
median_rank: float # Median rank
|
|
34
|
+
|
|
35
|
+
# Rule quality metrics
|
|
36
|
+
rule_precision: Optional[float] = None
|
|
37
|
+
rule_recall: Optional[float] = None
|
|
38
|
+
rule_f1: Optional[float] = None
|
|
39
|
+
|
|
40
|
+
# Confidence calibration
|
|
41
|
+
calibration_error: Optional[float] = None
|
|
42
|
+
confidence_correlation: Optional[float] = None
|
|
43
|
+
|
|
44
|
+
# Efficiency metrics
|
|
45
|
+
evaluation_time: float = 0.0
|
|
46
|
+
predictions_per_second: float = 0.0
|
|
47
|
+
|
|
48
|
+
def __str__(self) -> str:
|
|
49
|
+
return (f"Evaluation Metrics:\\n"
|
|
50
|
+
f" MRR: {self.mrr:.4f}\\n"
|
|
51
|
+
f" Hit@1: {self.hit_at_1:.4f}\\n"
|
|
52
|
+
f" Hit@3: {self.hit_at_3:.4f}\\n"
|
|
53
|
+
f" Hit@10: {self.hit_at_10:.4f}\\n"
|
|
54
|
+
f" Mean Rank: {self.mean_rank:.2f}\\n"
|
|
55
|
+
f" Median Rank: {self.median_rank:.2f}")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class KnowledgeGraphEvaluator:
|
|
59
|
+
"""
|
|
60
|
+
Base evaluator for knowledge graph tasks
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, knowledge_graph: KnowledgeGraph, filter_known: bool = True):
|
|
64
|
+
"""
|
|
65
|
+
Initialize evaluator
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
knowledge_graph: Knowledge graph for evaluation
|
|
69
|
+
filter_known: Whether to filter out known facts during ranking
|
|
70
|
+
"""
|
|
71
|
+
self.kg = knowledge_graph
|
|
72
|
+
self.filter_known = filter_known
|
|
73
|
+
|
|
74
|
+
# Create sets for efficient lookup
|
|
75
|
+
self.known_fact_set = set()
|
|
76
|
+
for fact in self.kg.known_facts:
|
|
77
|
+
self.known_fact_set.add((fact.head.name, fact.relation.name, fact.tail.name))
|
|
78
|
+
|
|
79
|
+
def create_negative_samples(self,
|
|
80
|
+
test_triple: Triple,
|
|
81
|
+
corruption_mode: str = "both") -> List[Triple]:
|
|
82
|
+
"""
|
|
83
|
+
Create negative samples by corrupting head or tail entities
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
test_triple: Triple to create negatives for
|
|
87
|
+
corruption_mode: "head", "tail", or "both"
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
List of negative triples
|
|
91
|
+
"""
|
|
92
|
+
negatives = []
|
|
93
|
+
|
|
94
|
+
if corruption_mode in ["head", "both"]:
|
|
95
|
+
# Corrupt head entity
|
|
96
|
+
for entity in self.kg.entities:
|
|
97
|
+
if entity != test_triple.head:
|
|
98
|
+
negative = Triple(
|
|
99
|
+
head=entity,
|
|
100
|
+
relation=test_triple.relation,
|
|
101
|
+
tail=test_triple.tail
|
|
102
|
+
)
|
|
103
|
+
negatives.append(negative)
|
|
104
|
+
|
|
105
|
+
if corruption_mode in ["tail", "both"]:
|
|
106
|
+
# Corrupt tail entity
|
|
107
|
+
for entity in self.kg.entities:
|
|
108
|
+
if entity != test_triple.tail:
|
|
109
|
+
negative = Triple(
|
|
110
|
+
head=test_triple.head,
|
|
111
|
+
relation=test_triple.relation,
|
|
112
|
+
tail=entity
|
|
113
|
+
)
|
|
114
|
+
negatives.append(negative)
|
|
115
|
+
|
|
116
|
+
return negatives
|
|
117
|
+
|
|
118
|
+
def filter_candidates(self, candidates: List[Triple]) -> List[Triple]:
|
|
119
|
+
"""Filter out known facts from candidates if filtering is enabled"""
|
|
120
|
+
if not self.filter_known:
|
|
121
|
+
return candidates
|
|
122
|
+
|
|
123
|
+
filtered = []
|
|
124
|
+
for candidate in candidates:
|
|
125
|
+
triple_tuple = (candidate.head.name, candidate.relation.name, candidate.tail.name)
|
|
126
|
+
if triple_tuple not in self.known_fact_set:
|
|
127
|
+
filtered.append(candidate)
|
|
128
|
+
|
|
129
|
+
return filtered
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class LinkPredictionEvaluator(KnowledgeGraphEvaluator):
|
|
133
|
+
"""
|
|
134
|
+
Evaluator for link prediction task
|
|
135
|
+
Implements standard knowledge graph completion metrics
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def evaluate_link_prediction(self,
|
|
139
|
+
model: NPLLModel,
|
|
140
|
+
test_triples: Optional[List[Triple]] = None,
|
|
141
|
+
top_k: List[int] = [1, 3, 10],
|
|
142
|
+
corruption_mode: str = "both") -> Dict[str, float]:
|
|
143
|
+
"""
|
|
144
|
+
Evaluate link prediction performance
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
model: Trained NPLL model
|
|
148
|
+
test_triples: Test triples (uses unknown facts if None)
|
|
149
|
+
top_k: List of K values for Hit@K computation
|
|
150
|
+
corruption_mode: "head", "tail", or "both"
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Dictionary with evaluation metrics
|
|
154
|
+
"""
|
|
155
|
+
if test_triples is None:
|
|
156
|
+
test_triples = list(self.kg.unknown_facts)
|
|
157
|
+
|
|
158
|
+
if not test_triples:
|
|
159
|
+
logger.warning("No test triples available for evaluation")
|
|
160
|
+
return {}
|
|
161
|
+
|
|
162
|
+
logger.info(f"Evaluating link prediction on {len(test_triples)} test triples")
|
|
163
|
+
|
|
164
|
+
start_time = time.time()
|
|
165
|
+
ranks = []
|
|
166
|
+
|
|
167
|
+
for i, test_triple in enumerate(test_triples):
|
|
168
|
+
if i % 100 == 0:
|
|
169
|
+
logger.debug(f"Evaluating triple {i}/{len(test_triples)}")
|
|
170
|
+
|
|
171
|
+
# Create candidates (test triple + negatives)
|
|
172
|
+
candidates = [test_triple] + self.create_negative_samples(test_triple, corruption_mode)
|
|
173
|
+
|
|
174
|
+
# Filter known facts
|
|
175
|
+
candidates = self.filter_candidates(candidates)
|
|
176
|
+
|
|
177
|
+
if not candidates:
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
# Get model predictions
|
|
181
|
+
try:
|
|
182
|
+
predictions = model.forward(candidates)
|
|
183
|
+
scores = predictions['probabilities'].cpu().numpy()
|
|
184
|
+
|
|
185
|
+
# Find rank of test triple (first candidate)
|
|
186
|
+
test_score = scores[0]
|
|
187
|
+
rank = 1 + np.sum(scores > test_score) # Rank starts from 1
|
|
188
|
+
ranks.append(rank)
|
|
189
|
+
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.warning(f"Error evaluating triple {test_triple}: {e}")
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
if not ranks:
|
|
195
|
+
logger.error("No valid ranks computed")
|
|
196
|
+
return {}
|
|
197
|
+
|
|
198
|
+
evaluation_time = time.time() - start_time
|
|
199
|
+
|
|
200
|
+
# Compute metrics
|
|
201
|
+
ranks = np.array(ranks, dtype=float)
|
|
202
|
+
|
|
203
|
+
metrics = {
|
|
204
|
+
'mrr': float(np.mean(1.0 / ranks)),
|
|
205
|
+
'mean_rank': float(np.mean(ranks)),
|
|
206
|
+
'median_rank': float(np.median(ranks)),
|
|
207
|
+
'evaluation_time': evaluation_time,
|
|
208
|
+
'predictions_per_second': len(test_triples) / evaluation_time if evaluation_time > 0 else 0.0
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
# Compute Hit@K metrics
|
|
212
|
+
for k in top_k:
|
|
213
|
+
hit_at_k = np.mean(ranks <= k)
|
|
214
|
+
metrics[f'hit@{k}'] = float(hit_at_k)
|
|
215
|
+
|
|
216
|
+
logger.info(f"Link prediction evaluation completed: MRR={metrics['mrr']:.4f}")
|
|
217
|
+
|
|
218
|
+
return metrics
|
|
219
|
+
|
|
220
|
+
def evaluate_entity_ranking(self,
|
|
221
|
+
model: NPLLModel,
|
|
222
|
+
query_relations: List[str],
|
|
223
|
+
top_k: int = 10) -> Dict[str, List[Tuple[str, float]]]:
|
|
224
|
+
"""
|
|
225
|
+
Evaluate entity ranking for specific relations
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
model: Trained NPLL model
|
|
229
|
+
query_relations: Relations to evaluate
|
|
230
|
+
top_k: Number of top entities to return
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Dictionary mapping relations to ranked entity lists
|
|
234
|
+
"""
|
|
235
|
+
results = {}
|
|
236
|
+
|
|
237
|
+
for relation_name in query_relations:
|
|
238
|
+
relation = self.kg.get_relation(relation_name)
|
|
239
|
+
if relation is None:
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
# For each head entity, rank all possible tail entities
|
|
243
|
+
entity_scores = defaultdict(list)
|
|
244
|
+
|
|
245
|
+
for head_entity in self.kg.entities:
|
|
246
|
+
candidates = []
|
|
247
|
+
for tail_entity in self.kg.entities:
|
|
248
|
+
if head_entity != tail_entity:
|
|
249
|
+
candidate = Triple(head=head_entity, relation=relation, tail=tail_entity)
|
|
250
|
+
candidates.append(candidate)
|
|
251
|
+
|
|
252
|
+
# Get predictions
|
|
253
|
+
if candidates:
|
|
254
|
+
predictions = model.forward(candidates)
|
|
255
|
+
scores = predictions['probabilities'].cpu().numpy()
|
|
256
|
+
|
|
257
|
+
for candidate, score in zip(candidates, scores):
|
|
258
|
+
entity_scores[candidate.tail.name].append(score)
|
|
259
|
+
|
|
260
|
+
# Average scores and rank
|
|
261
|
+
avg_scores = [(entity, np.mean(scores))
|
|
262
|
+
for entity, scores in entity_scores.items()]
|
|
263
|
+
avg_scores.sort(key=lambda x: x[1], reverse=True)
|
|
264
|
+
|
|
265
|
+
results[relation_name] = avg_scores[:top_k]
|
|
266
|
+
|
|
267
|
+
return results
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class RuleQualityEvaluator(KnowledgeGraphEvaluator):
|
|
271
|
+
"""
|
|
272
|
+
Evaluator for logical rule quality
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def evaluate_rule_quality(self, model: NPLLModel) -> Dict[str, float]:
|
|
276
|
+
"""
|
|
277
|
+
Evaluate quality of learned logical rules
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
model: Trained NPLL model
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Dictionary with rule quality metrics
|
|
284
|
+
"""
|
|
285
|
+
if not model.is_initialized or model.mln is None:
|
|
286
|
+
return {}
|
|
287
|
+
|
|
288
|
+
rule_confidences = model.get_rule_confidences()
|
|
289
|
+
if not rule_confidences:
|
|
290
|
+
return {}
|
|
291
|
+
|
|
292
|
+
# Compute rule statistics
|
|
293
|
+
confidences = list(rule_confidences.values())
|
|
294
|
+
|
|
295
|
+
metrics = {
|
|
296
|
+
'avg_rule_confidence': float(np.mean(confidences)),
|
|
297
|
+
'std_rule_confidence': float(np.std(confidences)),
|
|
298
|
+
'min_rule_confidence': float(np.min(confidences)),
|
|
299
|
+
'max_rule_confidence': float(np.max(confidences)),
|
|
300
|
+
'num_high_confidence_rules': int(np.sum(np.array(confidences) > 0.8)),
|
|
301
|
+
'num_low_confidence_rules': int(np.sum(np.array(confidences) < 0.2))
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
return metrics
|
|
305
|
+
|
|
306
|
+
def evaluate_rule_coverage(self, model: NPLLModel) -> Dict[str, float]:
|
|
307
|
+
"""
|
|
308
|
+
Evaluate how well rules cover the known facts
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
model: Trained NPLL model
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Dictionary with coverage metrics
|
|
315
|
+
"""
|
|
316
|
+
if not model.is_initialized or model.mln is None:
|
|
317
|
+
return {}
|
|
318
|
+
|
|
319
|
+
# This would require more complex analysis of ground rules
|
|
320
|
+
# For now, return basic statistics
|
|
321
|
+
|
|
322
|
+
total_ground_rules = len(model.mln.ground_rules) if model.mln.ground_rules else 0
|
|
323
|
+
total_facts = len(self.kg.known_facts) + len(self.kg.unknown_facts)
|
|
324
|
+
|
|
325
|
+
coverage_ratio = total_ground_rules / total_facts if total_facts > 0 else 0.0
|
|
326
|
+
|
|
327
|
+
return {
|
|
328
|
+
'total_ground_rules': total_ground_rules,
|
|
329
|
+
'total_facts': total_facts,
|
|
330
|
+
'coverage_ratio': coverage_ratio
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class ConfidenceCalibrationEvaluator(KnowledgeGraphEvaluator):
|
|
335
|
+
"""
|
|
336
|
+
Evaluator for prediction confidence calibration
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
def evaluate_calibration(self,
|
|
340
|
+
model: NPLLModel,
|
|
341
|
+
test_triples: List[Triple],
|
|
342
|
+
num_bins: int = 10) -> Dict[str, float]:
|
|
343
|
+
"""
|
|
344
|
+
Evaluate confidence calibration using reliability diagrams
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
model: Trained NPLL model
|
|
348
|
+
test_triples: Test triples with ground truth
|
|
349
|
+
num_bins: Number of confidence bins
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Dictionary with calibration metrics
|
|
353
|
+
"""
|
|
354
|
+
# Get model predictions
|
|
355
|
+
predictions = model.forward(test_triples)
|
|
356
|
+
confidences = predictions['probabilities'].cpu().numpy()
|
|
357
|
+
|
|
358
|
+
# For this example, assume all test triples are positive
|
|
359
|
+
# In practice, you'd need ground truth labels
|
|
360
|
+
ground_truth = np.ones(len(test_triples)) # Placeholder
|
|
361
|
+
|
|
362
|
+
# Compute Expected Calibration Error (ECE)
|
|
363
|
+
bin_boundaries = np.linspace(0, 1, num_bins + 1)
|
|
364
|
+
bin_lowers = bin_boundaries[:-1]
|
|
365
|
+
bin_uppers = bin_boundaries[1:]
|
|
366
|
+
|
|
367
|
+
ece = 0.0
|
|
368
|
+
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
|
369
|
+
in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
|
|
370
|
+
prop_in_bin = in_bin.mean()
|
|
371
|
+
|
|
372
|
+
if prop_in_bin > 0:
|
|
373
|
+
accuracy_in_bin = ground_truth[in_bin].mean()
|
|
374
|
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
|
375
|
+
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
|
376
|
+
|
|
377
|
+
return {
|
|
378
|
+
'expected_calibration_error': float(ece),
|
|
379
|
+
'avg_confidence': float(np.mean(confidences)),
|
|
380
|
+
'confidence_std': float(np.std(confidences))
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def create_evaluator(knowledge_graph: KnowledgeGraph,
|
|
385
|
+
evaluation_type: str = "link_prediction") -> KnowledgeGraphEvaluator:
|
|
386
|
+
"""
|
|
387
|
+
Factory function to create appropriate evaluator
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
knowledge_graph: Knowledge graph for evaluation
|
|
391
|
+
evaluation_type: Type of evaluator to create
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
Configured evaluator
|
|
395
|
+
"""
|
|
396
|
+
if evaluation_type == "link_prediction":
|
|
397
|
+
return LinkPredictionEvaluator(knowledge_graph)
|
|
398
|
+
elif evaluation_type == "rule_quality":
|
|
399
|
+
return RuleQualityEvaluator(knowledge_graph)
|
|
400
|
+
elif evaluation_type == "confidence_calibration":
|
|
401
|
+
return ConfidenceCalibrationEvaluator(knowledge_graph)
|
|
402
|
+
else:
|
|
403
|
+
return LinkPredictionEvaluator(knowledge_graph) # Default
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def comprehensive_evaluation(model: NPLLModel,
|
|
407
|
+
knowledge_graph: KnowledgeGraph,
|
|
408
|
+
test_triples: Optional[List[Triple]] = None) -> EvaluationMetrics:
|
|
409
|
+
"""
|
|
410
|
+
Run comprehensive evaluation of NPLL model
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
model: Trained NPLL model
|
|
414
|
+
knowledge_graph: Knowledge graph for evaluation
|
|
415
|
+
test_triples: Optional test triples
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
Comprehensive evaluation metrics
|
|
419
|
+
"""
|
|
420
|
+
start_time = time.time()
|
|
421
|
+
|
|
422
|
+
# Link prediction evaluation
|
|
423
|
+
link_evaluator = LinkPredictionEvaluator(knowledge_graph)
|
|
424
|
+
link_metrics = link_evaluator.evaluate_link_prediction(model, test_triples)
|
|
425
|
+
|
|
426
|
+
# Rule quality evaluation
|
|
427
|
+
rule_evaluator = RuleQualityEvaluator(knowledge_graph)
|
|
428
|
+
rule_metrics = rule_evaluator.evaluate_rule_quality(model)
|
|
429
|
+
|
|
430
|
+
evaluation_time = time.time() - start_time
|
|
431
|
+
|
|
432
|
+
# Create comprehensive metrics object
|
|
433
|
+
metrics = EvaluationMetrics(
|
|
434
|
+
mrr=link_metrics.get('mrr', 0.0),
|
|
435
|
+
hit_at_1=link_metrics.get('hit@1', 0.0),
|
|
436
|
+
hit_at_3=link_metrics.get('hit@3', 0.0),
|
|
437
|
+
hit_at_10=link_metrics.get('hit@10', 0.0),
|
|
438
|
+
mean_rank=link_metrics.get('mean_rank', 0.0),
|
|
439
|
+
median_rank=link_metrics.get('median_rank', 0.0),
|
|
440
|
+
rule_precision=rule_metrics.get('avg_rule_confidence'),
|
|
441
|
+
evaluation_time=evaluation_time,
|
|
442
|
+
predictions_per_second=link_metrics.get('predictions_per_second', 0.0)
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
return metrics
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
# Example usage function
|
|
449
|
+
def example_evaluation():
|
|
450
|
+
"""
|
|
451
|
+
Example showing comprehensive evaluation with sample data
|
|
452
|
+
"""
|
|
453
|
+
from ..core import load_knowledge_graph_from_triples
|
|
454
|
+
from ..core.logical_rules import RuleGenerator
|
|
455
|
+
from ..npll_model import create_npll_model
|
|
456
|
+
from ..utils import get_config
|
|
457
|
+
|
|
458
|
+
# Create sample data
|
|
459
|
+
sample_triples = [
|
|
460
|
+
('Alice', 'friendOf', 'Bob'),
|
|
461
|
+
('Bob', 'worksAt', 'Company'),
|
|
462
|
+
('Charlie', 'friendOf', 'Alice'),
|
|
463
|
+
('Alice', 'livesIn', 'NYC')
|
|
464
|
+
]
|
|
465
|
+
|
|
466
|
+
kg = load_knowledge_graph_from_triples(sample_triples, "Eval Test")
|
|
467
|
+
|
|
468
|
+
# Generate rules
|
|
469
|
+
rule_gen = RuleGenerator(kg)
|
|
470
|
+
rules = rule_gen.generate_simple_rules(min_support=1)
|
|
471
|
+
|
|
472
|
+
# Create and train model (simplified)
|
|
473
|
+
config = get_config("ArangoDB_Triples")
|
|
474
|
+
model = create_npll_model(config)
|
|
475
|
+
model.initialize(kg, rules)
|
|
476
|
+
|
|
477
|
+
# Add unknown facts for evaluation
|
|
478
|
+
kg.add_unknown_fact('Charlie', 'worksAt', 'Company')
|
|
479
|
+
kg.add_unknown_fact('Bob', 'livesIn', 'NYC')
|
|
480
|
+
|
|
481
|
+
# Quick training
|
|
482
|
+
model.train_epoch()
|
|
483
|
+
|
|
484
|
+
# Comprehensive evaluation
|
|
485
|
+
metrics = comprehensive_evaluation(model, kg)
|
|
486
|
+
print(f"Evaluation Results: {metrics}")
|
|
487
|
+
|
|
488
|
+
# Specific evaluations
|
|
489
|
+
evaluator = create_evaluator(kg)
|
|
490
|
+
link_metrics = evaluator.evaluate_link_prediction(model)
|
|
491
|
+
print(f"Link Prediction: {link_metrics}")
|
|
492
|
+
|
|
493
|
+
return metrics
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
if __name__ == "__main__":
|
|
497
497
|
example_evaluation()
|