odin-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -0
- benchmarks/datasets.py +284 -0
- benchmarks/metrics.py +275 -0
- benchmarks/run_ablation.py +279 -0
- benchmarks/run_npll_benchmark.py +270 -0
- npll/__init__.py +10 -0
- npll/bootstrap.py +474 -0
- npll/core/__init__.py +34 -0
- npll/core/knowledge_graph.py +309 -0
- npll/core/logical_rules.py +497 -0
- npll/core/mln.py +475 -0
- npll/inference/__init__.py +41 -0
- npll/inference/e_step.py +420 -0
- npll/inference/elbo.py +435 -0
- npll/inference/m_step.py +577 -0
- npll/npll_model.py +632 -0
- npll/scoring/__init__.py +43 -0
- npll/scoring/embeddings.py +442 -0
- npll/scoring/probability.py +403 -0
- npll/scoring/scoring_module.py +370 -0
- npll/training/__init__.py +25 -0
- npll/training/evaluation.py +497 -0
- npll/training/npll_trainer.py +521 -0
- npll/utils/__init__.py +48 -0
- npll/utils/batch_utils.py +493 -0
- npll/utils/config.py +145 -0
- npll/utils/math_utils.py +339 -0
- odin/__init__.py +20 -0
- odin/engine.py +264 -0
- odin_engine-0.1.0.dist-info/METADATA +456 -0
- odin_engine-0.1.0.dist-info/RECORD +62 -0
- odin_engine-0.1.0.dist-info/WHEEL +5 -0
- odin_engine-0.1.0.dist-info/licenses/LICENSE +21 -0
- odin_engine-0.1.0.dist-info/top_level.txt +4 -0
- retrieval/__init__.py +50 -0
- retrieval/adapters.py +140 -0
- retrieval/adapters_arango.py +1418 -0
- retrieval/aggregators.py +707 -0
- retrieval/beam.py +127 -0
- retrieval/budget.py +60 -0
- retrieval/cache.py +159 -0
- retrieval/confidence.py +88 -0
- retrieval/eval.py +49 -0
- retrieval/linker.py +87 -0
- retrieval/metrics.py +105 -0
- retrieval/metrics_motifs.py +36 -0
- retrieval/orchestrator.py +571 -0
- retrieval/ppr/__init__.py +12 -0
- retrieval/ppr/anchors.py +41 -0
- retrieval/ppr/bippr.py +61 -0
- retrieval/ppr/engines.py +257 -0
- retrieval/ppr/global_pr.py +76 -0
- retrieval/ppr/indexes.py +78 -0
- retrieval/ppr.py +156 -0
- retrieval/ppr_cache.py +25 -0
- retrieval/scoring.py +294 -0
- retrieval/utils/__init__.py +0 -0
- retrieval/utils/pii_redaction.py +36 -0
- retrieval/writers/__init__.py +9 -0
- retrieval/writers/arango_writer.py +28 -0
- retrieval/writers/base.py +21 -0
- retrieval/writers/janus_writer.py +36 -0
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NPLL Training Infrastructure
|
|
3
|
+
Complete training loop with E-M algorithm, validation, and checkpointing
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from typing import List, Dict, Set, Tuple, Optional, Any, Union
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
import os
|
|
12
|
+
import json
|
|
13
|
+
from dataclasses import dataclass, asdict
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
from ..npll_model import NPLLModel, NPLLTrainingState
|
|
17
|
+
from ..core import KnowledgeGraph, LogicalRule
|
|
18
|
+
from ..utils import NPLLConfig
|
|
19
|
+
from .evaluation import EvaluationMetrics, create_evaluator
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class TrainingConfig:
|
|
26
|
+
"""
|
|
27
|
+
Configuration for NPLL training process
|
|
28
|
+
"""
|
|
29
|
+
# Training parameters
|
|
30
|
+
num_epochs: int = 100
|
|
31
|
+
max_em_iterations_per_epoch: int = 20
|
|
32
|
+
early_stopping_patience: int = 10
|
|
33
|
+
|
|
34
|
+
# Validation
|
|
35
|
+
validate_every_n_epochs: int = 5
|
|
36
|
+
validation_split: float = 0.1
|
|
37
|
+
|
|
38
|
+
# Checkpointing
|
|
39
|
+
save_checkpoints: bool = True
|
|
40
|
+
checkpoint_dir: str = "checkpoints"
|
|
41
|
+
save_every_n_epochs: int = 10
|
|
42
|
+
keep_best_checkpoint: bool = True
|
|
43
|
+
|
|
44
|
+
# Logging
|
|
45
|
+
log_level: str = "INFO"
|
|
46
|
+
log_metrics_every_n_iterations: int = 5
|
|
47
|
+
|
|
48
|
+
# Performance
|
|
49
|
+
device: str = "cpu" # or "cuda"
|
|
50
|
+
num_workers: int = 1
|
|
51
|
+
|
|
52
|
+
# Optimization
|
|
53
|
+
learning_rate_schedule: bool = True
|
|
54
|
+
lr_decay_factor: float = 0.9
|
|
55
|
+
lr_decay_patience: int = 5
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class TrainingResult:
|
|
60
|
+
"""
|
|
61
|
+
Result of NPLL training process
|
|
62
|
+
"""
|
|
63
|
+
# Training progress
|
|
64
|
+
total_epochs: int
|
|
65
|
+
total_em_iterations: int
|
|
66
|
+
final_elbo: float
|
|
67
|
+
best_elbo: float
|
|
68
|
+
converged: bool
|
|
69
|
+
|
|
70
|
+
# Training history
|
|
71
|
+
elbo_history: List[float]
|
|
72
|
+
validation_metrics_history: List[Dict[str, float]]
|
|
73
|
+
|
|
74
|
+
# Timing
|
|
75
|
+
total_training_time: float
|
|
76
|
+
average_epoch_time: float
|
|
77
|
+
|
|
78
|
+
# Model state
|
|
79
|
+
final_model_path: Optional[str] = None
|
|
80
|
+
best_model_path: Optional[str] = None
|
|
81
|
+
|
|
82
|
+
# Convergence info
|
|
83
|
+
convergence_epoch: Optional[int] = None
|
|
84
|
+
early_stopping_triggered: bool = False
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class NPLLTrainer:
|
|
88
|
+
"""
|
|
89
|
+
Complete NPLL training infrastructure
|
|
90
|
+
|
|
91
|
+
Manages the full training pipeline:
|
|
92
|
+
- E-M algorithm execution
|
|
93
|
+
- Validation and evaluation
|
|
94
|
+
- Checkpointing and model saving
|
|
95
|
+
- Early stopping and convergence detection
|
|
96
|
+
- Learning rate scheduling
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(self,
|
|
100
|
+
model: NPLLModel,
|
|
101
|
+
training_config: TrainingConfig,
|
|
102
|
+
evaluator=None):
|
|
103
|
+
"""
|
|
104
|
+
Initialize NPLL trainer
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
model: NPLL model to train
|
|
108
|
+
training_config: Training configuration
|
|
109
|
+
evaluator: Optional evaluator for validation
|
|
110
|
+
"""
|
|
111
|
+
self.model = model
|
|
112
|
+
self.config = training_config
|
|
113
|
+
self.evaluator = evaluator
|
|
114
|
+
|
|
115
|
+
# Setup device
|
|
116
|
+
self.device = torch.device(self.config.device)
|
|
117
|
+
if self.model.is_initialized:
|
|
118
|
+
self.model.to(self.device)
|
|
119
|
+
|
|
120
|
+
# Setup logging
|
|
121
|
+
self._setup_logging()
|
|
122
|
+
|
|
123
|
+
# Training state
|
|
124
|
+
self.training_history = {
|
|
125
|
+
'epochs': [],
|
|
126
|
+
'elbo_history': [],
|
|
127
|
+
'validation_metrics': [],
|
|
128
|
+
'learning_rates': [],
|
|
129
|
+
'convergence_info': []
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
# Early stopping state
|
|
133
|
+
self.best_validation_score = float('-inf')
|
|
134
|
+
self.epochs_without_improvement = 0
|
|
135
|
+
|
|
136
|
+
# Checkpointing
|
|
137
|
+
if self.config.save_checkpoints:
|
|
138
|
+
self.checkpoint_dir = Path(self.config.checkpoint_dir)
|
|
139
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
140
|
+
|
|
141
|
+
logger.info(f"NPLL Trainer initialized with config: {self.config}")
|
|
142
|
+
|
|
143
|
+
def _setup_logging(self):
|
|
144
|
+
"""Setup training logging"""
|
|
145
|
+
log_level = getattr(logging, self.config.log_level.upper())
|
|
146
|
+
logging.getLogger(__name__).setLevel(log_level)
|
|
147
|
+
|
|
148
|
+
def train(self,
|
|
149
|
+
validation_kg: Optional[KnowledgeGraph] = None,
|
|
150
|
+
validation_rules: Optional[List[LogicalRule]] = None) -> TrainingResult:
|
|
151
|
+
"""
|
|
152
|
+
Complete training process
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
validation_kg: Optional validation knowledge graph
|
|
156
|
+
validation_rules: Optional validation rules
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
TrainingResult with comprehensive training information
|
|
160
|
+
"""
|
|
161
|
+
if not self.model.is_initialized:
|
|
162
|
+
raise RuntimeError("Model must be initialized before training")
|
|
163
|
+
|
|
164
|
+
logger.info("Starting NPLL training process")
|
|
165
|
+
training_start_time = time.time()
|
|
166
|
+
|
|
167
|
+
# Setup validation if provided
|
|
168
|
+
validation_available = validation_kg is not None and validation_rules is not None
|
|
169
|
+
if validation_available and self.evaluator is None:
|
|
170
|
+
self.evaluator = create_evaluator(validation_kg)
|
|
171
|
+
|
|
172
|
+
# Training loop
|
|
173
|
+
converged = False
|
|
174
|
+
early_stopped = False
|
|
175
|
+
|
|
176
|
+
for epoch in range(self.config.num_epochs):
|
|
177
|
+
epoch_start_time = time.time()
|
|
178
|
+
|
|
179
|
+
# Train one epoch
|
|
180
|
+
epoch_result = self._train_epoch(epoch)
|
|
181
|
+
|
|
182
|
+
# Update training history
|
|
183
|
+
self._update_training_history(epoch, epoch_result)
|
|
184
|
+
|
|
185
|
+
# Validation
|
|
186
|
+
validation_metrics = {}
|
|
187
|
+
if validation_available and epoch % self.config.validate_every_n_epochs == 0:
|
|
188
|
+
validation_metrics = self._validate(validation_kg, validation_rules)
|
|
189
|
+
self.training_history['validation_metrics'].append(validation_metrics)
|
|
190
|
+
|
|
191
|
+
# Early stopping check
|
|
192
|
+
early_stopped = self._check_early_stopping(validation_metrics)
|
|
193
|
+
|
|
194
|
+
# Checkpointing
|
|
195
|
+
if self.config.save_checkpoints and epoch % self.config.save_every_n_epochs == 0:
|
|
196
|
+
self._save_checkpoint(epoch, epoch_result, validation_metrics)
|
|
197
|
+
|
|
198
|
+
# Convergence check
|
|
199
|
+
converged = epoch_result['converged']
|
|
200
|
+
|
|
201
|
+
# Log progress
|
|
202
|
+
self._log_epoch_progress(epoch, epoch_result, validation_metrics,
|
|
203
|
+
time.time() - epoch_start_time)
|
|
204
|
+
|
|
205
|
+
# Break conditions
|
|
206
|
+
if converged:
|
|
207
|
+
logger.info(f"Training converged at epoch {epoch}")
|
|
208
|
+
break
|
|
209
|
+
|
|
210
|
+
if early_stopped:
|
|
211
|
+
logger.info(f"Early stopping triggered at epoch {epoch}")
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
# Training completed
|
|
215
|
+
total_training_time = time.time() - training_start_time
|
|
216
|
+
|
|
217
|
+
# Save final model
|
|
218
|
+
final_model_path = None
|
|
219
|
+
if self.config.save_checkpoints:
|
|
220
|
+
final_model_path = self.checkpoint_dir / "final_model.pt"
|
|
221
|
+
self.model.save_model(str(final_model_path))
|
|
222
|
+
|
|
223
|
+
# Create training result
|
|
224
|
+
result = self._create_training_result(
|
|
225
|
+
total_epochs=epoch + 1,
|
|
226
|
+
total_training_time=total_training_time,
|
|
227
|
+
converged=converged,
|
|
228
|
+
early_stopped=early_stopped,
|
|
229
|
+
final_model_path=str(final_model_path) if final_model_path else None
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
logger.info(f"Training completed: {result}")
|
|
233
|
+
return result
|
|
234
|
+
|
|
235
|
+
def _train_epoch(self, epoch: int) -> Dict[str, Any]:
|
|
236
|
+
"""Train a single epoch"""
|
|
237
|
+
logger.debug(f"Training epoch {epoch}")
|
|
238
|
+
|
|
239
|
+
# Train epoch with E-M iterations
|
|
240
|
+
epoch_result = self.model.train_epoch(
|
|
241
|
+
max_em_iterations=self.config.max_em_iterations_per_epoch
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return epoch_result
|
|
245
|
+
|
|
246
|
+
def _validate(self, validation_kg: KnowledgeGraph,
|
|
247
|
+
validation_rules: List[LogicalRule]) -> Dict[str, float]:
|
|
248
|
+
"""Run validation evaluation"""
|
|
249
|
+
if self.evaluator is None:
|
|
250
|
+
return {}
|
|
251
|
+
|
|
252
|
+
logger.debug("Running validation evaluation")
|
|
253
|
+
|
|
254
|
+
# Set model to eval mode
|
|
255
|
+
self.model.eval()
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
# Run evaluation
|
|
259
|
+
metrics = self.evaluator.evaluate_link_prediction(self.model, top_k=[1, 3, 10])
|
|
260
|
+
|
|
261
|
+
# Add rule quality metrics if possible
|
|
262
|
+
try:
|
|
263
|
+
rule_metrics = self.evaluator.evaluate_rule_quality(self.model)
|
|
264
|
+
metrics.update(rule_metrics)
|
|
265
|
+
except Exception as e:
|
|
266
|
+
logger.debug(f"Could not evaluate rule quality: {e}")
|
|
267
|
+
|
|
268
|
+
return metrics
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
logger.warning(f"Validation failed: {e}")
|
|
272
|
+
return {}
|
|
273
|
+
|
|
274
|
+
finally:
|
|
275
|
+
# Set model back to train mode
|
|
276
|
+
self.model.train()
|
|
277
|
+
|
|
278
|
+
def _check_early_stopping(self, validation_metrics: Dict[str, float]) -> bool:
|
|
279
|
+
"""Check if early stopping should be triggered"""
|
|
280
|
+
if not validation_metrics:
|
|
281
|
+
return False
|
|
282
|
+
|
|
283
|
+
# Use MRR as primary validation metric
|
|
284
|
+
current_score = validation_metrics.get('mrr', float('-inf'))
|
|
285
|
+
|
|
286
|
+
if current_score > self.best_validation_score:
|
|
287
|
+
self.best_validation_score = current_score
|
|
288
|
+
self.epochs_without_improvement = 0
|
|
289
|
+
return False
|
|
290
|
+
else:
|
|
291
|
+
self.epochs_without_improvement += 1
|
|
292
|
+
return self.epochs_without_improvement >= self.config.early_stopping_patience
|
|
293
|
+
|
|
294
|
+
def _save_checkpoint(self, epoch: int, epoch_result: Dict[str, Any],
|
|
295
|
+
validation_metrics: Dict[str, float]):
|
|
296
|
+
"""Save training checkpoint"""
|
|
297
|
+
checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
|
|
298
|
+
|
|
299
|
+
checkpoint_data = {
|
|
300
|
+
'epoch': epoch,
|
|
301
|
+
'model_state': self.model.get_model_summary(),
|
|
302
|
+
'training_history': self.training_history,
|
|
303
|
+
'training_config': asdict(self.config),
|
|
304
|
+
'epoch_result': epoch_result,
|
|
305
|
+
'validation_metrics': validation_metrics
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
torch.save(checkpoint_data, checkpoint_path)
|
|
309
|
+
|
|
310
|
+
# Save model state
|
|
311
|
+
model_checkpoint_path = self.checkpoint_dir / f"model_epoch_{epoch}.pt"
|
|
312
|
+
self.model.save_model(str(model_checkpoint_path))
|
|
313
|
+
|
|
314
|
+
logger.debug(f"Checkpoint saved: {checkpoint_path}")
|
|
315
|
+
|
|
316
|
+
def _update_training_history(self, epoch: int, epoch_result: Dict[str, Any]):
|
|
317
|
+
"""Update training history"""
|
|
318
|
+
self.training_history['epochs'].append(epoch)
|
|
319
|
+
self.training_history['elbo_history'].extend(
|
|
320
|
+
[r['elbo'] for r in epoch_result['iteration_results']]
|
|
321
|
+
)
|
|
322
|
+
self.training_history['convergence_info'].append({
|
|
323
|
+
'epoch': epoch,
|
|
324
|
+
'converged': epoch_result['converged'],
|
|
325
|
+
'em_iterations': epoch_result['em_iterations'],
|
|
326
|
+
'final_elbo': epoch_result['final_elbo']
|
|
327
|
+
})
|
|
328
|
+
|
|
329
|
+
def _log_epoch_progress(self, epoch: int, epoch_result: Dict[str, Any],
|
|
330
|
+
validation_metrics: Dict[str, float], epoch_time: float):
|
|
331
|
+
"""Log training progress"""
|
|
332
|
+
elbo = epoch_result['final_elbo']
|
|
333
|
+
em_iters = epoch_result['em_iterations']
|
|
334
|
+
converged = epoch_result['converged']
|
|
335
|
+
|
|
336
|
+
log_msg = (f"Epoch {epoch}: ELBO={elbo:.6f}, EM_iters={em_iters}, "
|
|
337
|
+
f"Converged={converged}, Time={epoch_time:.2f}s")
|
|
338
|
+
|
|
339
|
+
if validation_metrics:
|
|
340
|
+
mrr = validation_metrics.get('mrr', 0.0)
|
|
341
|
+
hit1 = validation_metrics.get('hit@1', 0.0)
|
|
342
|
+
log_msg += f", Val_MRR={mrr:.4f}, Val_Hit@1={hit1:.4f}"
|
|
343
|
+
|
|
344
|
+
logger.info(log_msg)
|
|
345
|
+
|
|
346
|
+
def _create_training_result(self, total_epochs: int, total_training_time: float,
|
|
347
|
+
converged: bool, early_stopped: bool,
|
|
348
|
+
final_model_path: Optional[str]) -> TrainingResult:
|
|
349
|
+
"""Create comprehensive training result"""
|
|
350
|
+
|
|
351
|
+
# Get total EM iterations
|
|
352
|
+
total_em_iterations = sum(
|
|
353
|
+
info['em_iterations'] for info in self.training_history['convergence_info']
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Get final and best ELBO
|
|
357
|
+
final_elbo = self.training_history['elbo_history'][-1] if self.training_history['elbo_history'] else float('-inf')
|
|
358
|
+
best_elbo = max(self.training_history['elbo_history']) if self.training_history['elbo_history'] else float('-inf')
|
|
359
|
+
|
|
360
|
+
# Find convergence epoch
|
|
361
|
+
convergence_epoch = None
|
|
362
|
+
for info in self.training_history['convergence_info']:
|
|
363
|
+
if info['converged']:
|
|
364
|
+
convergence_epoch = info['epoch']
|
|
365
|
+
break
|
|
366
|
+
|
|
367
|
+
return TrainingResult(
|
|
368
|
+
total_epochs=total_epochs,
|
|
369
|
+
total_em_iterations=total_em_iterations,
|
|
370
|
+
final_elbo=final_elbo,
|
|
371
|
+
best_elbo=best_elbo,
|
|
372
|
+
converged=converged,
|
|
373
|
+
elbo_history=self.training_history['elbo_history'],
|
|
374
|
+
validation_metrics_history=self.training_history['validation_metrics'],
|
|
375
|
+
total_training_time=total_training_time,
|
|
376
|
+
average_epoch_time=total_training_time / total_epochs if total_epochs > 0 else 0.0,
|
|
377
|
+
final_model_path=final_model_path,
|
|
378
|
+
convergence_epoch=convergence_epoch,
|
|
379
|
+
early_stopping_triggered=early_stopped
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
def resume_training(self, checkpoint_path: str) -> TrainingResult:
|
|
383
|
+
"""Resume training from checkpoint"""
|
|
384
|
+
if not os.path.exists(checkpoint_path):
|
|
385
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
386
|
+
|
|
387
|
+
logger.info(f"Resuming training from checkpoint: {checkpoint_path}")
|
|
388
|
+
|
|
389
|
+
checkpoint_data = torch.load(checkpoint_path)
|
|
390
|
+
|
|
391
|
+
# Restore training history
|
|
392
|
+
self.training_history = checkpoint_data['training_history']
|
|
393
|
+
|
|
394
|
+
# Resume training from next epoch
|
|
395
|
+
start_epoch = checkpoint_data['epoch'] + 1
|
|
396
|
+
|
|
397
|
+
# Continue training
|
|
398
|
+
# (This would require modifying the train method to accept start_epoch)
|
|
399
|
+
logger.info(f"Training resumed from epoch {start_epoch}")
|
|
400
|
+
|
|
401
|
+
return self.train()
|
|
402
|
+
|
|
403
|
+
def get_training_summary(self) -> Dict[str, Any]:
|
|
404
|
+
"""Get comprehensive training summary"""
|
|
405
|
+
return {
|
|
406
|
+
'config': asdict(self.config),
|
|
407
|
+
'model_summary': self.model.get_model_summary(),
|
|
408
|
+
'training_history': self.training_history,
|
|
409
|
+
'best_validation_score': self.best_validation_score,
|
|
410
|
+
'epochs_without_improvement': self.epochs_without_improvement
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def create_trainer(model: NPLLModel,
|
|
415
|
+
training_config: Optional[TrainingConfig] = None,
|
|
416
|
+
validation_kg: Optional[KnowledgeGraph] = None) -> NPLLTrainer:
|
|
417
|
+
"""
|
|
418
|
+
Factory function to create NPLL trainer
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
model: NPLL model to train
|
|
422
|
+
training_config: Optional training configuration
|
|
423
|
+
validation_kg: Optional validation knowledge graph for evaluator
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
Configured NPLL trainer
|
|
427
|
+
"""
|
|
428
|
+
if training_config is None:
|
|
429
|
+
training_config = TrainingConfig()
|
|
430
|
+
|
|
431
|
+
# Create evaluator if validation data provided
|
|
432
|
+
evaluator = None
|
|
433
|
+
if validation_kg is not None:
|
|
434
|
+
evaluator = create_evaluator(validation_kg)
|
|
435
|
+
|
|
436
|
+
return NPLLTrainer(model, training_config, evaluator)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def train_npll_from_scratch(knowledge_graph: KnowledgeGraph,
|
|
440
|
+
logical_rules: List[LogicalRule],
|
|
441
|
+
npll_config: Optional[NPLLConfig] = None,
|
|
442
|
+
training_config: Optional[TrainingConfig] = None) -> Tuple[NPLLModel, TrainingResult]:
|
|
443
|
+
"""
|
|
444
|
+
Complete training pipeline from scratch
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
knowledge_graph: Knowledge graph for training
|
|
448
|
+
logical_rules: Logical rules
|
|
449
|
+
npll_config: NPLL model configuration
|
|
450
|
+
training_config: Training configuration
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
(Trained model, Training result)
|
|
454
|
+
"""
|
|
455
|
+
from ..npll_model import create_npll_model
|
|
456
|
+
from ..utils import get_config
|
|
457
|
+
|
|
458
|
+
# Create model
|
|
459
|
+
if npll_config is None:
|
|
460
|
+
npll_config = get_config("ArangoDB_Triples")
|
|
461
|
+
|
|
462
|
+
model = create_npll_model(npll_config)
|
|
463
|
+
model.initialize(knowledge_graph, logical_rules)
|
|
464
|
+
|
|
465
|
+
# Create trainer
|
|
466
|
+
trainer = create_trainer(model, training_config)
|
|
467
|
+
|
|
468
|
+
# Train
|
|
469
|
+
result = trainer.train()
|
|
470
|
+
|
|
471
|
+
return model, result
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
# Example usage function
|
|
475
|
+
def example_training_pipeline():
|
|
476
|
+
"""
|
|
477
|
+
Example showing complete training pipeline with sample data
|
|
478
|
+
"""
|
|
479
|
+
from ..core import load_knowledge_graph_from_triples
|
|
480
|
+
from ..core.logical_rules import RuleGenerator
|
|
481
|
+
from ..utils import get_config
|
|
482
|
+
|
|
483
|
+
# 1. Create sample data (your data adapter would provide this format)
|
|
484
|
+
sample_triples = [
|
|
485
|
+
('Alice', 'friendOf', 'Bob'),
|
|
486
|
+
('Bob', 'worksAt', 'Company'),
|
|
487
|
+
('Charlie', 'friendOf', 'Alice'),
|
|
488
|
+
('Bob', 'livesIn', 'NYC'),
|
|
489
|
+
('Alice', 'livesIn', 'NYC'),
|
|
490
|
+
('Company', 'locatedIn', 'NYC')
|
|
491
|
+
]
|
|
492
|
+
|
|
493
|
+
# Load knowledge graph
|
|
494
|
+
kg = load_knowledge_graph_from_triples(sample_triples, "Sample KG")
|
|
495
|
+
|
|
496
|
+
# Generate rules
|
|
497
|
+
rule_generator = RuleGenerator(kg)
|
|
498
|
+
rules = rule_generator.generate_simple_rules(min_support=1)
|
|
499
|
+
rules.extend(rule_generator.generate_symmetry_rules(min_support=1))
|
|
500
|
+
|
|
501
|
+
# 2. Configure training
|
|
502
|
+
npll_config = get_config("ArangoDB_Triples")
|
|
503
|
+
training_config = TrainingConfig(
|
|
504
|
+
num_epochs=10,
|
|
505
|
+
max_em_iterations_per_epoch=5,
|
|
506
|
+
early_stopping_patience=3,
|
|
507
|
+
validate_every_n_epochs=2
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# 3. Train model
|
|
511
|
+
model, result = train_npll_from_scratch(kg, rules, npll_config, training_config)
|
|
512
|
+
|
|
513
|
+
# 4. Results
|
|
514
|
+
print(f"Training completed: {result}")
|
|
515
|
+
print(f"Final model: {model}")
|
|
516
|
+
|
|
517
|
+
return model, result
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
if __name__ == "__main__":
|
|
521
|
+
example_training_pipeline()
|
npll/utils/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility modules for NPLL implementation
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .config import NPLLConfig, get_config, default_config
|
|
6
|
+
from .math_utils import (
|
|
7
|
+
log_sum_exp, safe_log, safe_sigmoid, partition_function_approximation,
|
|
8
|
+
compute_mln_probability, compute_elbo_loss, bernoulli_entropy, bernoulli_log_prob,
|
|
9
|
+
compute_markov_blanket_prob, temperature_scaling, kl_divergence_bernoulli,
|
|
10
|
+
gradient_clipping, compute_metrics, NumericalStabilizer
|
|
11
|
+
)
|
|
12
|
+
from .batch_utils import (
|
|
13
|
+
GroundRuleBatch, GroundRuleSampler, FactBatchProcessor,
|
|
14
|
+
MemoryEfficientBatcher, AdaptiveBatcher, create_ground_rule_sampler,
|
|
15
|
+
verify_batch_utils
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# Configuration
|
|
20
|
+
'NPLLConfig',
|
|
21
|
+
'get_config',
|
|
22
|
+
'default_config',
|
|
23
|
+
|
|
24
|
+
# Mathematical Utilities
|
|
25
|
+
'log_sum_exp',
|
|
26
|
+
'safe_log',
|
|
27
|
+
'safe_sigmoid',
|
|
28
|
+
'partition_function_approximation',
|
|
29
|
+
'compute_mln_probability',
|
|
30
|
+
'compute_elbo_loss',
|
|
31
|
+
'bernoulli_entropy',
|
|
32
|
+
'bernoulli_log_prob',
|
|
33
|
+
'compute_markov_blanket_prob',
|
|
34
|
+
'temperature_scaling',
|
|
35
|
+
'kl_divergence_bernoulli',
|
|
36
|
+
'gradient_clipping',
|
|
37
|
+
'compute_metrics',
|
|
38
|
+
'NumericalStabilizer',
|
|
39
|
+
|
|
40
|
+
# Batch Processing
|
|
41
|
+
'GroundRuleBatch',
|
|
42
|
+
'GroundRuleSampler',
|
|
43
|
+
'FactBatchProcessor',
|
|
44
|
+
'MemoryEfficientBatcher',
|
|
45
|
+
'AdaptiveBatcher',
|
|
46
|
+
'create_ground_rule_sampler',
|
|
47
|
+
'verify_batch_utils'
|
|
48
|
+
]
|