odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/inference/e_step.py CHANGED
@@ -1,420 +1,420 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import List, Dict, Set, Tuple, Optional, Any
5
- import logging
6
- from dataclasses import dataclass
7
- from collections import defaultdict
8
-
9
- from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
10
- from ..core.mln import MarkovLogicNetwork
11
- from ..scoring import NPLLScoringModule
12
- from ..utils.config import NPLLConfig
13
- from ..utils.math_utils import (
14
- safe_log, safe_sigmoid, bernoulli_log_prob, bernoulli_entropy,
15
- kl_divergence_bernoulli
16
- )
17
- from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
18
- from .elbo import ELBOComputer, VariationalInference
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- @dataclass
24
- class EStepResult:
25
- """
26
- Result of E-step computation containing all relevant outputs
27
- """
28
- approximate_posterior_probs: torch.Tensor # Q(U) probabilities
29
- fact_probabilities: Dict[Triple, float] # Individual fact probabilities
30
- ground_rule_expectations: torch.Tensor # Expected ground rule counts
31
- entropy: torch.Tensor # Total entropy of Q(U)
32
- elbo_value: torch.Tensor # Current ELBO value
33
- convergence_info: Dict[str, Any] # Convergence diagnostics
34
- iteration_count: int # Number of iterations used
35
-
36
- def __str__(self) -> str:
37
- return (f"E-step Result:\n"
38
- f" Unknown facts: {len(self.approximate_posterior_probs)}\n"
39
- f" Mean probability: {self.approximate_posterior_probs.mean().item():.4f}\n"
40
- f" Entropy: {self.entropy.item():.4f}\n"
41
- f" ELBO: {self.elbo_value.item():.4f}\n"
42
- f" Iterations: {self.iteration_count}")
43
-
44
-
45
- class MeanFieldApproximation(nn.Module):
46
- """
47
- Mean-field approximation for approximate posterior Q(U)
48
- """
49
-
50
- def __init__(self, config: NPLLConfig):
51
- super().__init__()
52
- self.config = config
53
- self.max_iterations = config.mean_field_iterations
54
- self.convergence_threshold = config.convergence_threshold
55
-
56
- def initialize_fact_probabilities(self, unknown_facts: List[Triple],
57
- scoring_module: NPLLScoringModule) -> torch.Tensor:
58
- """
59
- Initialize fact probabilities using scoring module
60
- """
61
- if not unknown_facts:
62
- return torch.tensor([])
63
-
64
- # Get initial scores from scoring module
65
- with torch.no_grad():
66
- initial_scores = scoring_module.forward(unknown_facts)
67
- # Transform to probabilities using sigmoid
68
- initial_probs = safe_sigmoid(initial_scores)
69
-
70
- return initial_probs
71
-
72
- def compute_q_u_distribution(self, fact_probs: torch.Tensor,
73
- ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
74
- """
75
- Compute Q(U) distribution components
76
- """
77
- if len(fact_probs) == 0:
78
- return {
79
- 'fact_probs': torch.tensor([]),
80
- 'log_probs': torch.tensor([]),
81
- 'entropy': torch.tensor(0.0)
82
- }
83
-
84
- # Clamp probabilities for numerical stability
85
- fact_probs_clamped = torch.clamp(fact_probs, min=1e-8, max=1.0 - 1e-8)
86
-
87
- # Compute log probabilities
88
- log_probs = safe_log(fact_probs_clamped)
89
- log_neg_probs = safe_log(1 - fact_probs_clamped)
90
-
91
- # Compute entropy of individual facts
92
- fact_entropies = bernoulli_entropy(fact_probs_clamped)
93
- total_entropy = torch.sum(fact_entropies)
94
-
95
- result = {
96
- 'fact_probs': fact_probs_clamped,
97
- 'log_probs': log_probs,
98
- 'log_neg_probs': log_neg_probs,
99
- 'fact_entropies': fact_entropies,
100
- 'total_entropy': total_entropy
101
- }
102
-
103
- # If ground rule structure provided, compute ground rule probabilities
104
- if ground_rule_structure is not None:
105
- ground_rule_probs = self._compute_ground_rule_probabilities(
106
- fact_probs_clamped, ground_rule_structure
107
- )
108
- result.update(ground_rule_probs)
109
-
110
- return result
111
-
112
- def _compute_ground_rule_probabilities(self, fact_probs: torch.Tensor,
113
- ground_rule_structure: List[List[int]]) -> Dict[str, torch.Tensor]:
114
- """
115
- Compute probabilities for ground rules under mean-field approximation
116
-
117
- For ground rule with facts [i, j, k]: P(rule) = ∏ p_i * p_j * p_k
118
- """
119
- ground_rule_probs = []
120
- ground_rule_log_probs = []
121
-
122
- for fact_indices in ground_rule_structure:
123
- if fact_indices:
124
- # Get probabilities for facts in this ground rule
125
- rule_fact_probs = fact_probs[fact_indices]
126
-
127
- # Product probability (independence assumption)
128
- rule_prob = torch.prod(rule_fact_probs)
129
- rule_log_prob = torch.sum(safe_log(rule_fact_probs))
130
-
131
- ground_rule_probs.append(rule_prob)
132
- ground_rule_log_probs.append(rule_log_prob)
133
- else:
134
- ground_rule_probs.append(torch.tensor(0.0))
135
- ground_rule_log_probs.append(torch.tensor(float('-inf')))
136
-
137
- return {
138
- 'ground_rule_probs': torch.stack(ground_rule_probs) if ground_rule_probs else torch.tensor([]),
139
- 'ground_rule_log_probs': torch.stack(ground_rule_log_probs) if ground_rule_log_probs else torch.tensor([])
140
- }
141
-
142
-
143
- class EStepOptimizer(nn.Module):
144
- """
145
- Optimizes the E-step objective function
146
- """
147
-
148
- def __init__(self, config: NPLLConfig):
149
- super().__init__()
150
- self.config = config
151
- self.mean_field = MeanFieldApproximation(config)
152
- self.elbo_computer = ELBOComputer(config)
153
- self.variational_inference = VariationalInference(config)
154
-
155
- # Optimization parameters
156
- self.max_iterations = config.mean_field_iterations
157
- self.convergence_threshold = config.convergence_threshold
158
-
159
- def optimize_approximate_posterior(self,
160
- mln: MarkovLogicNetwork,
161
- scoring_module: NPLLScoringModule,
162
- known_facts: List[Triple],
163
- unknown_facts: List[Triple],
164
- ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> EStepResult:
165
- """
166
- Main E-step optimization procedure
167
- """
168
- if not unknown_facts:
169
- # No unknown facts to optimize
170
- return EStepResult(
171
- approximate_posterior_probs=torch.tensor([]),
172
- fact_probabilities={},
173
- ground_rule_expectations=torch.tensor([]),
174
- entropy=torch.tensor(0.0),
175
- elbo_value=torch.tensor(0.0),
176
- convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
177
- iteration_count=0
178
- )
179
-
180
- logger.debug(f"Starting E-step optimization for {len(unknown_facts)} unknown facts")
181
-
182
- # Initialize fact probabilities using scoring module
183
- current_probs = self.mean_field.initialize_fact_probabilities(unknown_facts, scoring_module)
184
-
185
- # Run a single VI optimization; it internally iterates and computes ELBO history
186
- updated_result = self.variational_inference.optimize_approximate_posterior(
187
- mln, known_facts, unknown_facts, current_probs.detach()
188
- )
189
- current_probs = updated_result['optimized_probs'].detach()
190
- elbo_history = updated_result['elbo_history']
191
- converged = updated_result['converged']
192
- iteration = updated_result['iterations']
193
-
194
- # Compute final Q(U) distribution
195
- final_q_dist = self.mean_field.compute_q_u_distribution(current_probs)
196
-
197
- # Compute ground rule expectations
198
- ground_rule_expectations = self._compute_ground_rule_expectations(
199
- mln, current_probs, unknown_facts
200
- )
201
-
202
- # Create fact probability dictionary
203
- fact_prob_dict = {fact: current_probs[i].item()
204
- for i, fact in enumerate(unknown_facts)}
205
-
206
- # Final ELBO computation (detached to avoid gradient issues)
207
- with torch.no_grad():
208
- final_elbo_components = self.elbo_computer.compute_elbo(
209
- mln, known_facts, unknown_facts, current_probs.detach()
210
- )
211
-
212
- convergence_info = {
213
- 'converged': converged,
214
- 'final_change': 0.0,
215
- 'elbo_history': elbo_history,
216
- 'reason': 'converged' if converged else 'max_iterations'
217
- }
218
-
219
- result = EStepResult(
220
- approximate_posterior_probs=current_probs,
221
- fact_probabilities=fact_prob_dict,
222
- ground_rule_expectations=ground_rule_expectations,
223
- entropy=final_q_dist['total_entropy'],
224
- elbo_value=final_elbo_components.elbo,
225
- convergence_info=convergence_info,
226
- iteration_count=iteration + 1
227
- )
228
-
229
- logger.debug(f"E-step completed: {result}")
230
-
231
- return result
232
-
233
- def _compute_ground_rule_expectations(self,
234
- mln: MarkovLogicNetwork,
235
- fact_probs: torch.Tensor,
236
- unknown_facts: List[Triple]) -> torch.Tensor:
237
- """
238
- Compute expected ground rule counts under Q(U)
239
- """
240
- if not mln.logical_rules or len(fact_probs) == 0:
241
- return torch.tensor([])
242
-
243
- # Create fact index mapping
244
- fact_to_idx = {fact: i for i, fact in enumerate(unknown_facts)}
245
-
246
- # Compute expectations for each logical rule
247
- expected_counts = torch.zeros(len(mln.logical_rules))
248
-
249
- for rule_idx, rule in enumerate(mln.logical_rules):
250
- ground_rules = mln.rule_to_ground_rules[rule.rule_id]
251
- rule_expectation = 0.0
252
-
253
- for ground_rule in ground_rules:
254
- # Compute expected satisfaction for this ground rule
255
- ground_rule_factors = []
256
-
257
- # Body facts
258
- for fact in ground_rule.body_facts:
259
- if fact in fact_to_idx:
260
- # Unknown fact - use probability
261
- fact_idx = fact_to_idx[fact]
262
- ground_rule_factors.append(fact_probs[fact_idx])
263
- else:
264
- # Known fact - assume probability 1
265
- ground_rule_factors.append(torch.tensor(1.0))
266
-
267
- # Head fact
268
- if ground_rule.head_fact in fact_to_idx:
269
- fact_idx = fact_to_idx[ground_rule.head_fact]
270
- ground_rule_factors.append(fact_probs[fact_idx])
271
- else:
272
- ground_rule_factors.append(torch.tensor(1.0))
273
-
274
- # Expected satisfaction is product of all factors
275
- if ground_rule_factors:
276
- ground_rule_expectation = torch.prod(torch.stack(ground_rule_factors))
277
- rule_expectation += ground_rule_expectation.item()
278
-
279
- expected_counts[rule_idx] = rule_expectation
280
-
281
- return expected_counts
282
-
283
-
284
- class EStepRunner:
285
- """
286
- High-level runner for E-step computations
287
- Handles batching, parallelization, and result aggregation
288
- """
289
-
290
- def __init__(self, config: NPLLConfig):
291
- self.config = config
292
- self.e_step_optimizer = EStepOptimizer(config)
293
- self.ground_rule_sampler = GroundRuleSampler(config)
294
-
295
- def run_e_step(self,
296
- mln: MarkovLogicNetwork,
297
- scoring_module: NPLLScoringModule,
298
- kg: KnowledgeGraph) -> EStepResult:
299
- known_facts = list(kg.known_facts)
300
- unknown_facts = list(kg.unknown_facts)
301
-
302
- if not unknown_facts:
303
- logger.info("No unknown facts for E-step")
304
- return EStepResult(
305
- approximate_posterior_probs=torch.tensor([]),
306
- fact_probabilities={},
307
- ground_rule_expectations=torch.tensor([]),
308
- entropy=torch.tensor(0.0),
309
- elbo_value=torch.tensor(0.0),
310
- convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
311
- iteration_count=0
312
- )
313
-
314
- logger.info(f"Running E-step for {len(known_facts)} known and {len(unknown_facts)} unknown facts")
315
-
316
- # Create ground rule batches for efficient processing
317
- ground_rule_batches = None
318
- if len(mln.ground_rules) > self.config.batch_size:
319
- ground_rule_batches = self.ground_rule_sampler.sample_ground_rules(
320
- mln.ground_rules,
321
- num_batches=max(1, len(mln.ground_rules) // self.config.batch_size),
322
- sampling_strategy="uniform"
323
- )
324
-
325
- # Run optimization
326
- result = self.e_step_optimizer.optimize_approximate_posterior(
327
- mln, scoring_module, known_facts, unknown_facts, ground_rule_batches
328
- )
329
-
330
- logger.info(f"E-step completed: ELBO={result.elbo_value.item():.4f}, "
331
- f"Entropy={result.entropy.item():.4f}, Iterations={result.iteration_count}")
332
-
333
- return result
334
-
335
- def run_e_step_with_constraints(self,
336
- mln: MarkovLogicNetwork,
337
- scoring_module: NPLLScoringModule,
338
- kg: KnowledgeGraph,
339
- fact_constraints: Optional[Dict[Triple, Tuple[float, float]]] = None) -> EStepResult:
340
-
341
- # Standard E-step
342
- result = self.run_e_step(mln, scoring_module, kg)
343
-
344
- # Apply constraints if provided
345
- if fact_constraints and len(result.approximate_posterior_probs) > 0:
346
- constrained_probs = result.approximate_posterior_probs.clone()
347
- unknown_facts = list(kg.unknown_facts)
348
-
349
- for i, fact in enumerate(unknown_facts):
350
- if fact in fact_constraints:
351
- min_prob, max_prob = fact_constraints[fact]
352
- constrained_probs[i] = torch.clamp(constrained_probs[i], min_prob, max_prob)
353
-
354
- # Update result with constrained probabilities
355
- result.approximate_posterior_probs = constrained_probs
356
- result.fact_probabilities = {
357
- fact: constrained_probs[i].item()
358
- for i, fact in enumerate(unknown_facts)
359
- }
360
-
361
- return result
362
-
363
-
364
- def create_e_step_runner(config: NPLLConfig) -> EStepRunner:
365
- """Factory function to create E-step runner"""
366
- return EStepRunner(config)
367
-
368
-
369
- def verify_e_step_implementation():
370
- """Verify E-step implementation"""
371
- from ..utils.config import default_config
372
- from ..core import load_knowledge_graph_from_triples
373
- from ..core.mln import create_mln_from_kg_and_rules
374
- from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
375
- from ..scoring import create_scoring_module
376
-
377
- # Create test data
378
- test_triples = [
379
- ("A", "r1", "B"),
380
- ("B", "r2", "C"),
381
- ]
382
-
383
- kg = load_knowledge_graph_from_triples(test_triples)
384
-
385
- # Add unknown facts
386
- kg.add_unknown_fact("A", "r3", "C")
387
-
388
- # Create test rule
389
- from ..core import Relation
390
- r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
391
- x, y, z = Variable('x'), Variable('y'), Variable('z')
392
-
393
- test_rule = LogicalRule(
394
- rule_id="test_e_step_rule",
395
- body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
396
- head=Atom(r3, (x, z)),
397
- rule_type=RuleType.TRANSITIVITY
398
- )
399
-
400
- # Create MLN and scoring module
401
- mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
402
- scoring_module = create_scoring_module(default_config, kg)
403
-
404
- # Test E-step
405
- e_step_runner = EStepRunner(default_config)
406
- result = e_step_runner.run_e_step(mln, scoring_module, kg)
407
-
408
- # Verify results
409
- assert len(result.approximate_posterior_probs) == len(kg.unknown_facts), \
410
- "Should have probabilities for all unknown facts"
411
-
412
- assert torch.all(result.approximate_posterior_probs >= 0), "Probabilities should be non-negative"
413
- assert torch.all(result.approximate_posterior_probs <= 1), "Probabilities should be <= 1"
414
-
415
- assert torch.isfinite(result.elbo_value), "ELBO should be finite"
416
- assert torch.isfinite(result.entropy), "Entropy should be finite"
417
-
418
- logger.info("E-step implementation verified successfully")
419
-
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List, Dict, Set, Tuple, Optional, Any
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from collections import defaultdict
8
+
9
+ from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
10
+ from ..core.mln import MarkovLogicNetwork
11
+ from ..scoring import NPLLScoringModule
12
+ from ..utils.config import NPLLConfig
13
+ from ..utils.math_utils import (
14
+ safe_log, safe_sigmoid, bernoulli_log_prob, bernoulli_entropy,
15
+ kl_divergence_bernoulli
16
+ )
17
+ from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
18
+ from .elbo import ELBOComputer, VariationalInference
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class EStepResult:
25
+ """
26
+ Result of E-step computation containing all relevant outputs
27
+ """
28
+ approximate_posterior_probs: torch.Tensor # Q(U) probabilities
29
+ fact_probabilities: Dict[Triple, float] # Individual fact probabilities
30
+ ground_rule_expectations: torch.Tensor # Expected ground rule counts
31
+ entropy: torch.Tensor # Total entropy of Q(U)
32
+ elbo_value: torch.Tensor # Current ELBO value
33
+ convergence_info: Dict[str, Any] # Convergence diagnostics
34
+ iteration_count: int # Number of iterations used
35
+
36
+ def __str__(self) -> str:
37
+ return (f"E-step Result:\n"
38
+ f" Unknown facts: {len(self.approximate_posterior_probs)}\n"
39
+ f" Mean probability: {self.approximate_posterior_probs.mean().item():.4f}\n"
40
+ f" Entropy: {self.entropy.item():.4f}\n"
41
+ f" ELBO: {self.elbo_value.item():.4f}\n"
42
+ f" Iterations: {self.iteration_count}")
43
+
44
+
45
+ class MeanFieldApproximation(nn.Module):
46
+ """
47
+ Mean-field approximation for approximate posterior Q(U)
48
+ """
49
+
50
+ def __init__(self, config: NPLLConfig):
51
+ super().__init__()
52
+ self.config = config
53
+ self.max_iterations = config.mean_field_iterations
54
+ self.convergence_threshold = config.convergence_threshold
55
+
56
+ def initialize_fact_probabilities(self, unknown_facts: List[Triple],
57
+ scoring_module: NPLLScoringModule) -> torch.Tensor:
58
+ """
59
+ Initialize fact probabilities using scoring module
60
+ """
61
+ if not unknown_facts:
62
+ return torch.tensor([])
63
+
64
+ # Get initial scores from scoring module
65
+ with torch.no_grad():
66
+ initial_scores = scoring_module.forward(unknown_facts)
67
+ # Transform to probabilities using sigmoid
68
+ initial_probs = safe_sigmoid(initial_scores)
69
+
70
+ return initial_probs
71
+
72
+ def compute_q_u_distribution(self, fact_probs: torch.Tensor,
73
+ ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
74
+ """
75
+ Compute Q(U) distribution components
76
+ """
77
+ if len(fact_probs) == 0:
78
+ return {
79
+ 'fact_probs': torch.tensor([]),
80
+ 'log_probs': torch.tensor([]),
81
+ 'entropy': torch.tensor(0.0)
82
+ }
83
+
84
+ # Clamp probabilities for numerical stability
85
+ fact_probs_clamped = torch.clamp(fact_probs, min=1e-8, max=1.0 - 1e-8)
86
+
87
+ # Compute log probabilities
88
+ log_probs = safe_log(fact_probs_clamped)
89
+ log_neg_probs = safe_log(1 - fact_probs_clamped)
90
+
91
+ # Compute entropy of individual facts
92
+ fact_entropies = bernoulli_entropy(fact_probs_clamped)
93
+ total_entropy = torch.sum(fact_entropies)
94
+
95
+ result = {
96
+ 'fact_probs': fact_probs_clamped,
97
+ 'log_probs': log_probs,
98
+ 'log_neg_probs': log_neg_probs,
99
+ 'fact_entropies': fact_entropies,
100
+ 'total_entropy': total_entropy
101
+ }
102
+
103
+ # If ground rule structure provided, compute ground rule probabilities
104
+ if ground_rule_structure is not None:
105
+ ground_rule_probs = self._compute_ground_rule_probabilities(
106
+ fact_probs_clamped, ground_rule_structure
107
+ )
108
+ result.update(ground_rule_probs)
109
+
110
+ return result
111
+
112
+ def _compute_ground_rule_probabilities(self, fact_probs: torch.Tensor,
113
+ ground_rule_structure: List[List[int]]) -> Dict[str, torch.Tensor]:
114
+ """
115
+ Compute probabilities for ground rules under mean-field approximation
116
+
117
+ For ground rule with facts [i, j, k]: P(rule) = ∏ p_i * p_j * p_k
118
+ """
119
+ ground_rule_probs = []
120
+ ground_rule_log_probs = []
121
+
122
+ for fact_indices in ground_rule_structure:
123
+ if fact_indices:
124
+ # Get probabilities for facts in this ground rule
125
+ rule_fact_probs = fact_probs[fact_indices]
126
+
127
+ # Product probability (independence assumption)
128
+ rule_prob = torch.prod(rule_fact_probs)
129
+ rule_log_prob = torch.sum(safe_log(rule_fact_probs))
130
+
131
+ ground_rule_probs.append(rule_prob)
132
+ ground_rule_log_probs.append(rule_log_prob)
133
+ else:
134
+ ground_rule_probs.append(torch.tensor(0.0))
135
+ ground_rule_log_probs.append(torch.tensor(float('-inf')))
136
+
137
+ return {
138
+ 'ground_rule_probs': torch.stack(ground_rule_probs) if ground_rule_probs else torch.tensor([]),
139
+ 'ground_rule_log_probs': torch.stack(ground_rule_log_probs) if ground_rule_log_probs else torch.tensor([])
140
+ }
141
+
142
+
143
+ class EStepOptimizer(nn.Module):
144
+ """
145
+ Optimizes the E-step objective function
146
+ """
147
+
148
+ def __init__(self, config: NPLLConfig):
149
+ super().__init__()
150
+ self.config = config
151
+ self.mean_field = MeanFieldApproximation(config)
152
+ self.elbo_computer = ELBOComputer(config)
153
+ self.variational_inference = VariationalInference(config)
154
+
155
+ # Optimization parameters
156
+ self.max_iterations = config.mean_field_iterations
157
+ self.convergence_threshold = config.convergence_threshold
158
+
159
+ def optimize_approximate_posterior(self,
160
+ mln: MarkovLogicNetwork,
161
+ scoring_module: NPLLScoringModule,
162
+ known_facts: List[Triple],
163
+ unknown_facts: List[Triple],
164
+ ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> EStepResult:
165
+ """
166
+ Main E-step optimization procedure
167
+ """
168
+ if not unknown_facts:
169
+ # No unknown facts to optimize
170
+ return EStepResult(
171
+ approximate_posterior_probs=torch.tensor([]),
172
+ fact_probabilities={},
173
+ ground_rule_expectations=torch.tensor([]),
174
+ entropy=torch.tensor(0.0),
175
+ elbo_value=torch.tensor(0.0),
176
+ convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
177
+ iteration_count=0
178
+ )
179
+
180
+ logger.debug(f"Starting E-step optimization for {len(unknown_facts)} unknown facts")
181
+
182
+ # Initialize fact probabilities using scoring module
183
+ current_probs = self.mean_field.initialize_fact_probabilities(unknown_facts, scoring_module)
184
+
185
+ # Run a single VI optimization; it internally iterates and computes ELBO history
186
+ updated_result = self.variational_inference.optimize_approximate_posterior(
187
+ mln, known_facts, unknown_facts, current_probs.detach()
188
+ )
189
+ current_probs = updated_result['optimized_probs'].detach()
190
+ elbo_history = updated_result['elbo_history']
191
+ converged = updated_result['converged']
192
+ iteration = updated_result['iterations']
193
+
194
+ # Compute final Q(U) distribution
195
+ final_q_dist = self.mean_field.compute_q_u_distribution(current_probs)
196
+
197
+ # Compute ground rule expectations
198
+ ground_rule_expectations = self._compute_ground_rule_expectations(
199
+ mln, current_probs, unknown_facts
200
+ )
201
+
202
+ # Create fact probability dictionary
203
+ fact_prob_dict = {fact: current_probs[i].item()
204
+ for i, fact in enumerate(unknown_facts)}
205
+
206
+ # Final ELBO computation (detached to avoid gradient issues)
207
+ with torch.no_grad():
208
+ final_elbo_components = self.elbo_computer.compute_elbo(
209
+ mln, known_facts, unknown_facts, current_probs.detach()
210
+ )
211
+
212
+ convergence_info = {
213
+ 'converged': converged,
214
+ 'final_change': 0.0,
215
+ 'elbo_history': elbo_history,
216
+ 'reason': 'converged' if converged else 'max_iterations'
217
+ }
218
+
219
+ result = EStepResult(
220
+ approximate_posterior_probs=current_probs,
221
+ fact_probabilities=fact_prob_dict,
222
+ ground_rule_expectations=ground_rule_expectations,
223
+ entropy=final_q_dist['total_entropy'],
224
+ elbo_value=final_elbo_components.elbo,
225
+ convergence_info=convergence_info,
226
+ iteration_count=iteration + 1
227
+ )
228
+
229
+ logger.debug(f"E-step completed: {result}")
230
+
231
+ return result
232
+
233
+ def _compute_ground_rule_expectations(self,
234
+ mln: MarkovLogicNetwork,
235
+ fact_probs: torch.Tensor,
236
+ unknown_facts: List[Triple]) -> torch.Tensor:
237
+ """
238
+ Compute expected ground rule counts under Q(U)
239
+ """
240
+ if not mln.logical_rules or len(fact_probs) == 0:
241
+ return torch.tensor([])
242
+
243
+ # Create fact index mapping
244
+ fact_to_idx = {fact: i for i, fact in enumerate(unknown_facts)}
245
+
246
+ # Compute expectations for each logical rule
247
+ expected_counts = torch.zeros(len(mln.logical_rules))
248
+
249
+ for rule_idx, rule in enumerate(mln.logical_rules):
250
+ ground_rules = mln.rule_to_ground_rules[rule.rule_id]
251
+ rule_expectation = 0.0
252
+
253
+ for ground_rule in ground_rules:
254
+ # Compute expected satisfaction for this ground rule
255
+ ground_rule_factors = []
256
+
257
+ # Body facts
258
+ for fact in ground_rule.body_facts:
259
+ if fact in fact_to_idx:
260
+ # Unknown fact - use probability
261
+ fact_idx = fact_to_idx[fact]
262
+ ground_rule_factors.append(fact_probs[fact_idx])
263
+ else:
264
+ # Known fact - assume probability 1
265
+ ground_rule_factors.append(torch.tensor(1.0))
266
+
267
+ # Head fact
268
+ if ground_rule.head_fact in fact_to_idx:
269
+ fact_idx = fact_to_idx[ground_rule.head_fact]
270
+ ground_rule_factors.append(fact_probs[fact_idx])
271
+ else:
272
+ ground_rule_factors.append(torch.tensor(1.0))
273
+
274
+ # Expected satisfaction is product of all factors
275
+ if ground_rule_factors:
276
+ ground_rule_expectation = torch.prod(torch.stack(ground_rule_factors))
277
+ rule_expectation += ground_rule_expectation.item()
278
+
279
+ expected_counts[rule_idx] = rule_expectation
280
+
281
+ return expected_counts
282
+
283
+
284
+ class EStepRunner:
285
+ """
286
+ High-level runner for E-step computations
287
+ Handles batching, parallelization, and result aggregation
288
+ """
289
+
290
+ def __init__(self, config: NPLLConfig):
291
+ self.config = config
292
+ self.e_step_optimizer = EStepOptimizer(config)
293
+ self.ground_rule_sampler = GroundRuleSampler(config)
294
+
295
+ def run_e_step(self,
296
+ mln: MarkovLogicNetwork,
297
+ scoring_module: NPLLScoringModule,
298
+ kg: KnowledgeGraph) -> EStepResult:
299
+ known_facts = list(kg.known_facts)
300
+ unknown_facts = list(kg.unknown_facts)
301
+
302
+ if not unknown_facts:
303
+ logger.info("No unknown facts for E-step")
304
+ return EStepResult(
305
+ approximate_posterior_probs=torch.tensor([]),
306
+ fact_probabilities={},
307
+ ground_rule_expectations=torch.tensor([]),
308
+ entropy=torch.tensor(0.0),
309
+ elbo_value=torch.tensor(0.0),
310
+ convergence_info={'converged': True, 'reason': 'no_unknown_facts'},
311
+ iteration_count=0
312
+ )
313
+
314
+ logger.info(f"Running E-step for {len(known_facts)} known and {len(unknown_facts)} unknown facts")
315
+
316
+ # Create ground rule batches for efficient processing
317
+ ground_rule_batches = None
318
+ if len(mln.ground_rules) > self.config.batch_size:
319
+ ground_rule_batches = self.ground_rule_sampler.sample_ground_rules(
320
+ mln.ground_rules,
321
+ num_batches=max(1, len(mln.ground_rules) // self.config.batch_size),
322
+ sampling_strategy="uniform"
323
+ )
324
+
325
+ # Run optimization
326
+ result = self.e_step_optimizer.optimize_approximate_posterior(
327
+ mln, scoring_module, known_facts, unknown_facts, ground_rule_batches
328
+ )
329
+
330
+ logger.info(f"E-step completed: ELBO={result.elbo_value.item():.4f}, "
331
+ f"Entropy={result.entropy.item():.4f}, Iterations={result.iteration_count}")
332
+
333
+ return result
334
+
335
+ def run_e_step_with_constraints(self,
336
+ mln: MarkovLogicNetwork,
337
+ scoring_module: NPLLScoringModule,
338
+ kg: KnowledgeGraph,
339
+ fact_constraints: Optional[Dict[Triple, Tuple[float, float]]] = None) -> EStepResult:
340
+
341
+ # Standard E-step
342
+ result = self.run_e_step(mln, scoring_module, kg)
343
+
344
+ # Apply constraints if provided
345
+ if fact_constraints and len(result.approximate_posterior_probs) > 0:
346
+ constrained_probs = result.approximate_posterior_probs.clone()
347
+ unknown_facts = list(kg.unknown_facts)
348
+
349
+ for i, fact in enumerate(unknown_facts):
350
+ if fact in fact_constraints:
351
+ min_prob, max_prob = fact_constraints[fact]
352
+ constrained_probs[i] = torch.clamp(constrained_probs[i], min_prob, max_prob)
353
+
354
+ # Update result with constrained probabilities
355
+ result.approximate_posterior_probs = constrained_probs
356
+ result.fact_probabilities = {
357
+ fact: constrained_probs[i].item()
358
+ for i, fact in enumerate(unknown_facts)
359
+ }
360
+
361
+ return result
362
+
363
+
364
+ def create_e_step_runner(config: NPLLConfig) -> EStepRunner:
365
+ """Factory function to create E-step runner"""
366
+ return EStepRunner(config)
367
+
368
+
369
+ def verify_e_step_implementation():
370
+ """Verify E-step implementation"""
371
+ from ..utils.config import default_config
372
+ from ..core import load_knowledge_graph_from_triples
373
+ from ..core.mln import create_mln_from_kg_and_rules
374
+ from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
375
+ from ..scoring import create_scoring_module
376
+
377
+ # Create test data
378
+ test_triples = [
379
+ ("A", "r1", "B"),
380
+ ("B", "r2", "C"),
381
+ ]
382
+
383
+ kg = load_knowledge_graph_from_triples(test_triples)
384
+
385
+ # Add unknown facts
386
+ kg.add_unknown_fact("A", "r3", "C")
387
+
388
+ # Create test rule
389
+ from ..core import Relation
390
+ r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
391
+ x, y, z = Variable('x'), Variable('y'), Variable('z')
392
+
393
+ test_rule = LogicalRule(
394
+ rule_id="test_e_step_rule",
395
+ body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
396
+ head=Atom(r3, (x, z)),
397
+ rule_type=RuleType.TRANSITIVITY
398
+ )
399
+
400
+ # Create MLN and scoring module
401
+ mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
402
+ scoring_module = create_scoring_module(default_config, kg)
403
+
404
+ # Test E-step
405
+ e_step_runner = EStepRunner(default_config)
406
+ result = e_step_runner.run_e_step(mln, scoring_module, kg)
407
+
408
+ # Verify results
409
+ assert len(result.approximate_posterior_probs) == len(kg.unknown_facts), \
410
+ "Should have probabilities for all unknown facts"
411
+
412
+ assert torch.all(result.approximate_posterior_probs >= 0), "Probabilities should be non-negative"
413
+ assert torch.all(result.approximate_posterior_probs <= 1), "Probabilities should be <= 1"
414
+
415
+ assert torch.isfinite(result.elbo_value), "ELBO should be finite"
416
+ assert torch.isfinite(result.entropy), "Entropy should be finite"
417
+
418
+ logger.info("E-step implementation verified successfully")
419
+
420
420
  return True