odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/inference/elbo.py
CHANGED
|
@@ -1,435 +1,435 @@
|
|
|
1
|
-
"""
|
|
2
|
-
ELBO (Evidence Lower Bound) computation for NPLL
|
|
3
|
-
Exact implementation of Equations 3-5 from the paper
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import torch.nn as nn
|
|
8
|
-
import torch.nn.functional as F
|
|
9
|
-
from typing import List, Dict, Tuple, Optional, Any
|
|
10
|
-
import logging
|
|
11
|
-
from dataclasses import dataclass
|
|
12
|
-
|
|
13
|
-
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
14
|
-
from ..core.mln import MarkovLogicNetwork
|
|
15
|
-
from ..utils.config import NPLLConfig
|
|
16
|
-
from ..utils.math_utils import (
|
|
17
|
-
safe_log, bernoulli_log_prob, kl_divergence_bernoulli,
|
|
18
|
-
log_sum_exp, compute_elbo_loss
|
|
19
|
-
)
|
|
20
|
-
from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
|
|
21
|
-
|
|
22
|
-
logger = logging.getLogger(__name__)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@dataclass
|
|
26
|
-
class ELBOComponents:
|
|
27
|
-
"""
|
|
28
|
-
Components of ELBO computation for analysis and debugging
|
|
29
|
-
"""
|
|
30
|
-
joint_term: torch.Tensor # Σ_U Q(U) log P(F,U|ω)
|
|
31
|
-
entropy_term: torch.Tensor # -Σ_U Q(U) log Q(U)
|
|
32
|
-
elbo: torch.Tensor # Total ELBO
|
|
33
|
-
kl_divergence: torch.Tensor # KL(Q||P) if available
|
|
34
|
-
num_samples: int # Number of samples used
|
|
35
|
-
|
|
36
|
-
def __str__(self) -> str:
|
|
37
|
-
return (f"ELBO Components:\n"
|
|
38
|
-
f" Joint term: {self.joint_term.item():.6f}\n"
|
|
39
|
-
f" Entropy term: {self.entropy_term.item():.6f}\n"
|
|
40
|
-
f" Total ELBO: {self.elbo.item():.6f}\n"
|
|
41
|
-
f" Samples: {self.num_samples}")
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class ELBOComputer(nn.Module):
|
|
45
|
-
"""
|
|
46
|
-
Computes Evidence Lower Bound (ELBO) for NPLL training
|
|
47
|
-
|
|
48
|
-
Paper Section 4: The optimization objective becomes maximizing the ELBO value
|
|
49
|
-
|
|
50
|
-
Implements exact computation of Equations 3-5:
|
|
51
|
-
- Equation 3: log P(F|ω) = log[P(F,U|ω)/Q(U)] - log[P(U|F,ω)/Q(U)]
|
|
52
|
-
- Equation 4: log P(F|ω) = ELBO + KL(q||p)
|
|
53
|
-
- Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
def __init__(self, config: NPLLConfig):
|
|
57
|
-
super().__init__()
|
|
58
|
-
self.config = config
|
|
59
|
-
self.elbo_weight = config.elbo_weight
|
|
60
|
-
self.kl_weight = config.kl_weight
|
|
61
|
-
|
|
62
|
-
def compute_elbo(self,
|
|
63
|
-
mln: MarkovLogicNetwork,
|
|
64
|
-
known_facts: List[Triple],
|
|
65
|
-
unknown_facts: List[Triple],
|
|
66
|
-
approximate_posterior_probs: torch.Tensor,
|
|
67
|
-
ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> ELBOComponents:
|
|
68
|
-
"""
|
|
69
|
-
Compute ELBO following Equation 5
|
|
70
|
-
"""
|
|
71
|
-
if len(unknown_facts) == 0:
|
|
72
|
-
# No unknown facts, ELBO is just log P(F|ω)
|
|
73
|
-
log_known_prob = self._compute_known_facts_probability(mln, known_facts)
|
|
74
|
-
zero = self._scalar_like(0.0, log_known_prob)
|
|
75
|
-
return ELBOComponents(
|
|
76
|
-
joint_term=log_known_prob,
|
|
77
|
-
entropy_term=zero,
|
|
78
|
-
elbo=log_known_prob,
|
|
79
|
-
kl_divergence=zero,
|
|
80
|
-
num_samples=0
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
# Sample possible worlds for ELBO computation
|
|
84
|
-
sampled_worlds = self._sample_possible_worlds(
|
|
85
|
-
known_facts, unknown_facts, approximate_posterior_probs
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
if not sampled_worlds:
|
|
89
|
-
logger.warning("No sampled worlds for ELBO computation")
|
|
90
|
-
zero = self._scalar_like(0.0, approximate_posterior_probs)
|
|
91
|
-
return ELBOComponents(
|
|
92
|
-
joint_term=zero,
|
|
93
|
-
entropy_term=zero,
|
|
94
|
-
elbo=zero,
|
|
95
|
-
kl_divergence=zero,
|
|
96
|
-
num_samples=0
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# Compute joint term: E_Q[log P(F,U|ω)] via MC average (do NOT weight by Q again)
|
|
100
|
-
joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=False)
|
|
101
|
-
|
|
102
|
-
# Compute entropy term: -Σ_U Q(U) log Q(U)
|
|
103
|
-
entropy_term = self._compute_entropy_term(approximate_posterior_probs)
|
|
104
|
-
|
|
105
|
-
# Total ELBO
|
|
106
|
-
elbo = joint_term + entropy_term
|
|
107
|
-
|
|
108
|
-
# Optional: compute KL divergence if true posterior is available
|
|
109
|
-
kl_div = torch.tensor(0.0)
|
|
110
|
-
return ELBOComponents(
|
|
111
|
-
joint_term=joint_term,
|
|
112
|
-
entropy_term=entropy_term,
|
|
113
|
-
elbo=elbo,
|
|
114
|
-
kl_divergence=kl_div,
|
|
115
|
-
num_samples=len(sampled_worlds)
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
def _compute_joint_term(self,
|
|
119
|
-
mln: MarkovLogicNetwork,
|
|
120
|
-
sampled_worlds: List[Dict[Triple, bool]],
|
|
121
|
-
propagate_grads: bool = False) -> torch.Tensor:
|
|
122
|
-
"""
|
|
123
|
-
Monte-Carlo estimate of E_Q[log P(F,U|ω)] as a simple average over worlds.
|
|
124
|
-
If propagate_grads=True, allow gradients to flow to MLN weights.
|
|
125
|
-
"""
|
|
126
|
-
if not sampled_worlds:
|
|
127
|
-
ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
|
|
128
|
-
return self._scalar_like(0.0, ref)
|
|
129
|
-
|
|
130
|
-
log_joint = []
|
|
131
|
-
for world_assignment in sampled_worlds:
|
|
132
|
-
logp = mln.compute_joint_probability(world_assignment, detach_weights=not propagate_grads)
|
|
133
|
-
log_joint.append(logp)
|
|
134
|
-
log_joint_tensor = torch.stack(log_joint)
|
|
135
|
-
return log_joint_tensor.mean()
|
|
136
|
-
|
|
137
|
-
def _compute_entropy_term(self, q_probs: torch.Tensor) -> torch.Tensor:
|
|
138
|
-
"""
|
|
139
|
-
Compute entropy term: -Σ_U Q(U) log Q(U)
|
|
140
|
-
|
|
141
|
-
This is the second term in Equation 5
|
|
142
|
-
"""
|
|
143
|
-
# Clamp probabilities to avoid log(0)
|
|
144
|
-
q_probs_clamped = torch.clamp(q_probs, min=1e-8, max=1.0 - 1e-8)
|
|
145
|
-
|
|
146
|
-
# Bernoulli entropy: -[p*log(p) + (1-p)*log(1-p)]
|
|
147
|
-
entropy = -(q_probs_clamped * safe_log(q_probs_clamped) +
|
|
148
|
-
(1 - q_probs_clamped) * safe_log(1 - q_probs_clamped))
|
|
149
|
-
|
|
150
|
-
# Sum over all unknown facts
|
|
151
|
-
total_entropy = torch.sum(entropy)
|
|
152
|
-
return total_entropy
|
|
153
|
-
|
|
154
|
-
def _sample_possible_worlds(self,
|
|
155
|
-
known_facts: List[Triple],
|
|
156
|
-
unknown_facts: List[Triple],
|
|
157
|
-
q_probs: torch.Tensor,
|
|
158
|
-
num_samples: int = None) -> List[Dict[Triple, bool]]:
|
|
159
|
-
"""Sample worlds over unknown facts only, preserving order; vectorized draws."""
|
|
160
|
-
M = len(unknown_facts)
|
|
161
|
-
if M == 0:
|
|
162
|
-
return []
|
|
163
|
-
num_samples = num_samples or min(100, 2 ** min(M, 10))
|
|
164
|
-
device = q_probs.device
|
|
165
|
-
probs = q_probs.unsqueeze(0).expand(num_samples, -1)
|
|
166
|
-
samples = torch.bernoulli(probs).bool()
|
|
167
|
-
worlds: List[Dict[Triple, bool]] = []
|
|
168
|
-
for s in samples:
|
|
169
|
-
w = {f: True for f in known_facts}
|
|
170
|
-
for i, fact in enumerate(unknown_facts):
|
|
171
|
-
w[fact] = bool(s[i].item())
|
|
172
|
-
worlds.append(w)
|
|
173
|
-
return worlds
|
|
174
|
-
|
|
175
|
-
def _compute_world_q_probability(self,
|
|
176
|
-
world_assignment: Dict[Triple, bool],
|
|
177
|
-
q_probs: torch.Tensor,
|
|
178
|
-
unknown_facts: Optional[List[Triple]] = None) -> torch.Tensor:
|
|
179
|
-
"""
|
|
180
|
-
Compute Q(U) probability for a specific world assignment
|
|
181
|
-
|
|
182
|
-
Q(U) = ∏_{uk∈U} p_k^{uk} (1-p_k)^{1-uk}
|
|
183
|
-
"""
|
|
184
|
-
ref = q_probs
|
|
185
|
-
log_prob = self._scalar_like(0.0, ref)
|
|
186
|
-
facts = unknown_facts if unknown_facts is not None else [f for f in world_assignment.keys() if f not in []]
|
|
187
|
-
for i, fact in enumerate(facts):
|
|
188
|
-
if i >= len(q_probs):
|
|
189
|
-
break
|
|
190
|
-
prob = torch.clamp(q_probs[i], min=1e-8, max=1.0 - 1e-8)
|
|
191
|
-
truth_value = world_assignment[fact]
|
|
192
|
-
log_prob = log_prob + (safe_log(prob) if truth_value else safe_log(1 - prob))
|
|
193
|
-
return torch.exp(log_prob)
|
|
194
|
-
|
|
195
|
-
def _compute_known_facts_probability(self,
|
|
196
|
-
mln: MarkovLogicNetwork,
|
|
197
|
-
known_facts: List[Triple]) -> torch.Tensor:
|
|
198
|
-
"""Compute log P(F|ω) when no unknown facts"""
|
|
199
|
-
if not known_facts:
|
|
200
|
-
return torch.tensor(0.0)
|
|
201
|
-
|
|
202
|
-
# Create assignment with all known facts as true
|
|
203
|
-
known_assignment = {fact: True for fact in known_facts}
|
|
204
|
-
|
|
205
|
-
# Compute joint probability
|
|
206
|
-
log_prob = mln.compute_joint_probability(known_assignment)
|
|
207
|
-
|
|
208
|
-
return log_prob
|
|
209
|
-
|
|
210
|
-
def compute_elbo_gradient(self,
|
|
211
|
-
mln: MarkovLogicNetwork,
|
|
212
|
-
known_facts: List[Triple],
|
|
213
|
-
unknown_facts: List[Triple],
|
|
214
|
-
q_probs: torch.Tensor) -> torch.Tensor:
|
|
215
|
-
"""Compute ∂ELBO/∂ω allowing gradients to MLN weights."""
|
|
216
|
-
if mln.rule_weights is not None:
|
|
217
|
-
mln.rule_weights.requires_grad_(True)
|
|
218
|
-
if mln.rule_weights.grad is not None:
|
|
219
|
-
mln.rule_weights.grad.zero_()
|
|
220
|
-
|
|
221
|
-
sampled_worlds = self._sample_possible_worlds(known_facts, unknown_facts, q_probs)
|
|
222
|
-
joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=True)
|
|
223
|
-
entropy_term = self._compute_entropy_term(q_probs)
|
|
224
|
-
elbo = joint_term + entropy_term
|
|
225
|
-
loss = -elbo
|
|
226
|
-
loss.backward()
|
|
227
|
-
|
|
228
|
-
if mln.rule_weights is not None and mln.rule_weights.grad is not None:
|
|
229
|
-
gradients = mln.rule_weights.grad.clone()
|
|
230
|
-
else:
|
|
231
|
-
ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
|
|
232
|
-
zeros = torch.zeros(len(mln.logical_rules)) if ref is None else torch.zeros(len(mln.logical_rules), device=ref.device, dtype=ref.dtype)
|
|
233
|
-
gradients = zeros
|
|
234
|
-
return gradients
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
class ELBOLoss(nn.Module):
|
|
238
|
-
"""
|
|
239
|
-
ELBO-based loss function for NPLL training
|
|
240
|
-
|
|
241
|
-
Implements the loss function that maximizes ELBO (minimizes negative ELBO)
|
|
242
|
-
"""
|
|
243
|
-
|
|
244
|
-
def __init__(self, config: NPLLConfig):
|
|
245
|
-
super().__init__()
|
|
246
|
-
self.config = config
|
|
247
|
-
self.elbo_computer = ELBOComputer(config)
|
|
248
|
-
|
|
249
|
-
# Loss weights
|
|
250
|
-
self.elbo_weight = config.elbo_weight
|
|
251
|
-
self.regularization_weight = 0.01
|
|
252
|
-
|
|
253
|
-
def forward(self,
|
|
254
|
-
mln: MarkovLogicNetwork,
|
|
255
|
-
known_facts: List[Triple],
|
|
256
|
-
unknown_facts: List[Triple],
|
|
257
|
-
predicted_probs: torch.Tensor,
|
|
258
|
-
target_probs: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
259
|
-
"""
|
|
260
|
-
Compute ELBO loss
|
|
261
|
-
|
|
262
|
-
"""
|
|
263
|
-
# Compute ELBO components
|
|
264
|
-
elbo_components = self.elbo_computer.compute_elbo(
|
|
265
|
-
mln, known_facts, unknown_facts, predicted_probs
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
# Primary loss: negative ELBO (we want to maximize ELBO)
|
|
269
|
-
elbo_loss = -elbo_components.elbo
|
|
270
|
-
|
|
271
|
-
# Optional supervised loss if targets provided
|
|
272
|
-
supervised_loss = torch.tensor(0.0)
|
|
273
|
-
if target_probs is not None:
|
|
274
|
-
supervised_loss = F.binary_cross_entropy(predicted_probs, target_probs)
|
|
275
|
-
|
|
276
|
-
# Regularization loss on rule weights
|
|
277
|
-
regularization_loss = torch.tensor(0.0)
|
|
278
|
-
if mln.rule_weights is not None:
|
|
279
|
-
regularization_loss = torch.sum(mln.rule_weights ** 2)
|
|
280
|
-
|
|
281
|
-
# Total loss
|
|
282
|
-
total_loss = (self.elbo_weight * elbo_loss +
|
|
283
|
-
supervised_loss +
|
|
284
|
-
self.regularization_weight * regularization_loss)
|
|
285
|
-
|
|
286
|
-
return {
|
|
287
|
-
'total_loss': total_loss,
|
|
288
|
-
'elbo_loss': elbo_loss,
|
|
289
|
-
'supervised_loss': supervised_loss,
|
|
290
|
-
'regularization_loss': regularization_loss,
|
|
291
|
-
'elbo_components': elbo_components
|
|
292
|
-
}
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
class VariationalInference:
|
|
296
|
-
"""
|
|
297
|
-
Variational inference for NPLL using ELBO optimization
|
|
298
|
-
"""
|
|
299
|
-
|
|
300
|
-
def __init__(self, config: NPLLConfig):
|
|
301
|
-
self.config = config
|
|
302
|
-
self.elbo_computer = ELBOComputer(config)
|
|
303
|
-
self.convergence_threshold = config.convergence_threshold
|
|
304
|
-
self.max_iterations = config.em_iterations
|
|
305
|
-
|
|
306
|
-
def optimize_approximate_posterior(self,
|
|
307
|
-
mln: MarkovLogicNetwork,
|
|
308
|
-
known_facts: List[Triple],
|
|
309
|
-
unknown_facts: List[Triple],
|
|
310
|
-
initial_probs: Optional[torch.Tensor] = None) -> Dict[str, Any]:
|
|
311
|
-
"""
|
|
312
|
-
Optimize approximate posterior Q(U) to maximize ELBO
|
|
313
|
-
|
|
314
|
-
This implements the variational inference component of the E-step
|
|
315
|
-
|
|
316
|
-
"""
|
|
317
|
-
num_unknown = len(unknown_facts)
|
|
318
|
-
if num_unknown == 0:
|
|
319
|
-
return {
|
|
320
|
-
'optimized_probs': torch.tensor([]),
|
|
321
|
-
'elbo_history': [],
|
|
322
|
-
'converged': True,
|
|
323
|
-
'iterations': 0
|
|
324
|
-
}
|
|
325
|
-
|
|
326
|
-
# Initialize logits phi (unconstrained)
|
|
327
|
-
if initial_probs is None:
|
|
328
|
-
phi = torch.zeros(num_unknown)
|
|
329
|
-
else:
|
|
330
|
-
with torch.no_grad():
|
|
331
|
-
init = torch.clamp(initial_probs, 1e-6, 1 - 1e-6)
|
|
332
|
-
phi = torch.log(init / (1 - init))
|
|
333
|
-
device = mln.rule_weights.device if (mln.rule_weights is not None) else torch.device('cpu')
|
|
334
|
-
phi = phi.to(device)
|
|
335
|
-
phi.requires_grad_(True)
|
|
336
|
-
|
|
337
|
-
optimizer = torch.optim.Adam([phi], lr=0.01)
|
|
338
|
-
elbo_history: List[float] = []
|
|
339
|
-
prev_elbo = float('-inf')
|
|
340
|
-
|
|
341
|
-
for iteration in range(self.max_iterations):
|
|
342
|
-
optimizer.zero_grad()
|
|
343
|
-
q_probs = torch.sigmoid(phi)
|
|
344
|
-
elbo_components = self.elbo_computer.compute_elbo(
|
|
345
|
-
mln, known_facts, unknown_facts, q_probs
|
|
346
|
-
)
|
|
347
|
-
loss = -elbo_components.elbo
|
|
348
|
-
loss.backward(retain_graph=True)
|
|
349
|
-
torch.nn.utils.clip_grad_norm_([phi], max_norm=1.0)
|
|
350
|
-
optimizer.step()
|
|
351
|
-
|
|
352
|
-
# Detach to prevent gradient accumulation issues
|
|
353
|
-
current_elbo = elbo_components.elbo.detach().item()
|
|
354
|
-
elbo_history.append(current_elbo)
|
|
355
|
-
if abs(current_elbo - prev_elbo) < self.convergence_threshold:
|
|
356
|
-
logger.debug(f"Variational inference converged at iteration {iteration}")
|
|
357
|
-
break
|
|
358
|
-
prev_elbo = current_elbo
|
|
359
|
-
|
|
360
|
-
return {
|
|
361
|
-
'optimized_probs': torch.sigmoid(phi).detach(),
|
|
362
|
-
'elbo_history': elbo_history,
|
|
363
|
-
'converged': len(elbo_history) < self.max_iterations,
|
|
364
|
-
'iterations': len(elbo_history),
|
|
365
|
-
'final_elbo': elbo_history[-1] if elbo_history else float('-inf')
|
|
366
|
-
}
|
|
367
|
-
|
|
368
|
-
def _scalar_like(self, value: float, ref: Optional[torch.Tensor]) -> torch.Tensor:
|
|
369
|
-
"""Create a scalar tensor on the same device/dtype as ref (if provided)."""
|
|
370
|
-
if ref is None:
|
|
371
|
-
return torch.tensor(value)
|
|
372
|
-
return torch.tensor(value, device=ref.device, dtype=ref.dtype)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
def create_elbo_computer(config: NPLLConfig) -> ELBOComputer:
|
|
376
|
-
"""Factory function to create ELBO computer"""
|
|
377
|
-
return ELBOComputer(config)
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
def verify_elbo_implementation():
|
|
381
|
-
"""Verify ELBO computation implementation"""
|
|
382
|
-
from ..utils.config import default_config
|
|
383
|
-
from ..core import load_knowledge_graph_from_triples
|
|
384
|
-
from ..core.mln import create_mln_from_kg_and_rules
|
|
385
|
-
from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
|
|
386
|
-
|
|
387
|
-
# Create test data
|
|
388
|
-
test_triples = [
|
|
389
|
-
("A", "r1", "B"),
|
|
390
|
-
("B", "r2", "C")
|
|
391
|
-
]
|
|
392
|
-
|
|
393
|
-
kg = load_knowledge_graph_from_triples(test_triples)
|
|
394
|
-
|
|
395
|
-
# Create test rule
|
|
396
|
-
from ..core import Relation
|
|
397
|
-
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
398
|
-
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
399
|
-
|
|
400
|
-
test_rule = LogicalRule(
|
|
401
|
-
rule_id="test_elbo_rule",
|
|
402
|
-
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
403
|
-
head=Atom(r3, (x, z)),
|
|
404
|
-
rule_type=RuleType.TRANSITIVITY
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
# Create MLN
|
|
408
|
-
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
409
|
-
|
|
410
|
-
# Test ELBO computation
|
|
411
|
-
elbo_computer = ELBOComputer(default_config)
|
|
412
|
-
|
|
413
|
-
known_facts = list(kg.known_facts)
|
|
414
|
-
unknown_facts = [list(kg.known_facts)[0]] # Treat one known fact as unknown for testing
|
|
415
|
-
q_probs = torch.tensor([0.8]) # High probability for the "unknown" fact
|
|
416
|
-
|
|
417
|
-
elbo_components = elbo_computer.compute_elbo(
|
|
418
|
-
mln, known_facts[:-1], unknown_facts, q_probs
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
# Verify components
|
|
422
|
-
assert torch.isfinite(elbo_components.elbo), "ELBO should be finite"
|
|
423
|
-
assert torch.isfinite(elbo_components.joint_term), "Joint term should be finite"
|
|
424
|
-
assert torch.isfinite(elbo_components.entropy_term), "Entropy term should be finite"
|
|
425
|
-
|
|
426
|
-
# Test variational inference
|
|
427
|
-
vi = VariationalInference(default_config)
|
|
428
|
-
vi_result = vi.optimize_approximate_posterior(mln, known_facts[:-1], unknown_facts)
|
|
429
|
-
|
|
430
|
-
assert len(vi_result['optimized_probs']) == len(unknown_facts), "Should optimize all unknown facts"
|
|
431
|
-
assert len(vi_result['elbo_history']) > 0, "Should have ELBO history"
|
|
432
|
-
|
|
433
|
-
logger.info("ELBO implementation verified successfully")
|
|
434
|
-
|
|
1
|
+
"""
|
|
2
|
+
ELBO (Evidence Lower Bound) computation for NPLL
|
|
3
|
+
Exact implementation of Equations 3-5 from the paper
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from typing import List, Dict, Tuple, Optional, Any
|
|
10
|
+
import logging
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
14
|
+
from ..core.mln import MarkovLogicNetwork
|
|
15
|
+
from ..utils.config import NPLLConfig
|
|
16
|
+
from ..utils.math_utils import (
|
|
17
|
+
safe_log, bernoulli_log_prob, kl_divergence_bernoulli,
|
|
18
|
+
log_sum_exp, compute_elbo_loss
|
|
19
|
+
)
|
|
20
|
+
from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ELBOComponents:
|
|
27
|
+
"""
|
|
28
|
+
Components of ELBO computation for analysis and debugging
|
|
29
|
+
"""
|
|
30
|
+
joint_term: torch.Tensor # Σ_U Q(U) log P(F,U|ω)
|
|
31
|
+
entropy_term: torch.Tensor # -Σ_U Q(U) log Q(U)
|
|
32
|
+
elbo: torch.Tensor # Total ELBO
|
|
33
|
+
kl_divergence: torch.Tensor # KL(Q||P) if available
|
|
34
|
+
num_samples: int # Number of samples used
|
|
35
|
+
|
|
36
|
+
def __str__(self) -> str:
|
|
37
|
+
return (f"ELBO Components:\n"
|
|
38
|
+
f" Joint term: {self.joint_term.item():.6f}\n"
|
|
39
|
+
f" Entropy term: {self.entropy_term.item():.6f}\n"
|
|
40
|
+
f" Total ELBO: {self.elbo.item():.6f}\n"
|
|
41
|
+
f" Samples: {self.num_samples}")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ELBOComputer(nn.Module):
|
|
45
|
+
"""
|
|
46
|
+
Computes Evidence Lower Bound (ELBO) for NPLL training
|
|
47
|
+
|
|
48
|
+
Paper Section 4: The optimization objective becomes maximizing the ELBO value
|
|
49
|
+
|
|
50
|
+
Implements exact computation of Equations 3-5:
|
|
51
|
+
- Equation 3: log P(F|ω) = log[P(F,U|ω)/Q(U)] - log[P(U|F,ω)/Q(U)]
|
|
52
|
+
- Equation 4: log P(F|ω) = ELBO + KL(q||p)
|
|
53
|
+
- Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, config: NPLLConfig):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.config = config
|
|
59
|
+
self.elbo_weight = config.elbo_weight
|
|
60
|
+
self.kl_weight = config.kl_weight
|
|
61
|
+
|
|
62
|
+
def compute_elbo(self,
|
|
63
|
+
mln: MarkovLogicNetwork,
|
|
64
|
+
known_facts: List[Triple],
|
|
65
|
+
unknown_facts: List[Triple],
|
|
66
|
+
approximate_posterior_probs: torch.Tensor,
|
|
67
|
+
ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> ELBOComponents:
|
|
68
|
+
"""
|
|
69
|
+
Compute ELBO following Equation 5
|
|
70
|
+
"""
|
|
71
|
+
if len(unknown_facts) == 0:
|
|
72
|
+
# No unknown facts, ELBO is just log P(F|ω)
|
|
73
|
+
log_known_prob = self._compute_known_facts_probability(mln, known_facts)
|
|
74
|
+
zero = self._scalar_like(0.0, log_known_prob)
|
|
75
|
+
return ELBOComponents(
|
|
76
|
+
joint_term=log_known_prob,
|
|
77
|
+
entropy_term=zero,
|
|
78
|
+
elbo=log_known_prob,
|
|
79
|
+
kl_divergence=zero,
|
|
80
|
+
num_samples=0
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Sample possible worlds for ELBO computation
|
|
84
|
+
sampled_worlds = self._sample_possible_worlds(
|
|
85
|
+
known_facts, unknown_facts, approximate_posterior_probs
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if not sampled_worlds:
|
|
89
|
+
logger.warning("No sampled worlds for ELBO computation")
|
|
90
|
+
zero = self._scalar_like(0.0, approximate_posterior_probs)
|
|
91
|
+
return ELBOComponents(
|
|
92
|
+
joint_term=zero,
|
|
93
|
+
entropy_term=zero,
|
|
94
|
+
elbo=zero,
|
|
95
|
+
kl_divergence=zero,
|
|
96
|
+
num_samples=0
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Compute joint term: E_Q[log P(F,U|ω)] via MC average (do NOT weight by Q again)
|
|
100
|
+
joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=False)
|
|
101
|
+
|
|
102
|
+
# Compute entropy term: -Σ_U Q(U) log Q(U)
|
|
103
|
+
entropy_term = self._compute_entropy_term(approximate_posterior_probs)
|
|
104
|
+
|
|
105
|
+
# Total ELBO
|
|
106
|
+
elbo = joint_term + entropy_term
|
|
107
|
+
|
|
108
|
+
# Optional: compute KL divergence if true posterior is available
|
|
109
|
+
kl_div = torch.tensor(0.0)
|
|
110
|
+
return ELBOComponents(
|
|
111
|
+
joint_term=joint_term,
|
|
112
|
+
entropy_term=entropy_term,
|
|
113
|
+
elbo=elbo,
|
|
114
|
+
kl_divergence=kl_div,
|
|
115
|
+
num_samples=len(sampled_worlds)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def _compute_joint_term(self,
|
|
119
|
+
mln: MarkovLogicNetwork,
|
|
120
|
+
sampled_worlds: List[Dict[Triple, bool]],
|
|
121
|
+
propagate_grads: bool = False) -> torch.Tensor:
|
|
122
|
+
"""
|
|
123
|
+
Monte-Carlo estimate of E_Q[log P(F,U|ω)] as a simple average over worlds.
|
|
124
|
+
If propagate_grads=True, allow gradients to flow to MLN weights.
|
|
125
|
+
"""
|
|
126
|
+
if not sampled_worlds:
|
|
127
|
+
ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
|
|
128
|
+
return self._scalar_like(0.0, ref)
|
|
129
|
+
|
|
130
|
+
log_joint = []
|
|
131
|
+
for world_assignment in sampled_worlds:
|
|
132
|
+
logp = mln.compute_joint_probability(world_assignment, detach_weights=not propagate_grads)
|
|
133
|
+
log_joint.append(logp)
|
|
134
|
+
log_joint_tensor = torch.stack(log_joint)
|
|
135
|
+
return log_joint_tensor.mean()
|
|
136
|
+
|
|
137
|
+
def _compute_entropy_term(self, q_probs: torch.Tensor) -> torch.Tensor:
|
|
138
|
+
"""
|
|
139
|
+
Compute entropy term: -Σ_U Q(U) log Q(U)
|
|
140
|
+
|
|
141
|
+
This is the second term in Equation 5
|
|
142
|
+
"""
|
|
143
|
+
# Clamp probabilities to avoid log(0)
|
|
144
|
+
q_probs_clamped = torch.clamp(q_probs, min=1e-8, max=1.0 - 1e-8)
|
|
145
|
+
|
|
146
|
+
# Bernoulli entropy: -[p*log(p) + (1-p)*log(1-p)]
|
|
147
|
+
entropy = -(q_probs_clamped * safe_log(q_probs_clamped) +
|
|
148
|
+
(1 - q_probs_clamped) * safe_log(1 - q_probs_clamped))
|
|
149
|
+
|
|
150
|
+
# Sum over all unknown facts
|
|
151
|
+
total_entropy = torch.sum(entropy)
|
|
152
|
+
return total_entropy
|
|
153
|
+
|
|
154
|
+
def _sample_possible_worlds(self,
|
|
155
|
+
known_facts: List[Triple],
|
|
156
|
+
unknown_facts: List[Triple],
|
|
157
|
+
q_probs: torch.Tensor,
|
|
158
|
+
num_samples: int = None) -> List[Dict[Triple, bool]]:
|
|
159
|
+
"""Sample worlds over unknown facts only, preserving order; vectorized draws."""
|
|
160
|
+
M = len(unknown_facts)
|
|
161
|
+
if M == 0:
|
|
162
|
+
return []
|
|
163
|
+
num_samples = num_samples or min(100, 2 ** min(M, 10))
|
|
164
|
+
device = q_probs.device
|
|
165
|
+
probs = q_probs.unsqueeze(0).expand(num_samples, -1)
|
|
166
|
+
samples = torch.bernoulli(probs).bool()
|
|
167
|
+
worlds: List[Dict[Triple, bool]] = []
|
|
168
|
+
for s in samples:
|
|
169
|
+
w = {f: True for f in known_facts}
|
|
170
|
+
for i, fact in enumerate(unknown_facts):
|
|
171
|
+
w[fact] = bool(s[i].item())
|
|
172
|
+
worlds.append(w)
|
|
173
|
+
return worlds
|
|
174
|
+
|
|
175
|
+
def _compute_world_q_probability(self,
|
|
176
|
+
world_assignment: Dict[Triple, bool],
|
|
177
|
+
q_probs: torch.Tensor,
|
|
178
|
+
unknown_facts: Optional[List[Triple]] = None) -> torch.Tensor:
|
|
179
|
+
"""
|
|
180
|
+
Compute Q(U) probability for a specific world assignment
|
|
181
|
+
|
|
182
|
+
Q(U) = ∏_{uk∈U} p_k^{uk} (1-p_k)^{1-uk}
|
|
183
|
+
"""
|
|
184
|
+
ref = q_probs
|
|
185
|
+
log_prob = self._scalar_like(0.0, ref)
|
|
186
|
+
facts = unknown_facts if unknown_facts is not None else [f for f in world_assignment.keys() if f not in []]
|
|
187
|
+
for i, fact in enumerate(facts):
|
|
188
|
+
if i >= len(q_probs):
|
|
189
|
+
break
|
|
190
|
+
prob = torch.clamp(q_probs[i], min=1e-8, max=1.0 - 1e-8)
|
|
191
|
+
truth_value = world_assignment[fact]
|
|
192
|
+
log_prob = log_prob + (safe_log(prob) if truth_value else safe_log(1 - prob))
|
|
193
|
+
return torch.exp(log_prob)
|
|
194
|
+
|
|
195
|
+
def _compute_known_facts_probability(self,
|
|
196
|
+
mln: MarkovLogicNetwork,
|
|
197
|
+
known_facts: List[Triple]) -> torch.Tensor:
|
|
198
|
+
"""Compute log P(F|ω) when no unknown facts"""
|
|
199
|
+
if not known_facts:
|
|
200
|
+
return torch.tensor(0.0)
|
|
201
|
+
|
|
202
|
+
# Create assignment with all known facts as true
|
|
203
|
+
known_assignment = {fact: True for fact in known_facts}
|
|
204
|
+
|
|
205
|
+
# Compute joint probability
|
|
206
|
+
log_prob = mln.compute_joint_probability(known_assignment)
|
|
207
|
+
|
|
208
|
+
return log_prob
|
|
209
|
+
|
|
210
|
+
def compute_elbo_gradient(self,
|
|
211
|
+
mln: MarkovLogicNetwork,
|
|
212
|
+
known_facts: List[Triple],
|
|
213
|
+
unknown_facts: List[Triple],
|
|
214
|
+
q_probs: torch.Tensor) -> torch.Tensor:
|
|
215
|
+
"""Compute ∂ELBO/∂ω allowing gradients to MLN weights."""
|
|
216
|
+
if mln.rule_weights is not None:
|
|
217
|
+
mln.rule_weights.requires_grad_(True)
|
|
218
|
+
if mln.rule_weights.grad is not None:
|
|
219
|
+
mln.rule_weights.grad.zero_()
|
|
220
|
+
|
|
221
|
+
sampled_worlds = self._sample_possible_worlds(known_facts, unknown_facts, q_probs)
|
|
222
|
+
joint_term = self._compute_joint_term(mln, sampled_worlds, propagate_grads=True)
|
|
223
|
+
entropy_term = self._compute_entropy_term(q_probs)
|
|
224
|
+
elbo = joint_term + entropy_term
|
|
225
|
+
loss = -elbo
|
|
226
|
+
loss.backward()
|
|
227
|
+
|
|
228
|
+
if mln.rule_weights is not None and mln.rule_weights.grad is not None:
|
|
229
|
+
gradients = mln.rule_weights.grad.clone()
|
|
230
|
+
else:
|
|
231
|
+
ref = next(mln.parameters()).detach() if any(mln.parameters()) else None
|
|
232
|
+
zeros = torch.zeros(len(mln.logical_rules)) if ref is None else torch.zeros(len(mln.logical_rules), device=ref.device, dtype=ref.dtype)
|
|
233
|
+
gradients = zeros
|
|
234
|
+
return gradients
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class ELBOLoss(nn.Module):
|
|
238
|
+
"""
|
|
239
|
+
ELBO-based loss function for NPLL training
|
|
240
|
+
|
|
241
|
+
Implements the loss function that maximizes ELBO (minimizes negative ELBO)
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
def __init__(self, config: NPLLConfig):
|
|
245
|
+
super().__init__()
|
|
246
|
+
self.config = config
|
|
247
|
+
self.elbo_computer = ELBOComputer(config)
|
|
248
|
+
|
|
249
|
+
# Loss weights
|
|
250
|
+
self.elbo_weight = config.elbo_weight
|
|
251
|
+
self.regularization_weight = 0.01
|
|
252
|
+
|
|
253
|
+
def forward(self,
|
|
254
|
+
mln: MarkovLogicNetwork,
|
|
255
|
+
known_facts: List[Triple],
|
|
256
|
+
unknown_facts: List[Triple],
|
|
257
|
+
predicted_probs: torch.Tensor,
|
|
258
|
+
target_probs: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
259
|
+
"""
|
|
260
|
+
Compute ELBO loss
|
|
261
|
+
|
|
262
|
+
"""
|
|
263
|
+
# Compute ELBO components
|
|
264
|
+
elbo_components = self.elbo_computer.compute_elbo(
|
|
265
|
+
mln, known_facts, unknown_facts, predicted_probs
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Primary loss: negative ELBO (we want to maximize ELBO)
|
|
269
|
+
elbo_loss = -elbo_components.elbo
|
|
270
|
+
|
|
271
|
+
# Optional supervised loss if targets provided
|
|
272
|
+
supervised_loss = torch.tensor(0.0)
|
|
273
|
+
if target_probs is not None:
|
|
274
|
+
supervised_loss = F.binary_cross_entropy(predicted_probs, target_probs)
|
|
275
|
+
|
|
276
|
+
# Regularization loss on rule weights
|
|
277
|
+
regularization_loss = torch.tensor(0.0)
|
|
278
|
+
if mln.rule_weights is not None:
|
|
279
|
+
regularization_loss = torch.sum(mln.rule_weights ** 2)
|
|
280
|
+
|
|
281
|
+
# Total loss
|
|
282
|
+
total_loss = (self.elbo_weight * elbo_loss +
|
|
283
|
+
supervised_loss +
|
|
284
|
+
self.regularization_weight * regularization_loss)
|
|
285
|
+
|
|
286
|
+
return {
|
|
287
|
+
'total_loss': total_loss,
|
|
288
|
+
'elbo_loss': elbo_loss,
|
|
289
|
+
'supervised_loss': supervised_loss,
|
|
290
|
+
'regularization_loss': regularization_loss,
|
|
291
|
+
'elbo_components': elbo_components
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class VariationalInference:
|
|
296
|
+
"""
|
|
297
|
+
Variational inference for NPLL using ELBO optimization
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(self, config: NPLLConfig):
|
|
301
|
+
self.config = config
|
|
302
|
+
self.elbo_computer = ELBOComputer(config)
|
|
303
|
+
self.convergence_threshold = config.convergence_threshold
|
|
304
|
+
self.max_iterations = config.em_iterations
|
|
305
|
+
|
|
306
|
+
def optimize_approximate_posterior(self,
|
|
307
|
+
mln: MarkovLogicNetwork,
|
|
308
|
+
known_facts: List[Triple],
|
|
309
|
+
unknown_facts: List[Triple],
|
|
310
|
+
initial_probs: Optional[torch.Tensor] = None) -> Dict[str, Any]:
|
|
311
|
+
"""
|
|
312
|
+
Optimize approximate posterior Q(U) to maximize ELBO
|
|
313
|
+
|
|
314
|
+
This implements the variational inference component of the E-step
|
|
315
|
+
|
|
316
|
+
"""
|
|
317
|
+
num_unknown = len(unknown_facts)
|
|
318
|
+
if num_unknown == 0:
|
|
319
|
+
return {
|
|
320
|
+
'optimized_probs': torch.tensor([]),
|
|
321
|
+
'elbo_history': [],
|
|
322
|
+
'converged': True,
|
|
323
|
+
'iterations': 0
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
# Initialize logits phi (unconstrained)
|
|
327
|
+
if initial_probs is None:
|
|
328
|
+
phi = torch.zeros(num_unknown)
|
|
329
|
+
else:
|
|
330
|
+
with torch.no_grad():
|
|
331
|
+
init = torch.clamp(initial_probs, 1e-6, 1 - 1e-6)
|
|
332
|
+
phi = torch.log(init / (1 - init))
|
|
333
|
+
device = mln.rule_weights.device if (mln.rule_weights is not None) else torch.device('cpu')
|
|
334
|
+
phi = phi.to(device)
|
|
335
|
+
phi.requires_grad_(True)
|
|
336
|
+
|
|
337
|
+
optimizer = torch.optim.Adam([phi], lr=0.01)
|
|
338
|
+
elbo_history: List[float] = []
|
|
339
|
+
prev_elbo = float('-inf')
|
|
340
|
+
|
|
341
|
+
for iteration in range(self.max_iterations):
|
|
342
|
+
optimizer.zero_grad()
|
|
343
|
+
q_probs = torch.sigmoid(phi)
|
|
344
|
+
elbo_components = self.elbo_computer.compute_elbo(
|
|
345
|
+
mln, known_facts, unknown_facts, q_probs
|
|
346
|
+
)
|
|
347
|
+
loss = -elbo_components.elbo
|
|
348
|
+
loss.backward(retain_graph=True)
|
|
349
|
+
torch.nn.utils.clip_grad_norm_([phi], max_norm=1.0)
|
|
350
|
+
optimizer.step()
|
|
351
|
+
|
|
352
|
+
# Detach to prevent gradient accumulation issues
|
|
353
|
+
current_elbo = elbo_components.elbo.detach().item()
|
|
354
|
+
elbo_history.append(current_elbo)
|
|
355
|
+
if abs(current_elbo - prev_elbo) < self.convergence_threshold:
|
|
356
|
+
logger.debug(f"Variational inference converged at iteration {iteration}")
|
|
357
|
+
break
|
|
358
|
+
prev_elbo = current_elbo
|
|
359
|
+
|
|
360
|
+
return {
|
|
361
|
+
'optimized_probs': torch.sigmoid(phi).detach(),
|
|
362
|
+
'elbo_history': elbo_history,
|
|
363
|
+
'converged': len(elbo_history) < self.max_iterations,
|
|
364
|
+
'iterations': len(elbo_history),
|
|
365
|
+
'final_elbo': elbo_history[-1] if elbo_history else float('-inf')
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
def _scalar_like(self, value: float, ref: Optional[torch.Tensor]) -> torch.Tensor:
|
|
369
|
+
"""Create a scalar tensor on the same device/dtype as ref (if provided)."""
|
|
370
|
+
if ref is None:
|
|
371
|
+
return torch.tensor(value)
|
|
372
|
+
return torch.tensor(value, device=ref.device, dtype=ref.dtype)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def create_elbo_computer(config: NPLLConfig) -> ELBOComputer:
|
|
376
|
+
"""Factory function to create ELBO computer"""
|
|
377
|
+
return ELBOComputer(config)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def verify_elbo_implementation():
|
|
381
|
+
"""Verify ELBO computation implementation"""
|
|
382
|
+
from ..utils.config import default_config
|
|
383
|
+
from ..core import load_knowledge_graph_from_triples
|
|
384
|
+
from ..core.mln import create_mln_from_kg_and_rules
|
|
385
|
+
from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
|
|
386
|
+
|
|
387
|
+
# Create test data
|
|
388
|
+
test_triples = [
|
|
389
|
+
("A", "r1", "B"),
|
|
390
|
+
("B", "r2", "C")
|
|
391
|
+
]
|
|
392
|
+
|
|
393
|
+
kg = load_knowledge_graph_from_triples(test_triples)
|
|
394
|
+
|
|
395
|
+
# Create test rule
|
|
396
|
+
from ..core import Relation
|
|
397
|
+
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
398
|
+
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
399
|
+
|
|
400
|
+
test_rule = LogicalRule(
|
|
401
|
+
rule_id="test_elbo_rule",
|
|
402
|
+
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
403
|
+
head=Atom(r3, (x, z)),
|
|
404
|
+
rule_type=RuleType.TRANSITIVITY
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Create MLN
|
|
408
|
+
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
409
|
+
|
|
410
|
+
# Test ELBO computation
|
|
411
|
+
elbo_computer = ELBOComputer(default_config)
|
|
412
|
+
|
|
413
|
+
known_facts = list(kg.known_facts)
|
|
414
|
+
unknown_facts = [list(kg.known_facts)[0]] # Treat one known fact as unknown for testing
|
|
415
|
+
q_probs = torch.tensor([0.8]) # High probability for the "unknown" fact
|
|
416
|
+
|
|
417
|
+
elbo_components = elbo_computer.compute_elbo(
|
|
418
|
+
mln, known_facts[:-1], unknown_facts, q_probs
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Verify components
|
|
422
|
+
assert torch.isfinite(elbo_components.elbo), "ELBO should be finite"
|
|
423
|
+
assert torch.isfinite(elbo_components.joint_term), "Joint term should be finite"
|
|
424
|
+
assert torch.isfinite(elbo_components.entropy_term), "Entropy term should be finite"
|
|
425
|
+
|
|
426
|
+
# Test variational inference
|
|
427
|
+
vi = VariationalInference(default_config)
|
|
428
|
+
vi_result = vi.optimize_approximate_posterior(mln, known_facts[:-1], unknown_facts)
|
|
429
|
+
|
|
430
|
+
assert len(vi_result['optimized_probs']) == len(unknown_facts), "Should optimize all unknown facts"
|
|
431
|
+
assert len(vi_result['elbo_history']) > 0, "Should have ELBO history"
|
|
432
|
+
|
|
433
|
+
logger.info("ELBO implementation verified successfully")
|
|
434
|
+
|
|
435
435
|
return True
|