odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/inference/elbo.py CHANGED
@@ -1,435 +1,435 @@
1
- """
2
- ELBO (Evidence Lower Bound) computation for NPLL
3
- Exact implementation of Equations 3-5 from the paper
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from typing import List, Dict, Tuple, Optional, Any
10
- import logging
11
- from dataclasses import dataclass
12
-
13
- from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
14
- from ..core.mln import MarkovLogicNetwork
15
- from ..utils.config import NPLLConfig
16
- from ..utils.math_utils import (
17
- safe_log, bernoulli_log_prob, kl_divergence_bernoulli,
18
- log_sum_exp, compute_elbo_loss
19
- )
20
- from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
-
25
- @dataclass
26
- class ELBOComponents:
27
- """
28
- Components of ELBO computation for analysis and debugging
29
- """
30
- joint_term: torch.Tensor # Σ_U Q(U) log P(F,U|ω)
31
- entropy_term: torch.Tensor # -Σ_U Q(U) log Q(U)
32
- elbo: torch.Tensor # Total ELBO
33
- kl_divergence: torch.Tensor # KL(Q||P) if available
34
- num_samples: int # Number of samples used
35
-
36
- def __str__(self) -> str:
37
- return (f"ELBO Components:\n"
38
- f" Joint term: {self.joint_term.item():.6f}\n"
39
- f" Entropy term: {self.entropy_term.item():.6f}\n"
40
- f" Total ELBO: {self.elbo.item():.6f}\n"
41
- f" Samples: {self.num_samples}")
42
-
43
-
44
- class ELBOComputer(nn.Module):
45
- """
46
- Computes Evidence Lower Bound (ELBO) for NPLL training
47
-
48
- Paper Section 4: The optimization objective becomes maximizing the ELBO value
49
-
50
- Implements exact computation of Equations 3-5:
51
- - Equation 3: log P(F|ω) = log[P(F,U|ω)/Q(U)] - log[P(U|F,ω)/Q(U)]
52
- - Equation 4: log P(F|ω) = ELBO + KL(q||p)
53
- - Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
54
- """
55
-
56
- def __init__(self, config: NPLLConfig):
57
- super().__init__()
58
- self.config = config
59
- self.elbo_weight = config.elbo_weight
60
- self.kl_weight = config.kl_weight
61
-
62
- def compute_elbo(self,
63
- mln: MarkovLogicNetwork,
64
- known_facts: List[Triple],
65
- unknown_facts: List[Triple],
66
- approximate_posterior_probs: torch.Tensor,
67
- ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> ELBOComponents:
68
- """
69
- Compute ELBO following Equation 5
70
- """
71
- if len(unknown_facts) == 0:
72
- # No unknown facts, ELBO is just log P(F|ω)
73
- log_known_prob = self._compute_known_facts_probability(mln, known_facts)
74
- zero = self._scalar_like(0.0, log_known_prob)
75
- return ELBOComponents(
76
- joint_term=log_known_prob,
77
- entropy_term=zero,
78
- elbo=log_known_prob,
79
- kl_divergence=zero,
80
- num_samples=0
81
- )
82
-
83
- # Sample possible worlds for ELBO computation
84
- sampled_worlds = self._sample_possible_worlds(
85
- known_facts, unknown_facts, approximate_posterior_probs
86
- )
87
-
88
- if not sampled_worlds:
89
- logger.warning("No sampled worlds for ELBO computation")
90
- zero = self._scalar_like(0.0, approximate_posterior_probs)
91
- return ELBOComponents(
92
- joint_term=zero,
93
- entropy_term=zero,
94
- elbo=zero,
95
- kl_divergence=zero,
96
- num_samples=0
97
- )
98
-
99
- # Compute joint term: E_Q[log P(F,U|ω)] via MC average (do NOT weight by Q again)
100
- joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=False)
101
-
102
- # Compute entropy term: -Σ_U Q(U) log Q(U)
103
- entropy_term = self._compute_entropy_term(approximate_posterior_probs)
104
-
105
- # Total ELBO
106
- elbo = joint_term + entropy_term
107
-
108
- # Optional: compute KL divergence if true posterior is available
109
- kl_div = torch.tensor(0.0)
110
- return ELBOComponents(
111
- joint_term=joint_term,
112
- entropy_term=entropy_term,
113
- elbo=elbo,
114
- kl_divergence=kl_div,
115
- num_samples=len(sampled_worlds)
116
- )
117
-
118
- def _compute_joint_term(self,
119
- mln: MarkovLogicNetwork,
120
- sampled_worlds: List[Dict[Triple, bool]],
121
- propagate_grads: bool = False) -> torch.Tensor:
122
- """
123
- Monte-Carlo estimate of E_Q[log P(F,U|ω)] as a simple average over worlds.
124
- If propagate_grads=True, allow gradients to flow to MLN weights.
125
- """
126
- if not sampled_worlds:
127
- ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
128
- return self._scalar_like(0.0, ref)
129
-
130
- log_joint = []
131
- for world_assignment in sampled_worlds:
132
- logp = mln.compute_joint_probability(world_assignment, detach_weights=not propagate_grads)
133
- log_joint.append(logp)
134
- log_joint_tensor = torch.stack(log_joint)
135
- return log_joint_tensor.mean()
136
-
137
- def _compute_entropy_term(self, q_probs: torch.Tensor) -> torch.Tensor:
138
- """
139
- Compute entropy term: -Σ_U Q(U) log Q(U)
140
-
141
- This is the second term in Equation 5
142
- """
143
- # Clamp probabilities to avoid log(0)
144
- q_probs_clamped = torch.clamp(q_probs, min=1e-8, max=1.0 - 1e-8)
145
-
146
- # Bernoulli entropy: -[p*log(p) + (1-p)*log(1-p)]
147
- entropy = -(q_probs_clamped * safe_log(q_probs_clamped) +
148
- (1 - q_probs_clamped) * safe_log(1 - q_probs_clamped))
149
-
150
- # Sum over all unknown facts
151
- total_entropy = torch.sum(entropy)
152
- return total_entropy
153
-
154
- def _sample_possible_worlds(self,
155
- known_facts: List[Triple],
156
- unknown_facts: List[Triple],
157
- q_probs: torch.Tensor,
158
- num_samples: int = None) -> List[Dict[Triple, bool]]:
159
- """Sample worlds over unknown facts only, preserving order; vectorized draws."""
160
- M = len(unknown_facts)
161
- if M == 0:
162
- return []
163
- num_samples = num_samples or min(100, 2 ** min(M, 10))
164
- device = q_probs.device
165
- probs = q_probs.unsqueeze(0).expand(num_samples, -1)
166
- samples = torch.bernoulli(probs).bool()
167
- worlds: List[Dict[Triple, bool]] = []
168
- for s in samples:
169
- w = {f: True for f in known_facts}
170
- for i, fact in enumerate(unknown_facts):
171
- w[fact] = bool(s[i].item())
172
- worlds.append(w)
173
- return worlds
174
-
175
- def _compute_world_q_probability(self,
176
- world_assignment: Dict[Triple, bool],
177
- q_probs: torch.Tensor,
178
- unknown_facts: Optional[List[Triple]] = None) -> torch.Tensor:
179
- """
180
- Compute Q(U) probability for a specific world assignment
181
-
182
- Q(U) = ∏_{uk∈U} p_k^{uk} (1-p_k)^{1-uk}
183
- """
184
- ref = q_probs
185
- log_prob = self._scalar_like(0.0, ref)
186
- facts = unknown_facts if unknown_facts is not None else [f for f in world_assignment.keys() if f not in []]
187
- for i, fact in enumerate(facts):
188
- if i >= len(q_probs):
189
- break
190
- prob = torch.clamp(q_probs[i], min=1e-8, max=1.0 - 1e-8)
191
- truth_value = world_assignment[fact]
192
- log_prob = log_prob + (safe_log(prob) if truth_value else safe_log(1 - prob))
193
- return torch.exp(log_prob)
194
-
195
- def _compute_known_facts_probability(self,
196
- mln: MarkovLogicNetwork,
197
- known_facts: List[Triple]) -> torch.Tensor:
198
- """Compute log P(F|ω) when no unknown facts"""
199
- if not known_facts:
200
- return torch.tensor(0.0)
201
-
202
- # Create assignment with all known facts as true
203
- known_assignment = {fact: True for fact in known_facts}
204
-
205
- # Compute joint probability
206
- log_prob = mln.compute_joint_probability(known_assignment)
207
-
208
- return log_prob
209
-
210
- def compute_elbo_gradient(self,
211
- mln: MarkovLogicNetwork,
212
- known_facts: List[Triple],
213
- unknown_facts: List[Triple],
214
- q_probs: torch.Tensor) -> torch.Tensor:
215
- """Compute ∂ELBO/∂ω allowing gradients to MLN weights."""
216
- if mln.rule_weights is not None:
217
- mln.rule_weights.requires_grad_(True)
218
- if mln.rule_weights.grad is not None:
219
- mln.rule_weights.grad.zero_()
220
-
221
- sampled_worlds = self._sample_possible_worlds(known_facts, unknown_facts, q_probs)
222
- joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=True)
223
- entropy_term = self._compute_entropy_term(q_probs)
224
- elbo = joint_term + entropy_term
225
- loss = -elbo
226
- loss.backward()
227
-
228
- if mln.rule_weights is not None and mln.rule_weights.grad is not None:
229
- gradients = mln.rule_weights.grad.clone()
230
- else:
231
- ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
232
- zeros = torch.zeros(len(mln.logical_rules)) if ref is None else torch.zeros(len(mln.logical_rules), device=ref.device, dtype=ref.dtype)
233
- gradients = zeros
234
- return gradients
235
-
236
-
237
- class ELBOLoss(nn.Module):
238
- """
239
- ELBO-based loss function for NPLL training
240
-
241
- Implements the loss function that maximizes ELBO (minimizes negative ELBO)
242
- """
243
-
244
- def __init__(self, config: NPLLConfig):
245
- super().__init__()
246
- self.config = config
247
- self.elbo_computer = ELBOComputer(config)
248
-
249
- # Loss weights
250
- self.elbo_weight = config.elbo_weight
251
- self.regularization_weight = 0.01
252
-
253
- def forward(self,
254
- mln: MarkovLogicNetwork,
255
- known_facts: List[Triple],
256
- unknown_facts: List[Triple],
257
- predicted_probs: torch.Tensor,
258
- target_probs: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
259
- """
260
- Compute ELBO loss
261
-
262
- """
263
- # Compute ELBO components
264
- elbo_components = self.elbo_computer.compute_elbo(
265
- mln, known_facts, unknown_facts, predicted_probs
266
- )
267
-
268
- # Primary loss: negative ELBO (we want to maximize ELBO)
269
- elbo_loss = -elbo_components.elbo
270
-
271
- # Optional supervised loss if targets provided
272
- supervised_loss = torch.tensor(0.0)
273
- if target_probs is not None:
274
- supervised_loss = F.binary_cross_entropy(predicted_probs, target_probs)
275
-
276
- # Regularization loss on rule weights
277
- regularization_loss = torch.tensor(0.0)
278
- if mln.rule_weights is not None:
279
- regularization_loss = torch.sum(mln.rule_weights ** 2)
280
-
281
- # Total loss
282
- total_loss = (self.elbo_weight * elbo_loss +
283
- supervised_loss +
284
- self.regularization_weight * regularization_loss)
285
-
286
- return {
287
- 'total_loss': total_loss,
288
- 'elbo_loss': elbo_loss,
289
- 'supervised_loss': supervised_loss,
290
- 'regularization_loss': regularization_loss,
291
- 'elbo_components': elbo_components
292
- }
293
-
294
-
295
- class VariationalInference:
296
- """
297
- Variational inference for NPLL using ELBO optimization
298
- """
299
-
300
- def __init__(self, config: NPLLConfig):
301
- self.config = config
302
- self.elbo_computer = ELBOComputer(config)
303
- self.convergence_threshold = config.convergence_threshold
304
- self.max_iterations = config.em_iterations
305
-
306
- def optimize_approximate_posterior(self,
307
- mln: MarkovLogicNetwork,
308
- known_facts: List[Triple],
309
- unknown_facts: List[Triple],
310
- initial_probs: Optional[torch.Tensor] = None) -> Dict[str, Any]:
311
- """
312
- Optimize approximate posterior Q(U) to maximize ELBO
313
-
314
- This implements the variational inference component of the E-step
315
-
316
- """
317
- num_unknown = len(unknown_facts)
318
- if num_unknown == 0:
319
- return {
320
- 'optimized_probs': torch.tensor([]),
321
- 'elbo_history': [],
322
- 'converged': True,
323
- 'iterations': 0
324
- }
325
-
326
- # Initialize logits phi (unconstrained)
327
- if initial_probs is None:
328
- phi = torch.zeros(num_unknown)
329
- else:
330
- with torch.no_grad():
331
- init = torch.clamp(initial_probs, 1e-6, 1 - 1e-6)
332
- phi = torch.log(init / (1 - init))
333
- device = mln.rule_weights.device if (mln.rule_weights is not None) else torch.device('cpu')
334
- phi = phi.to(device)
335
- phi.requires_grad_(True)
336
-
337
- optimizer = torch.optim.Adam([phi], lr=0.01)
338
- elbo_history: List[float] = []
339
- prev_elbo = float('-inf')
340
-
341
- for iteration in range(self.max_iterations):
342
- optimizer.zero_grad()
343
- q_probs = torch.sigmoid(phi)
344
- elbo_components = self.elbo_computer.compute_elbo(
345
- mln, known_facts, unknown_facts, q_probs
346
- )
347
- loss = -elbo_components.elbo
348
- loss.backward(retain_graph=True)
349
- torch.nn.utils.clip_grad_norm_([phi], max_norm=1.0)
350
- optimizer.step()
351
-
352
- # Detach to prevent gradient accumulation issues
353
- current_elbo = elbo_components.elbo.detach().item()
354
- elbo_history.append(current_elbo)
355
- if abs(current_elbo - prev_elbo) < self.convergence_threshold:
356
- logger.debug(f"Variational inference converged at iteration {iteration}")
357
- break
358
- prev_elbo = current_elbo
359
-
360
- return {
361
- 'optimized_probs': torch.sigmoid(phi).detach(),
362
- 'elbo_history': elbo_history,
363
- 'converged': len(elbo_history) < self.max_iterations,
364
- 'iterations': len(elbo_history),
365
- 'final_elbo': elbo_history[-1] if elbo_history else float('-inf')
366
- }
367
-
368
- def _scalar_like(self, value: float, ref: Optional[torch.Tensor]) -> torch.Tensor:
369
- """Create a scalar tensor on the same device/dtype as ref (if provided)."""
370
- if ref is None:
371
- return torch.tensor(value)
372
- return torch.tensor(value, device=ref.device, dtype=ref.dtype)
373
-
374
-
375
- def create_elbo_computer(config: NPLLConfig) -> ELBOComputer:
376
- """Factory function to create ELBO computer"""
377
- return ELBOComputer(config)
378
-
379
-
380
- def verify_elbo_implementation():
381
- """Verify ELBO computation implementation"""
382
- from ..utils.config import default_config
383
- from ..core import load_knowledge_graph_from_triples
384
- from ..core.mln import create_mln_from_kg_and_rules
385
- from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
386
-
387
- # Create test data
388
- test_triples = [
389
- ("A", "r1", "B"),
390
- ("B", "r2", "C")
391
- ]
392
-
393
- kg = load_knowledge_graph_from_triples(test_triples)
394
-
395
- # Create test rule
396
- from ..core import Relation
397
- r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
398
- x, y, z = Variable('x'), Variable('y'), Variable('z')
399
-
400
- test_rule = LogicalRule(
401
- rule_id="test_elbo_rule",
402
- body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
403
- head=Atom(r3, (x, z)),
404
- rule_type=RuleType.TRANSITIVITY
405
- )
406
-
407
- # Create MLN
408
- mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
409
-
410
- # Test ELBO computation
411
- elbo_computer = ELBOComputer(default_config)
412
-
413
- known_facts = list(kg.known_facts)
414
- unknown_facts = [list(kg.known_facts)[0]] # Treat one known fact as unknown for testing
415
- q_probs = torch.tensor([0.8]) # High probability for the "unknown" fact
416
-
417
- elbo_components = elbo_computer.compute_elbo(
418
- mln, known_facts[:-1], unknown_facts, q_probs
419
- )
420
-
421
- # Verify components
422
- assert torch.isfinite(elbo_components.elbo), "ELBO should be finite"
423
- assert torch.isfinite(elbo_components.joint_term), "Joint term should be finite"
424
- assert torch.isfinite(elbo_components.entropy_term), "Entropy term should be finite"
425
-
426
- # Test variational inference
427
- vi = VariationalInference(default_config)
428
- vi_result = vi.optimize_approximate_posterior(mln, known_facts[:-1], unknown_facts)
429
-
430
- assert len(vi_result['optimized_probs']) == len(unknown_facts), "Should optimize all unknown facts"
431
- assert len(vi_result['elbo_history']) > 0, "Should have ELBO history"
432
-
433
- logger.info("ELBO implementation verified successfully")
434
-
1
+ """
2
+ ELBO (Evidence Lower Bound) computation for NPLL
3
+ Exact implementation of Equations 3-5 from the paper
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import List, Dict, Tuple, Optional, Any
10
+ import logging
11
+ from dataclasses import dataclass
12
+
13
+ from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
14
+ from ..core.mln import MarkovLogicNetwork
15
+ from ..utils.config import NPLLConfig
16
+ from ..utils.math_utils import (
17
+ safe_log, bernoulli_log_prob, kl_divergence_bernoulli,
18
+ log_sum_exp, compute_elbo_loss
19
+ )
20
+ from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class ELBOComponents:
27
+ """
28
+ Components of ELBO computation for analysis and debugging
29
+ """
30
+ joint_term: torch.Tensor # Σ_U Q(U) log P(F,U|ω)
31
+ entropy_term: torch.Tensor # -Σ_U Q(U) log Q(U)
32
+ elbo: torch.Tensor # Total ELBO
33
+ kl_divergence: torch.Tensor # KL(Q||P) if available
34
+ num_samples: int # Number of samples used
35
+
36
+ def __str__(self) -> str:
37
+ return (f"ELBO Components:\n"
38
+ f" Joint term: {self.joint_term.item():.6f}\n"
39
+ f" Entropy term: {self.entropy_term.item():.6f}\n"
40
+ f" Total ELBO: {self.elbo.item():.6f}\n"
41
+ f" Samples: {self.num_samples}")
42
+
43
+
44
+ class ELBOComputer(nn.Module):
45
+ """
46
+ Computes Evidence Lower Bound (ELBO) for NPLL training
47
+
48
+ Paper Section 4: The optimization objective becomes maximizing the ELBO value
49
+
50
+ Implements exact computation of Equations 3-5:
51
+ - Equation 3: log P(F|ω) = log[P(F,U|ω)/Q(U)] - log[P(U|F,ω)/Q(U)]
52
+ - Equation 4: log P(F|ω) = ELBO + KL(q||p)
53
+ - Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
54
+ """
55
+
56
+ def __init__(self, config: NPLLConfig):
57
+ super().__init__()
58
+ self.config = config
59
+ self.elbo_weight = config.elbo_weight
60
+ self.kl_weight = config.kl_weight
61
+
62
+ def compute_elbo(self,
63
+ mln: MarkovLogicNetwork,
64
+ known_facts: List[Triple],
65
+ unknown_facts: List[Triple],
66
+ approximate_posterior_probs: torch.Tensor,
67
+ ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> ELBOComponents:
68
+ """
69
+ Compute ELBO following Equation 5
70
+ """
71
+ if len(unknown_facts) == 0:
72
+ # No unknown facts, ELBO is just log P(F|ω)
73
+ log_known_prob = self._compute_known_facts_probability(mln, known_facts)
74
+ zero = self._scalar_like(0.0, log_known_prob)
75
+ return ELBOComponents(
76
+ joint_term=log_known_prob,
77
+ entropy_term=zero,
78
+ elbo=log_known_prob,
79
+ kl_divergence=zero,
80
+ num_samples=0
81
+ )
82
+
83
+ # Sample possible worlds for ELBO computation
84
+ sampled_worlds = self._sample_possible_worlds(
85
+ known_facts, unknown_facts, approximate_posterior_probs
86
+ )
87
+
88
+ if not sampled_worlds:
89
+ logger.warning("No sampled worlds for ELBO computation")
90
+ zero = self._scalar_like(0.0, approximate_posterior_probs)
91
+ return ELBOComponents(
92
+ joint_term=zero,
93
+ entropy_term=zero,
94
+ elbo=zero,
95
+ kl_divergence=zero,
96
+ num_samples=0
97
+ )
98
+
99
+ # Compute joint term: E_Q[log P(F,U|ω)] via MC average (do NOT weight by Q again)
100
+ joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=False)
101
+
102
+ # Compute entropy term: -Σ_U Q(U) log Q(U)
103
+ entropy_term = self._compute_entropy_term(approximate_posterior_probs)
104
+
105
+ # Total ELBO
106
+ elbo = joint_term + entropy_term
107
+
108
+ # Optional: compute KL divergence if true posterior is available
109
+ kl_div = torch.tensor(0.0)
110
+ return ELBOComponents(
111
+ joint_term=joint_term,
112
+ entropy_term=entropy_term,
113
+ elbo=elbo,
114
+ kl_divergence=kl_div,
115
+ num_samples=len(sampled_worlds)
116
+ )
117
+
118
+ def _compute_joint_term(self,
119
+ mln: MarkovLogicNetwork,
120
+ sampled_worlds: List[Dict[Triple, bool]],
121
+ propagate_grads: bool = False) -> torch.Tensor:
122
+ """
123
+ Monte-Carlo estimate of E_Q[log P(F,U|ω)] as a simple average over worlds.
124
+ If propagate_grads=True, allow gradients to flow to MLN weights.
125
+ """
126
+ if not sampled_worlds:
127
+ ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
128
+ return self._scalar_like(0.0, ref)
129
+
130
+ log_joint = []
131
+ for world_assignment in sampled_worlds:
132
+ logp = mln.compute_joint_probability(world_assignment, detach_weights=not propagate_grads)
133
+ log_joint.append(logp)
134
+ log_joint_tensor = torch.stack(log_joint)
135
+ return log_joint_tensor.mean()
136
+
137
+ def _compute_entropy_term(self, q_probs: torch.Tensor) -> torch.Tensor:
138
+ """
139
+ Compute entropy term: -Σ_U Q(U) log Q(U)
140
+
141
+ This is the second term in Equation 5
142
+ """
143
+ # Clamp probabilities to avoid log(0)
144
+ q_probs_clamped = torch.clamp(q_probs, min=1e-8, max=1.0 - 1e-8)
145
+
146
+ # Bernoulli entropy: -[p*log(p) + (1-p)*log(1-p)]
147
+ entropy = -(q_probs_clamped * safe_log(q_probs_clamped) +
148
+ (1 - q_probs_clamped) * safe_log(1 - q_probs_clamped))
149
+
150
+ # Sum over all unknown facts
151
+ total_entropy = torch.sum(entropy)
152
+ return total_entropy
153
+
154
+ def _sample_possible_worlds(self,
155
+ known_facts: List[Triple],
156
+ unknown_facts: List[Triple],
157
+ q_probs: torch.Tensor,
158
+ num_samples: int = None) -> List[Dict[Triple, bool]]:
159
+ """Sample worlds over unknown facts only, preserving order; vectorized draws."""
160
+ M = len(unknown_facts)
161
+ if M == 0:
162
+ return []
163
+ num_samples = num_samples or min(100, 2 ** min(M, 10))
164
+ device = q_probs.device
165
+ probs = q_probs.unsqueeze(0).expand(num_samples, -1)
166
+ samples = torch.bernoulli(probs).bool()
167
+ worlds: List[Dict[Triple, bool]] = []
168
+ for s in samples:
169
+ w = {f: True for f in known_facts}
170
+ for i, fact in enumerate(unknown_facts):
171
+ w[fact] = bool(s[i].item())
172
+ worlds.append(w)
173
+ return worlds
174
+
175
+ def _compute_world_q_probability(self,
176
+ world_assignment: Dict[Triple, bool],
177
+ q_probs: torch.Tensor,
178
+ unknown_facts: Optional[List[Triple]] = None) -> torch.Tensor:
179
+ """
180
+ Compute Q(U) probability for a specific world assignment
181
+
182
+ Q(U) = ∏_{uk∈U} p_k^{uk} (1-p_k)^{1-uk}
183
+ """
184
+ ref = q_probs
185
+ log_prob = self._scalar_like(0.0, ref)
186
+ facts = unknown_facts if unknown_facts is not None else [f for f in world_assignment.keys() if f not in []]
187
+ for i, fact in enumerate(facts):
188
+ if i >= len(q_probs):
189
+ break
190
+ prob = torch.clamp(q_probs[i], min=1e-8, max=1.0 - 1e-8)
191
+ truth_value = world_assignment[fact]
192
+ log_prob = log_prob + (safe_log(prob) if truth_value else safe_log(1 - prob))
193
+ return torch.exp(log_prob)
194
+
195
+ def _compute_known_facts_probability(self,
196
+ mln: MarkovLogicNetwork,
197
+ known_facts: List[Triple]) -> torch.Tensor:
198
+ """Compute log P(F|ω) when no unknown facts"""
199
+ if not known_facts:
200
+ return torch.tensor(0.0)
201
+
202
+ # Create assignment with all known facts as true
203
+ known_assignment = {fact: True for fact in known_facts}
204
+
205
+ # Compute joint probability
206
+ log_prob = mln.compute_joint_probability(known_assignment)
207
+
208
+ return log_prob
209
+
210
+ def compute_elbo_gradient(self,
211
+ mln: MarkovLogicNetwork,
212
+ known_facts: List[Triple],
213
+ unknown_facts: List[Triple],
214
+ q_probs: torch.Tensor) -> torch.Tensor:
215
+ """Compute ∂ELBO/∂ω allowing gradients to MLN weights."""
216
+ if mln.rule_weights is not None:
217
+ mln.rule_weights.requires_grad_(True)
218
+ if mln.rule_weights.grad is not None:
219
+ mln.rule_weights.grad.zero_()
220
+
221
+ sampled_worlds = self._sample_possible_worlds(known_facts, unknown_facts, q_probs)
222
+ joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=True)
223
+ entropy_term = self._compute_entropy_term(q_probs)
224
+ elbo = joint_term + entropy_term
225
+ loss = -elbo
226
+ loss.backward()
227
+
228
+ if mln.rule_weights is not None and mln.rule_weights.grad is not None:
229
+ gradients = mln.rule_weights.grad.clone()
230
+ else:
231
+ ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
232
+ zeros = torch.zeros(len(mln.logical_rules)) if ref is None else torch.zeros(len(mln.logical_rules), device=ref.device, dtype=ref.dtype)
233
+ gradients = zeros
234
+ return gradients
235
+
236
+
237
+ class ELBOLoss(nn.Module):
238
+ """
239
+ ELBO-based loss function for NPLL training
240
+
241
+ Implements the loss function that maximizes ELBO (minimizes negative ELBO)
242
+ """
243
+
244
+ def __init__(self, config: NPLLConfig):
245
+ super().__init__()
246
+ self.config = config
247
+ self.elbo_computer = ELBOComputer(config)
248
+
249
+ # Loss weights
250
+ self.elbo_weight = config.elbo_weight
251
+ self.regularization_weight = 0.01
252
+
253
+ def forward(self,
254
+ mln: MarkovLogicNetwork,
255
+ known_facts: List[Triple],
256
+ unknown_facts: List[Triple],
257
+ predicted_probs: torch.Tensor,
258
+ target_probs: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
259
+ """
260
+ Compute ELBO loss
261
+
262
+ """
263
+ # Compute ELBO components
264
+ elbo_components = self.elbo_computer.compute_elbo(
265
+ mln, known_facts, unknown_facts, predicted_probs
266
+ )
267
+
268
+ # Primary loss: negative ELBO (we want to maximize ELBO)
269
+ elbo_loss = -elbo_components.elbo
270
+
271
+ # Optional supervised loss if targets provided
272
+ supervised_loss = torch.tensor(0.0)
273
+ if target_probs is not None:
274
+ supervised_loss = F.binary_cross_entropy(predicted_probs, target_probs)
275
+
276
+ # Regularization loss on rule weights
277
+ regularization_loss = torch.tensor(0.0)
278
+ if mln.rule_weights is not None:
279
+ regularization_loss = torch.sum(mln.rule_weights ** 2)
280
+
281
+ # Total loss
282
+ total_loss = (self.elbo_weight * elbo_loss +
283
+ supervised_loss +
284
+ self.regularization_weight * regularization_loss)
285
+
286
+ return {
287
+ 'total_loss': total_loss,
288
+ 'elbo_loss': elbo_loss,
289
+ 'supervised_loss': supervised_loss,
290
+ 'regularization_loss': regularization_loss,
291
+ 'elbo_components': elbo_components
292
+ }
293
+
294
+
295
+ class VariationalInference:
296
+ """
297
+ Variational inference for NPLL using ELBO optimization
298
+ """
299
+
300
+ def __init__(self, config: NPLLConfig):
301
+ self.config = config
302
+ self.elbo_computer = ELBOComputer(config)
303
+ self.convergence_threshold = config.convergence_threshold
304
+ self.max_iterations = config.em_iterations
305
+
306
+ def optimize_approximate_posterior(self,
307
+ mln: MarkovLogicNetwork,
308
+ known_facts: List[Triple],
309
+ unknown_facts: List[Triple],
310
+ initial_probs: Optional[torch.Tensor] = None) -> Dict[str, Any]:
311
+ """
312
+ Optimize approximate posterior Q(U) to maximize ELBO
313
+
314
+ This implements the variational inference component of the E-step
315
+
316
+ """
317
+ num_unknown = len(unknown_facts)
318
+ if num_unknown == 0:
319
+ return {
320
+ 'optimized_probs': torch.tensor([]),
321
+ 'elbo_history': [],
322
+ 'converged': True,
323
+ 'iterations': 0
324
+ }
325
+
326
+ # Initialize logits phi (unconstrained)
327
+ if initial_probs is None:
328
+ phi = torch.zeros(num_unknown)
329
+ else:
330
+ with torch.no_grad():
331
+ init = torch.clamp(initial_probs, 1e-6, 1 - 1e-6)
332
+ phi = torch.log(init / (1 - init))
333
+ device = mln.rule_weights.device if (mln.rule_weights is not None) else torch.device('cpu')
334
+ phi = phi.to(device)
335
+ phi.requires_grad_(True)
336
+
337
+ optimizer = torch.optim.Adam([phi], lr=0.01)
338
+ elbo_history: List[float] = []
339
+ prev_elbo = float('-inf')
340
+
341
+ for iteration in range(self.max_iterations):
342
+ optimizer.zero_grad()
343
+ q_probs = torch.sigmoid(phi)
344
+ elbo_components = self.elbo_computer.compute_elbo(
345
+ mln, known_facts, unknown_facts, q_probs
346
+ )
347
+ loss = -elbo_components.elbo
348
+ loss.backward(retain_graph=True)
349
+ torch.nn.utils.clip_grad_norm_([phi], max_norm=1.0)
350
+ optimizer.step()
351
+
352
+ # Detach to prevent gradient accumulation issues
353
+ current_elbo = elbo_components.elbo.detach().item()
354
+ elbo_history.append(current_elbo)
355
+ if abs(current_elbo - prev_elbo) < self.convergence_threshold:
356
+ logger.debug(f"Variational inference converged at iteration {iteration}")
357
+ break
358
+ prev_elbo = current_elbo
359
+
360
+ return {
361
+ 'optimized_probs': torch.sigmoid(phi).detach(),
362
+ 'elbo_history': elbo_history,
363
+ 'converged': len(elbo_history) < self.max_iterations,
364
+ 'iterations': len(elbo_history),
365
+ 'final_elbo': elbo_history[-1] if elbo_history else float('-inf')
366
+ }
367
+
368
+ def _scalar_like(self, value: float, ref: Optional[torch.Tensor]) -> torch.Tensor:
369
+ """Create a scalar tensor on the same device/dtype as ref (if provided)."""
370
+ if ref is None:
371
+ return torch.tensor(value)
372
+ return torch.tensor(value, device=ref.device, dtype=ref.dtype)
373
+
374
+
375
+ def create_elbo_computer(config: NPLLConfig) -> ELBOComputer:
376
+ """Factory function to create ELBO computer"""
377
+ return ELBOComputer(config)
378
+
379
+
380
+ def verify_elbo_implementation():
381
+ """Verify ELBO computation implementation"""
382
+ from ..utils.config import default_config
383
+ from ..core import load_knowledge_graph_from_triples
384
+ from ..core.mln import create_mln_from_kg_and_rules
385
+ from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
386
+
387
+ # Create test data
388
+ test_triples = [
389
+ ("A", "r1", "B"),
390
+ ("B", "r2", "C")
391
+ ]
392
+
393
+ kg = load_knowledge_graph_from_triples(test_triples)
394
+
395
+ # Create test rule
396
+ from ..core import Relation
397
+ r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
398
+ x, y, z = Variable('x'), Variable('y'), Variable('z')
399
+
400
+ test_rule = LogicalRule(
401
+ rule_id="test_elbo_rule",
402
+ body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
403
+ head=Atom(r3, (x, z)),
404
+ rule_type=RuleType.TRANSITIVITY
405
+ )
406
+
407
+ # Create MLN
408
+ mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
409
+
410
+ # Test ELBO computation
411
+ elbo_computer = ELBOComputer(default_config)
412
+
413
+ known_facts = list(kg.known_facts)
414
+ unknown_facts = [list(kg.known_facts)[0]] # Treat one known fact as unknown for testing
415
+ q_probs = torch.tensor([0.8]) # High probability for the "unknown" fact
416
+
417
+ elbo_components = elbo_computer.compute_elbo(
418
+ mln, known_facts[:-1], unknown_facts, q_probs
419
+ )
420
+
421
+ # Verify components
422
+ assert torch.isfinite(elbo_components.elbo), "ELBO should be finite"
423
+ assert torch.isfinite(elbo_components.joint_term), "Joint term should be finite"
424
+ assert torch.isfinite(elbo_components.entropy_term), "Entropy term should be finite"
425
+
426
+ # Test variational inference
427
+ vi = VariationalInference(default_config)
428
+ vi_result = vi.optimize_approximate_posterior(mln, known_facts[:-1], unknown_facts)
429
+
430
+ assert len(vi_result['optimized_probs']) == len(unknown_facts), "Should optimize all unknown facts"
431
+ assert len(vi_result['elbo_history']) > 0, "Should have ELBO history"
432
+
433
+ logger.info("ELBO implementation verified successfully")
434
+
435
435
  return True