odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/core/mln.py CHANGED
@@ -1,475 +1,475 @@
1
- """
2
- Markov Logic Network (MLN) implementation for NPLL
3
- """
4
-
5
- import torch
6
- import torch.nn as nn
7
- from typing import List, Dict, Set, Tuple, Optional, Any
8
- from collections import defaultdict
9
- import logging
10
- from dataclasses import dataclass
11
-
12
- from .knowledge_graph import KnowledgeGraph, Triple
13
- from .logical_rules import LogicalRule, GroundRule
14
- from ..utils.config import NPLLConfig
15
- from ..utils.math_utils import log_sum_exp, partition_function_approximation, compute_mln_probability
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- @dataclass
21
- class MLNState:
22
-
23
- fact_assignments: Dict[Triple, bool] # Truth values for all facts
24
- known_facts: Set[Triple] # Known facts F
25
- unknown_facts: Set[Triple] # Unknown facts U
26
-
27
- def __post_init__(self):
28
- """Validate MLN state"""
29
- all_facts = set(self.fact_assignments.keys())
30
- expected_facts = self.known_facts | self.unknown_facts
31
-
32
- if all_facts != expected_facts:
33
- missing = expected_facts - all_facts
34
- extra = all_facts - expected_facts
35
- logger.warning(f"MLN state inconsistency. Missing: {len(missing)}, Extra: {len(extra)}")
36
-
37
- def evaluate_ground_rule(self, ground_rule: GroundRule) -> bool:
38
-
39
- # Check if all body facts are true
40
- body_satisfied = all(
41
- self.fact_assignments.get(fact, False)
42
- for fact in ground_rule.body_facts
43
- )
44
-
45
- # If body is false, rule is vacuously true
46
- if not body_satisfied:
47
- return True
48
-
49
- # If body is true, check if head is true
50
- head_satisfied = self.fact_assignments.get(ground_rule.head_fact, False)
51
- return head_satisfied
52
-
53
- def count_satisfied_ground_rules(self, ground_rules: List[GroundRule]) -> int:
54
- """Count number of ground rules satisfied in this state"""
55
- return sum(1 for gr in ground_rules if self.evaluate_ground_rule(gr))
56
-
57
-
58
- class MarkovLogicNetwork(nn.Module):
59
-
60
- def __init__(self, config: NPLLConfig):
61
- super().__init__()
62
- self.config = config
63
-
64
- # Core MLN components
65
- self.knowledge_graph: Optional[KnowledgeGraph] = None
66
- self.logical_rules: List[LogicalRule] = []
67
- self.ground_rules: List[GroundRule] = []
68
-
69
- # Rule weights ω (learnable parameters)
70
- self.rule_weights: Optional[nn.Parameter] = None
71
-
72
- # Ground rule organization
73
- self.rule_to_ground_rules: Dict[str, List[GroundRule]] = defaultdict(list)
74
- self.ground_rule_facts: Set[Triple] = set()
75
- # Inverted index for fast lookup: fact -> ground rules containing it
76
- self.fact_to_groundrules: Dict[Triple, List[GroundRule]] = defaultdict(list)
77
-
78
- # Caching for efficiency
79
- self._partition_function_cache: Dict[str, torch.Tensor] = {}
80
- self._ground_rule_counts_cache: Optional[torch.Tensor] = None
81
-
82
- def add_knowledge_graph(self, kg: KnowledgeGraph):
83
- """Add knowledge graph to MLN"""
84
- self.knowledge_graph = kg
85
- logger.info(f"Added knowledge graph with {len(kg.known_facts)} known facts")
86
-
87
- def add_logical_rules(self, rules: List[LogicalRule]):
88
- self.logical_rules.extend(rules)
89
-
90
- # Initialize or expand rule weights
91
- if self.rule_weights is None:
92
- # Initialize rule weights to small random values
93
- initial_weights = torch.randn(len(rules)) * 0.1
94
- self.rule_weights = nn.Parameter(initial_weights, requires_grad=True)
95
- else:
96
- # Expand existing weights
97
- old_weights = self.rule_weights.data
98
- new_weights = torch.randn(len(rules)) * 0.1
99
- expanded_weights = torch.cat([old_weights, new_weights])
100
- self.rule_weights = nn.Parameter(expanded_weights, requires_grad=True)
101
-
102
- # Generate ground rules for new logical rules
103
- if self.knowledge_graph is not None:
104
- self._generate_ground_rules(rules)
105
-
106
- logger.info(f"Added {len(rules)} logical rules. Total: {len(self.logical_rules)}")
107
-
108
- def _generate_ground_rules(self, rules: List[LogicalRule]):
109
- """Generate ground rules from logical rules using knowledge graph"""
110
- new_ground_rules = []
111
-
112
- for rule in rules:
113
- # Generate ground rules for this logical rule
114
- ground_rules = rule.generate_ground_rules(
115
- self.knowledge_graph,
116
- max_groundings=self.config.max_ground_rules
117
- )
118
-
119
- # Add to collections
120
- new_ground_rules.extend(ground_rules)
121
- self.rule_to_ground_rules[rule.rule_id].extend(ground_rules)
122
-
123
- # Collect all facts involved in ground rules
124
- for gr in ground_rules:
125
- self.ground_rule_facts.update(gr.get_all_facts())
126
- # Build inverted index for each fact
127
- for f in gr.get_all_facts():
128
- self.fact_to_groundrules[f].append(gr)
129
-
130
- self.ground_rules.extend(new_ground_rules)
131
- logger.info(f"Generated {len(new_ground_rules)} ground rules. Total: {len(self.ground_rules)}")
132
-
133
- def compute_ground_rule_counts(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
134
- """
135
- Compute N(F,U) - number of satisfied ground rules for each logical rule
136
- """
137
- if not self.logical_rules:
138
- return torch.tensor([])
139
-
140
- rule_counts = torch.zeros(len(self.logical_rules))
141
-
142
- for rule_idx, rule in enumerate(self.logical_rules):
143
- ground_rules = self.rule_to_ground_rules[rule.rule_id]
144
- satisfied_count = 0
145
-
146
- for ground_rule in ground_rules:
147
- # Check if this ground rule is satisfied
148
- all_facts = ground_rule.get_all_facts()
149
-
150
- # Check if all required facts are true (for body) and conclusion follows
151
- body_satisfied = all(
152
- fact_assignments.get(fact, False)
153
- for fact in ground_rule.body_facts
154
- )
155
-
156
- if body_satisfied:
157
- # If body is true, rule is satisfied if head is also true
158
- head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
159
- if head_satisfied:
160
- satisfied_count += 1
161
- else:
162
- # If body is false, rule is vacuously satisfied
163
- satisfied_count += 1
164
-
165
- rule_counts[rule_idx] = satisfied_count
166
-
167
- return rule_counts
168
-
169
- def compute_partition_function(self, sample_states: Optional[List[Dict[Triple, bool]]] = None,
170
- use_approximation: bool = True) -> torch.Tensor:
171
- """
172
- Compute MLN partition function Z(ω) from Equation 2
173
- """
174
- if self.rule_weights is None or len(self.logical_rules) == 0:
175
- return torch.tensor(0.0)
176
-
177
- # Use caching if available
178
- cache_key = str(self.rule_weights.data.tolist())
179
- if cache_key in self._partition_function_cache:
180
- return self._partition_function_cache[cache_key]
181
-
182
- if use_approximation or sample_states is not None:
183
- # Sampling-based approximation
184
- if sample_states is None:
185
- sample_states = self._generate_sample_states(num_samples=1000)
186
-
187
- # Compute counts for each sample state
188
- all_counts = []
189
- for state_assignment in sample_states:
190
- counts = self.compute_ground_rule_counts(state_assignment)
191
- all_counts.append(counts)
192
-
193
- if all_counts:
194
- counts_tensor = torch.stack(all_counts) # [num_samples, num_rules]
195
- log_partition = partition_function_approximation(
196
- self.rule_weights, counts_tensor, use_log_domain=True
197
- )
198
- else:
199
- log_partition = torch.tensor(0.0)
200
- else:
201
- # Exact computation (intractable for large graphs)
202
- logger.warning("Exact partition function computation is intractable for large graphs")
203
- log_partition = self._compute_exact_partition_function()
204
-
205
- # Cache result
206
- self._partition_function_cache[cache_key] = log_partition
207
-
208
- return log_partition
209
-
210
- def compute_joint_probability(self, fact_assignments: Dict[Triple, bool],
211
- log_partition: Optional[torch.Tensor] = None,
212
- detach_weights: bool = False) -> torch.Tensor:
213
- """
214
- Compute joint probability P(F,U|ω) from Equation 1
215
-
216
- """
217
- if self.rule_weights is None:
218
- return torch.tensor(0.0)
219
-
220
- # Compute ground rule counts N(F,U)
221
- counts = self.compute_ground_rule_counts(fact_assignments)
222
-
223
- # Compute log partition function if not provided
224
- if log_partition is None:
225
- log_partition = self.compute_partition_function()
226
-
227
- # Compute log probability using utility function
228
- weights_to_use = self.rule_weights.detach() if detach_weights else self.rule_weights
229
- log_prob = compute_mln_probability(
230
- weights_to_use, counts.unsqueeze(0), log_partition
231
- )
232
-
233
- return log_prob.squeeze(0)
234
-
235
- def _generate_sample_states(self, num_samples: int = 1000) -> List[Dict[Triple, bool]]:
236
- """
237
- Generate sample states for partition function approximation
238
- """
239
- sample_states = []
240
-
241
- if self.knowledge_graph is None:
242
- return sample_states
243
-
244
- # Get all facts that appear in ground rules
245
- all_facts = list(self.ground_rule_facts)
246
-
247
- if not all_facts:
248
- return sample_states
249
-
250
- # Generate random assignments
251
- for _ in range(num_samples):
252
- # Start with known facts as true
253
- assignment = {}
254
-
255
- # Set known facts to true
256
- for fact in self.knowledge_graph.known_facts:
257
- assignment[fact] = True
258
-
259
- # Randomly assign unknown facts
260
- unknown_facts_in_rules = [f for f in all_facts if f not in assignment]
261
- for fact in unknown_facts_in_rules:
262
- # Assign random truth value (could be made smarter)
263
- assignment[fact] = torch.rand(1).item() > 0.5
264
-
265
- sample_states.append(assignment)
266
-
267
- return sample_states
268
-
269
- def _compute_exact_partition_function(self) -> torch.Tensor:
270
- """
271
- Compute exact partition function
272
- """
273
- if not self.ground_rule_facts:
274
- return torch.tensor(0.0)
275
-
276
- all_facts = list(self.ground_rule_facts)
277
- num_facts = len(all_facts)
278
-
279
- if num_facts > 20: # Arbitrary limit to prevent memory explosion
280
- logger.error(f"Too many facts ({num_facts}) for exact partition function computation")
281
- return self.compute_partition_function(use_approximation=True)
282
-
283
- # Enumerate all possible truth assignments
284
- total_log_prob = []
285
-
286
- for i in range(2 ** num_facts):
287
- # Generate truth assignment from binary representation
288
- assignment = {}
289
- for j, fact in enumerate(all_facts):
290
- assignment[fact] = bool((i >> j) & 1)
291
-
292
- # Compute counts for this assignment
293
- counts = self.compute_ground_rule_counts(assignment)
294
-
295
- # Compute potential
296
- log_potential = torch.sum(self.rule_weights * counts)
297
- total_log_prob.append(log_potential)
298
-
299
- # Compute log-sum-exp
300
- if total_log_prob:
301
- log_partition = log_sum_exp(torch.stack(total_log_prob))
302
- else:
303
- log_partition = torch.tensor(0.0)
304
-
305
- return log_partition
306
-
307
- def get_rule_statistics(self) -> Dict[str, Any]:
308
- """Get statistics about the MLN"""
309
- stats = {
310
- 'num_logical_rules': len(self.logical_rules),
311
- 'num_ground_rules': len(self.ground_rules),
312
- 'num_facts_in_ground_rules': len(self.ground_rule_facts),
313
- 'rule_weights': self.rule_weights.data.tolist() if self.rule_weights is not None else []
314
- }
315
-
316
- # Per-rule statistics
317
- rule_stats = []
318
- for i, rule in enumerate(self.logical_rules):
319
- ground_rules = self.rule_to_ground_rules[rule.rule_id]
320
- rule_stat = {
321
- 'rule_id': rule.rule_id,
322
- 'rule_type': rule.rule_type.value,
323
- 'num_ground_rules': len(ground_rules),
324
- 'weight': self.rule_weights[i].item() if self.rule_weights is not None else 0.0,
325
- 'learned_confidence': torch.sigmoid(self.rule_weights[i]).item() if self.rule_weights is not None else None,
326
- 'support': rule.support
327
- }
328
- rule_stats.append(rule_stat)
329
-
330
- stats['rule_details'] = rule_stats
331
-
332
- return stats
333
-
334
- def forward(self, fact_assignments_batch: List[Dict[Triple, bool]]) -> torch.Tensor:
335
- """
336
- Forward pass for batch of fact assignments
337
- """
338
- if not fact_assignments_batch:
339
- return torch.tensor([])
340
-
341
- # Compute partition function once
342
- log_partition = self.compute_partition_function()
343
-
344
- # Compute probabilities for each assignment
345
- log_probs = []
346
- for assignment in fact_assignments_batch:
347
- log_prob = self.compute_joint_probability(assignment, log_partition)
348
- log_probs.append(log_prob)
349
-
350
- return torch.stack(log_probs) if log_probs else torch.tensor([])
351
-
352
- def sample_from_distribution(self, num_samples: int = 100) -> List[Dict[Triple, bool]]:
353
- """
354
- Sample fact assignments from MLN distribution using Gibbs sampling
355
- """
356
- if not self.ground_rule_facts:
357
- return []
358
-
359
- samples = []
360
- all_facts = list(self.ground_rule_facts)
361
-
362
- # Initialize with random assignment
363
- current_assignment = {fact: torch.rand(1).item() > 0.5 for fact in all_facts}
364
-
365
- # Set known facts to true (they don't change)
366
- if self.knowledge_graph:
367
- for fact in self.knowledge_graph.known_facts:
368
- current_assignment[fact] = True
369
-
370
- # Gibbs sampling
371
- for _ in range(num_samples):
372
- # Sample each unknown fact given others
373
- for fact in all_facts:
374
- if self.knowledge_graph and fact in self.knowledge_graph.known_facts:
375
- continue # Skip known facts
376
-
377
- # Compute conditional probability P(fact=True | others)
378
- prob_true = self._compute_conditional_probability(fact, current_assignment)
379
-
380
- # Sample from Bernoulli distribution
381
- current_assignment[fact] = torch.rand(1).item() < prob_true
382
-
383
- # Store sample
384
- samples.append(current_assignment.copy())
385
-
386
- return samples
387
-
388
- def _compute_conditional_probability(self, target_fact: Triple,
389
- current_assignment: Dict[Triple, bool]) -> float:
390
- """
391
- Compute P(target_fact=True | other_facts) using local MLN structure
392
- """
393
- # Create two assignments: one with target_fact=True, one with False
394
- assignment_true = current_assignment.copy()
395
- assignment_false = current_assignment.copy()
396
- assignment_true[target_fact] = True
397
- assignment_false[target_fact] = False
398
-
399
- # Compute unnormalized probabilities
400
- log_prob_true = self.compute_joint_probability(assignment_true)
401
- log_prob_false = self.compute_joint_probability(assignment_false)
402
-
403
- # Normalize using log-sum-exp
404
- log_probs = torch.stack([log_prob_false, log_prob_true])
405
- normalized_probs = torch.softmax(log_probs, dim=0)
406
-
407
- return normalized_probs[1].item() # Return P(target_fact=True)
408
-
409
-
410
- def create_mln_from_kg_and_rules(kg: KnowledgeGraph, rules: List[LogicalRule],
411
- config: NPLLConfig) -> MarkovLogicNetwork:
412
- """
413
- Factory function to create MLN from knowledge graph and logical rules
414
- """
415
- mln = MarkovLogicNetwork(config)
416
- mln.add_knowledge_graph(kg)
417
- mln.add_logical_rules(rules)
418
-
419
- logger.info(f"Created MLN with {len(rules)} rules and {len(mln.ground_rules)} ground rules")
420
-
421
- return mln
422
-
423
-
424
- def verify_mln_implementation():
425
- """Verify MLN implementation with small test case"""
426
- from ..utils.config import default_config
427
- from .knowledge_graph import Entity, Relation, load_knowledge_graph_from_triples
428
- from .logical_rules import Variable, Atom, RuleType
429
-
430
- # Create test knowledge graph
431
- test_triples = [
432
- ("Tom", "plays", "basketball"),
433
- ("Tom", "friend", "John"),
434
- ("John", "plays", "soccer")
435
- ]
436
-
437
- kg = load_knowledge_graph_from_triples(test_triples, "TestKG")
438
-
439
- # Create test rule: plays(x, y) ∧ friend(x, z) ⇒ plays(z, y)
440
- plays_rel = Relation("plays")
441
- friend_rel = Relation("friend")
442
-
443
- x, y, z = Variable('x'), Variable('y'), Variable('z')
444
-
445
- body_atoms = [
446
- Atom(plays_rel, (x, y)),
447
- Atom(friend_rel, (x, z))
448
- ]
449
- head_atom = Atom(plays_rel, (z, y))
450
-
451
- test_rule = LogicalRule(
452
- rule_id="test_transitivity",
453
- body=body_atoms,
454
- head=head_atom,
455
- rule_type=RuleType.TRANSITIVITY,
456
- confidence=0.8
457
- )
458
-
459
- # Create MLN
460
- mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
461
-
462
- # Verify MLN properties
463
- assert len(mln.logical_rules) == 1, "Should have 1 logical rule"
464
- assert len(mln.ground_rules) > 0, "Should have generated ground rules"
465
- assert mln.rule_weights is not None, "Should have initialized rule weights"
466
-
467
- # Test probability computation
468
- test_assignment = {fact: True for fact in kg.known_facts}
469
- log_prob = mln.compute_joint_probability(test_assignment)
470
-
471
- assert torch.isfinite(log_prob), "Joint probability should be finite"
472
-
473
- logger.info("MLN implementation verified successfully")
474
-
1
+ """
2
+ Markov Logic Network (MLN) implementation for NPLL
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from typing import List, Dict, Set, Tuple, Optional, Any
8
+ from collections import defaultdict
9
+ import logging
10
+ from dataclasses import dataclass
11
+
12
+ from .knowledge_graph import KnowledgeGraph, Triple
13
+ from .logical_rules import LogicalRule, GroundRule
14
+ from ..utils.config import NPLLConfig
15
+ from ..utils.math_utils import log_sum_exp, partition_function_approximation, compute_mln_probability
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class MLNState:
22
+
23
+ fact_assignments: Dict[Triple, bool] # Truth values for all facts
24
+ known_facts: Set[Triple] # Known facts F
25
+ unknown_facts: Set[Triple] # Unknown facts U
26
+
27
+ def __post_init__(self):
28
+ """Validate MLN state"""
29
+ all_facts = set(self.fact_assignments.keys())
30
+ expected_facts = self.known_facts | self.unknown_facts
31
+
32
+ if all_facts != expected_facts:
33
+ missing = expected_facts - all_facts
34
+ extra = all_facts - expected_facts
35
+ logger.warning(f"MLN state inconsistency. Missing: {len(missing)}, Extra: {len(extra)}")
36
+
37
+ def evaluate_ground_rule(self, ground_rule: GroundRule) -> bool:
38
+
39
+ # Check if all body facts are true
40
+ body_satisfied = all(
41
+ self.fact_assignments.get(fact, False)
42
+ for fact in ground_rule.body_facts
43
+ )
44
+
45
+ # If body is false, rule is vacuously true
46
+ if not body_satisfied:
47
+ return True
48
+
49
+ # If body is true, check if head is true
50
+ head_satisfied = self.fact_assignments.get(ground_rule.head_fact, False)
51
+ return head_satisfied
52
+
53
+ def count_satisfied_ground_rules(self, ground_rules: List[GroundRule]) -> int:
54
+ """Count number of ground rules satisfied in this state"""
55
+ return sum(1 for gr in ground_rules if self.evaluate_ground_rule(gr))
56
+
57
+
58
+ class MarkovLogicNetwork(nn.Module):
59
+
60
+ def __init__(self, config: NPLLConfig):
61
+ super().__init__()
62
+ self.config = config
63
+
64
+ # Core MLN components
65
+ self.knowledge_graph: Optional[KnowledgeGraph] = None
66
+ self.logical_rules: List[LogicalRule] = []
67
+ self.ground_rules: List[GroundRule] = []
68
+
69
+ # Rule weights ω (learnable parameters)
70
+ self.rule_weights: Optional[nn.Parameter] = None
71
+
72
+ # Ground rule organization
73
+ self.rule_to_ground_rules: Dict[str, List[GroundRule]] = defaultdict(list)
74
+ self.ground_rule_facts: Set[Triple] = set()
75
+ # Inverted index for fast lookup: fact -> ground rules containing it
76
+ self.fact_to_groundrules: Dict[Triple, List[GroundRule]] = defaultdict(list)
77
+
78
+ # Caching for efficiency
79
+ self._partition_function_cache: Dict[str, torch.Tensor] = {}
80
+ self._ground_rule_counts_cache: Optional[torch.Tensor] = None
81
+
82
+ def add_knowledge_graph(self, kg: KnowledgeGraph):
83
+ """Add knowledge graph to MLN"""
84
+ self.knowledge_graph = kg
85
+ logger.info(f"Added knowledge graph with {len(kg.known_facts)} known facts")
86
+
87
+ def add_logical_rules(self, rules: List[LogicalRule]):
88
+ self.logical_rules.extend(rules)
89
+
90
+ # Initialize or expand rule weights
91
+ if self.rule_weights is None:
92
+ # Initialize rule weights to small random values
93
+ initial_weights = torch.randn(len(rules)) * 0.1
94
+ self.rule_weights = nn.Parameter(initial_weights, requires_grad=True)
95
+ else:
96
+ # Expand existing weights
97
+ old_weights = self.rule_weights.data
98
+ new_weights = torch.randn(len(rules)) * 0.1
99
+ expanded_weights = torch.cat([old_weights, new_weights])
100
+ self.rule_weights = nn.Parameter(expanded_weights, requires_grad=True)
101
+
102
+ # Generate ground rules for new logical rules
103
+ if self.knowledge_graph is not None:
104
+ self._generate_ground_rules(rules)
105
+
106
+ logger.info(f"Added {len(rules)} logical rules. Total: {len(self.logical_rules)}")
107
+
108
+ def _generate_ground_rules(self, rules: List[LogicalRule]):
109
+ """Generate ground rules from logical rules using knowledge graph"""
110
+ new_ground_rules = []
111
+
112
+ for rule in rules:
113
+ # Generate ground rules for this logical rule
114
+ ground_rules = rule.generate_ground_rules(
115
+ self.knowledge_graph,
116
+ max_groundings=self.config.max_ground_rules
117
+ )
118
+
119
+ # Add to collections
120
+ new_ground_rules.extend(ground_rules)
121
+ self.rule_to_ground_rules[rule.rule_id].extend(ground_rules)
122
+
123
+ # Collect all facts involved in ground rules
124
+ for gr in ground_rules:
125
+ self.ground_rule_facts.update(gr.get_all_facts())
126
+ # Build inverted index for each fact
127
+ for f in gr.get_all_facts():
128
+ self.fact_to_groundrules[f].append(gr)
129
+
130
+ self.ground_rules.extend(new_ground_rules)
131
+ logger.info(f"Generated {len(new_ground_rules)} ground rules. Total: {len(self.ground_rules)}")
132
+
133
+ def compute_ground_rule_counts(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
134
+ """
135
+ Compute N(F,U) - number of satisfied ground rules for each logical rule
136
+ """
137
+ if not self.logical_rules:
138
+ return torch.tensor([])
139
+
140
+ rule_counts = torch.zeros(len(self.logical_rules))
141
+
142
+ for rule_idx, rule in enumerate(self.logical_rules):
143
+ ground_rules = self.rule_to_ground_rules[rule.rule_id]
144
+ satisfied_count = 0
145
+
146
+ for ground_rule in ground_rules:
147
+ # Check if this ground rule is satisfied
148
+ all_facts = ground_rule.get_all_facts()
149
+
150
+ # Check if all required facts are true (for body) and conclusion follows
151
+ body_satisfied = all(
152
+ fact_assignments.get(fact, False)
153
+ for fact in ground_rule.body_facts
154
+ )
155
+
156
+ if body_satisfied:
157
+ # If body is true, rule is satisfied if head is also true
158
+ head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
159
+ if head_satisfied:
160
+ satisfied_count += 1
161
+ else:
162
+ # If body is false, rule is vacuously satisfied
163
+ satisfied_count += 1
164
+
165
+ rule_counts[rule_idx] = satisfied_count
166
+
167
+ return rule_counts
168
+
169
+ def compute_partition_function(self, sample_states: Optional[List[Dict[Triple, bool]]] = None,
170
+ use_approximation: bool = True) -> torch.Tensor:
171
+ """
172
+ Compute MLN partition function Z(ω) from Equation 2
173
+ """
174
+ if self.rule_weights is None or len(self.logical_rules) == 0:
175
+ return torch.tensor(0.0)
176
+
177
+ # Use caching if available
178
+ cache_key = str(self.rule_weights.data.tolist())
179
+ if cache_key in self._partition_function_cache:
180
+ return self._partition_function_cache[cache_key]
181
+
182
+ if use_approximation or sample_states is not None:
183
+ # Sampling-based approximation
184
+ if sample_states is None:
185
+ sample_states = self._generate_sample_states(num_samples=1000)
186
+
187
+ # Compute counts for each sample state
188
+ all_counts = []
189
+ for state_assignment in sample_states:
190
+ counts = self.compute_ground_rule_counts(state_assignment)
191
+ all_counts.append(counts)
192
+
193
+ if all_counts:
194
+ counts_tensor = torch.stack(all_counts) # [num_samples, num_rules]
195
+ log_partition = partition_function_approximation(
196
+ self.rule_weights, counts_tensor, use_log_domain=True
197
+ )
198
+ else:
199
+ log_partition = torch.tensor(0.0)
200
+ else:
201
+ # Exact computation (intractable for large graphs)
202
+ logger.warning("Exact partition function computation is intractable for large graphs")
203
+ log_partition = self._compute_exact_partition_function()
204
+
205
+ # Cache result
206
+ self._partition_function_cache[cache_key] = log_partition
207
+
208
+ return log_partition
209
+
210
+ def compute_joint_probability(self, fact_assignments: Dict[Triple, bool],
211
+ log_partition: Optional[torch.Tensor] = None,
212
+ detach_weights: bool = False) -> torch.Tensor:
213
+ """
214
+ Compute joint probability P(F,U|ω) from Equation 1
215
+
216
+ """
217
+ if self.rule_weights is None:
218
+ return torch.tensor(0.0)
219
+
220
+ # Compute ground rule counts N(F,U)
221
+ counts = self.compute_ground_rule_counts(fact_assignments)
222
+
223
+ # Compute log partition function if not provided
224
+ if log_partition is None:
225
+ log_partition = self.compute_partition_function()
226
+
227
+ # Compute log probability using utility function
228
+ weights_to_use = self.rule_weights.detach() if detach_weights else self.rule_weights
229
+ log_prob = compute_mln_probability(
230
+ weights_to_use, counts.unsqueeze(0), log_partition
231
+ )
232
+
233
+ return log_prob.squeeze(0)
234
+
235
+ def _generate_sample_states(self, num_samples: int = 1000) -> List[Dict[Triple, bool]]:
236
+ """
237
+ Generate sample states for partition function approximation
238
+ """
239
+ sample_states = []
240
+
241
+ if self.knowledge_graph is None:
242
+ return sample_states
243
+
244
+ # Get all facts that appear in ground rules
245
+ all_facts = list(self.ground_rule_facts)
246
+
247
+ if not all_facts:
248
+ return sample_states
249
+
250
+ # Generate random assignments
251
+ for _ in range(num_samples):
252
+ # Start with known facts as true
253
+ assignment = {}
254
+
255
+ # Set known facts to true
256
+ for fact in self.knowledge_graph.known_facts:
257
+ assignment[fact] = True
258
+
259
+ # Randomly assign unknown facts
260
+ unknown_facts_in_rules = [f for f in all_facts if f not in assignment]
261
+ for fact in unknown_facts_in_rules:
262
+ # Assign random truth value (could be made smarter)
263
+ assignment[fact] = torch.rand(1).item() > 0.5
264
+
265
+ sample_states.append(assignment)
266
+
267
+ return sample_states
268
+
269
+ def _compute_exact_partition_function(self) -> torch.Tensor:
270
+ """
271
+ Compute exact partition function
272
+ """
273
+ if not self.ground_rule_facts:
274
+ return torch.tensor(0.0)
275
+
276
+ all_facts = list(self.ground_rule_facts)
277
+ num_facts = len(all_facts)
278
+
279
+ if num_facts > 20: # Arbitrary limit to prevent memory explosion
280
+ logger.error(f"Too many facts ({num_facts}) for exact partition function computation")
281
+ return self.compute_partition_function(use_approximation=True)
282
+
283
+ # Enumerate all possible truth assignments
284
+ total_log_prob = []
285
+
286
+ for i in range(2 ** num_facts):
287
+ # Generate truth assignment from binary representation
288
+ assignment = {}
289
+ for j, fact in enumerate(all_facts):
290
+ assignment[fact] = bool((i >> j) & 1)
291
+
292
+ # Compute counts for this assignment
293
+ counts = self.compute_ground_rule_counts(assignment)
294
+
295
+ # Compute potential
296
+ log_potential = torch.sum(self.rule_weights * counts)
297
+ total_log_prob.append(log_potential)
298
+
299
+ # Compute log-sum-exp
300
+ if total_log_prob:
301
+ log_partition = log_sum_exp(torch.stack(total_log_prob))
302
+ else:
303
+ log_partition = torch.tensor(0.0)
304
+
305
+ return log_partition
306
+
307
+ def get_rule_statistics(self) -> Dict[str, Any]:
308
+ """Get statistics about the MLN"""
309
+ stats = {
310
+ 'num_logical_rules': len(self.logical_rules),
311
+ 'num_ground_rules': len(self.ground_rules),
312
+ 'num_facts_in_ground_rules': len(self.ground_rule_facts),
313
+ 'rule_weights': self.rule_weights.data.tolist() if self.rule_weights is not None else []
314
+ }
315
+
316
+ # Per-rule statistics
317
+ rule_stats = []
318
+ for i, rule in enumerate(self.logical_rules):
319
+ ground_rules = self.rule_to_ground_rules[rule.rule_id]
320
+ rule_stat = {
321
+ 'rule_id': rule.rule_id,
322
+ 'rule_type': rule.rule_type.value,
323
+ 'num_ground_rules': len(ground_rules),
324
+ 'weight': self.rule_weights[i].item() if self.rule_weights is not None else 0.0,
325
+ 'learned_confidence': torch.sigmoid(self.rule_weights[i]).item() if self.rule_weights is not None else None,
326
+ 'support': rule.support
327
+ }
328
+ rule_stats.append(rule_stat)
329
+
330
+ stats['rule_details'] = rule_stats
331
+
332
+ return stats
333
+
334
+ def forward(self, fact_assignments_batch: List[Dict[Triple, bool]]) -> torch.Tensor:
335
+ """
336
+ Forward pass for batch of fact assignments
337
+ """
338
+ if not fact_assignments_batch:
339
+ return torch.tensor([])
340
+
341
+ # Compute partition function once
342
+ log_partition = self.compute_partition_function()
343
+
344
+ # Compute probabilities for each assignment
345
+ log_probs = []
346
+ for assignment in fact_assignments_batch:
347
+ log_prob = self.compute_joint_probability(assignment, log_partition)
348
+ log_probs.append(log_prob)
349
+
350
+ return torch.stack(log_probs) if log_probs else torch.tensor([])
351
+
352
+ def sample_from_distribution(self, num_samples: int = 100) -> List[Dict[Triple, bool]]:
353
+ """
354
+ Sample fact assignments from MLN distribution using Gibbs sampling
355
+ """
356
+ if not self.ground_rule_facts:
357
+ return []
358
+
359
+ samples = []
360
+ all_facts = list(self.ground_rule_facts)
361
+
362
+ # Initialize with random assignment
363
+ current_assignment = {fact: torch.rand(1).item() > 0.5 for fact in all_facts}
364
+
365
+ # Set known facts to true (they don't change)
366
+ if self.knowledge_graph:
367
+ for fact in self.knowledge_graph.known_facts:
368
+ current_assignment[fact] = True
369
+
370
+ # Gibbs sampling
371
+ for _ in range(num_samples):
372
+ # Sample each unknown fact given others
373
+ for fact in all_facts:
374
+ if self.knowledge_graph and fact in self.knowledge_graph.known_facts:
375
+ continue # Skip known facts
376
+
377
+ # Compute conditional probability P(fact=True | others)
378
+ prob_true = self._compute_conditional_probability(fact, current_assignment)
379
+
380
+ # Sample from Bernoulli distribution
381
+ current_assignment[fact] = torch.rand(1).item() < prob_true
382
+
383
+ # Store sample
384
+ samples.append(current_assignment.copy())
385
+
386
+ return samples
387
+
388
+ def _compute_conditional_probability(self, target_fact: Triple,
389
+ current_assignment: Dict[Triple, bool]) -> float:
390
+ """
391
+ Compute P(target_fact=True | other_facts) using local MLN structure
392
+ """
393
+ # Create two assignments: one with target_fact=True, one with False
394
+ assignment_true = current_assignment.copy()
395
+ assignment_false = current_assignment.copy()
396
+ assignment_true[target_fact] = True
397
+ assignment_false[target_fact] = False
398
+
399
+ # Compute unnormalized probabilities
400
+ log_prob_true = self.compute_joint_probability(assignment_true)
401
+ log_prob_false = self.compute_joint_probability(assignment_false)
402
+
403
+ # Normalize using log-sum-exp
404
+ log_probs = torch.stack([log_prob_false, log_prob_true])
405
+ normalized_probs = torch.softmax(log_probs, dim=0)
406
+
407
+ return normalized_probs[1].item() # Return P(target_fact=True)
408
+
409
+
410
+ def create_mln_from_kg_and_rules(kg: KnowledgeGraph, rules: List[LogicalRule],
411
+ config: NPLLConfig) -> MarkovLogicNetwork:
412
+ """
413
+ Factory function to create MLN from knowledge graph and logical rules
414
+ """
415
+ mln = MarkovLogicNetwork(config)
416
+ mln.add_knowledge_graph(kg)
417
+ mln.add_logical_rules(rules)
418
+
419
+ logger.info(f"Created MLN with {len(rules)} rules and {len(mln.ground_rules)} ground rules")
420
+
421
+ return mln
422
+
423
+
424
+ def verify_mln_implementation():
425
+ """Verify MLN implementation with small test case"""
426
+ from ..utils.config import default_config
427
+ from .knowledge_graph import Entity, Relation, load_knowledge_graph_from_triples
428
+ from .logical_rules import Variable, Atom, RuleType
429
+
430
+ # Create test knowledge graph
431
+ test_triples = [
432
+ ("Tom", "plays", "basketball"),
433
+ ("Tom", "friend", "John"),
434
+ ("John", "plays", "soccer")
435
+ ]
436
+
437
+ kg = load_knowledge_graph_from_triples(test_triples, "TestKG")
438
+
439
+ # Create test rule: plays(x, y) ∧ friend(x, z) ⇒ plays(z, y)
440
+ plays_rel = Relation("plays")
441
+ friend_rel = Relation("friend")
442
+
443
+ x, y, z = Variable('x'), Variable('y'), Variable('z')
444
+
445
+ body_atoms = [
446
+ Atom(plays_rel, (x, y)),
447
+ Atom(friend_rel, (x, z))
448
+ ]
449
+ head_atom = Atom(plays_rel, (z, y))
450
+
451
+ test_rule = LogicalRule(
452
+ rule_id="test_transitivity",
453
+ body=body_atoms,
454
+ head=head_atom,
455
+ rule_type=RuleType.TRANSITIVITY,
456
+ confidence=0.8
457
+ )
458
+
459
+ # Create MLN
460
+ mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
461
+
462
+ # Verify MLN properties
463
+ assert len(mln.logical_rules) == 1, "Should have 1 logical rule"
464
+ assert len(mln.ground_rules) > 0, "Should have generated ground rules"
465
+ assert mln.rule_weights is not None, "Should have initialized rule weights"
466
+
467
+ # Test probability computation
468
+ test_assignment = {fact: True for fact in kg.known_facts}
469
+ log_prob = mln.compute_joint_probability(test_assignment)
470
+
471
+ assert torch.isfinite(log_prob), "Joint probability should be finite"
472
+
473
+ logger.info("MLN implementation verified successfully")
474
+
475
475
  return True