odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/inference/e_step.py
CHANGED
|
@@ -1,420 +1,420 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
import torch.nn.functional as F
|
|
4
|
-
from typing import List, Dict, Set, Tuple, Optional, Any
|
|
5
|
-
import logging
|
|
6
|
-
from dataclasses import dataclass
|
|
7
|
-
from collections import defaultdict
|
|
8
|
-
|
|
9
|
-
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
10
|
-
from ..core.mln import MarkovLogicNetwork
|
|
11
|
-
from ..scoring import NPLLScoringModule
|
|
12
|
-
from ..utils.config import NPLLConfig
|
|
13
|
-
from ..utils.math_utils import (
|
|
14
|
-
safe_log, safe_sigmoid, bernoulli_log_prob, bernoulli_entropy,
|
|
15
|
-
kl_divergence_bernoulli
|
|
16
|
-
)
|
|
17
|
-
from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
|
|
18
|
-
from .elbo import ELBOComputer, VariationalInference
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger(__name__)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@dataclass
|
|
24
|
-
class EStepResult:
|
|
25
|
-
"""
|
|
26
|
-
Result of E-step computation containing all relevant outputs
|
|
27
|
-
"""
|
|
28
|
-
approximate_posterior_probs: torch.Tensor # Q(U) probabilities
|
|
29
|
-
fact_probabilities: Dict[Triple, float] # Individual fact probabilities
|
|
30
|
-
ground_rule_expectations: torch.Tensor # Expected ground rule counts
|
|
31
|
-
entropy: torch.Tensor # Total entropy of Q(U)
|
|
32
|
-
elbo_value: torch.Tensor # Current ELBO value
|
|
33
|
-
convergence_info: Dict[str, Any] # Convergence diagnostics
|
|
34
|
-
iteration_count: int # Number of iterations used
|
|
35
|
-
|
|
36
|
-
def __str__(self) -> str:
|
|
37
|
-
return (f"E-step Result:\n"
|
|
38
|
-
f" Unknown facts: {len(self.approximate_posterior_probs)}\n"
|
|
39
|
-
f" Mean probability: {self.approximate_posterior_probs.mean().item():.4f}\n"
|
|
40
|
-
f" Entropy: {self.entropy.item():.4f}\n"
|
|
41
|
-
f" ELBO: {self.elbo_value.item():.4f}\n"
|
|
42
|
-
f" Iterations: {self.iteration_count}")
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class MeanFieldApproximation(nn.Module):
|
|
46
|
-
"""
|
|
47
|
-
Mean-field approximation for approximate posterior Q(U)
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(self, config: NPLLConfig):
|
|
51
|
-
super().__init__()
|
|
52
|
-
self.config = config
|
|
53
|
-
self.max_iterations = config.mean_field_iterations
|
|
54
|
-
self.convergence_threshold = config.convergence_threshold
|
|
55
|
-
|
|
56
|
-
def initialize_fact_probabilities(self, unknown_facts: List[Triple],
|
|
57
|
-
scoring_module: NPLLScoringModule) -> torch.Tensor:
|
|
58
|
-
"""
|
|
59
|
-
Initialize fact probabilities using scoring module
|
|
60
|
-
"""
|
|
61
|
-
if not unknown_facts:
|
|
62
|
-
return torch.tensor([])
|
|
63
|
-
|
|
64
|
-
# Get initial scores from scoring module
|
|
65
|
-
with torch.no_grad():
|
|
66
|
-
initial_scores = scoring_module.forward(unknown_facts)
|
|
67
|
-
# Transform to probabilities using sigmoid
|
|
68
|
-
initial_probs = safe_sigmoid(initial_scores)
|
|
69
|
-
|
|
70
|
-
return initial_probs
|
|
71
|
-
|
|
72
|
-
def compute_q_u_distribution(self, fact_probs: torch.Tensor,
|
|
73
|
-
ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
|
|
74
|
-
"""
|
|
75
|
-
Compute Q(U) distribution components
|
|
76
|
-
"""
|
|
77
|
-
if len(fact_probs) == 0:
|
|
78
|
-
return {
|
|
79
|
-
'fact_probs': torch.tensor([]),
|
|
80
|
-
'log_probs': torch.tensor([]),
|
|
81
|
-
'entropy': torch.tensor(0.0)
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
# Clamp probabilities for numerical stability
|
|
85
|
-
fact_probs_clamped = torch.clamp(fact_probs, min=1e-8, max=1.0 - 1e-8)
|
|
86
|
-
|
|
87
|
-
# Compute log probabilities
|
|
88
|
-
log_probs = safe_log(fact_probs_clamped)
|
|
89
|
-
log_neg_probs = safe_log(1 - fact_probs_clamped)
|
|
90
|
-
|
|
91
|
-
# Compute entropy of individual facts
|
|
92
|
-
fact_entropies = bernoulli_entropy(fact_probs_clamped)
|
|
93
|
-
total_entropy = torch.sum(fact_entropies)
|
|
94
|
-
|
|
95
|
-
result = {
|
|
96
|
-
'fact_probs': fact_probs_clamped,
|
|
97
|
-
'log_probs': log_probs,
|
|
98
|
-
'log_neg_probs': log_neg_probs,
|
|
99
|
-
'fact_entropies': fact_entropies,
|
|
100
|
-
'total_entropy': total_entropy
|
|
101
|
-
}
|
|
102
|
-
|
|
103
|
-
# If ground rule structure provided, compute ground rule probabilities
|
|
104
|
-
if ground_rule_structure is not None:
|
|
105
|
-
ground_rule_probs = self._compute_ground_rule_probabilities(
|
|
106
|
-
fact_probs_clamped, ground_rule_structure
|
|
107
|
-
)
|
|
108
|
-
result.update(ground_rule_probs)
|
|
109
|
-
|
|
110
|
-
return result
|
|
111
|
-
|
|
112
|
-
def _compute_ground_rule_probabilities(self, fact_probs: torch.Tensor,
|
|
113
|
-
ground_rule_structure: List[List[int]]) -> Dict[str, torch.Tensor]:
|
|
114
|
-
"""
|
|
115
|
-
Compute probabilities for ground rules under mean-field approximation
|
|
116
|
-
|
|
117
|
-
For ground rule with facts [i, j, k]: P(rule) = ∏ p_i * p_j * p_k
|
|
118
|
-
"""
|
|
119
|
-
ground_rule_probs = []
|
|
120
|
-
ground_rule_log_probs = []
|
|
121
|
-
|
|
122
|
-
for fact_indices in ground_rule_structure:
|
|
123
|
-
if fact_indices:
|
|
124
|
-
# Get probabilities for facts in this ground rule
|
|
125
|
-
rule_fact_probs = fact_probs[fact_indices]
|
|
126
|
-
|
|
127
|
-
# Product probability (independence assumption)
|
|
128
|
-
rule_prob = torch.prod(rule_fact_probs)
|
|
129
|
-
rule_log_prob = torch.sum(safe_log(rule_fact_probs))
|
|
130
|
-
|
|
131
|
-
ground_rule_probs.append(rule_prob)
|
|
132
|
-
ground_rule_log_probs.append(rule_log_prob)
|
|
133
|
-
else:
|
|
134
|
-
ground_rule_probs.append(torch.tensor(0.0))
|
|
135
|
-
ground_rule_log_probs.append(torch.tensor(float('-inf')))
|
|
136
|
-
|
|
137
|
-
return {
|
|
138
|
-
'ground_rule_probs': torch.stack(ground_rule_probs) if ground_rule_probs else torch.tensor([]),
|
|
139
|
-
'ground_rule_log_probs': torch.stack(ground_rule_log_probs) if ground_rule_log_probs else torch.tensor([])
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
class EStepOptimizer(nn.Module):
|
|
144
|
-
"""
|
|
145
|
-
Optimizes the E-step objective function
|
|
146
|
-
"""
|
|
147
|
-
|
|
148
|
-
def __init__(self, config: NPLLConfig):
|
|
149
|
-
super().__init__()
|
|
150
|
-
self.config = config
|
|
151
|
-
self.mean_field = MeanFieldApproximation(config)
|
|
152
|
-
self.elbo_computer = ELBOComputer(config)
|
|
153
|
-
self.variational_inference = VariationalInference(config)
|
|
154
|
-
|
|
155
|
-
# Optimization parameters
|
|
156
|
-
self.max_iterations = config.mean_field_iterations
|
|
157
|
-
self.convergence_threshold = config.convergence_threshold
|
|
158
|
-
|
|
159
|
-
def optimize_approximate_posterior(self,
|
|
160
|
-
mln: MarkovLogicNetwork,
|
|
161
|
-
scoring_module: NPLLScoringModule,
|
|
162
|
-
known_facts: List[Triple],
|
|
163
|
-
unknown_facts: List[Triple],
|
|
164
|
-
ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> EStepResult:
|
|
165
|
-
"""
|
|
166
|
-
Main E-step optimization procedure
|
|
167
|
-
"""
|
|
168
|
-
if not unknown_facts:
|
|
169
|
-
# No unknown facts to optimize
|
|
170
|
-
return EStepResult(
|
|
171
|
-
approximate_posterior_probs=torch.tensor([]),
|
|
172
|
-
fact_probabilities={},
|
|
173
|
-
ground_rule_expectations=torch.tensor([]),
|
|
174
|
-
entropy=torch.tensor(0.0),
|
|
175
|
-
elbo_value=torch.tensor(0.0),
|
|
176
|
-
convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
|
|
177
|
-
iteration_count=0
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
logger.debug(f"Starting E-step optimization for {len(unknown_facts)} unknown facts")
|
|
181
|
-
|
|
182
|
-
# Initialize fact probabilities using scoring module
|
|
183
|
-
current_probs = self.mean_field.initialize_fact_probabilities(unknown_facts, scoring_module)
|
|
184
|
-
|
|
185
|
-
# Run a single VI optimization; it internally iterates and computes ELBO history
|
|
186
|
-
updated_result = self.variational_inference.optimize_approximate_posterior(
|
|
187
|
-
mln, known_facts, unknown_facts, current_probs.detach()
|
|
188
|
-
)
|
|
189
|
-
current_probs = updated_result['optimized_probs'].detach()
|
|
190
|
-
elbo_history = updated_result['elbo_history']
|
|
191
|
-
converged = updated_result['converged']
|
|
192
|
-
iteration = updated_result['iterations']
|
|
193
|
-
|
|
194
|
-
# Compute final Q(U) distribution
|
|
195
|
-
final_q_dist = self.mean_field.compute_q_u_distribution(current_probs)
|
|
196
|
-
|
|
197
|
-
# Compute ground rule expectations
|
|
198
|
-
ground_rule_expectations = self._compute_ground_rule_expectations(
|
|
199
|
-
mln, current_probs, unknown_facts
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
# Create fact probability dictionary
|
|
203
|
-
fact_prob_dict = {fact: current_probs[i].item()
|
|
204
|
-
for i, fact in enumerate(unknown_facts)}
|
|
205
|
-
|
|
206
|
-
# Final ELBO computation (detached to avoid gradient issues)
|
|
207
|
-
with torch.no_grad():
|
|
208
|
-
final_elbo_components = self.elbo_computer.compute_elbo(
|
|
209
|
-
mln, known_facts, unknown_facts, current_probs.detach()
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
convergence_info = {
|
|
213
|
-
'converged': converged,
|
|
214
|
-
'final_change': 0.0,
|
|
215
|
-
'elbo_history': elbo_history,
|
|
216
|
-
'reason': 'converged' if converged else 'max_iterations'
|
|
217
|
-
}
|
|
218
|
-
|
|
219
|
-
result = EStepResult(
|
|
220
|
-
approximate_posterior_probs=current_probs,
|
|
221
|
-
fact_probabilities=fact_prob_dict,
|
|
222
|
-
ground_rule_expectations=ground_rule_expectations,
|
|
223
|
-
entropy=final_q_dist['total_entropy'],
|
|
224
|
-
elbo_value=final_elbo_components.elbo,
|
|
225
|
-
convergence_info=convergence_info,
|
|
226
|
-
iteration_count=iteration + 1
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
logger.debug(f"E-step completed: {result}")
|
|
230
|
-
|
|
231
|
-
return result
|
|
232
|
-
|
|
233
|
-
def _compute_ground_rule_expectations(self,
|
|
234
|
-
mln: MarkovLogicNetwork,
|
|
235
|
-
fact_probs: torch.Tensor,
|
|
236
|
-
unknown_facts: List[Triple]) -> torch.Tensor:
|
|
237
|
-
"""
|
|
238
|
-
Compute expected ground rule counts under Q(U)
|
|
239
|
-
"""
|
|
240
|
-
if not mln.logical_rules or len(fact_probs) == 0:
|
|
241
|
-
return torch.tensor([])
|
|
242
|
-
|
|
243
|
-
# Create fact index mapping
|
|
244
|
-
fact_to_idx = {fact: i for i, fact in enumerate(unknown_facts)}
|
|
245
|
-
|
|
246
|
-
# Compute expectations for each logical rule
|
|
247
|
-
expected_counts = torch.zeros(len(mln.logical_rules))
|
|
248
|
-
|
|
249
|
-
for rule_idx, rule in enumerate(mln.logical_rules):
|
|
250
|
-
ground_rules = mln.rule_to_ground_rules[rule.rule_id]
|
|
251
|
-
rule_expectation = 0.0
|
|
252
|
-
|
|
253
|
-
for ground_rule in ground_rules:
|
|
254
|
-
# Compute expected satisfaction for this ground rule
|
|
255
|
-
ground_rule_factors = []
|
|
256
|
-
|
|
257
|
-
# Body facts
|
|
258
|
-
for fact in ground_rule.body_facts:
|
|
259
|
-
if fact in fact_to_idx:
|
|
260
|
-
# Unknown fact - use probability
|
|
261
|
-
fact_idx = fact_to_idx[fact]
|
|
262
|
-
ground_rule_factors.append(fact_probs[fact_idx])
|
|
263
|
-
else:
|
|
264
|
-
# Known fact - assume probability 1
|
|
265
|
-
ground_rule_factors.append(torch.tensor(1.0))
|
|
266
|
-
|
|
267
|
-
# Head fact
|
|
268
|
-
if ground_rule.head_fact in fact_to_idx:
|
|
269
|
-
fact_idx = fact_to_idx[ground_rule.head_fact]
|
|
270
|
-
ground_rule_factors.append(fact_probs[fact_idx])
|
|
271
|
-
else:
|
|
272
|
-
ground_rule_factors.append(torch.tensor(1.0))
|
|
273
|
-
|
|
274
|
-
# Expected satisfaction is product of all factors
|
|
275
|
-
if ground_rule_factors:
|
|
276
|
-
ground_rule_expectation = torch.prod(torch.stack(ground_rule_factors))
|
|
277
|
-
rule_expectation += ground_rule_expectation.item()
|
|
278
|
-
|
|
279
|
-
expected_counts[rule_idx] = rule_expectation
|
|
280
|
-
|
|
281
|
-
return expected_counts
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
class EStepRunner:
|
|
285
|
-
"""
|
|
286
|
-
High-level runner for E-step computations
|
|
287
|
-
Handles batching, parallelization, and result aggregation
|
|
288
|
-
"""
|
|
289
|
-
|
|
290
|
-
def __init__(self, config: NPLLConfig):
|
|
291
|
-
self.config = config
|
|
292
|
-
self.e_step_optimizer = EStepOptimizer(config)
|
|
293
|
-
self.ground_rule_sampler = GroundRuleSampler(config)
|
|
294
|
-
|
|
295
|
-
def run_e_step(self,
|
|
296
|
-
mln: MarkovLogicNetwork,
|
|
297
|
-
scoring_module: NPLLScoringModule,
|
|
298
|
-
kg: KnowledgeGraph) -> EStepResult:
|
|
299
|
-
known_facts = list(kg.known_facts)
|
|
300
|
-
unknown_facts = list(kg.unknown_facts)
|
|
301
|
-
|
|
302
|
-
if not unknown_facts:
|
|
303
|
-
logger.info("No unknown facts for E-step")
|
|
304
|
-
return EStepResult(
|
|
305
|
-
approximate_posterior_probs=torch.tensor([]),
|
|
306
|
-
fact_probabilities={},
|
|
307
|
-
ground_rule_expectations=torch.tensor([]),
|
|
308
|
-
entropy=torch.tensor(0.0),
|
|
309
|
-
elbo_value=torch.tensor(0.0),
|
|
310
|
-
convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
|
|
311
|
-
iteration_count=0
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
logger.info(f"Running E-step for {len(known_facts)} known and {len(unknown_facts)} unknown facts")
|
|
315
|
-
|
|
316
|
-
# Create ground rule batches for efficient processing
|
|
317
|
-
ground_rule_batches = None
|
|
318
|
-
if len(mln.ground_rules) > self.config.batch_size:
|
|
319
|
-
ground_rule_batches = self.ground_rule_sampler.sample_ground_rules(
|
|
320
|
-
mln.ground_rules,
|
|
321
|
-
num_batches=max(1, len(mln.ground_rules) // self.config.batch_size),
|
|
322
|
-
sampling_strategy="uniform"
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
# Run optimization
|
|
326
|
-
result = self.e_step_optimizer.optimize_approximate_posterior(
|
|
327
|
-
mln, scoring_module, known_facts, unknown_facts, ground_rule_batches
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
logger.info(f"E-step completed: ELBO={result.elbo_value.item():.4f}, "
|
|
331
|
-
f"Entropy={result.entropy.item():.4f}, Iterations={result.iteration_count}")
|
|
332
|
-
|
|
333
|
-
return result
|
|
334
|
-
|
|
335
|
-
def run_e_step_with_constraints(self,
|
|
336
|
-
mln: MarkovLogicNetwork,
|
|
337
|
-
scoring_module: NPLLScoringModule,
|
|
338
|
-
kg: KnowledgeGraph,
|
|
339
|
-
fact_constraints: Optional[Dict[Triple, Tuple[float, float]]] = None) -> EStepResult:
|
|
340
|
-
|
|
341
|
-
# Standard E-step
|
|
342
|
-
result = self.run_e_step(mln, scoring_module, kg)
|
|
343
|
-
|
|
344
|
-
# Apply constraints if provided
|
|
345
|
-
if fact_constraints and len(result.approximate_posterior_probs) > 0:
|
|
346
|
-
constrained_probs = result.approximate_posterior_probs.clone()
|
|
347
|
-
unknown_facts = list(kg.unknown_facts)
|
|
348
|
-
|
|
349
|
-
for i, fact in enumerate(unknown_facts):
|
|
350
|
-
if fact in fact_constraints:
|
|
351
|
-
min_prob, max_prob = fact_constraints[fact]
|
|
352
|
-
constrained_probs[i] = torch.clamp(constrained_probs[i], min_prob, max_prob)
|
|
353
|
-
|
|
354
|
-
# Update result with constrained probabilities
|
|
355
|
-
result.approximate_posterior_probs = constrained_probs
|
|
356
|
-
result.fact_probabilities = {
|
|
357
|
-
fact: constrained_probs[i].item()
|
|
358
|
-
for i, fact in enumerate(unknown_facts)
|
|
359
|
-
}
|
|
360
|
-
|
|
361
|
-
return result
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def create_e_step_runner(config: NPLLConfig) -> EStepRunner:
|
|
365
|
-
"""Factory function to create E-step runner"""
|
|
366
|
-
return EStepRunner(config)
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
def verify_e_step_implementation():
|
|
370
|
-
"""Verify E-step implementation"""
|
|
371
|
-
from ..utils.config import default_config
|
|
372
|
-
from ..core import load_knowledge_graph_from_triples
|
|
373
|
-
from ..core.mln import create_mln_from_kg_and_rules
|
|
374
|
-
from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
|
|
375
|
-
from ..scoring import create_scoring_module
|
|
376
|
-
|
|
377
|
-
# Create test data
|
|
378
|
-
test_triples = [
|
|
379
|
-
("A", "r1", "B"),
|
|
380
|
-
("B", "r2", "C"),
|
|
381
|
-
]
|
|
382
|
-
|
|
383
|
-
kg = load_knowledge_graph_from_triples(test_triples)
|
|
384
|
-
|
|
385
|
-
# Add unknown facts
|
|
386
|
-
kg.add_unknown_fact("A", "r3", "C")
|
|
387
|
-
|
|
388
|
-
# Create test rule
|
|
389
|
-
from ..core import Relation
|
|
390
|
-
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
391
|
-
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
392
|
-
|
|
393
|
-
test_rule = LogicalRule(
|
|
394
|
-
rule_id="test_e_step_rule",
|
|
395
|
-
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
396
|
-
head=Atom(r3, (x, z)),
|
|
397
|
-
rule_type=RuleType.TRANSITIVITY
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
# Create MLN and scoring module
|
|
401
|
-
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
402
|
-
scoring_module = create_scoring_module(default_config, kg)
|
|
403
|
-
|
|
404
|
-
# Test E-step
|
|
405
|
-
e_step_runner = EStepRunner(default_config)
|
|
406
|
-
result = e_step_runner.run_e_step(mln, scoring_module, kg)
|
|
407
|
-
|
|
408
|
-
# Verify results
|
|
409
|
-
assert len(result.approximate_posterior_probs) == len(kg.unknown_facts), \
|
|
410
|
-
"Should have probabilities for all unknown facts"
|
|
411
|
-
|
|
412
|
-
assert torch.all(result.approximate_posterior_probs >= 0), "Probabilities should be non-negative"
|
|
413
|
-
assert torch.all(result.approximate_posterior_probs <= 1), "Probabilities should be <= 1"
|
|
414
|
-
|
|
415
|
-
assert torch.isfinite(result.elbo_value), "ELBO should be finite"
|
|
416
|
-
assert torch.isfinite(result.entropy), "Entropy should be finite"
|
|
417
|
-
|
|
418
|
-
logger.info("E-step implementation verified successfully")
|
|
419
|
-
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import List, Dict, Set, Tuple, Optional, Any
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
|
|
9
|
+
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
10
|
+
from ..core.mln import MarkovLogicNetwork
|
|
11
|
+
from ..scoring import NPLLScoringModule
|
|
12
|
+
from ..utils.config import NPLLConfig
|
|
13
|
+
from ..utils.math_utils import (
|
|
14
|
+
safe_log, safe_sigmoid, bernoulli_log_prob, bernoulli_entropy,
|
|
15
|
+
kl_divergence_bernoulli
|
|
16
|
+
)
|
|
17
|
+
from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
|
|
18
|
+
from .elbo import ELBOComputer, VariationalInference
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class EStepResult:
|
|
25
|
+
"""
|
|
26
|
+
Result of E-step computation containing all relevant outputs
|
|
27
|
+
"""
|
|
28
|
+
approximate_posterior_probs: torch.Tensor # Q(U) probabilities
|
|
29
|
+
fact_probabilities: Dict[Triple, float] # Individual fact probabilities
|
|
30
|
+
ground_rule_expectations: torch.Tensor # Expected ground rule counts
|
|
31
|
+
entropy: torch.Tensor # Total entropy of Q(U)
|
|
32
|
+
elbo_value: torch.Tensor # Current ELBO value
|
|
33
|
+
convergence_info: Dict[str, Any] # Convergence diagnostics
|
|
34
|
+
iteration_count: int # Number of iterations used
|
|
35
|
+
|
|
36
|
+
def __str__(self) -> str:
|
|
37
|
+
return (f"E-step Result:\n"
|
|
38
|
+
f" Unknown facts: {len(self.approximate_posterior_probs)}\n"
|
|
39
|
+
f" Mean probability: {self.approximate_posterior_probs.mean().item():.4f}\n"
|
|
40
|
+
f" Entropy: {self.entropy.item():.4f}\n"
|
|
41
|
+
f" ELBO: {self.elbo_value.item():.4f}\n"
|
|
42
|
+
f" Iterations: {self.iteration_count}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class MeanFieldApproximation(nn.Module):
|
|
46
|
+
"""
|
|
47
|
+
Mean-field approximation for approximate posterior Q(U)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, config: NPLLConfig):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.config = config
|
|
53
|
+
self.max_iterations = config.mean_field_iterations
|
|
54
|
+
self.convergence_threshold = config.convergence_threshold
|
|
55
|
+
|
|
56
|
+
def initialize_fact_probabilities(self, unknown_facts: List[Triple],
|
|
57
|
+
scoring_module: NPLLScoringModule) -> torch.Tensor:
|
|
58
|
+
"""
|
|
59
|
+
Initialize fact probabilities using scoring module
|
|
60
|
+
"""
|
|
61
|
+
if not unknown_facts:
|
|
62
|
+
return torch.tensor([])
|
|
63
|
+
|
|
64
|
+
# Get initial scores from scoring module
|
|
65
|
+
with torch.no_grad():
|
|
66
|
+
initial_scores = scoring_module.forward(unknown_facts)
|
|
67
|
+
# Transform to probabilities using sigmoid
|
|
68
|
+
initial_probs = safe_sigmoid(initial_scores)
|
|
69
|
+
|
|
70
|
+
return initial_probs
|
|
71
|
+
|
|
72
|
+
def compute_q_u_distribution(self, fact_probs: torch.Tensor,
|
|
73
|
+
ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
|
|
74
|
+
"""
|
|
75
|
+
Compute Q(U) distribution components
|
|
76
|
+
"""
|
|
77
|
+
if len(fact_probs) == 0:
|
|
78
|
+
return {
|
|
79
|
+
'fact_probs': torch.tensor([]),
|
|
80
|
+
'log_probs': torch.tensor([]),
|
|
81
|
+
'entropy': torch.tensor(0.0)
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
# Clamp probabilities for numerical stability
|
|
85
|
+
fact_probs_clamped = torch.clamp(fact_probs, min=1e-8, max=1.0 - 1e-8)
|
|
86
|
+
|
|
87
|
+
# Compute log probabilities
|
|
88
|
+
log_probs = safe_log(fact_probs_clamped)
|
|
89
|
+
log_neg_probs = safe_log(1 - fact_probs_clamped)
|
|
90
|
+
|
|
91
|
+
# Compute entropy of individual facts
|
|
92
|
+
fact_entropies = bernoulli_entropy(fact_probs_clamped)
|
|
93
|
+
total_entropy = torch.sum(fact_entropies)
|
|
94
|
+
|
|
95
|
+
result = {
|
|
96
|
+
'fact_probs': fact_probs_clamped,
|
|
97
|
+
'log_probs': log_probs,
|
|
98
|
+
'log_neg_probs': log_neg_probs,
|
|
99
|
+
'fact_entropies': fact_entropies,
|
|
100
|
+
'total_entropy': total_entropy
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
# If ground rule structure provided, compute ground rule probabilities
|
|
104
|
+
if ground_rule_structure is not None:
|
|
105
|
+
ground_rule_probs = self._compute_ground_rule_probabilities(
|
|
106
|
+
fact_probs_clamped, ground_rule_structure
|
|
107
|
+
)
|
|
108
|
+
result.update(ground_rule_probs)
|
|
109
|
+
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
def _compute_ground_rule_probabilities(self, fact_probs: torch.Tensor,
|
|
113
|
+
ground_rule_structure: List[List[int]]) -> Dict[str, torch.Tensor]:
|
|
114
|
+
"""
|
|
115
|
+
Compute probabilities for ground rules under mean-field approximation
|
|
116
|
+
|
|
117
|
+
For ground rule with facts [i, j, k]: P(rule) = ∏ p_i * p_j * p_k
|
|
118
|
+
"""
|
|
119
|
+
ground_rule_probs = []
|
|
120
|
+
ground_rule_log_probs = []
|
|
121
|
+
|
|
122
|
+
for fact_indices in ground_rule_structure:
|
|
123
|
+
if fact_indices:
|
|
124
|
+
# Get probabilities for facts in this ground rule
|
|
125
|
+
rule_fact_probs = fact_probs[fact_indices]
|
|
126
|
+
|
|
127
|
+
# Product probability (independence assumption)
|
|
128
|
+
rule_prob = torch.prod(rule_fact_probs)
|
|
129
|
+
rule_log_prob = torch.sum(safe_log(rule_fact_probs))
|
|
130
|
+
|
|
131
|
+
ground_rule_probs.append(rule_prob)
|
|
132
|
+
ground_rule_log_probs.append(rule_log_prob)
|
|
133
|
+
else:
|
|
134
|
+
ground_rule_probs.append(torch.tensor(0.0))
|
|
135
|
+
ground_rule_log_probs.append(torch.tensor(float('-inf')))
|
|
136
|
+
|
|
137
|
+
return {
|
|
138
|
+
'ground_rule_probs': torch.stack(ground_rule_probs) if ground_rule_probs else torch.tensor([]),
|
|
139
|
+
'ground_rule_log_probs': torch.stack(ground_rule_log_probs) if ground_rule_log_probs else torch.tensor([])
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class EStepOptimizer(nn.Module):
|
|
144
|
+
"""
|
|
145
|
+
Optimizes the E-step objective function
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(self, config: NPLLConfig):
|
|
149
|
+
super().__init__()
|
|
150
|
+
self.config = config
|
|
151
|
+
self.mean_field = MeanFieldApproximation(config)
|
|
152
|
+
self.elbo_computer = ELBOComputer(config)
|
|
153
|
+
self.variational_inference = VariationalInference(config)
|
|
154
|
+
|
|
155
|
+
# Optimization parameters
|
|
156
|
+
self.max_iterations = config.mean_field_iterations
|
|
157
|
+
self.convergence_threshold = config.convergence_threshold
|
|
158
|
+
|
|
159
|
+
def optimize_approximate_posterior(self,
|
|
160
|
+
mln: MarkovLogicNetwork,
|
|
161
|
+
scoring_module: NPLLScoringModule,
|
|
162
|
+
known_facts: List[Triple],
|
|
163
|
+
unknown_facts: List[Triple],
|
|
164
|
+
ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> EStepResult:
|
|
165
|
+
"""
|
|
166
|
+
Main E-step optimization procedure
|
|
167
|
+
"""
|
|
168
|
+
if not unknown_facts:
|
|
169
|
+
# No unknown facts to optimize
|
|
170
|
+
return EStepResult(
|
|
171
|
+
approximate_posterior_probs=torch.tensor([]),
|
|
172
|
+
fact_probabilities={},
|
|
173
|
+
ground_rule_expectations=torch.tensor([]),
|
|
174
|
+
entropy=torch.tensor(0.0),
|
|
175
|
+
elbo_value=torch.tensor(0.0),
|
|
176
|
+
convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
|
|
177
|
+
iteration_count=0
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
logger.debug(f"Starting E-step optimization for {len(unknown_facts)} unknown facts")
|
|
181
|
+
|
|
182
|
+
# Initialize fact probabilities using scoring module
|
|
183
|
+
current_probs = self.mean_field.initialize_fact_probabilities(unknown_facts, scoring_module)
|
|
184
|
+
|
|
185
|
+
# Run a single VI optimization; it internally iterates and computes ELBO history
|
|
186
|
+
updated_result = self.variational_inference.optimize_approximate_posterior(
|
|
187
|
+
mln, known_facts, unknown_facts, current_probs.detach()
|
|
188
|
+
)
|
|
189
|
+
current_probs = updated_result['optimized_probs'].detach()
|
|
190
|
+
elbo_history = updated_result['elbo_history']
|
|
191
|
+
converged = updated_result['converged']
|
|
192
|
+
iteration = updated_result['iterations']
|
|
193
|
+
|
|
194
|
+
# Compute final Q(U) distribution
|
|
195
|
+
final_q_dist = self.mean_field.compute_q_u_distribution(current_probs)
|
|
196
|
+
|
|
197
|
+
# Compute ground rule expectations
|
|
198
|
+
ground_rule_expectations = self._compute_ground_rule_expectations(
|
|
199
|
+
mln, current_probs, unknown_facts
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Create fact probability dictionary
|
|
203
|
+
fact_prob_dict = {fact: current_probs[i].item()
|
|
204
|
+
for i, fact in enumerate(unknown_facts)}
|
|
205
|
+
|
|
206
|
+
# Final ELBO computation (detached to avoid gradient issues)
|
|
207
|
+
with torch.no_grad():
|
|
208
|
+
final_elbo_components = self.elbo_computer.compute_elbo(
|
|
209
|
+
mln, known_facts, unknown_facts, current_probs.detach()
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
convergence_info = {
|
|
213
|
+
'converged': converged,
|
|
214
|
+
'final_change': 0.0,
|
|
215
|
+
'elbo_history': elbo_history,
|
|
216
|
+
'reason': 'converged' if converged else 'max_iterations'
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
result = EStepResult(
|
|
220
|
+
approximate_posterior_probs=current_probs,
|
|
221
|
+
fact_probabilities=fact_prob_dict,
|
|
222
|
+
ground_rule_expectations=ground_rule_expectations,
|
|
223
|
+
entropy=final_q_dist['total_entropy'],
|
|
224
|
+
elbo_value=final_elbo_components.elbo,
|
|
225
|
+
convergence_info=convergence_info,
|
|
226
|
+
iteration_count=iteration + 1
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
logger.debug(f"E-step completed: {result}")
|
|
230
|
+
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
def _compute_ground_rule_expectations(self,
|
|
234
|
+
mln: MarkovLogicNetwork,
|
|
235
|
+
fact_probs: torch.Tensor,
|
|
236
|
+
unknown_facts: List[Triple]) -> torch.Tensor:
|
|
237
|
+
"""
|
|
238
|
+
Compute expected ground rule counts under Q(U)
|
|
239
|
+
"""
|
|
240
|
+
if not mln.logical_rules or len(fact_probs) == 0:
|
|
241
|
+
return torch.tensor([])
|
|
242
|
+
|
|
243
|
+
# Create fact index mapping
|
|
244
|
+
fact_to_idx = {fact: i for i, fact in enumerate(unknown_facts)}
|
|
245
|
+
|
|
246
|
+
# Compute expectations for each logical rule
|
|
247
|
+
expected_counts = torch.zeros(len(mln.logical_rules))
|
|
248
|
+
|
|
249
|
+
for rule_idx, rule in enumerate(mln.logical_rules):
|
|
250
|
+
ground_rules = mln.rule_to_ground_rules[rule.rule_id]
|
|
251
|
+
rule_expectation = 0.0
|
|
252
|
+
|
|
253
|
+
for ground_rule in ground_rules:
|
|
254
|
+
# Compute expected satisfaction for this ground rule
|
|
255
|
+
ground_rule_factors = []
|
|
256
|
+
|
|
257
|
+
# Body facts
|
|
258
|
+
for fact in ground_rule.body_facts:
|
|
259
|
+
if fact in fact_to_idx:
|
|
260
|
+
# Unknown fact - use probability
|
|
261
|
+
fact_idx = fact_to_idx[fact]
|
|
262
|
+
ground_rule_factors.append(fact_probs[fact_idx])
|
|
263
|
+
else:
|
|
264
|
+
# Known fact - assume probability 1
|
|
265
|
+
ground_rule_factors.append(torch.tensor(1.0))
|
|
266
|
+
|
|
267
|
+
# Head fact
|
|
268
|
+
if ground_rule.head_fact in fact_to_idx:
|
|
269
|
+
fact_idx = fact_to_idx[ground_rule.head_fact]
|
|
270
|
+
ground_rule_factors.append(fact_probs[fact_idx])
|
|
271
|
+
else:
|
|
272
|
+
ground_rule_factors.append(torch.tensor(1.0))
|
|
273
|
+
|
|
274
|
+
# Expected satisfaction is product of all factors
|
|
275
|
+
if ground_rule_factors:
|
|
276
|
+
ground_rule_expectation = torch.prod(torch.stack(ground_rule_factors))
|
|
277
|
+
rule_expectation += ground_rule_expectation.item()
|
|
278
|
+
|
|
279
|
+
expected_counts[rule_idx] = rule_expectation
|
|
280
|
+
|
|
281
|
+
return expected_counts
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class EStepRunner:
|
|
285
|
+
"""
|
|
286
|
+
High-level runner for E-step computations
|
|
287
|
+
Handles batching, parallelization, and result aggregation
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
def __init__(self, config: NPLLConfig):
|
|
291
|
+
self.config = config
|
|
292
|
+
self.e_step_optimizer = EStepOptimizer(config)
|
|
293
|
+
self.ground_rule_sampler = GroundRuleSampler(config)
|
|
294
|
+
|
|
295
|
+
def run_e_step(self,
|
|
296
|
+
mln: MarkovLogicNetwork,
|
|
297
|
+
scoring_module: NPLLScoringModule,
|
|
298
|
+
kg: KnowledgeGraph) -> EStepResult:
|
|
299
|
+
known_facts = list(kg.known_facts)
|
|
300
|
+
unknown_facts = list(kg.unknown_facts)
|
|
301
|
+
|
|
302
|
+
if not unknown_facts:
|
|
303
|
+
logger.info("No unknown facts for E-step")
|
|
304
|
+
return EStepResult(
|
|
305
|
+
approximate_posterior_probs=torch.tensor([]),
|
|
306
|
+
fact_probabilities={},
|
|
307
|
+
ground_rule_expectations=torch.tensor([]),
|
|
308
|
+
entropy=torch.tensor(0.0),
|
|
309
|
+
elbo_value=torch.tensor(0.0),
|
|
310
|
+
convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
|
|
311
|
+
iteration_count=0
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
logger.info(f"Running E-step for {len(known_facts)} known and {len(unknown_facts)} unknown facts")
|
|
315
|
+
|
|
316
|
+
# Create ground rule batches for efficient processing
|
|
317
|
+
ground_rule_batches = None
|
|
318
|
+
if len(mln.ground_rules) > self.config.batch_size:
|
|
319
|
+
ground_rule_batches = self.ground_rule_sampler.sample_ground_rules(
|
|
320
|
+
mln.ground_rules,
|
|
321
|
+
num_batches=max(1, len(mln.ground_rules) // self.config.batch_size),
|
|
322
|
+
sampling_strategy="uniform"
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Run optimization
|
|
326
|
+
result = self.e_step_optimizer.optimize_approximate_posterior(
|
|
327
|
+
mln, scoring_module, known_facts, unknown_facts, ground_rule_batches
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
logger.info(f"E-step completed: ELBO={result.elbo_value.item():.4f}, "
|
|
331
|
+
f"Entropy={result.entropy.item():.4f}, Iterations={result.iteration_count}")
|
|
332
|
+
|
|
333
|
+
return result
|
|
334
|
+
|
|
335
|
+
def run_e_step_with_constraints(self,
|
|
336
|
+
mln: MarkovLogicNetwork,
|
|
337
|
+
scoring_module: NPLLScoringModule,
|
|
338
|
+
kg: KnowledgeGraph,
|
|
339
|
+
fact_constraints: Optional[Dict[Triple, Tuple[float, float]]] = None) -> EStepResult:
|
|
340
|
+
|
|
341
|
+
# Standard E-step
|
|
342
|
+
result = self.run_e_step(mln, scoring_module, kg)
|
|
343
|
+
|
|
344
|
+
# Apply constraints if provided
|
|
345
|
+
if fact_constraints and len(result.approximate_posterior_probs) > 0:
|
|
346
|
+
constrained_probs = result.approximate_posterior_probs.clone()
|
|
347
|
+
unknown_facts = list(kg.unknown_facts)
|
|
348
|
+
|
|
349
|
+
for i, fact in enumerate(unknown_facts):
|
|
350
|
+
if fact in fact_constraints:
|
|
351
|
+
min_prob, max_prob = fact_constraints[fact]
|
|
352
|
+
constrained_probs[i] = torch.clamp(constrained_probs[i], min_prob, max_prob)
|
|
353
|
+
|
|
354
|
+
# Update result with constrained probabilities
|
|
355
|
+
result.approximate_posterior_probs = constrained_probs
|
|
356
|
+
result.fact_probabilities = {
|
|
357
|
+
fact: constrained_probs[i].item()
|
|
358
|
+
for i, fact in enumerate(unknown_facts)
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
return result
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def create_e_step_runner(config: NPLLConfig) -> EStepRunner:
|
|
365
|
+
"""Factory function to create E-step runner"""
|
|
366
|
+
return EStepRunner(config)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def verify_e_step_implementation():
|
|
370
|
+
"""Verify E-step implementation"""
|
|
371
|
+
from ..utils.config import default_config
|
|
372
|
+
from ..core import load_knowledge_graph_from_triples
|
|
373
|
+
from ..core.mln import create_mln_from_kg_and_rules
|
|
374
|
+
from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
|
|
375
|
+
from ..scoring import create_scoring_module
|
|
376
|
+
|
|
377
|
+
# Create test data
|
|
378
|
+
test_triples = [
|
|
379
|
+
("A", "r1", "B"),
|
|
380
|
+
("B", "r2", "C"),
|
|
381
|
+
]
|
|
382
|
+
|
|
383
|
+
kg = load_knowledge_graph_from_triples(test_triples)
|
|
384
|
+
|
|
385
|
+
# Add unknown facts
|
|
386
|
+
kg.add_unknown_fact("A", "r3", "C")
|
|
387
|
+
|
|
388
|
+
# Create test rule
|
|
389
|
+
from ..core import Relation
|
|
390
|
+
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
391
|
+
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
392
|
+
|
|
393
|
+
test_rule = LogicalRule(
|
|
394
|
+
rule_id="test_e_step_rule",
|
|
395
|
+
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
396
|
+
head=Atom(r3, (x, z)),
|
|
397
|
+
rule_type=RuleType.TRANSITIVITY
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Create MLN and scoring module
|
|
401
|
+
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
402
|
+
scoring_module = create_scoring_module(default_config, kg)
|
|
403
|
+
|
|
404
|
+
# Test E-step
|
|
405
|
+
e_step_runner = EStepRunner(default_config)
|
|
406
|
+
result = e_step_runner.run_e_step(mln, scoring_module, kg)
|
|
407
|
+
|
|
408
|
+
# Verify results
|
|
409
|
+
assert len(result.approximate_posterior_probs) == len(kg.unknown_facts), \
|
|
410
|
+
"Should have probabilities for all unknown facts"
|
|
411
|
+
|
|
412
|
+
assert torch.all(result.approximate_posterior_probs >= 0), "Probabilities should be non-negative"
|
|
413
|
+
assert torch.all(result.approximate_posterior_probs <= 1), "Probabilities should be <= 1"
|
|
414
|
+
|
|
415
|
+
assert torch.isfinite(result.elbo_value), "ELBO should be finite"
|
|
416
|
+
assert torch.isfinite(result.entropy), "Entropy should be finite"
|
|
417
|
+
|
|
418
|
+
logger.info("E-step implementation verified successfully")
|
|
419
|
+
|
|
420
420
|
return True
|