odin-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. benchmarks/__init__.py +17 -0
  2. benchmarks/datasets.py +284 -0
  3. benchmarks/metrics.py +275 -0
  4. benchmarks/run_ablation.py +279 -0
  5. benchmarks/run_npll_benchmark.py +270 -0
  6. npll/__init__.py +10 -0
  7. npll/bootstrap.py +474 -0
  8. npll/core/__init__.py +34 -0
  9. npll/core/knowledge_graph.py +309 -0
  10. npll/core/logical_rules.py +497 -0
  11. npll/core/mln.py +475 -0
  12. npll/inference/__init__.py +41 -0
  13. npll/inference/e_step.py +420 -0
  14. npll/inference/elbo.py +435 -0
  15. npll/inference/m_step.py +577 -0
  16. npll/npll_model.py +632 -0
  17. npll/scoring/__init__.py +43 -0
  18. npll/scoring/embeddings.py +442 -0
  19. npll/scoring/probability.py +403 -0
  20. npll/scoring/scoring_module.py +370 -0
  21. npll/training/__init__.py +25 -0
  22. npll/training/evaluation.py +497 -0
  23. npll/training/npll_trainer.py +521 -0
  24. npll/utils/__init__.py +48 -0
  25. npll/utils/batch_utils.py +493 -0
  26. npll/utils/config.py +145 -0
  27. npll/utils/math_utils.py +339 -0
  28. odin/__init__.py +20 -0
  29. odin/engine.py +264 -0
  30. odin_engine-0.1.0.dist-info/METADATA +456 -0
  31. odin_engine-0.1.0.dist-info/RECORD +62 -0
  32. odin_engine-0.1.0.dist-info/WHEEL +5 -0
  33. odin_engine-0.1.0.dist-info/licenses/LICENSE +21 -0
  34. odin_engine-0.1.0.dist-info/top_level.txt +4 -0
  35. retrieval/__init__.py +50 -0
  36. retrieval/adapters.py +140 -0
  37. retrieval/adapters_arango.py +1418 -0
  38. retrieval/aggregators.py +707 -0
  39. retrieval/beam.py +127 -0
  40. retrieval/budget.py +60 -0
  41. retrieval/cache.py +159 -0
  42. retrieval/confidence.py +88 -0
  43. retrieval/eval.py +49 -0
  44. retrieval/linker.py +87 -0
  45. retrieval/metrics.py +105 -0
  46. retrieval/metrics_motifs.py +36 -0
  47. retrieval/orchestrator.py +571 -0
  48. retrieval/ppr/__init__.py +12 -0
  49. retrieval/ppr/anchors.py +41 -0
  50. retrieval/ppr/bippr.py +61 -0
  51. retrieval/ppr/engines.py +257 -0
  52. retrieval/ppr/global_pr.py +76 -0
  53. retrieval/ppr/indexes.py +78 -0
  54. retrieval/ppr.py +156 -0
  55. retrieval/ppr_cache.py +25 -0
  56. retrieval/scoring.py +294 -0
  57. retrieval/utils/__init__.py +0 -0
  58. retrieval/utils/pii_redaction.py +36 -0
  59. retrieval/writers/__init__.py +9 -0
  60. retrieval/writers/arango_writer.py +28 -0
  61. retrieval/writers/base.py +21 -0
  62. retrieval/writers/janus_writer.py +36 -0
@@ -0,0 +1,521 @@
1
+ """
2
+ NPLL Training Infrastructure
3
+ Complete training loop with E-M algorithm, validation, and checkpointing
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import List, Dict, Set, Tuple, Optional, Any, Union
9
+ import logging
10
+ import time
11
+ import os
12
+ import json
13
+ from dataclasses import dataclass, asdict
14
+ from pathlib import Path
15
+
16
+ from ..npll_model import NPLLModel, NPLLTrainingState
17
+ from ..core import KnowledgeGraph, LogicalRule
18
+ from ..utils import NPLLConfig
19
+ from .evaluation import EvaluationMetrics, create_evaluator
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class TrainingConfig:
26
+ """
27
+ Configuration for NPLL training process
28
+ """
29
+ # Training parameters
30
+ num_epochs: int = 100
31
+ max_em_iterations_per_epoch: int = 20
32
+ early_stopping_patience: int = 10
33
+
34
+ # Validation
35
+ validate_every_n_epochs: int = 5
36
+ validation_split: float = 0.1
37
+
38
+ # Checkpointing
39
+ save_checkpoints: bool = True
40
+ checkpoint_dir: str = "checkpoints"
41
+ save_every_n_epochs: int = 10
42
+ keep_best_checkpoint: bool = True
43
+
44
+ # Logging
45
+ log_level: str = "INFO"
46
+ log_metrics_every_n_iterations: int = 5
47
+
48
+ # Performance
49
+ device: str = "cpu" # or "cuda"
50
+ num_workers: int = 1
51
+
52
+ # Optimization
53
+ learning_rate_schedule: bool = True
54
+ lr_decay_factor: float = 0.9
55
+ lr_decay_patience: int = 5
56
+
57
+
58
+ @dataclass
59
+ class TrainingResult:
60
+ """
61
+ Result of NPLL training process
62
+ """
63
+ # Training progress
64
+ total_epochs: int
65
+ total_em_iterations: int
66
+ final_elbo: float
67
+ best_elbo: float
68
+ converged: bool
69
+
70
+ # Training history
71
+ elbo_history: List[float]
72
+ validation_metrics_history: List[Dict[str, float]]
73
+
74
+ # Timing
75
+ total_training_time: float
76
+ average_epoch_time: float
77
+
78
+ # Model state
79
+ final_model_path: Optional[str] = None
80
+ best_model_path: Optional[str] = None
81
+
82
+ # Convergence info
83
+ convergence_epoch: Optional[int] = None
84
+ early_stopping_triggered: bool = False
85
+
86
+
87
+ class NPLLTrainer:
88
+ """
89
+ Complete NPLL training infrastructure
90
+
91
+ Manages the full training pipeline:
92
+ - E-M algorithm execution
93
+ - Validation and evaluation
94
+ - Checkpointing and model saving
95
+ - Early stopping and convergence detection
96
+ - Learning rate scheduling
97
+ """
98
+
99
+ def __init__(self,
100
+ model: NPLLModel,
101
+ training_config: TrainingConfig,
102
+ evaluator=None):
103
+ """
104
+ Initialize NPLL trainer
105
+
106
+ Args:
107
+ model: NPLL model to train
108
+ training_config: Training configuration
109
+ evaluator: Optional evaluator for validation
110
+ """
111
+ self.model = model
112
+ self.config = training_config
113
+ self.evaluator = evaluator
114
+
115
+ # Setup device
116
+ self.device = torch.device(self.config.device)
117
+ if self.model.is_initialized:
118
+ self.model.to(self.device)
119
+
120
+ # Setup logging
121
+ self._setup_logging()
122
+
123
+ # Training state
124
+ self.training_history = {
125
+ 'epochs': [],
126
+ 'elbo_history': [],
127
+ 'validation_metrics': [],
128
+ 'learning_rates': [],
129
+ 'convergence_info': []
130
+ }
131
+
132
+ # Early stopping state
133
+ self.best_validation_score = float('-inf')
134
+ self.epochs_without_improvement = 0
135
+
136
+ # Checkpointing
137
+ if self.config.save_checkpoints:
138
+ self.checkpoint_dir = Path(self.config.checkpoint_dir)
139
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
140
+
141
+ logger.info(f"NPLL Trainer initialized with config: {self.config}")
142
+
143
+ def _setup_logging(self):
144
+ """Setup training logging"""
145
+ log_level = getattr(logging, self.config.log_level.upper())
146
+ logging.getLogger(__name__).setLevel(log_level)
147
+
148
+ def train(self,
149
+ validation_kg: Optional[KnowledgeGraph] = None,
150
+ validation_rules: Optional[List[LogicalRule]] = None) -> TrainingResult:
151
+ """
152
+ Complete training process
153
+
154
+ Args:
155
+ validation_kg: Optional validation knowledge graph
156
+ validation_rules: Optional validation rules
157
+
158
+ Returns:
159
+ TrainingResult with comprehensive training information
160
+ """
161
+ if not self.model.is_initialized:
162
+ raise RuntimeError("Model must be initialized before training")
163
+
164
+ logger.info("Starting NPLL training process")
165
+ training_start_time = time.time()
166
+
167
+ # Setup validation if provided
168
+ validation_available = validation_kg is not None and validation_rules is not None
169
+ if validation_available and self.evaluator is None:
170
+ self.evaluator = create_evaluator(validation_kg)
171
+
172
+ # Training loop
173
+ converged = False
174
+ early_stopped = False
175
+
176
+ for epoch in range(self.config.num_epochs):
177
+ epoch_start_time = time.time()
178
+
179
+ # Train one epoch
180
+ epoch_result = self._train_epoch(epoch)
181
+
182
+ # Update training history
183
+ self._update_training_history(epoch, epoch_result)
184
+
185
+ # Validation
186
+ validation_metrics = {}
187
+ if validation_available and epoch % self.config.validate_every_n_epochs == 0:
188
+ validation_metrics = self._validate(validation_kg, validation_rules)
189
+ self.training_history['validation_metrics'].append(validation_metrics)
190
+
191
+ # Early stopping check
192
+ early_stopped = self._check_early_stopping(validation_metrics)
193
+
194
+ # Checkpointing
195
+ if self.config.save_checkpoints and epoch % self.config.save_every_n_epochs == 0:
196
+ self._save_checkpoint(epoch, epoch_result, validation_metrics)
197
+
198
+ # Convergence check
199
+ converged = epoch_result['converged']
200
+
201
+ # Log progress
202
+ self._log_epoch_progress(epoch, epoch_result, validation_metrics,
203
+ time.time() - epoch_start_time)
204
+
205
+ # Break conditions
206
+ if converged:
207
+ logger.info(f"Training converged at epoch {epoch}")
208
+ break
209
+
210
+ if early_stopped:
211
+ logger.info(f"Early stopping triggered at epoch {epoch}")
212
+ break
213
+
214
+ # Training completed
215
+ total_training_time = time.time() - training_start_time
216
+
217
+ # Save final model
218
+ final_model_path = None
219
+ if self.config.save_checkpoints:
220
+ final_model_path = self.checkpoint_dir / "final_model.pt"
221
+ self.model.save_model(str(final_model_path))
222
+
223
+ # Create training result
224
+ result = self._create_training_result(
225
+ total_epochs=epoch + 1,
226
+ total_training_time=total_training_time,
227
+ converged=converged,
228
+ early_stopped=early_stopped,
229
+ final_model_path=str(final_model_path) if final_model_path else None
230
+ )
231
+
232
+ logger.info(f"Training completed: {result}")
233
+ return result
234
+
235
+ def _train_epoch(self, epoch: int) -> Dict[str, Any]:
236
+ """Train a single epoch"""
237
+ logger.debug(f"Training epoch {epoch}")
238
+
239
+ # Train epoch with E-M iterations
240
+ epoch_result = self.model.train_epoch(
241
+ max_em_iterations=self.config.max_em_iterations_per_epoch
242
+ )
243
+
244
+ return epoch_result
245
+
246
+ def _validate(self, validation_kg: KnowledgeGraph,
247
+ validation_rules: List[LogicalRule]) -> Dict[str, float]:
248
+ """Run validation evaluation"""
249
+ if self.evaluator is None:
250
+ return {}
251
+
252
+ logger.debug("Running validation evaluation")
253
+
254
+ # Set model to eval mode
255
+ self.model.eval()
256
+
257
+ try:
258
+ # Run evaluation
259
+ metrics = self.evaluator.evaluate_link_prediction(self.model, top_k=[1, 3, 10])
260
+
261
+ # Add rule quality metrics if possible
262
+ try:
263
+ rule_metrics = self.evaluator.evaluate_rule_quality(self.model)
264
+ metrics.update(rule_metrics)
265
+ except Exception as e:
266
+ logger.debug(f"Could not evaluate rule quality: {e}")
267
+
268
+ return metrics
269
+
270
+ except Exception as e:
271
+ logger.warning(f"Validation failed: {e}")
272
+ return {}
273
+
274
+ finally:
275
+ # Set model back to train mode
276
+ self.model.train()
277
+
278
+ def _check_early_stopping(self, validation_metrics: Dict[str, float]) -> bool:
279
+ """Check if early stopping should be triggered"""
280
+ if not validation_metrics:
281
+ return False
282
+
283
+ # Use MRR as primary validation metric
284
+ current_score = validation_metrics.get('mrr', float('-inf'))
285
+
286
+ if current_score > self.best_validation_score:
287
+ self.best_validation_score = current_score
288
+ self.epochs_without_improvement = 0
289
+ return False
290
+ else:
291
+ self.epochs_without_improvement += 1
292
+ return self.epochs_without_improvement >= self.config.early_stopping_patience
293
+
294
+ def _save_checkpoint(self, epoch: int, epoch_result: Dict[str, Any],
295
+ validation_metrics: Dict[str, float]):
296
+ """Save training checkpoint"""
297
+ checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
298
+
299
+ checkpoint_data = {
300
+ 'epoch': epoch,
301
+ 'model_state': self.model.get_model_summary(),
302
+ 'training_history': self.training_history,
303
+ 'training_config': asdict(self.config),
304
+ 'epoch_result': epoch_result,
305
+ 'validation_metrics': validation_metrics
306
+ }
307
+
308
+ torch.save(checkpoint_data, checkpoint_path)
309
+
310
+ # Save model state
311
+ model_checkpoint_path = self.checkpoint_dir / f"model_epoch_{epoch}.pt"
312
+ self.model.save_model(str(model_checkpoint_path))
313
+
314
+ logger.debug(f"Checkpoint saved: {checkpoint_path}")
315
+
316
+ def _update_training_history(self, epoch: int, epoch_result: Dict[str, Any]):
317
+ """Update training history"""
318
+ self.training_history['epochs'].append(epoch)
319
+ self.training_history['elbo_history'].extend(
320
+ [r['elbo'] for r in epoch_result['iteration_results']]
321
+ )
322
+ self.training_history['convergence_info'].append({
323
+ 'epoch': epoch,
324
+ 'converged': epoch_result['converged'],
325
+ 'em_iterations': epoch_result['em_iterations'],
326
+ 'final_elbo': epoch_result['final_elbo']
327
+ })
328
+
329
+ def _log_epoch_progress(self, epoch: int, epoch_result: Dict[str, Any],
330
+ validation_metrics: Dict[str, float], epoch_time: float):
331
+ """Log training progress"""
332
+ elbo = epoch_result['final_elbo']
333
+ em_iters = epoch_result['em_iterations']
334
+ converged = epoch_result['converged']
335
+
336
+ log_msg = (f"Epoch {epoch}: ELBO={elbo:.6f}, EM_iters={em_iters}, "
337
+ f"Converged={converged}, Time={epoch_time:.2f}s")
338
+
339
+ if validation_metrics:
340
+ mrr = validation_metrics.get('mrr', 0.0)
341
+ hit1 = validation_metrics.get('hit@1', 0.0)
342
+ log_msg += f", Val_MRR={mrr:.4f}, Val_Hit@1={hit1:.4f}"
343
+
344
+ logger.info(log_msg)
345
+
346
+ def _create_training_result(self, total_epochs: int, total_training_time: float,
347
+ converged: bool, early_stopped: bool,
348
+ final_model_path: Optional[str]) -> TrainingResult:
349
+ """Create comprehensive training result"""
350
+
351
+ # Get total EM iterations
352
+ total_em_iterations = sum(
353
+ info['em_iterations'] for info in self.training_history['convergence_info']
354
+ )
355
+
356
+ # Get final and best ELBO
357
+ final_elbo = self.training_history['elbo_history'][-1] if self.training_history['elbo_history'] else float('-inf')
358
+ best_elbo = max(self.training_history['elbo_history']) if self.training_history['elbo_history'] else float('-inf')
359
+
360
+ # Find convergence epoch
361
+ convergence_epoch = None
362
+ for info in self.training_history['convergence_info']:
363
+ if info['converged']:
364
+ convergence_epoch = info['epoch']
365
+ break
366
+
367
+ return TrainingResult(
368
+ total_epochs=total_epochs,
369
+ total_em_iterations=total_em_iterations,
370
+ final_elbo=final_elbo,
371
+ best_elbo=best_elbo,
372
+ converged=converged,
373
+ elbo_history=self.training_history['elbo_history'],
374
+ validation_metrics_history=self.training_history['validation_metrics'],
375
+ total_training_time=total_training_time,
376
+ average_epoch_time=total_training_time / total_epochs if total_epochs > 0 else 0.0,
377
+ final_model_path=final_model_path,
378
+ convergence_epoch=convergence_epoch,
379
+ early_stopping_triggered=early_stopped
380
+ )
381
+
382
+ def resume_training(self, checkpoint_path: str) -> TrainingResult:
383
+ """Resume training from checkpoint"""
384
+ if not os.path.exists(checkpoint_path):
385
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
386
+
387
+ logger.info(f"Resuming training from checkpoint: {checkpoint_path}")
388
+
389
+ checkpoint_data = torch.load(checkpoint_path)
390
+
391
+ # Restore training history
392
+ self.training_history = checkpoint_data['training_history']
393
+
394
+ # Resume training from next epoch
395
+ start_epoch = checkpoint_data['epoch'] + 1
396
+
397
+ # Continue training
398
+ # (This would require modifying the train method to accept start_epoch)
399
+ logger.info(f"Training resumed from epoch {start_epoch}")
400
+
401
+ return self.train()
402
+
403
+ def get_training_summary(self) -> Dict[str, Any]:
404
+ """Get comprehensive training summary"""
405
+ return {
406
+ 'config': asdict(self.config),
407
+ 'model_summary': self.model.get_model_summary(),
408
+ 'training_history': self.training_history,
409
+ 'best_validation_score': self.best_validation_score,
410
+ 'epochs_without_improvement': self.epochs_without_improvement
411
+ }
412
+
413
+
414
+ def create_trainer(model: NPLLModel,
415
+ training_config: Optional[TrainingConfig] = None,
416
+ validation_kg: Optional[KnowledgeGraph] = None) -> NPLLTrainer:
417
+ """
418
+ Factory function to create NPLL trainer
419
+
420
+ Args:
421
+ model: NPLL model to train
422
+ training_config: Optional training configuration
423
+ validation_kg: Optional validation knowledge graph for evaluator
424
+
425
+ Returns:
426
+ Configured NPLL trainer
427
+ """
428
+ if training_config is None:
429
+ training_config = TrainingConfig()
430
+
431
+ # Create evaluator if validation data provided
432
+ evaluator = None
433
+ if validation_kg is not None:
434
+ evaluator = create_evaluator(validation_kg)
435
+
436
+ return NPLLTrainer(model, training_config, evaluator)
437
+
438
+
439
+ def train_npll_from_scratch(knowledge_graph: KnowledgeGraph,
440
+ logical_rules: List[LogicalRule],
441
+ npll_config: Optional[NPLLConfig] = None,
442
+ training_config: Optional[TrainingConfig] = None) -> Tuple[NPLLModel, TrainingResult]:
443
+ """
444
+ Complete training pipeline from scratch
445
+
446
+ Args:
447
+ knowledge_graph: Knowledge graph for training
448
+ logical_rules: Logical rules
449
+ npll_config: NPLL model configuration
450
+ training_config: Training configuration
451
+
452
+ Returns:
453
+ (Trained model, Training result)
454
+ """
455
+ from ..npll_model import create_npll_model
456
+ from ..utils import get_config
457
+
458
+ # Create model
459
+ if npll_config is None:
460
+ npll_config = get_config("ArangoDB_Triples")
461
+
462
+ model = create_npll_model(npll_config)
463
+ model.initialize(knowledge_graph, logical_rules)
464
+
465
+ # Create trainer
466
+ trainer = create_trainer(model, training_config)
467
+
468
+ # Train
469
+ result = trainer.train()
470
+
471
+ return model, result
472
+
473
+
474
+ # Example usage function
475
+ def example_training_pipeline():
476
+ """
477
+ Example showing complete training pipeline with sample data
478
+ """
479
+ from ..core import load_knowledge_graph_from_triples
480
+ from ..core.logical_rules import RuleGenerator
481
+ from ..utils import get_config
482
+
483
+ # 1. Create sample data (your data adapter would provide this format)
484
+ sample_triples = [
485
+ ('Alice', 'friendOf', 'Bob'),
486
+ ('Bob', 'worksAt', 'Company'),
487
+ ('Charlie', 'friendOf', 'Alice'),
488
+ ('Bob', 'livesIn', 'NYC'),
489
+ ('Alice', 'livesIn', 'NYC'),
490
+ ('Company', 'locatedIn', 'NYC')
491
+ ]
492
+
493
+ # Load knowledge graph
494
+ kg = load_knowledge_graph_from_triples(sample_triples, "Sample KG")
495
+
496
+ # Generate rules
497
+ rule_generator = RuleGenerator(kg)
498
+ rules = rule_generator.generate_simple_rules(min_support=1)
499
+ rules.extend(rule_generator.generate_symmetry_rules(min_support=1))
500
+
501
+ # 2. Configure training
502
+ npll_config = get_config("ArangoDB_Triples")
503
+ training_config = TrainingConfig(
504
+ num_epochs=10,
505
+ max_em_iterations_per_epoch=5,
506
+ early_stopping_patience=3,
507
+ validate_every_n_epochs=2
508
+ )
509
+
510
+ # 3. Train model
511
+ model, result = train_npll_from_scratch(kg, rules, npll_config, training_config)
512
+
513
+ # 4. Results
514
+ print(f"Training completed: {result}")
515
+ print(f"Final model: {model}")
516
+
517
+ return model, result
518
+
519
+
520
+ if __name__ == "__main__":
521
+ example_training_pipeline()
npll/utils/__init__.py ADDED
@@ -0,0 +1,48 @@
1
+ """
2
+ Utility modules for NPLL implementation
3
+ """
4
+
5
+ from .config import NPLLConfig, get_config, default_config
6
+ from .math_utils import (
7
+ log_sum_exp, safe_log, safe_sigmoid, partition_function_approximation,
8
+ compute_mln_probability, compute_elbo_loss, bernoulli_entropy, bernoulli_log_prob,
9
+ compute_markov_blanket_prob, temperature_scaling, kl_divergence_bernoulli,
10
+ gradient_clipping, compute_metrics, NumericalStabilizer
11
+ )
12
+ from .batch_utils import (
13
+ GroundRuleBatch, GroundRuleSampler, FactBatchProcessor,
14
+ MemoryEfficientBatcher, AdaptiveBatcher, create_ground_rule_sampler,
15
+ verify_batch_utils
16
+ )
17
+
18
+ __all__ = [
19
+ # Configuration
20
+ 'NPLLConfig',
21
+ 'get_config',
22
+ 'default_config',
23
+
24
+ # Mathematical Utilities
25
+ 'log_sum_exp',
26
+ 'safe_log',
27
+ 'safe_sigmoid',
28
+ 'partition_function_approximation',
29
+ 'compute_mln_probability',
30
+ 'compute_elbo_loss',
31
+ 'bernoulli_entropy',
32
+ 'bernoulli_log_prob',
33
+ 'compute_markov_blanket_prob',
34
+ 'temperature_scaling',
35
+ 'kl_divergence_bernoulli',
36
+ 'gradient_clipping',
37
+ 'compute_metrics',
38
+ 'NumericalStabilizer',
39
+
40
+ # Batch Processing
41
+ 'GroundRuleBatch',
42
+ 'GroundRuleSampler',
43
+ 'FactBatchProcessor',
44
+ 'MemoryEfficientBatcher',
45
+ 'AdaptiveBatcher',
46
+ 'create_ground_rule_sampler',
47
+ 'verify_batch_utils'
48
+ ]