odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/inference/m_step.py
CHANGED
|
@@ -1,577 +1,577 @@
|
|
|
1
|
-
|
|
2
|
-
import torch
|
|
3
|
-
import torch.nn as nn
|
|
4
|
-
import torch.optim as optim
|
|
5
|
-
from typing import List, Dict, Set, Tuple, Optional, Any, Callable
|
|
6
|
-
import logging
|
|
7
|
-
from dataclasses import dataclass
|
|
8
|
-
from collections import defaultdict
|
|
9
|
-
import numpy as np
|
|
10
|
-
|
|
11
|
-
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
12
|
-
from ..core.mln import MarkovLogicNetwork
|
|
13
|
-
from ..utils.config import NPLLConfig
|
|
14
|
-
from ..utils.math_utils import (
|
|
15
|
-
safe_log, gradient_clipping, compute_markov_blanket_prob,
|
|
16
|
-
log_sum_exp, partition_function_approximation
|
|
17
|
-
)
|
|
18
|
-
from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
|
|
19
|
-
from .e_step import EStepResult
|
|
20
|
-
|
|
21
|
-
logger = logging.getLogger(__name__)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
@dataclass
|
|
25
|
-
class MStepResult:
|
|
26
|
-
"""
|
|
27
|
-
Result of M-step computation containing optimization details
|
|
28
|
-
"""
|
|
29
|
-
updated_rule_weights: torch.Tensor # New rule weights ω
|
|
30
|
-
weight_changes: torch.Tensor # Changes in weights from previous iteration
|
|
31
|
-
gradient_norms: torch.Tensor # Gradient norms for each rule
|
|
32
|
-
pseudo_likelihood: torch.Tensor # Final pseudo-likelihood value
|
|
33
|
-
optimization_history: List[float] # History of objective values
|
|
34
|
-
convergence_info: Dict[str, Any] # Convergence diagnostics
|
|
35
|
-
iteration_count: int # Number of optimization iterations
|
|
36
|
-
|
|
37
|
-
def __str__(self) -> str:
|
|
38
|
-
return (f"M-step Result:\n"
|
|
39
|
-
f" Rule weights: {self.updated_rule_weights.tolist()}\n"
|
|
40
|
-
f" Max weight change: {torch.max(torch.abs(self.weight_changes)).item():.6f}\n"
|
|
41
|
-
f" Pseudo-likelihood: {self.pseudo_likelihood.item():.4f}\n"
|
|
42
|
-
f" Iterations: {self.iteration_count}")
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class PseudoLikelihoodComputer:
|
|
46
|
-
"""
|
|
47
|
-
Computes pseudo-log-likelihood objective for M-step optimization
|
|
48
|
-
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
def __init__(self, config: NPLLConfig):
|
|
52
|
-
self.config = config
|
|
53
|
-
self.markov_blanket_size = config.markov_blanket_size
|
|
54
|
-
|
|
55
|
-
def compute_pseudo_likelihood(self,
|
|
56
|
-
mln: MarkovLogicNetwork,
|
|
57
|
-
fact_probabilities: Dict[Triple, float],
|
|
58
|
-
ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> torch.Tensor:
|
|
59
|
-
"""
|
|
60
|
-
Compute pseudo-log-likelihood objective
|
|
61
|
-
"""
|
|
62
|
-
if not fact_probabilities or mln.rule_weights is None:
|
|
63
|
-
return torch.tensor(0.0)
|
|
64
|
-
|
|
65
|
-
likelihood_parts = []
|
|
66
|
-
|
|
67
|
-
# Compute pseudo-likelihood for each fact
|
|
68
|
-
for fact, q_prob in fact_probabilities.items():
|
|
69
|
-
# Get Markov blanket for this fact
|
|
70
|
-
markov_blanket = self._get_markov_blanket(fact, mln)
|
|
71
|
-
|
|
72
|
-
# Compute P(fact=True|ω, Markov Blanket) as a tensor
|
|
73
|
-
p_true = self._compute_conditional_probability(
|
|
74
|
-
fact, markov_blanket, mln, fact_probabilities
|
|
75
|
-
)
|
|
76
|
-
# Full cross-entropy against Q: q*log p + (1-q)*log(1-p)
|
|
77
|
-
device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
|
|
78
|
-
q = torch.tensor(q_prob, device=device, requires_grad=False)
|
|
79
|
-
contribution = q * safe_log(p_true) + (1.0 - q) * safe_log(1.0 - p_true)
|
|
80
|
-
likelihood_parts.append(contribution)
|
|
81
|
-
|
|
82
|
-
# Sum all parts at the end to create the final tensor
|
|
83
|
-
if likelihood_parts:
|
|
84
|
-
total_pseudo_likelihood = torch.sum(torch.stack(likelihood_parts))
|
|
85
|
-
else:
|
|
86
|
-
total_pseudo_likelihood = torch.tensor(0.0)
|
|
87
|
-
|
|
88
|
-
return total_pseudo_likelihood
|
|
89
|
-
|
|
90
|
-
def _get_markov_blanket(self, target_fact: Triple, mln: MarkovLogicNetwork) -> Set[Triple]:
|
|
91
|
-
"""
|
|
92
|
-
Get Markov blanket for a fact - all facts that appear in the same ground rules
|
|
93
|
-
|
|
94
|
-
"""
|
|
95
|
-
markov_blanket = set()
|
|
96
|
-
|
|
97
|
-
# Find all ground rules containing the target fact
|
|
98
|
-
for ground_rule in mln.ground_rules:
|
|
99
|
-
rule_facts = ground_rule.get_all_facts()
|
|
100
|
-
if target_fact in rule_facts:
|
|
101
|
-
# Add all other facts in this ground rule to Markov blanket
|
|
102
|
-
markov_blanket.update(fact for fact in rule_facts if fact != target_fact)
|
|
103
|
-
|
|
104
|
-
# Limit size to prevent computational explosion
|
|
105
|
-
if len(markov_blanket) >= self.markov_blanket_size:
|
|
106
|
-
break
|
|
107
|
-
|
|
108
|
-
return markov_blanket
|
|
109
|
-
|
|
110
|
-
def _compute_conditional_probability(self,
|
|
111
|
-
target_fact: Triple,
|
|
112
|
-
markov_blanket: Set[Triple],
|
|
113
|
-
mln: MarkovLogicNetwork,
|
|
114
|
-
fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
|
|
115
|
-
"""Expected feature-difference log-odds approximation; no hard 0/1 blanket."""
|
|
116
|
-
device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
|
|
117
|
-
if mln.rule_weights is None:
|
|
118
|
-
return torch.tensor(0.5, device=device)
|
|
119
|
-
|
|
120
|
-
delta = torch.zeros((), device=device, requires_grad=True)
|
|
121
|
-
rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
|
|
122
|
-
relevant_ground_rules = getattr(mln, 'fact_to_groundrules', {}).get(target_fact, [])
|
|
123
|
-
if not relevant_ground_rules:
|
|
124
|
-
relevant_ground_rules = [gr for gr in mln.ground_rules if target_fact in gr.get_all_facts()]
|
|
125
|
-
|
|
126
|
-
def q_true(f: Triple) -> torch.Tensor:
|
|
127
|
-
if f in fact_probabilities:
|
|
128
|
-
return torch.tensor(fact_probabilities[f], device=device)
|
|
129
|
-
if mln.knowledge_graph and f in mln.knowledge_graph.known_facts:
|
|
130
|
-
return torch.tensor(1.0, device=device)
|
|
131
|
-
return torch.tensor(0.5, device=device)
|
|
132
|
-
|
|
133
|
-
for gr in relevant_ground_rules:
|
|
134
|
-
w = mln.rule_weights[rule_id_to_idx[gr.parent_rule.rule_id]]
|
|
135
|
-
if gr.body_facts:
|
|
136
|
-
p_body_true = torch.stack([q_true(f) for f in gr.body_facts]).prod()
|
|
137
|
-
else:
|
|
138
|
-
p_body_true = torch.tensor(1.0, device=device)
|
|
139
|
-
|
|
140
|
-
p_head_true = q_true(gr.head_fact)
|
|
141
|
-
p_head_false = 1.0 - p_head_true
|
|
142
|
-
|
|
143
|
-
if gr.head_fact == target_fact:
|
|
144
|
-
delta = delta + w * p_body_true
|
|
145
|
-
elif target_fact in gr.body_facts:
|
|
146
|
-
other_body = [f for f in gr.body_facts if f != target_fact]
|
|
147
|
-
p_other_body_true = (torch.stack([q_true(f) for f in other_body]).prod()
|
|
148
|
-
if other_body else torch.tensor(1.0, device=device))
|
|
149
|
-
delta = delta - w * (p_other_body_true * p_head_false)
|
|
150
|
-
else:
|
|
151
|
-
pass
|
|
152
|
-
|
|
153
|
-
delta = torch.clamp(delta, -40.0, 40.0)
|
|
154
|
-
return torch.sigmoid(delta)
|
|
155
|
-
|
|
156
|
-
def _compute_local_potential(self,
|
|
157
|
-
target_fact: Triple,
|
|
158
|
-
assignment: Dict[Triple, bool],
|
|
159
|
-
mln: MarkovLogicNetwork) -> torch.Tensor:
|
|
160
|
-
"""
|
|
161
|
-
Compute local potential for assignment involving target fact
|
|
162
|
-
|
|
163
|
-
"""
|
|
164
|
-
# Start with a zero tensor that is on the same device as the weights
|
|
165
|
-
local_potential = torch.tensor(0.0, device=mln.rule_weights.device)
|
|
166
|
-
|
|
167
|
-
if mln.rule_weights is None:
|
|
168
|
-
return local_potential
|
|
169
|
-
|
|
170
|
-
# Find all ground rules containing the target fact
|
|
171
|
-
relevant_ground_rules = []
|
|
172
|
-
for ground_rule in mln.ground_rules:
|
|
173
|
-
if target_fact in ground_rule.get_all_facts():
|
|
174
|
-
relevant_ground_rules.append(ground_rule)
|
|
175
|
-
|
|
176
|
-
# Create a mapping from rule_id to its index in the weights tensor
|
|
177
|
-
rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
|
|
178
|
-
|
|
179
|
-
for ground_rule in relevant_ground_rules:
|
|
180
|
-
# Check if ground rule is satisfied
|
|
181
|
-
body_satisfied = all(
|
|
182
|
-
assignment.get(fact, False)
|
|
183
|
-
for fact in ground_rule.body_facts
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
rule_satisfied = False
|
|
187
|
-
if not body_satisfied:
|
|
188
|
-
# Body false -> rule vacuously true
|
|
189
|
-
rule_satisfied = True
|
|
190
|
-
else:
|
|
191
|
-
# Body true -> check head
|
|
192
|
-
rule_satisfied = assignment.get(ground_rule.head_fact, False)
|
|
193
|
-
|
|
194
|
-
# Add contribution to potential
|
|
195
|
-
if rule_satisfied:
|
|
196
|
-
rule_idx = rule_id_to_idx.get(ground_rule.parent_rule.rule_id)
|
|
197
|
-
if rule_idx is not None:
|
|
198
|
-
# Add the weight *tensor* to the potential, preserving the graph
|
|
199
|
-
local_potential = local_potential + mln.rule_weights[rule_idx]
|
|
200
|
-
|
|
201
|
-
return local_potential
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
class GradientComputer:
|
|
205
|
-
"""
|
|
206
|
-
Computes gradients for rule weight optimization
|
|
207
|
-
|
|
208
|
-
"""
|
|
209
|
-
|
|
210
|
-
def __init__(self, config: NPLLConfig):
|
|
211
|
-
self.config = config
|
|
212
|
-
self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
|
|
213
|
-
|
|
214
|
-
def compute_rule_weight_gradients(self,
|
|
215
|
-
mln: MarkovLogicNetwork,
|
|
216
|
-
fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
|
|
217
|
-
"""
|
|
218
|
-
Compute gradients of pseudo-likelihood with respect to rule weights
|
|
219
|
-
"""
|
|
220
|
-
if mln.rule_weights is None or not fact_probabilities:
|
|
221
|
-
return torch.zeros(len(mln.logical_rules))
|
|
222
|
-
|
|
223
|
-
# Enable gradient computation
|
|
224
|
-
if mln.rule_weights is not None:
|
|
225
|
-
mln.rule_weights.requires_grad_(True)
|
|
226
|
-
|
|
227
|
-
# Compute pseudo-likelihood
|
|
228
|
-
pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
229
|
-
mln, fact_probabilities
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
# Compute gradients
|
|
233
|
-
pseudo_likelihood.backward()
|
|
234
|
-
|
|
235
|
-
if mln.rule_weights.grad is not None:
|
|
236
|
-
gradients = mln.rule_weights.grad.clone()
|
|
237
|
-
else:
|
|
238
|
-
gradients = torch.zeros_like(mln.rule_weights)
|
|
239
|
-
|
|
240
|
-
return gradients
|
|
241
|
-
|
|
242
|
-
def compute_finite_difference_gradients(self,
|
|
243
|
-
mln: MarkovLogicNetwork,
|
|
244
|
-
fact_probabilities: Dict[Triple, float],
|
|
245
|
-
epsilon: float = 1e-5) -> torch.Tensor:
|
|
246
|
-
"""
|
|
247
|
-
Compute gradients using finite differences (for verification/debugging)
|
|
248
|
-
"""
|
|
249
|
-
if mln.rule_weights is None:
|
|
250
|
-
return torch.zeros(len(mln.logical_rules))
|
|
251
|
-
|
|
252
|
-
gradients = torch.zeros_like(mln.rule_weights)
|
|
253
|
-
original_weights = mln.rule_weights.data.clone()
|
|
254
|
-
|
|
255
|
-
for i in range(len(mln.rule_weights)):
|
|
256
|
-
# Forward difference
|
|
257
|
-
mln.rule_weights.data[i] += epsilon
|
|
258
|
-
pseudo_likelihood_plus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
259
|
-
mln, fact_probabilities
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# Backward difference
|
|
263
|
-
mln.rule_weights.data[i] -= 2 * epsilon
|
|
264
|
-
pseudo_likelihood_minus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
265
|
-
mln, fact_probabilities
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
# Compute gradient
|
|
269
|
-
gradients[i] = (pseudo_likelihood_plus - pseudo_likelihood_minus) / (2 * epsilon)
|
|
270
|
-
|
|
271
|
-
# Restore original weight
|
|
272
|
-
mln.rule_weights.data[i] = original_weights[i]
|
|
273
|
-
|
|
274
|
-
return gradients
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
class MStepOptimizer:
|
|
278
|
-
"""
|
|
279
|
-
Main M-step optimizer that updates rule weights ω
|
|
280
|
-
|
|
281
|
-
"""
|
|
282
|
-
|
|
283
|
-
def __init__(self, config: NPLLConfig):
|
|
284
|
-
self.config = config
|
|
285
|
-
self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
|
|
286
|
-
self.gradient_computer = GradientComputer(config)
|
|
287
|
-
|
|
288
|
-
# Optimization parameters
|
|
289
|
-
self.learning_rate = config.learning_rate
|
|
290
|
-
self.max_iterations = 100 # M-step specific iterations
|
|
291
|
-
self.convergence_threshold = config.convergence_threshold
|
|
292
|
-
self.grad_clip_norm = config.grad_clip_norm
|
|
293
|
-
|
|
294
|
-
def optimize_rule_weights(self,
|
|
295
|
-
mln: MarkovLogicNetwork,
|
|
296
|
-
e_step_result: EStepResult) -> MStepResult:
|
|
297
|
-
"""
|
|
298
|
-
Main M-step optimization procedure
|
|
299
|
-
|
|
300
|
-
"""
|
|
301
|
-
if mln.rule_weights is None:
|
|
302
|
-
logger.warning("No rule weights to optimize in M-step")
|
|
303
|
-
return MStepResult(
|
|
304
|
-
updated_rule_weights=torch.tensor([]),
|
|
305
|
-
weight_changes=torch.tensor([]),
|
|
306
|
-
gradient_norms=torch.tensor([]),
|
|
307
|
-
pseudo_likelihood=torch.tensor(0.0),
|
|
308
|
-
optimization_history=[],
|
|
309
|
-
convergence_info={'converged': True, 'reason': 'no_weights'},
|
|
310
|
-
iteration_count=0
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
logger.debug(f"Starting M-step optimization for {len(mln.rule_weights)} rule weights")
|
|
314
|
-
|
|
315
|
-
# Store initial weights
|
|
316
|
-
initial_weights = mln.rule_weights.data.clone()
|
|
317
|
-
|
|
318
|
-
# Setup optimizer
|
|
319
|
-
optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
|
|
320
|
-
|
|
321
|
-
# Optimization history
|
|
322
|
-
objective_history = []
|
|
323
|
-
gradient_history = []
|
|
324
|
-
|
|
325
|
-
converged = False
|
|
326
|
-
iteration = 0
|
|
327
|
-
|
|
328
|
-
for iteration in range(self.max_iterations):
|
|
329
|
-
optimizer.zero_grad()
|
|
330
|
-
|
|
331
|
-
# Ensure gradients are enabled for rule weights
|
|
332
|
-
if mln.rule_weights is not None:
|
|
333
|
-
mln.rule_weights.requires_grad_(True)
|
|
334
|
-
|
|
335
|
-
# Compute pseudo-likelihood objective
|
|
336
|
-
pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
337
|
-
mln, e_step_result.fact_probabilities
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
# We want to maximize pseudo-likelihood, so minimize negative
|
|
341
|
-
loss = -pseudo_likelihood
|
|
342
|
-
|
|
343
|
-
# Backward pass
|
|
344
|
-
loss.backward()
|
|
345
|
-
|
|
346
|
-
# Store objective value (detached to prevent gradient issues)
|
|
347
|
-
objective_history.append(pseudo_likelihood.detach().item())
|
|
348
|
-
|
|
349
|
-
# Clip gradients
|
|
350
|
-
if mln.rule_weights.grad is not None:
|
|
351
|
-
grad_norm = gradient_clipping([mln.rule_weights], self.grad_clip_norm)
|
|
352
|
-
gradient_history.append(grad_norm)
|
|
353
|
-
else:
|
|
354
|
-
gradient_history.append(0.0)
|
|
355
|
-
|
|
356
|
-
# Optimization step
|
|
357
|
-
optimizer.step()
|
|
358
|
-
|
|
359
|
-
# Check convergence
|
|
360
|
-
if iteration > 0:
|
|
361
|
-
objective_change = abs(objective_history[-1] - objective_history[-2])
|
|
362
|
-
if objective_change < self.convergence_threshold:
|
|
363
|
-
converged = True
|
|
364
|
-
logger.debug(f"M-step converged at iteration {iteration}")
|
|
365
|
-
break
|
|
366
|
-
|
|
367
|
-
# Compute final metrics
|
|
368
|
-
final_weights = mln.rule_weights.data.clone()
|
|
369
|
-
weight_changes = final_weights - initial_weights
|
|
370
|
-
|
|
371
|
-
final_pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
372
|
-
mln, e_step_result.fact_probabilities
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
convergence_info = {
|
|
376
|
-
'converged': converged,
|
|
377
|
-
'final_objective_change': abs(objective_history[-1] - objective_history[-2]) if len(objective_history) > 1 else 0.0,
|
|
378
|
-
'objective_history': objective_history,
|
|
379
|
-
'gradient_history': gradient_history,
|
|
380
|
-
'reason': 'converged' if converged else 'max_iterations'
|
|
381
|
-
}
|
|
382
|
-
|
|
383
|
-
result = MStepResult(
|
|
384
|
-
updated_rule_weights=final_weights,
|
|
385
|
-
weight_changes=weight_changes,
|
|
386
|
-
gradient_norms=torch.tensor(gradient_history),
|
|
387
|
-
pseudo_likelihood=final_pseudo_likelihood,
|
|
388
|
-
optimization_history=objective_history,
|
|
389
|
-
convergence_info=convergence_info,
|
|
390
|
-
iteration_count=iteration + 1
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
logger.debug(f"M-step completed: {result}")
|
|
394
|
-
|
|
395
|
-
return result
|
|
396
|
-
|
|
397
|
-
def optimize_with_regularization(self,
|
|
398
|
-
mln: MarkovLogicNetwork,
|
|
399
|
-
e_step_result: EStepResult,
|
|
400
|
-
l1_weight: float = 0.01,
|
|
401
|
-
l2_weight: float = 0.01) -> MStepResult:
|
|
402
|
-
"""
|
|
403
|
-
M-step optimization with L1 and L2 regularization on rule weights
|
|
404
|
-
|
|
405
|
-
Helps prevent overfitting and encourages sparse rule sets
|
|
406
|
-
"""
|
|
407
|
-
if mln.rule_weights is None:
|
|
408
|
-
return self.optimize_rule_weights(mln, e_step_result)
|
|
409
|
-
|
|
410
|
-
initial_weights = mln.rule_weights.data.clone()
|
|
411
|
-
optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
|
|
412
|
-
|
|
413
|
-
objective_history = []
|
|
414
|
-
converged = False
|
|
415
|
-
|
|
416
|
-
for iteration in range(self.max_iterations):
|
|
417
|
-
optimizer.zero_grad()
|
|
418
|
-
|
|
419
|
-
# Primary pseudo-likelihood objective
|
|
420
|
-
pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
421
|
-
mln, e_step_result.fact_probabilities
|
|
422
|
-
)
|
|
423
|
-
|
|
424
|
-
# Regularization terms
|
|
425
|
-
l1_reg = torch.sum(torch.abs(mln.rule_weights))
|
|
426
|
-
l2_reg = torch.sum(mln.rule_weights ** 2)
|
|
427
|
-
|
|
428
|
-
# Total objective (maximize pseudo-likelihood, minimize regularization)
|
|
429
|
-
total_objective = pseudo_likelihood - l1_weight * l1_reg - l2_weight * l2_reg
|
|
430
|
-
loss = -total_objective
|
|
431
|
-
|
|
432
|
-
# Optimization
|
|
433
|
-
loss.backward()
|
|
434
|
-
gradient_clipping([mln.rule_weights], self.grad_clip_norm)
|
|
435
|
-
optimizer.step()
|
|
436
|
-
|
|
437
|
-
objective_history.append(total_objective.item())
|
|
438
|
-
|
|
439
|
-
# Convergence check
|
|
440
|
-
if iteration > 0:
|
|
441
|
-
if abs(objective_history[-1] - objective_history[-2]) < self.convergence_threshold:
|
|
442
|
-
converged = True
|
|
443
|
-
break
|
|
444
|
-
|
|
445
|
-
final_weights = mln.rule_weights.data.clone()
|
|
446
|
-
weight_changes = final_weights - initial_weights
|
|
447
|
-
|
|
448
|
-
return MStepResult(
|
|
449
|
-
updated_rule_weights=final_weights,
|
|
450
|
-
weight_changes=weight_changes,
|
|
451
|
-
gradient_norms=torch.tensor([0.0]), # Placeholder
|
|
452
|
-
pseudo_likelihood=pseudo_likelihood,
|
|
453
|
-
optimization_history=objective_history,
|
|
454
|
-
convergence_info={'converged': converged},
|
|
455
|
-
iteration_count=len(objective_history)
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
class MStepRunner:
|
|
460
|
-
"""
|
|
461
|
-
High-level runner for M-step computations
|
|
462
|
-
Handles different optimization strategies and result aggregation
|
|
463
|
-
"""
|
|
464
|
-
|
|
465
|
-
def __init__(self, config: NPLLConfig):
|
|
466
|
-
self.config = config
|
|
467
|
-
self.m_step_optimizer = MStepOptimizer(config)
|
|
468
|
-
|
|
469
|
-
def run_m_step(self,
|
|
470
|
-
mln: MarkovLogicNetwork,
|
|
471
|
-
e_step_result: EStepResult,
|
|
472
|
-
optimization_strategy: str = "adam") -> MStepResult:
|
|
473
|
-
"""
|
|
474
|
-
Run complete M-step for rule weight optimization
|
|
475
|
-
|
|
476
|
-
"""
|
|
477
|
-
logger.info(f"Running M-step with {optimization_strategy} optimization")
|
|
478
|
-
|
|
479
|
-
if optimization_strategy == "regularized":
|
|
480
|
-
result = self.m_step_optimizer.optimize_with_regularization(
|
|
481
|
-
mln, e_step_result, l1_weight=0.01, l2_weight=0.01
|
|
482
|
-
)
|
|
483
|
-
else:
|
|
484
|
-
# Standard optimization
|
|
485
|
-
result = self.m_step_optimizer.optimize_rule_weights(mln, e_step_result)
|
|
486
|
-
|
|
487
|
-
logger.info(f"M-step completed: Pseudo-likelihood={result.pseudo_likelihood.item():.4f}, "
|
|
488
|
-
f"Max weight change={torch.max(torch.abs(result.weight_changes)).item():.6f}")
|
|
489
|
-
|
|
490
|
-
return result
|
|
491
|
-
|
|
492
|
-
def run_m_step_with_constraints(self,
|
|
493
|
-
mln: MarkovLogicNetwork,
|
|
494
|
-
e_step_result: EStepResult,
|
|
495
|
-
weight_bounds: Optional[List[Tuple[float, float]]] = None) -> MStepResult:
|
|
496
|
-
"""
|
|
497
|
-
Run M-step with constraints on rule weights
|
|
498
|
-
|
|
499
|
-
"""
|
|
500
|
-
# Standard M-step
|
|
501
|
-
result = self.run_m_step(mln, e_step_result)
|
|
502
|
-
|
|
503
|
-
# Apply weight constraints
|
|
504
|
-
if weight_bounds and mln.rule_weights is not None:
|
|
505
|
-
with torch.no_grad():
|
|
506
|
-
for i, (min_weight, max_weight) in enumerate(weight_bounds):
|
|
507
|
-
if i < len(mln.rule_weights):
|
|
508
|
-
mln.rule_weights[i] = torch.clamp(
|
|
509
|
-
mln.rule_weights[i], min_weight, max_weight
|
|
510
|
-
)
|
|
511
|
-
|
|
512
|
-
# Update result with constrained weights
|
|
513
|
-
result.updated_rule_weights = mln.rule_weights.data.clone()
|
|
514
|
-
|
|
515
|
-
return result
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
def create_m_step_runner(config: NPLLConfig) -> MStepRunner:
|
|
519
|
-
"""Factory function to create M-step runner"""
|
|
520
|
-
return MStepRunner(config)
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
def verify_m_step_implementation():
|
|
524
|
-
"""Verify M-step implementation"""
|
|
525
|
-
from ..utils.config import default_config
|
|
526
|
-
from ..core import load_knowledge_graph_from_triples
|
|
527
|
-
from ..core.mln import create_mln_from_kg_and_rules
|
|
528
|
-
from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
|
|
529
|
-
from ..scoring import create_scoring_module
|
|
530
|
-
from .e_step import create_e_step_runner
|
|
531
|
-
|
|
532
|
-
# Create test data
|
|
533
|
-
test_triples = [
|
|
534
|
-
("A", "r1", "B"),
|
|
535
|
-
("B", "r2", "C"),
|
|
536
|
-
]
|
|
537
|
-
|
|
538
|
-
kg = load_knowledge_graph_from_triples(test_triples)
|
|
539
|
-
kg.add_unknown_fact("A", "r3", "C")
|
|
540
|
-
|
|
541
|
-
# Create test rule
|
|
542
|
-
from ..core import Relation
|
|
543
|
-
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
544
|
-
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
545
|
-
|
|
546
|
-
test_rule = LogicalRule(
|
|
547
|
-
rule_id="test_m_step_rule",
|
|
548
|
-
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
549
|
-
head=Atom(r3, (x, z)),
|
|
550
|
-
rule_type=RuleType.TRANSITIVITY
|
|
551
|
-
)
|
|
552
|
-
|
|
553
|
-
# Create MLN and scoring module
|
|
554
|
-
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
555
|
-
scoring_module = create_scoring_module(default_config, kg)
|
|
556
|
-
|
|
557
|
-
# Run E-step to get Q(U)
|
|
558
|
-
e_step_runner = create_e_step_runner(default_config)
|
|
559
|
-
e_step_result = e_step_runner.run_e_step(mln, scoring_module, kg)
|
|
560
|
-
|
|
561
|
-
# Test M-step
|
|
562
|
-
m_step_runner = MStepRunner(default_config)
|
|
563
|
-
m_step_result = m_step_runner.run_m_step(mln, e_step_result)
|
|
564
|
-
|
|
565
|
-
# Verify results
|
|
566
|
-
assert len(m_step_result.updated_rule_weights) == len(mln.logical_rules), \
|
|
567
|
-
"Should have weights for all rules"
|
|
568
|
-
|
|
569
|
-
assert torch.all(torch.isfinite(m_step_result.updated_rule_weights)), \
|
|
570
|
-
"All weights should be finite"
|
|
571
|
-
|
|
572
|
-
assert torch.isfinite(m_step_result.pseudo_likelihood), \
|
|
573
|
-
"Pseudo-likelihood should be finite"
|
|
574
|
-
|
|
575
|
-
logger.info("M-step implementation verified successfully")
|
|
576
|
-
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
from typing import List, Dict, Set, Tuple, Optional, Any, Callable
|
|
6
|
+
import logging
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
|
|
12
|
+
from ..core.mln import MarkovLogicNetwork
|
|
13
|
+
from ..utils.config import NPLLConfig
|
|
14
|
+
from ..utils.math_utils import (
|
|
15
|
+
safe_log, gradient_clipping, compute_markov_blanket_prob,
|
|
16
|
+
log_sum_exp, partition_function_approximation
|
|
17
|
+
)
|
|
18
|
+
from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
|
|
19
|
+
from .e_step import EStepResult
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class MStepResult:
|
|
26
|
+
"""
|
|
27
|
+
Result of M-step computation containing optimization details
|
|
28
|
+
"""
|
|
29
|
+
updated_rule_weights: torch.Tensor # New rule weights ω
|
|
30
|
+
weight_changes: torch.Tensor # Changes in weights from previous iteration
|
|
31
|
+
gradient_norms: torch.Tensor # Gradient norms for each rule
|
|
32
|
+
pseudo_likelihood: torch.Tensor # Final pseudo-likelihood value
|
|
33
|
+
optimization_history: List[float] # History of objective values
|
|
34
|
+
convergence_info: Dict[str, Any] # Convergence diagnostics
|
|
35
|
+
iteration_count: int # Number of optimization iterations
|
|
36
|
+
|
|
37
|
+
def __str__(self) -> str:
|
|
38
|
+
return (f"M-step Result:\n"
|
|
39
|
+
f" Rule weights: {self.updated_rule_weights.tolist()}\n"
|
|
40
|
+
f" Max weight change: {torch.max(torch.abs(self.weight_changes)).item():.6f}\n"
|
|
41
|
+
f" Pseudo-likelihood: {self.pseudo_likelihood.item():.4f}\n"
|
|
42
|
+
f" Iterations: {self.iteration_count}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PseudoLikelihoodComputer:
|
|
46
|
+
"""
|
|
47
|
+
Computes pseudo-log-likelihood objective for M-step optimization
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, config: NPLLConfig):
|
|
52
|
+
self.config = config
|
|
53
|
+
self.markov_blanket_size = config.markov_blanket_size
|
|
54
|
+
|
|
55
|
+
def compute_pseudo_likelihood(self,
|
|
56
|
+
mln: MarkovLogicNetwork,
|
|
57
|
+
fact_probabilities: Dict[Triple, float],
|
|
58
|
+
ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> torch.Tensor:
|
|
59
|
+
"""
|
|
60
|
+
Compute pseudo-log-likelihood objective
|
|
61
|
+
"""
|
|
62
|
+
if not fact_probabilities or mln.rule_weights is None:
|
|
63
|
+
return torch.tensor(0.0)
|
|
64
|
+
|
|
65
|
+
likelihood_parts = []
|
|
66
|
+
|
|
67
|
+
# Compute pseudo-likelihood for each fact
|
|
68
|
+
for fact, q_prob in fact_probabilities.items():
|
|
69
|
+
# Get Markov blanket for this fact
|
|
70
|
+
markov_blanket = self._get_markov_blanket(fact, mln)
|
|
71
|
+
|
|
72
|
+
# Compute P(fact=True|ω, Markov Blanket) as a tensor
|
|
73
|
+
p_true = self._compute_conditional_probability(
|
|
74
|
+
fact, markov_blanket, mln, fact_probabilities
|
|
75
|
+
)
|
|
76
|
+
# Full cross-entropy against Q: q*log p + (1-q)*log(1-p)
|
|
77
|
+
device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
|
|
78
|
+
q = torch.tensor(q_prob, device=device, requires_grad=False)
|
|
79
|
+
contribution = q * safe_log(p_true) + (1.0 - q) * safe_log(1.0 - p_true)
|
|
80
|
+
likelihood_parts.append(contribution)
|
|
81
|
+
|
|
82
|
+
# Sum all parts at the end to create the final tensor
|
|
83
|
+
if likelihood_parts:
|
|
84
|
+
total_pseudo_likelihood = torch.sum(torch.stack(likelihood_parts))
|
|
85
|
+
else:
|
|
86
|
+
total_pseudo_likelihood = torch.tensor(0.0)
|
|
87
|
+
|
|
88
|
+
return total_pseudo_likelihood
|
|
89
|
+
|
|
90
|
+
def _get_markov_blanket(self, target_fact: Triple, mln: MarkovLogicNetwork) -> Set[Triple]:
|
|
91
|
+
"""
|
|
92
|
+
Get Markov blanket for a fact - all facts that appear in the same ground rules
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
markov_blanket = set()
|
|
96
|
+
|
|
97
|
+
# Find all ground rules containing the target fact
|
|
98
|
+
for ground_rule in mln.ground_rules:
|
|
99
|
+
rule_facts = ground_rule.get_all_facts()
|
|
100
|
+
if target_fact in rule_facts:
|
|
101
|
+
# Add all other facts in this ground rule to Markov blanket
|
|
102
|
+
markov_blanket.update(fact for fact in rule_facts if fact != target_fact)
|
|
103
|
+
|
|
104
|
+
# Limit size to prevent computational explosion
|
|
105
|
+
if len(markov_blanket) >= self.markov_blanket_size:
|
|
106
|
+
break
|
|
107
|
+
|
|
108
|
+
return markov_blanket
|
|
109
|
+
|
|
110
|
+
def _compute_conditional_probability(self,
|
|
111
|
+
target_fact: Triple,
|
|
112
|
+
markov_blanket: Set[Triple],
|
|
113
|
+
mln: MarkovLogicNetwork,
|
|
114
|
+
fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
|
|
115
|
+
"""Expected feature-difference log-odds approximation; no hard 0/1 blanket."""
|
|
116
|
+
device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
|
|
117
|
+
if mln.rule_weights is None:
|
|
118
|
+
return torch.tensor(0.5, device=device)
|
|
119
|
+
|
|
120
|
+
delta = torch.zeros((), device=device, requires_grad=True)
|
|
121
|
+
rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
|
|
122
|
+
relevant_ground_rules = getattr(mln, 'fact_to_groundrules', {}).get(target_fact, [])
|
|
123
|
+
if not relevant_ground_rules:
|
|
124
|
+
relevant_ground_rules = [gr for gr in mln.ground_rules if target_fact in gr.get_all_facts()]
|
|
125
|
+
|
|
126
|
+
def q_true(f: Triple) -> torch.Tensor:
|
|
127
|
+
if f in fact_probabilities:
|
|
128
|
+
return torch.tensor(fact_probabilities[f], device=device)
|
|
129
|
+
if mln.knowledge_graph and f in mln.knowledge_graph.known_facts:
|
|
130
|
+
return torch.tensor(1.0, device=device)
|
|
131
|
+
return torch.tensor(0.5, device=device)
|
|
132
|
+
|
|
133
|
+
for gr in relevant_ground_rules:
|
|
134
|
+
w = mln.rule_weights[rule_id_to_idx[gr.parent_rule.rule_id]]
|
|
135
|
+
if gr.body_facts:
|
|
136
|
+
p_body_true = torch.stack([q_true(f) for f in gr.body_facts]).prod()
|
|
137
|
+
else:
|
|
138
|
+
p_body_true = torch.tensor(1.0, device=device)
|
|
139
|
+
|
|
140
|
+
p_head_true = q_true(gr.head_fact)
|
|
141
|
+
p_head_false = 1.0 - p_head_true
|
|
142
|
+
|
|
143
|
+
if gr.head_fact == target_fact:
|
|
144
|
+
delta = delta + w * p_body_true
|
|
145
|
+
elif target_fact in gr.body_facts:
|
|
146
|
+
other_body = [f for f in gr.body_facts if f != target_fact]
|
|
147
|
+
p_other_body_true = (torch.stack([q_true(f) for f in other_body]).prod()
|
|
148
|
+
if other_body else torch.tensor(1.0, device=device))
|
|
149
|
+
delta = delta - w * (p_other_body_true * p_head_false)
|
|
150
|
+
else:
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
delta = torch.clamp(delta, -40.0, 40.0)
|
|
154
|
+
return torch.sigmoid(delta)
|
|
155
|
+
|
|
156
|
+
def _compute_local_potential(self,
|
|
157
|
+
target_fact: Triple,
|
|
158
|
+
assignment: Dict[Triple, bool],
|
|
159
|
+
mln: MarkovLogicNetwork) -> torch.Tensor:
|
|
160
|
+
"""
|
|
161
|
+
Compute local potential for assignment involving target fact
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
# Start with a zero tensor that is on the same device as the weights
|
|
165
|
+
local_potential = torch.tensor(0.0, device=mln.rule_weights.device)
|
|
166
|
+
|
|
167
|
+
if mln.rule_weights is None:
|
|
168
|
+
return local_potential
|
|
169
|
+
|
|
170
|
+
# Find all ground rules containing the target fact
|
|
171
|
+
relevant_ground_rules = []
|
|
172
|
+
for ground_rule in mln.ground_rules:
|
|
173
|
+
if target_fact in ground_rule.get_all_facts():
|
|
174
|
+
relevant_ground_rules.append(ground_rule)
|
|
175
|
+
|
|
176
|
+
# Create a mapping from rule_id to its index in the weights tensor
|
|
177
|
+
rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
|
|
178
|
+
|
|
179
|
+
for ground_rule in relevant_ground_rules:
|
|
180
|
+
# Check if ground rule is satisfied
|
|
181
|
+
body_satisfied = all(
|
|
182
|
+
assignment.get(fact, False)
|
|
183
|
+
for fact in ground_rule.body_facts
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
rule_satisfied = False
|
|
187
|
+
if not body_satisfied:
|
|
188
|
+
# Body false -> rule vacuously true
|
|
189
|
+
rule_satisfied = True
|
|
190
|
+
else:
|
|
191
|
+
# Body true -> check head
|
|
192
|
+
rule_satisfied = assignment.get(ground_rule.head_fact, False)
|
|
193
|
+
|
|
194
|
+
# Add contribution to potential
|
|
195
|
+
if rule_satisfied:
|
|
196
|
+
rule_idx = rule_id_to_idx.get(ground_rule.parent_rule.rule_id)
|
|
197
|
+
if rule_idx is not None:
|
|
198
|
+
# Add the weight *tensor* to the potential, preserving the graph
|
|
199
|
+
local_potential = local_potential + mln.rule_weights[rule_idx]
|
|
200
|
+
|
|
201
|
+
return local_potential
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class GradientComputer:
|
|
205
|
+
"""
|
|
206
|
+
Computes gradients for rule weight optimization
|
|
207
|
+
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self, config: NPLLConfig):
|
|
211
|
+
self.config = config
|
|
212
|
+
self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
|
|
213
|
+
|
|
214
|
+
def compute_rule_weight_gradients(self,
|
|
215
|
+
mln: MarkovLogicNetwork,
|
|
216
|
+
fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
|
|
217
|
+
"""
|
|
218
|
+
Compute gradients of pseudo-likelihood with respect to rule weights
|
|
219
|
+
"""
|
|
220
|
+
if mln.rule_weights is None or not fact_probabilities:
|
|
221
|
+
return torch.zeros(len(mln.logical_rules))
|
|
222
|
+
|
|
223
|
+
# Enable gradient computation
|
|
224
|
+
if mln.rule_weights is not None:
|
|
225
|
+
mln.rule_weights.requires_grad_(True)
|
|
226
|
+
|
|
227
|
+
# Compute pseudo-likelihood
|
|
228
|
+
pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
229
|
+
mln, fact_probabilities
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Compute gradients
|
|
233
|
+
pseudo_likelihood.backward()
|
|
234
|
+
|
|
235
|
+
if mln.rule_weights.grad is not None:
|
|
236
|
+
gradients = mln.rule_weights.grad.clone()
|
|
237
|
+
else:
|
|
238
|
+
gradients = torch.zeros_like(mln.rule_weights)
|
|
239
|
+
|
|
240
|
+
return gradients
|
|
241
|
+
|
|
242
|
+
def compute_finite_difference_gradients(self,
|
|
243
|
+
mln: MarkovLogicNetwork,
|
|
244
|
+
fact_probabilities: Dict[Triple, float],
|
|
245
|
+
epsilon: float = 1e-5) -> torch.Tensor:
|
|
246
|
+
"""
|
|
247
|
+
Compute gradients using finite differences (for verification/debugging)
|
|
248
|
+
"""
|
|
249
|
+
if mln.rule_weights is None:
|
|
250
|
+
return torch.zeros(len(mln.logical_rules))
|
|
251
|
+
|
|
252
|
+
gradients = torch.zeros_like(mln.rule_weights)
|
|
253
|
+
original_weights = mln.rule_weights.data.clone()
|
|
254
|
+
|
|
255
|
+
for i in range(len(mln.rule_weights)):
|
|
256
|
+
# Forward difference
|
|
257
|
+
mln.rule_weights.data[i] += epsilon
|
|
258
|
+
pseudo_likelihood_plus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
259
|
+
mln, fact_probabilities
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Backward difference
|
|
263
|
+
mln.rule_weights.data[i] -= 2 * epsilon
|
|
264
|
+
pseudo_likelihood_minus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
265
|
+
mln, fact_probabilities
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Compute gradient
|
|
269
|
+
gradients[i] = (pseudo_likelihood_plus - pseudo_likelihood_minus) / (2 * epsilon)
|
|
270
|
+
|
|
271
|
+
# Restore original weight
|
|
272
|
+
mln.rule_weights.data[i] = original_weights[i]
|
|
273
|
+
|
|
274
|
+
return gradients
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class MStepOptimizer:
|
|
278
|
+
"""
|
|
279
|
+
Main M-step optimizer that updates rule weights ω
|
|
280
|
+
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
def __init__(self, config: NPLLConfig):
|
|
284
|
+
self.config = config
|
|
285
|
+
self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
|
|
286
|
+
self.gradient_computer = GradientComputer(config)
|
|
287
|
+
|
|
288
|
+
# Optimization parameters
|
|
289
|
+
self.learning_rate = config.learning_rate
|
|
290
|
+
self.max_iterations = 100 # M-step specific iterations
|
|
291
|
+
self.convergence_threshold = config.convergence_threshold
|
|
292
|
+
self.grad_clip_norm = config.grad_clip_norm
|
|
293
|
+
|
|
294
|
+
def optimize_rule_weights(self,
|
|
295
|
+
mln: MarkovLogicNetwork,
|
|
296
|
+
e_step_result: EStepResult) -> MStepResult:
|
|
297
|
+
"""
|
|
298
|
+
Main M-step optimization procedure
|
|
299
|
+
|
|
300
|
+
"""
|
|
301
|
+
if mln.rule_weights is None:
|
|
302
|
+
logger.warning("No rule weights to optimize in M-step")
|
|
303
|
+
return MStepResult(
|
|
304
|
+
updated_rule_weights=torch.tensor([]),
|
|
305
|
+
weight_changes=torch.tensor([]),
|
|
306
|
+
gradient_norms=torch.tensor([]),
|
|
307
|
+
pseudo_likelihood=torch.tensor(0.0),
|
|
308
|
+
optimization_history=[],
|
|
309
|
+
convergence_info={'converged': True, 'reason': 'no_weights'},
|
|
310
|
+
iteration_count=0
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
logger.debug(f"Starting M-step optimization for {len(mln.rule_weights)} rule weights")
|
|
314
|
+
|
|
315
|
+
# Store initial weights
|
|
316
|
+
initial_weights = mln.rule_weights.data.clone()
|
|
317
|
+
|
|
318
|
+
# Setup optimizer
|
|
319
|
+
optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
|
|
320
|
+
|
|
321
|
+
# Optimization history
|
|
322
|
+
objective_history = []
|
|
323
|
+
gradient_history = []
|
|
324
|
+
|
|
325
|
+
converged = False
|
|
326
|
+
iteration = 0
|
|
327
|
+
|
|
328
|
+
for iteration in range(self.max_iterations):
|
|
329
|
+
optimizer.zero_grad()
|
|
330
|
+
|
|
331
|
+
# Ensure gradients are enabled for rule weights
|
|
332
|
+
if mln.rule_weights is not None:
|
|
333
|
+
mln.rule_weights.requires_grad_(True)
|
|
334
|
+
|
|
335
|
+
# Compute pseudo-likelihood objective
|
|
336
|
+
pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
337
|
+
mln, e_step_result.fact_probabilities
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# We want to maximize pseudo-likelihood, so minimize negative
|
|
341
|
+
loss = -pseudo_likelihood
|
|
342
|
+
|
|
343
|
+
# Backward pass
|
|
344
|
+
loss.backward()
|
|
345
|
+
|
|
346
|
+
# Store objective value (detached to prevent gradient issues)
|
|
347
|
+
objective_history.append(pseudo_likelihood.detach().item())
|
|
348
|
+
|
|
349
|
+
# Clip gradients
|
|
350
|
+
if mln.rule_weights.grad is not None:
|
|
351
|
+
grad_norm = gradient_clipping([mln.rule_weights], self.grad_clip_norm)
|
|
352
|
+
gradient_history.append(grad_norm)
|
|
353
|
+
else:
|
|
354
|
+
gradient_history.append(0.0)
|
|
355
|
+
|
|
356
|
+
# Optimization step
|
|
357
|
+
optimizer.step()
|
|
358
|
+
|
|
359
|
+
# Check convergence
|
|
360
|
+
if iteration > 0:
|
|
361
|
+
objective_change = abs(objective_history[-1] - objective_history[-2])
|
|
362
|
+
if objective_change < self.convergence_threshold:
|
|
363
|
+
converged = True
|
|
364
|
+
logger.debug(f"M-step converged at iteration {iteration}")
|
|
365
|
+
break
|
|
366
|
+
|
|
367
|
+
# Compute final metrics
|
|
368
|
+
final_weights = mln.rule_weights.data.clone()
|
|
369
|
+
weight_changes = final_weights - initial_weights
|
|
370
|
+
|
|
371
|
+
final_pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
372
|
+
mln, e_step_result.fact_probabilities
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
convergence_info = {
|
|
376
|
+
'converged': converged,
|
|
377
|
+
'final_objective_change': abs(objective_history[-1] - objective_history[-2]) if len(objective_history) > 1 else 0.0,
|
|
378
|
+
'objective_history': objective_history,
|
|
379
|
+
'gradient_history': gradient_history,
|
|
380
|
+
'reason': 'converged' if converged else 'max_iterations'
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
result = MStepResult(
|
|
384
|
+
updated_rule_weights=final_weights,
|
|
385
|
+
weight_changes=weight_changes,
|
|
386
|
+
gradient_norms=torch.tensor(gradient_history),
|
|
387
|
+
pseudo_likelihood=final_pseudo_likelihood,
|
|
388
|
+
optimization_history=objective_history,
|
|
389
|
+
convergence_info=convergence_info,
|
|
390
|
+
iteration_count=iteration + 1
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
logger.debug(f"M-step completed: {result}")
|
|
394
|
+
|
|
395
|
+
return result
|
|
396
|
+
|
|
397
|
+
def optimize_with_regularization(self,
|
|
398
|
+
mln: MarkovLogicNetwork,
|
|
399
|
+
e_step_result: EStepResult,
|
|
400
|
+
l1_weight: float = 0.01,
|
|
401
|
+
l2_weight: float = 0.01) -> MStepResult:
|
|
402
|
+
"""
|
|
403
|
+
M-step optimization with L1 and L2 regularization on rule weights
|
|
404
|
+
|
|
405
|
+
Helps prevent overfitting and encourages sparse rule sets
|
|
406
|
+
"""
|
|
407
|
+
if mln.rule_weights is None:
|
|
408
|
+
return self.optimize_rule_weights(mln, e_step_result)
|
|
409
|
+
|
|
410
|
+
initial_weights = mln.rule_weights.data.clone()
|
|
411
|
+
optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
|
|
412
|
+
|
|
413
|
+
objective_history = []
|
|
414
|
+
converged = False
|
|
415
|
+
|
|
416
|
+
for iteration in range(self.max_iterations):
|
|
417
|
+
optimizer.zero_grad()
|
|
418
|
+
|
|
419
|
+
# Primary pseudo-likelihood objective
|
|
420
|
+
pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
|
|
421
|
+
mln, e_step_result.fact_probabilities
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Regularization terms
|
|
425
|
+
l1_reg = torch.sum(torch.abs(mln.rule_weights))
|
|
426
|
+
l2_reg = torch.sum(mln.rule_weights ** 2)
|
|
427
|
+
|
|
428
|
+
# Total objective (maximize pseudo-likelihood, minimize regularization)
|
|
429
|
+
total_objective = pseudo_likelihood - l1_weight * l1_reg - l2_weight * l2_reg
|
|
430
|
+
loss = -total_objective
|
|
431
|
+
|
|
432
|
+
# Optimization
|
|
433
|
+
loss.backward()
|
|
434
|
+
gradient_clipping([mln.rule_weights], self.grad_clip_norm)
|
|
435
|
+
optimizer.step()
|
|
436
|
+
|
|
437
|
+
objective_history.append(total_objective.item())
|
|
438
|
+
|
|
439
|
+
# Convergence check
|
|
440
|
+
if iteration > 0:
|
|
441
|
+
if abs(objective_history[-1] - objective_history[-2]) < self.convergence_threshold:
|
|
442
|
+
converged = True
|
|
443
|
+
break
|
|
444
|
+
|
|
445
|
+
final_weights = mln.rule_weights.data.clone()
|
|
446
|
+
weight_changes = final_weights - initial_weights
|
|
447
|
+
|
|
448
|
+
return MStepResult(
|
|
449
|
+
updated_rule_weights=final_weights,
|
|
450
|
+
weight_changes=weight_changes,
|
|
451
|
+
gradient_norms=torch.tensor([0.0]), # Placeholder
|
|
452
|
+
pseudo_likelihood=pseudo_likelihood,
|
|
453
|
+
optimization_history=objective_history,
|
|
454
|
+
convergence_info={'converged': converged},
|
|
455
|
+
iteration_count=len(objective_history)
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
class MStepRunner:
|
|
460
|
+
"""
|
|
461
|
+
High-level runner for M-step computations
|
|
462
|
+
Handles different optimization strategies and result aggregation
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
def __init__(self, config: NPLLConfig):
|
|
466
|
+
self.config = config
|
|
467
|
+
self.m_step_optimizer = MStepOptimizer(config)
|
|
468
|
+
|
|
469
|
+
def run_m_step(self,
|
|
470
|
+
mln: MarkovLogicNetwork,
|
|
471
|
+
e_step_result: EStepResult,
|
|
472
|
+
optimization_strategy: str = "adam") -> MStepResult:
|
|
473
|
+
"""
|
|
474
|
+
Run complete M-step for rule weight optimization
|
|
475
|
+
|
|
476
|
+
"""
|
|
477
|
+
logger.info(f"Running M-step with {optimization_strategy} optimization")
|
|
478
|
+
|
|
479
|
+
if optimization_strategy == "regularized":
|
|
480
|
+
result = self.m_step_optimizer.optimize_with_regularization(
|
|
481
|
+
mln, e_step_result, l1_weight=0.01, l2_weight=0.01
|
|
482
|
+
)
|
|
483
|
+
else:
|
|
484
|
+
# Standard optimization
|
|
485
|
+
result = self.m_step_optimizer.optimize_rule_weights(mln, e_step_result)
|
|
486
|
+
|
|
487
|
+
logger.info(f"M-step completed: Pseudo-likelihood={result.pseudo_likelihood.item():.4f}, "
|
|
488
|
+
f"Max weight change={torch.max(torch.abs(result.weight_changes)).item():.6f}")
|
|
489
|
+
|
|
490
|
+
return result
|
|
491
|
+
|
|
492
|
+
def run_m_step_with_constraints(self,
|
|
493
|
+
mln: MarkovLogicNetwork,
|
|
494
|
+
e_step_result: EStepResult,
|
|
495
|
+
weight_bounds: Optional[List[Tuple[float, float]]] = None) -> MStepResult:
|
|
496
|
+
"""
|
|
497
|
+
Run M-step with constraints on rule weights
|
|
498
|
+
|
|
499
|
+
"""
|
|
500
|
+
# Standard M-step
|
|
501
|
+
result = self.run_m_step(mln, e_step_result)
|
|
502
|
+
|
|
503
|
+
# Apply weight constraints
|
|
504
|
+
if weight_bounds and mln.rule_weights is not None:
|
|
505
|
+
with torch.no_grad():
|
|
506
|
+
for i, (min_weight, max_weight) in enumerate(weight_bounds):
|
|
507
|
+
if i < len(mln.rule_weights):
|
|
508
|
+
mln.rule_weights[i] = torch.clamp(
|
|
509
|
+
mln.rule_weights[i], min_weight, max_weight
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
# Update result with constrained weights
|
|
513
|
+
result.updated_rule_weights = mln.rule_weights.data.clone()
|
|
514
|
+
|
|
515
|
+
return result
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def create_m_step_runner(config: NPLLConfig) -> MStepRunner:
|
|
519
|
+
"""Factory function to create M-step runner"""
|
|
520
|
+
return MStepRunner(config)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def verify_m_step_implementation():
|
|
524
|
+
"""Verify M-step implementation"""
|
|
525
|
+
from ..utils.config import default_config
|
|
526
|
+
from ..core import load_knowledge_graph_from_triples
|
|
527
|
+
from ..core.mln import create_mln_from_kg_and_rules
|
|
528
|
+
from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
|
|
529
|
+
from ..scoring import create_scoring_module
|
|
530
|
+
from .e_step import create_e_step_runner
|
|
531
|
+
|
|
532
|
+
# Create test data
|
|
533
|
+
test_triples = [
|
|
534
|
+
("A", "r1", "B"),
|
|
535
|
+
("B", "r2", "C"),
|
|
536
|
+
]
|
|
537
|
+
|
|
538
|
+
kg = load_knowledge_graph_from_triples(test_triples)
|
|
539
|
+
kg.add_unknown_fact("A", "r3", "C")
|
|
540
|
+
|
|
541
|
+
# Create test rule
|
|
542
|
+
from ..core import Relation
|
|
543
|
+
r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
|
|
544
|
+
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
545
|
+
|
|
546
|
+
test_rule = LogicalRule(
|
|
547
|
+
rule_id="test_m_step_rule",
|
|
548
|
+
body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
|
|
549
|
+
head=Atom(r3, (x, z)),
|
|
550
|
+
rule_type=RuleType.TRANSITIVITY
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
# Create MLN and scoring module
|
|
554
|
+
mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
|
|
555
|
+
scoring_module = create_scoring_module(default_config, kg)
|
|
556
|
+
|
|
557
|
+
# Run E-step to get Q(U)
|
|
558
|
+
e_step_runner = create_e_step_runner(default_config)
|
|
559
|
+
e_step_result = e_step_runner.run_e_step(mln, scoring_module, kg)
|
|
560
|
+
|
|
561
|
+
# Test M-step
|
|
562
|
+
m_step_runner = MStepRunner(default_config)
|
|
563
|
+
m_step_result = m_step_runner.run_m_step(mln, e_step_result)
|
|
564
|
+
|
|
565
|
+
# Verify results
|
|
566
|
+
assert len(m_step_result.updated_rule_weights) == len(mln.logical_rules), \
|
|
567
|
+
"Should have weights for all rules"
|
|
568
|
+
|
|
569
|
+
assert torch.all(torch.isfinite(m_step_result.updated_rule_weights)), \
|
|
570
|
+
"All weights should be finite"
|
|
571
|
+
|
|
572
|
+
assert torch.isfinite(m_step_result.pseudo_likelihood), \
|
|
573
|
+
"Pseudo-likelihood should be finite"
|
|
574
|
+
|
|
575
|
+
logger.info("M-step implementation verified successfully")
|
|
576
|
+
|
|
577
577
|
return True
|