odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/npll_model.py CHANGED
@@ -1,632 +1,632 @@
1
- """
2
- Main NPLL Model - Neural Probabilistic Logic Learning
3
- Integrates all components: MLN, Scoring Module, E-step, M-step
4
- Exact implementation of the complete NPLL framework from the paper
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
- from typing import List, Dict, Set, Tuple, Optional, Any, Union
10
- import logging
11
- from dataclasses import dataclass
12
- import time
13
- import os
14
-
15
- from .core import (
16
- KnowledgeGraph, Entity, Relation, Triple, LogicalRule, GroundRule,
17
- MarkovLogicNetwork, create_mln_from_kg_and_rules
18
- )
19
- from .scoring import (
20
- NPLLScoringModule, create_scoring_module,
21
- ProbabilityTransform, create_probability_components
22
- )
23
- from .inference import (
24
- EStepRunner, MStepRunner, ELBOComputer, EStepResult, MStepResult,
25
- create_e_step_runner, create_m_step_runner
26
- )
27
- from .utils import NPLLConfig, get_config
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
-
32
- @dataclass
33
- class NPLLTrainingState:
34
- """
35
- Training state for NPLL model
36
- Tracks progress through E-M iterations
37
- """
38
- epoch: int
39
- em_iteration: int
40
- elbo_history: List[float]
41
- rule_weight_history: List[List[float]]
42
- convergence_info: Dict[str, Any]
43
- training_time: float
44
- best_elbo: float
45
- best_weights: Optional[torch.Tensor] = None
46
-
47
- def __str__(self) -> str:
48
- return (f"NPLL Training State:\n"
49
- f" Epoch: {self.epoch}\n"
50
- f" EM Iteration: {self.em_iteration}\n"
51
- f" Current ELBO: {self.elbo_history[-1] if self.elbo_history else 'N/A'}\n"
52
- f" Best ELBO: {self.best_elbo:.6f}\n"
53
- f" Training Time: {self.training_time:.2f}s")
54
-
55
-
56
- class NPLLModel(nn.Module):
57
- """
58
- Complete Neural Probabilistic Logic Learning Model
59
-
60
- Integrates all paper components:
61
- - Knowledge Graph representation (Section 3)
62
- - Scoring Module with bilinear function (Section 4.1, Equation 7)
63
- - Markov Logic Network (Sections 3-4, Equations 1-2)
64
- - E-M Algorithm (Sections 4.2-4.3, Equations 8-14)
65
- - ELBO optimization (Equations 3-5)
66
- """
67
-
68
- def __init__(self, config: NPLLConfig):
69
- super().__init__()
70
- self.config = config
71
-
72
- # Core components (initialized when knowledge graph is provided)
73
- self.knowledge_graph: Optional[KnowledgeGraph] = None
74
- self.mln: Optional[MarkovLogicNetwork] = None
75
- self.scoring_module: Optional[NPLLScoringModule] = None
76
-
77
- # Inference components
78
- self.e_step_runner: Optional[EStepRunner] = None
79
- self.m_step_runner: Optional[MStepRunner] = None
80
- self.elbo_computer: Optional[ELBOComputer] = None
81
-
82
- # Probability transformation
83
- self.probability_transform: Optional[ProbabilityTransform] = None
84
-
85
- # Training state
86
- self.training_state: Optional[NPLLTrainingState] = None
87
- self.is_initialized = False
88
-
89
- # Model metadata
90
- self.model_version = "1.0"
91
- self.creation_time = time.time()
92
- self.calibration_version = None
93
-
94
- def initialize(self,
95
- knowledge_graph: KnowledgeGraph,
96
- logical_rules: List[LogicalRule]):
97
- """
98
- Initialize NPLL model with knowledge graph and rules
99
-
100
- Args:
101
- knowledge_graph: Knowledge graph K = (E, L, F)
102
- logical_rules: List of logical rules R
103
- """
104
- logger.info("Initializing NPLL model...")
105
-
106
- # Store knowledge graph
107
- self.knowledge_graph = knowledge_graph
108
-
109
- # Create and initialize MLN
110
- self.mln = create_mln_from_kg_and_rules(knowledge_graph, logical_rules, self.config)
111
-
112
- # Create scoring module
113
- self.scoring_module = create_scoring_module(self.config, knowledge_graph)
114
-
115
- # Create inference components
116
- self.e_step_runner = create_e_step_runner(self.config)
117
- self.m_step_runner = create_m_step_runner(self.config)
118
-
119
- # Create ELBO computer
120
- from .inference import create_elbo_computer
121
- self.elbo_computer = create_elbo_computer(self.config)
122
-
123
- # Create probability transformation (enable per-relation groups if available)
124
- num_rel = None
125
- if self.scoring_module is not None and hasattr(self.scoring_module, 'embedding_manager'):
126
- emb_mgr = self.scoring_module.embedding_manager
127
- if hasattr(emb_mgr, 'relation_num_groups'):
128
- num_rel = emb_mgr.relation_num_groups
129
- per_relation = num_rel is not None
130
- prob_transform, _ = create_probability_components(
131
- self.config.temperature,
132
- per_relation=per_relation,
133
- num_relations=(num_rel or 1)
134
- )
135
- self.probability_transform = prob_transform
136
-
137
- # Initialize training state
138
- self.training_state = NPLLTrainingState(
139
- epoch=0,
140
- em_iteration=0,
141
- elbo_history=[],
142
- rule_weight_history=[],
143
- convergence_info={},
144
- training_time=0.0,
145
- best_elbo=float('-inf')
146
- )
147
-
148
- self.is_initialized = True
149
-
150
- logger.info(f"NPLL model initialized with {len(logical_rules)} rules, "
151
- f"{len(knowledge_graph.entities)} entities, "
152
- f"{len(knowledge_graph.relations)} relations")
153
-
154
- def forward(self, triples: List[Triple]) -> Dict[str, torch.Tensor]:
155
- """
156
- Forward pass through NPLL model
157
-
158
- Args:
159
- triples: List of triples to score
160
-
161
- Returns:
162
- Dictionary with scores and probabilities
163
- """
164
- if not self.is_initialized:
165
- raise RuntimeError("NPLL model not initialized. Call initialize() first.")
166
-
167
- # Get raw scores from scoring module
168
- raw_scores = self.scoring_module.forward(triples)
169
-
170
- # Group IDs per relation (no vocab mutation)
171
- group_ids = None
172
- if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
173
- emb_mgr = self.scoring_module.embedding_manager
174
- group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
175
- # Ensure transform capacity if table grew
176
- if hasattr(self.probability_transform, 'ensure_num_groups'):
177
- self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
178
-
179
- # Transform to probabilities
180
- probabilities = self.probability_transform(raw_scores, apply_temperature=True, group_ids=group_ids)
181
-
182
- # Get log probabilities
183
- log_probabilities = self.probability_transform.get_log_probabilities(raw_scores, apply_temperature=True, group_ids=group_ids)
184
-
185
- return {
186
- 'raw_scores': raw_scores,
187
- 'probabilities': probabilities,
188
- 'log_probabilities': log_probabilities
189
- }
190
-
191
- def predict_single_triple(self, head: str, relation: str, tail: str, transient: bool = True) -> Dict[str, float]:
192
- """
193
- Predict probability for a single triple
194
-
195
- Args:
196
- head: Head entity name
197
- relation: Relation name
198
- tail: Tail entity name
199
- transient: If True, do not mutate the underlying knowledge graph
200
-
201
- Returns:
202
- Dictionary with prediction results
203
- """
204
- if not self.is_initialized:
205
- raise RuntimeError("NPLL model not initialized")
206
-
207
- # Create triple object without mutating KG by default
208
- if transient:
209
- head_entity = Entity(head)
210
- relation_obj = Relation(relation)
211
- tail_entity = Entity(tail)
212
- else:
213
- head_entity = self.knowledge_graph.get_entity(head) or self.knowledge_graph.add_entity(head)
214
- relation_obj = self.knowledge_graph.get_relation(relation) or self.knowledge_graph.add_relation(relation)
215
- tail_entity = self.knowledge_graph.get_entity(tail) or self.knowledge_graph.add_entity(tail)
216
-
217
- triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
218
-
219
- # Get predictions
220
- self.eval()
221
- with torch.no_grad():
222
- results = self.forward([triple])
223
-
224
- return {
225
- 'probability': results['probabilities'][0].item(),
226
- 'log_probability': results['log_probabilities'][0].item(),
227
- 'raw_score': results['raw_scores'][0].item()
228
- }
229
-
230
- def run_single_em_iteration(self) -> Dict[str, Any]:
231
- """
232
- Run a single E-M iteration
233
-
234
- Returns:
235
- Dictionary with iteration results
236
- """
237
- if not self.is_initialized:
238
- raise RuntimeError("NPLL model not initialized")
239
-
240
- iteration_start_time = time.time()
241
-
242
- logger.debug(f"Starting E-M iteration {self.training_state.em_iteration}")
243
-
244
- # E-step: Optimize Q(U)
245
- logger.debug("Running E-step...")
246
- e_step_result = self.e_step_runner.run_e_step(
247
- self.mln, self.scoring_module, self.knowledge_graph
248
- )
249
-
250
- # M-step: Optimize rule weights ω
251
- logger.debug("Running M-step...")
252
- m_step_result = self.m_step_runner.run_m_step(self.mln, e_step_result)
253
-
254
- # Update training state
255
- current_elbo = e_step_result.elbo_value.item()
256
- self.training_state.elbo_history.append(current_elbo)
257
-
258
- if self.mln.rule_weights is not None:
259
- current_weights = self.mln.rule_weights.data.tolist()
260
- self.training_state.rule_weight_history.append(current_weights)
261
-
262
- # Track best snapshot (MLN + scoring)
263
- if current_elbo > self.training_state.best_elbo:
264
- self.training_state.best_elbo = current_elbo
265
- self.training_state.best_weights = {
266
- 'mln': {k: v.clone() for k, v in self.mln.state_dict().items()},
267
- 'scoring': {k: v.clone() for k, v in self.scoring_module.state_dict().items()} if self.scoring_module else {},
268
- }
269
-
270
- self.training_state.em_iteration += 1
271
- iteration_time = time.time() - iteration_start_time
272
- self.training_state.training_time += iteration_time
273
-
274
- # Check convergence
275
- converged = self._check_em_convergence()
276
-
277
- iteration_result = {
278
- 'em_iteration': self.training_state.em_iteration - 1,
279
- 'e_step_result': e_step_result,
280
- 'm_step_result': m_step_result,
281
- 'elbo': current_elbo,
282
- 'iteration_time': iteration_time,
283
- 'converged': converged,
284
- 'convergence_info': {
285
- 'e_step_converged': e_step_result.convergence_info.get('converged', False),
286
- 'm_step_converged': m_step_result.convergence_info.get('converged', False)
287
- }
288
- }
289
-
290
- logger.debug(f"E-M iteration completed: ELBO={current_elbo:.6f}, "
291
- f"Time={iteration_time:.2f}s, Converged={converged}")
292
-
293
- return iteration_result
294
-
295
- def train_epoch(self, max_em_iterations: Optional[int] = None) -> Dict[str, Any]:
296
- """
297
- Train for one epoch (multiple E-M iterations until convergence)
298
-
299
- Args:
300
- max_em_iterations: Maximum E-M iterations per epoch
301
-
302
- Returns:
303
- Dictionary with epoch results
304
- """
305
- if not self.is_initialized:
306
- raise RuntimeError("NPLL model not initialized")
307
-
308
- max_iterations = max_em_iterations or self.config.em_iterations
309
- epoch_start_time = time.time()
310
-
311
- logger.info(f"Starting training epoch {self.training_state.epoch}")
312
-
313
- epoch_results = []
314
- converged = False
315
-
316
- for em_iter in range(max_iterations):
317
- iteration_result = self.run_single_em_iteration()
318
- epoch_results.append(iteration_result)
319
-
320
- if iteration_result['converged']:
321
- converged = True
322
- logger.info(f"Converged after {em_iter + 1} E-M iterations")
323
- break
324
-
325
- epoch_time = time.time() - epoch_start_time
326
- self.training_state.epoch += 1
327
-
328
- epoch_summary = {
329
- 'epoch': self.training_state.epoch - 1,
330
- 'em_iterations': len(epoch_results),
331
- 'converged': converged,
332
- 'final_elbo': epoch_results[-1]['elbo'] if epoch_results else float('-inf'),
333
- 'best_elbo_this_epoch': max(r['elbo'] for r in epoch_results) if epoch_results else float('-inf'),
334
- 'epoch_time': epoch_time,
335
- 'iteration_results': epoch_results
336
- }
337
-
338
- logger.info(f"Epoch {self.training_state.epoch - 1} completed: "
339
- f"ELBO={epoch_summary['final_elbo']:.6f}, "
340
- f"EM iterations={epoch_summary['em_iterations']}, "
341
- f"Time={epoch_time:.2f}s")
342
-
343
- return epoch_summary
344
-
345
- def _check_em_convergence(self) -> bool:
346
- """Check if E-M algorithm has converged with patience and relative tolerance"""
347
- if len(self.training_state.elbo_history) < 2:
348
- return False
349
- h = self.training_state.elbo_history
350
- rel = abs(h[-1] - h[-2]) / (abs(h[-2]) + 1e-8)
351
- elbo_ok = rel < getattr(self.config, 'elbo_rel_tol', self.config.convergence_threshold)
352
-
353
- weight_ok = True
354
- if len(self.training_state.rule_weight_history) >= 2:
355
- current_weights = torch.tensor(self.training_state.rule_weight_history[-1])
356
- prev_weights = torch.tensor(self.training_state.rule_weight_history[-2])
357
- weight_change = torch.norm(current_weights - prev_weights).item()
358
- weight_ok = weight_change < getattr(self.config, 'weight_abs_tol', self.config.convergence_threshold)
359
-
360
- if elbo_ok and weight_ok:
361
- hits = self.training_state.convergence_info.get('hits', 0) + 1
362
- self.training_state.convergence_info['hits'] = hits
363
- else:
364
- self.training_state.convergence_info['hits'] = 0
365
- patience = getattr(self.config, 'convergence_patience', 3)
366
- return self.training_state.convergence_info.get('hits', 0) >= patience
367
-
368
- def get_rule_confidences(self) -> Dict[str, float]:
369
- """Get learned confidence scores for all rules"""
370
- if not self.is_initialized or self.mln is None or self.mln.rule_weights is None:
371
- return {}
372
-
373
- confidences = {}
374
- for i, rule in enumerate(self.mln.logical_rules):
375
- # Convert rule weight to confidence using sigmoid
376
- weight = self.mln.rule_weights[i].item()
377
- confidence = torch.sigmoid(torch.tensor(weight)).item()
378
- confidences[rule.rule_id] = confidence
379
-
380
- return confidences
381
-
382
- def predict_unknown_facts(self, top_k: int = 10) -> List[Tuple[Triple, float]]:
383
- """
384
- Predict probabilities for unknown facts
385
-
386
- Args:
387
- top_k: Number of top predictions to return
388
-
389
- Returns:
390
- List of (triple, probability) tuples sorted by probability
391
- """
392
- if not self.is_initialized:
393
- raise RuntimeError("NPLL model not initialized")
394
-
395
- unknown_facts = list(self.knowledge_graph.unknown_facts)
396
- if not unknown_facts:
397
- return []
398
-
399
- # Get predictions for unknown facts
400
- results = self.forward(unknown_facts)
401
- probabilities = results['probabilities']
402
-
403
- # Create (triple, probability) pairs
404
- predictions = list(zip(unknown_facts, probabilities.tolist()))
405
-
406
- # Sort by probability (descending)
407
- predictions.sort(key=lambda x: x[1], reverse=True)
408
-
409
- return predictions[:top_k]
410
-
411
- def save_model(self, filepath: str):
412
- """Save complete NPLL model state with KG and rules serialization"""
413
- if not self.is_initialized:
414
- raise RuntimeError("Cannot save uninitialized model")
415
-
416
- rules = self.mln.logical_rules if self.mln else []
417
- payload = {
418
- 'schema_version': self.model_version,
419
- 'config': self.config.__dict__,
420
- 'creation_time': self.creation_time,
421
- 'training_state': self.training_state,
422
- 'mln_state': self.mln.state_dict() if self.mln else None,
423
- 'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
424
- 'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
425
- 'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
426
- 'rules': [rule.serialize() for rule in rules],
427
- }
428
- torch.save(payload, filepath)
429
- logger.info(f"NPLL model saved to {filepath}")
430
-
431
- def load_model(self, filepath: str):
432
- """
433
- Load complete NPLL model state, rebuilding KG and rules before weights.
434
-
435
- PRODUCTION FIX: Respects device configuration (cpu/cuda) instead of hardcoding cpu.
436
- """
437
- if not os.path.exists(filepath):
438
- raise FileNotFoundError(f"Model file not found: {filepath}")
439
-
440
- from .core import KnowledgeGraph, LogicalRule
441
-
442
- # Respect device config: load to configured device (cpu or cuda)
443
- # Note: weights_only=False required for custom objects (NPLLTrainingState, etc.)
444
- device = getattr(self, 'device', 'cpu')
445
- payload = torch.load(filepath, map_location=device, weights_only=False)
446
-
447
- # Rebuild config
448
- self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
449
- # Rebuild KG and rules
450
- kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
451
- rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
452
- # Recreate components
453
- self.__init__(self.config)
454
- if kg is not None:
455
- self.initialize(kg, rules)
456
- # Load weights
457
- if payload.get('mln_state') and self.mln:
458
- self.mln.load_state_dict(payload['mln_state'])
459
- if payload.get('scoring_state') and self.scoring_module:
460
- self.scoring_module.load_state_dict(payload['scoring_state'])
461
- if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
462
- self.probability_transform.load_state_dict(payload['prob_state'])
463
- # Metadata
464
- self.model_version = payload.get('schema_version', '1.0')
465
- self.creation_time = payload.get('creation_time', time.time())
466
- self.training_state = payload.get('training_state', self.training_state)
467
- logger.info(f"NPLL model loaded from {filepath}")
468
-
469
- def save_to_buffer(self, buffer):
470
- """Save complete NPLL model state to a BytesIO buffer (for database storage)."""
471
- if not self.is_initialized:
472
- raise RuntimeError("Cannot save uninitialized model")
473
-
474
- rules = self.mln.logical_rules if self.mln else []
475
- payload = {
476
- 'schema_version': self.model_version,
477
- 'config': self.config.__dict__,
478
- 'creation_time': self.creation_time,
479
- 'training_state': self.training_state,
480
- 'mln_state': self.mln.state_dict() if self.mln else None,
481
- 'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
482
- 'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
483
- 'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
484
- 'rules': [rule.serialize() for rule in rules],
485
- }
486
- torch.save(payload, buffer)
487
- logger.info("NPLL model saved to buffer")
488
-
489
- def load_from_buffer(self, buffer):
490
- """Load complete NPLL model state from a BytesIO buffer (for database storage)."""
491
- from .core import KnowledgeGraph, LogicalRule
492
-
493
- # Respect device config
494
- device = getattr(self, 'device', 'cpu')
495
- buffer.seek(0)
496
- payload = torch.load(buffer, map_location=device, weights_only=False)
497
-
498
- # Rebuild config
499
- self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
500
- # Rebuild KG and rules
501
- kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
502
- rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
503
- # Recreate components
504
- self.__init__(self.config)
505
- if kg is not None:
506
- self.initialize(kg, rules)
507
- # Load weights
508
- if payload.get('mln_state') and self.mln:
509
- self.mln.load_state_dict(payload['mln_state'])
510
- if payload.get('scoring_state') and self.scoring_module:
511
- self.scoring_module.load_state_dict(payload['scoring_state'])
512
- if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
513
- self.probability_transform.load_state_dict(payload['prob_state'])
514
- # Metadata
515
- self.model_version = payload.get('schema_version', '1.0')
516
- self.creation_time = payload.get('creation_time', time.time())
517
- self.training_state = payload.get('training_state', self.training_state)
518
- logger.info("NPLL model loaded from buffer")
519
-
520
- def _get_device(self) -> str:
521
- for p in self.parameters():
522
- return str(p.device)
523
- return 'cpu'
524
-
525
- def get_model_summary(self) -> Dict[str, Any]:
526
- """Get comprehensive model summary (device-safe)"""
527
- summary = {
528
- 'model_version': self.model_version,
529
- 'is_initialized': self.is_initialized,
530
- 'config': self.config.__dict__ if self.config else {},
531
- 'creation_time': self.creation_time,
532
- 'calibration_version': self.calibration_version,
533
- 'training_state': self.training_state.__dict__ if self.training_state else {},
534
- 'num_parameters': sum(p.numel() for p in self.parameters()),
535
- 'device': self._get_device(),
536
- }
537
- if self.is_initialized:
538
- summary.update({
539
- 'knowledge_graph_stats': self.knowledge_graph.get_statistics() if self.knowledge_graph else {},
540
- 'mln_stats': self.mln.get_rule_statistics() if self.mln else {},
541
- 'rule_confidences': self.get_rule_confidences(),
542
- })
543
- return summary
544
-
545
- # --- Utilities: checkpoints and calibration ---
546
- def restore_best(self) -> bool:
547
- """Restore the best-scoring MLN and scoring-module weights if available."""
548
- if not self.training_state or not getattr(self.training_state, 'best_weights', None):
549
- return False
550
- best = self.training_state.best_weights
551
- if 'mln' in best and self.mln:
552
- self.mln.load_state_dict(best['mln'])
553
- if 'scoring' in best and self.scoring_module:
554
- self.scoring_module.load_state_dict(best['scoring'])
555
- return True
556
-
557
- def calibrate_temperature_on_data(self,
558
- triples: List[Triple],
559
- labels: torch.Tensor,
560
- max_iter: int = 100,
561
- version: Optional[str] = None) -> float:
562
- """
563
- Calibrate the ProbabilityTransform temperature on a holdout set.
564
- Stores the learned temperature in the transform and records a calibration version.
565
- Returns the optimized temperature value.
566
- """
567
- if not self.is_initialized:
568
- raise RuntimeError("NPLL model not initialized")
569
- self.eval()
570
- with torch.no_grad():
571
- scores = self.scoring_module.forward(triples)
572
-
573
- # Extract group_ids for calibration if per_group is enabled
574
- group_ids = None
575
- if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
576
- emb_mgr = self.scoring_module.embedding_manager
577
- group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
578
- if hasattr(self.probability_transform, 'ensure_num_groups'):
579
- self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
580
-
581
- optimized_temp = self.probability_transform.calibrate_temperature(scores, labels.float(), max_iter=max_iter, group_ids=group_ids)
582
- self.calibration_version = version or f"temp@{time.time():.0f}"
583
- return optimized_temp
584
-
585
- def __str__(self) -> str:
586
- if not self.is_initialized:
587
- return "NPLL Model (not initialized)"
588
-
589
- stats = self.knowledge_graph.get_statistics() if self.knowledge_graph else {}
590
- return (f"NPLL Model:\n"
591
- f" Entities: {stats.get('num_entities', 0)}\n"
592
- f" Relations: {stats.get('num_relations', 0)}\n"
593
- f" Known Facts: {stats.get('num_known_facts', 0)}\n"
594
- f" Unknown Facts: {stats.get('num_unknown_facts', 0)}\n"
595
- f" Rules: {len(self.mln.logical_rules) if self.mln else 0}\n"
596
- f" Training State: {self.training_state}")
597
-
598
-
599
- def create_npll_model(config: NPLLConfig) -> NPLLModel:
600
- """
601
- Factory function to create NPLL model
602
-
603
- Args:
604
- config: NPLL configuration
605
-
606
- Returns:
607
- Uninitialized NPLL model
608
- """
609
- return NPLLModel(config)
610
-
611
-
612
- def create_initialized_npll_model(knowledge_graph: KnowledgeGraph,
613
- logical_rules: List[LogicalRule],
614
- config: Optional[NPLLConfig] = None) -> NPLLModel:
615
- """
616
- Factory function to create and initialize NPLL model
617
-
618
- Args:
619
- knowledge_graph: Knowledge graph
620
- logical_rules: List of logical rules
621
- config: Optional configuration (uses default if not provided)
622
-
623
- Returns:
624
- Initialized NPLL model ready for training
625
- """
626
- if config is None:
627
- config = get_config("ArangoDB_Triples") # Default configuration
628
-
629
- model = NPLLModel(config)
630
- model.initialize(knowledge_graph, logical_rules)
631
-
1
+ """
2
+ Main NPLL Model - Neural Probabilistic Logic Learning
3
+ Integrates all components: MLN, Scoring Module, E-step, M-step
4
+ Exact implementation of the complete NPLL framework from the paper
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import List, Dict, Set, Tuple, Optional, Any, Union
10
+ import logging
11
+ from dataclasses import dataclass
12
+ import time
13
+ import os
14
+
15
+ from .core import (
16
+ KnowledgeGraph, Entity, Relation, Triple, LogicalRule, GroundRule,
17
+ MarkovLogicNetwork, create_mln_from_kg_and_rules
18
+ )
19
+ from .scoring import (
20
+ NPLLScoringModule, create_scoring_module,
21
+ ProbabilityTransform, create_probability_components
22
+ )
23
+ from .inference import (
24
+ EStepRunner, MStepRunner, ELBOComputer, EStepResult, MStepResult,
25
+ create_e_step_runner, create_m_step_runner
26
+ )
27
+ from .utils import NPLLConfig, get_config
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class NPLLTrainingState:
34
+ """
35
+ Training state for NPLL model
36
+ Tracks progress through E-M iterations
37
+ """
38
+ epoch: int
39
+ em_iteration: int
40
+ elbo_history: List[float]
41
+ rule_weight_history: List[List[float]]
42
+ convergence_info: Dict[str, Any]
43
+ training_time: float
44
+ best_elbo: float
45
+ best_weights: Optional[torch.Tensor] = None
46
+
47
+ def __str__(self) -> str:
48
+ return (f"NPLL Training State:\n"
49
+ f" Epoch: {self.epoch}\n"
50
+ f" EM Iteration: {self.em_iteration}\n"
51
+ f" Current ELBO: {self.elbo_history[-1] if self.elbo_history else 'N/A'}\n"
52
+ f" Best ELBO: {self.best_elbo:.6f}\n"
53
+ f" Training Time: {self.training_time:.2f}s")
54
+
55
+
56
+ class NPLLModel(nn.Module):
57
+ """
58
+ Complete Neural Probabilistic Logic Learning Model
59
+
60
+ Integrates all paper components:
61
+ - Knowledge Graph representation (Section 3)
62
+ - Scoring Module with bilinear function (Section 4.1, Equation 7)
63
+ - Markov Logic Network (Sections 3-4, Equations 1-2)
64
+ - E-M Algorithm (Sections 4.2-4.3, Equations 8-14)
65
+ - ELBO optimization (Equations 3-5)
66
+ """
67
+
68
+ def __init__(self, config: NPLLConfig):
69
+ super().__init__()
70
+ self.config = config
71
+
72
+ # Core components (initialized when knowledge graph is provided)
73
+ self.knowledge_graph: Optional[KnowledgeGraph] = None
74
+ self.mln: Optional[MarkovLogicNetwork] = None
75
+ self.scoring_module: Optional[NPLLScoringModule] = None
76
+
77
+ # Inference components
78
+ self.e_step_runner: Optional[EStepRunner] = None
79
+ self.m_step_runner: Optional[MStepRunner] = None
80
+ self.elbo_computer: Optional[ELBOComputer] = None
81
+
82
+ # Probability transformation
83
+ self.probability_transform: Optional[ProbabilityTransform] = None
84
+
85
+ # Training state
86
+ self.training_state: Optional[NPLLTrainingState] = None
87
+ self.is_initialized = False
88
+
89
+ # Model metadata
90
+ self.model_version = "1.0"
91
+ self.creation_time = time.time()
92
+ self.calibration_version = None
93
+
94
+ def initialize(self,
95
+ knowledge_graph: KnowledgeGraph,
96
+ logical_rules: List[LogicalRule]):
97
+ """
98
+ Initialize NPLL model with knowledge graph and rules
99
+
100
+ Args:
101
+ knowledge_graph: Knowledge graph K = (E, L, F)
102
+ logical_rules: List of logical rules R
103
+ """
104
+ logger.info("Initializing NPLL model...")
105
+
106
+ # Store knowledge graph
107
+ self.knowledge_graph = knowledge_graph
108
+
109
+ # Create and initialize MLN
110
+ self.mln = create_mln_from_kg_and_rules(knowledge_graph, logical_rules, self.config)
111
+
112
+ # Create scoring module
113
+ self.scoring_module = create_scoring_module(self.config, knowledge_graph)
114
+
115
+ # Create inference components
116
+ self.e_step_runner = create_e_step_runner(self.config)
117
+ self.m_step_runner = create_m_step_runner(self.config)
118
+
119
+ # Create ELBO computer
120
+ from .inference import create_elbo_computer
121
+ self.elbo_computer = create_elbo_computer(self.config)
122
+
123
+ # Create probability transformation (enable per-relation groups if available)
124
+ num_rel = None
125
+ if self.scoring_module is not None and hasattr(self.scoring_module, 'embedding_manager'):
126
+ emb_mgr = self.scoring_module.embedding_manager
127
+ if hasattr(emb_mgr, 'relation_num_groups'):
128
+ num_rel = emb_mgr.relation_num_groups
129
+ per_relation = num_rel is not None
130
+ prob_transform, _ = create_probability_components(
131
+ self.config.temperature,
132
+ per_relation=per_relation,
133
+ num_relations=(num_rel or 1)
134
+ )
135
+ self.probability_transform = prob_transform
136
+
137
+ # Initialize training state
138
+ self.training_state = NPLLTrainingState(
139
+ epoch=0,
140
+ em_iteration=0,
141
+ elbo_history=[],
142
+ rule_weight_history=[],
143
+ convergence_info={},
144
+ training_time=0.0,
145
+ best_elbo=float('-inf')
146
+ )
147
+
148
+ self.is_initialized = True
149
+
150
+ logger.info(f"NPLL model initialized with {len(logical_rules)} rules, "
151
+ f"{len(knowledge_graph.entities)} entities, "
152
+ f"{len(knowledge_graph.relations)} relations")
153
+
154
+ def forward(self, triples: List[Triple]) -> Dict[str, torch.Tensor]:
155
+ """
156
+ Forward pass through NPLL model
157
+
158
+ Args:
159
+ triples: List of triples to score
160
+
161
+ Returns:
162
+ Dictionary with scores and probabilities
163
+ """
164
+ if not self.is_initialized:
165
+ raise RuntimeError("NPLL model not initialized. Call initialize() first.")
166
+
167
+ # Get raw scores from scoring module
168
+ raw_scores = self.scoring_module.forward(triples)
169
+
170
+ # Group IDs per relation (no vocab mutation)
171
+ group_ids = None
172
+ if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
173
+ emb_mgr = self.scoring_module.embedding_manager
174
+ group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
175
+ # Ensure transform capacity if table grew
176
+ if hasattr(self.probability_transform, 'ensure_num_groups'):
177
+ self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
178
+
179
+ # Transform to probabilities
180
+ probabilities = self.probability_transform(raw_scores, apply_temperature=True, group_ids=group_ids)
181
+
182
+ # Get log probabilities
183
+ log_probabilities = self.probability_transform.get_log_probabilities(raw_scores, apply_temperature=True, group_ids=group_ids)
184
+
185
+ return {
186
+ 'raw_scores': raw_scores,
187
+ 'probabilities': probabilities,
188
+ 'log_probabilities': log_probabilities
189
+ }
190
+
191
+ def predict_single_triple(self, head: str, relation: str, tail: str, transient: bool = True) -> Dict[str, float]:
192
+ """
193
+ Predict probability for a single triple
194
+
195
+ Args:
196
+ head: Head entity name
197
+ relation: Relation name
198
+ tail: Tail entity name
199
+ transient: If True, do not mutate the underlying knowledge graph
200
+
201
+ Returns:
202
+ Dictionary with prediction results
203
+ """
204
+ if not self.is_initialized:
205
+ raise RuntimeError("NPLL model not initialized")
206
+
207
+ # Create triple object without mutating KG by default
208
+ if transient:
209
+ head_entity = Entity(head)
210
+ relation_obj = Relation(relation)
211
+ tail_entity = Entity(tail)
212
+ else:
213
+ head_entity = self.knowledge_graph.get_entity(head) or self.knowledge_graph.add_entity(head)
214
+ relation_obj = self.knowledge_graph.get_relation(relation) or self.knowledge_graph.add_relation(relation)
215
+ tail_entity = self.knowledge_graph.get_entity(tail) or self.knowledge_graph.add_entity(tail)
216
+
217
+ triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
218
+
219
+ # Get predictions
220
+ self.eval()
221
+ with torch.no_grad():
222
+ results = self.forward([triple])
223
+
224
+ return {
225
+ 'probability': results['probabilities'][0].item(),
226
+ 'log_probability': results['log_probabilities'][0].item(),
227
+ 'raw_score': results['raw_scores'][0].item()
228
+ }
229
+
230
+ def run_single_em_iteration(self) -> Dict[str, Any]:
231
+ """
232
+ Run a single E-M iteration
233
+
234
+ Returns:
235
+ Dictionary with iteration results
236
+ """
237
+ if not self.is_initialized:
238
+ raise RuntimeError("NPLL model not initialized")
239
+
240
+ iteration_start_time = time.time()
241
+
242
+ logger.debug(f"Starting E-M iteration {self.training_state.em_iteration}")
243
+
244
+ # E-step: Optimize Q(U)
245
+ logger.debug("Running E-step...")
246
+ e_step_result = self.e_step_runner.run_e_step(
247
+ self.mln, self.scoring_module, self.knowledge_graph
248
+ )
249
+
250
+ # M-step: Optimize rule weights ω
251
+ logger.debug("Running M-step...")
252
+ m_step_result = self.m_step_runner.run_m_step(self.mln, e_step_result)
253
+
254
+ # Update training state
255
+ current_elbo = e_step_result.elbo_value.item()
256
+ self.training_state.elbo_history.append(current_elbo)
257
+
258
+ if self.mln.rule_weights is not None:
259
+ current_weights = self.mln.rule_weights.data.tolist()
260
+ self.training_state.rule_weight_history.append(current_weights)
261
+
262
+ # Track best snapshot (MLN + scoring)
263
+ if current_elbo > self.training_state.best_elbo:
264
+ self.training_state.best_elbo = current_elbo
265
+ self.training_state.best_weights = {
266
+ 'mln': {k: v.clone() for k, v in self.mln.state_dict().items()},
267
+ 'scoring': {k: v.clone() for k, v in self.scoring_module.state_dict().items()} if self.scoring_module else {},
268
+ }
269
+
270
+ self.training_state.em_iteration += 1
271
+ iteration_time = time.time() - iteration_start_time
272
+ self.training_state.training_time += iteration_time
273
+
274
+ # Check convergence
275
+ converged = self._check_em_convergence()
276
+
277
+ iteration_result = {
278
+ 'em_iteration': self.training_state.em_iteration - 1,
279
+ 'e_step_result': e_step_result,
280
+ 'm_step_result': m_step_result,
281
+ 'elbo': current_elbo,
282
+ 'iteration_time': iteration_time,
283
+ 'converged': converged,
284
+ 'convergence_info': {
285
+ 'e_step_converged': e_step_result.convergence_info.get('converged', False),
286
+ 'm_step_converged': m_step_result.convergence_info.get('converged', False)
287
+ }
288
+ }
289
+
290
+ logger.debug(f"E-M iteration completed: ELBO={current_elbo:.6f}, "
291
+ f"Time={iteration_time:.2f}s, Converged={converged}")
292
+
293
+ return iteration_result
294
+
295
+ def train_epoch(self, max_em_iterations: Optional[int] = None) -> Dict[str, Any]:
296
+ """
297
+ Train for one epoch (multiple E-M iterations until convergence)
298
+
299
+ Args:
300
+ max_em_iterations: Maximum E-M iterations per epoch
301
+
302
+ Returns:
303
+ Dictionary with epoch results
304
+ """
305
+ if not self.is_initialized:
306
+ raise RuntimeError("NPLL model not initialized")
307
+
308
+ max_iterations = max_em_iterations or self.config.em_iterations
309
+ epoch_start_time = time.time()
310
+
311
+ logger.info(f"Starting training epoch {self.training_state.epoch}")
312
+
313
+ epoch_results = []
314
+ converged = False
315
+
316
+ for em_iter in range(max_iterations):
317
+ iteration_result = self.run_single_em_iteration()
318
+ epoch_results.append(iteration_result)
319
+
320
+ if iteration_result['converged']:
321
+ converged = True
322
+ logger.info(f"Converged after {em_iter + 1} E-M iterations")
323
+ break
324
+
325
+ epoch_time = time.time() - epoch_start_time
326
+ self.training_state.epoch += 1
327
+
328
+ epoch_summary = {
329
+ 'epoch': self.training_state.epoch - 1,
330
+ 'em_iterations': len(epoch_results),
331
+ 'converged': converged,
332
+ 'final_elbo': epoch_results[-1]['elbo'] if epoch_results else float('-inf'),
333
+ 'best_elbo_this_epoch': max(r['elbo'] for r in epoch_results) if epoch_results else float('-inf'),
334
+ 'epoch_time': epoch_time,
335
+ 'iteration_results': epoch_results
336
+ }
337
+
338
+ logger.info(f"Epoch {self.training_state.epoch - 1} completed: "
339
+ f"ELBO={epoch_summary['final_elbo']:.6f}, "
340
+ f"EM iterations={epoch_summary['em_iterations']}, "
341
+ f"Time={epoch_time:.2f}s")
342
+
343
+ return epoch_summary
344
+
345
+ def _check_em_convergence(self) -> bool:
346
+ """Check if E-M algorithm has converged with patience and relative tolerance"""
347
+ if len(self.training_state.elbo_history) < 2:
348
+ return False
349
+ h = self.training_state.elbo_history
350
+ rel = abs(h[-1] - h[-2]) / (abs(h[-2]) + 1e-8)
351
+ elbo_ok = rel < getattr(self.config, 'elbo_rel_tol', self.config.convergence_threshold)
352
+
353
+ weight_ok = True
354
+ if len(self.training_state.rule_weight_history) >= 2:
355
+ current_weights = torch.tensor(self.training_state.rule_weight_history[-1])
356
+ prev_weights = torch.tensor(self.training_state.rule_weight_history[-2])
357
+ weight_change = torch.norm(current_weights - prev_weights).item()
358
+ weight_ok = weight_change < getattr(self.config, 'weight_abs_tol', self.config.convergence_threshold)
359
+
360
+ if elbo_ok and weight_ok:
361
+ hits = self.training_state.convergence_info.get('hits', 0) + 1
362
+ self.training_state.convergence_info['hits'] = hits
363
+ else:
364
+ self.training_state.convergence_info['hits'] = 0
365
+ patience = getattr(self.config, 'convergence_patience', 3)
366
+ return self.training_state.convergence_info.get('hits', 0) >= patience
367
+
368
+ def get_rule_confidences(self) -> Dict[str, float]:
369
+ """Get learned confidence scores for all rules"""
370
+ if not self.is_initialized or self.mln is None or self.mln.rule_weights is None:
371
+ return {}
372
+
373
+ confidences = {}
374
+ for i, rule in enumerate(self.mln.logical_rules):
375
+ # Convert rule weight to confidence using sigmoid
376
+ weight = self.mln.rule_weights[i].item()
377
+ confidence = torch.sigmoid(torch.tensor(weight)).item()
378
+ confidences[rule.rule_id] = confidence
379
+
380
+ return confidences
381
+
382
+ def predict_unknown_facts(self, top_k: int = 10) -> List[Tuple[Triple, float]]:
383
+ """
384
+ Predict probabilities for unknown facts
385
+
386
+ Args:
387
+ top_k: Number of top predictions to return
388
+
389
+ Returns:
390
+ List of (triple, probability) tuples sorted by probability
391
+ """
392
+ if not self.is_initialized:
393
+ raise RuntimeError("NPLL model not initialized")
394
+
395
+ unknown_facts = list(self.knowledge_graph.unknown_facts)
396
+ if not unknown_facts:
397
+ return []
398
+
399
+ # Get predictions for unknown facts
400
+ results = self.forward(unknown_facts)
401
+ probabilities = results['probabilities']
402
+
403
+ # Create (triple, probability) pairs
404
+ predictions = list(zip(unknown_facts, probabilities.tolist()))
405
+
406
+ # Sort by probability (descending)
407
+ predictions.sort(key=lambda x: x[1], reverse=True)
408
+
409
+ return predictions[:top_k]
410
+
411
+ def save_model(self, filepath: str):
412
+ """Save complete NPLL model state with KG and rules serialization"""
413
+ if not self.is_initialized:
414
+ raise RuntimeError("Cannot save uninitialized model")
415
+
416
+ rules = self.mln.logical_rules if self.mln else []
417
+ payload = {
418
+ 'schema_version': self.model_version,
419
+ 'config': self.config.__dict__,
420
+ 'creation_time': self.creation_time,
421
+ 'training_state': self.training_state,
422
+ 'mln_state': self.mln.state_dict() if self.mln else None,
423
+ 'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
424
+ 'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
425
+ 'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
426
+ 'rules': [rule.serialize() for rule in rules],
427
+ }
428
+ torch.save(payload, filepath)
429
+ logger.info(f"NPLL model saved to {filepath}")
430
+
431
+ def load_model(self, filepath: str):
432
+ """
433
+ Load complete NPLL model state, rebuilding KG and rules before weights.
434
+
435
+ PRODUCTION FIX: Respects device configuration (cpu/cuda) instead of hardcoding cpu.
436
+ """
437
+ if not os.path.exists(filepath):
438
+ raise FileNotFoundError(f"Model file not found: {filepath}")
439
+
440
+ from .core import KnowledgeGraph, LogicalRule
441
+
442
+ # Respect device config: load to configured device (cpu or cuda)
443
+ # Note: weights_only=False required for custom objects (NPLLTrainingState, etc.)
444
+ device = getattr(self, 'device', 'cpu')
445
+ payload = torch.load(filepath, map_location=device, weights_only=False)
446
+
447
+ # Rebuild config
448
+ self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
449
+ # Rebuild KG and rules
450
+ kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
451
+ rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
452
+ # Recreate components
453
+ self.__init__(self.config)
454
+ if kg is not None:
455
+ self.initialize(kg, rules)
456
+ # Load weights
457
+ if payload.get('mln_state') and self.mln:
458
+ self.mln.load_state_dict(payload['mln_state'])
459
+ if payload.get('scoring_state') and self.scoring_module:
460
+ self.scoring_module.load_state_dict(payload['scoring_state'])
461
+ if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
462
+ self.probability_transform.load_state_dict(payload['prob_state'])
463
+ # Metadata
464
+ self.model_version = payload.get('schema_version', '1.0')
465
+ self.creation_time = payload.get('creation_time', time.time())
466
+ self.training_state = payload.get('training_state', self.training_state)
467
+ logger.info(f"NPLL model loaded from {filepath}")
468
+
469
+ def save_to_buffer(self, buffer):
470
+ """Save complete NPLL model state to a BytesIO buffer (for database storage)."""
471
+ if not self.is_initialized:
472
+ raise RuntimeError("Cannot save uninitialized model")
473
+
474
+ rules = self.mln.logical_rules if self.mln else []
475
+ payload = {
476
+ 'schema_version': self.model_version,
477
+ 'config': self.config.__dict__,
478
+ 'creation_time': self.creation_time,
479
+ 'training_state': self.training_state,
480
+ 'mln_state': self.mln.state_dict() if self.mln else None,
481
+ 'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
482
+ 'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
483
+ 'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
484
+ 'rules': [rule.serialize() for rule in rules],
485
+ }
486
+ torch.save(payload, buffer)
487
+ logger.info("NPLL model saved to buffer")
488
+
489
+ def load_from_buffer(self, buffer):
490
+ """Load complete NPLL model state from a BytesIO buffer (for database storage)."""
491
+ from .core import KnowledgeGraph, LogicalRule
492
+
493
+ # Respect device config
494
+ device = getattr(self, 'device', 'cpu')
495
+ buffer.seek(0)
496
+ payload = torch.load(buffer, map_location=device, weights_only=False)
497
+
498
+ # Rebuild config
499
+ self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
500
+ # Rebuild KG and rules
501
+ kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
502
+ rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
503
+ # Recreate components
504
+ self.__init__(self.config)
505
+ if kg is not None:
506
+ self.initialize(kg, rules)
507
+ # Load weights
508
+ if payload.get('mln_state') and self.mln:
509
+ self.mln.load_state_dict(payload['mln_state'])
510
+ if payload.get('scoring_state') and self.scoring_module:
511
+ self.scoring_module.load_state_dict(payload['scoring_state'])
512
+ if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
513
+ self.probability_transform.load_state_dict(payload['prob_state'])
514
+ # Metadata
515
+ self.model_version = payload.get('schema_version', '1.0')
516
+ self.creation_time = payload.get('creation_time', time.time())
517
+ self.training_state = payload.get('training_state', self.training_state)
518
+ logger.info("NPLL model loaded from buffer")
519
+
520
+ def _get_device(self) -> str:
521
+ for p in self.parameters():
522
+ return str(p.device)
523
+ return 'cpu'
524
+
525
+ def get_model_summary(self) -> Dict[str, Any]:
526
+ """Get comprehensive model summary (device-safe)"""
527
+ summary = {
528
+ 'model_version': self.model_version,
529
+ 'is_initialized': self.is_initialized,
530
+ 'config': self.config.__dict__ if self.config else {},
531
+ 'creation_time': self.creation_time,
532
+ 'calibration_version': self.calibration_version,
533
+ 'training_state': self.training_state.__dict__ if self.training_state else {},
534
+ 'num_parameters': sum(p.numel() for p in self.parameters()),
535
+ 'device': self._get_device(),
536
+ }
537
+ if self.is_initialized:
538
+ summary.update({
539
+ 'knowledge_graph_stats': self.knowledge_graph.get_statistics() if self.knowledge_graph else {},
540
+ 'mln_stats': self.mln.get_rule_statistics() if self.mln else {},
541
+ 'rule_confidences': self.get_rule_confidences(),
542
+ })
543
+ return summary
544
+
545
+ # --- Utilities: checkpoints and calibration ---
546
+ def restore_best(self) -> bool:
547
+ """Restore the best-scoring MLN and scoring-module weights if available."""
548
+ if not self.training_state or not getattr(self.training_state, 'best_weights', None):
549
+ return False
550
+ best = self.training_state.best_weights
551
+ if 'mln' in best and self.mln:
552
+ self.mln.load_state_dict(best['mln'])
553
+ if 'scoring' in best and self.scoring_module:
554
+ self.scoring_module.load_state_dict(best['scoring'])
555
+ return True
556
+
557
+ def calibrate_temperature_on_data(self,
558
+ triples: List[Triple],
559
+ labels: torch.Tensor,
560
+ max_iter: int = 100,
561
+ version: Optional[str] = None) -> float:
562
+ """
563
+ Calibrate the ProbabilityTransform temperature on a holdout set.
564
+ Stores the learned temperature in the transform and records a calibration version.
565
+ Returns the optimized temperature value.
566
+ """
567
+ if not self.is_initialized:
568
+ raise RuntimeError("NPLL model not initialized")
569
+ self.eval()
570
+ with torch.no_grad():
571
+ scores = self.scoring_module.forward(triples)
572
+
573
+ # Extract group_ids for calibration if per_group is enabled
574
+ group_ids = None
575
+ if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
576
+ emb_mgr = self.scoring_module.embedding_manager
577
+ group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
578
+ if hasattr(self.probability_transform, 'ensure_num_groups'):
579
+ self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
580
+
581
+ optimized_temp = self.probability_transform.calibrate_temperature(scores, labels.float(), max_iter=max_iter, group_ids=group_ids)
582
+ self.calibration_version = version or f"temp@{time.time():.0f}"
583
+ return optimized_temp
584
+
585
+ def __str__(self) -> str:
586
+ if not self.is_initialized:
587
+ return "NPLL Model (not initialized)"
588
+
589
+ stats = self.knowledge_graph.get_statistics() if self.knowledge_graph else {}
590
+ return (f"NPLL Model:\n"
591
+ f" Entities: {stats.get('num_entities', 0)}\n"
592
+ f" Relations: {stats.get('num_relations', 0)}\n"
593
+ f" Known Facts: {stats.get('num_known_facts', 0)}\n"
594
+ f" Unknown Facts: {stats.get('num_unknown_facts', 0)}\n"
595
+ f" Rules: {len(self.mln.logical_rules) if self.mln else 0}\n"
596
+ f" Training State: {self.training_state}")
597
+
598
+
599
+ def create_npll_model(config: NPLLConfig) -> NPLLModel:
600
+ """
601
+ Factory function to create NPLL model
602
+
603
+ Args:
604
+ config: NPLL configuration
605
+
606
+ Returns:
607
+ Uninitialized NPLL model
608
+ """
609
+ return NPLLModel(config)
610
+
611
+
612
+ def create_initialized_npll_model(knowledge_graph: KnowledgeGraph,
613
+ logical_rules: List[LogicalRule],
614
+ config: Optional[NPLLConfig] = None) -> NPLLModel:
615
+ """
616
+ Factory function to create and initialize NPLL model
617
+
618
+ Args:
619
+ knowledge_graph: Knowledge graph
620
+ logical_rules: List of logical rules
621
+ config: Optional configuration (uses default if not provided)
622
+
623
+ Returns:
624
+ Initialized NPLL model ready for training
625
+ """
626
+ if config is None:
627
+ config = get_config("ArangoDB_Triples") # Default configuration
628
+
629
+ model = NPLLModel(config)
630
+ model.initialize(knowledge_graph, logical_rules)
631
+
632
632
  return model