odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/utils/math_utils.py CHANGED
@@ -1,339 +1,339 @@
1
- """
2
- Mathematical utilities for NPLL implementation
3
- Implements key mathematical functions from the paper
4
- """
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- import numpy as np
9
- from typing import List, Dict, Tuple, Optional, Union
10
- import math
11
- import logging
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- def log_sum_exp(values: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
17
- """
18
- Numerically stable log-sum-exp operation
19
- Used for partition function computation in MLN (Equation 2)
20
-
21
- Args:
22
- values: Input tensor
23
- dim: Dimension to sum over
24
- keepdim: Whether to keep dimension
25
-
26
- Returns:
27
- log(sum(exp(values))) computed stably
28
- """
29
- max_val, _ = values.max(dim=dim, keepdim=True)
30
-
31
- # Handle case where all values are -inf
32
- max_val = torch.where(torch.isfinite(max_val), max_val, torch.zeros_like(max_val))
33
-
34
- # Compute log-sum-exp
35
- result = max_val + torch.log(torch.sum(torch.exp(values - max_val), dim=dim, keepdim=True))
36
-
37
- if not keepdim:
38
- result = result.squeeze(dim)
39
-
40
- return result
41
-
42
-
43
- def safe_log(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
44
- """
45
- Safe logarithm that clamps input to avoid log(0)
46
- Used throughout NPLL for probability computations
47
- """
48
- return torch.log(torch.clamp(x, min=eps))
49
-
50
-
51
- def safe_sigmoid(x: torch.Tensor) -> torch.Tensor:
52
- """
53
- Numerically stable sigmoid function
54
- Used in scoring module probability transformation (Section 4.1)
55
- """
56
- return torch.sigmoid(torch.clamp(x, min=-50, max=50))
57
-
58
-
59
- def partition_function_approximation(rule_weights: torch.Tensor,
60
- ground_rule_counts: torch.Tensor,
61
- use_log_domain: bool = True) -> torch.Tensor:
62
- """
63
- Approximate MLN partition function Z(ω) from Equation 2
64
-
65
- Paper Equation 2: Z(ω) = Σ_{F,U} ∏_{r∈R} exp(ωr * N(F,U))
66
-
67
- For large knowledge graphs, exact computation is intractable,
68
- so we use sampling-based approximation
69
-
70
- Args:
71
- rule_weights: Tensor of shape [num_rules] - rule weights ω
72
- ground_rule_counts: Tensor of shape [num_samples, num_rules] - N(F,U) values
73
- use_log_domain: Whether to compute in log domain for stability
74
-
75
- Returns:
76
- Approximation of partition function
77
- """
78
- if use_log_domain:
79
- # Compute in log domain: log(Z) = log_sum_exp(Σ_r ωr * N(F,U))
80
- log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
81
- log_partition = log_sum_exp(log_potentials, dim=0)
82
- return log_partition
83
- else:
84
- # Direct computation (less stable)
85
- potentials = torch.exp(torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1))
86
- return torch.sum(potentials)
87
-
88
-
89
- def compute_mln_probability(rule_weights: torch.Tensor,
90
- ground_rule_counts: torch.Tensor,
91
- log_partition: torch.Tensor) -> torch.Tensor:
92
- """
93
- Compute MLN probability P(F,U|ω) from Equation 1
94
-
95
- Paper Equation 1: P(F,U|ω) = (1/Z(ω)) * ∏_{r∈R} exp(ωr * N(F,U))
96
-
97
- Args:
98
- rule_weights: Tensor of shape [num_rules] - ω values
99
- ground_rule_counts: Tensor of shape [batch_size, num_rules] - N(F,U) values
100
- log_partition: Log partition function log(Z(ω))
101
-
102
- Returns:
103
- Log probabilities of shape [batch_size]
104
- """
105
- # Compute log potential: Σ_r ωr * N(F,U)
106
- log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
107
-
108
- # Subtract log partition function
109
- log_probabilities = log_potentials - log_partition
110
-
111
- return log_probabilities
112
-
113
-
114
- def compute_elbo_loss(predicted_probs: torch.Tensor,
115
- approximate_posterior: torch.Tensor,
116
- rule_weights: torch.Tensor,
117
- ground_rule_counts: torch.Tensor,
118
- log_partition: torch.Tensor) -> torch.Tensor:
119
- """
120
- Compute ELBO loss from Equation 5
121
-
122
- Paper Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
123
-
124
- Args:
125
- predicted_probs: Predicted fact probabilities from scoring module
126
- approximate_posterior: Q(U) distribution
127
- rule_weights: MLN rule weights ω
128
- ground_rule_counts: Ground rule satisfaction counts N(F,U)
129
- log_partition: Log partition function
130
-
131
- Returns:
132
- Negative ELBO loss (to minimize)
133
- """
134
- # First term: Σ_U Q(U) log P(F,U|ω)
135
- log_joint_probs = compute_mln_probability(rule_weights, ground_rule_counts, log_partition)
136
- joint_term = torch.sum(approximate_posterior * log_joint_probs)
137
-
138
- # Second term: -Σ_U Q(U) log Q(U) (entropy)
139
- entropy_term = -torch.sum(approximate_posterior * safe_log(approximate_posterior))
140
-
141
- elbo = joint_term + entropy_term
142
-
143
- # Return negative ELBO as loss (we want to maximize ELBO)
144
- return -elbo
145
-
146
-
147
- def bernoulli_entropy(p: torch.Tensor) -> torch.Tensor:
148
- """
149
- Compute entropy of Bernoulli distribution
150
- Used in E-step for Q(U) entropy computation
151
-
152
- H(p) = -p*log(p) - (1-p)*log(1-p)
153
- """
154
- return -(p * safe_log(p) + (1 - p) * safe_log(1 - p))
155
-
156
-
157
- def bernoulli_log_prob(value: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
158
- """
159
- Compute log probability of Bernoulli distribution
160
- Used for fact probability computations in E-step
161
-
162
- log P(x=value) = value*log(p) + (1-value)*log(1-p)
163
- """
164
- return value * safe_log(prob) + (1 - value) * safe_log(1 - prob)
165
-
166
-
167
- def compute_markov_blanket_prob(fact_prob: torch.Tensor,
168
- neighbor_probs: torch.Tensor,
169
- rule_weights: torch.Tensor) -> torch.Tensor:
170
- """
171
- Compute probability of fact given Markov blanket
172
- Used in M-step pseudo-likelihood computation (Equation 13)
173
-
174
- Args:
175
- fact_prob: Probability of target fact
176
- neighbor_probs: Probabilities of facts in Markov blanket
177
- rule_weights: Weights of rules involving this fact
178
-
179
- Returns:
180
- P(uk | Markov Blanket)
181
- """
182
- # Simplified computation - in practice this involves complex inference
183
- # over the local Markov network structure
184
-
185
- # Compute local potential based on neighboring facts and rule weights
186
- local_potential = torch.sum(rule_weights * neighbor_probs)
187
-
188
- # Normalize using sigmoid
189
- return torch.sigmoid(local_potential + torch.logit(fact_prob))
190
-
191
-
192
- def temperature_scaling(logits: torch.Tensor, temperature: float) -> torch.Tensor:
193
- """
194
- Apply temperature scaling for confidence calibration
195
- Used in paper for calibrating confidence estimates
196
-
197
- Args:
198
- logits: Input logits
199
- temperature: Temperature parameter (1.0 = no scaling)
200
-
201
- Returns:
202
- Temperature-scaled probabilities
203
- """
204
- scaled_logits = logits / temperature
205
- return torch.softmax(scaled_logits, dim=-1)
206
-
207
-
208
- def kl_divergence_bernoulli(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
209
- """
210
- Compute KL divergence between two Bernoulli distributions
211
- KL(p||q) = p*log(p/q) + (1-p)*log((1-p)/(1-q))
212
-
213
- Used for measuring distance between true and approximate posteriors
214
- """
215
- eps = 1e-8
216
- p = torch.clamp(p, eps, 1 - eps)
217
- q = torch.clamp(q, eps, 1 - eps)
218
-
219
- kl = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
220
- return kl
221
-
222
-
223
- def gradient_clipping(parameters: List[torch.nn.Parameter],
224
- max_norm: float = 1.0) -> float:
225
- """
226
- Clip gradients to prevent exploding gradients
227
- Returns the total norm before clipping
228
- """
229
- total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
230
- return total_norm.item()
231
-
232
-
233
- def compute_metrics(predictions: torch.Tensor,
234
- targets: torch.Tensor,
235
- k_values: List[int] = [1, 3, 10]) -> Dict[str, float]:
236
- """
237
- Compute evaluation metrics as specified in paper Section 5.2
238
-
239
- Args:
240
- predictions: Predicted scores/ranks [batch_size, num_entities]
241
- targets: Target entity indices [batch_size]
242
- k_values: Values of k for Hit@k computation
243
-
244
- Returns:
245
- Dictionary with MRR and Hit@k metrics
246
- """
247
- batch_size = predictions.size(0)
248
-
249
- # Get ranks of target entities
250
- sorted_indices = torch.argsort(predictions, dim=1, descending=True)
251
- ranks = torch.zeros(batch_size, dtype=torch.float)
252
-
253
- for i in range(batch_size):
254
- target_idx = targets[i]
255
- rank = (sorted_indices[i] == target_idx).nonzero(as_tuple=True)[0][0] + 1
256
- ranks[i] = rank.float()
257
-
258
- # Compute MRR
259
- mrr = torch.mean(1.0 / ranks).item()
260
-
261
- # Compute Hit@k
262
- metrics = {'MRR': mrr}
263
- for k in k_values:
264
- hit_at_k = torch.mean((ranks <= k).float()).item()
265
- metrics[f'Hit@{k}'] = hit_at_k
266
-
267
- return metrics
268
-
269
-
270
- def moving_average(values: List[float], window_size: int = 10) -> float:
271
- """Compute moving average of values"""
272
- if len(values) < window_size:
273
- return sum(values) / len(values) if values else 0.0
274
- return sum(values[-window_size:]) / window_size
275
-
276
-
277
- def cosine_similarity(x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor:
278
- """Compute cosine similarity between vectors"""
279
- return F.cosine_similarity(x, y, dim=dim)
280
-
281
-
282
- def euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
283
- """Compute Euclidean distance between vectors"""
284
- return torch.norm(x - y, p=2, dim=-1)
285
-
286
-
287
- def sample_negative_triples(positive_triples: torch.Tensor,
288
- num_entities: int,
289
- corruption_mode: str = 'both') -> torch.Tensor:
290
- """
291
- Sample negative triples for training by corrupting positive ones
292
-
293
- Args:
294
- positive_triples: Tensor of shape [batch_size, 3] (head, relation, tail)
295
- num_entities: Total number of entities
296
- corruption_mode: 'head', 'tail', or 'both'
297
-
298
- Returns:
299
- Negative triples tensor
300
- """
301
- batch_size = positive_triples.size(0)
302
- negative_triples = positive_triples.clone()
303
-
304
- for i in range(batch_size):
305
- if corruption_mode == 'head' or (corruption_mode == 'both' and i % 2 == 0):
306
- # Corrupt head entity
307
- negative_triples[i, 0] = torch.randint(0, num_entities, (1,))
308
- else:
309
- # Corrupt tail entity
310
- negative_triples[i, 2] = torch.randint(0, num_entities, (1,))
311
-
312
- return negative_triples
313
-
314
-
315
- class NumericalStabilizer:
316
- """Utility class for numerical stability in computations"""
317
-
318
- @staticmethod
319
- def stabilize_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
320
- """Numerically stable softmax"""
321
- x_max, _ = torch.max(x, dim=dim, keepdim=True)
322
- exp_x = torch.exp(x - x_max)
323
- return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
324
-
325
- @staticmethod
326
- def stabilize_log_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
327
- """Numerically stable log-softmax"""
328
- return F.log_softmax(x, dim=dim)
329
-
330
- @staticmethod
331
- def clamp_probabilities(p: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
332
- """Clamp probabilities to valid range"""
333
- return torch.clamp(p, min=eps, max=1.0 - eps)
334
-
335
-
336
- # Constants for numerical stability
337
- EPS = 1e-8
338
- LOG_EPS = math.log(EPS)
1
+ """
2
+ Mathematical utilities for NPLL implementation
3
+ Implements key mathematical functions from the paper
4
+ """
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from typing import List, Dict, Tuple, Optional, Union
10
+ import math
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def log_sum_exp(values: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
17
+ """
18
+ Numerically stable log-sum-exp operation
19
+ Used for partition function computation in MLN (Equation 2)
20
+
21
+ Args:
22
+ values: Input tensor
23
+ dim: Dimension to sum over
24
+ keepdim: Whether to keep dimension
25
+
26
+ Returns:
27
+ log(sum(exp(values))) computed stably
28
+ """
29
+ max_val, _ = values.max(dim=dim, keepdim=True)
30
+
31
+ # Handle case where all values are -inf
32
+ max_val = torch.where(torch.isfinite(max_val), max_val, torch.zeros_like(max_val))
33
+
34
+ # Compute log-sum-exp
35
+ result = max_val + torch.log(torch.sum(torch.exp(values - max_val), dim=dim, keepdim=True))
36
+
37
+ if not keepdim:
38
+ result = result.squeeze(dim)
39
+
40
+ return result
41
+
42
+
43
+ def safe_log(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
44
+ """
45
+ Safe logarithm that clamps input to avoid log(0)
46
+ Used throughout NPLL for probability computations
47
+ """
48
+ return torch.log(torch.clamp(x, min=eps))
49
+
50
+
51
+ def safe_sigmoid(x: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Numerically stable sigmoid function
54
+ Used in scoring module probability transformation (Section 4.1)
55
+ """
56
+ return torch.sigmoid(torch.clamp(x, min=-50, max=50))
57
+
58
+
59
+ def partition_function_approximation(rule_weights: torch.Tensor,
60
+ ground_rule_counts: torch.Tensor,
61
+ use_log_domain: bool = True) -> torch.Tensor:
62
+ """
63
+ Approximate MLN partition function Z(ω) from Equation 2
64
+
65
+ Paper Equation 2: Z(ω) = Σ_{F,U} ∏_{r∈R} exp(ωr * N(F,U))
66
+
67
+ For large knowledge graphs, exact computation is intractable,
68
+ so we use sampling-based approximation
69
+
70
+ Args:
71
+ rule_weights: Tensor of shape [num_rules] - rule weights ω
72
+ ground_rule_counts: Tensor of shape [num_samples, num_rules] - N(F,U) values
73
+ use_log_domain: Whether to compute in log domain for stability
74
+
75
+ Returns:
76
+ Approximation of partition function
77
+ """
78
+ if use_log_domain:
79
+ # Compute in log domain: log(Z) = log_sum_exp(Σ_r ωr * N(F,U))
80
+ log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
81
+ log_partition = log_sum_exp(log_potentials, dim=0)
82
+ return log_partition
83
+ else:
84
+ # Direct computation (less stable)
85
+ potentials = torch.exp(torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1))
86
+ return torch.sum(potentials)
87
+
88
+
89
+ def compute_mln_probability(rule_weights: torch.Tensor,
90
+ ground_rule_counts: torch.Tensor,
91
+ log_partition: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Compute MLN probability P(F,U|ω) from Equation 1
94
+
95
+ Paper Equation 1: P(F,U|ω) = (1/Z(ω)) * ∏_{r∈R} exp(ωr * N(F,U))
96
+
97
+ Args:
98
+ rule_weights: Tensor of shape [num_rules] - ω values
99
+ ground_rule_counts: Tensor of shape [batch_size, num_rules] - N(F,U) values
100
+ log_partition: Log partition function log(Z(ω))
101
+
102
+ Returns:
103
+ Log probabilities of shape [batch_size]
104
+ """
105
+ # Compute log potential: Σ_r ωr * N(F,U)
106
+ log_potentials = torch.sum(rule_weights.unsqueeze(0) * ground_rule_counts, dim=1)
107
+
108
+ # Subtract log partition function
109
+ log_probabilities = log_potentials - log_partition
110
+
111
+ return log_probabilities
112
+
113
+
114
+ def compute_elbo_loss(predicted_probs: torch.Tensor,
115
+ approximate_posterior: torch.Tensor,
116
+ rule_weights: torch.Tensor,
117
+ ground_rule_counts: torch.Tensor,
118
+ log_partition: torch.Tensor) -> torch.Tensor:
119
+ """
120
+ Compute ELBO loss from Equation 5
121
+
122
+ Paper Equation 5: ELBO = Σ_U Q(U) log P(F,U|ω) - Σ_U Q(U) log Q(U)
123
+
124
+ Args:
125
+ predicted_probs: Predicted fact probabilities from scoring module
126
+ approximate_posterior: Q(U) distribution
127
+ rule_weights: MLN rule weights ω
128
+ ground_rule_counts: Ground rule satisfaction counts N(F,U)
129
+ log_partition: Log partition function
130
+
131
+ Returns:
132
+ Negative ELBO loss (to minimize)
133
+ """
134
+ # First term: Σ_U Q(U) log P(F,U|ω)
135
+ log_joint_probs = compute_mln_probability(rule_weights, ground_rule_counts, log_partition)
136
+ joint_term = torch.sum(approximate_posterior * log_joint_probs)
137
+
138
+ # Second term: -Σ_U Q(U) log Q(U) (entropy)
139
+ entropy_term = -torch.sum(approximate_posterior * safe_log(approximate_posterior))
140
+
141
+ elbo = joint_term + entropy_term
142
+
143
+ # Return negative ELBO as loss (we want to maximize ELBO)
144
+ return -elbo
145
+
146
+
147
+ def bernoulli_entropy(p: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ Compute entropy of Bernoulli distribution
150
+ Used in E-step for Q(U) entropy computation
151
+
152
+ H(p) = -p*log(p) - (1-p)*log(1-p)
153
+ """
154
+ return -(p * safe_log(p) + (1 - p) * safe_log(1 - p))
155
+
156
+
157
+ def bernoulli_log_prob(value: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Compute log probability of Bernoulli distribution
160
+ Used for fact probability computations in E-step
161
+
162
+ log P(x=value) = value*log(p) + (1-value)*log(1-p)
163
+ """
164
+ return value * safe_log(prob) + (1 - value) * safe_log(1 - prob)
165
+
166
+
167
+ def compute_markov_blanket_prob(fact_prob: torch.Tensor,
168
+ neighbor_probs: torch.Tensor,
169
+ rule_weights: torch.Tensor) -> torch.Tensor:
170
+ """
171
+ Compute probability of fact given Markov blanket
172
+ Used in M-step pseudo-likelihood computation (Equation 13)
173
+
174
+ Args:
175
+ fact_prob: Probability of target fact
176
+ neighbor_probs: Probabilities of facts in Markov blanket
177
+ rule_weights: Weights of rules involving this fact
178
+
179
+ Returns:
180
+ P(uk | Markov Blanket)
181
+ """
182
+ # Simplified computation - in practice this involves complex inference
183
+ # over the local Markov network structure
184
+
185
+ # Compute local potential based on neighboring facts and rule weights
186
+ local_potential = torch.sum(rule_weights * neighbor_probs)
187
+
188
+ # Normalize using sigmoid
189
+ return torch.sigmoid(local_potential + torch.logit(fact_prob))
190
+
191
+
192
+ def temperature_scaling(logits: torch.Tensor, temperature: float) -> torch.Tensor:
193
+ """
194
+ Apply temperature scaling for confidence calibration
195
+ Used in paper for calibrating confidence estimates
196
+
197
+ Args:
198
+ logits: Input logits
199
+ temperature: Temperature parameter (1.0 = no scaling)
200
+
201
+ Returns:
202
+ Temperature-scaled probabilities
203
+ """
204
+ scaled_logits = logits / temperature
205
+ return torch.softmax(scaled_logits, dim=-1)
206
+
207
+
208
+ def kl_divergence_bernoulli(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
209
+ """
210
+ Compute KL divergence between two Bernoulli distributions
211
+ KL(p||q) = p*log(p/q) + (1-p)*log((1-p)/(1-q))
212
+
213
+ Used for measuring distance between true and approximate posteriors
214
+ """
215
+ eps = 1e-8
216
+ p = torch.clamp(p, eps, 1 - eps)
217
+ q = torch.clamp(q, eps, 1 - eps)
218
+
219
+ kl = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
220
+ return kl
221
+
222
+
223
+ def gradient_clipping(parameters: List[torch.nn.Parameter],
224
+ max_norm: float = 1.0) -> float:
225
+ """
226
+ Clip gradients to prevent exploding gradients
227
+ Returns the total norm before clipping
228
+ """
229
+ total_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm)
230
+ return total_norm.item()
231
+
232
+
233
+ def compute_metrics(predictions: torch.Tensor,
234
+ targets: torch.Tensor,
235
+ k_values: List[int] = [1, 3, 10]) -> Dict[str, float]:
236
+ """
237
+ Compute evaluation metrics as specified in paper Section 5.2
238
+
239
+ Args:
240
+ predictions: Predicted scores/ranks [batch_size, num_entities]
241
+ targets: Target entity indices [batch_size]
242
+ k_values: Values of k for Hit@k computation
243
+
244
+ Returns:
245
+ Dictionary with MRR and Hit@k metrics
246
+ """
247
+ batch_size = predictions.size(0)
248
+
249
+ # Get ranks of target entities
250
+ sorted_indices = torch.argsort(predictions, dim=1, descending=True)
251
+ ranks = torch.zeros(batch_size, dtype=torch.float)
252
+
253
+ for i in range(batch_size):
254
+ target_idx = targets[i]
255
+ rank = (sorted_indices[i] == target_idx).nonzero(as_tuple=True)[0][0] + 1
256
+ ranks[i] = rank.float()
257
+
258
+ # Compute MRR
259
+ mrr = torch.mean(1.0 / ranks).item()
260
+
261
+ # Compute Hit@k
262
+ metrics = {'MRR': mrr}
263
+ for k in k_values:
264
+ hit_at_k = torch.mean((ranks <= k).float()).item()
265
+ metrics[f'Hit@{k}'] = hit_at_k
266
+
267
+ return metrics
268
+
269
+
270
+ def moving_average(values: List[float], window_size: int = 10) -> float:
271
+ """Compute moving average of values"""
272
+ if len(values) < window_size:
273
+ return sum(values) / len(values) if values else 0.0
274
+ return sum(values[-window_size:]) / window_size
275
+
276
+
277
+ def cosine_similarity(x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor:
278
+ """Compute cosine similarity between vectors"""
279
+ return F.cosine_similarity(x, y, dim=dim)
280
+
281
+
282
+ def euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
283
+ """Compute Euclidean distance between vectors"""
284
+ return torch.norm(x - y, p=2, dim=-1)
285
+
286
+
287
+ def sample_negative_triples(positive_triples: torch.Tensor,
288
+ num_entities: int,
289
+ corruption_mode: str = 'both') -> torch.Tensor:
290
+ """
291
+ Sample negative triples for training by corrupting positive ones
292
+
293
+ Args:
294
+ positive_triples: Tensor of shape [batch_size, 3] (head, relation, tail)
295
+ num_entities: Total number of entities
296
+ corruption_mode: 'head', 'tail', or 'both'
297
+
298
+ Returns:
299
+ Negative triples tensor
300
+ """
301
+ batch_size = positive_triples.size(0)
302
+ negative_triples = positive_triples.clone()
303
+
304
+ for i in range(batch_size):
305
+ if corruption_mode == 'head' or (corruption_mode == 'both' and i % 2 == 0):
306
+ # Corrupt head entity
307
+ negative_triples[i, 0] = torch.randint(0, num_entities, (1,))
308
+ else:
309
+ # Corrupt tail entity
310
+ negative_triples[i, 2] = torch.randint(0, num_entities, (1,))
311
+
312
+ return negative_triples
313
+
314
+
315
+ class NumericalStabilizer:
316
+ """Utility class for numerical stability in computations"""
317
+
318
+ @staticmethod
319
+ def stabilize_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
320
+ """Numerically stable softmax"""
321
+ x_max, _ = torch.max(x, dim=dim, keepdim=True)
322
+ exp_x = torch.exp(x - x_max)
323
+ return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
324
+
325
+ @staticmethod
326
+ def stabilize_log_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
327
+ """Numerically stable log-softmax"""
328
+ return F.log_softmax(x, dim=dim)
329
+
330
+ @staticmethod
331
+ def clamp_probabilities(p: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
332
+ """Clamp probabilities to valid range"""
333
+ return torch.clamp(p, min=eps, max=1.0 - eps)
334
+
335
+
336
+ # Constants for numerical stability
337
+ EPS = 1e-8
338
+ LOG_EPS = math.log(EPS)
339
339
  LARGE_NUMBER = 1e8