odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/inference/m_step.py CHANGED
@@ -1,577 +1,577 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from typing import List, Dict, Set, Tuple, Optional, Any, Callable
6
- import logging
7
- from dataclasses import dataclass
8
- from collections import defaultdict
9
- import numpy as np
10
-
11
- from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
12
- from ..core.mln import MarkovLogicNetwork
13
- from ..utils.config import NPLLConfig
14
- from ..utils.math_utils import (
15
- safe_log, gradient_clipping, compute_markov_blanket_prob,
16
- log_sum_exp, partition_function_approximation
17
- )
18
- from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
19
- from .e_step import EStepResult
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- @dataclass
25
- class MStepResult:
26
- """
27
- Result of M-step computation containing optimization details
28
- """
29
- updated_rule_weights: torch.Tensor # New rule weights ω
30
- weight_changes: torch.Tensor # Changes in weights from previous iteration
31
- gradient_norms: torch.Tensor # Gradient norms for each rule
32
- pseudo_likelihood: torch.Tensor # Final pseudo-likelihood value
33
- optimization_history: List[float] # History of objective values
34
- convergence_info: Dict[str, Any] # Convergence diagnostics
35
- iteration_count: int # Number of optimization iterations
36
-
37
- def __str__(self) -> str:
38
- return (f"M-step Result:\n"
39
- f" Rule weights: {self.updated_rule_weights.tolist()}\n"
40
- f" Max weight change: {torch.max(torch.abs(self.weight_changes)).item():.6f}\n"
41
- f" Pseudo-likelihood: {self.pseudo_likelihood.item():.4f}\n"
42
- f" Iterations: {self.iteration_count}")
43
-
44
-
45
- class PseudoLikelihoodComputer:
46
- """
47
- Computes pseudo-log-likelihood objective for M-step optimization
48
-
49
- """
50
-
51
- def __init__(self, config: NPLLConfig):
52
- self.config = config
53
- self.markov_blanket_size = config.markov_blanket_size
54
-
55
- def compute_pseudo_likelihood(self,
56
- mln: MarkovLogicNetwork,
57
- fact_probabilities: Dict[Triple, float],
58
- ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> torch.Tensor:
59
- """
60
- Compute pseudo-log-likelihood objective
61
- """
62
- if not fact_probabilities or mln.rule_weights is None:
63
- return torch.tensor(0.0)
64
-
65
- likelihood_parts = []
66
-
67
- # Compute pseudo-likelihood for each fact
68
- for fact, q_prob in fact_probabilities.items():
69
- # Get Markov blanket for this fact
70
- markov_blanket = self._get_markov_blanket(fact, mln)
71
-
72
- # Compute P(fact=True|ω, Markov Blanket) as a tensor
73
- p_true = self._compute_conditional_probability(
74
- fact, markov_blanket, mln, fact_probabilities
75
- )
76
- # Full cross-entropy against Q: q*log p + (1-q)*log(1-p)
77
- device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
78
- q = torch.tensor(q_prob, device=device, requires_grad=False)
79
- contribution = q * safe_log(p_true) + (1.0 - q) * safe_log(1.0 - p_true)
80
- likelihood_parts.append(contribution)
81
-
82
- # Sum all parts at the end to create the final tensor
83
- if likelihood_parts:
84
- total_pseudo_likelihood = torch.sum(torch.stack(likelihood_parts))
85
- else:
86
- total_pseudo_likelihood = torch.tensor(0.0)
87
-
88
- return total_pseudo_likelihood
89
-
90
- def _get_markov_blanket(self, target_fact: Triple, mln: MarkovLogicNetwork) -> Set[Triple]:
91
- """
92
- Get Markov blanket for a fact - all facts that appear in the same ground rules
93
-
94
- """
95
- markov_blanket = set()
96
-
97
- # Find all ground rules containing the target fact
98
- for ground_rule in mln.ground_rules:
99
- rule_facts = ground_rule.get_all_facts()
100
- if target_fact in rule_facts:
101
- # Add all other facts in this ground rule to Markov blanket
102
- markov_blanket.update(fact for fact in rule_facts if fact != target_fact)
103
-
104
- # Limit size to prevent computational explosion
105
- if len(markov_blanket) >= self.markov_blanket_size:
106
- break
107
-
108
- return markov_blanket
109
-
110
- def _compute_conditional_probability(self,
111
- target_fact: Triple,
112
- markov_blanket: Set[Triple],
113
- mln: MarkovLogicNetwork,
114
- fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
115
- """Expected feature-difference log-odds approximation; no hard 0/1 blanket."""
116
- device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
117
- if mln.rule_weights is None:
118
- return torch.tensor(0.5, device=device)
119
-
120
- delta = torch.zeros((), device=device, requires_grad=True)
121
- rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
122
- relevant_ground_rules = getattr(mln, 'fact_to_groundrules', {}).get(target_fact, [])
123
- if not relevant_ground_rules:
124
- relevant_ground_rules = [gr for gr in mln.ground_rules if target_fact in gr.get_all_facts()]
125
-
126
- def q_true(f: Triple) -> torch.Tensor:
127
- if f in fact_probabilities:
128
- return torch.tensor(fact_probabilities[f], device=device)
129
- if mln.knowledge_graph and f in mln.knowledge_graph.known_facts:
130
- return torch.tensor(1.0, device=device)
131
- return torch.tensor(0.5, device=device)
132
-
133
- for gr in relevant_ground_rules:
134
- w = mln.rule_weights[rule_id_to_idx[gr.parent_rule.rule_id]]
135
- if gr.body_facts:
136
- p_body_true = torch.stack([q_true(f) for f in gr.body_facts]).prod()
137
- else:
138
- p_body_true = torch.tensor(1.0, device=device)
139
-
140
- p_head_true = q_true(gr.head_fact)
141
- p_head_false = 1.0 - p_head_true
142
-
143
- if gr.head_fact == target_fact:
144
- delta = delta + w * p_body_true
145
- elif target_fact in gr.body_facts:
146
- other_body = [f for f in gr.body_facts if f != target_fact]
147
- p_other_body_true = (torch.stack([q_true(f) for f in other_body]).prod()
148
- if other_body else torch.tensor(1.0, device=device))
149
- delta = delta - w * (p_other_body_true * p_head_false)
150
- else:
151
- pass
152
-
153
- delta = torch.clamp(delta, -40.0, 40.0)
154
- return torch.sigmoid(delta)
155
-
156
- def _compute_local_potential(self,
157
- target_fact: Triple,
158
- assignment: Dict[Triple, bool],
159
- mln: MarkovLogicNetwork) -> torch.Tensor:
160
- """
161
- Compute local potential for assignment involving target fact
162
-
163
- """
164
- # Start with a zero tensor that is on the same device as the weights
165
- local_potential = torch.tensor(0.0, device=mln.rule_weights.device)
166
-
167
- if mln.rule_weights is None:
168
- return local_potential
169
-
170
- # Find all ground rules containing the target fact
171
- relevant_ground_rules = []
172
- for ground_rule in mln.ground_rules:
173
- if target_fact in ground_rule.get_all_facts():
174
- relevant_ground_rules.append(ground_rule)
175
-
176
- # Create a mapping from rule_id to its index in the weights tensor
177
- rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
178
-
179
- for ground_rule in relevant_ground_rules:
180
- # Check if ground rule is satisfied
181
- body_satisfied = all(
182
- assignment.get(fact, False)
183
- for fact in ground_rule.body_facts
184
- )
185
-
186
- rule_satisfied = False
187
- if not body_satisfied:
188
- # Body false -> rule vacuously true
189
- rule_satisfied = True
190
- else:
191
- # Body true -> check head
192
- rule_satisfied = assignment.get(ground_rule.head_fact, False)
193
-
194
- # Add contribution to potential
195
- if rule_satisfied:
196
- rule_idx = rule_id_to_idx.get(ground_rule.parent_rule.rule_id)
197
- if rule_idx is not None:
198
- # Add the weight *tensor* to the potential, preserving the graph
199
- local_potential = local_potential + mln.rule_weights[rule_idx]
200
-
201
- return local_potential
202
-
203
-
204
- class GradientComputer:
205
- """
206
- Computes gradients for rule weight optimization
207
-
208
- """
209
-
210
- def __init__(self, config: NPLLConfig):
211
- self.config = config
212
- self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
213
-
214
- def compute_rule_weight_gradients(self,
215
- mln: MarkovLogicNetwork,
216
- fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
217
- """
218
- Compute gradients of pseudo-likelihood with respect to rule weights
219
- """
220
- if mln.rule_weights is None or not fact_probabilities:
221
- return torch.zeros(len(mln.logical_rules))
222
-
223
- # Enable gradient computation
224
- if mln.rule_weights is not None:
225
- mln.rule_weights.requires_grad_(True)
226
-
227
- # Compute pseudo-likelihood
228
- pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
229
- mln, fact_probabilities
230
- )
231
-
232
- # Compute gradients
233
- pseudo_likelihood.backward()
234
-
235
- if mln.rule_weights.grad is not None:
236
- gradients = mln.rule_weights.grad.clone()
237
- else:
238
- gradients = torch.zeros_like(mln.rule_weights)
239
-
240
- return gradients
241
-
242
- def compute_finite_difference_gradients(self,
243
- mln: MarkovLogicNetwork,
244
- fact_probabilities: Dict[Triple, float],
245
- epsilon: float = 1e-5) -> torch.Tensor:
246
- """
247
- Compute gradients using finite differences (for verification/debugging)
248
- """
249
- if mln.rule_weights is None:
250
- return torch.zeros(len(mln.logical_rules))
251
-
252
- gradients = torch.zeros_like(mln.rule_weights)
253
- original_weights = mln.rule_weights.data.clone()
254
-
255
- for i in range(len(mln.rule_weights)):
256
- # Forward difference
257
- mln.rule_weights.data[i] += epsilon
258
- pseudo_likelihood_plus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
259
- mln, fact_probabilities
260
- )
261
-
262
- # Backward difference
263
- mln.rule_weights.data[i] -= 2 * epsilon
264
- pseudo_likelihood_minus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
265
- mln, fact_probabilities
266
- )
267
-
268
- # Compute gradient
269
- gradients[i] = (pseudo_likelihood_plus - pseudo_likelihood_minus) / (2 * epsilon)
270
-
271
- # Restore original weight
272
- mln.rule_weights.data[i] = original_weights[i]
273
-
274
- return gradients
275
-
276
-
277
- class MStepOptimizer:
278
- """
279
- Main M-step optimizer that updates rule weights ω
280
-
281
- """
282
-
283
- def __init__(self, config: NPLLConfig):
284
- self.config = config
285
- self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
286
- self.gradient_computer = GradientComputer(config)
287
-
288
- # Optimization parameters
289
- self.learning_rate = config.learning_rate
290
- self.max_iterations = 100 # M-step specific iterations
291
- self.convergence_threshold = config.convergence_threshold
292
- self.grad_clip_norm = config.grad_clip_norm
293
-
294
- def optimize_rule_weights(self,
295
- mln: MarkovLogicNetwork,
296
- e_step_result: EStepResult) -> MStepResult:
297
- """
298
- Main M-step optimization procedure
299
-
300
- """
301
- if mln.rule_weights is None:
302
- logger.warning("No rule weights to optimize in M-step")
303
- return MStepResult(
304
- updated_rule_weights=torch.tensor([]),
305
- weight_changes=torch.tensor([]),
306
- gradient_norms=torch.tensor([]),
307
- pseudo_likelihood=torch.tensor(0.0),
308
- optimization_history=[],
309
- convergence_info={'converged': True, 'reason': 'no_weights'},
310
- iteration_count=0
311
- )
312
-
313
- logger.debug(f"Starting M-step optimization for {len(mln.rule_weights)} rule weights")
314
-
315
- # Store initial weights
316
- initial_weights = mln.rule_weights.data.clone()
317
-
318
- # Setup optimizer
319
- optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
320
-
321
- # Optimization history
322
- objective_history = []
323
- gradient_history = []
324
-
325
- converged = False
326
- iteration = 0
327
-
328
- for iteration in range(self.max_iterations):
329
- optimizer.zero_grad()
330
-
331
- # Ensure gradients are enabled for rule weights
332
- if mln.rule_weights is not None:
333
- mln.rule_weights.requires_grad_(True)
334
-
335
- # Compute pseudo-likelihood objective
336
- pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
337
- mln, e_step_result.fact_probabilities
338
- )
339
-
340
- # We want to maximize pseudo-likelihood, so minimize negative
341
- loss = -pseudo_likelihood
342
-
343
- # Backward pass
344
- loss.backward()
345
-
346
- # Store objective value (detached to prevent gradient issues)
347
- objective_history.append(pseudo_likelihood.detach().item())
348
-
349
- # Clip gradients
350
- if mln.rule_weights.grad is not None:
351
- grad_norm = gradient_clipping([mln.rule_weights], self.grad_clip_norm)
352
- gradient_history.append(grad_norm)
353
- else:
354
- gradient_history.append(0.0)
355
-
356
- # Optimization step
357
- optimizer.step()
358
-
359
- # Check convergence
360
- if iteration > 0:
361
- objective_change = abs(objective_history[-1] - objective_history[-2])
362
- if objective_change < self.convergence_threshold:
363
- converged = True
364
- logger.debug(f"M-step converged at iteration {iteration}")
365
- break
366
-
367
- # Compute final metrics
368
- final_weights = mln.rule_weights.data.clone()
369
- weight_changes = final_weights - initial_weights
370
-
371
- final_pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
372
- mln, e_step_result.fact_probabilities
373
- )
374
-
375
- convergence_info = {
376
- 'converged': converged,
377
- 'final_objective_change': abs(objective_history[-1] - objective_history[-2]) if len(objective_history) > 1 else 0.0,
378
- 'objective_history': objective_history,
379
- 'gradient_history': gradient_history,
380
- 'reason': 'converged' if converged else 'max_iterations'
381
- }
382
-
383
- result = MStepResult(
384
- updated_rule_weights=final_weights,
385
- weight_changes=weight_changes,
386
- gradient_norms=torch.tensor(gradient_history),
387
- pseudo_likelihood=final_pseudo_likelihood,
388
- optimization_history=objective_history,
389
- convergence_info=convergence_info,
390
- iteration_count=iteration + 1
391
- )
392
-
393
- logger.debug(f"M-step completed: {result}")
394
-
395
- return result
396
-
397
- def optimize_with_regularization(self,
398
- mln: MarkovLogicNetwork,
399
- e_step_result: EStepResult,
400
- l1_weight: float = 0.01,
401
- l2_weight: float = 0.01) -> MStepResult:
402
- """
403
- M-step optimization with L1 and L2 regularization on rule weights
404
-
405
- Helps prevent overfitting and encourages sparse rule sets
406
- """
407
- if mln.rule_weights is None:
408
- return self.optimize_rule_weights(mln, e_step_result)
409
-
410
- initial_weights = mln.rule_weights.data.clone()
411
- optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
412
-
413
- objective_history = []
414
- converged = False
415
-
416
- for iteration in range(self.max_iterations):
417
- optimizer.zero_grad()
418
-
419
- # Primary pseudo-likelihood objective
420
- pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
421
- mln, e_step_result.fact_probabilities
422
- )
423
-
424
- # Regularization terms
425
- l1_reg = torch.sum(torch.abs(mln.rule_weights))
426
- l2_reg = torch.sum(mln.rule_weights ** 2)
427
-
428
- # Total objective (maximize pseudo-likelihood, minimize regularization)
429
- total_objective = pseudo_likelihood - l1_weight * l1_reg - l2_weight * l2_reg
430
- loss = -total_objective
431
-
432
- # Optimization
433
- loss.backward()
434
- gradient_clipping([mln.rule_weights], self.grad_clip_norm)
435
- optimizer.step()
436
-
437
- objective_history.append(total_objective.item())
438
-
439
- # Convergence check
440
- if iteration > 0:
441
- if abs(objective_history[-1] - objective_history[-2]) < self.convergence_threshold:
442
- converged = True
443
- break
444
-
445
- final_weights = mln.rule_weights.data.clone()
446
- weight_changes = final_weights - initial_weights
447
-
448
- return MStepResult(
449
- updated_rule_weights=final_weights,
450
- weight_changes=weight_changes,
451
- gradient_norms=torch.tensor([0.0]), # Placeholder
452
- pseudo_likelihood=pseudo_likelihood,
453
- optimization_history=objective_history,
454
- convergence_info={'converged': converged},
455
- iteration_count=len(objective_history)
456
- )
457
-
458
-
459
- class MStepRunner:
460
- """
461
- High-level runner for M-step computations
462
- Handles different optimization strategies and result aggregation
463
- """
464
-
465
- def __init__(self, config: NPLLConfig):
466
- self.config = config
467
- self.m_step_optimizer = MStepOptimizer(config)
468
-
469
- def run_m_step(self,
470
- mln: MarkovLogicNetwork,
471
- e_step_result: EStepResult,
472
- optimization_strategy: str = "adam") -> MStepResult:
473
- """
474
- Run complete M-step for rule weight optimization
475
-
476
- """
477
- logger.info(f"Running M-step with {optimization_strategy} optimization")
478
-
479
- if optimization_strategy == "regularized":
480
- result = self.m_step_optimizer.optimize_with_regularization(
481
- mln, e_step_result, l1_weight=0.01, l2_weight=0.01
482
- )
483
- else:
484
- # Standard optimization
485
- result = self.m_step_optimizer.optimize_rule_weights(mln, e_step_result)
486
-
487
- logger.info(f"M-step completed: Pseudo-likelihood={result.pseudo_likelihood.item():.4f}, "
488
- f"Max weight change={torch.max(torch.abs(result.weight_changes)).item():.6f}")
489
-
490
- return result
491
-
492
- def run_m_step_with_constraints(self,
493
- mln: MarkovLogicNetwork,
494
- e_step_result: EStepResult,
495
- weight_bounds: Optional[List[Tuple[float, float]]] = None) -> MStepResult:
496
- """
497
- Run M-step with constraints on rule weights
498
-
499
- """
500
- # Standard M-step
501
- result = self.run_m_step(mln, e_step_result)
502
-
503
- # Apply weight constraints
504
- if weight_bounds and mln.rule_weights is not None:
505
- with torch.no_grad():
506
- for i, (min_weight, max_weight) in enumerate(weight_bounds):
507
- if i < len(mln.rule_weights):
508
- mln.rule_weights[i] = torch.clamp(
509
- mln.rule_weights[i], min_weight, max_weight
510
- )
511
-
512
- # Update result with constrained weights
513
- result.updated_rule_weights = mln.rule_weights.data.clone()
514
-
515
- return result
516
-
517
-
518
- def create_m_step_runner(config: NPLLConfig) -> MStepRunner:
519
- """Factory function to create M-step runner"""
520
- return MStepRunner(config)
521
-
522
-
523
- def verify_m_step_implementation():
524
- """Verify M-step implementation"""
525
- from ..utils.config import default_config
526
- from ..core import load_knowledge_graph_from_triples
527
- from ..core.mln import create_mln_from_kg_and_rules
528
- from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
529
- from ..scoring import create_scoring_module
530
- from .e_step import create_e_step_runner
531
-
532
- # Create test data
533
- test_triples = [
534
- ("A", "r1", "B"),
535
- ("B", "r2", "C"),
536
- ]
537
-
538
- kg = load_knowledge_graph_from_triples(test_triples)
539
- kg.add_unknown_fact("A", "r3", "C")
540
-
541
- # Create test rule
542
- from ..core import Relation
543
- r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
544
- x, y, z = Variable('x'), Variable('y'), Variable('z')
545
-
546
- test_rule = LogicalRule(
547
- rule_id="test_m_step_rule",
548
- body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
549
- head=Atom(r3, (x, z)),
550
- rule_type=RuleType.TRANSITIVITY
551
- )
552
-
553
- # Create MLN and scoring module
554
- mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
555
- scoring_module = create_scoring_module(default_config, kg)
556
-
557
- # Run E-step to get Q(U)
558
- e_step_runner = create_e_step_runner(default_config)
559
- e_step_result = e_step_runner.run_e_step(mln, scoring_module, kg)
560
-
561
- # Test M-step
562
- m_step_runner = MStepRunner(default_config)
563
- m_step_result = m_step_runner.run_m_step(mln, e_step_result)
564
-
565
- # Verify results
566
- assert len(m_step_result.updated_rule_weights) == len(mln.logical_rules), \
567
- "Should have weights for all rules"
568
-
569
- assert torch.all(torch.isfinite(m_step_result.updated_rule_weights)), \
570
- "All weights should be finite"
571
-
572
- assert torch.isfinite(m_step_result.pseudo_likelihood), \
573
- "Pseudo-likelihood should be finite"
574
-
575
- logger.info("M-step implementation verified successfully")
576
-
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from typing import List, Dict, Set, Tuple, Optional, Any, Callable
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from collections import defaultdict
9
+ import numpy as np
10
+
11
+ from ..core import Triple, LogicalRule, GroundRule, KnowledgeGraph
12
+ from ..core.mln import MarkovLogicNetwork
13
+ from ..utils.config import NPLLConfig
14
+ from ..utils.math_utils import (
15
+ safe_log, gradient_clipping, compute_markov_blanket_prob,
16
+ log_sum_exp, partition_function_approximation
17
+ )
18
+ from ..utils.batch_utils import GroundRuleBatch, GroundRuleSampler
19
+ from .e_step import EStepResult
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class MStepResult:
26
+ """
27
+ Result of M-step computation containing optimization details
28
+ """
29
+ updated_rule_weights: torch.Tensor # New rule weights ω
30
+ weight_changes: torch.Tensor # Changes in weights from previous iteration
31
+ gradient_norms: torch.Tensor # Gradient norms for each rule
32
+ pseudo_likelihood: torch.Tensor # Final pseudo-likelihood value
33
+ optimization_history: List[float] # History of objective values
34
+ convergence_info: Dict[str, Any] # Convergence diagnostics
35
+ iteration_count: int # Number of optimization iterations
36
+
37
+ def __str__(self) -> str:
38
+ return (f"M-step Result:\n"
39
+ f" Rule weights: {self.updated_rule_weights.tolist()}\n"
40
+ f" Max weight change: {torch.max(torch.abs(self.weight_changes)).item():.6f}\n"
41
+ f" Pseudo-likelihood: {self.pseudo_likelihood.item():.4f}\n"
42
+ f" Iterations: {self.iteration_count}")
43
+
44
+
45
+ class PseudoLikelihoodComputer:
46
+ """
47
+ Computes pseudo-log-likelihood objective for M-step optimization
48
+
49
+ """
50
+
51
+ def __init__(self, config: NPLLConfig):
52
+ self.config = config
53
+ self.markov_blanket_size = config.markov_blanket_size
54
+
55
+ def compute_pseudo_likelihood(self,
56
+ mln: MarkovLogicNetwork,
57
+ fact_probabilities: Dict[Triple, float],
58
+ ground_rule_batches: Optional[List[GroundRuleBatch]] = None) -> torch.Tensor:
59
+ """
60
+ Compute pseudo-log-likelihood objective
61
+ """
62
+ if not fact_probabilities or mln.rule_weights is None:
63
+ return torch.tensor(0.0)
64
+
65
+ likelihood_parts = []
66
+
67
+ # Compute pseudo-likelihood for each fact
68
+ for fact, q_prob in fact_probabilities.items():
69
+ # Get Markov blanket for this fact
70
+ markov_blanket = self._get_markov_blanket(fact, mln)
71
+
72
+ # Compute P(fact=True|ω, Markov Blanket) as a tensor
73
+ p_true = self._compute_conditional_probability(
74
+ fact, markov_blanket, mln, fact_probabilities
75
+ )
76
+ # Full cross-entropy against Q: q*log p + (1-q)*log(1-p)
77
+ device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
78
+ q = torch.tensor(q_prob, device=device, requires_grad=False)
79
+ contribution = q * safe_log(p_true) + (1.0 - q) * safe_log(1.0 - p_true)
80
+ likelihood_parts.append(contribution)
81
+
82
+ # Sum all parts at the end to create the final tensor
83
+ if likelihood_parts:
84
+ total_pseudo_likelihood = torch.sum(torch.stack(likelihood_parts))
85
+ else:
86
+ total_pseudo_likelihood = torch.tensor(0.0)
87
+
88
+ return total_pseudo_likelihood
89
+
90
+ def _get_markov_blanket(self, target_fact: Triple, mln: MarkovLogicNetwork) -> Set[Triple]:
91
+ """
92
+ Get Markov blanket for a fact - all facts that appear in the same ground rules
93
+
94
+ """
95
+ markov_blanket = set()
96
+
97
+ # Find all ground rules containing the target fact
98
+ for ground_rule in mln.ground_rules:
99
+ rule_facts = ground_rule.get_all_facts()
100
+ if target_fact in rule_facts:
101
+ # Add all other facts in this ground rule to Markov blanket
102
+ markov_blanket.update(fact for fact in rule_facts if fact != target_fact)
103
+
104
+ # Limit size to prevent computational explosion
105
+ if len(markov_blanket) >= self.markov_blanket_size:
106
+ break
107
+
108
+ return markov_blanket
109
+
110
+ def _compute_conditional_probability(self,
111
+ target_fact: Triple,
112
+ markov_blanket: Set[Triple],
113
+ mln: MarkovLogicNetwork,
114
+ fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
115
+ """Expected feature-difference log-odds approximation; no hard 0/1 blanket."""
116
+ device = mln.rule_weights.device if mln.rule_weights is not None else 'cpu'
117
+ if mln.rule_weights is None:
118
+ return torch.tensor(0.5, device=device)
119
+
120
+ delta = torch.zeros((), device=device, requires_grad=True)
121
+ rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
122
+ relevant_ground_rules = getattr(mln, 'fact_to_groundrules', {}).get(target_fact, [])
123
+ if not relevant_ground_rules:
124
+ relevant_ground_rules = [gr for gr in mln.ground_rules if target_fact in gr.get_all_facts()]
125
+
126
+ def q_true(f: Triple) -> torch.Tensor:
127
+ if f in fact_probabilities:
128
+ return torch.tensor(fact_probabilities[f], device=device)
129
+ if mln.knowledge_graph and f in mln.knowledge_graph.known_facts:
130
+ return torch.tensor(1.0, device=device)
131
+ return torch.tensor(0.5, device=device)
132
+
133
+ for gr in relevant_ground_rules:
134
+ w = mln.rule_weights[rule_id_to_idx[gr.parent_rule.rule_id]]
135
+ if gr.body_facts:
136
+ p_body_true = torch.stack([q_true(f) for f in gr.body_facts]).prod()
137
+ else:
138
+ p_body_true = torch.tensor(1.0, device=device)
139
+
140
+ p_head_true = q_true(gr.head_fact)
141
+ p_head_false = 1.0 - p_head_true
142
+
143
+ if gr.head_fact == target_fact:
144
+ delta = delta + w * p_body_true
145
+ elif target_fact in gr.body_facts:
146
+ other_body = [f for f in gr.body_facts if f != target_fact]
147
+ p_other_body_true = (torch.stack([q_true(f) for f in other_body]).prod()
148
+ if other_body else torch.tensor(1.0, device=device))
149
+ delta = delta - w * (p_other_body_true * p_head_false)
150
+ else:
151
+ pass
152
+
153
+ delta = torch.clamp(delta, -40.0, 40.0)
154
+ return torch.sigmoid(delta)
155
+
156
+ def _compute_local_potential(self,
157
+ target_fact: Triple,
158
+ assignment: Dict[Triple, bool],
159
+ mln: MarkovLogicNetwork) -> torch.Tensor:
160
+ """
161
+ Compute local potential for assignment involving target fact
162
+
163
+ """
164
+ # Start with a zero tensor that is on the same device as the weights
165
+ local_potential = torch.tensor(0.0, device=mln.rule_weights.device)
166
+
167
+ if mln.rule_weights is None:
168
+ return local_potential
169
+
170
+ # Find all ground rules containing the target fact
171
+ relevant_ground_rules = []
172
+ for ground_rule in mln.ground_rules:
173
+ if target_fact in ground_rule.get_all_facts():
174
+ relevant_ground_rules.append(ground_rule)
175
+
176
+ # Create a mapping from rule_id to its index in the weights tensor
177
+ rule_id_to_idx = {rule.rule_id: i for i, rule in enumerate(mln.logical_rules)}
178
+
179
+ for ground_rule in relevant_ground_rules:
180
+ # Check if ground rule is satisfied
181
+ body_satisfied = all(
182
+ assignment.get(fact, False)
183
+ for fact in ground_rule.body_facts
184
+ )
185
+
186
+ rule_satisfied = False
187
+ if not body_satisfied:
188
+ # Body false -> rule vacuously true
189
+ rule_satisfied = True
190
+ else:
191
+ # Body true -> check head
192
+ rule_satisfied = assignment.get(ground_rule.head_fact, False)
193
+
194
+ # Add contribution to potential
195
+ if rule_satisfied:
196
+ rule_idx = rule_id_to_idx.get(ground_rule.parent_rule.rule_id)
197
+ if rule_idx is not None:
198
+ # Add the weight *tensor* to the potential, preserving the graph
199
+ local_potential = local_potential + mln.rule_weights[rule_idx]
200
+
201
+ return local_potential
202
+
203
+
204
+ class GradientComputer:
205
+ """
206
+ Computes gradients for rule weight optimization
207
+
208
+ """
209
+
210
+ def __init__(self, config: NPLLConfig):
211
+ self.config = config
212
+ self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
213
+
214
+ def compute_rule_weight_gradients(self,
215
+ mln: MarkovLogicNetwork,
216
+ fact_probabilities: Dict[Triple, float]) -> torch.Tensor:
217
+ """
218
+ Compute gradients of pseudo-likelihood with respect to rule weights
219
+ """
220
+ if mln.rule_weights is None or not fact_probabilities:
221
+ return torch.zeros(len(mln.logical_rules))
222
+
223
+ # Enable gradient computation
224
+ if mln.rule_weights is not None:
225
+ mln.rule_weights.requires_grad_(True)
226
+
227
+ # Compute pseudo-likelihood
228
+ pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
229
+ mln, fact_probabilities
230
+ )
231
+
232
+ # Compute gradients
233
+ pseudo_likelihood.backward()
234
+
235
+ if mln.rule_weights.grad is not None:
236
+ gradients = mln.rule_weights.grad.clone()
237
+ else:
238
+ gradients = torch.zeros_like(mln.rule_weights)
239
+
240
+ return gradients
241
+
242
+ def compute_finite_difference_gradients(self,
243
+ mln: MarkovLogicNetwork,
244
+ fact_probabilities: Dict[Triple, float],
245
+ epsilon: float = 1e-5) -> torch.Tensor:
246
+ """
247
+ Compute gradients using finite differences (for verification/debugging)
248
+ """
249
+ if mln.rule_weights is None:
250
+ return torch.zeros(len(mln.logical_rules))
251
+
252
+ gradients = torch.zeros_like(mln.rule_weights)
253
+ original_weights = mln.rule_weights.data.clone()
254
+
255
+ for i in range(len(mln.rule_weights)):
256
+ # Forward difference
257
+ mln.rule_weights.data[i] += epsilon
258
+ pseudo_likelihood_plus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
259
+ mln, fact_probabilities
260
+ )
261
+
262
+ # Backward difference
263
+ mln.rule_weights.data[i] -= 2 * epsilon
264
+ pseudo_likelihood_minus = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
265
+ mln, fact_probabilities
266
+ )
267
+
268
+ # Compute gradient
269
+ gradients[i] = (pseudo_likelihood_plus - pseudo_likelihood_minus) / (2 * epsilon)
270
+
271
+ # Restore original weight
272
+ mln.rule_weights.data[i] = original_weights[i]
273
+
274
+ return gradients
275
+
276
+
277
+ class MStepOptimizer:
278
+ """
279
+ Main M-step optimizer that updates rule weights ω
280
+
281
+ """
282
+
283
+ def __init__(self, config: NPLLConfig):
284
+ self.config = config
285
+ self.pseudo_likelihood_computer = PseudoLikelihoodComputer(config)
286
+ self.gradient_computer = GradientComputer(config)
287
+
288
+ # Optimization parameters
289
+ self.learning_rate = config.learning_rate
290
+ self.max_iterations = 100 # M-step specific iterations
291
+ self.convergence_threshold = config.convergence_threshold
292
+ self.grad_clip_norm = config.grad_clip_norm
293
+
294
+ def optimize_rule_weights(self,
295
+ mln: MarkovLogicNetwork,
296
+ e_step_result: EStepResult) -> MStepResult:
297
+ """
298
+ Main M-step optimization procedure
299
+
300
+ """
301
+ if mln.rule_weights is None:
302
+ logger.warning("No rule weights to optimize in M-step")
303
+ return MStepResult(
304
+ updated_rule_weights=torch.tensor([]),
305
+ weight_changes=torch.tensor([]),
306
+ gradient_norms=torch.tensor([]),
307
+ pseudo_likelihood=torch.tensor(0.0),
308
+ optimization_history=[],
309
+ convergence_info={'converged': True, 'reason': 'no_weights'},
310
+ iteration_count=0
311
+ )
312
+
313
+ logger.debug(f"Starting M-step optimization for {len(mln.rule_weights)} rule weights")
314
+
315
+ # Store initial weights
316
+ initial_weights = mln.rule_weights.data.clone()
317
+
318
+ # Setup optimizer
319
+ optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
320
+
321
+ # Optimization history
322
+ objective_history = []
323
+ gradient_history = []
324
+
325
+ converged = False
326
+ iteration = 0
327
+
328
+ for iteration in range(self.max_iterations):
329
+ optimizer.zero_grad()
330
+
331
+ # Ensure gradients are enabled for rule weights
332
+ if mln.rule_weights is not None:
333
+ mln.rule_weights.requires_grad_(True)
334
+
335
+ # Compute pseudo-likelihood objective
336
+ pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
337
+ mln, e_step_result.fact_probabilities
338
+ )
339
+
340
+ # We want to maximize pseudo-likelihood, so minimize negative
341
+ loss = -pseudo_likelihood
342
+
343
+ # Backward pass
344
+ loss.backward()
345
+
346
+ # Store objective value (detached to prevent gradient issues)
347
+ objective_history.append(pseudo_likelihood.detach().item())
348
+
349
+ # Clip gradients
350
+ if mln.rule_weights.grad is not None:
351
+ grad_norm = gradient_clipping([mln.rule_weights], self.grad_clip_norm)
352
+ gradient_history.append(grad_norm)
353
+ else:
354
+ gradient_history.append(0.0)
355
+
356
+ # Optimization step
357
+ optimizer.step()
358
+
359
+ # Check convergence
360
+ if iteration > 0:
361
+ objective_change = abs(objective_history[-1] - objective_history[-2])
362
+ if objective_change < self.convergence_threshold:
363
+ converged = True
364
+ logger.debug(f"M-step converged at iteration {iteration}")
365
+ break
366
+
367
+ # Compute final metrics
368
+ final_weights = mln.rule_weights.data.clone()
369
+ weight_changes = final_weights - initial_weights
370
+
371
+ final_pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
372
+ mln, e_step_result.fact_probabilities
373
+ )
374
+
375
+ convergence_info = {
376
+ 'converged': converged,
377
+ 'final_objective_change': abs(objective_history[-1] - objective_history[-2]) if len(objective_history) > 1 else 0.0,
378
+ 'objective_history': objective_history,
379
+ 'gradient_history': gradient_history,
380
+ 'reason': 'converged' if converged else 'max_iterations'
381
+ }
382
+
383
+ result = MStepResult(
384
+ updated_rule_weights=final_weights,
385
+ weight_changes=weight_changes,
386
+ gradient_norms=torch.tensor(gradient_history),
387
+ pseudo_likelihood=final_pseudo_likelihood,
388
+ optimization_history=objective_history,
389
+ convergence_info=convergence_info,
390
+ iteration_count=iteration + 1
391
+ )
392
+
393
+ logger.debug(f"M-step completed: {result}")
394
+
395
+ return result
396
+
397
+ def optimize_with_regularization(self,
398
+ mln: MarkovLogicNetwork,
399
+ e_step_result: EStepResult,
400
+ l1_weight: float = 0.01,
401
+ l2_weight: float = 0.01) -> MStepResult:
402
+ """
403
+ M-step optimization with L1 and L2 regularization on rule weights
404
+
405
+ Helps prevent overfitting and encourages sparse rule sets
406
+ """
407
+ if mln.rule_weights is None:
408
+ return self.optimize_rule_weights(mln, e_step_result)
409
+
410
+ initial_weights = mln.rule_weights.data.clone()
411
+ optimizer = optim.Adam([mln.rule_weights], lr=self.learning_rate)
412
+
413
+ objective_history = []
414
+ converged = False
415
+
416
+ for iteration in range(self.max_iterations):
417
+ optimizer.zero_grad()
418
+
419
+ # Primary pseudo-likelihood objective
420
+ pseudo_likelihood = self.pseudo_likelihood_computer.compute_pseudo_likelihood(
421
+ mln, e_step_result.fact_probabilities
422
+ )
423
+
424
+ # Regularization terms
425
+ l1_reg = torch.sum(torch.abs(mln.rule_weights))
426
+ l2_reg = torch.sum(mln.rule_weights ** 2)
427
+
428
+ # Total objective (maximize pseudo-likelihood, minimize regularization)
429
+ total_objective = pseudo_likelihood - l1_weight * l1_reg - l2_weight * l2_reg
430
+ loss = -total_objective
431
+
432
+ # Optimization
433
+ loss.backward()
434
+ gradient_clipping([mln.rule_weights], self.grad_clip_norm)
435
+ optimizer.step()
436
+
437
+ objective_history.append(total_objective.item())
438
+
439
+ # Convergence check
440
+ if iteration > 0:
441
+ if abs(objective_history[-1] - objective_history[-2]) < self.convergence_threshold:
442
+ converged = True
443
+ break
444
+
445
+ final_weights = mln.rule_weights.data.clone()
446
+ weight_changes = final_weights - initial_weights
447
+
448
+ return MStepResult(
449
+ updated_rule_weights=final_weights,
450
+ weight_changes=weight_changes,
451
+ gradient_norms=torch.tensor([0.0]), # Placeholder
452
+ pseudo_likelihood=pseudo_likelihood,
453
+ optimization_history=objective_history,
454
+ convergence_info={'converged': converged},
455
+ iteration_count=len(objective_history)
456
+ )
457
+
458
+
459
+ class MStepRunner:
460
+ """
461
+ High-level runner for M-step computations
462
+ Handles different optimization strategies and result aggregation
463
+ """
464
+
465
+ def __init__(self, config: NPLLConfig):
466
+ self.config = config
467
+ self.m_step_optimizer = MStepOptimizer(config)
468
+
469
+ def run_m_step(self,
470
+ mln: MarkovLogicNetwork,
471
+ e_step_result: EStepResult,
472
+ optimization_strategy: str = "adam") -> MStepResult:
473
+ """
474
+ Run complete M-step for rule weight optimization
475
+
476
+ """
477
+ logger.info(f"Running M-step with {optimization_strategy} optimization")
478
+
479
+ if optimization_strategy == "regularized":
480
+ result = self.m_step_optimizer.optimize_with_regularization(
481
+ mln, e_step_result, l1_weight=0.01, l2_weight=0.01
482
+ )
483
+ else:
484
+ # Standard optimization
485
+ result = self.m_step_optimizer.optimize_rule_weights(mln, e_step_result)
486
+
487
+ logger.info(f"M-step completed: Pseudo-likelihood={result.pseudo_likelihood.item():.4f}, "
488
+ f"Max weight change={torch.max(torch.abs(result.weight_changes)).item():.6f}")
489
+
490
+ return result
491
+
492
+ def run_m_step_with_constraints(self,
493
+ mln: MarkovLogicNetwork,
494
+ e_step_result: EStepResult,
495
+ weight_bounds: Optional[List[Tuple[float, float]]] = None) -> MStepResult:
496
+ """
497
+ Run M-step with constraints on rule weights
498
+
499
+ """
500
+ # Standard M-step
501
+ result = self.run_m_step(mln, e_step_result)
502
+
503
+ # Apply weight constraints
504
+ if weight_bounds and mln.rule_weights is not None:
505
+ with torch.no_grad():
506
+ for i, (min_weight, max_weight) in enumerate(weight_bounds):
507
+ if i < len(mln.rule_weights):
508
+ mln.rule_weights[i] = torch.clamp(
509
+ mln.rule_weights[i], min_weight, max_weight
510
+ )
511
+
512
+ # Update result with constrained weights
513
+ result.updated_rule_weights = mln.rule_weights.data.clone()
514
+
515
+ return result
516
+
517
+
518
+ def create_m_step_runner(config: NPLLConfig) -> MStepRunner:
519
+ """Factory function to create M-step runner"""
520
+ return MStepRunner(config)
521
+
522
+
523
+ def verify_m_step_implementation():
524
+ """Verify M-step implementation"""
525
+ from ..utils.config import default_config
526
+ from ..core import load_knowledge_graph_from_triples
527
+ from ..core.mln import create_mln_from_kg_and_rules
528
+ from ..core.logical_rules import Variable, Atom, RuleType, LogicalRule
529
+ from ..scoring import create_scoring_module
530
+ from .e_step import create_e_step_runner
531
+
532
+ # Create test data
533
+ test_triples = [
534
+ ("A", "r1", "B"),
535
+ ("B", "r2", "C"),
536
+ ]
537
+
538
+ kg = load_knowledge_graph_from_triples(test_triples)
539
+ kg.add_unknown_fact("A", "r3", "C")
540
+
541
+ # Create test rule
542
+ from ..core import Relation
543
+ r1, r2, r3 = Relation("r1"), Relation("r2"), Relation("r3")
544
+ x, y, z = Variable('x'), Variable('y'), Variable('z')
545
+
546
+ test_rule = LogicalRule(
547
+ rule_id="test_m_step_rule",
548
+ body=[Atom(r1, (x, y)), Atom(r2, (y, z))],
549
+ head=Atom(r3, (x, z)),
550
+ rule_type=RuleType.TRANSITIVITY
551
+ )
552
+
553
+ # Create MLN and scoring module
554
+ mln = create_mln_from_kg_and_rules(kg, [test_rule], default_config)
555
+ scoring_module = create_scoring_module(default_config, kg)
556
+
557
+ # Run E-step to get Q(U)
558
+ e_step_runner = create_e_step_runner(default_config)
559
+ e_step_result = e_step_runner.run_e_step(mln, scoring_module, kg)
560
+
561
+ # Test M-step
562
+ m_step_runner = MStepRunner(default_config)
563
+ m_step_result = m_step_runner.run_m_step(mln, e_step_result)
564
+
565
+ # Verify results
566
+ assert len(m_step_result.updated_rule_weights) == len(mln.logical_rules), \
567
+ "Should have weights for all rules"
568
+
569
+ assert torch.all(torch.isfinite(m_step_result.updated_rule_weights)), \
570
+ "All weights should be finite"
571
+
572
+ assert torch.isfinite(m_step_result.pseudo_likelihood), \
573
+ "Pseudo-likelihood should be finite"
574
+
575
+ logger.info("M-step implementation verified successfully")
576
+
577
577
  return True