odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/scoring/probability.py
CHANGED
|
@@ -1,403 +1,403 @@
|
|
|
1
|
-
|
|
2
|
-
import torch
|
|
3
|
-
import torch.nn as nn
|
|
4
|
-
import torch.nn.functional as F
|
|
5
|
-
from typing import List, Dict, Tuple, Optional, Union
|
|
6
|
-
import logging
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
from ..utils.math_utils import safe_sigmoid, safe_log, bernoulli_log_prob, bernoulli_entropy
|
|
10
|
-
from ..core import Triple
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class ProbabilityTransform(nn.Module):
|
|
16
|
-
"""
|
|
17
|
-
Temperature scaling with optional per-group (e.g., per-relation) temperatures.
|
|
18
|
-
Guarantees T>0 via softplus on an unconstrained log-T parameter.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def __init__(self, temperature: float = 1.0, per_group: bool = False, num_groups: int = 1):
|
|
22
|
-
super().__init__()
|
|
23
|
-
self.per_group = per_group
|
|
24
|
-
init = torch.log(torch.expm1(torch.tensor(float(temperature))))
|
|
25
|
-
if per_group:
|
|
26
|
-
assert num_groups >= 1
|
|
27
|
-
self.log_t = nn.Parameter(init.repeat(int(num_groups)))
|
|
28
|
-
else:
|
|
29
|
-
self.log_t = nn.Parameter(init.unsqueeze(0))
|
|
30
|
-
|
|
31
|
-
def _temperature(self, device, group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
32
|
-
T = F.softplus(self.log_t.to(device)) + 1e-6
|
|
33
|
-
if self.per_group:
|
|
34
|
-
if group_ids is None:
|
|
35
|
-
raise ValueError("group_ids is required when per_group=True")
|
|
36
|
-
return T.index_select(0, group_ids)
|
|
37
|
-
return T[0]
|
|
38
|
-
|
|
39
|
-
def ensure_num_groups(self, num_groups: int):
|
|
40
|
-
if not self.per_group:
|
|
41
|
-
return
|
|
42
|
-
cur = self.log_t.numel()
|
|
43
|
-
if num_groups <= cur:
|
|
44
|
-
return
|
|
45
|
-
with torch.no_grad():
|
|
46
|
-
new = torch.empty(num_groups, device=self.log_t.device, dtype=self.log_t.dtype)
|
|
47
|
-
new[:cur] = self.log_t.data
|
|
48
|
-
init = torch.log(torch.expm1(torch.tensor(1.0, device=new.device, dtype=new.dtype)))
|
|
49
|
-
new[cur:] = init
|
|
50
|
-
self.log_t = nn.Parameter(new)
|
|
51
|
-
|
|
52
|
-
def forward(self, scores: torch.Tensor,
|
|
53
|
-
apply_temperature: bool = True,
|
|
54
|
-
group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
55
|
-
scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
|
|
56
|
-
return torch.sigmoid(scaled)
|
|
57
|
-
|
|
58
|
-
def log_probs(self, scores: torch.Tensor,
|
|
59
|
-
apply_temperature: bool = True,
|
|
60
|
-
group_ids: Optional[torch.LongTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
61
|
-
scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
|
|
62
|
-
return F.logsigmoid(scaled), F.logsigmoid(-scaled)
|
|
63
|
-
|
|
64
|
-
# Backward-compat alias
|
|
65
|
-
def get_log_probabilities(self, scores: torch.Tensor, apply_temperature: bool = True,
|
|
66
|
-
group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
67
|
-
log_p, _ = self.log_probs(scores, apply_temperature, group_ids)
|
|
68
|
-
return log_p
|
|
69
|
-
|
|
70
|
-
@torch.no_grad()
|
|
71
|
-
def calibrate_temperature(self,
|
|
72
|
-
logits: torch.Tensor, # raw scores before sigmoid
|
|
73
|
-
labels: torch.Tensor, # {0,1}
|
|
74
|
-
max_iter: int = 100,
|
|
75
|
-
weight: Optional[torch.Tensor] = None,
|
|
76
|
-
group_ids: Optional[torch.LongTensor] = None) -> float:
|
|
77
|
-
"""
|
|
78
|
-
Temperature scaling on held-out logits using BCEWithLogitsLoss.
|
|
79
|
-
Optimizes log-T only.
|
|
80
|
-
"""
|
|
81
|
-
logits = logits.detach()
|
|
82
|
-
labels = labels.float().detach()
|
|
83
|
-
if weight is not None:
|
|
84
|
-
weight = weight.to(logits.device)
|
|
85
|
-
|
|
86
|
-
optimizer = torch.optim.LBFGS([self.log_t], max_iter=max_iter)
|
|
87
|
-
|
|
88
|
-
def closure():
|
|
89
|
-
optimizer.zero_grad()
|
|
90
|
-
# Always divide by T (scalar for global, vector for per-group)
|
|
91
|
-
if self.per_group:
|
|
92
|
-
T = self._temperature(logits.device, group_ids)
|
|
93
|
-
else:
|
|
94
|
-
T = self._temperature(logits.device, None)
|
|
95
|
-
scaled = logits / T
|
|
96
|
-
loss = F.binary_cross_entropy_with_logits(scaled, labels, weight=weight)
|
|
97
|
-
loss.backward()
|
|
98
|
-
return loss
|
|
99
|
-
|
|
100
|
-
optimizer.step(closure)
|
|
101
|
-
T = (F.softplus(self.log_t) + 1e-6).detach().cpu()
|
|
102
|
-
logger.info(f"Temperature calibrated; mean T={T.mean().item():.4f}")
|
|
103
|
-
return T.mean().item()
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
class FactProbabilityComputer:
|
|
107
|
-
"""
|
|
108
|
-
Computes probabilities for facts in the context of NPLL
|
|
109
|
-
Handles both known facts F and unknown facts U
|
|
110
|
-
"""
|
|
111
|
-
|
|
112
|
-
def __init__(self, probability_transform: ProbabilityTransform):
|
|
113
|
-
self.prob_transform = probability_transform
|
|
114
|
-
|
|
115
|
-
def compute_fact_probabilities(self, scores: torch.Tensor,
|
|
116
|
-
fact_types: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
|
|
117
|
-
"""
|
|
118
|
-
Compute probabilities and stable log-probabilities for facts from scores.
|
|
119
|
-
"""
|
|
120
|
-
probabilities = self.prob_transform(scores)
|
|
121
|
-
log_p, log1m_p = self.prob_transform.log_probs(scores)
|
|
122
|
-
result = {
|
|
123
|
-
'probabilities': probabilities,
|
|
124
|
-
'log_probabilities': log_p,
|
|
125
|
-
'neg_log_probabilities': log1m_p
|
|
126
|
-
}
|
|
127
|
-
if fact_types is not None:
|
|
128
|
-
device = probabilities.device
|
|
129
|
-
known_mask = torch.tensor([ft == 'known' for ft in fact_types], device=device, dtype=torch.bool)
|
|
130
|
-
unknown_mask = ~known_mask
|
|
131
|
-
if torch.any(known_mask):
|
|
132
|
-
result['known_probabilities'] = probabilities[known_mask]
|
|
133
|
-
if torch.any(unknown_mask):
|
|
134
|
-
result['unknown_probabilities'] = probabilities[unknown_mask]
|
|
135
|
-
return result
|
|
136
|
-
|
|
137
|
-
def compute_bernoulli_parameters(self, scores: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
138
|
-
"""
|
|
139
|
-
Compute Bernoulli distribution parameters for facts
|
|
140
|
-
"""
|
|
141
|
-
probabilities = self.prob_transform(scores)
|
|
142
|
-
log_p, log1m_p = self.prob_transform.log_probs(scores)
|
|
143
|
-
return {
|
|
144
|
-
'success_prob': probabilities,
|
|
145
|
-
'failure_prob': 1 - probabilities,
|
|
146
|
-
'log_success_prob': log_p,
|
|
147
|
-
'log_failure_prob': log1m_p,
|
|
148
|
-
'entropy': bernoulli_entropy(probabilities)
|
|
149
|
-
}
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
class ApproximatePosteriorComputer:
|
|
153
|
-
"""
|
|
154
|
-
Computes approximate posterior distribution Q(U) as described in paper Section 4.2
|
|
155
|
-
|
|
156
|
-
"""
|
|
157
|
-
|
|
158
|
-
def __init__(self, probability_transform: ProbabilityTransform):
|
|
159
|
-
self.prob_transform = probability_transform
|
|
160
|
-
self.fact_prob_computer = FactProbabilityComputer(probability_transform)
|
|
161
|
-
|
|
162
|
-
def compute_q_u_distribution(self, unknown_fact_scores: torch.Tensor,
|
|
163
|
-
ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
|
|
164
|
-
"""
|
|
165
|
-
Compute approximate posterior distribution Q(U)
|
|
166
|
-
|
|
167
|
-
"""
|
|
168
|
-
# Get Bernoulli parameters for unknown facts
|
|
169
|
-
bernoulli_params = self.fact_prob_computer.compute_bernoulli_parameters(unknown_fact_scores)
|
|
170
|
-
|
|
171
|
-
result = {
|
|
172
|
-
'fact_probabilities': bernoulli_params['success_prob'],
|
|
173
|
-
'fact_log_probabilities': bernoulli_params['log_success_prob'],
|
|
174
|
-
'fact_entropies': bernoulli_params['entropy']
|
|
175
|
-
}
|
|
176
|
-
|
|
177
|
-
# If ground rule structure provided, compute ground rule probabilities
|
|
178
|
-
if ground_rule_structure is not None:
|
|
179
|
-
ground_rule_probs = []
|
|
180
|
-
ground_rule_log_probs = []
|
|
181
|
-
|
|
182
|
-
for rule_fact_indices in ground_rule_structure:
|
|
183
|
-
if rule_fact_indices:
|
|
184
|
-
# Product of fact probabilities in this ground rule
|
|
185
|
-
rule_fact_probs = bernoulli_params['success_prob'][rule_fact_indices]
|
|
186
|
-
rule_prob = torch.prod(rule_fact_probs)
|
|
187
|
-
rule_log_prob = torch.sum(bernoulli_params['log_success_prob'][rule_fact_indices])
|
|
188
|
-
|
|
189
|
-
ground_rule_probs.append(rule_prob)
|
|
190
|
-
ground_rule_log_probs.append(rule_log_prob)
|
|
191
|
-
|
|
192
|
-
if ground_rule_probs:
|
|
193
|
-
result['ground_rule_probabilities'] = torch.stack(ground_rule_probs)
|
|
194
|
-
result['ground_rule_log_probabilities'] = torch.stack(ground_rule_log_probs)
|
|
195
|
-
|
|
196
|
-
return result
|
|
197
|
-
|
|
198
|
-
def compute_expected_counts(self, fact_probabilities: torch.Tensor,
|
|
199
|
-
ground_rule_structure: List[List[int]]) -> torch.Tensor:
|
|
200
|
-
"""
|
|
201
|
-
Compute expected counts N(F,U) for ground rules
|
|
202
|
-
|
|
203
|
-
"""
|
|
204
|
-
expected_counts = []
|
|
205
|
-
|
|
206
|
-
for rule_fact_indices in ground_rule_structure:
|
|
207
|
-
if rule_fact_indices:
|
|
208
|
-
# Expected count for this ground rule is product of fact probabilities
|
|
209
|
-
rule_fact_probs = fact_probabilities[rule_fact_indices]
|
|
210
|
-
expected_count = torch.prod(rule_fact_probs)
|
|
211
|
-
expected_counts.append(expected_count)
|
|
212
|
-
else:
|
|
213
|
-
expected_counts.append(torch.tensor(0.0, device=fact_probabilities.device))
|
|
214
|
-
|
|
215
|
-
return torch.stack(expected_counts) if expected_counts else torch.empty(0, device=fact_probabilities.device)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
class ProbabilityCalibrator:
|
|
219
|
-
"""
|
|
220
|
-
Post-hoc probability calibration using various methods
|
|
221
|
-
Improves reliability of confidence estimates
|
|
222
|
-
"""
|
|
223
|
-
|
|
224
|
-
def __init__(self, method: str = 'platt'):
|
|
225
|
-
"""
|
|
226
|
-
Args:
|
|
227
|
-
method: Calibration method ('platt', 'isotonic', 'temperature')
|
|
228
|
-
"""
|
|
229
|
-
self.method = method
|
|
230
|
-
self.calibration_function = None
|
|
231
|
-
self.is_fitted = False
|
|
232
|
-
|
|
233
|
-
def fit(self, predicted_probs: np.ndarray, true_labels: np.ndarray):
|
|
234
|
-
"""
|
|
235
|
-
Fit calibration function to data
|
|
236
|
-
"""
|
|
237
|
-
if self.method == 'platt':
|
|
238
|
-
from sklearn.calibration import CalibratedClassifierCV
|
|
239
|
-
from sklearn.linear_model import LogisticRegression
|
|
240
|
-
|
|
241
|
-
# Platt scaling using logistic regression
|
|
242
|
-
self.calibration_function = LogisticRegression()
|
|
243
|
-
self.calibration_function.fit(predicted_probs.reshape(-1, 1), true_labels)
|
|
244
|
-
|
|
245
|
-
elif self.method == 'isotonic':
|
|
246
|
-
from sklearn.isotonic import IsotonicRegression
|
|
247
|
-
|
|
248
|
-
self.calibration_function = IsotonicRegression(out_of_bounds='clip')
|
|
249
|
-
self.calibration_function.fit(predicted_probs, true_labels)
|
|
250
|
-
|
|
251
|
-
elif self.method == 'temperature':
|
|
252
|
-
# Temperature scaling (implemented in ProbabilityTransform)
|
|
253
|
-
pass
|
|
254
|
-
|
|
255
|
-
self.is_fitted = True
|
|
256
|
-
logger.info(f"Calibration function fitted using {self.method}")
|
|
257
|
-
|
|
258
|
-
def transform(self, predicted_probs: np.ndarray) -> np.ndarray:
|
|
259
|
-
"""Apply calibration to predicted probabilities"""
|
|
260
|
-
if not self.is_fitted:
|
|
261
|
-
logger.warning("Calibration function not fitted, returning original probabilities")
|
|
262
|
-
return predicted_probs
|
|
263
|
-
|
|
264
|
-
if self.method == 'platt':
|
|
265
|
-
return self.calibration_function.predict_proba(predicted_probs.reshape(-1, 1))[:, 1]
|
|
266
|
-
elif self.method == 'isotonic':
|
|
267
|
-
return self.calibration_function.transform(predicted_probs)
|
|
268
|
-
else:
|
|
269
|
-
return predicted_probs
|
|
270
|
-
|
|
271
|
-
def compute_calibration_error(self, predicted_probs: np.ndarray,
|
|
272
|
-
true_labels: np.ndarray,
|
|
273
|
-
n_bins: int = 10) -> float:
|
|
274
|
-
"""
|
|
275
|
-
Compute Expected Calibration Error (ECE)
|
|
276
|
-
|
|
277
|
-
ECE measures the difference between confidence and accuracy
|
|
278
|
-
"""
|
|
279
|
-
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
|
280
|
-
bin_lowers = bin_boundaries[:-1]
|
|
281
|
-
bin_uppers = bin_boundaries[1:]
|
|
282
|
-
|
|
283
|
-
ece = 0
|
|
284
|
-
total_samples = len(predicted_probs)
|
|
285
|
-
|
|
286
|
-
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
|
287
|
-
# Find predictions in this bin
|
|
288
|
-
in_bin = (predicted_probs > bin_lower) & (predicted_probs <= bin_upper)
|
|
289
|
-
prop_in_bin = in_bin.sum() / total_samples
|
|
290
|
-
|
|
291
|
-
if prop_in_bin > 0:
|
|
292
|
-
accuracy_in_bin = true_labels[in_bin].mean()
|
|
293
|
-
avg_confidence_in_bin = predicted_probs[in_bin].mean()
|
|
294
|
-
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
|
295
|
-
|
|
296
|
-
return ece
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
class ConfidenceEstimator:
|
|
300
|
-
"""
|
|
301
|
-
Estimates confidence scores for logical rules using scoring module outputs
|
|
302
|
-
Integrates with NPLL's probabilistic framework
|
|
303
|
-
"""
|
|
304
|
-
|
|
305
|
-
def __init__(self, probability_transform: ProbabilityTransform):
|
|
306
|
-
self.prob_transform = probability_transform
|
|
307
|
-
self.posterior_computer = ApproximatePosteriorComputer(probability_transform)
|
|
308
|
-
|
|
309
|
-
def estimate_rule_confidence(self, rule_satisfaction_scores: torch.Tensor,
|
|
310
|
-
supporting_evidence_scores: torch.Tensor) -> Dict[str, float]:
|
|
311
|
-
"""
|
|
312
|
-
Estimate confidence for a logical rule based on satisfaction and evidence
|
|
313
|
-
"""
|
|
314
|
-
# Convert scores to probabilities
|
|
315
|
-
satisfaction_probs = self.prob_transform(rule_satisfaction_scores)
|
|
316
|
-
evidence_probs = self.prob_transform(supporting_evidence_scores)
|
|
317
|
-
|
|
318
|
-
# Compute various confidence metrics
|
|
319
|
-
confidence_metrics = {
|
|
320
|
-
'mean_satisfaction': satisfaction_probs.mean().item(),
|
|
321
|
-
'median_satisfaction': satisfaction_probs.median().item(),
|
|
322
|
-
'min_satisfaction': satisfaction_probs.min().item(),
|
|
323
|
-
'max_satisfaction': satisfaction_probs.max().item(),
|
|
324
|
-
'std_satisfaction': satisfaction_probs.std().item(),
|
|
325
|
-
'mean_evidence': evidence_probs.mean().item(),
|
|
326
|
-
'evidence_strength': evidence_probs.sum().item(),
|
|
327
|
-
'num_supporting_instances': (evidence_probs > 0.5).sum().item()
|
|
328
|
-
}
|
|
329
|
-
|
|
330
|
-
# Combined confidence score (weighted average)
|
|
331
|
-
combined_confidence = (
|
|
332
|
-
0.7 * confidence_metrics['mean_satisfaction'] +
|
|
333
|
-
0.3 * confidence_metrics['mean_evidence']
|
|
334
|
-
)
|
|
335
|
-
confidence_metrics['combined_confidence'] = combined_confidence
|
|
336
|
-
|
|
337
|
-
return confidence_metrics
|
|
338
|
-
|
|
339
|
-
def compute_uncertainty_measures(self, probabilities: torch.Tensor) -> Dict[str, float]:
|
|
340
|
-
"""
|
|
341
|
-
Compute various uncertainty measures for probability estimates
|
|
342
|
-
|
|
343
|
-
Returns:
|
|
344
|
-
Dictionary with uncertainty metrics
|
|
345
|
-
"""
|
|
346
|
-
# Entropy-based uncertainty
|
|
347
|
-
entropy = bernoulli_entropy(probabilities).mean().item()
|
|
348
|
-
|
|
349
|
-
# Variance-based uncertainty
|
|
350
|
-
variance = (probabilities * (1 - probabilities)).mean().item()
|
|
351
|
-
|
|
352
|
-
# Confidence intervals (assuming independence)
|
|
353
|
-
confidence_95 = 1.96 * torch.sqrt(probabilities * (1 - probabilities))
|
|
354
|
-
mean_ci_width = confidence_95.mean().item()
|
|
355
|
-
|
|
356
|
-
return {
|
|
357
|
-
'entropy': entropy,
|
|
358
|
-
'variance': variance,
|
|
359
|
-
'mean_ci_width_95': mean_ci_width,
|
|
360
|
-
'prediction_uncertainty': entropy, # Alias for compatibility
|
|
361
|
-
}
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def create_probability_components(temperature: float = 1.0,
|
|
365
|
-
per_relation: bool = False,
|
|
366
|
-
num_relations: int = 1) -> Tuple[ProbabilityTransform, ApproximatePosteriorComputer]:
|
|
367
|
-
"""
|
|
368
|
-
Factory function to create probability computation components.
|
|
369
|
-
Supports optional per-relation temperature scaling.
|
|
370
|
-
"""
|
|
371
|
-
prob_transform = ProbabilityTransform(temperature=temperature,
|
|
372
|
-
per_group=per_relation,
|
|
373
|
-
num_groups=(num_relations if per_relation else 1))
|
|
374
|
-
posterior_computer = ApproximatePosteriorComputer(prob_transform)
|
|
375
|
-
return prob_transform, posterior_computer
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def verify_probability_computations():
|
|
379
|
-
"""Verify probability computation implementations"""
|
|
380
|
-
# Test probability transform
|
|
381
|
-
prob_transform = ProbabilityTransform(temperature=1.0)
|
|
382
|
-
|
|
383
|
-
# Test scores
|
|
384
|
-
test_scores = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
|
|
385
|
-
probs = prob_transform(test_scores, apply_temperature=False)
|
|
386
|
-
|
|
387
|
-
# Verify sigmoid properties
|
|
388
|
-
assert torch.all(probs >= 0) and torch.all(probs <= 1), "Probabilities not in [0,1]"
|
|
389
|
-
assert abs(probs[2].item() - 0.5) < 1e-5, "sigmoid(0) should be 0.5"
|
|
390
|
-
|
|
391
|
-
# Test Bernoulli computations
|
|
392
|
-
fact_computer = FactProbabilityComputer(prob_transform)
|
|
393
|
-
bernoulli_params = fact_computer.compute_bernoulli_parameters(test_scores)
|
|
394
|
-
|
|
395
|
-
# Verify Bernoulli properties
|
|
396
|
-
success_prob = bernoulli_params['success_prob']
|
|
397
|
-
failure_prob = bernoulli_params['failure_prob']
|
|
398
|
-
assert torch.allclose(success_prob + failure_prob, torch.ones_like(success_prob)), \
|
|
399
|
-
"Success + failure probabilities should sum to 1"
|
|
400
|
-
|
|
401
|
-
logger.info("Probability computation implementations verified successfully")
|
|
402
|
-
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from typing import List, Dict, Tuple, Optional, Union
|
|
6
|
+
import logging
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from ..utils.math_utils import safe_sigmoid, safe_log, bernoulli_log_prob, bernoulli_entropy
|
|
10
|
+
from ..core import Triple
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ProbabilityTransform(nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
Temperature scaling with optional per-group (e.g., per-relation) temperatures.
|
|
18
|
+
Guarantees T>0 via softplus on an unconstrained log-T parameter.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, temperature: float = 1.0, per_group: bool = False, num_groups: int = 1):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.per_group = per_group
|
|
24
|
+
init = torch.log(torch.expm1(torch.tensor(float(temperature))))
|
|
25
|
+
if per_group:
|
|
26
|
+
assert num_groups >= 1
|
|
27
|
+
self.log_t = nn.Parameter(init.repeat(int(num_groups)))
|
|
28
|
+
else:
|
|
29
|
+
self.log_t = nn.Parameter(init.unsqueeze(0))
|
|
30
|
+
|
|
31
|
+
def _temperature(self, device, group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
32
|
+
T = F.softplus(self.log_t.to(device)) + 1e-6
|
|
33
|
+
if self.per_group:
|
|
34
|
+
if group_ids is None:
|
|
35
|
+
raise ValueError("group_ids is required when per_group=True")
|
|
36
|
+
return T.index_select(0, group_ids)
|
|
37
|
+
return T[0]
|
|
38
|
+
|
|
39
|
+
def ensure_num_groups(self, num_groups: int):
|
|
40
|
+
if not self.per_group:
|
|
41
|
+
return
|
|
42
|
+
cur = self.log_t.numel()
|
|
43
|
+
if num_groups <= cur:
|
|
44
|
+
return
|
|
45
|
+
with torch.no_grad():
|
|
46
|
+
new = torch.empty(num_groups, device=self.log_t.device, dtype=self.log_t.dtype)
|
|
47
|
+
new[:cur] = self.log_t.data
|
|
48
|
+
init = torch.log(torch.expm1(torch.tensor(1.0, device=new.device, dtype=new.dtype)))
|
|
49
|
+
new[cur:] = init
|
|
50
|
+
self.log_t = nn.Parameter(new)
|
|
51
|
+
|
|
52
|
+
def forward(self, scores: torch.Tensor,
|
|
53
|
+
apply_temperature: bool = True,
|
|
54
|
+
group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
55
|
+
scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
|
|
56
|
+
return torch.sigmoid(scaled)
|
|
57
|
+
|
|
58
|
+
def log_probs(self, scores: torch.Tensor,
|
|
59
|
+
apply_temperature: bool = True,
|
|
60
|
+
group_ids: Optional[torch.LongTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
61
|
+
scaled = scores if not apply_temperature else scores / self._temperature(scores.device, group_ids)
|
|
62
|
+
return F.logsigmoid(scaled), F.logsigmoid(-scaled)
|
|
63
|
+
|
|
64
|
+
# Backward-compat alias
|
|
65
|
+
def get_log_probabilities(self, scores: torch.Tensor, apply_temperature: bool = True,
|
|
66
|
+
group_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
|
67
|
+
log_p, _ = self.log_probs(scores, apply_temperature, group_ids)
|
|
68
|
+
return log_p
|
|
69
|
+
|
|
70
|
+
@torch.no_grad()
|
|
71
|
+
def calibrate_temperature(self,
|
|
72
|
+
logits: torch.Tensor, # raw scores before sigmoid
|
|
73
|
+
labels: torch.Tensor, # {0,1}
|
|
74
|
+
max_iter: int = 100,
|
|
75
|
+
weight: Optional[torch.Tensor] = None,
|
|
76
|
+
group_ids: Optional[torch.LongTensor] = None) -> float:
|
|
77
|
+
"""
|
|
78
|
+
Temperature scaling on held-out logits using BCEWithLogitsLoss.
|
|
79
|
+
Optimizes log-T only.
|
|
80
|
+
"""
|
|
81
|
+
logits = logits.detach()
|
|
82
|
+
labels = labels.float().detach()
|
|
83
|
+
if weight is not None:
|
|
84
|
+
weight = weight.to(logits.device)
|
|
85
|
+
|
|
86
|
+
optimizer = torch.optim.LBFGS([self.log_t], max_iter=max_iter)
|
|
87
|
+
|
|
88
|
+
def closure():
|
|
89
|
+
optimizer.zero_grad()
|
|
90
|
+
# Always divide by T (scalar for global, vector for per-group)
|
|
91
|
+
if self.per_group:
|
|
92
|
+
T = self._temperature(logits.device, group_ids)
|
|
93
|
+
else:
|
|
94
|
+
T = self._temperature(logits.device, None)
|
|
95
|
+
scaled = logits / T
|
|
96
|
+
loss = F.binary_cross_entropy_with_logits(scaled, labels, weight=weight)
|
|
97
|
+
loss.backward()
|
|
98
|
+
return loss
|
|
99
|
+
|
|
100
|
+
optimizer.step(closure)
|
|
101
|
+
T = (F.softplus(self.log_t) + 1e-6).detach().cpu()
|
|
102
|
+
logger.info(f"Temperature calibrated; mean T={T.mean().item():.4f}")
|
|
103
|
+
return T.mean().item()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class FactProbabilityComputer:
|
|
107
|
+
"""
|
|
108
|
+
Computes probabilities for facts in the context of NPLL
|
|
109
|
+
Handles both known facts F and unknown facts U
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __init__(self, probability_transform: ProbabilityTransform):
|
|
113
|
+
self.prob_transform = probability_transform
|
|
114
|
+
|
|
115
|
+
def compute_fact_probabilities(self, scores: torch.Tensor,
|
|
116
|
+
fact_types: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
|
|
117
|
+
"""
|
|
118
|
+
Compute probabilities and stable log-probabilities for facts from scores.
|
|
119
|
+
"""
|
|
120
|
+
probabilities = self.prob_transform(scores)
|
|
121
|
+
log_p, log1m_p = self.prob_transform.log_probs(scores)
|
|
122
|
+
result = {
|
|
123
|
+
'probabilities': probabilities,
|
|
124
|
+
'log_probabilities': log_p,
|
|
125
|
+
'neg_log_probabilities': log1m_p
|
|
126
|
+
}
|
|
127
|
+
if fact_types is not None:
|
|
128
|
+
device = probabilities.device
|
|
129
|
+
known_mask = torch.tensor([ft == 'known' for ft in fact_types], device=device, dtype=torch.bool)
|
|
130
|
+
unknown_mask = ~known_mask
|
|
131
|
+
if torch.any(known_mask):
|
|
132
|
+
result['known_probabilities'] = probabilities[known_mask]
|
|
133
|
+
if torch.any(unknown_mask):
|
|
134
|
+
result['unknown_probabilities'] = probabilities[unknown_mask]
|
|
135
|
+
return result
|
|
136
|
+
|
|
137
|
+
def compute_bernoulli_parameters(self, scores: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
138
|
+
"""
|
|
139
|
+
Compute Bernoulli distribution parameters for facts
|
|
140
|
+
"""
|
|
141
|
+
probabilities = self.prob_transform(scores)
|
|
142
|
+
log_p, log1m_p = self.prob_transform.log_probs(scores)
|
|
143
|
+
return {
|
|
144
|
+
'success_prob': probabilities,
|
|
145
|
+
'failure_prob': 1 - probabilities,
|
|
146
|
+
'log_success_prob': log_p,
|
|
147
|
+
'log_failure_prob': log1m_p,
|
|
148
|
+
'entropy': bernoulli_entropy(probabilities)
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ApproximatePosteriorComputer:
|
|
153
|
+
"""
|
|
154
|
+
Computes approximate posterior distribution Q(U) as described in paper Section 4.2
|
|
155
|
+
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(self, probability_transform: ProbabilityTransform):
|
|
159
|
+
self.prob_transform = probability_transform
|
|
160
|
+
self.fact_prob_computer = FactProbabilityComputer(probability_transform)
|
|
161
|
+
|
|
162
|
+
def compute_q_u_distribution(self, unknown_fact_scores: torch.Tensor,
|
|
163
|
+
ground_rule_structure: Optional[List[List[int]]] = None) -> Dict[str, torch.Tensor]:
|
|
164
|
+
"""
|
|
165
|
+
Compute approximate posterior distribution Q(U)
|
|
166
|
+
|
|
167
|
+
"""
|
|
168
|
+
# Get Bernoulli parameters for unknown facts
|
|
169
|
+
bernoulli_params = self.fact_prob_computer.compute_bernoulli_parameters(unknown_fact_scores)
|
|
170
|
+
|
|
171
|
+
result = {
|
|
172
|
+
'fact_probabilities': bernoulli_params['success_prob'],
|
|
173
|
+
'fact_log_probabilities': bernoulli_params['log_success_prob'],
|
|
174
|
+
'fact_entropies': bernoulli_params['entropy']
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
# If ground rule structure provided, compute ground rule probabilities
|
|
178
|
+
if ground_rule_structure is not None:
|
|
179
|
+
ground_rule_probs = []
|
|
180
|
+
ground_rule_log_probs = []
|
|
181
|
+
|
|
182
|
+
for rule_fact_indices in ground_rule_structure:
|
|
183
|
+
if rule_fact_indices:
|
|
184
|
+
# Product of fact probabilities in this ground rule
|
|
185
|
+
rule_fact_probs = bernoulli_params['success_prob'][rule_fact_indices]
|
|
186
|
+
rule_prob = torch.prod(rule_fact_probs)
|
|
187
|
+
rule_log_prob = torch.sum(bernoulli_params['log_success_prob'][rule_fact_indices])
|
|
188
|
+
|
|
189
|
+
ground_rule_probs.append(rule_prob)
|
|
190
|
+
ground_rule_log_probs.append(rule_log_prob)
|
|
191
|
+
|
|
192
|
+
if ground_rule_probs:
|
|
193
|
+
result['ground_rule_probabilities'] = torch.stack(ground_rule_probs)
|
|
194
|
+
result['ground_rule_log_probabilities'] = torch.stack(ground_rule_log_probs)
|
|
195
|
+
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
def compute_expected_counts(self, fact_probabilities: torch.Tensor,
|
|
199
|
+
ground_rule_structure: List[List[int]]) -> torch.Tensor:
|
|
200
|
+
"""
|
|
201
|
+
Compute expected counts N(F,U) for ground rules
|
|
202
|
+
|
|
203
|
+
"""
|
|
204
|
+
expected_counts = []
|
|
205
|
+
|
|
206
|
+
for rule_fact_indices in ground_rule_structure:
|
|
207
|
+
if rule_fact_indices:
|
|
208
|
+
# Expected count for this ground rule is product of fact probabilities
|
|
209
|
+
rule_fact_probs = fact_probabilities[rule_fact_indices]
|
|
210
|
+
expected_count = torch.prod(rule_fact_probs)
|
|
211
|
+
expected_counts.append(expected_count)
|
|
212
|
+
else:
|
|
213
|
+
expected_counts.append(torch.tensor(0.0, device=fact_probabilities.device))
|
|
214
|
+
|
|
215
|
+
return torch.stack(expected_counts) if expected_counts else torch.empty(0, device=fact_probabilities.device)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class ProbabilityCalibrator:
|
|
219
|
+
"""
|
|
220
|
+
Post-hoc probability calibration using various methods
|
|
221
|
+
Improves reliability of confidence estimates
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def __init__(self, method: str = 'platt'):
|
|
225
|
+
"""
|
|
226
|
+
Args:
|
|
227
|
+
method: Calibration method ('platt', 'isotonic', 'temperature')
|
|
228
|
+
"""
|
|
229
|
+
self.method = method
|
|
230
|
+
self.calibration_function = None
|
|
231
|
+
self.is_fitted = False
|
|
232
|
+
|
|
233
|
+
def fit(self, predicted_probs: np.ndarray, true_labels: np.ndarray):
|
|
234
|
+
"""
|
|
235
|
+
Fit calibration function to data
|
|
236
|
+
"""
|
|
237
|
+
if self.method == 'platt':
|
|
238
|
+
from sklearn.calibration import CalibratedClassifierCV
|
|
239
|
+
from sklearn.linear_model import LogisticRegression
|
|
240
|
+
|
|
241
|
+
# Platt scaling using logistic regression
|
|
242
|
+
self.calibration_function = LogisticRegression()
|
|
243
|
+
self.calibration_function.fit(predicted_probs.reshape(-1, 1), true_labels)
|
|
244
|
+
|
|
245
|
+
elif self.method == 'isotonic':
|
|
246
|
+
from sklearn.isotonic import IsotonicRegression
|
|
247
|
+
|
|
248
|
+
self.calibration_function = IsotonicRegression(out_of_bounds='clip')
|
|
249
|
+
self.calibration_function.fit(predicted_probs, true_labels)
|
|
250
|
+
|
|
251
|
+
elif self.method == 'temperature':
|
|
252
|
+
# Temperature scaling (implemented in ProbabilityTransform)
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
self.is_fitted = True
|
|
256
|
+
logger.info(f"Calibration function fitted using {self.method}")
|
|
257
|
+
|
|
258
|
+
def transform(self, predicted_probs: np.ndarray) -> np.ndarray:
|
|
259
|
+
"""Apply calibration to predicted probabilities"""
|
|
260
|
+
if not self.is_fitted:
|
|
261
|
+
logger.warning("Calibration function not fitted, returning original probabilities")
|
|
262
|
+
return predicted_probs
|
|
263
|
+
|
|
264
|
+
if self.method == 'platt':
|
|
265
|
+
return self.calibration_function.predict_proba(predicted_probs.reshape(-1, 1))[:, 1]
|
|
266
|
+
elif self.method == 'isotonic':
|
|
267
|
+
return self.calibration_function.transform(predicted_probs)
|
|
268
|
+
else:
|
|
269
|
+
return predicted_probs
|
|
270
|
+
|
|
271
|
+
def compute_calibration_error(self, predicted_probs: np.ndarray,
|
|
272
|
+
true_labels: np.ndarray,
|
|
273
|
+
n_bins: int = 10) -> float:
|
|
274
|
+
"""
|
|
275
|
+
Compute Expected Calibration Error (ECE)
|
|
276
|
+
|
|
277
|
+
ECE measures the difference between confidence and accuracy
|
|
278
|
+
"""
|
|
279
|
+
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
|
280
|
+
bin_lowers = bin_boundaries[:-1]
|
|
281
|
+
bin_uppers = bin_boundaries[1:]
|
|
282
|
+
|
|
283
|
+
ece = 0
|
|
284
|
+
total_samples = len(predicted_probs)
|
|
285
|
+
|
|
286
|
+
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
|
287
|
+
# Find predictions in this bin
|
|
288
|
+
in_bin = (predicted_probs > bin_lower) & (predicted_probs <= bin_upper)
|
|
289
|
+
prop_in_bin = in_bin.sum() / total_samples
|
|
290
|
+
|
|
291
|
+
if prop_in_bin > 0:
|
|
292
|
+
accuracy_in_bin = true_labels[in_bin].mean()
|
|
293
|
+
avg_confidence_in_bin = predicted_probs[in_bin].mean()
|
|
294
|
+
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
|
295
|
+
|
|
296
|
+
return ece
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class ConfidenceEstimator:
|
|
300
|
+
"""
|
|
301
|
+
Estimates confidence scores for logical rules using scoring module outputs
|
|
302
|
+
Integrates with NPLL's probabilistic framework
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
def __init__(self, probability_transform: ProbabilityTransform):
|
|
306
|
+
self.prob_transform = probability_transform
|
|
307
|
+
self.posterior_computer = ApproximatePosteriorComputer(probability_transform)
|
|
308
|
+
|
|
309
|
+
def estimate_rule_confidence(self, rule_satisfaction_scores: torch.Tensor,
|
|
310
|
+
supporting_evidence_scores: torch.Tensor) -> Dict[str, float]:
|
|
311
|
+
"""
|
|
312
|
+
Estimate confidence for a logical rule based on satisfaction and evidence
|
|
313
|
+
"""
|
|
314
|
+
# Convert scores to probabilities
|
|
315
|
+
satisfaction_probs = self.prob_transform(rule_satisfaction_scores)
|
|
316
|
+
evidence_probs = self.prob_transform(supporting_evidence_scores)
|
|
317
|
+
|
|
318
|
+
# Compute various confidence metrics
|
|
319
|
+
confidence_metrics = {
|
|
320
|
+
'mean_satisfaction': satisfaction_probs.mean().item(),
|
|
321
|
+
'median_satisfaction': satisfaction_probs.median().item(),
|
|
322
|
+
'min_satisfaction': satisfaction_probs.min().item(),
|
|
323
|
+
'max_satisfaction': satisfaction_probs.max().item(),
|
|
324
|
+
'std_satisfaction': satisfaction_probs.std().item(),
|
|
325
|
+
'mean_evidence': evidence_probs.mean().item(),
|
|
326
|
+
'evidence_strength': evidence_probs.sum().item(),
|
|
327
|
+
'num_supporting_instances': (evidence_probs > 0.5).sum().item()
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
# Combined confidence score (weighted average)
|
|
331
|
+
combined_confidence = (
|
|
332
|
+
0.7 * confidence_metrics['mean_satisfaction'] +
|
|
333
|
+
0.3 * confidence_metrics['mean_evidence']
|
|
334
|
+
)
|
|
335
|
+
confidence_metrics['combined_confidence'] = combined_confidence
|
|
336
|
+
|
|
337
|
+
return confidence_metrics
|
|
338
|
+
|
|
339
|
+
def compute_uncertainty_measures(self, probabilities: torch.Tensor) -> Dict[str, float]:
|
|
340
|
+
"""
|
|
341
|
+
Compute various uncertainty measures for probability estimates
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Dictionary with uncertainty metrics
|
|
345
|
+
"""
|
|
346
|
+
# Entropy-based uncertainty
|
|
347
|
+
entropy = bernoulli_entropy(probabilities).mean().item()
|
|
348
|
+
|
|
349
|
+
# Variance-based uncertainty
|
|
350
|
+
variance = (probabilities * (1 - probabilities)).mean().item()
|
|
351
|
+
|
|
352
|
+
# Confidence intervals (assuming independence)
|
|
353
|
+
confidence_95 = 1.96 * torch.sqrt(probabilities * (1 - probabilities))
|
|
354
|
+
mean_ci_width = confidence_95.mean().item()
|
|
355
|
+
|
|
356
|
+
return {
|
|
357
|
+
'entropy': entropy,
|
|
358
|
+
'variance': variance,
|
|
359
|
+
'mean_ci_width_95': mean_ci_width,
|
|
360
|
+
'prediction_uncertainty': entropy, # Alias for compatibility
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def create_probability_components(temperature: float = 1.0,
|
|
365
|
+
per_relation: bool = False,
|
|
366
|
+
num_relations: int = 1) -> Tuple[ProbabilityTransform, ApproximatePosteriorComputer]:
|
|
367
|
+
"""
|
|
368
|
+
Factory function to create probability computation components.
|
|
369
|
+
Supports optional per-relation temperature scaling.
|
|
370
|
+
"""
|
|
371
|
+
prob_transform = ProbabilityTransform(temperature=temperature,
|
|
372
|
+
per_group=per_relation,
|
|
373
|
+
num_groups=(num_relations if per_relation else 1))
|
|
374
|
+
posterior_computer = ApproximatePosteriorComputer(prob_transform)
|
|
375
|
+
return prob_transform, posterior_computer
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def verify_probability_computations():
|
|
379
|
+
"""Verify probability computation implementations"""
|
|
380
|
+
# Test probability transform
|
|
381
|
+
prob_transform = ProbabilityTransform(temperature=1.0)
|
|
382
|
+
|
|
383
|
+
# Test scores
|
|
384
|
+
test_scores = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
|
|
385
|
+
probs = prob_transform(test_scores, apply_temperature=False)
|
|
386
|
+
|
|
387
|
+
# Verify sigmoid properties
|
|
388
|
+
assert torch.all(probs >= 0) and torch.all(probs <= 1), "Probabilities not in [0,1]"
|
|
389
|
+
assert abs(probs[2].item() - 0.5) < 1e-5, "sigmoid(0) should be 0.5"
|
|
390
|
+
|
|
391
|
+
# Test Bernoulli computations
|
|
392
|
+
fact_computer = FactProbabilityComputer(prob_transform)
|
|
393
|
+
bernoulli_params = fact_computer.compute_bernoulli_parameters(test_scores)
|
|
394
|
+
|
|
395
|
+
# Verify Bernoulli properties
|
|
396
|
+
success_prob = bernoulli_params['success_prob']
|
|
397
|
+
failure_prob = bernoulli_params['failure_prob']
|
|
398
|
+
assert torch.allclose(success_prob + failure_prob, torch.ones_like(success_prob)), \
|
|
399
|
+
"Success + failure probabilities should sum to 1"
|
|
400
|
+
|
|
401
|
+
logger.info("Probability computation implementations verified successfully")
|
|
402
|
+
|
|
403
403
|
return True
|