odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/core/mln.py
CHANGED
|
@@ -1,475 +1,475 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Markov Logic Network (MLN) implementation for NPLL
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
from typing import List, Dict, Set, Tuple, Optional, Any
|
|
8
|
-
from collections import defaultdict
|
|
9
|
-
import logging
|
|
10
|
-
from dataclasses import dataclass
|
|
11
|
-
|
|
12
|
-
from .knowledge_graph import KnowledgeGraph, Triple
|
|
13
|
-
from .logical_rules import LogicalRule, GroundRule
|
|
14
|
-
from ..utils.config import NPLLConfig
|
|
15
|
-
from ..utils.math_utils import log_sum_exp, partition_function_approximation, compute_mln_probability
|
|
16
|
-
|
|
17
|
-
logger = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class MLNState:
|
|
22
|
-
|
|
23
|
-
fact_assignments: Dict[Triple, bool] # Truth values for all facts
|
|
24
|
-
known_facts: Set[Triple] # Known facts F
|
|
25
|
-
unknown_facts: Set[Triple] # Unknown facts U
|
|
26
|
-
|
|
27
|
-
def __post_init__(self):
|
|
28
|
-
"""Validate MLN state"""
|
|
29
|
-
all_facts = set(self.fact_assignments.keys())
|
|
30
|
-
expected_facts = self.known_facts | self.unknown_facts
|
|
31
|
-
|
|
32
|
-
if all_facts != expected_facts:
|
|
33
|
-
missing = expected_facts - all_facts
|
|
34
|
-
extra = all_facts - expected_facts
|
|
35
|
-
logger.warning(f"MLN state inconsistency. Missing: {len(missing)}, Extra: {len(extra)}")
|
|
36
|
-
|
|
37
|
-
def evaluate_ground_rule(self, ground_rule: GroundRule) -> bool:
|
|
38
|
-
|
|
39
|
-
# Check if all body facts are true
|
|
40
|
-
body_satisfied = all(
|
|
41
|
-
self.fact_assignments.get(fact, False)
|
|
42
|
-
for fact in ground_rule.body_facts
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
# If body is false, rule is vacuously true
|
|
46
|
-
if not body_satisfied:
|
|
47
|
-
return True
|
|
48
|
-
|
|
49
|
-
# If body is true, check if head is true
|
|
50
|
-
head_satisfied = self.fact_assignments.get(ground_rule.head_fact, False)
|
|
51
|
-
return head_satisfied
|
|
52
|
-
|
|
53
|
-
def count_satisfied_ground_rules(self, ground_rules: List[GroundRule]) -> int:
|
|
54
|
-
"""Count number of ground rules satisfied in this state"""
|
|
55
|
-
return sum(1 for gr in ground_rules if self.evaluate_ground_rule(gr))
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class MarkovLogicNetwork(nn.Module):
|
|
59
|
-
|
|
60
|
-
def __init__(self, config: NPLLConfig):
|
|
61
|
-
super().__init__()
|
|
62
|
-
self.config = config
|
|
63
|
-
|
|
64
|
-
# Core MLN components
|
|
65
|
-
self.knowledge_graph: Optional[KnowledgeGraph] = None
|
|
66
|
-
self.logical_rules: List[LogicalRule] = []
|
|
67
|
-
self.ground_rules: List[GroundRule] = []
|
|
68
|
-
|
|
69
|
-
# Rule weights ω (learnable parameters)
|
|
70
|
-
self.rule_weights: Optional[nn.Parameter] = None
|
|
71
|
-
|
|
72
|
-
# Ground rule organization
|
|
73
|
-
self.rule_to_ground_rules: Dict[str, List[GroundRule]] = defaultdict(list)
|
|
74
|
-
self.ground_rule_facts: Set[Triple] = set()
|
|
75
|
-
# Inverted index for fast lookup: fact -> ground rules containing it
|
|
76
|
-
self.fact_to_groundrules: Dict[Triple, List[GroundRule]] = defaultdict(list)
|
|
77
|
-
|
|
78
|
-
# Caching for efficiency
|
|
79
|
-
self._partition_function_cache: Dict[str, torch.Tensor] = {}
|
|
80
|
-
self._ground_rule_counts_cache: Optional[torch.Tensor] = None
|
|
81
|
-
|
|
82
|
-
def add_knowledge_graph(self, kg: KnowledgeGraph):
|
|
83
|
-
"""Add knowledge graph to MLN"""
|
|
84
|
-
self.knowledge_graph = kg
|
|
85
|
-
logger.info(f"Added knowledge graph with {len(kg.known_facts)} known facts")
|
|
86
|
-
|
|
87
|
-
def add_logical_rules(self, rules: List[LogicalRule]):
|
|
88
|
-
self.logical_rules.extend(rules)
|
|
89
|
-
|
|
90
|
-
# Initialize or expand rule weights
|
|
91
|
-
if self.rule_weights is None:
|
|
92
|
-
# Initialize rule weights to small random values
|
|
93
|
-
initial_weights = torch.randn(len(rules)) * 0.1
|
|
94
|
-
self.rule_weights = nn.Parameter(initial_weights, requires_grad=True)
|
|
95
|
-
else:
|
|
96
|
-
# Expand existing weights
|
|
97
|
-
old_weights = self.rule_weights.data
|
|
98
|
-
new_weights = torch.randn(len(rules)) * 0.1
|
|
99
|
-
expanded_weights = torch.cat([old_weights, new_weights])
|
|
100
|
-
self.rule_weights = nn.Parameter(expanded_weights, requires_grad=True)
|
|
101
|
-
|
|
102
|
-
# Generate ground rules for new logical rules
|
|
103
|
-
if self.knowledge_graph is not None:
|
|
104
|
-
self._generate_ground_rules(rules)
|
|
105
|
-
|
|
106
|
-
logger.info(f"Added {len(rules)} logical rules. Total: {len(self.logical_rules)}")
|
|
107
|
-
|
|
108
|
-
def _generate_ground_rules(self, rules: List[LogicalRule]):
|
|
109
|
-
"""Generate ground rules from logical rules using knowledge graph"""
|
|
110
|
-
new_ground_rules = []
|
|
111
|
-
|
|
112
|
-
for rule in rules:
|
|
113
|
-
# Generate ground rules for this logical rule
|
|
114
|
-
ground_rules = rule.generate_ground_rules(
|
|
115
|
-
self.knowledge_graph,
|
|
116
|
-
max_groundings=self.config.max_ground_rules
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
# Add to collections
|
|
120
|
-
new_ground_rules.extend(ground_rules)
|
|
121
|
-
self.rule_to_ground_rules[rule.rule_id].extend(ground_rules)
|
|
122
|
-
|
|
123
|
-
# Collect all facts involved in ground rules
|
|
124
|
-
for gr in ground_rules:
|
|
125
|
-
self.ground_rule_facts.update(gr.get_all_facts())
|
|
126
|
-
# Build inverted index for each fact
|
|
127
|
-
for f in gr.get_all_facts():
|
|
128
|
-
self.fact_to_groundrules[f].append(gr)
|
|
129
|
-
|
|
130
|
-
self.ground_rules.extend(new_ground_rules)
|
|
131
|
-
logger.info(f"Generated {len(new_ground_rules)} ground rules. Total: {len(self.ground_rules)}")
|
|
132
|
-
|
|
133
|
-
def compute_ground_rule_counts(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
|
|
134
|
-
"""
|
|
135
|
-
Compute N(F,U) - number of satisfied ground rules for each logical rule
|
|
136
|
-
"""
|
|
137
|
-
if not self.logical_rules:
|
|
138
|
-
return torch.tensor([])
|
|
139
|
-
|
|
140
|
-
rule_counts = torch.zeros(len(self.logical_rules))
|
|
141
|
-
|
|
142
|
-
for rule_idx, rule in enumerate(self.logical_rules):
|
|
143
|
-
ground_rules = self.rule_to_ground_rules[rule.rule_id]
|
|
144
|
-
satisfied_count = 0
|
|
145
|
-
|
|
146
|
-
for ground_rule in ground_rules:
|
|
147
|
-
# Check if this ground rule is satisfied
|
|
148
|
-
all_facts = ground_rule.get_all_facts()
|
|
149
|
-
|
|
150
|
-
# Check if all required facts are true (for body) and conclusion follows
|
|
151
|
-
body_satisfied = all(
|
|
152
|
-
fact_assignments.get(fact, False)
|
|
153
|
-
for fact in ground_rule.body_facts
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
if body_satisfied:
|
|
157
|
-
# If body is true, rule is satisfied if head is also true
|
|
158
|
-
head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
|
|
159
|
-
if head_satisfied:
|
|
160
|
-
satisfied_count += 1
|
|
161
|
-
else:
|
|
162
|
-
# If body is false, rule is vacuously satisfied
|
|
163
|
-
satisfied_count += 1
|
|
164
|
-
|
|
165
|
-
rule_counts[rule_idx] = satisfied_count
|
|
166
|
-
|
|
167
|
-
return rule_counts
|
|
168
|
-
|
|
169
|
-
def compute_partition_function(self, sample_states: Optional[List[Dict[Triple, bool]]] = None,
|
|
170
|
-
use_approximation: bool = True) -> torch.Tensor:
|
|
171
|
-
"""
|
|
172
|
-
Compute MLN partition function Z(ω) from Equation 2
|
|
173
|
-
"""
|
|
174
|
-
if self.rule_weights is None or len(self.logical_rules) == 0:
|
|
175
|
-
return torch.tensor(0.0)
|
|
176
|
-
|
|
177
|
-
# Use caching if available
|
|
178
|
-
cache_key = str(self.rule_weights.data.tolist())
|
|
179
|
-
if cache_key in self._partition_function_cache:
|
|
180
|
-
return self._partition_function_cache[cache_key]
|
|
181
|
-
|
|
182
|
-
if use_approximation or sample_states is not None:
|
|
183
|
-
# Sampling-based approximation
|
|
184
|
-
if sample_states is None:
|
|
185
|
-
sample_states = self._generate_sample_states(num_samples=1000)
|
|
186
|
-
|
|
187
|
-
# Compute counts for each sample state
|
|
188
|
-
all_counts = []
|
|
189
|
-
for state_assignment in sample_states:
|
|
190
|
-
counts = self.compute_ground_rule_counts(state_assignment)
|
|
191
|
-
all_counts.append(counts)
|
|
192
|
-
|
|
193
|
-
if all_counts:
|
|
194
|
-
counts_tensor = torch.stack(all_counts) # [num_samples, num_rules]
|
|
195
|
-
log_partition = partition_function_approximation(
|
|
196
|
-
self.rule_weights, counts_tensor, use_log_domain=True
|
|
197
|
-
)
|
|
198
|
-
else:
|
|
199
|
-
log_partition = torch.tensor(0.0)
|
|
200
|
-
else:
|
|
201
|
-
# Exact computation (intractable for large graphs)
|
|
202
|
-
logger.warning("Exact partition function computation is intractable for large graphs")
|
|
203
|
-
log_partition = self._compute_exact_partition_function()
|
|
204
|
-
|
|
205
|
-
# Cache result
|
|
206
|
-
self._partition_function_cache[cache_key] = log_partition
|
|
207
|
-
|
|
208
|
-
return log_partition
|
|
209
|
-
|
|
210
|
-
def compute_joint_probability(self, fact_assignments: Dict[Triple, bool],
|
|
211
|
-
log_partition: Optional[torch.Tensor] = None,
|
|
212
|
-
detach_weights: bool = False) -> torch.Tensor:
|
|
213
|
-
"""
|
|
214
|
-
Compute joint probability P(F,U|ω) from Equation 1
|
|
215
|
-
|
|
216
|
-
"""
|
|
217
|
-
if self.rule_weights is None:
|
|
218
|
-
return torch.tensor(0.0)
|
|
219
|
-
|
|
220
|
-
# Compute ground rule counts N(F,U)
|
|
221
|
-
counts = self.compute_ground_rule_counts(fact_assignments)
|
|
222
|
-
|
|
223
|
-
# Compute log partition function if not provided
|
|
224
|
-
if log_partition is None:
|
|
225
|
-
log_partition = self.compute_partition_function()
|
|
226
|
-
|
|
227
|
-
# Compute log probability using utility function
|
|
228
|
-
weights_to_use = self.rule_weights.detach() if detach_weights else self.rule_weights
|
|
229
|
-
log_prob = compute_mln_probability(
|
|
230
|
-
weights_to_use, counts.unsqueeze(0), log_partition
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
return log_prob.squeeze(0)
|
|
234
|
-
|
|
235
|
-
def _generate_sample_states(self, num_samples: int = 1000) -> List[Dict[Triple, bool]]:
|
|
236
|
-
"""
|
|
237
|
-
Generate sample states for partition function approximation
|
|
238
|
-
"""
|
|
239
|
-
sample_states = []
|
|
240
|
-
|
|
241
|
-
if self.knowledge_graph is None:
|
|
242
|
-
return sample_states
|
|
243
|
-
|
|
244
|
-
# Get all facts that appear in ground rules
|
|
245
|
-
all_facts = list(self.ground_rule_facts)
|
|
246
|
-
|
|
247
|
-
if not all_facts:
|
|
248
|
-
return sample_states
|
|
249
|
-
|
|
250
|
-
# Generate random assignments
|
|
251
|
-
for _ in range(num_samples):
|
|
252
|
-
# Start with known facts as true
|
|
253
|
-
assignment = {}
|
|
254
|
-
|
|
255
|
-
# Set known facts to true
|
|
256
|
-
for fact in self.knowledge_graph.known_facts:
|
|
257
|
-
assignment[fact] = True
|
|
258
|
-
|
|
259
|
-
# Randomly assign unknown facts
|
|
260
|
-
unknown_facts_in_rules = [f for f in all_facts if f not in assignment]
|
|
261
|
-
for fact in unknown_facts_in_rules:
|
|
262
|
-
# Assign random truth value (could be made smarter)
|
|
263
|
-
assignment[fact] = torch.rand(1).item() > 0.5
|
|
264
|
-
|
|
265
|
-
sample_states.append(assignment)
|
|
266
|
-
|
|
267
|
-
return sample_states
|
|
268
|
-
|
|
269
|
-
def _compute_exact_partition_function(self) -> torch.Tensor:
|
|
270
|
-
"""
|
|
271
|
-
Compute exact partition function
|
|
272
|
-
"""
|
|
273
|
-
if not self.ground_rule_facts:
|
|
274
|
-
return torch.tensor(0.0)
|
|
275
|
-
|
|
276
|
-
all_facts = list(self.ground_rule_facts)
|
|
277
|
-
num_facts = len(all_facts)
|
|
278
|
-
|
|
279
|
-
if num_facts > 20: # Arbitrary limit to prevent memory explosion
|
|
280
|
-
logger.error(f"Too many facts ({num_facts}) for exact partition function computation")
|
|
281
|
-
return self.compute_partition_function(use_approximation=True)
|
|
282
|
-
|
|
283
|
-
# Enumerate all possible truth assignments
|
|
284
|
-
total_log_prob = []
|
|
285
|
-
|
|
286
|
-
for i in range(2 ** num_facts):
|
|
287
|
-
# Generate truth assignment from binary representation
|
|
288
|
-
assignment = {}
|
|
289
|
-
for j, fact in enumerate(all_facts):
|
|
290
|
-
assignment[fact] = bool((i >> j) & 1)
|
|
291
|
-
|
|
292
|
-
# Compute counts for this assignment
|
|
293
|
-
counts = self.compute_ground_rule_counts(assignment)
|
|
294
|
-
|
|
295
|
-
# Compute potential
|
|
296
|
-
log_potential = torch.sum(self.rule_weights * counts)
|
|
297
|
-
total_log_prob.append(log_potential)
|
|
298
|
-
|
|
299
|
-
# Compute log-sum-exp
|
|
300
|
-
if total_log_prob:
|
|
301
|
-
log_partition = log_sum_exp(torch.stack(total_log_prob))
|
|
302
|
-
else:
|
|
303
|
-
log_partition = torch.tensor(0.0)
|
|
304
|
-
|
|
305
|
-
return log_partition
|
|
306
|
-
|
|
307
|
-
def get_rule_statistics(self) -> Dict[str, Any]:
|
|
308
|
-
"""Get statistics about the MLN"""
|
|
309
|
-
stats = {
|
|
310
|
-
'num_logical_rules': len(self.logical_rules),
|
|
311
|
-
'num_ground_rules': len(self.ground_rules),
|
|
312
|
-
'num_facts_in_ground_rules': len(self.ground_rule_facts),
|
|
313
|
-
'rule_weights': self.rule_weights.data.tolist() if self.rule_weights is not None else []
|
|
314
|
-
}
|
|
315
|
-
|
|
316
|
-
# Per-rule statistics
|
|
317
|
-
rule_stats = []
|
|
318
|
-
for i, rule in enumerate(self.logical_rules):
|
|
319
|
-
ground_rules = self.rule_to_ground_rules[rule.rule_id]
|
|
320
|
-
rule_stat = {
|
|
321
|
-
'rule_id': rule.rule_id,
|
|
322
|
-
'rule_type': rule.rule_type.value,
|
|
323
|
-
'num_ground_rules': len(ground_rules),
|
|
324
|
-
'weight': self.rule_weights[i].item() if self.rule_weights is not None else 0.0,
|
|
325
|
-
'learned_confidence': torch.sigmoid(self.rule_weights[i]).item() if self.rule_weights is not None else None,
|
|
326
|
-
'support': rule.support
|
|
327
|
-
}
|
|
328
|
-
rule_stats.append(rule_stat)
|
|
329
|
-
|
|
330
|
-
stats['rule_details'] = rule_stats
|
|
331
|
-
|
|
332
|
-
return stats
|
|
333
|
-
|
|
334
|
-
def forward(self, fact_assignments_batch: List[Dict[Triple, bool]]) -> torch.Tensor:
|
|
335
|
-
"""
|
|
336
|
-
Forward pass for batch of fact assignments
|
|
337
|
-
"""
|
|
338
|
-
if not fact_assignments_batch:
|
|
339
|
-
return torch.tensor([])
|
|
340
|
-
|
|
341
|
-
# Compute partition function once
|
|
342
|
-
log_partition = self.compute_partition_function()
|
|
343
|
-
|
|
344
|
-
# Compute probabilities for each assignment
|
|
345
|
-
log_probs = []
|
|
346
|
-
for assignment in fact_assignments_batch:
|
|
347
|
-
log_prob = self.compute_joint_probability(assignment, log_partition)
|
|
348
|
-
log_probs.append(log_prob)
|
|
349
|
-
|
|
350
|
-
return torch.stack(log_probs) if log_probs else torch.tensor([])
|
|
351
|
-
|
|
352
|
-
def sample_from_distribution(self, num_samples: int = 100) -> List[Dict[Triple, bool]]:
|
|
353
|
-
"""
|
|
354
|
-
Sample fact assignments from MLN distribution using Gibbs sampling
|
|
355
|
-
"""
|
|
356
|
-
if not self.ground_rule_facts:
|
|
357
|
-
return []
|
|
358
|
-
|
|
359
|
-
samples = []
|
|
360
|
-
all_facts = list(self.ground_rule_facts)
|
|
361
|
-
|
|
362
|
-
# Initialize with random assignment
|
|
363
|
-
current_assignment = {fact: torch.rand(1).item() > 0.5 for fact in all_facts}
|
|
364
|
-
|
|
365
|
-
# Set known facts to true (they don't change)
|
|
366
|
-
if self.knowledge_graph:
|
|
367
|
-
for fact in self.knowledge_graph.known_facts:
|
|
368
|
-
current_assignment[fact] = True
|
|
369
|
-
|
|
370
|
-
# Gibbs sampling
|
|
371
|
-
for _ in range(num_samples):
|
|
372
|
-
# Sample each unknown fact given others
|
|
373
|
-
for fact in all_facts:
|
|
374
|
-
if self.knowledge_graph and fact in self.knowledge_graph.known_facts:
|
|
375
|
-
continue # Skip known facts
|
|
376
|
-
|
|
377
|
-
# Compute conditional probability P(fact=True | others)
|
|
378
|
-
prob_true = self._compute_conditional_probability(fact, current_assignment)
|
|
379
|
-
|
|
380
|
-
# Sample from Bernoulli distribution
|
|
381
|
-
current_assignment[fact] = torch.rand(1).item() < prob_true
|
|
382
|
-
|
|
383
|
-
# Store sample
|
|
384
|
-
samples.append(current_assignment.copy())
|
|
385
|
-
|
|
386
|
-
return samples
|
|
387
|
-
|
|
388
|
-
def _compute_conditional_probability(self, target_fact: Triple,
|
|
389
|
-
current_assignment: Dict[Triple, bool]) -> float:
|
|
390
|
-
"""
|
|
391
|
-
Compute P(target_fact=True | other_facts) using local MLN structure
|
|
392
|
-
"""
|
|
393
|
-
# Create two assignments: one with target_fact=True, one with False
|
|
394
|
-
assignment_true = current_assignment.copy()
|
|
395
|
-
assignment_false = current_assignment.copy()
|
|
396
|
-
assignment_true[target_fact] = True
|
|
397
|
-
assignment_false[target_fact] = False
|
|
398
|
-
|
|
399
|
-
# Compute unnormalized probabilities
|
|
400
|
-
log_prob_true = self.compute_joint_probability(assignment_true)
|
|
401
|
-
log_prob_false = self.compute_joint_probability(assignment_false)
|
|
402
|
-
|
|
403
|
-
# Normalize using log-sum-exp
|
|
404
|
-
log_probs = torch.stack([log_prob_false, log_prob_true])
|
|
405
|
-
normalized_probs = torch.softmax(log_probs, dim=0)
|
|
406
|
-
|
|
407
|
-
return normalized_probs[1].item() # Return P(target_fact=True)
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
def create_mln_from_kg_and_rules(kg: KnowledgeGraph, rules: List[LogicalRule],
|
|
411
|
-
config: NPLLConfig) -> MarkovLogicNetwork:
|
|
412
|
-
"""
|
|
413
|
-
Factory function to create MLN from knowledge graph and logical rules
|
|
414
|
-
"""
|
|
415
|
-
mln = MarkovLogicNetwork(config)
|
|
416
|
-
mln.add_knowledge_graph(kg)
|
|
417
|
-
mln.add_logical_rules(rules)
|
|
418
|
-
|
|
419
|
-
logger.info(f"Created MLN with {len(rules)} rules and {len(mln.ground_rules)} ground rules")
|
|
420
|
-
|
|
421
|
-
return mln
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
def verify_mln_implementation():
|
|
425
|
-
"""Verify MLN implementation with small test case"""
|
|
426
|
-
from ..utils.config import default_config
|
|
427
|
-
from .knowledge_graph import Entity, Relation, load_knowledge_graph_from_triples
|
|
428
|
-
from .logical_rules import Variable, Atom, RuleType
|
|
429
|
-
|
|
430
|
-
# Create test knowledge graph
|
|
431
|
-
test_triples = [
|
|
432
|
-
("Tom", "plays", "basketball"),
|
|
433
|
-
("Tom", "friend", "John"),
|
|
434
|
-
("John", "plays", "soccer")
|
|
435
|
-
]
|
|
436
|
-
|
|
437
|
-
kg = load_knowledge_graph_from_triples(test_triples, "TestKG")
|
|
438
|
-
|
|
439
|
-
# Create test rule: plays(x, y) ∧ friend(x, z) ⇒ plays(z, y)
|
|
440
|
-
plays_rel = Relation("plays")
|
|
441
|
-
friend_rel = Relation("friend")
|
|
442
|
-
|
|
443
|
-
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
444
|
-
|
|
445
|
-
body_atoms = [
|
|
446
|
-
Atom(plays_rel, (x, y)),
|
|
447
|
-
Atom(friend_rel, (x, z))
|
|
448
|
-
]
|
|
449
|
-
head_atom = Atom(plays_rel, (z, y))
|
|
450
|
-
|
|
451
|
-
test_rule = LogicalRule(
|
|
452
|
-
rule_id="test_transitivity",
|
|
453
|
-
body=body_atoms,
|
|
454
|
-
head=head_atom,
|
|
455
|
-
rule_type=RuleType.TRANSITIVITY,
|
|
456
|
-
confidence=0.8
|
|
457
|
-
)
|
|
458
|
-
|
|
459
|
-
# Create MLN
|
|
460
|
-
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
461
|
-
|
|
462
|
-
# Verify MLN properties
|
|
463
|
-
assert len(mln.logical_rules) == 1, "Should have 1 logical rule"
|
|
464
|
-
assert len(mln.ground_rules) > 0, "Should have generated ground rules"
|
|
465
|
-
assert mln.rule_weights is not None, "Should have initialized rule weights"
|
|
466
|
-
|
|
467
|
-
# Test probability computation
|
|
468
|
-
test_assignment = {fact: True for fact in kg.known_facts}
|
|
469
|
-
log_prob = mln.compute_joint_probability(test_assignment)
|
|
470
|
-
|
|
471
|
-
assert torch.isfinite(log_prob), "Joint probability should be finite"
|
|
472
|
-
|
|
473
|
-
logger.info("MLN implementation verified successfully")
|
|
474
|
-
|
|
1
|
+
"""
|
|
2
|
+
Markov Logic Network (MLN) implementation for NPLL
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from typing import List, Dict, Set, Tuple, Optional, Any
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
import logging
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
from .knowledge_graph import KnowledgeGraph, Triple
|
|
13
|
+
from .logical_rules import LogicalRule, GroundRule
|
|
14
|
+
from ..utils.config import NPLLConfig
|
|
15
|
+
from ..utils.math_utils import log_sum_exp, partition_function_approximation, compute_mln_probability
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class MLNState:
|
|
22
|
+
|
|
23
|
+
fact_assignments: Dict[Triple, bool] # Truth values for all facts
|
|
24
|
+
known_facts: Set[Triple] # Known facts F
|
|
25
|
+
unknown_facts: Set[Triple] # Unknown facts U
|
|
26
|
+
|
|
27
|
+
def __post_init__(self):
|
|
28
|
+
"""Validate MLN state"""
|
|
29
|
+
all_facts = set(self.fact_assignments.keys())
|
|
30
|
+
expected_facts = self.known_facts | self.unknown_facts
|
|
31
|
+
|
|
32
|
+
if all_facts != expected_facts:
|
|
33
|
+
missing = expected_facts - all_facts
|
|
34
|
+
extra = all_facts - expected_facts
|
|
35
|
+
logger.warning(f"MLN state inconsistency. Missing: {len(missing)}, Extra: {len(extra)}")
|
|
36
|
+
|
|
37
|
+
def evaluate_ground_rule(self, ground_rule: GroundRule) -> bool:
|
|
38
|
+
|
|
39
|
+
# Check if all body facts are true
|
|
40
|
+
body_satisfied = all(
|
|
41
|
+
self.fact_assignments.get(fact, False)
|
|
42
|
+
for fact in ground_rule.body_facts
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# If body is false, rule is vacuously true
|
|
46
|
+
if not body_satisfied:
|
|
47
|
+
return True
|
|
48
|
+
|
|
49
|
+
# If body is true, check if head is true
|
|
50
|
+
head_satisfied = self.fact_assignments.get(ground_rule.head_fact, False)
|
|
51
|
+
return head_satisfied
|
|
52
|
+
|
|
53
|
+
def count_satisfied_ground_rules(self, ground_rules: List[GroundRule]) -> int:
|
|
54
|
+
"""Count number of ground rules satisfied in this state"""
|
|
55
|
+
return sum(1 for gr in ground_rules if self.evaluate_ground_rule(gr))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MarkovLogicNetwork(nn.Module):
|
|
59
|
+
|
|
60
|
+
def __init__(self, config: NPLLConfig):
|
|
61
|
+
super().__init__()
|
|
62
|
+
self.config = config
|
|
63
|
+
|
|
64
|
+
# Core MLN components
|
|
65
|
+
self.knowledge_graph: Optional[KnowledgeGraph] = None
|
|
66
|
+
self.logical_rules: List[LogicalRule] = []
|
|
67
|
+
self.ground_rules: List[GroundRule] = []
|
|
68
|
+
|
|
69
|
+
# Rule weights ω (learnable parameters)
|
|
70
|
+
self.rule_weights: Optional[nn.Parameter] = None
|
|
71
|
+
|
|
72
|
+
# Ground rule organization
|
|
73
|
+
self.rule_to_ground_rules: Dict[str, List[GroundRule]] = defaultdict(list)
|
|
74
|
+
self.ground_rule_facts: Set[Triple] = set()
|
|
75
|
+
# Inverted index for fast lookup: fact -> ground rules containing it
|
|
76
|
+
self.fact_to_groundrules: Dict[Triple, List[GroundRule]] = defaultdict(list)
|
|
77
|
+
|
|
78
|
+
# Caching for efficiency
|
|
79
|
+
self._partition_function_cache: Dict[str, torch.Tensor] = {}
|
|
80
|
+
self._ground_rule_counts_cache: Optional[torch.Tensor] = None
|
|
81
|
+
|
|
82
|
+
def add_knowledge_graph(self, kg: KnowledgeGraph):
|
|
83
|
+
"""Add knowledge graph to MLN"""
|
|
84
|
+
self.knowledge_graph = kg
|
|
85
|
+
logger.info(f"Added knowledge graph with {len(kg.known_facts)} known facts")
|
|
86
|
+
|
|
87
|
+
def add_logical_rules(self, rules: List[LogicalRule]):
|
|
88
|
+
self.logical_rules.extend(rules)
|
|
89
|
+
|
|
90
|
+
# Initialize or expand rule weights
|
|
91
|
+
if self.rule_weights is None:
|
|
92
|
+
# Initialize rule weights to small random values
|
|
93
|
+
initial_weights = torch.randn(len(rules)) * 0.1
|
|
94
|
+
self.rule_weights = nn.Parameter(initial_weights, requires_grad=True)
|
|
95
|
+
else:
|
|
96
|
+
# Expand existing weights
|
|
97
|
+
old_weights = self.rule_weights.data
|
|
98
|
+
new_weights = torch.randn(len(rules)) * 0.1
|
|
99
|
+
expanded_weights = torch.cat([old_weights, new_weights])
|
|
100
|
+
self.rule_weights = nn.Parameter(expanded_weights, requires_grad=True)
|
|
101
|
+
|
|
102
|
+
# Generate ground rules for new logical rules
|
|
103
|
+
if self.knowledge_graph is not None:
|
|
104
|
+
self._generate_ground_rules(rules)
|
|
105
|
+
|
|
106
|
+
logger.info(f"Added {len(rules)} logical rules. Total: {len(self.logical_rules)}")
|
|
107
|
+
|
|
108
|
+
def _generate_ground_rules(self, rules: List[LogicalRule]):
|
|
109
|
+
"""Generate ground rules from logical rules using knowledge graph"""
|
|
110
|
+
new_ground_rules = []
|
|
111
|
+
|
|
112
|
+
for rule in rules:
|
|
113
|
+
# Generate ground rules for this logical rule
|
|
114
|
+
ground_rules = rule.generate_ground_rules(
|
|
115
|
+
self.knowledge_graph,
|
|
116
|
+
max_groundings=self.config.max_ground_rules
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Add to collections
|
|
120
|
+
new_ground_rules.extend(ground_rules)
|
|
121
|
+
self.rule_to_ground_rules[rule.rule_id].extend(ground_rules)
|
|
122
|
+
|
|
123
|
+
# Collect all facts involved in ground rules
|
|
124
|
+
for gr in ground_rules:
|
|
125
|
+
self.ground_rule_facts.update(gr.get_all_facts())
|
|
126
|
+
# Build inverted index for each fact
|
|
127
|
+
for f in gr.get_all_facts():
|
|
128
|
+
self.fact_to_groundrules[f].append(gr)
|
|
129
|
+
|
|
130
|
+
self.ground_rules.extend(new_ground_rules)
|
|
131
|
+
logger.info(f"Generated {len(new_ground_rules)} ground rules. Total: {len(self.ground_rules)}")
|
|
132
|
+
|
|
133
|
+
def compute_ground_rule_counts(self, fact_assignments: Dict[Triple, bool]) -> torch.Tensor:
|
|
134
|
+
"""
|
|
135
|
+
Compute N(F,U) - number of satisfied ground rules for each logical rule
|
|
136
|
+
"""
|
|
137
|
+
if not self.logical_rules:
|
|
138
|
+
return torch.tensor([])
|
|
139
|
+
|
|
140
|
+
rule_counts = torch.zeros(len(self.logical_rules))
|
|
141
|
+
|
|
142
|
+
for rule_idx, rule in enumerate(self.logical_rules):
|
|
143
|
+
ground_rules = self.rule_to_ground_rules[rule.rule_id]
|
|
144
|
+
satisfied_count = 0
|
|
145
|
+
|
|
146
|
+
for ground_rule in ground_rules:
|
|
147
|
+
# Check if this ground rule is satisfied
|
|
148
|
+
all_facts = ground_rule.get_all_facts()
|
|
149
|
+
|
|
150
|
+
# Check if all required facts are true (for body) and conclusion follows
|
|
151
|
+
body_satisfied = all(
|
|
152
|
+
fact_assignments.get(fact, False)
|
|
153
|
+
for fact in ground_rule.body_facts
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if body_satisfied:
|
|
157
|
+
# If body is true, rule is satisfied if head is also true
|
|
158
|
+
head_satisfied = fact_assignments.get(ground_rule.head_fact, False)
|
|
159
|
+
if head_satisfied:
|
|
160
|
+
satisfied_count += 1
|
|
161
|
+
else:
|
|
162
|
+
# If body is false, rule is vacuously satisfied
|
|
163
|
+
satisfied_count += 1
|
|
164
|
+
|
|
165
|
+
rule_counts[rule_idx] = satisfied_count
|
|
166
|
+
|
|
167
|
+
return rule_counts
|
|
168
|
+
|
|
169
|
+
def compute_partition_function(self, sample_states: Optional[List[Dict[Triple, bool]]] = None,
|
|
170
|
+
use_approximation: bool = True) -> torch.Tensor:
|
|
171
|
+
"""
|
|
172
|
+
Compute MLN partition function Z(ω) from Equation 2
|
|
173
|
+
"""
|
|
174
|
+
if self.rule_weights is None or len(self.logical_rules) == 0:
|
|
175
|
+
return torch.tensor(0.0)
|
|
176
|
+
|
|
177
|
+
# Use caching if available
|
|
178
|
+
cache_key = str(self.rule_weights.data.tolist())
|
|
179
|
+
if cache_key in self._partition_function_cache:
|
|
180
|
+
return self._partition_function_cache[cache_key]
|
|
181
|
+
|
|
182
|
+
if use_approximation or sample_states is not None:
|
|
183
|
+
# Sampling-based approximation
|
|
184
|
+
if sample_states is None:
|
|
185
|
+
sample_states = self._generate_sample_states(num_samples=1000)
|
|
186
|
+
|
|
187
|
+
# Compute counts for each sample state
|
|
188
|
+
all_counts = []
|
|
189
|
+
for state_assignment in sample_states:
|
|
190
|
+
counts = self.compute_ground_rule_counts(state_assignment)
|
|
191
|
+
all_counts.append(counts)
|
|
192
|
+
|
|
193
|
+
if all_counts:
|
|
194
|
+
counts_tensor = torch.stack(all_counts) # [num_samples, num_rules]
|
|
195
|
+
log_partition = partition_function_approximation(
|
|
196
|
+
self.rule_weights, counts_tensor, use_log_domain=True
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
log_partition = torch.tensor(0.0)
|
|
200
|
+
else:
|
|
201
|
+
# Exact computation (intractable for large graphs)
|
|
202
|
+
logger.warning("Exact partition function computation is intractable for large graphs")
|
|
203
|
+
log_partition = self._compute_exact_partition_function()
|
|
204
|
+
|
|
205
|
+
# Cache result
|
|
206
|
+
self._partition_function_cache[cache_key] = log_partition
|
|
207
|
+
|
|
208
|
+
return log_partition
|
|
209
|
+
|
|
210
|
+
def compute_joint_probability(self, fact_assignments: Dict[Triple, bool],
|
|
211
|
+
log_partition: Optional[torch.Tensor] = None,
|
|
212
|
+
detach_weights: bool = False) -> torch.Tensor:
|
|
213
|
+
"""
|
|
214
|
+
Compute joint probability P(F,U|ω) from Equation 1
|
|
215
|
+
|
|
216
|
+
"""
|
|
217
|
+
if self.rule_weights is None:
|
|
218
|
+
return torch.tensor(0.0)
|
|
219
|
+
|
|
220
|
+
# Compute ground rule counts N(F,U)
|
|
221
|
+
counts = self.compute_ground_rule_counts(fact_assignments)
|
|
222
|
+
|
|
223
|
+
# Compute log partition function if not provided
|
|
224
|
+
if log_partition is None:
|
|
225
|
+
log_partition = self.compute_partition_function()
|
|
226
|
+
|
|
227
|
+
# Compute log probability using utility function
|
|
228
|
+
weights_to_use = self.rule_weights.detach() if detach_weights else self.rule_weights
|
|
229
|
+
log_prob = compute_mln_probability(
|
|
230
|
+
weights_to_use, counts.unsqueeze(0), log_partition
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
return log_prob.squeeze(0)
|
|
234
|
+
|
|
235
|
+
def _generate_sample_states(self, num_samples: int = 1000) -> List[Dict[Triple, bool]]:
|
|
236
|
+
"""
|
|
237
|
+
Generate sample states for partition function approximation
|
|
238
|
+
"""
|
|
239
|
+
sample_states = []
|
|
240
|
+
|
|
241
|
+
if self.knowledge_graph is None:
|
|
242
|
+
return sample_states
|
|
243
|
+
|
|
244
|
+
# Get all facts that appear in ground rules
|
|
245
|
+
all_facts = list(self.ground_rule_facts)
|
|
246
|
+
|
|
247
|
+
if not all_facts:
|
|
248
|
+
return sample_states
|
|
249
|
+
|
|
250
|
+
# Generate random assignments
|
|
251
|
+
for _ in range(num_samples):
|
|
252
|
+
# Start with known facts as true
|
|
253
|
+
assignment = {}
|
|
254
|
+
|
|
255
|
+
# Set known facts to true
|
|
256
|
+
for fact in self.knowledge_graph.known_facts:
|
|
257
|
+
assignment[fact] = True
|
|
258
|
+
|
|
259
|
+
# Randomly assign unknown facts
|
|
260
|
+
unknown_facts_in_rules = [f for f in all_facts if f not in assignment]
|
|
261
|
+
for fact in unknown_facts_in_rules:
|
|
262
|
+
# Assign random truth value (could be made smarter)
|
|
263
|
+
assignment[fact] = torch.rand(1).item() > 0.5
|
|
264
|
+
|
|
265
|
+
sample_states.append(assignment)
|
|
266
|
+
|
|
267
|
+
return sample_states
|
|
268
|
+
|
|
269
|
+
def _compute_exact_partition_function(self) -> torch.Tensor:
|
|
270
|
+
"""
|
|
271
|
+
Compute exact partition function
|
|
272
|
+
"""
|
|
273
|
+
if not self.ground_rule_facts:
|
|
274
|
+
return torch.tensor(0.0)
|
|
275
|
+
|
|
276
|
+
all_facts = list(self.ground_rule_facts)
|
|
277
|
+
num_facts = len(all_facts)
|
|
278
|
+
|
|
279
|
+
if num_facts > 20: # Arbitrary limit to prevent memory explosion
|
|
280
|
+
logger.error(f"Too many facts ({num_facts}) for exact partition function computation")
|
|
281
|
+
return self.compute_partition_function(use_approximation=True)
|
|
282
|
+
|
|
283
|
+
# Enumerate all possible truth assignments
|
|
284
|
+
total_log_prob = []
|
|
285
|
+
|
|
286
|
+
for i in range(2 ** num_facts):
|
|
287
|
+
# Generate truth assignment from binary representation
|
|
288
|
+
assignment = {}
|
|
289
|
+
for j, fact in enumerate(all_facts):
|
|
290
|
+
assignment[fact] = bool((i >> j) & 1)
|
|
291
|
+
|
|
292
|
+
# Compute counts for this assignment
|
|
293
|
+
counts = self.compute_ground_rule_counts(assignment)
|
|
294
|
+
|
|
295
|
+
# Compute potential
|
|
296
|
+
log_potential = torch.sum(self.rule_weights * counts)
|
|
297
|
+
total_log_prob.append(log_potential)
|
|
298
|
+
|
|
299
|
+
# Compute log-sum-exp
|
|
300
|
+
if total_log_prob:
|
|
301
|
+
log_partition = log_sum_exp(torch.stack(total_log_prob))
|
|
302
|
+
else:
|
|
303
|
+
log_partition = torch.tensor(0.0)
|
|
304
|
+
|
|
305
|
+
return log_partition
|
|
306
|
+
|
|
307
|
+
def get_rule_statistics(self) -> Dict[str, Any]:
|
|
308
|
+
"""Get statistics about the MLN"""
|
|
309
|
+
stats = {
|
|
310
|
+
'num_logical_rules': len(self.logical_rules),
|
|
311
|
+
'num_ground_rules': len(self.ground_rules),
|
|
312
|
+
'num_facts_in_ground_rules': len(self.ground_rule_facts),
|
|
313
|
+
'rule_weights': self.rule_weights.data.tolist() if self.rule_weights is not None else []
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
# Per-rule statistics
|
|
317
|
+
rule_stats = []
|
|
318
|
+
for i, rule in enumerate(self.logical_rules):
|
|
319
|
+
ground_rules = self.rule_to_ground_rules[rule.rule_id]
|
|
320
|
+
rule_stat = {
|
|
321
|
+
'rule_id': rule.rule_id,
|
|
322
|
+
'rule_type': rule.rule_type.value,
|
|
323
|
+
'num_ground_rules': len(ground_rules),
|
|
324
|
+
'weight': self.rule_weights[i].item() if self.rule_weights is not None else 0.0,
|
|
325
|
+
'learned_confidence': torch.sigmoid(self.rule_weights[i]).item() if self.rule_weights is not None else None,
|
|
326
|
+
'support': rule.support
|
|
327
|
+
}
|
|
328
|
+
rule_stats.append(rule_stat)
|
|
329
|
+
|
|
330
|
+
stats['rule_details'] = rule_stats
|
|
331
|
+
|
|
332
|
+
return stats
|
|
333
|
+
|
|
334
|
+
def forward(self, fact_assignments_batch: List[Dict[Triple, bool]]) -> torch.Tensor:
|
|
335
|
+
"""
|
|
336
|
+
Forward pass for batch of fact assignments
|
|
337
|
+
"""
|
|
338
|
+
if not fact_assignments_batch:
|
|
339
|
+
return torch.tensor([])
|
|
340
|
+
|
|
341
|
+
# Compute partition function once
|
|
342
|
+
log_partition = self.compute_partition_function()
|
|
343
|
+
|
|
344
|
+
# Compute probabilities for each assignment
|
|
345
|
+
log_probs = []
|
|
346
|
+
for assignment in fact_assignments_batch:
|
|
347
|
+
log_prob = self.compute_joint_probability(assignment, log_partition)
|
|
348
|
+
log_probs.append(log_prob)
|
|
349
|
+
|
|
350
|
+
return torch.stack(log_probs) if log_probs else torch.tensor([])
|
|
351
|
+
|
|
352
|
+
def sample_from_distribution(self, num_samples: int = 100) -> List[Dict[Triple, bool]]:
|
|
353
|
+
"""
|
|
354
|
+
Sample fact assignments from MLN distribution using Gibbs sampling
|
|
355
|
+
"""
|
|
356
|
+
if not self.ground_rule_facts:
|
|
357
|
+
return []
|
|
358
|
+
|
|
359
|
+
samples = []
|
|
360
|
+
all_facts = list(self.ground_rule_facts)
|
|
361
|
+
|
|
362
|
+
# Initialize with random assignment
|
|
363
|
+
current_assignment = {fact: torch.rand(1).item() > 0.5 for fact in all_facts}
|
|
364
|
+
|
|
365
|
+
# Set known facts to true (they don't change)
|
|
366
|
+
if self.knowledge_graph:
|
|
367
|
+
for fact in self.knowledge_graph.known_facts:
|
|
368
|
+
current_assignment[fact] = True
|
|
369
|
+
|
|
370
|
+
# Gibbs sampling
|
|
371
|
+
for _ in range(num_samples):
|
|
372
|
+
# Sample each unknown fact given others
|
|
373
|
+
for fact in all_facts:
|
|
374
|
+
if self.knowledge_graph and fact in self.knowledge_graph.known_facts:
|
|
375
|
+
continue # Skip known facts
|
|
376
|
+
|
|
377
|
+
# Compute conditional probability P(fact=True | others)
|
|
378
|
+
prob_true = self._compute_conditional_probability(fact, current_assignment)
|
|
379
|
+
|
|
380
|
+
# Sample from Bernoulli distribution
|
|
381
|
+
current_assignment[fact] = torch.rand(1).item() < prob_true
|
|
382
|
+
|
|
383
|
+
# Store sample
|
|
384
|
+
samples.append(current_assignment.copy())
|
|
385
|
+
|
|
386
|
+
return samples
|
|
387
|
+
|
|
388
|
+
def _compute_conditional_probability(self, target_fact: Triple,
|
|
389
|
+
current_assignment: Dict[Triple, bool]) -> float:
|
|
390
|
+
"""
|
|
391
|
+
Compute P(target_fact=True | other_facts) using local MLN structure
|
|
392
|
+
"""
|
|
393
|
+
# Create two assignments: one with target_fact=True, one with False
|
|
394
|
+
assignment_true = current_assignment.copy()
|
|
395
|
+
assignment_false = current_assignment.copy()
|
|
396
|
+
assignment_true[target_fact] = True
|
|
397
|
+
assignment_false[target_fact] = False
|
|
398
|
+
|
|
399
|
+
# Compute unnormalized probabilities
|
|
400
|
+
log_prob_true = self.compute_joint_probability(assignment_true)
|
|
401
|
+
log_prob_false = self.compute_joint_probability(assignment_false)
|
|
402
|
+
|
|
403
|
+
# Normalize using log-sum-exp
|
|
404
|
+
log_probs = torch.stack([log_prob_false, log_prob_true])
|
|
405
|
+
normalized_probs = torch.softmax(log_probs, dim=0)
|
|
406
|
+
|
|
407
|
+
return normalized_probs[1].item() # Return P(target_fact=True)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def create_mln_from_kg_and_rules(kg: KnowledgeGraph, rules: List[LogicalRule],
|
|
411
|
+
config: NPLLConfig) -> MarkovLogicNetwork:
|
|
412
|
+
"""
|
|
413
|
+
Factory function to create MLN from knowledge graph and logical rules
|
|
414
|
+
"""
|
|
415
|
+
mln = MarkovLogicNetwork(config)
|
|
416
|
+
mln.add_knowledge_graph(kg)
|
|
417
|
+
mln.add_logical_rules(rules)
|
|
418
|
+
|
|
419
|
+
logger.info(f"Created MLN with {len(rules)} rules and {len(mln.ground_rules)} ground rules")
|
|
420
|
+
|
|
421
|
+
return mln
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def verify_mln_implementation():
|
|
425
|
+
"""Verify MLN implementation with small test case"""
|
|
426
|
+
from ..utils.config import default_config
|
|
427
|
+
from .knowledge_graph import Entity, Relation, load_knowledge_graph_from_triples
|
|
428
|
+
from .logical_rules import Variable, Atom, RuleType
|
|
429
|
+
|
|
430
|
+
# Create test knowledge graph
|
|
431
|
+
test_triples = [
|
|
432
|
+
("Tom", "plays", "basketball"),
|
|
433
|
+
("Tom", "friend", "John"),
|
|
434
|
+
("John", "plays", "soccer")
|
|
435
|
+
]
|
|
436
|
+
|
|
437
|
+
kg = load_knowledge_graph_from_triples(test_triples, "TestKG")
|
|
438
|
+
|
|
439
|
+
# Create test rule: plays(x, y) ∧ friend(x, z) ⇒ plays(z, y)
|
|
440
|
+
plays_rel = Relation("plays")
|
|
441
|
+
friend_rel = Relation("friend")
|
|
442
|
+
|
|
443
|
+
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
444
|
+
|
|
445
|
+
body_atoms = [
|
|
446
|
+
Atom(plays_rel, (x, y)),
|
|
447
|
+
Atom(friend_rel, (x, z))
|
|
448
|
+
]
|
|
449
|
+
head_atom = Atom(plays_rel, (z, y))
|
|
450
|
+
|
|
451
|
+
test_rule = LogicalRule(
|
|
452
|
+
rule_id="test_transitivity",
|
|
453
|
+
body=body_atoms,
|
|
454
|
+
head=head_atom,
|
|
455
|
+
rule_type=RuleType.TRANSITIVITY,
|
|
456
|
+
confidence=0.8
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Create MLN
|
|
460
|
+
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
461
|
+
|
|
462
|
+
# Verify MLN properties
|
|
463
|
+
assert len(mln.logical_rules) == 1, "Should have 1 logical rule"
|
|
464
|
+
assert len(mln.ground_rules) > 0, "Should have generated ground rules"
|
|
465
|
+
assert mln.rule_weights is not None, "Should have initialized rule weights"
|
|
466
|
+
|
|
467
|
+
# Test probability computation
|
|
468
|
+
test_assignment = {fact: True for fact in kg.known_facts}
|
|
469
|
+
log_prob = mln.compute_joint_probability(test_assignment)
|
|
470
|
+
|
|
471
|
+
assert torch.isfinite(log_prob), "Joint probability should be finite"
|
|
472
|
+
|
|
473
|
+
logger.info("MLN implementation verified successfully")
|
|
474
|
+
|
|
475
475
|
return True
|