odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,497 +1,497 @@
1
- """
2
- Evaluation Metrics for NPLL Implementation
3
- Implements MRR, Hit@K, and other knowledge graph evaluation metrics from the paper
4
- """
5
-
6
- import torch
7
- import numpy as np
8
- from typing import List, Dict, Set, Tuple, Optional, Any, Union
9
- import logging
10
- from dataclasses import dataclass
11
- from collections import defaultdict
12
- import time
13
-
14
- from ..core import KnowledgeGraph, Triple, Entity, Relation
15
- from ..npll_model import NPLLModel
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- @dataclass
21
- class EvaluationMetrics:
22
- """
23
- Comprehensive evaluation metrics for knowledge graph completion
24
- """
25
- # Link prediction metrics (standard)
26
- mrr: float # Mean Reciprocal Rank
27
- hit_at_1: float # Hit@1
28
- hit_at_3: float # Hit@3
29
- hit_at_10: float # Hit@10
30
-
31
- # Extended metrics
32
- mean_rank: float # Mean rank
33
- median_rank: float # Median rank
34
-
35
- # Rule quality metrics
36
- rule_precision: Optional[float] = None
37
- rule_recall: Optional[float] = None
38
- rule_f1: Optional[float] = None
39
-
40
- # Confidence calibration
41
- calibration_error: Optional[float] = None
42
- confidence_correlation: Optional[float] = None
43
-
44
- # Efficiency metrics
45
- evaluation_time: float = 0.0
46
- predictions_per_second: float = 0.0
47
-
48
- def __str__(self) -> str:
49
- return (f"Evaluation Metrics:\\n"
50
- f" MRR: {self.mrr:.4f}\\n"
51
- f" Hit@1: {self.hit_at_1:.4f}\\n"
52
- f" Hit@3: {self.hit_at_3:.4f}\\n"
53
- f" Hit@10: {self.hit_at_10:.4f}\\n"
54
- f" Mean Rank: {self.mean_rank:.2f}\\n"
55
- f" Median Rank: {self.median_rank:.2f}")
56
-
57
-
58
- class KnowledgeGraphEvaluator:
59
- """
60
- Base evaluator for knowledge graph tasks
61
- """
62
-
63
- def __init__(self, knowledge_graph: KnowledgeGraph, filter_known: bool = True):
64
- """
65
- Initialize evaluator
66
-
67
- Args:
68
- knowledge_graph: Knowledge graph for evaluation
69
- filter_known: Whether to filter out known facts during ranking
70
- """
71
- self.kg = knowledge_graph
72
- self.filter_known = filter_known
73
-
74
- # Create sets for efficient lookup
75
- self.known_fact_set = set()
76
- for fact in self.kg.known_facts:
77
- self.known_fact_set.add((fact.head.name, fact.relation.name, fact.tail.name))
78
-
79
- def create_negative_samples(self,
80
- test_triple: Triple,
81
- corruption_mode: str = "both") -> List[Triple]:
82
- """
83
- Create negative samples by corrupting head or tail entities
84
-
85
- Args:
86
- test_triple: Triple to create negatives for
87
- corruption_mode: "head", "tail", or "both"
88
-
89
- Returns:
90
- List of negative triples
91
- """
92
- negatives = []
93
-
94
- if corruption_mode in ["head", "both"]:
95
- # Corrupt head entity
96
- for entity in self.kg.entities:
97
- if entity != test_triple.head:
98
- negative = Triple(
99
- head=entity,
100
- relation=test_triple.relation,
101
- tail=test_triple.tail
102
- )
103
- negatives.append(negative)
104
-
105
- if corruption_mode in ["tail", "both"]:
106
- # Corrupt tail entity
107
- for entity in self.kg.entities:
108
- if entity != test_triple.tail:
109
- negative = Triple(
110
- head=test_triple.head,
111
- relation=test_triple.relation,
112
- tail=entity
113
- )
114
- negatives.append(negative)
115
-
116
- return negatives
117
-
118
- def filter_candidates(self, candidates: List[Triple]) -> List[Triple]:
119
- """Filter out known facts from candidates if filtering is enabled"""
120
- if not self.filter_known:
121
- return candidates
122
-
123
- filtered = []
124
- for candidate in candidates:
125
- triple_tuple = (candidate.head.name, candidate.relation.name, candidate.tail.name)
126
- if triple_tuple not in self.known_fact_set:
127
- filtered.append(candidate)
128
-
129
- return filtered
130
-
131
-
132
- class LinkPredictionEvaluator(KnowledgeGraphEvaluator):
133
- """
134
- Evaluator for link prediction task
135
- Implements standard knowledge graph completion metrics
136
- """
137
-
138
- def evaluate_link_prediction(self,
139
- model: NPLLModel,
140
- test_triples: Optional[List[Triple]] = None,
141
- top_k: List[int] = [1, 3, 10],
142
- corruption_mode: str = "both") -> Dict[str, float]:
143
- """
144
- Evaluate link prediction performance
145
-
146
- Args:
147
- model: Trained NPLL model
148
- test_triples: Test triples (uses unknown facts if None)
149
- top_k: List of K values for Hit@K computation
150
- corruption_mode: "head", "tail", or "both"
151
-
152
- Returns:
153
- Dictionary with evaluation metrics
154
- """
155
- if test_triples is None:
156
- test_triples = list(self.kg.unknown_facts)
157
-
158
- if not test_triples:
159
- logger.warning("No test triples available for evaluation")
160
- return {}
161
-
162
- logger.info(f"Evaluating link prediction on {len(test_triples)} test triples")
163
-
164
- start_time = time.time()
165
- ranks = []
166
-
167
- for i, test_triple in enumerate(test_triples):
168
- if i % 100 == 0:
169
- logger.debug(f"Evaluating triple {i}/{len(test_triples)}")
170
-
171
- # Create candidates (test triple + negatives)
172
- candidates = [test_triple] + self.create_negative_samples(test_triple, corruption_mode)
173
-
174
- # Filter known facts
175
- candidates = self.filter_candidates(candidates)
176
-
177
- if not candidates:
178
- continue
179
-
180
- # Get model predictions
181
- try:
182
- predictions = model.forward(candidates)
183
- scores = predictions['probabilities'].cpu().numpy()
184
-
185
- # Find rank of test triple (first candidate)
186
- test_score = scores[0]
187
- rank = 1 + np.sum(scores > test_score) # Rank starts from 1
188
- ranks.append(rank)
189
-
190
- except Exception as e:
191
- logger.warning(f"Error evaluating triple {test_triple}: {e}")
192
- continue
193
-
194
- if not ranks:
195
- logger.error("No valid ranks computed")
196
- return {}
197
-
198
- evaluation_time = time.time() - start_time
199
-
200
- # Compute metrics
201
- ranks = np.array(ranks, dtype=float)
202
-
203
- metrics = {
204
- 'mrr': float(np.mean(1.0 / ranks)),
205
- 'mean_rank': float(np.mean(ranks)),
206
- 'median_rank': float(np.median(ranks)),
207
- 'evaluation_time': evaluation_time,
208
- 'predictions_per_second': len(test_triples) / evaluation_time if evaluation_time > 0 else 0.0
209
- }
210
-
211
- # Compute Hit@K metrics
212
- for k in top_k:
213
- hit_at_k = np.mean(ranks <= k)
214
- metrics[f'hit@{k}'] = float(hit_at_k)
215
-
216
- logger.info(f"Link prediction evaluation completed: MRR={metrics['mrr']:.4f}")
217
-
218
- return metrics
219
-
220
- def evaluate_entity_ranking(self,
221
- model: NPLLModel,
222
- query_relations: List[str],
223
- top_k: int = 10) -> Dict[str, List[Tuple[str, float]]]:
224
- """
225
- Evaluate entity ranking for specific relations
226
-
227
- Args:
228
- model: Trained NPLL model
229
- query_relations: Relations to evaluate
230
- top_k: Number of top entities to return
231
-
232
- Returns:
233
- Dictionary mapping relations to ranked entity lists
234
- """
235
- results = {}
236
-
237
- for relation_name in query_relations:
238
- relation = self.kg.get_relation(relation_name)
239
- if relation is None:
240
- continue
241
-
242
- # For each head entity, rank all possible tail entities
243
- entity_scores = defaultdict(list)
244
-
245
- for head_entity in self.kg.entities:
246
- candidates = []
247
- for tail_entity in self.kg.entities:
248
- if head_entity != tail_entity:
249
- candidate = Triple(head=head_entity, relation=relation, tail=tail_entity)
250
- candidates.append(candidate)
251
-
252
- # Get predictions
253
- if candidates:
254
- predictions = model.forward(candidates)
255
- scores = predictions['probabilities'].cpu().numpy()
256
-
257
- for candidate, score in zip(candidates, scores):
258
- entity_scores[candidate.tail.name].append(score)
259
-
260
- # Average scores and rank
261
- avg_scores = [(entity, np.mean(scores))
262
- for entity, scores in entity_scores.items()]
263
- avg_scores.sort(key=lambda x: x[1], reverse=True)
264
-
265
- results[relation_name] = avg_scores[:top_k]
266
-
267
- return results
268
-
269
-
270
- class RuleQualityEvaluator(KnowledgeGraphEvaluator):
271
- """
272
- Evaluator for logical rule quality
273
- """
274
-
275
- def evaluate_rule_quality(self, model: NPLLModel) -> Dict[str, float]:
276
- """
277
- Evaluate quality of learned logical rules
278
-
279
- Args:
280
- model: Trained NPLL model
281
-
282
- Returns:
283
- Dictionary with rule quality metrics
284
- """
285
- if not model.is_initialized or model.mln is None:
286
- return {}
287
-
288
- rule_confidences = model.get_rule_confidences()
289
- if not rule_confidences:
290
- return {}
291
-
292
- # Compute rule statistics
293
- confidences = list(rule_confidences.values())
294
-
295
- metrics = {
296
- 'avg_rule_confidence': float(np.mean(confidences)),
297
- 'std_rule_confidence': float(np.std(confidences)),
298
- 'min_rule_confidence': float(np.min(confidences)),
299
- 'max_rule_confidence': float(np.max(confidences)),
300
- 'num_high_confidence_rules': int(np.sum(np.array(confidences) > 0.8)),
301
- 'num_low_confidence_rules': int(np.sum(np.array(confidences) < 0.2))
302
- }
303
-
304
- return metrics
305
-
306
- def evaluate_rule_coverage(self, model: NPLLModel) -> Dict[str, float]:
307
- """
308
- Evaluate how well rules cover the known facts
309
-
310
- Args:
311
- model: Trained NPLL model
312
-
313
- Returns:
314
- Dictionary with coverage metrics
315
- """
316
- if not model.is_initialized or model.mln is None:
317
- return {}
318
-
319
- # This would require more complex analysis of ground rules
320
- # For now, return basic statistics
321
-
322
- total_ground_rules = len(model.mln.ground_rules) if model.mln.ground_rules else 0
323
- total_facts = len(self.kg.known_facts) + len(self.kg.unknown_facts)
324
-
325
- coverage_ratio = total_ground_rules / total_facts if total_facts > 0 else 0.0
326
-
327
- return {
328
- 'total_ground_rules': total_ground_rules,
329
- 'total_facts': total_facts,
330
- 'coverage_ratio': coverage_ratio
331
- }
332
-
333
-
334
- class ConfidenceCalibrationEvaluator(KnowledgeGraphEvaluator):
335
- """
336
- Evaluator for prediction confidence calibration
337
- """
338
-
339
- def evaluate_calibration(self,
340
- model: NPLLModel,
341
- test_triples: List[Triple],
342
- num_bins: int = 10) -> Dict[str, float]:
343
- """
344
- Evaluate confidence calibration using reliability diagrams
345
-
346
- Args:
347
- model: Trained NPLL model
348
- test_triples: Test triples with ground truth
349
- num_bins: Number of confidence bins
350
-
351
- Returns:
352
- Dictionary with calibration metrics
353
- """
354
- # Get model predictions
355
- predictions = model.forward(test_triples)
356
- confidences = predictions['probabilities'].cpu().numpy()
357
-
358
- # For this example, assume all test triples are positive
359
- # In practice, you'd need ground truth labels
360
- ground_truth = np.ones(len(test_triples)) # Placeholder
361
-
362
- # Compute Expected Calibration Error (ECE)
363
- bin_boundaries = np.linspace(0, 1, num_bins + 1)
364
- bin_lowers = bin_boundaries[:-1]
365
- bin_uppers = bin_boundaries[1:]
366
-
367
- ece = 0.0
368
- for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
369
- in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
370
- prop_in_bin = in_bin.mean()
371
-
372
- if prop_in_bin > 0:
373
- accuracy_in_bin = ground_truth[in_bin].mean()
374
- avg_confidence_in_bin = confidences[in_bin].mean()
375
- ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
376
-
377
- return {
378
- 'expected_calibration_error': float(ece),
379
- 'avg_confidence': float(np.mean(confidences)),
380
- 'confidence_std': float(np.std(confidences))
381
- }
382
-
383
-
384
- def create_evaluator(knowledge_graph: KnowledgeGraph,
385
- evaluation_type: str = "link_prediction") -> KnowledgeGraphEvaluator:
386
- """
387
- Factory function to create appropriate evaluator
388
-
389
- Args:
390
- knowledge_graph: Knowledge graph for evaluation
391
- evaluation_type: Type of evaluator to create
392
-
393
- Returns:
394
- Configured evaluator
395
- """
396
- if evaluation_type == "link_prediction":
397
- return LinkPredictionEvaluator(knowledge_graph)
398
- elif evaluation_type == "rule_quality":
399
- return RuleQualityEvaluator(knowledge_graph)
400
- elif evaluation_type == "confidence_calibration":
401
- return ConfidenceCalibrationEvaluator(knowledge_graph)
402
- else:
403
- return LinkPredictionEvaluator(knowledge_graph) # Default
404
-
405
-
406
- def comprehensive_evaluation(model: NPLLModel,
407
- knowledge_graph: KnowledgeGraph,
408
- test_triples: Optional[List[Triple]] = None) -> EvaluationMetrics:
409
- """
410
- Run comprehensive evaluation of NPLL model
411
-
412
- Args:
413
- model: Trained NPLL model
414
- knowledge_graph: Knowledge graph for evaluation
415
- test_triples: Optional test triples
416
-
417
- Returns:
418
- Comprehensive evaluation metrics
419
- """
420
- start_time = time.time()
421
-
422
- # Link prediction evaluation
423
- link_evaluator = LinkPredictionEvaluator(knowledge_graph)
424
- link_metrics = link_evaluator.evaluate_link_prediction(model, test_triples)
425
-
426
- # Rule quality evaluation
427
- rule_evaluator = RuleQualityEvaluator(knowledge_graph)
428
- rule_metrics = rule_evaluator.evaluate_rule_quality(model)
429
-
430
- evaluation_time = time.time() - start_time
431
-
432
- # Create comprehensive metrics object
433
- metrics = EvaluationMetrics(
434
- mrr=link_metrics.get('mrr', 0.0),
435
- hit_at_1=link_metrics.get('hit@1', 0.0),
436
- hit_at_3=link_metrics.get('hit@3', 0.0),
437
- hit_at_10=link_metrics.get('hit@10', 0.0),
438
- mean_rank=link_metrics.get('mean_rank', 0.0),
439
- median_rank=link_metrics.get('median_rank', 0.0),
440
- rule_precision=rule_metrics.get('avg_rule_confidence'),
441
- evaluation_time=evaluation_time,
442
- predictions_per_second=link_metrics.get('predictions_per_second', 0.0)
443
- )
444
-
445
- return metrics
446
-
447
-
448
- # Example usage function
449
- def example_evaluation():
450
- """
451
- Example showing comprehensive evaluation with sample data
452
- """
453
- from ..core import load_knowledge_graph_from_triples
454
- from ..core.logical_rules import RuleGenerator
455
- from ..npll_model import create_npll_model
456
- from ..utils import get_config
457
-
458
- # Create sample data
459
- sample_triples = [
460
- ('Alice', 'friendOf', 'Bob'),
461
- ('Bob', 'worksAt', 'Company'),
462
- ('Charlie', 'friendOf', 'Alice'),
463
- ('Alice', 'livesIn', 'NYC')
464
- ]
465
-
466
- kg = load_knowledge_graph_from_triples(sample_triples, "Eval Test")
467
-
468
- # Generate rules
469
- rule_gen = RuleGenerator(kg)
470
- rules = rule_gen.generate_simple_rules(min_support=1)
471
-
472
- # Create and train model (simplified)
473
- config = get_config("ArangoDB_Triples")
474
- model = create_npll_model(config)
475
- model.initialize(kg, rules)
476
-
477
- # Add unknown facts for evaluation
478
- kg.add_unknown_fact('Charlie', 'worksAt', 'Company')
479
- kg.add_unknown_fact('Bob', 'livesIn', 'NYC')
480
-
481
- # Quick training
482
- model.train_epoch()
483
-
484
- # Comprehensive evaluation
485
- metrics = comprehensive_evaluation(model, kg)
486
- print(f"Evaluation Results: {metrics}")
487
-
488
- # Specific evaluations
489
- evaluator = create_evaluator(kg)
490
- link_metrics = evaluator.evaluate_link_prediction(model)
491
- print(f"Link Prediction: {link_metrics}")
492
-
493
- return metrics
494
-
495
-
496
- if __name__ == "__main__":
1
+ """
2
+ Evaluation Metrics for NPLL Implementation
3
+ Implements MRR, Hit@K, and other knowledge graph evaluation metrics from the paper
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from typing import List, Dict, Set, Tuple, Optional, Any, Union
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from collections import defaultdict
12
+ import time
13
+
14
+ from ..core import KnowledgeGraph, Triple, Entity, Relation
15
+ from ..npll_model import NPLLModel
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class EvaluationMetrics:
22
+ """
23
+ Comprehensive evaluation metrics for knowledge graph completion
24
+ """
25
+ # Link prediction metrics (standard)
26
+ mrr: float # Mean Reciprocal Rank
27
+ hit_at_1: float # Hit@1
28
+ hit_at_3: float # Hit@3
29
+ hit_at_10: float # Hit@10
30
+
31
+ # Extended metrics
32
+ mean_rank: float # Mean rank
33
+ median_rank: float # Median rank
34
+
35
+ # Rule quality metrics
36
+ rule_precision: Optional[float] = None
37
+ rule_recall: Optional[float] = None
38
+ rule_f1: Optional[float] = None
39
+
40
+ # Confidence calibration
41
+ calibration_error: Optional[float] = None
42
+ confidence_correlation: Optional[float] = None
43
+
44
+ # Efficiency metrics
45
+ evaluation_time: float = 0.0
46
+ predictions_per_second: float = 0.0
47
+
48
+ def __str__(self) -> str:
49
+ return (f"Evaluation Metrics:\\n"
50
+ f" MRR: {self.mrr:.4f}\\n"
51
+ f" Hit@1: {self.hit_at_1:.4f}\\n"
52
+ f" Hit@3: {self.hit_at_3:.4f}\\n"
53
+ f" Hit@10: {self.hit_at_10:.4f}\\n"
54
+ f" Mean Rank: {self.mean_rank:.2f}\\n"
55
+ f" Median Rank: {self.median_rank:.2f}")
56
+
57
+
58
+ class KnowledgeGraphEvaluator:
59
+ """
60
+ Base evaluator for knowledge graph tasks
61
+ """
62
+
63
+ def __init__(self, knowledge_graph: KnowledgeGraph, filter_known: bool = True):
64
+ """
65
+ Initialize evaluator
66
+
67
+ Args:
68
+ knowledge_graph: Knowledge graph for evaluation
69
+ filter_known: Whether to filter out known facts during ranking
70
+ """
71
+ self.kg = knowledge_graph
72
+ self.filter_known = filter_known
73
+
74
+ # Create sets for efficient lookup
75
+ self.known_fact_set = set()
76
+ for fact in self.kg.known_facts:
77
+ self.known_fact_set.add((fact.head.name, fact.relation.name, fact.tail.name))
78
+
79
+ def create_negative_samples(self,
80
+ test_triple: Triple,
81
+ corruption_mode: str = "both") -> List[Triple]:
82
+ """
83
+ Create negative samples by corrupting head or tail entities
84
+
85
+ Args:
86
+ test_triple: Triple to create negatives for
87
+ corruption_mode: "head", "tail", or "both"
88
+
89
+ Returns:
90
+ List of negative triples
91
+ """
92
+ negatives = []
93
+
94
+ if corruption_mode in ["head", "both"]:
95
+ # Corrupt head entity
96
+ for entity in self.kg.entities:
97
+ if entity != test_triple.head:
98
+ negative = Triple(
99
+ head=entity,
100
+ relation=test_triple.relation,
101
+ tail=test_triple.tail
102
+ )
103
+ negatives.append(negative)
104
+
105
+ if corruption_mode in ["tail", "both"]:
106
+ # Corrupt tail entity
107
+ for entity in self.kg.entities:
108
+ if entity != test_triple.tail:
109
+ negative = Triple(
110
+ head=test_triple.head,
111
+ relation=test_triple.relation,
112
+ tail=entity
113
+ )
114
+ negatives.append(negative)
115
+
116
+ return negatives
117
+
118
+ def filter_candidates(self, candidates: List[Triple]) -> List[Triple]:
119
+ """Filter out known facts from candidates if filtering is enabled"""
120
+ if not self.filter_known:
121
+ return candidates
122
+
123
+ filtered = []
124
+ for candidate in candidates:
125
+ triple_tuple = (candidate.head.name, candidate.relation.name, candidate.tail.name)
126
+ if triple_tuple not in self.known_fact_set:
127
+ filtered.append(candidate)
128
+
129
+ return filtered
130
+
131
+
132
+ class LinkPredictionEvaluator(KnowledgeGraphEvaluator):
133
+ """
134
+ Evaluator for link prediction task
135
+ Implements standard knowledge graph completion metrics
136
+ """
137
+
138
+ def evaluate_link_prediction(self,
139
+ model: NPLLModel,
140
+ test_triples: Optional[List[Triple]] = None,
141
+ top_k: List[int] = [1, 3, 10],
142
+ corruption_mode: str = "both") -> Dict[str, float]:
143
+ """
144
+ Evaluate link prediction performance
145
+
146
+ Args:
147
+ model: Trained NPLL model
148
+ test_triples: Test triples (uses unknown facts if None)
149
+ top_k: List of K values for Hit@K computation
150
+ corruption_mode: "head", "tail", or "both"
151
+
152
+ Returns:
153
+ Dictionary with evaluation metrics
154
+ """
155
+ if test_triples is None:
156
+ test_triples = list(self.kg.unknown_facts)
157
+
158
+ if not test_triples:
159
+ logger.warning("No test triples available for evaluation")
160
+ return {}
161
+
162
+ logger.info(f"Evaluating link prediction on {len(test_triples)} test triples")
163
+
164
+ start_time = time.time()
165
+ ranks = []
166
+
167
+ for i, test_triple in enumerate(test_triples):
168
+ if i % 100 == 0:
169
+ logger.debug(f"Evaluating triple {i}/{len(test_triples)}")
170
+
171
+ # Create candidates (test triple + negatives)
172
+ candidates = [test_triple] + self.create_negative_samples(test_triple, corruption_mode)
173
+
174
+ # Filter known facts
175
+ candidates = self.filter_candidates(candidates)
176
+
177
+ if not candidates:
178
+ continue
179
+
180
+ # Get model predictions
181
+ try:
182
+ predictions = model.forward(candidates)
183
+ scores = predictions['probabilities'].cpu().numpy()
184
+
185
+ # Find rank of test triple (first candidate)
186
+ test_score = scores[0]
187
+ rank = 1 + np.sum(scores > test_score) # Rank starts from 1
188
+ ranks.append(rank)
189
+
190
+ except Exception as e:
191
+ logger.warning(f"Error evaluating triple {test_triple}: {e}")
192
+ continue
193
+
194
+ if not ranks:
195
+ logger.error("No valid ranks computed")
196
+ return {}
197
+
198
+ evaluation_time = time.time() - start_time
199
+
200
+ # Compute metrics
201
+ ranks = np.array(ranks, dtype=float)
202
+
203
+ metrics = {
204
+ 'mrr': float(np.mean(1.0 / ranks)),
205
+ 'mean_rank': float(np.mean(ranks)),
206
+ 'median_rank': float(np.median(ranks)),
207
+ 'evaluation_time': evaluation_time,
208
+ 'predictions_per_second': len(test_triples) / evaluation_time if evaluation_time > 0 else 0.0
209
+ }
210
+
211
+ # Compute Hit@K metrics
212
+ for k in top_k:
213
+ hit_at_k = np.mean(ranks <= k)
214
+ metrics[f'hit@{k}'] = float(hit_at_k)
215
+
216
+ logger.info(f"Link prediction evaluation completed: MRR={metrics['mrr']:.4f}")
217
+
218
+ return metrics
219
+
220
+ def evaluate_entity_ranking(self,
221
+ model: NPLLModel,
222
+ query_relations: List[str],
223
+ top_k: int = 10) -> Dict[str, List[Tuple[str, float]]]:
224
+ """
225
+ Evaluate entity ranking for specific relations
226
+
227
+ Args:
228
+ model: Trained NPLL model
229
+ query_relations: Relations to evaluate
230
+ top_k: Number of top entities to return
231
+
232
+ Returns:
233
+ Dictionary mapping relations to ranked entity lists
234
+ """
235
+ results = {}
236
+
237
+ for relation_name in query_relations:
238
+ relation = self.kg.get_relation(relation_name)
239
+ if relation is None:
240
+ continue
241
+
242
+ # For each head entity, rank all possible tail entities
243
+ entity_scores = defaultdict(list)
244
+
245
+ for head_entity in self.kg.entities:
246
+ candidates = []
247
+ for tail_entity in self.kg.entities:
248
+ if head_entity != tail_entity:
249
+ candidate = Triple(head=head_entity, relation=relation, tail=tail_entity)
250
+ candidates.append(candidate)
251
+
252
+ # Get predictions
253
+ if candidates:
254
+ predictions = model.forward(candidates)
255
+ scores = predictions['probabilities'].cpu().numpy()
256
+
257
+ for candidate, score in zip(candidates, scores):
258
+ entity_scores[candidate.tail.name].append(score)
259
+
260
+ # Average scores and rank
261
+ avg_scores = [(entity, np.mean(scores))
262
+ for entity, scores in entity_scores.items()]
263
+ avg_scores.sort(key=lambda x: x[1], reverse=True)
264
+
265
+ results[relation_name] = avg_scores[:top_k]
266
+
267
+ return results
268
+
269
+
270
+ class RuleQualityEvaluator(KnowledgeGraphEvaluator):
271
+ """
272
+ Evaluator for logical rule quality
273
+ """
274
+
275
+ def evaluate_rule_quality(self, model: NPLLModel) -> Dict[str, float]:
276
+ """
277
+ Evaluate quality of learned logical rules
278
+
279
+ Args:
280
+ model: Trained NPLL model
281
+
282
+ Returns:
283
+ Dictionary with rule quality metrics
284
+ """
285
+ if not model.is_initialized or model.mln is None:
286
+ return {}
287
+
288
+ rule_confidences = model.get_rule_confidences()
289
+ if not rule_confidences:
290
+ return {}
291
+
292
+ # Compute rule statistics
293
+ confidences = list(rule_confidences.values())
294
+
295
+ metrics = {
296
+ 'avg_rule_confidence': float(np.mean(confidences)),
297
+ 'std_rule_confidence': float(np.std(confidences)),
298
+ 'min_rule_confidence': float(np.min(confidences)),
299
+ 'max_rule_confidence': float(np.max(confidences)),
300
+ 'num_high_confidence_rules': int(np.sum(np.array(confidences) > 0.8)),
301
+ 'num_low_confidence_rules': int(np.sum(np.array(confidences) < 0.2))
302
+ }
303
+
304
+ return metrics
305
+
306
+ def evaluate_rule_coverage(self, model: NPLLModel) -> Dict[str, float]:
307
+ """
308
+ Evaluate how well rules cover the known facts
309
+
310
+ Args:
311
+ model: Trained NPLL model
312
+
313
+ Returns:
314
+ Dictionary with coverage metrics
315
+ """
316
+ if not model.is_initialized or model.mln is None:
317
+ return {}
318
+
319
+ # This would require more complex analysis of ground rules
320
+ # For now, return basic statistics
321
+
322
+ total_ground_rules = len(model.mln.ground_rules) if model.mln.ground_rules else 0
323
+ total_facts = len(self.kg.known_facts) + len(self.kg.unknown_facts)
324
+
325
+ coverage_ratio = total_ground_rules / total_facts if total_facts > 0 else 0.0
326
+
327
+ return {
328
+ 'total_ground_rules': total_ground_rules,
329
+ 'total_facts': total_facts,
330
+ 'coverage_ratio': coverage_ratio
331
+ }
332
+
333
+
334
+ class ConfidenceCalibrationEvaluator(KnowledgeGraphEvaluator):
335
+ """
336
+ Evaluator for prediction confidence calibration
337
+ """
338
+
339
+ def evaluate_calibration(self,
340
+ model: NPLLModel,
341
+ test_triples: List[Triple],
342
+ num_bins: int = 10) -> Dict[str, float]:
343
+ """
344
+ Evaluate confidence calibration using reliability diagrams
345
+
346
+ Args:
347
+ model: Trained NPLL model
348
+ test_triples: Test triples with ground truth
349
+ num_bins: Number of confidence bins
350
+
351
+ Returns:
352
+ Dictionary with calibration metrics
353
+ """
354
+ # Get model predictions
355
+ predictions = model.forward(test_triples)
356
+ confidences = predictions['probabilities'].cpu().numpy()
357
+
358
+ # For this example, assume all test triples are positive
359
+ # In practice, you'd need ground truth labels
360
+ ground_truth = np.ones(len(test_triples)) # Placeholder
361
+
362
+ # Compute Expected Calibration Error (ECE)
363
+ bin_boundaries = np.linspace(0, 1, num_bins + 1)
364
+ bin_lowers = bin_boundaries[:-1]
365
+ bin_uppers = bin_boundaries[1:]
366
+
367
+ ece = 0.0
368
+ for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
369
+ in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
370
+ prop_in_bin = in_bin.mean()
371
+
372
+ if prop_in_bin > 0:
373
+ accuracy_in_bin = ground_truth[in_bin].mean()
374
+ avg_confidence_in_bin = confidences[in_bin].mean()
375
+ ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
376
+
377
+ return {
378
+ 'expected_calibration_error': float(ece),
379
+ 'avg_confidence': float(np.mean(confidences)),
380
+ 'confidence_std': float(np.std(confidences))
381
+ }
382
+
383
+
384
+ def create_evaluator(knowledge_graph: KnowledgeGraph,
385
+ evaluation_type: str = "link_prediction") -> KnowledgeGraphEvaluator:
386
+ """
387
+ Factory function to create appropriate evaluator
388
+
389
+ Args:
390
+ knowledge_graph: Knowledge graph for evaluation
391
+ evaluation_type: Type of evaluator to create
392
+
393
+ Returns:
394
+ Configured evaluator
395
+ """
396
+ if evaluation_type == "link_prediction":
397
+ return LinkPredictionEvaluator(knowledge_graph)
398
+ elif evaluation_type == "rule_quality":
399
+ return RuleQualityEvaluator(knowledge_graph)
400
+ elif evaluation_type == "confidence_calibration":
401
+ return ConfidenceCalibrationEvaluator(knowledge_graph)
402
+ else:
403
+ return LinkPredictionEvaluator(knowledge_graph) # Default
404
+
405
+
406
+ def comprehensive_evaluation(model: NPLLModel,
407
+ knowledge_graph: KnowledgeGraph,
408
+ test_triples: Optional[List[Triple]] = None) -> EvaluationMetrics:
409
+ """
410
+ Run comprehensive evaluation of NPLL model
411
+
412
+ Args:
413
+ model: Trained NPLL model
414
+ knowledge_graph: Knowledge graph for evaluation
415
+ test_triples: Optional test triples
416
+
417
+ Returns:
418
+ Comprehensive evaluation metrics
419
+ """
420
+ start_time = time.time()
421
+
422
+ # Link prediction evaluation
423
+ link_evaluator = LinkPredictionEvaluator(knowledge_graph)
424
+ link_metrics = link_evaluator.evaluate_link_prediction(model, test_triples)
425
+
426
+ # Rule quality evaluation
427
+ rule_evaluator = RuleQualityEvaluator(knowledge_graph)
428
+ rule_metrics = rule_evaluator.evaluate_rule_quality(model)
429
+
430
+ evaluation_time = time.time() - start_time
431
+
432
+ # Create comprehensive metrics object
433
+ metrics = EvaluationMetrics(
434
+ mrr=link_metrics.get('mrr', 0.0),
435
+ hit_at_1=link_metrics.get('hit@1', 0.0),
436
+ hit_at_3=link_metrics.get('hit@3', 0.0),
437
+ hit_at_10=link_metrics.get('hit@10', 0.0),
438
+ mean_rank=link_metrics.get('mean_rank', 0.0),
439
+ median_rank=link_metrics.get('median_rank', 0.0),
440
+ rule_precision=rule_metrics.get('avg_rule_confidence'),
441
+ evaluation_time=evaluation_time,
442
+ predictions_per_second=link_metrics.get('predictions_per_second', 0.0)
443
+ )
444
+
445
+ return metrics
446
+
447
+
448
+ # Example usage function
449
+ def example_evaluation():
450
+ """
451
+ Example showing comprehensive evaluation with sample data
452
+ """
453
+ from ..core import load_knowledge_graph_from_triples
454
+ from ..core.logical_rules import RuleGenerator
455
+ from ..npll_model import create_npll_model
456
+ from ..utils import get_config
457
+
458
+ # Create sample data
459
+ sample_triples = [
460
+ ('Alice', 'friendOf', 'Bob'),
461
+ ('Bob', 'worksAt', 'Company'),
462
+ ('Charlie', 'friendOf', 'Alice'),
463
+ ('Alice', 'livesIn', 'NYC')
464
+ ]
465
+
466
+ kg = load_knowledge_graph_from_triples(sample_triples, "Eval Test")
467
+
468
+ # Generate rules
469
+ rule_gen = RuleGenerator(kg)
470
+ rules = rule_gen.generate_simple_rules(min_support=1)
471
+
472
+ # Create and train model (simplified)
473
+ config = get_config("ArangoDB_Triples")
474
+ model = create_npll_model(config)
475
+ model.initialize(kg, rules)
476
+
477
+ # Add unknown facts for evaluation
478
+ kg.add_unknown_fact('Charlie', 'worksAt', 'Company')
479
+ kg.add_unknown_fact('Bob', 'livesIn', 'NYC')
480
+
481
+ # Quick training
482
+ model.train_epoch()
483
+
484
+ # Comprehensive evaluation
485
+ metrics = comprehensive_evaluation(model, kg)
486
+ print(f"Evaluation Results: {metrics}")
487
+
488
+ # Specific evaluations
489
+ evaluator = create_evaluator(kg)
490
+ link_metrics = evaluator.evaluate_link_prediction(model)
491
+ print(f"Link Prediction: {link_metrics}")
492
+
493
+ return metrics
494
+
495
+
496
+ if __name__ == "__main__":
497
497
  example_evaluation()