odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/utils/math_utils.py
CHANGED
|
@@ -1,339 +1,339 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Mathematical utilities for NPLL implementation
|
|
3
|
-
Implements key mathematical functions from the paper
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
import numpy as np
|
|
9
|
-
from typing import List, Dict, Tuple, Optional, Union
|
|
10
|
-
import math
|
|
11
|
-
import logging
|
|
12
|
-
|
|
13
|
-
logger = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def log_sum_exp(values: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
|
|
17
|
-
"""
|
|
18
|
-
Numerically stable log-sum-exp operation
|
|
19
|
-
Used for partition function computation in MLN (Equation 2)
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
values: Input tensor
|
|
23
|
-
dim: Dimension to sum over
|
|
24
|
-
keepdim: Whether to keep dimension
|
|
25
|
-
|
|
26
|
-
Returns:
|
|
27
|
-
log(sum(exp(values))) computed stably
|
|
28
|
-
"""
|
|
29
|
-
max_val, _ = values.max(dim=dim, keepdim=True)
|
|
30
|
-
|
|
31
|
-
# Handle case where all values are -inf
|
|
32
|
-
max_val = torch.where(torch.isfinite(max_val), max_val, torch.zeros_like(max_val))
|
|
33
|
-
|
|
34
|
-
# Compute log-sum-exp
|
|
35
|
-
result = max_val + torch.log(torch.sum(torch.exp(values - max_val), dim=dim, keepdim=True))
|
|
36
|
-
|
|
37
|
-
if not keepdim:
|
|
38
|
-
result = result.squeeze(dim)
|
|
39
|
-
|
|
40
|
-
return result
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def safe_log(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
|
44
|
-
"""
|
|
45
|
-
Safe logarithm that clamps input to avoid log(0)
|
|
46
|
-
Used throughout NPLL for probability computations
|
|
47
|
-
"""
|
|
48
|
-
return torch.log(torch.clamp(x, min=eps))
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def safe_sigmoid(x: torch.Tensor) -> torch.Tensor:
|
|
52
|
-
"""
|
|
53
|
-
Numerically stable sigmoid function
|
|
54
|
-
Used in scoring module probability transformation (Section 4.1)
|
|
55
|
-
"""
|
|
56
|
-
return torch.sigmoid(torch.clamp(x, min=-50, max=50))
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def partition_function_approximation(rule_weights: torch.Tensor,
|
|
60
|
-
ground_rule_counts: torch.Tensor,
|
|
61
|
-
use_log_domain: bool = True) -> torch.Tensor:
|
|
62
|
-
"""
|
|
63
|
-
Approximate MLN partition function Z(ω) from Equation 2
|
|
64
|
-
|
|
65
|
-
Paper Equation 2: Z(ω) = Σ_{F,U} ∏_{r∈R} exp(ωr * N(F,U))
|
|
66
|
-
|
|
67
|
-
For large knowledge graphs, exact computation is intractable,
|
|
68
|
-
so we use sampling-based approximation
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
rule_weights: Tensor of shape [num_rules] - rule weights ω
|
|
72
|
-
ground_rule_counts: Tensor of shape [num_samples, num_rules] - N(F,U) values
|
|
73
|
-
use_log_domain: Whether to compute in log domain for stability
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
Approximation of partition function
|
|
77
|
-
"""
|
|
78
|
-
if use_log_domain:
|
|
79
|
-
# Compute in log domain: log(Z) = log_sum_exp(Σ_r ωr * N(F,U))
|
|
80
|
-
log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
|
|
81
|
-
log_partition = log_sum_exp(log_potentials, dim=0)
|
|
82
|
-
return log_partition
|
|
83
|
-
else:
|
|
84
|
-
# Direct computation (less stable)
|
|
85
|
-
potentials = torch.exp(torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1))
|
|
86
|
-
return torch.sum(potentials)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def compute_mln_probability(rule_weights: torch.Tensor,
|
|
90
|
-
ground_rule_counts: torch.Tensor,
|
|
91
|
-
log_partition: torch.Tensor) -> torch.Tensor:
|
|
92
|
-
"""
|
|
93
|
-
Compute MLN probability P(F,U|ω) from Equation 1
|
|
94
|
-
|
|
95
|
-
Paper Equation 1: P(F,U|ω) = (1/Z(ω)) * ∏_{r∈R} exp(ωr * N(F,U))
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
rule_weights: Tensor of shape [num_rules] - ω values
|
|
99
|
-
ground_rule_counts: Tensor of shape [batch_size, num_rules] - N(F,U) values
|
|
100
|
-
log_partition: Log partition function log(Z(ω))
|
|
101
|
-
|
|
102
|
-
Returns:
|
|
103
|
-
Log probabilities of shape [batch_size]
|
|
104
|
-
"""
|
|
105
|
-
# Compute log potential: Σ_r ωr * N(F,U)
|
|
106
|
-
log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
|
|
107
|
-
|
|
108
|
-
# Subtract log partition function
|
|
109
|
-
log_probabilities = log_potentials - log_partition
|
|
110
|
-
|
|
111
|
-
return log_probabilities
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def compute_elbo_loss(predicted_probs: torch.Tensor,
|
|
115
|
-
approximate_posterior: torch.Tensor,
|
|
116
|
-
rule_weights: torch.Tensor,
|
|
117
|
-
ground_rule_counts: torch.Tensor,
|
|
118
|
-
log_partition: torch.Tensor) -> torch.Tensor:
|
|
119
|
-
"""
|
|
120
|
-
Compute ELBO loss from Equation 5
|
|
121
|
-
|
|
122
|
-
Paper Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
predicted_probs: Predicted fact probabilities from scoring module
|
|
126
|
-
approximate_posterior: Q(U) distribution
|
|
127
|
-
rule_weights: MLN rule weights ω
|
|
128
|
-
ground_rule_counts: Ground rule satisfaction counts N(F,U)
|
|
129
|
-
log_partition: Log partition function
|
|
130
|
-
|
|
131
|
-
Returns:
|
|
132
|
-
Negative ELBO loss (to minimize)
|
|
133
|
-
"""
|
|
134
|
-
# First term: Σ_U Q(U) log P(F,U|ω)
|
|
135
|
-
log_joint_probs = compute_mln_probability(rule_weights, ground_rule_counts, log_partition)
|
|
136
|
-
joint_term = torch.sum(approximate_posterior * log_joint_probs)
|
|
137
|
-
|
|
138
|
-
# Second term: -Σ_U Q(U) log Q(U) (entropy)
|
|
139
|
-
entropy_term = -torch.sum(approximate_posterior * safe_log(approximate_posterior))
|
|
140
|
-
|
|
141
|
-
elbo = joint_term + entropy_term
|
|
142
|
-
|
|
143
|
-
# Return negative ELBO as loss (we want to maximize ELBO)
|
|
144
|
-
return -elbo
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def bernoulli_entropy(p: torch.Tensor) -> torch.Tensor:
|
|
148
|
-
"""
|
|
149
|
-
Compute entropy of Bernoulli distribution
|
|
150
|
-
Used in E-step for Q(U) entropy computation
|
|
151
|
-
|
|
152
|
-
H(p) = -p*log(p) - (1-p)*log(1-p)
|
|
153
|
-
"""
|
|
154
|
-
return -(p * safe_log(p) + (1 - p) * safe_log(1 - p))
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def bernoulli_log_prob(value: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
|
|
158
|
-
"""
|
|
159
|
-
Compute log probability of Bernoulli distribution
|
|
160
|
-
Used for fact probability computations in E-step
|
|
161
|
-
|
|
162
|
-
log P(x=value) = value*log(p) + (1-value)*log(1-p)
|
|
163
|
-
"""
|
|
164
|
-
return value * safe_log(prob) + (1 - value) * safe_log(1 - prob)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def compute_markov_blanket_prob(fact_prob: torch.Tensor,
|
|
168
|
-
neighbor_probs: torch.Tensor,
|
|
169
|
-
rule_weights: torch.Tensor) -> torch.Tensor:
|
|
170
|
-
"""
|
|
171
|
-
Compute probability of fact given Markov blanket
|
|
172
|
-
Used in M-step pseudo-likelihood computation (Equation 13)
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
fact_prob: Probability of target fact
|
|
176
|
-
neighbor_probs: Probabilities of facts in Markov blanket
|
|
177
|
-
rule_weights: Weights of rules involving this fact
|
|
178
|
-
|
|
179
|
-
Returns:
|
|
180
|
-
P(uk | Markov Blanket)
|
|
181
|
-
"""
|
|
182
|
-
# Simplified computation - in practice this involves complex inference
|
|
183
|
-
# over the local Markov network structure
|
|
184
|
-
|
|
185
|
-
# Compute local potential based on neighboring facts and rule weights
|
|
186
|
-
local_potential = torch.sum(rule_weights * neighbor_probs)
|
|
187
|
-
|
|
188
|
-
# Normalize using sigmoid
|
|
189
|
-
return torch.sigmoid(local_potential + torch.logit(fact_prob))
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def temperature_scaling(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
|
193
|
-
"""
|
|
194
|
-
Apply temperature scaling for confidence calibration
|
|
195
|
-
Used in paper for calibrating confidence estimates
|
|
196
|
-
|
|
197
|
-
Args:
|
|
198
|
-
logits: Input logits
|
|
199
|
-
temperature: Temperature parameter (1.0 = no scaling)
|
|
200
|
-
|
|
201
|
-
Returns:
|
|
202
|
-
Temperature-scaled probabilities
|
|
203
|
-
"""
|
|
204
|
-
scaled_logits = logits / temperature
|
|
205
|
-
return torch.softmax(scaled_logits, dim=-1)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def kl_divergence_bernoulli(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
|
|
209
|
-
"""
|
|
210
|
-
Compute KL divergence between two Bernoulli distributions
|
|
211
|
-
KL(p||q) = p*log(p/q) + (1-p)*log((1-p)/(1-q))
|
|
212
|
-
|
|
213
|
-
Used for measuring distance between true and approximate posteriors
|
|
214
|
-
"""
|
|
215
|
-
eps = 1e-8
|
|
216
|
-
p = torch.clamp(p, eps, 1 - eps)
|
|
217
|
-
q = torch.clamp(q, eps, 1 - eps)
|
|
218
|
-
|
|
219
|
-
kl = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
|
|
220
|
-
return kl
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def gradient_clipping(parameters: List[torch.nn.Parameter],
|
|
224
|
-
max_norm: float = 1.0) -> float:
|
|
225
|
-
"""
|
|
226
|
-
Clip gradients to prevent exploding gradients
|
|
227
|
-
Returns the total norm before clipping
|
|
228
|
-
"""
|
|
229
|
-
total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
|
|
230
|
-
return total_norm.item()
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def compute_metrics(predictions: torch.Tensor,
|
|
234
|
-
targets: torch.Tensor,
|
|
235
|
-
k_values: List[int] = [1, 3, 10]) -> Dict[str, float]:
|
|
236
|
-
"""
|
|
237
|
-
Compute evaluation metrics as specified in paper Section 5.2
|
|
238
|
-
|
|
239
|
-
Args:
|
|
240
|
-
predictions: Predicted scores/ranks [batch_size, num_entities]
|
|
241
|
-
targets: Target entity indices [batch_size]
|
|
242
|
-
k_values: Values of k for Hit@k computation
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
Dictionary with MRR and Hit@k metrics
|
|
246
|
-
"""
|
|
247
|
-
batch_size = predictions.size(0)
|
|
248
|
-
|
|
249
|
-
# Get ranks of target entities
|
|
250
|
-
sorted_indices = torch.argsort(predictions, dim=1, descending=True)
|
|
251
|
-
ranks = torch.zeros(batch_size, dtype=torch.float)
|
|
252
|
-
|
|
253
|
-
for i in range(batch_size):
|
|
254
|
-
target_idx = targets[i]
|
|
255
|
-
rank = (sorted_indices[i] == target_idx).nonzero(as_tuple=True)[0][0] + 1
|
|
256
|
-
ranks[i] = rank.float()
|
|
257
|
-
|
|
258
|
-
# Compute MRR
|
|
259
|
-
mrr = torch.mean(1.0 / ranks).item()
|
|
260
|
-
|
|
261
|
-
# Compute Hit@k
|
|
262
|
-
metrics = {'MRR': mrr}
|
|
263
|
-
for k in k_values:
|
|
264
|
-
hit_at_k = torch.mean((ranks <= k).float()).item()
|
|
265
|
-
metrics[f'Hit@{k}'] = hit_at_k
|
|
266
|
-
|
|
267
|
-
return metrics
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
def moving_average(values: List[float], window_size: int = 10) -> float:
|
|
271
|
-
"""Compute moving average of values"""
|
|
272
|
-
if len(values) < window_size:
|
|
273
|
-
return sum(values) / len(values) if values else 0.0
|
|
274
|
-
return sum(values[-window_size:]) / window_size
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
def cosine_similarity(x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
278
|
-
"""Compute cosine similarity between vectors"""
|
|
279
|
-
return F.cosine_similarity(x, y, dim=dim)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
283
|
-
"""Compute Euclidean distance between vectors"""
|
|
284
|
-
return torch.norm(x - y, p=2, dim=-1)
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
def sample_negative_triples(positive_triples: torch.Tensor,
|
|
288
|
-
num_entities: int,
|
|
289
|
-
corruption_mode: str = 'both') -> torch.Tensor:
|
|
290
|
-
"""
|
|
291
|
-
Sample negative triples for training by corrupting positive ones
|
|
292
|
-
|
|
293
|
-
Args:
|
|
294
|
-
positive_triples: Tensor of shape [batch_size, 3] (head, relation, tail)
|
|
295
|
-
num_entities: Total number of entities
|
|
296
|
-
corruption_mode: 'head', 'tail', or 'both'
|
|
297
|
-
|
|
298
|
-
Returns:
|
|
299
|
-
Negative triples tensor
|
|
300
|
-
"""
|
|
301
|
-
batch_size = positive_triples.size(0)
|
|
302
|
-
negative_triples = positive_triples.clone()
|
|
303
|
-
|
|
304
|
-
for i in range(batch_size):
|
|
305
|
-
if corruption_mode == 'head' or (corruption_mode == 'both' and i % 2 == 0):
|
|
306
|
-
# Corrupt head entity
|
|
307
|
-
negative_triples[i, 0] = torch.randint(0, num_entities, (1,))
|
|
308
|
-
else:
|
|
309
|
-
# Corrupt tail entity
|
|
310
|
-
negative_triples[i, 2] = torch.randint(0, num_entities, (1,))
|
|
311
|
-
|
|
312
|
-
return negative_triples
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
class NumericalStabilizer:
|
|
316
|
-
"""Utility class for numerical stability in computations"""
|
|
317
|
-
|
|
318
|
-
@staticmethod
|
|
319
|
-
def stabilize_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
320
|
-
"""Numerically stable softmax"""
|
|
321
|
-
x_max, _ = torch.max(x, dim=dim, keepdim=True)
|
|
322
|
-
exp_x = torch.exp(x - x_max)
|
|
323
|
-
return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
|
|
324
|
-
|
|
325
|
-
@staticmethod
|
|
326
|
-
def stabilize_log_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
327
|
-
"""Numerically stable log-softmax"""
|
|
328
|
-
return F.log_softmax(x, dim=dim)
|
|
329
|
-
|
|
330
|
-
@staticmethod
|
|
331
|
-
def clamp_probabilities(p: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
332
|
-
"""Clamp probabilities to valid range"""
|
|
333
|
-
return torch.clamp(p, min=eps, max=1.0 - eps)
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
# Constants for numerical stability
|
|
337
|
-
EPS = 1e-8
|
|
338
|
-
LOG_EPS = math.log(EPS)
|
|
1
|
+
"""
|
|
2
|
+
Mathematical utilities for NPLL implementation
|
|
3
|
+
Implements key mathematical functions from the paper
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import numpy as np
|
|
9
|
+
from typing import List, Dict, Tuple, Optional, Union
|
|
10
|
+
import math
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def log_sum_exp(values: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
|
|
17
|
+
"""
|
|
18
|
+
Numerically stable log-sum-exp operation
|
|
19
|
+
Used for partition function computation in MLN (Equation 2)
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
values: Input tensor
|
|
23
|
+
dim: Dimension to sum over
|
|
24
|
+
keepdim: Whether to keep dimension
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
log(sum(exp(values))) computed stably
|
|
28
|
+
"""
|
|
29
|
+
max_val, _ = values.max(dim=dim, keepdim=True)
|
|
30
|
+
|
|
31
|
+
# Handle case where all values are -inf
|
|
32
|
+
max_val = torch.where(torch.isfinite(max_val), max_val, torch.zeros_like(max_val))
|
|
33
|
+
|
|
34
|
+
# Compute log-sum-exp
|
|
35
|
+
result = max_val + torch.log(torch.sum(torch.exp(values - max_val), dim=dim, keepdim=True))
|
|
36
|
+
|
|
37
|
+
if not keepdim:
|
|
38
|
+
result = result.squeeze(dim)
|
|
39
|
+
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def safe_log(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
|
44
|
+
"""
|
|
45
|
+
Safe logarithm that clamps input to avoid log(0)
|
|
46
|
+
Used throughout NPLL for probability computations
|
|
47
|
+
"""
|
|
48
|
+
return torch.log(torch.clamp(x, min=eps))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def safe_sigmoid(x: torch.Tensor) -> torch.Tensor:
|
|
52
|
+
"""
|
|
53
|
+
Numerically stable sigmoid function
|
|
54
|
+
Used in scoring module probability transformation (Section 4.1)
|
|
55
|
+
"""
|
|
56
|
+
return torch.sigmoid(torch.clamp(x, min=-50, max=50))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def partition_function_approximation(rule_weights: torch.Tensor,
|
|
60
|
+
ground_rule_counts: torch.Tensor,
|
|
61
|
+
use_log_domain: bool = True) -> torch.Tensor:
|
|
62
|
+
"""
|
|
63
|
+
Approximate MLN partition function Z(ω) from Equation 2
|
|
64
|
+
|
|
65
|
+
Paper Equation 2: Z(ω) = Σ_{F,U} ∏_{r∈R} exp(ωr * N(F,U))
|
|
66
|
+
|
|
67
|
+
For large knowledge graphs, exact computation is intractable,
|
|
68
|
+
so we use sampling-based approximation
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
rule_weights: Tensor of shape [num_rules] - rule weights ω
|
|
72
|
+
ground_rule_counts: Tensor of shape [num_samples, num_rules] - N(F,U) values
|
|
73
|
+
use_log_domain: Whether to compute in log domain for stability
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Approximation of partition function
|
|
77
|
+
"""
|
|
78
|
+
if use_log_domain:
|
|
79
|
+
# Compute in log domain: log(Z) = log_sum_exp(Σ_r ωr * N(F,U))
|
|
80
|
+
log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
|
|
81
|
+
log_partition = log_sum_exp(log_potentials, dim=0)
|
|
82
|
+
return log_partition
|
|
83
|
+
else:
|
|
84
|
+
# Direct computation (less stable)
|
|
85
|
+
potentials = torch.exp(torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1))
|
|
86
|
+
return torch.sum(potentials)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def compute_mln_probability(rule_weights: torch.Tensor,
|
|
90
|
+
ground_rule_counts: torch.Tensor,
|
|
91
|
+
log_partition: torch.Tensor) -> torch.Tensor:
|
|
92
|
+
"""
|
|
93
|
+
Compute MLN probability P(F,U|ω) from Equation 1
|
|
94
|
+
|
|
95
|
+
Paper Equation 1: P(F,U|ω) = (1/Z(ω)) * ∏_{r∈R} exp(ωr * N(F,U))
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
rule_weights: Tensor of shape [num_rules] - ω values
|
|
99
|
+
ground_rule_counts: Tensor of shape [batch_size, num_rules] - N(F,U) values
|
|
100
|
+
log_partition: Log partition function log(Z(ω))
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Log probabilities of shape [batch_size]
|
|
104
|
+
"""
|
|
105
|
+
# Compute log potential: Σ_r ωr * N(F,U)
|
|
106
|
+
log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
|
|
107
|
+
|
|
108
|
+
# Subtract log partition function
|
|
109
|
+
log_probabilities = log_potentials - log_partition
|
|
110
|
+
|
|
111
|
+
return log_probabilities
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def compute_elbo_loss(predicted_probs: torch.Tensor,
|
|
115
|
+
approximate_posterior: torch.Tensor,
|
|
116
|
+
rule_weights: torch.Tensor,
|
|
117
|
+
ground_rule_counts: torch.Tensor,
|
|
118
|
+
log_partition: torch.Tensor) -> torch.Tensor:
|
|
119
|
+
"""
|
|
120
|
+
Compute ELBO loss from Equation 5
|
|
121
|
+
|
|
122
|
+
Paper Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
predicted_probs: Predicted fact probabilities from scoring module
|
|
126
|
+
approximate_posterior: Q(U) distribution
|
|
127
|
+
rule_weights: MLN rule weights ω
|
|
128
|
+
ground_rule_counts: Ground rule satisfaction counts N(F,U)
|
|
129
|
+
log_partition: Log partition function
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Negative ELBO loss (to minimize)
|
|
133
|
+
"""
|
|
134
|
+
# First term: Σ_U Q(U) log P(F,U|ω)
|
|
135
|
+
log_joint_probs = compute_mln_probability(rule_weights, ground_rule_counts, log_partition)
|
|
136
|
+
joint_term = torch.sum(approximate_posterior * log_joint_probs)
|
|
137
|
+
|
|
138
|
+
# Second term: -Σ_U Q(U) log Q(U) (entropy)
|
|
139
|
+
entropy_term = -torch.sum(approximate_posterior * safe_log(approximate_posterior))
|
|
140
|
+
|
|
141
|
+
elbo = joint_term + entropy_term
|
|
142
|
+
|
|
143
|
+
# Return negative ELBO as loss (we want to maximize ELBO)
|
|
144
|
+
return -elbo
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def bernoulli_entropy(p: torch.Tensor) -> torch.Tensor:
|
|
148
|
+
"""
|
|
149
|
+
Compute entropy of Bernoulli distribution
|
|
150
|
+
Used in E-step for Q(U) entropy computation
|
|
151
|
+
|
|
152
|
+
H(p) = -p*log(p) - (1-p)*log(1-p)
|
|
153
|
+
"""
|
|
154
|
+
return -(p * safe_log(p) + (1 - p) * safe_log(1 - p))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def bernoulli_log_prob(value: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
"""
|
|
159
|
+
Compute log probability of Bernoulli distribution
|
|
160
|
+
Used for fact probability computations in E-step
|
|
161
|
+
|
|
162
|
+
log P(x=value) = value*log(p) + (1-value)*log(1-p)
|
|
163
|
+
"""
|
|
164
|
+
return value * safe_log(prob) + (1 - value) * safe_log(1 - prob)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def compute_markov_blanket_prob(fact_prob: torch.Tensor,
|
|
168
|
+
neighbor_probs: torch.Tensor,
|
|
169
|
+
rule_weights: torch.Tensor) -> torch.Tensor:
|
|
170
|
+
"""
|
|
171
|
+
Compute probability of fact given Markov blanket
|
|
172
|
+
Used in M-step pseudo-likelihood computation (Equation 13)
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
fact_prob: Probability of target fact
|
|
176
|
+
neighbor_probs: Probabilities of facts in Markov blanket
|
|
177
|
+
rule_weights: Weights of rules involving this fact
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
P(uk | Markov Blanket)
|
|
181
|
+
"""
|
|
182
|
+
# Simplified computation - in practice this involves complex inference
|
|
183
|
+
# over the local Markov network structure
|
|
184
|
+
|
|
185
|
+
# Compute local potential based on neighboring facts and rule weights
|
|
186
|
+
local_potential = torch.sum(rule_weights * neighbor_probs)
|
|
187
|
+
|
|
188
|
+
# Normalize using sigmoid
|
|
189
|
+
return torch.sigmoid(local_potential + torch.logit(fact_prob))
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def temperature_scaling(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
|
193
|
+
"""
|
|
194
|
+
Apply temperature scaling for confidence calibration
|
|
195
|
+
Used in paper for calibrating confidence estimates
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
logits: Input logits
|
|
199
|
+
temperature: Temperature parameter (1.0 = no scaling)
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Temperature-scaled probabilities
|
|
203
|
+
"""
|
|
204
|
+
scaled_logits = logits / temperature
|
|
205
|
+
return torch.softmax(scaled_logits, dim=-1)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def kl_divergence_bernoulli(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
|
|
209
|
+
"""
|
|
210
|
+
Compute KL divergence between two Bernoulli distributions
|
|
211
|
+
KL(p||q) = p*log(p/q) + (1-p)*log((1-p)/(1-q))
|
|
212
|
+
|
|
213
|
+
Used for measuring distance between true and approximate posteriors
|
|
214
|
+
"""
|
|
215
|
+
eps = 1e-8
|
|
216
|
+
p = torch.clamp(p, eps, 1 - eps)
|
|
217
|
+
q = torch.clamp(q, eps, 1 - eps)
|
|
218
|
+
|
|
219
|
+
kl = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
|
|
220
|
+
return kl
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def gradient_clipping(parameters: List[torch.nn.Parameter],
|
|
224
|
+
max_norm: float = 1.0) -> float:
|
|
225
|
+
"""
|
|
226
|
+
Clip gradients to prevent exploding gradients
|
|
227
|
+
Returns the total norm before clipping
|
|
228
|
+
"""
|
|
229
|
+
total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
|
|
230
|
+
return total_norm.item()
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def compute_metrics(predictions: torch.Tensor,
|
|
234
|
+
targets: torch.Tensor,
|
|
235
|
+
k_values: List[int] = [1, 3, 10]) -> Dict[str, float]:
|
|
236
|
+
"""
|
|
237
|
+
Compute evaluation metrics as specified in paper Section 5.2
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
predictions: Predicted scores/ranks [batch_size, num_entities]
|
|
241
|
+
targets: Target entity indices [batch_size]
|
|
242
|
+
k_values: Values of k for Hit@k computation
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Dictionary with MRR and Hit@k metrics
|
|
246
|
+
"""
|
|
247
|
+
batch_size = predictions.size(0)
|
|
248
|
+
|
|
249
|
+
# Get ranks of target entities
|
|
250
|
+
sorted_indices = torch.argsort(predictions, dim=1, descending=True)
|
|
251
|
+
ranks = torch.zeros(batch_size, dtype=torch.float)
|
|
252
|
+
|
|
253
|
+
for i in range(batch_size):
|
|
254
|
+
target_idx = targets[i]
|
|
255
|
+
rank = (sorted_indices[i] == target_idx).nonzero(as_tuple=True)[0][0] + 1
|
|
256
|
+
ranks[i] = rank.float()
|
|
257
|
+
|
|
258
|
+
# Compute MRR
|
|
259
|
+
mrr = torch.mean(1.0 / ranks).item()
|
|
260
|
+
|
|
261
|
+
# Compute Hit@k
|
|
262
|
+
metrics = {'MRR': mrr}
|
|
263
|
+
for k in k_values:
|
|
264
|
+
hit_at_k = torch.mean((ranks <= k).float()).item()
|
|
265
|
+
metrics[f'Hit@{k}'] = hit_at_k
|
|
266
|
+
|
|
267
|
+
return metrics
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def moving_average(values: List[float], window_size: int = 10) -> float:
|
|
271
|
+
"""Compute moving average of values"""
|
|
272
|
+
if len(values) < window_size:
|
|
273
|
+
return sum(values) / len(values) if values else 0.0
|
|
274
|
+
return sum(values[-window_size:]) / window_size
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def cosine_similarity(x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
278
|
+
"""Compute cosine similarity between vectors"""
|
|
279
|
+
return F.cosine_similarity(x, y, dim=dim)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
283
|
+
"""Compute Euclidean distance between vectors"""
|
|
284
|
+
return torch.norm(x - y, p=2, dim=-1)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def sample_negative_triples(positive_triples: torch.Tensor,
|
|
288
|
+
num_entities: int,
|
|
289
|
+
corruption_mode: str = 'both') -> torch.Tensor:
|
|
290
|
+
"""
|
|
291
|
+
Sample negative triples for training by corrupting positive ones
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
positive_triples: Tensor of shape [batch_size, 3] (head, relation, tail)
|
|
295
|
+
num_entities: Total number of entities
|
|
296
|
+
corruption_mode: 'head', 'tail', or 'both'
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Negative triples tensor
|
|
300
|
+
"""
|
|
301
|
+
batch_size = positive_triples.size(0)
|
|
302
|
+
negative_triples = positive_triples.clone()
|
|
303
|
+
|
|
304
|
+
for i in range(batch_size):
|
|
305
|
+
if corruption_mode == 'head' or (corruption_mode == 'both' and i % 2 == 0):
|
|
306
|
+
# Corrupt head entity
|
|
307
|
+
negative_triples[i, 0] = torch.randint(0, num_entities, (1,))
|
|
308
|
+
else:
|
|
309
|
+
# Corrupt tail entity
|
|
310
|
+
negative_triples[i, 2] = torch.randint(0, num_entities, (1,))
|
|
311
|
+
|
|
312
|
+
return negative_triples
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class NumericalStabilizer:
|
|
316
|
+
"""Utility class for numerical stability in computations"""
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def stabilize_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
320
|
+
"""Numerically stable softmax"""
|
|
321
|
+
x_max, _ = torch.max(x, dim=dim, keepdim=True)
|
|
322
|
+
exp_x = torch.exp(x - x_max)
|
|
323
|
+
return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
|
|
324
|
+
|
|
325
|
+
@staticmethod
|
|
326
|
+
def stabilize_log_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
327
|
+
"""Numerically stable log-softmax"""
|
|
328
|
+
return F.log_softmax(x, dim=dim)
|
|
329
|
+
|
|
330
|
+
@staticmethod
|
|
331
|
+
def clamp_probabilities(p: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
332
|
+
"""Clamp probabilities to valid range"""
|
|
333
|
+
return torch.clamp(p, min=eps, max=1.0 - eps)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
# Constants for numerical stability
|
|
337
|
+
EPS = 1e-8
|
|
338
|
+
LOG_EPS = math.log(EPS)
|
|
339
339
|
LARGE_NUMBER = 1e8
|