odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,403 +1,403 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from typing import List, Dict, Tuple, Optional, Union
6
- import logging
7
- import numpy as np
8
-
9
- from ..utils.math_utils import safe_sigmoid, safe_log, bernoulli_log_prob, bernoulli_entropy
10
- from ..core import Triple
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class ProbabilityTransform(nn.Module):
16
- """
17
- Temperature scaling with optional per-group (e.g., per-relation) temperatures.
18
- Guarantees T>0 via softplus on an unconstrained log-T parameter.
19
- """
20
-
21
- def __init__(self, temperature: float = 1.0, per_group: bool = False, num_groups: int = 1):
22
- super().__init__()
23
- self.per_group = per_group
24
- init = torch.log(torch.expm1(torch.tensor(float(temperature))))
25
- if per_group:
26
- assert num_groups >= 1
27
- self.log_t = nn.Parameter(init.repeat(int(num_groups)))
28
- else:
29
- self.log_t = nn.Parameter(init.unsqueeze(0))
30
-
31
- def _temperature(self, device, group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
32
- T = F.softplus(self.log_t.to(device)) + 1e-6
33
- if self.per_group:
34
- if group_ids is None:
35
- raise ValueError("group_ids is required when per_group=True")
36
- return T.index_select(0, group_ids)
37
- return T[0]
38
-
39
- def ensure_num_groups(self, num_groups: int):
40
- if not self.per_group:
41
- return
42
- cur = self.log_t.numel()
43
- if num_groups <= cur:
44
- return
45
- with torch.no_grad():
46
- new = torch.empty(num_groups, device=self.log_t.device, dtype=self.log_t.dtype)
47
- new[:cur] = self.log_t.data
48
- init = torch.log(torch.expm1(torch.tensor(1.0, device=new.device, dtype=new.dtype)))
49
- new[cur:] = init
50
- self.log_t = nn.Parameter(new)
51
-
52
- def forward(self, scores: torch.Tensor,
53
- apply_temperature: bool = True,
54
- group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
55
- scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
56
- return torch.sigmoid(scaled)
57
-
58
- def log_probs(self, scores: torch.Tensor,
59
- apply_temperature: bool = True,
60
- group_ids: Optional[torch.LongTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
61
- scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
62
- return F.logsigmoid(scaled), F.logsigmoid(-scaled)
63
-
64
- # Backward-compat alias
65
- def get_log_probabilities(self, scores: torch.Tensor, apply_temperature: bool = True,
66
- group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
67
- log_p, _ = self.log_probs(scores, apply_temperature, group_ids)
68
- return log_p
69
-
70
- @torch.no_grad()
71
- def calibrate_temperature(self,
72
- logits: torch.Tensor, # raw scores before sigmoid
73
- labels: torch.Tensor, # {0,1}
74
- max_iter: int = 100,
75
- weight: Optional[torch.Tensor] = None,
76
- group_ids: Optional[torch.LongTensor] = None) -> float:
77
- """
78
- Temperature scaling on held-out logits using BCEWithLogitsLoss.
79
- Optimizes log-T only.
80
- """
81
- logits = logits.detach()
82
- labels = labels.float().detach()
83
- if weight is not None:
84
- weight = weight.to(logits.device)
85
-
86
- optimizer = torch.optim.LBFGS([self.log_t], max_iter=max_iter)
87
-
88
- def closure():
89
- optimizer.zero_grad()
90
- # Always divide by T (scalar for global, vector for per-group)
91
- if self.per_group:
92
- T = self._temperature(logits.device, group_ids)
93
- else:
94
- T = self._temperature(logits.device, None)
95
- scaled = logits / T
96
- loss = F.binary_cross_entropy_with_logits(scaled, labels, weight=weight)
97
- loss.backward()
98
- return loss
99
-
100
- optimizer.step(closure)
101
- T = (F.softplus(self.log_t) + 1e-6).detach().cpu()
102
- logger.info(f"Temperature calibrated; mean T={T.mean().item():.4f}")
103
- return T.mean().item()
104
-
105
-
106
- class FactProbabilityComputer:
107
- """
108
- Computes probabilities for facts in the context of NPLL
109
- Handles both known facts F and unknown facts U
110
- """
111
-
112
- def __init__(self, probability_transform: ProbabilityTransform):
113
- self.prob_transform = probability_transform
114
-
115
- def compute_fact_probabilities(self, scores: torch.Tensor,
116
- fact_types: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
117
- """
118
- Compute probabilities and stable log-probabilities for facts from scores.
119
- """
120
- probabilities = self.prob_transform(scores)
121
- log_p, log1m_p = self.prob_transform.log_probs(scores)
122
- result = {
123
- 'probabilities': probabilities,
124
- 'log_probabilities': log_p,
125
- 'neg_log_probabilities': log1m_p
126
- }
127
- if fact_types is not None:
128
- device = probabilities.device
129
- known_mask = torch.tensor([ft == 'known' for ft in fact_types], device=device, dtype=torch.bool)
130
- unknown_mask = ~known_mask
131
- if torch.any(known_mask):
132
- result['known_probabilities'] = probabilities[known_mask]
133
- if torch.any(unknown_mask):
134
- result['unknown_probabilities'] = probabilities[unknown_mask]
135
- return result
136
-
137
- def compute_bernoulli_parameters(self, scores: torch.Tensor) -> Dict[str, torch.Tensor]:
138
- """
139
- Compute Bernoulli distribution parameters for facts
140
- """
141
- probabilities = self.prob_transform(scores)
142
- log_p, log1m_p = self.prob_transform.log_probs(scores)
143
- return {
144
- 'success_prob': probabilities,
145
- 'failure_prob': 1 - probabilities,
146
- 'log_success_prob': log_p,
147
- 'log_failure_prob': log1m_p,
148
- 'entropy': bernoulli_entropy(probabilities)
149
- }
150
-
151
-
152
- class ApproximatePosteriorComputer:
153
- """
154
- Computes approximate posterior distribution Q(U) as described in paper Section 4.2
155
-
156
- """
157
-
158
- def __init__(self, probability_transform: ProbabilityTransform):
159
- self.prob_transform = probability_transform
160
- self.fact_prob_computer = FactProbabilityComputer(probability_transform)
161
-
162
- def compute_q_u_distribution(self, unknown_fact_scores: torch.Tensor,
163
- ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
164
- """
165
- Compute approximate posterior distribution Q(U)
166
-
167
- """
168
- # Get Bernoulli parameters for unknown facts
169
- bernoulli_params = self.fact_prob_computer.compute_bernoulli_parameters(unknown_fact_scores)
170
-
171
- result = {
172
- 'fact_probabilities': bernoulli_params['success_prob'],
173
- 'fact_log_probabilities': bernoulli_params['log_success_prob'],
174
- 'fact_entropies': bernoulli_params['entropy']
175
- }
176
-
177
- # If ground rule structure provided, compute ground rule probabilities
178
- if ground_rule_structure is not None:
179
- ground_rule_probs = []
180
- ground_rule_log_probs = []
181
-
182
- for rule_fact_indices in ground_rule_structure:
183
- if rule_fact_indices:
184
- # Product of fact probabilities in this ground rule
185
- rule_fact_probs = bernoulli_params['success_prob'][rule_fact_indices]
186
- rule_prob = torch.prod(rule_fact_probs)
187
- rule_log_prob = torch.sum(bernoulli_params['log_success_prob'][rule_fact_indices])
188
-
189
- ground_rule_probs.append(rule_prob)
190
- ground_rule_log_probs.append(rule_log_prob)
191
-
192
- if ground_rule_probs:
193
- result['ground_rule_probabilities'] = torch.stack(ground_rule_probs)
194
- result['ground_rule_log_probabilities'] = torch.stack(ground_rule_log_probs)
195
-
196
- return result
197
-
198
- def compute_expected_counts(self, fact_probabilities: torch.Tensor,
199
- ground_rule_structure: List[List[int]]) -> torch.Tensor:
200
- """
201
- Compute expected counts N(F,U) for ground rules
202
-
203
- """
204
- expected_counts = []
205
-
206
- for rule_fact_indices in ground_rule_structure:
207
- if rule_fact_indices:
208
- # Expected count for this ground rule is product of fact probabilities
209
- rule_fact_probs = fact_probabilities[rule_fact_indices]
210
- expected_count = torch.prod(rule_fact_probs)
211
- expected_counts.append(expected_count)
212
- else:
213
- expected_counts.append(torch.tensor(0.0, device=fact_probabilities.device))
214
-
215
- return torch.stack(expected_counts) if expected_counts else torch.empty(0, device=fact_probabilities.device)
216
-
217
-
218
- class ProbabilityCalibrator:
219
- """
220
- Post-hoc probability calibration using various methods
221
- Improves reliability of confidence estimates
222
- """
223
-
224
- def __init__(self, method: str = 'platt'):
225
- """
226
- Args:
227
- method: Calibration method ('platt', 'isotonic', 'temperature')
228
- """
229
- self.method = method
230
- self.calibration_function = None
231
- self.is_fitted = False
232
-
233
- def fit(self, predicted_probs: np.ndarray, true_labels: np.ndarray):
234
- """
235
- Fit calibration function to data
236
- """
237
- if self.method == 'platt':
238
- from sklearn.calibration import CalibratedClassifierCV
239
- from sklearn.linear_model import LogisticRegression
240
-
241
- # Platt scaling using logistic regression
242
- self.calibration_function = LogisticRegression()
243
- self.calibration_function.fit(predicted_probs.reshape(-1, 1), true_labels)
244
-
245
- elif self.method == 'isotonic':
246
- from sklearn.isotonic import IsotonicRegression
247
-
248
- self.calibration_function = IsotonicRegression(out_of_bounds='clip')
249
- self.calibration_function.fit(predicted_probs, true_labels)
250
-
251
- elif self.method == 'temperature':
252
- # Temperature scaling (implemented in ProbabilityTransform)
253
- pass
254
-
255
- self.is_fitted = True
256
- logger.info(f"Calibration function fitted using {self.method}")
257
-
258
- def transform(self, predicted_probs: np.ndarray) -> np.ndarray:
259
- """Apply calibration to predicted probabilities"""
260
- if not self.is_fitted:
261
- logger.warning("Calibration function not fitted, returning original probabilities")
262
- return predicted_probs
263
-
264
- if self.method == 'platt':
265
- return self.calibration_function.predict_proba(predicted_probs.reshape(-1, 1))[:, 1]
266
- elif self.method == 'isotonic':
267
- return self.calibration_function.transform(predicted_probs)
268
- else:
269
- return predicted_probs
270
-
271
- def compute_calibration_error(self, predicted_probs: np.ndarray,
272
- true_labels: np.ndarray,
273
- n_bins: int = 10) -> float:
274
- """
275
- Compute Expected Calibration Error (ECE)
276
-
277
- ECE measures the difference between confidence and accuracy
278
- """
279
- bin_boundaries = np.linspace(0, 1, n_bins + 1)
280
- bin_lowers = bin_boundaries[:-1]
281
- bin_uppers = bin_boundaries[1:]
282
-
283
- ece = 0
284
- total_samples = len(predicted_probs)
285
-
286
- for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
287
- # Find predictions in this bin
288
- in_bin = (predicted_probs > bin_lower) & (predicted_probs <= bin_upper)
289
- prop_in_bin = in_bin.sum() / total_samples
290
-
291
- if prop_in_bin > 0:
292
- accuracy_in_bin = true_labels[in_bin].mean()
293
- avg_confidence_in_bin = predicted_probs[in_bin].mean()
294
- ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
295
-
296
- return ece
297
-
298
-
299
- class ConfidenceEstimator:
300
- """
301
- Estimates confidence scores for logical rules using scoring module outputs
302
- Integrates with NPLL's probabilistic framework
303
- """
304
-
305
- def __init__(self, probability_transform: ProbabilityTransform):
306
- self.prob_transform = probability_transform
307
- self.posterior_computer = ApproximatePosteriorComputer(probability_transform)
308
-
309
- def estimate_rule_confidence(self, rule_satisfaction_scores: torch.Tensor,
310
- supporting_evidence_scores: torch.Tensor) -> Dict[str, float]:
311
- """
312
- Estimate confidence for a logical rule based on satisfaction and evidence
313
- """
314
- # Convert scores to probabilities
315
- satisfaction_probs = self.prob_transform(rule_satisfaction_scores)
316
- evidence_probs = self.prob_transform(supporting_evidence_scores)
317
-
318
- # Compute various confidence metrics
319
- confidence_metrics = {
320
- 'mean_satisfaction': satisfaction_probs.mean().item(),
321
- 'median_satisfaction': satisfaction_probs.median().item(),
322
- 'min_satisfaction': satisfaction_probs.min().item(),
323
- 'max_satisfaction': satisfaction_probs.max().item(),
324
- 'std_satisfaction': satisfaction_probs.std().item(),
325
- 'mean_evidence': evidence_probs.mean().item(),
326
- 'evidence_strength': evidence_probs.sum().item(),
327
- 'num_supporting_instances': (evidence_probs > 0.5).sum().item()
328
- }
329
-
330
- # Combined confidence score (weighted average)
331
- combined_confidence = (
332
- 0.7 * confidence_metrics['mean_satisfaction'] +
333
- 0.3 * confidence_metrics['mean_evidence']
334
- )
335
- confidence_metrics['combined_confidence'] = combined_confidence
336
-
337
- return confidence_metrics
338
-
339
- def compute_uncertainty_measures(self, probabilities: torch.Tensor) -> Dict[str, float]:
340
- """
341
- Compute various uncertainty measures for probability estimates
342
-
343
- Returns:
344
- Dictionary with uncertainty metrics
345
- """
346
- # Entropy-based uncertainty
347
- entropy = bernoulli_entropy(probabilities).mean().item()
348
-
349
- # Variance-based uncertainty
350
- variance = (probabilities * (1 - probabilities)).mean().item()
351
-
352
- # Confidence intervals (assuming independence)
353
- confidence_95 = 1.96 * torch.sqrt(probabilities * (1 - probabilities))
354
- mean_ci_width = confidence_95.mean().item()
355
-
356
- return {
357
- 'entropy': entropy,
358
- 'variance': variance,
359
- 'mean_ci_width_95': mean_ci_width,
360
- 'prediction_uncertainty': entropy, # Alias for compatibility
361
- }
362
-
363
-
364
- def create_probability_components(temperature: float = 1.0,
365
- per_relation: bool = False,
366
- num_relations: int = 1) -> Tuple[ProbabilityTransform, ApproximatePosteriorComputer]:
367
- """
368
- Factory function to create probability computation components.
369
- Supports optional per-relation temperature scaling.
370
- """
371
- prob_transform = ProbabilityTransform(temperature=temperature,
372
- per_group=per_relation,
373
- num_groups=(num_relations if per_relation else 1))
374
- posterior_computer = ApproximatePosteriorComputer(prob_transform)
375
- return prob_transform, posterior_computer
376
-
377
-
378
- def verify_probability_computations():
379
- """Verify probability computation implementations"""
380
- # Test probability transform
381
- prob_transform = ProbabilityTransform(temperature=1.0)
382
-
383
- # Test scores
384
- test_scores = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
385
- probs = prob_transform(test_scores, apply_temperature=False)
386
-
387
- # Verify sigmoid properties
388
- assert torch.all(probs >= 0) and torch.all(probs <= 1), "Probabilities not in [0,1]"
389
- assert abs(probs[2].item() - 0.5) < 1e-5, "sigmoid(0) should be 0.5"
390
-
391
- # Test Bernoulli computations
392
- fact_computer = FactProbabilityComputer(prob_transform)
393
- bernoulli_params = fact_computer.compute_bernoulli_parameters(test_scores)
394
-
395
- # Verify Bernoulli properties
396
- success_prob = bernoulli_params['success_prob']
397
- failure_prob = bernoulli_params['failure_prob']
398
- assert torch.allclose(success_prob + failure_prob, torch.ones_like(success_prob)), \
399
- "Success + failure probabilities should sum to 1"
400
-
401
- logger.info("Probability computation implementations verified successfully")
402
-
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import List, Dict, Tuple, Optional, Union
6
+ import logging
7
+ import numpy as np
8
+
9
+ from ..utils.math_utils import safe_sigmoid, safe_log, bernoulli_log_prob, bernoulli_entropy
10
+ from ..core import Triple
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ProbabilityTransform(nn.Module):
16
+ """
17
+ Temperature scaling with optional per-group (e.g., per-relation) temperatures.
18
+ Guarantees T>0 via softplus on an unconstrained log-T parameter.
19
+ """
20
+
21
+ def __init__(self, temperature: float = 1.0, per_group: bool = False, num_groups: int = 1):
22
+ super().__init__()
23
+ self.per_group = per_group
24
+ init = torch.log(torch.expm1(torch.tensor(float(temperature))))
25
+ if per_group:
26
+ assert num_groups >= 1
27
+ self.log_t = nn.Parameter(init.repeat(int(num_groups)))
28
+ else:
29
+ self.log_t = nn.Parameter(init.unsqueeze(0))
30
+
31
+ def _temperature(self, device, group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
32
+ T = F.softplus(self.log_t.to(device)) + 1e-6
33
+ if self.per_group:
34
+ if group_ids is None:
35
+ raise ValueError("group_ids is required when per_group=True")
36
+ return T.index_select(0, group_ids)
37
+ return T[0]
38
+
39
+ def ensure_num_groups(self, num_groups: int):
40
+ if not self.per_group:
41
+ return
42
+ cur = self.log_t.numel()
43
+ if num_groups <= cur:
44
+ return
45
+ with torch.no_grad():
46
+ new = torch.empty(num_groups, device=self.log_t.device, dtype=self.log_t.dtype)
47
+ new[:cur] = self.log_t.data
48
+ init = torch.log(torch.expm1(torch.tensor(1.0, device=new.device, dtype=new.dtype)))
49
+ new[cur:] = init
50
+ self.log_t = nn.Parameter(new)
51
+
52
+ def forward(self, scores: torch.Tensor,
53
+ apply_temperature: bool = True,
54
+ group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
55
+ scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
56
+ return torch.sigmoid(scaled)
57
+
58
+ def log_probs(self, scores: torch.Tensor,
59
+ apply_temperature: bool = True,
60
+ group_ids: Optional[torch.LongTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
62
+ return F.logsigmoid(scaled), F.logsigmoid(-scaled)
63
+
64
+ # Backward-compat alias
65
+ def get_log_probabilities(self, scores: torch.Tensor, apply_temperature: bool = True,
66
+ group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
67
+ log_p, _ = self.log_probs(scores, apply_temperature, group_ids)
68
+ return log_p
69
+
70
+ @torch.no_grad()
71
+ def calibrate_temperature(self,
72
+ logits: torch.Tensor, # raw scores before sigmoid
73
+ labels: torch.Tensor, # {0,1}
74
+ max_iter: int = 100,
75
+ weight: Optional[torch.Tensor] = None,
76
+ group_ids: Optional[torch.LongTensor] = None) -> float:
77
+ """
78
+ Temperature scaling on held-out logits using BCEWithLogitsLoss.
79
+ Optimizes log-T only.
80
+ """
81
+ logits = logits.detach()
82
+ labels = labels.float().detach()
83
+ if weight is not None:
84
+ weight = weight.to(logits.device)
85
+
86
+ optimizer = torch.optim.LBFGS([self.log_t], max_iter=max_iter)
87
+
88
+ def closure():
89
+ optimizer.zero_grad()
90
+ # Always divide by T (scalar for global, vector for per-group)
91
+ if self.per_group:
92
+ T = self._temperature(logits.device, group_ids)
93
+ else:
94
+ T = self._temperature(logits.device, None)
95
+ scaled = logits / T
96
+ loss = F.binary_cross_entropy_with_logits(scaled, labels, weight=weight)
97
+ loss.backward()
98
+ return loss
99
+
100
+ optimizer.step(closure)
101
+ T = (F.softplus(self.log_t) + 1e-6).detach().cpu()
102
+ logger.info(f"Temperature calibrated; mean T={T.mean().item():.4f}")
103
+ return T.mean().item()
104
+
105
+
106
+ class FactProbabilityComputer:
107
+ """
108
+ Computes probabilities for facts in the context of NPLL
109
+ Handles both known facts F and unknown facts U
110
+ """
111
+
112
+ def __init__(self, probability_transform: ProbabilityTransform):
113
+ self.prob_transform = probability_transform
114
+
115
+ def compute_fact_probabilities(self, scores: torch.Tensor,
116
+ fact_types: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
117
+ """
118
+ Compute probabilities and stable log-probabilities for facts from scores.
119
+ """
120
+ probabilities = self.prob_transform(scores)
121
+ log_p, log1m_p = self.prob_transform.log_probs(scores)
122
+ result = {
123
+ 'probabilities': probabilities,
124
+ 'log_probabilities': log_p,
125
+ 'neg_log_probabilities': log1m_p
126
+ }
127
+ if fact_types is not None:
128
+ device = probabilities.device
129
+ known_mask = torch.tensor([ft == 'known' for ft in fact_types], device=device, dtype=torch.bool)
130
+ unknown_mask = ~known_mask
131
+ if torch.any(known_mask):
132
+ result['known_probabilities'] = probabilities[known_mask]
133
+ if torch.any(unknown_mask):
134
+ result['unknown_probabilities'] = probabilities[unknown_mask]
135
+ return result
136
+
137
+ def compute_bernoulli_parameters(self, scores: torch.Tensor) -> Dict[str, torch.Tensor]:
138
+ """
139
+ Compute Bernoulli distribution parameters for facts
140
+ """
141
+ probabilities = self.prob_transform(scores)
142
+ log_p, log1m_p = self.prob_transform.log_probs(scores)
143
+ return {
144
+ 'success_prob': probabilities,
145
+ 'failure_prob': 1 - probabilities,
146
+ 'log_success_prob': log_p,
147
+ 'log_failure_prob': log1m_p,
148
+ 'entropy': bernoulli_entropy(probabilities)
149
+ }
150
+
151
+
152
+ class ApproximatePosteriorComputer:
153
+ """
154
+ Computes approximate posterior distribution Q(U) as described in paper Section 4.2
155
+
156
+ """
157
+
158
+ def __init__(self, probability_transform: ProbabilityTransform):
159
+ self.prob_transform = probability_transform
160
+ self.fact_prob_computer = FactProbabilityComputer(probability_transform)
161
+
162
+ def compute_q_u_distribution(self, unknown_fact_scores: torch.Tensor,
163
+ ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
164
+ """
165
+ Compute approximate posterior distribution Q(U)
166
+
167
+ """
168
+ # Get Bernoulli parameters for unknown facts
169
+ bernoulli_params = self.fact_prob_computer.compute_bernoulli_parameters(unknown_fact_scores)
170
+
171
+ result = {
172
+ 'fact_probabilities': bernoulli_params['success_prob'],
173
+ 'fact_log_probabilities': bernoulli_params['log_success_prob'],
174
+ 'fact_entropies': bernoulli_params['entropy']
175
+ }
176
+
177
+ # If ground rule structure provided, compute ground rule probabilities
178
+ if ground_rule_structure is not None:
179
+ ground_rule_probs = []
180
+ ground_rule_log_probs = []
181
+
182
+ for rule_fact_indices in ground_rule_structure:
183
+ if rule_fact_indices:
184
+ # Product of fact probabilities in this ground rule
185
+ rule_fact_probs = bernoulli_params['success_prob'][rule_fact_indices]
186
+ rule_prob = torch.prod(rule_fact_probs)
187
+ rule_log_prob = torch.sum(bernoulli_params['log_success_prob'][rule_fact_indices])
188
+
189
+ ground_rule_probs.append(rule_prob)
190
+ ground_rule_log_probs.append(rule_log_prob)
191
+
192
+ if ground_rule_probs:
193
+ result['ground_rule_probabilities'] = torch.stack(ground_rule_probs)
194
+ result['ground_rule_log_probabilities'] = torch.stack(ground_rule_log_probs)
195
+
196
+ return result
197
+
198
+ def compute_expected_counts(self, fact_probabilities: torch.Tensor,
199
+ ground_rule_structure: List[List[int]]) -> torch.Tensor:
200
+ """
201
+ Compute expected counts N(F,U) for ground rules
202
+
203
+ """
204
+ expected_counts = []
205
+
206
+ for rule_fact_indices in ground_rule_structure:
207
+ if rule_fact_indices:
208
+ # Expected count for this ground rule is product of fact probabilities
209
+ rule_fact_probs = fact_probabilities[rule_fact_indices]
210
+ expected_count = torch.prod(rule_fact_probs)
211
+ expected_counts.append(expected_count)
212
+ else:
213
+ expected_counts.append(torch.tensor(0.0, device=fact_probabilities.device))
214
+
215
+ return torch.stack(expected_counts) if expected_counts else torch.empty(0, device=fact_probabilities.device)
216
+
217
+
218
+ class ProbabilityCalibrator:
219
+ """
220
+ Post-hoc probability calibration using various methods
221
+ Improves reliability of confidence estimates
222
+ """
223
+
224
+ def __init__(self, method: str = 'platt'):
225
+ """
226
+ Args:
227
+ method: Calibration method ('platt', 'isotonic', 'temperature')
228
+ """
229
+ self.method = method
230
+ self.calibration_function = None
231
+ self.is_fitted = False
232
+
233
+ def fit(self, predicted_probs: np.ndarray, true_labels: np.ndarray):
234
+ """
235
+ Fit calibration function to data
236
+ """
237
+ if self.method == 'platt':
238
+ from sklearn.calibration import CalibratedClassifierCV
239
+ from sklearn.linear_model import LogisticRegression
240
+
241
+ # Platt scaling using logistic regression
242
+ self.calibration_function = LogisticRegression()
243
+ self.calibration_function.fit(predicted_probs.reshape(-1, 1), true_labels)
244
+
245
+ elif self.method == 'isotonic':
246
+ from sklearn.isotonic import IsotonicRegression
247
+
248
+ self.calibration_function = IsotonicRegression(out_of_bounds='clip')
249
+ self.calibration_function.fit(predicted_probs, true_labels)
250
+
251
+ elif self.method == 'temperature':
252
+ # Temperature scaling (implemented in ProbabilityTransform)
253
+ pass
254
+
255
+ self.is_fitted = True
256
+ logger.info(f"Calibration function fitted using {self.method}")
257
+
258
+ def transform(self, predicted_probs: np.ndarray) -> np.ndarray:
259
+ """Apply calibration to predicted probabilities"""
260
+ if not self.is_fitted:
261
+ logger.warning("Calibration function not fitted, returning original probabilities")
262
+ return predicted_probs
263
+
264
+ if self.method == 'platt':
265
+ return self.calibration_function.predict_proba(predicted_probs.reshape(-1, 1))[:, 1]
266
+ elif self.method == 'isotonic':
267
+ return self.calibration_function.transform(predicted_probs)
268
+ else:
269
+ return predicted_probs
270
+
271
+ def compute_calibration_error(self, predicted_probs: np.ndarray,
272
+ true_labels: np.ndarray,
273
+ n_bins: int = 10) -> float:
274
+ """
275
+ Compute Expected Calibration Error (ECE)
276
+
277
+ ECE measures the difference between confidence and accuracy
278
+ """
279
+ bin_boundaries = np.linspace(0, 1, n_bins + 1)
280
+ bin_lowers = bin_boundaries[:-1]
281
+ bin_uppers = bin_boundaries[1:]
282
+
283
+ ece = 0
284
+ total_samples = len(predicted_probs)
285
+
286
+ for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
287
+ # Find predictions in this bin
288
+ in_bin = (predicted_probs > bin_lower) & (predicted_probs <= bin_upper)
289
+ prop_in_bin = in_bin.sum() / total_samples
290
+
291
+ if prop_in_bin > 0:
292
+ accuracy_in_bin = true_labels[in_bin].mean()
293
+ avg_confidence_in_bin = predicted_probs[in_bin].mean()
294
+ ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
295
+
296
+ return ece
297
+
298
+
299
+ class ConfidenceEstimator:
300
+ """
301
+ Estimates confidence scores for logical rules using scoring module outputs
302
+ Integrates with NPLL's probabilistic framework
303
+ """
304
+
305
+ def __init__(self, probability_transform: ProbabilityTransform):
306
+ self.prob_transform = probability_transform
307
+ self.posterior_computer = ApproximatePosteriorComputer(probability_transform)
308
+
309
+ def estimate_rule_confidence(self, rule_satisfaction_scores: torch.Tensor,
310
+ supporting_evidence_scores: torch.Tensor) -> Dict[str, float]:
311
+ """
312
+ Estimate confidence for a logical rule based on satisfaction and evidence
313
+ """
314
+ # Convert scores to probabilities
315
+ satisfaction_probs = self.prob_transform(rule_satisfaction_scores)
316
+ evidence_probs = self.prob_transform(supporting_evidence_scores)
317
+
318
+ # Compute various confidence metrics
319
+ confidence_metrics = {
320
+ 'mean_satisfaction': satisfaction_probs.mean().item(),
321
+ 'median_satisfaction': satisfaction_probs.median().item(),
322
+ 'min_satisfaction': satisfaction_probs.min().item(),
323
+ 'max_satisfaction': satisfaction_probs.max().item(),
324
+ 'std_satisfaction': satisfaction_probs.std().item(),
325
+ 'mean_evidence': evidence_probs.mean().item(),
326
+ 'evidence_strength': evidence_probs.sum().item(),
327
+ 'num_supporting_instances': (evidence_probs > 0.5).sum().item()
328
+ }
329
+
330
+ # Combined confidence score (weighted average)
331
+ combined_confidence = (
332
+ 0.7 * confidence_metrics['mean_satisfaction'] +
333
+ 0.3 * confidence_metrics['mean_evidence']
334
+ )
335
+ confidence_metrics['combined_confidence'] = combined_confidence
336
+
337
+ return confidence_metrics
338
+
339
+ def compute_uncertainty_measures(self, probabilities: torch.Tensor) -> Dict[str, float]:
340
+ """
341
+ Compute various uncertainty measures for probability estimates
342
+
343
+ Returns:
344
+ Dictionary with uncertainty metrics
345
+ """
346
+ # Entropy-based uncertainty
347
+ entropy = bernoulli_entropy(probabilities).mean().item()
348
+
349
+ # Variance-based uncertainty
350
+ variance = (probabilities * (1 - probabilities)).mean().item()
351
+
352
+ # Confidence intervals (assuming independence)
353
+ confidence_95 = 1.96 * torch.sqrt(probabilities * (1 - probabilities))
354
+ mean_ci_width = confidence_95.mean().item()
355
+
356
+ return {
357
+ 'entropy': entropy,
358
+ 'variance': variance,
359
+ 'mean_ci_width_95': mean_ci_width,
360
+ 'prediction_uncertainty': entropy, # Alias for compatibility
361
+ }
362
+
363
+
364
+ def create_probability_components(temperature: float = 1.0,
365
+ per_relation: bool = False,
366
+ num_relations: int = 1) -> Tuple[ProbabilityTransform, ApproximatePosteriorComputer]:
367
+ """
368
+ Factory function to create probability computation components.
369
+ Supports optional per-relation temperature scaling.
370
+ """
371
+ prob_transform = ProbabilityTransform(temperature=temperature,
372
+ per_group=per_relation,
373
+ num_groups=(num_relations if per_relation else 1))
374
+ posterior_computer = ApproximatePosteriorComputer(prob_transform)
375
+ return prob_transform, posterior_computer
376
+
377
+
378
+ def verify_probability_computations():
379
+ """Verify probability computation implementations"""
380
+ # Test probability transform
381
+ prob_transform = ProbabilityTransform(temperature=1.0)
382
+
383
+ # Test scores
384
+ test_scores = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
385
+ probs = prob_transform(test_scores, apply_temperature=False)
386
+
387
+ # Verify sigmoid properties
388
+ assert torch.all(probs >= 0) and torch.all(probs <= 1), "Probabilities not in [0,1]"
389
+ assert abs(probs[2].item() - 0.5) < 1e-5, "sigmoid(0) should be 0.5"
390
+
391
+ # Test Bernoulli computations
392
+ fact_computer = FactProbabilityComputer(prob_transform)
393
+ bernoulli_params = fact_computer.compute_bernoulli_parameters(test_scores)
394
+
395
+ # Verify Bernoulli properties
396
+ success_prob = bernoulli_params['success_prob']
397
+ failure_prob = bernoulli_params['failure_prob']
398
+ assert torch.allclose(success_prob + failure_prob, torch.ones_like(success_prob)), \
399
+ "Success + failure probabilities should sum to 1"
400
+
401
+ logger.info("Probability computation implementations verified successfully")
402
+
403
403
  return True