odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,370 +1,370 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from typing import List, Dict, Tuple, Optional
6
- import logging
7
-
8
- from ..core import Triple
9
- from ..utils.config import NPLLConfig
10
- from ..utils.math_utils import safe_sigmoid
11
- from .embeddings import EmbeddingManager
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- class BilinearScoringFunction(nn.Module):
17
-
18
- def __init__(self, config: NPLLConfig):
19
- super().__init__()
20
- self.config = config
21
-
22
- # Dimensions from paper
23
- self.entity_dim = config.entity_embedding_dim # d
24
- self.relation_dim = config.relation_embedding_dim # d
25
- self.hidden_dim = config.scoring_hidden_dim # k
26
-
27
- # Ensure dimensions match paper assumptions
28
- assert self.entity_dim == self.relation_dim, \
29
- "Paper assumes entity and relation embeddings have same dimension d"
30
-
31
- # Bilinear tensor W_R: d×d×k dimensional tensor
32
- # For efficiency, we implement this as k separate d×d matrices
33
- self.W_R = nn.Parameter(torch.zeros(self.hidden_dim, self.entity_dim, self.entity_dim))
34
-
35
- # Linear tensor V_R: k×2d dimensional tensor
36
- # Maps concatenated [eh; et] to k dimensions
37
- self.V_R = nn.Parameter(torch.zeros(self.hidden_dim, 2 * self.entity_dim))
38
-
39
- # Output projection u_R: k-dimensional vector
40
- self.u_R = nn.Parameter(torch.zeros(self.hidden_dim))
41
-
42
- # Bias term b_R: k-dimensional vector
43
- self.b_R = nn.Parameter(torch.zeros(self.hidden_dim))
44
-
45
- # Activation function f (paper uses ReLU)
46
- if config.scoring_activation == "relu":
47
- self.activation = F.relu
48
- elif config.scoring_activation == "tanh":
49
- self.activation = torch.tanh
50
- elif config.scoring_activation == "gelu":
51
- self.activation = F.gelu
52
- else:
53
- self.activation = F.relu
54
-
55
- # Initialize parameters
56
- self._initialize_parameters()
57
-
58
- def _initialize_parameters(self):
59
- """Initialize parameters using Xavier/Glorot initialization"""
60
- # Initialize bilinear tensor W_R
61
- nn.init.xavier_uniform_(self.W_R.data)
62
-
63
- # Initialize linear tensor V_R
64
- nn.init.xavier_uniform_(self.V_R.data)
65
-
66
- # Initialize output projection u_R
67
- nn.init.xavier_uniform_(self.u_R.data.unsqueeze(0))
68
- self.u_R.data.squeeze_(0)
69
-
70
- # Initialize bias b_R to small values
71
- nn.init.constant_(self.b_R.data, 0.1)
72
-
73
- def forward(self, head_embeddings: torch.Tensor,
74
- relation_embeddings: torch.Tensor,
75
- tail_embeddings: torch.Tensor) -> torch.Tensor:
76
- """
77
- Forward pass implementing Equation 7
78
-
79
- Args:
80
- head_embeddings: [batch_size, entity_dim] - eh vectors
81
- relation_embeddings: [batch_size, relation_dim] - relation vectors (unused in Eq 7 but kept for completeness)
82
- tail_embeddings: [batch_size, entity_dim] - et vectors
83
-
84
- Returns:
85
- scores: [batch_size] - g(l, eh, et) scores
86
- """
87
- batch_size = head_embeddings.size(0)
88
-
89
- # Step 1: Compute bilinear term e^T_h W_R et
90
- # W_R has shape [k, d, d], we need to compute bilinear for each k
91
- bilinear_terms = []
92
-
93
- for i in range(self.hidden_dim):
94
- # For each k-th slice: e^T_h W_R[i] et
95
- # head_embeddings: [batch_size, d]
96
- # W_R[i]: [d, d]
97
- # tail_embeddings: [batch_size, d]
98
-
99
- # Compute head_embeddings @ W_R[i] -> [batch_size, d]
100
- temp = torch.matmul(head_embeddings, self.W_R[i])
101
-
102
- # Compute (head_embeddings @ W_R[i]) @ tail_embeddings^T -> [batch_size]
103
- bilinear_i = torch.sum(temp * tail_embeddings, dim=1)
104
- bilinear_terms.append(bilinear_i)
105
-
106
- # Stack bilinear terms: [batch_size, k]
107
- bilinear_output = torch.stack(bilinear_terms, dim=1)
108
-
109
- #Compute linear term V_R [eh; et]
110
- # Concatenate head and tail embeddings: [batch_size, 2d]
111
- concatenated = torch.cat([head_embeddings, tail_embeddings], dim=1)
112
-
113
- # Linear transformation: V_R @ [eh; et]^T -> [batch_size, k]
114
- linear_output = torch.matmul(concatenated, self.V_R.transpose(0, 1))
115
-
116
- # Combine terms inside activation
117
- # e^T_h W_R et + V_R [eh; et] + b_R
118
- combined = bilinear_output + linear_output + self.b_R.unsqueeze(0)
119
-
120
- # Apply non-linear activation f
121
- activated = self.activation(combined) # [batch_size, k]
122
-
123
- #Final projection u^T_R f(...)
124
- scores = torch.matmul(activated, self.u_R) # [batch_size]
125
-
126
- return scores
127
-
128
- def forward_single(self, head_embedding: torch.Tensor,
129
- relation_embedding: torch.Tensor,
130
- tail_embedding: torch.Tensor) -> torch.Tensor:
131
- """Forward pass for single triple"""
132
- # Add batch dimension
133
- head_batch = head_embedding.unsqueeze(0)
134
- rel_batch = relation_embedding.unsqueeze(0)
135
- tail_batch = tail_embedding.unsqueeze(0)
136
-
137
- score = self.forward(head_batch, rel_batch, tail_batch)
138
- return score.squeeze(0)
139
-
140
-
141
- class NPLLScoringModule(nn.Module):
142
-
143
- def __init__(self, config: NPLLConfig):
144
- super().__init__()
145
- self.config = config
146
-
147
- # Embedding manager for entities and relations
148
- self.embedding_manager = EmbeddingManager(config)
149
-
150
- # Bilinear scoring function (Equation 7)
151
- self.scoring_function = BilinearScoringFunction(config)
152
-
153
- # Temperature parameter for calibration (paper mentions temperature scaling)
154
- self.temperature = nn.Parameter(torch.tensor(config.temperature))
155
-
156
- def forward(self, triples: List[Triple]) -> torch.Tensor:
157
- """
158
- Score a batch of triples
159
-
160
- Args:
161
- triples: List of Triple objects to score
162
-
163
- Returns:
164
- scores: [batch_size] raw scores g(l, eh, et)
165
- """
166
- if not triples:
167
- return torch.tensor([])
168
-
169
- # Extract entity and relation names
170
- head_names = [triple.head.name for triple in triples]
171
- relation_names = [triple.relation.name for triple in triples]
172
- tail_names = [triple.tail.name for triple in triples]
173
-
174
- # Get embeddings
175
- head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
176
- head_names, relation_names, tail_names, add_if_missing=False
177
- )
178
-
179
- # Compute scores using Equation 7
180
- scores = self.scoring_function(head_embs, rel_embs, tail_embs)
181
-
182
- return scores
183
-
184
- def forward_with_names(self, head_names: List[str],
185
- relation_names: List[str],
186
- tail_names: List[str]) -> torch.Tensor:
187
- """Score triples given entity/relation names directly"""
188
- # Get embeddings
189
- head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
190
- head_names, relation_names, tail_names, add_if_missing=False
191
- )
192
-
193
- # Compute scores
194
- scores = self.scoring_function(head_embs, rel_embs, tail_embs)
195
-
196
- return scores
197
-
198
- def score_single_triple(self, triple: Triple) -> torch.Tensor:
199
- """Score a single triple"""
200
- head_emb, rel_emb, tail_emb = self.embedding_manager.get_triple_embeddings(triple, add_if_missing=False)
201
-
202
- # Update entity embeddings
203
- head_emb = self.embedding_manager.update_entity_embeddings(head_emb.unsqueeze(0)).squeeze(0)
204
- tail_emb = self.embedding_manager.update_entity_embeddings(tail_emb.unsqueeze(0)).squeeze(0)
205
-
206
- score = self.scoring_function.forward_single(head_emb, rel_emb, tail_emb)
207
- return score
208
-
209
- def get_probabilities(self, triples: List[Triple],
210
- apply_temperature: bool = True) -> torch.Tensor:
211
- """
212
- Get probability scores p = sigmoid(g(l, eh, et))
213
- """
214
- scores = self.forward(triples)
215
-
216
- if apply_temperature:
217
- # Apply temperature scaling for calibration
218
- scores = scores / self.temperature
219
-
220
- # Apply sigmoid transformation
221
- probabilities = safe_sigmoid(scores)
222
-
223
- return probabilities
224
-
225
- def compute_fact_scores(self, known_facts: List[Triple],
226
- unknown_facts: List[Triple]) -> Tuple[torch.Tensor, torch.Tensor]:
227
- """
228
- Compute scores for known and unknown facts
229
-
230
- Used in E-step for computing approximate posterior Q(U)
231
- """
232
- known_scores = self.forward(known_facts) if known_facts else torch.tensor([])
233
- unknown_scores = self.forward(unknown_facts) if unknown_facts else torch.tensor([])
234
-
235
- return known_scores, unknown_scores
236
-
237
- def build_vocabulary(self, kg):
238
- """Build vocabulary from knowledge graph"""
239
- self.embedding_manager.build_vocabulary_from_kg(kg)
240
-
241
- @property
242
- def entity_vocab_size(self) -> int:
243
- """Number of entities in vocabulary"""
244
- return self.embedding_manager.entity_vocab_size
245
-
246
- @property
247
- def relation_vocab_size(self) -> int:
248
- """Number of relations in vocabulary"""
249
- return self.embedding_manager.relation_vocab_size
250
-
251
-
252
- class BatchedScoringModule(nn.Module):
253
- """
254
- Optimized scoring module for large-scale batch processing
255
- Implements memory-efficient batching for scoring many triples
256
- """
257
-
258
- def __init__(self, config: NPLLConfig, base_scoring_module: NPLLScoringModule):
259
- super().__init__()
260
- self.config = config
261
- self.base_module = base_scoring_module
262
- self.batch_size = config.batch_size
263
-
264
- def score_large_batch(self, triples: List[Triple]) -> torch.Tensor:
265
- """Score large batch of triples with memory-efficient batching"""
266
- all_scores = []
267
-
268
- # Process in smaller batches
269
- for i in range(0, len(triples), self.batch_size):
270
- batch_triples = triples[i:i + self.batch_size]
271
- batch_scores = self.base_module.forward(batch_triples)
272
- all_scores.append(batch_scores)
273
-
274
- # Concatenate all scores
275
- if all_scores:
276
- return torch.cat(all_scores, dim=0)
277
- else:
278
- return torch.tensor([])
279
-
280
- def score_all_possible_triples(self, entities: List[str], relations: List[str]) -> torch.Tensor:
281
- """
282
- Score all possible (head, relation, tail) combinations
283
- Warning: This can be very memory intensive for large vocabularies
284
- """
285
- all_triples = []
286
-
287
- # Generate all possible combinations
288
- for head in entities:
289
- for relation in relations:
290
- for tail in entities:
291
- if head != tail: # Avoid self-loops typically
292
- all_triples.append((head, relation, tail))
293
-
294
- logger.warning(f"Scoring {len(all_triples)} possible triples - this may be memory intensive")
295
-
296
- # Score in batches
297
- all_scores = []
298
- for i in range(0, len(all_triples), self.batch_size):
299
- batch_triples = all_triples[i:i + self.batch_size]
300
-
301
- # Extract names
302
- head_names = [t[0] for t in batch_triples]
303
- rel_names = [t[1] for t in batch_triples]
304
- tail_names = [t[2] for t in batch_triples]
305
-
306
- batch_scores = self.base_module.forward_with_names(head_names, rel_names, tail_names)
307
- all_scores.append(batch_scores)
308
-
309
- if all_scores:
310
- return torch.cat(all_scores, dim=0)
311
- else:
312
- return torch.tensor([])
313
-
314
-
315
- def create_scoring_module(config: NPLLConfig, kg=None) -> NPLLScoringModule:
316
- """
317
- Factory function to create and initialize NPLL scoring module
318
-
319
- Args:
320
- config: NPLL configuration
321
- kg: Optional knowledge graph to build vocabulary from
322
-
323
- Returns:
324
- Initialized NPLLScoringModule
325
- """
326
- scoring_module = NPLLScoringModule(config)
327
-
328
- if kg is not None:
329
- scoring_module.build_vocabulary(kg)
330
- logger.info(f"Built vocabulary: {scoring_module.entity_vocab_size} entities, "
331
- f"{scoring_module.relation_vocab_size} relations")
332
-
333
- return scoring_module
334
-
335
-
336
- def verify_equation7_implementation():
337
- """
338
- Verification function to ensure Equation 7 is implemented correctly
339
- Tests the mathematical operations step by step
340
- """
341
- from ..utils.config import default_config
342
-
343
- config = default_config
344
- scoring_func = BilinearScoringFunction(config)
345
-
346
- # Create test inputs
347
- batch_size = 3
348
- d = config.entity_embedding_dim
349
- k = config.scoring_hidden_dim
350
-
351
- head_emb = torch.randn(batch_size, d)
352
- tail_emb = torch.randn(batch_size, d)
353
- rel_emb = torch.randn(batch_size, d) # Not used in Eq 7 but kept for interface
354
-
355
- # Forward pass
356
- scores = scoring_func(head_emb, rel_emb, tail_emb)
357
-
358
- # Verify output shape
359
- assert scores.shape == (batch_size,), f"Expected shape ({batch_size},), got {scores.shape}"
360
-
361
- # Verify no NaN values
362
- assert not torch.isnan(scores).any(), "Output contains NaN values"
363
-
364
- # Test single input
365
- single_score = scoring_func.forward_single(head_emb[0], rel_emb[0], tail_emb[0])
366
- assert abs(single_score.item() - scores[0].item()) < 1e-5, "Single vs batch mismatch"
367
-
368
- logger.info("Equation 7 implementation verified successfully")
369
-
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import List, Dict, Tuple, Optional
6
+ import logging
7
+
8
+ from ..core import Triple
9
+ from ..utils.config import NPLLConfig
10
+ from ..utils.math_utils import safe_sigmoid
11
+ from .embeddings import EmbeddingManager
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class BilinearScoringFunction(nn.Module):
17
+
18
+ def __init__(self, config: NPLLConfig):
19
+ super().__init__()
20
+ self.config = config
21
+
22
+ # Dimensions from paper
23
+ self.entity_dim = config.entity_embedding_dim # d
24
+ self.relation_dim = config.relation_embedding_dim # d
25
+ self.hidden_dim = config.scoring_hidden_dim # k
26
+
27
+ # Ensure dimensions match paper assumptions
28
+ assert self.entity_dim == self.relation_dim, \
29
+ "Paper assumes entity and relation embeddings have same dimension d"
30
+
31
+ # Bilinear tensor W_R: d×d×k dimensional tensor
32
+ # For efficiency, we implement this as k separate d×d matrices
33
+ self.W_R = nn.Parameter(torch.zeros(self.hidden_dim, self.entity_dim, self.entity_dim))
34
+
35
+ # Linear tensor V_R: k×2d dimensional tensor
36
+ # Maps concatenated [eh; et] to k dimensions
37
+ self.V_R = nn.Parameter(torch.zeros(self.hidden_dim, 2 * self.entity_dim))
38
+
39
+ # Output projection u_R: k-dimensional vector
40
+ self.u_R = nn.Parameter(torch.zeros(self.hidden_dim))
41
+
42
+ # Bias term b_R: k-dimensional vector
43
+ self.b_R = nn.Parameter(torch.zeros(self.hidden_dim))
44
+
45
+ # Activation function f (paper uses ReLU)
46
+ if config.scoring_activation == "relu":
47
+ self.activation = F.relu
48
+ elif config.scoring_activation == "tanh":
49
+ self.activation = torch.tanh
50
+ elif config.scoring_activation == "gelu":
51
+ self.activation = F.gelu
52
+ else:
53
+ self.activation = F.relu
54
+
55
+ # Initialize parameters
56
+ self._initialize_parameters()
57
+
58
+ def _initialize_parameters(self):
59
+ """Initialize parameters using Xavier/Glorot initialization"""
60
+ # Initialize bilinear tensor W_R
61
+ nn.init.xavier_uniform_(self.W_R.data)
62
+
63
+ # Initialize linear tensor V_R
64
+ nn.init.xavier_uniform_(self.V_R.data)
65
+
66
+ # Initialize output projection u_R
67
+ nn.init.xavier_uniform_(self.u_R.data.unsqueeze(0))
68
+ self.u_R.data.squeeze_(0)
69
+
70
+ # Initialize bias b_R to small values
71
+ nn.init.constant_(self.b_R.data, 0.1)
72
+
73
+ def forward(self, head_embeddings: torch.Tensor,
74
+ relation_embeddings: torch.Tensor,
75
+ tail_embeddings: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ Forward pass implementing Equation 7
78
+
79
+ Args:
80
+ head_embeddings: [batch_size, entity_dim] - eh vectors
81
+ relation_embeddings: [batch_size, relation_dim] - relation vectors (unused in Eq 7 but kept for completeness)
82
+ tail_embeddings: [batch_size, entity_dim] - et vectors
83
+
84
+ Returns:
85
+ scores: [batch_size] - g(l, eh, et) scores
86
+ """
87
+ batch_size = head_embeddings.size(0)
88
+
89
+ # Step 1: Compute bilinear term e^T_h W_R et
90
+ # W_R has shape [k, d, d], we need to compute bilinear for each k
91
+ bilinear_terms = []
92
+
93
+ for i in range(self.hidden_dim):
94
+ # For each k-th slice: e^T_h W_R[i] et
95
+ # head_embeddings: [batch_size, d]
96
+ # W_R[i]: [d, d]
97
+ # tail_embeddings: [batch_size, d]
98
+
99
+ # Compute head_embeddings @ W_R[i] -> [batch_size, d]
100
+ temp = torch.matmul(head_embeddings, self.W_R[i])
101
+
102
+ # Compute (head_embeddings @ W_R[i]) @ tail_embeddings^T -> [batch_size]
103
+ bilinear_i = torch.sum(temp * tail_embeddings, dim=1)
104
+ bilinear_terms.append(bilinear_i)
105
+
106
+ # Stack bilinear terms: [batch_size, k]
107
+ bilinear_output = torch.stack(bilinear_terms, dim=1)
108
+
109
+ #Compute linear term V_R [eh; et]
110
+ # Concatenate head and tail embeddings: [batch_size, 2d]
111
+ concatenated = torch.cat([head_embeddings, tail_embeddings], dim=1)
112
+
113
+ # Linear transformation: V_R @ [eh; et]^T -> [batch_size, k]
114
+ linear_output = torch.matmul(concatenated, self.V_R.transpose(0, 1))
115
+
116
+ # Combine terms inside activation
117
+ # e^T_h W_R et + V_R [eh; et] + b_R
118
+ combined = bilinear_output + linear_output + self.b_R.unsqueeze(0)
119
+
120
+ # Apply non-linear activation f
121
+ activated = self.activation(combined) # [batch_size, k]
122
+
123
+ #Final projection u^T_R f(...)
124
+ scores = torch.matmul(activated, self.u_R) # [batch_size]
125
+
126
+ return scores
127
+
128
+ def forward_single(self, head_embedding: torch.Tensor,
129
+ relation_embedding: torch.Tensor,
130
+ tail_embedding: torch.Tensor) -> torch.Tensor:
131
+ """Forward pass for single triple"""
132
+ # Add batch dimension
133
+ head_batch = head_embedding.unsqueeze(0)
134
+ rel_batch = relation_embedding.unsqueeze(0)
135
+ tail_batch = tail_embedding.unsqueeze(0)
136
+
137
+ score = self.forward(head_batch, rel_batch, tail_batch)
138
+ return score.squeeze(0)
139
+
140
+
141
+ class NPLLScoringModule(nn.Module):
142
+
143
+ def __init__(self, config: NPLLConfig):
144
+ super().__init__()
145
+ self.config = config
146
+
147
+ # Embedding manager for entities and relations
148
+ self.embedding_manager = EmbeddingManager(config)
149
+
150
+ # Bilinear scoring function (Equation 7)
151
+ self.scoring_function = BilinearScoringFunction(config)
152
+
153
+ # Temperature parameter for calibration (paper mentions temperature scaling)
154
+ self.temperature = nn.Parameter(torch.tensor(config.temperature))
155
+
156
+ def forward(self, triples: List[Triple]) -> torch.Tensor:
157
+ """
158
+ Score a batch of triples
159
+
160
+ Args:
161
+ triples: List of Triple objects to score
162
+
163
+ Returns:
164
+ scores: [batch_size] raw scores g(l, eh, et)
165
+ """
166
+ if not triples:
167
+ return torch.tensor([])
168
+
169
+ # Extract entity and relation names
170
+ head_names = [triple.head.name for triple in triples]
171
+ relation_names = [triple.relation.name for triple in triples]
172
+ tail_names = [triple.tail.name for triple in triples]
173
+
174
+ # Get embeddings
175
+ head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
176
+ head_names, relation_names, tail_names, add_if_missing=False
177
+ )
178
+
179
+ # Compute scores using Equation 7
180
+ scores = self.scoring_function(head_embs, rel_embs, tail_embs)
181
+
182
+ return scores
183
+
184
+ def forward_with_names(self, head_names: List[str],
185
+ relation_names: List[str],
186
+ tail_names: List[str]) -> torch.Tensor:
187
+ """Score triples given entity/relation names directly"""
188
+ # Get embeddings
189
+ head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
190
+ head_names, relation_names, tail_names, add_if_missing=False
191
+ )
192
+
193
+ # Compute scores
194
+ scores = self.scoring_function(head_embs, rel_embs, tail_embs)
195
+
196
+ return scores
197
+
198
+ def score_single_triple(self, triple: Triple) -> torch.Tensor:
199
+ """Score a single triple"""
200
+ head_emb, rel_emb, tail_emb = self.embedding_manager.get_triple_embeddings(triple, add_if_missing=False)
201
+
202
+ # Update entity embeddings
203
+ head_emb = self.embedding_manager.update_entity_embeddings(head_emb.unsqueeze(0)).squeeze(0)
204
+ tail_emb = self.embedding_manager.update_entity_embeddings(tail_emb.unsqueeze(0)).squeeze(0)
205
+
206
+ score = self.scoring_function.forward_single(head_emb, rel_emb, tail_emb)
207
+ return score
208
+
209
+ def get_probabilities(self, triples: List[Triple],
210
+ apply_temperature: bool = True) -> torch.Tensor:
211
+ """
212
+ Get probability scores p = sigmoid(g(l, eh, et))
213
+ """
214
+ scores = self.forward(triples)
215
+
216
+ if apply_temperature:
217
+ # Apply temperature scaling for calibration
218
+ scores = scores / self.temperature
219
+
220
+ # Apply sigmoid transformation
221
+ probabilities = safe_sigmoid(scores)
222
+
223
+ return probabilities
224
+
225
+ def compute_fact_scores(self, known_facts: List[Triple],
226
+ unknown_facts: List[Triple]) -> Tuple[torch.Tensor, torch.Tensor]:
227
+ """
228
+ Compute scores for known and unknown facts
229
+
230
+ Used in E-step for computing approximate posterior Q(U)
231
+ """
232
+ known_scores = self.forward(known_facts) if known_facts else torch.tensor([])
233
+ unknown_scores = self.forward(unknown_facts) if unknown_facts else torch.tensor([])
234
+
235
+ return known_scores, unknown_scores
236
+
237
+ def build_vocabulary(self, kg):
238
+ """Build vocabulary from knowledge graph"""
239
+ self.embedding_manager.build_vocabulary_from_kg(kg)
240
+
241
+ @property
242
+ def entity_vocab_size(self) -> int:
243
+ """Number of entities in vocabulary"""
244
+ return self.embedding_manager.entity_vocab_size
245
+
246
+ @property
247
+ def relation_vocab_size(self) -> int:
248
+ """Number of relations in vocabulary"""
249
+ return self.embedding_manager.relation_vocab_size
250
+
251
+
252
+ class BatchedScoringModule(nn.Module):
253
+ """
254
+ Optimized scoring module for large-scale batch processing
255
+ Implements memory-efficient batching for scoring many triples
256
+ """
257
+
258
+ def __init__(self, config: NPLLConfig, base_scoring_module: NPLLScoringModule):
259
+ super().__init__()
260
+ self.config = config
261
+ self.base_module = base_scoring_module
262
+ self.batch_size = config.batch_size
263
+
264
+ def score_large_batch(self, triples: List[Triple]) -> torch.Tensor:
265
+ """Score large batch of triples with memory-efficient batching"""
266
+ all_scores = []
267
+
268
+ # Process in smaller batches
269
+ for i in range(0, len(triples), self.batch_size):
270
+ batch_triples = triples[i:i + self.batch_size]
271
+ batch_scores = self.base_module.forward(batch_triples)
272
+ all_scores.append(batch_scores)
273
+
274
+ # Concatenate all scores
275
+ if all_scores:
276
+ return torch.cat(all_scores, dim=0)
277
+ else:
278
+ return torch.tensor([])
279
+
280
+ def score_all_possible_triples(self, entities: List[str], relations: List[str]) -> torch.Tensor:
281
+ """
282
+ Score all possible (head, relation, tail) combinations
283
+ Warning: This can be very memory intensive for large vocabularies
284
+ """
285
+ all_triples = []
286
+
287
+ # Generate all possible combinations
288
+ for head in entities:
289
+ for relation in relations:
290
+ for tail in entities:
291
+ if head != tail: # Avoid self-loops typically
292
+ all_triples.append((head, relation, tail))
293
+
294
+ logger.warning(f"Scoring {len(all_triples)} possible triples - this may be memory intensive")
295
+
296
+ # Score in batches
297
+ all_scores = []
298
+ for i in range(0, len(all_triples), self.batch_size):
299
+ batch_triples = all_triples[i:i + self.batch_size]
300
+
301
+ # Extract names
302
+ head_names = [t[0] for t in batch_triples]
303
+ rel_names = [t[1] for t in batch_triples]
304
+ tail_names = [t[2] for t in batch_triples]
305
+
306
+ batch_scores = self.base_module.forward_with_names(head_names, rel_names, tail_names)
307
+ all_scores.append(batch_scores)
308
+
309
+ if all_scores:
310
+ return torch.cat(all_scores, dim=0)
311
+ else:
312
+ return torch.tensor([])
313
+
314
+
315
+ def create_scoring_module(config: NPLLConfig, kg=None) -> NPLLScoringModule:
316
+ """
317
+ Factory function to create and initialize NPLL scoring module
318
+
319
+ Args:
320
+ config: NPLL configuration
321
+ kg: Optional knowledge graph to build vocabulary from
322
+
323
+ Returns:
324
+ Initialized NPLLScoringModule
325
+ """
326
+ scoring_module = NPLLScoringModule(config)
327
+
328
+ if kg is not None:
329
+ scoring_module.build_vocabulary(kg)
330
+ logger.info(f"Built vocabulary: {scoring_module.entity_vocab_size} entities, "
331
+ f"{scoring_module.relation_vocab_size} relations")
332
+
333
+ return scoring_module
334
+
335
+
336
+ def verify_equation7_implementation():
337
+ """
338
+ Verification function to ensure Equation 7 is implemented correctly
339
+ Tests the mathematical operations step by step
340
+ """
341
+ from ..utils.config import default_config
342
+
343
+ config = default_config
344
+ scoring_func = BilinearScoringFunction(config)
345
+
346
+ # Create test inputs
347
+ batch_size = 3
348
+ d = config.entity_embedding_dim
349
+ k = config.scoring_hidden_dim
350
+
351
+ head_emb = torch.randn(batch_size, d)
352
+ tail_emb = torch.randn(batch_size, d)
353
+ rel_emb = torch.randn(batch_size, d) # Not used in Eq 7 but kept for interface
354
+
355
+ # Forward pass
356
+ scores = scoring_func(head_emb, rel_emb, tail_emb)
357
+
358
+ # Verify output shape
359
+ assert scores.shape == (batch_size,), f"Expected shape ({batch_size},), got {scores.shape}"
360
+
361
+ # Verify no NaN values
362
+ assert not torch.isnan(scores).any(), "Output contains NaN values"
363
+
364
+ # Test single input
365
+ single_score = scoring_func.forward_single(head_emb[0], rel_emb[0], tail_emb[0])
366
+ assert abs(single_score.item() - scores[0].item()) < 1e-5, "Single vs batch mismatch"
367
+
368
+ logger.info("Equation 7 implementation verified successfully")
369
+
370
370
  return True