odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/scoring/scoring_module.py
CHANGED
|
@@ -1,370 +1,370 @@
|
|
|
1
|
-
|
|
2
|
-
import torch
|
|
3
|
-
import torch.nn as nn
|
|
4
|
-
import torch.nn.functional as F
|
|
5
|
-
from typing import List, Dict, Tuple, Optional
|
|
6
|
-
import logging
|
|
7
|
-
|
|
8
|
-
from ..core import Triple
|
|
9
|
-
from ..utils.config import NPLLConfig
|
|
10
|
-
from ..utils.math_utils import safe_sigmoid
|
|
11
|
-
from .embeddings import EmbeddingManager
|
|
12
|
-
|
|
13
|
-
logger = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class BilinearScoringFunction(nn.Module):
|
|
17
|
-
|
|
18
|
-
def __init__(self, config: NPLLConfig):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.config = config
|
|
21
|
-
|
|
22
|
-
# Dimensions from paper
|
|
23
|
-
self.entity_dim = config.entity_embedding_dim # d
|
|
24
|
-
self.relation_dim = config.relation_embedding_dim # d
|
|
25
|
-
self.hidden_dim = config.scoring_hidden_dim # k
|
|
26
|
-
|
|
27
|
-
# Ensure dimensions match paper assumptions
|
|
28
|
-
assert self.entity_dim == self.relation_dim, \
|
|
29
|
-
"Paper assumes entity and relation embeddings have same dimension d"
|
|
30
|
-
|
|
31
|
-
# Bilinear tensor W_R: d×d×k dimensional tensor
|
|
32
|
-
# For efficiency, we implement this as k separate d×d matrices
|
|
33
|
-
self.W_R = nn.Parameter(torch.zeros(self.hidden_dim, self.entity_dim, self.entity_dim))
|
|
34
|
-
|
|
35
|
-
# Linear tensor V_R: k×2d dimensional tensor
|
|
36
|
-
# Maps concatenated [eh; et] to k dimensions
|
|
37
|
-
self.V_R = nn.Parameter(torch.zeros(self.hidden_dim, 2 * self.entity_dim))
|
|
38
|
-
|
|
39
|
-
# Output projection u_R: k-dimensional vector
|
|
40
|
-
self.u_R = nn.Parameter(torch.zeros(self.hidden_dim))
|
|
41
|
-
|
|
42
|
-
# Bias term b_R: k-dimensional vector
|
|
43
|
-
self.b_R = nn.Parameter(torch.zeros(self.hidden_dim))
|
|
44
|
-
|
|
45
|
-
# Activation function f (paper uses ReLU)
|
|
46
|
-
if config.scoring_activation == "relu":
|
|
47
|
-
self.activation = F.relu
|
|
48
|
-
elif config.scoring_activation == "tanh":
|
|
49
|
-
self.activation = torch.tanh
|
|
50
|
-
elif config.scoring_activation == "gelu":
|
|
51
|
-
self.activation = F.gelu
|
|
52
|
-
else:
|
|
53
|
-
self.activation = F.relu
|
|
54
|
-
|
|
55
|
-
# Initialize parameters
|
|
56
|
-
self._initialize_parameters()
|
|
57
|
-
|
|
58
|
-
def _initialize_parameters(self):
|
|
59
|
-
"""Initialize parameters using Xavier/Glorot initialization"""
|
|
60
|
-
# Initialize bilinear tensor W_R
|
|
61
|
-
nn.init.xavier_uniform_(self.W_R.data)
|
|
62
|
-
|
|
63
|
-
# Initialize linear tensor V_R
|
|
64
|
-
nn.init.xavier_uniform_(self.V_R.data)
|
|
65
|
-
|
|
66
|
-
# Initialize output projection u_R
|
|
67
|
-
nn.init.xavier_uniform_(self.u_R.data.unsqueeze(0))
|
|
68
|
-
self.u_R.data.squeeze_(0)
|
|
69
|
-
|
|
70
|
-
# Initialize bias b_R to small values
|
|
71
|
-
nn.init.constant_(self.b_R.data, 0.1)
|
|
72
|
-
|
|
73
|
-
def forward(self, head_embeddings: torch.Tensor,
|
|
74
|
-
relation_embeddings: torch.Tensor,
|
|
75
|
-
tail_embeddings: torch.Tensor) -> torch.Tensor:
|
|
76
|
-
"""
|
|
77
|
-
Forward pass implementing Equation 7
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
head_embeddings: [batch_size, entity_dim] - eh vectors
|
|
81
|
-
relation_embeddings: [batch_size, relation_dim] - relation vectors (unused in Eq 7 but kept for completeness)
|
|
82
|
-
tail_embeddings: [batch_size, entity_dim] - et vectors
|
|
83
|
-
|
|
84
|
-
Returns:
|
|
85
|
-
scores: [batch_size] - g(l, eh, et) scores
|
|
86
|
-
"""
|
|
87
|
-
batch_size = head_embeddings.size(0)
|
|
88
|
-
|
|
89
|
-
# Step 1: Compute bilinear term e^T_h W_R et
|
|
90
|
-
# W_R has shape [k, d, d], we need to compute bilinear for each k
|
|
91
|
-
bilinear_terms = []
|
|
92
|
-
|
|
93
|
-
for i in range(self.hidden_dim):
|
|
94
|
-
# For each k-th slice: e^T_h W_R[i] et
|
|
95
|
-
# head_embeddings: [batch_size, d]
|
|
96
|
-
# W_R[i]: [d, d]
|
|
97
|
-
# tail_embeddings: [batch_size, d]
|
|
98
|
-
|
|
99
|
-
# Compute head_embeddings @ W_R[i] -> [batch_size, d]
|
|
100
|
-
temp = torch.matmul(head_embeddings, self.W_R[i])
|
|
101
|
-
|
|
102
|
-
# Compute (head_embeddings @ W_R[i]) @ tail_embeddings^T -> [batch_size]
|
|
103
|
-
bilinear_i = torch.sum(temp * tail_embeddings, dim=1)
|
|
104
|
-
bilinear_terms.append(bilinear_i)
|
|
105
|
-
|
|
106
|
-
# Stack bilinear terms: [batch_size, k]
|
|
107
|
-
bilinear_output = torch.stack(bilinear_terms, dim=1)
|
|
108
|
-
|
|
109
|
-
#Compute linear term V_R [eh; et]
|
|
110
|
-
# Concatenate head and tail embeddings: [batch_size, 2d]
|
|
111
|
-
concatenated = torch.cat([head_embeddings, tail_embeddings], dim=1)
|
|
112
|
-
|
|
113
|
-
# Linear transformation: V_R @ [eh; et]^T -> [batch_size, k]
|
|
114
|
-
linear_output = torch.matmul(concatenated, self.V_R.transpose(0, 1))
|
|
115
|
-
|
|
116
|
-
# Combine terms inside activation
|
|
117
|
-
# e^T_h W_R et + V_R [eh; et] + b_R
|
|
118
|
-
combined = bilinear_output + linear_output + self.b_R.unsqueeze(0)
|
|
119
|
-
|
|
120
|
-
# Apply non-linear activation f
|
|
121
|
-
activated = self.activation(combined) # [batch_size, k]
|
|
122
|
-
|
|
123
|
-
#Final projection u^T_R f(...)
|
|
124
|
-
scores = torch.matmul(activated, self.u_R) # [batch_size]
|
|
125
|
-
|
|
126
|
-
return scores
|
|
127
|
-
|
|
128
|
-
def forward_single(self, head_embedding: torch.Tensor,
|
|
129
|
-
relation_embedding: torch.Tensor,
|
|
130
|
-
tail_embedding: torch.Tensor) -> torch.Tensor:
|
|
131
|
-
"""Forward pass for single triple"""
|
|
132
|
-
# Add batch dimension
|
|
133
|
-
head_batch = head_embedding.unsqueeze(0)
|
|
134
|
-
rel_batch = relation_embedding.unsqueeze(0)
|
|
135
|
-
tail_batch = tail_embedding.unsqueeze(0)
|
|
136
|
-
|
|
137
|
-
score = self.forward(head_batch, rel_batch, tail_batch)
|
|
138
|
-
return score.squeeze(0)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
class NPLLScoringModule(nn.Module):
|
|
142
|
-
|
|
143
|
-
def __init__(self, config: NPLLConfig):
|
|
144
|
-
super().__init__()
|
|
145
|
-
self.config = config
|
|
146
|
-
|
|
147
|
-
# Embedding manager for entities and relations
|
|
148
|
-
self.embedding_manager = EmbeddingManager(config)
|
|
149
|
-
|
|
150
|
-
# Bilinear scoring function (Equation 7)
|
|
151
|
-
self.scoring_function = BilinearScoringFunction(config)
|
|
152
|
-
|
|
153
|
-
# Temperature parameter for calibration (paper mentions temperature scaling)
|
|
154
|
-
self.temperature = nn.Parameter(torch.tensor(config.temperature))
|
|
155
|
-
|
|
156
|
-
def forward(self, triples: List[Triple]) -> torch.Tensor:
|
|
157
|
-
"""
|
|
158
|
-
Score a batch of triples
|
|
159
|
-
|
|
160
|
-
Args:
|
|
161
|
-
triples: List of Triple objects to score
|
|
162
|
-
|
|
163
|
-
Returns:
|
|
164
|
-
scores: [batch_size] raw scores g(l, eh, et)
|
|
165
|
-
"""
|
|
166
|
-
if not triples:
|
|
167
|
-
return torch.tensor([])
|
|
168
|
-
|
|
169
|
-
# Extract entity and relation names
|
|
170
|
-
head_names = [triple.head.name for triple in triples]
|
|
171
|
-
relation_names = [triple.relation.name for triple in triples]
|
|
172
|
-
tail_names = [triple.tail.name for triple in triples]
|
|
173
|
-
|
|
174
|
-
# Get embeddings
|
|
175
|
-
head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
|
|
176
|
-
head_names, relation_names, tail_names, add_if_missing=False
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
# Compute scores using Equation 7
|
|
180
|
-
scores = self.scoring_function(head_embs, rel_embs, tail_embs)
|
|
181
|
-
|
|
182
|
-
return scores
|
|
183
|
-
|
|
184
|
-
def forward_with_names(self, head_names: List[str],
|
|
185
|
-
relation_names: List[str],
|
|
186
|
-
tail_names: List[str]) -> torch.Tensor:
|
|
187
|
-
"""Score triples given entity/relation names directly"""
|
|
188
|
-
# Get embeddings
|
|
189
|
-
head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
|
|
190
|
-
head_names, relation_names, tail_names, add_if_missing=False
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
# Compute scores
|
|
194
|
-
scores = self.scoring_function(head_embs, rel_embs, tail_embs)
|
|
195
|
-
|
|
196
|
-
return scores
|
|
197
|
-
|
|
198
|
-
def score_single_triple(self, triple: Triple) -> torch.Tensor:
|
|
199
|
-
"""Score a single triple"""
|
|
200
|
-
head_emb, rel_emb, tail_emb = self.embedding_manager.get_triple_embeddings(triple, add_if_missing=False)
|
|
201
|
-
|
|
202
|
-
# Update entity embeddings
|
|
203
|
-
head_emb = self.embedding_manager.update_entity_embeddings(head_emb.unsqueeze(0)).squeeze(0)
|
|
204
|
-
tail_emb = self.embedding_manager.update_entity_embeddings(tail_emb.unsqueeze(0)).squeeze(0)
|
|
205
|
-
|
|
206
|
-
score = self.scoring_function.forward_single(head_emb, rel_emb, tail_emb)
|
|
207
|
-
return score
|
|
208
|
-
|
|
209
|
-
def get_probabilities(self, triples: List[Triple],
|
|
210
|
-
apply_temperature: bool = True) -> torch.Tensor:
|
|
211
|
-
"""
|
|
212
|
-
Get probability scores p = sigmoid(g(l, eh, et))
|
|
213
|
-
"""
|
|
214
|
-
scores = self.forward(triples)
|
|
215
|
-
|
|
216
|
-
if apply_temperature:
|
|
217
|
-
# Apply temperature scaling for calibration
|
|
218
|
-
scores = scores / self.temperature
|
|
219
|
-
|
|
220
|
-
# Apply sigmoid transformation
|
|
221
|
-
probabilities = safe_sigmoid(scores)
|
|
222
|
-
|
|
223
|
-
return probabilities
|
|
224
|
-
|
|
225
|
-
def compute_fact_scores(self, known_facts: List[Triple],
|
|
226
|
-
unknown_facts: List[Triple]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
227
|
-
"""
|
|
228
|
-
Compute scores for known and unknown facts
|
|
229
|
-
|
|
230
|
-
Used in E-step for computing approximate posterior Q(U)
|
|
231
|
-
"""
|
|
232
|
-
known_scores = self.forward(known_facts) if known_facts else torch.tensor([])
|
|
233
|
-
unknown_scores = self.forward(unknown_facts) if unknown_facts else torch.tensor([])
|
|
234
|
-
|
|
235
|
-
return known_scores, unknown_scores
|
|
236
|
-
|
|
237
|
-
def build_vocabulary(self, kg):
|
|
238
|
-
"""Build vocabulary from knowledge graph"""
|
|
239
|
-
self.embedding_manager.build_vocabulary_from_kg(kg)
|
|
240
|
-
|
|
241
|
-
@property
|
|
242
|
-
def entity_vocab_size(self) -> int:
|
|
243
|
-
"""Number of entities in vocabulary"""
|
|
244
|
-
return self.embedding_manager.entity_vocab_size
|
|
245
|
-
|
|
246
|
-
@property
|
|
247
|
-
def relation_vocab_size(self) -> int:
|
|
248
|
-
"""Number of relations in vocabulary"""
|
|
249
|
-
return self.embedding_manager.relation_vocab_size
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
class BatchedScoringModule(nn.Module):
|
|
253
|
-
"""
|
|
254
|
-
Optimized scoring module for large-scale batch processing
|
|
255
|
-
Implements memory-efficient batching for scoring many triples
|
|
256
|
-
"""
|
|
257
|
-
|
|
258
|
-
def __init__(self, config: NPLLConfig, base_scoring_module: NPLLScoringModule):
|
|
259
|
-
super().__init__()
|
|
260
|
-
self.config = config
|
|
261
|
-
self.base_module = base_scoring_module
|
|
262
|
-
self.batch_size = config.batch_size
|
|
263
|
-
|
|
264
|
-
def score_large_batch(self, triples: List[Triple]) -> torch.Tensor:
|
|
265
|
-
"""Score large batch of triples with memory-efficient batching"""
|
|
266
|
-
all_scores = []
|
|
267
|
-
|
|
268
|
-
# Process in smaller batches
|
|
269
|
-
for i in range(0, len(triples), self.batch_size):
|
|
270
|
-
batch_triples = triples[i:i + self.batch_size]
|
|
271
|
-
batch_scores = self.base_module.forward(batch_triples)
|
|
272
|
-
all_scores.append(batch_scores)
|
|
273
|
-
|
|
274
|
-
# Concatenate all scores
|
|
275
|
-
if all_scores:
|
|
276
|
-
return torch.cat(all_scores, dim=0)
|
|
277
|
-
else:
|
|
278
|
-
return torch.tensor([])
|
|
279
|
-
|
|
280
|
-
def score_all_possible_triples(self, entities: List[str], relations: List[str]) -> torch.Tensor:
|
|
281
|
-
"""
|
|
282
|
-
Score all possible (head, relation, tail) combinations
|
|
283
|
-
Warning: This can be very memory intensive for large vocabularies
|
|
284
|
-
"""
|
|
285
|
-
all_triples = []
|
|
286
|
-
|
|
287
|
-
# Generate all possible combinations
|
|
288
|
-
for head in entities:
|
|
289
|
-
for relation in relations:
|
|
290
|
-
for tail in entities:
|
|
291
|
-
if head != tail: # Avoid self-loops typically
|
|
292
|
-
all_triples.append((head, relation, tail))
|
|
293
|
-
|
|
294
|
-
logger.warning(f"Scoring {len(all_triples)} possible triples - this may be memory intensive")
|
|
295
|
-
|
|
296
|
-
# Score in batches
|
|
297
|
-
all_scores = []
|
|
298
|
-
for i in range(0, len(all_triples), self.batch_size):
|
|
299
|
-
batch_triples = all_triples[i:i + self.batch_size]
|
|
300
|
-
|
|
301
|
-
# Extract names
|
|
302
|
-
head_names = [t[0] for t in batch_triples]
|
|
303
|
-
rel_names = [t[1] for t in batch_triples]
|
|
304
|
-
tail_names = [t[2] for t in batch_triples]
|
|
305
|
-
|
|
306
|
-
batch_scores = self.base_module.forward_with_names(head_names, rel_names, tail_names)
|
|
307
|
-
all_scores.append(batch_scores)
|
|
308
|
-
|
|
309
|
-
if all_scores:
|
|
310
|
-
return torch.cat(all_scores, dim=0)
|
|
311
|
-
else:
|
|
312
|
-
return torch.tensor([])
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
def create_scoring_module(config: NPLLConfig, kg=None) -> NPLLScoringModule:
|
|
316
|
-
"""
|
|
317
|
-
Factory function to create and initialize NPLL scoring module
|
|
318
|
-
|
|
319
|
-
Args:
|
|
320
|
-
config: NPLL configuration
|
|
321
|
-
kg: Optional knowledge graph to build vocabulary from
|
|
322
|
-
|
|
323
|
-
Returns:
|
|
324
|
-
Initialized NPLLScoringModule
|
|
325
|
-
"""
|
|
326
|
-
scoring_module = NPLLScoringModule(config)
|
|
327
|
-
|
|
328
|
-
if kg is not None:
|
|
329
|
-
scoring_module.build_vocabulary(kg)
|
|
330
|
-
logger.info(f"Built vocabulary: {scoring_module.entity_vocab_size} entities, "
|
|
331
|
-
f"{scoring_module.relation_vocab_size} relations")
|
|
332
|
-
|
|
333
|
-
return scoring_module
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
def verify_equation7_implementation():
|
|
337
|
-
"""
|
|
338
|
-
Verification function to ensure Equation 7 is implemented correctly
|
|
339
|
-
Tests the mathematical operations step by step
|
|
340
|
-
"""
|
|
341
|
-
from ..utils.config import default_config
|
|
342
|
-
|
|
343
|
-
config = default_config
|
|
344
|
-
scoring_func = BilinearScoringFunction(config)
|
|
345
|
-
|
|
346
|
-
# Create test inputs
|
|
347
|
-
batch_size = 3
|
|
348
|
-
d = config.entity_embedding_dim
|
|
349
|
-
k = config.scoring_hidden_dim
|
|
350
|
-
|
|
351
|
-
head_emb = torch.randn(batch_size, d)
|
|
352
|
-
tail_emb = torch.randn(batch_size, d)
|
|
353
|
-
rel_emb = torch.randn(batch_size, d) # Not used in Eq 7 but kept for interface
|
|
354
|
-
|
|
355
|
-
# Forward pass
|
|
356
|
-
scores = scoring_func(head_emb, rel_emb, tail_emb)
|
|
357
|
-
|
|
358
|
-
# Verify output shape
|
|
359
|
-
assert scores.shape == (batch_size,), f"Expected shape ({batch_size},), got {scores.shape}"
|
|
360
|
-
|
|
361
|
-
# Verify no NaN values
|
|
362
|
-
assert not torch.isnan(scores).any(), "Output contains NaN values"
|
|
363
|
-
|
|
364
|
-
# Test single input
|
|
365
|
-
single_score = scoring_func.forward_single(head_emb[0], rel_emb[0], tail_emb[0])
|
|
366
|
-
assert abs(single_score.item() - scores[0].item()) < 1e-5, "Single vs batch mismatch"
|
|
367
|
-
|
|
368
|
-
logger.info("Equation 7 implementation verified successfully")
|
|
369
|
-
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from typing import List, Dict, Tuple, Optional
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from ..core import Triple
|
|
9
|
+
from ..utils.config import NPLLConfig
|
|
10
|
+
from ..utils.math_utils import safe_sigmoid
|
|
11
|
+
from .embeddings import EmbeddingManager
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BilinearScoringFunction(nn.Module):
|
|
17
|
+
|
|
18
|
+
def __init__(self, config: NPLLConfig):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.config = config
|
|
21
|
+
|
|
22
|
+
# Dimensions from paper
|
|
23
|
+
self.entity_dim = config.entity_embedding_dim # d
|
|
24
|
+
self.relation_dim = config.relation_embedding_dim # d
|
|
25
|
+
self.hidden_dim = config.scoring_hidden_dim # k
|
|
26
|
+
|
|
27
|
+
# Ensure dimensions match paper assumptions
|
|
28
|
+
assert self.entity_dim == self.relation_dim, \
|
|
29
|
+
"Paper assumes entity and relation embeddings have same dimension d"
|
|
30
|
+
|
|
31
|
+
# Bilinear tensor W_R: d×d×k dimensional tensor
|
|
32
|
+
# For efficiency, we implement this as k separate d×d matrices
|
|
33
|
+
self.W_R = nn.Parameter(torch.zeros(self.hidden_dim, self.entity_dim, self.entity_dim))
|
|
34
|
+
|
|
35
|
+
# Linear tensor V_R: k×2d dimensional tensor
|
|
36
|
+
# Maps concatenated [eh; et] to k dimensions
|
|
37
|
+
self.V_R = nn.Parameter(torch.zeros(self.hidden_dim, 2 * self.entity_dim))
|
|
38
|
+
|
|
39
|
+
# Output projection u_R: k-dimensional vector
|
|
40
|
+
self.u_R = nn.Parameter(torch.zeros(self.hidden_dim))
|
|
41
|
+
|
|
42
|
+
# Bias term b_R: k-dimensional vector
|
|
43
|
+
self.b_R = nn.Parameter(torch.zeros(self.hidden_dim))
|
|
44
|
+
|
|
45
|
+
# Activation function f (paper uses ReLU)
|
|
46
|
+
if config.scoring_activation == "relu":
|
|
47
|
+
self.activation = F.relu
|
|
48
|
+
elif config.scoring_activation == "tanh":
|
|
49
|
+
self.activation = torch.tanh
|
|
50
|
+
elif config.scoring_activation == "gelu":
|
|
51
|
+
self.activation = F.gelu
|
|
52
|
+
else:
|
|
53
|
+
self.activation = F.relu
|
|
54
|
+
|
|
55
|
+
# Initialize parameters
|
|
56
|
+
self._initialize_parameters()
|
|
57
|
+
|
|
58
|
+
def _initialize_parameters(self):
|
|
59
|
+
"""Initialize parameters using Xavier/Glorot initialization"""
|
|
60
|
+
# Initialize bilinear tensor W_R
|
|
61
|
+
nn.init.xavier_uniform_(self.W_R.data)
|
|
62
|
+
|
|
63
|
+
# Initialize linear tensor V_R
|
|
64
|
+
nn.init.xavier_uniform_(self.V_R.data)
|
|
65
|
+
|
|
66
|
+
# Initialize output projection u_R
|
|
67
|
+
nn.init.xavier_uniform_(self.u_R.data.unsqueeze(0))
|
|
68
|
+
self.u_R.data.squeeze_(0)
|
|
69
|
+
|
|
70
|
+
# Initialize bias b_R to small values
|
|
71
|
+
nn.init.constant_(self.b_R.data, 0.1)
|
|
72
|
+
|
|
73
|
+
def forward(self, head_embeddings: torch.Tensor,
|
|
74
|
+
relation_embeddings: torch.Tensor,
|
|
75
|
+
tail_embeddings: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
"""
|
|
77
|
+
Forward pass implementing Equation 7
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
head_embeddings: [batch_size, entity_dim] - eh vectors
|
|
81
|
+
relation_embeddings: [batch_size, relation_dim] - relation vectors (unused in Eq 7 but kept for completeness)
|
|
82
|
+
tail_embeddings: [batch_size, entity_dim] - et vectors
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
scores: [batch_size] - g(l, eh, et) scores
|
|
86
|
+
"""
|
|
87
|
+
batch_size = head_embeddings.size(0)
|
|
88
|
+
|
|
89
|
+
# Step 1: Compute bilinear term e^T_h W_R et
|
|
90
|
+
# W_R has shape [k, d, d], we need to compute bilinear for each k
|
|
91
|
+
bilinear_terms = []
|
|
92
|
+
|
|
93
|
+
for i in range(self.hidden_dim):
|
|
94
|
+
# For each k-th slice: e^T_h W_R[i] et
|
|
95
|
+
# head_embeddings: [batch_size, d]
|
|
96
|
+
# W_R[i]: [d, d]
|
|
97
|
+
# tail_embeddings: [batch_size, d]
|
|
98
|
+
|
|
99
|
+
# Compute head_embeddings @ W_R[i] -> [batch_size, d]
|
|
100
|
+
temp = torch.matmul(head_embeddings, self.W_R[i])
|
|
101
|
+
|
|
102
|
+
# Compute (head_embeddings @ W_R[i]) @ tail_embeddings^T -> [batch_size]
|
|
103
|
+
bilinear_i = torch.sum(temp * tail_embeddings, dim=1)
|
|
104
|
+
bilinear_terms.append(bilinear_i)
|
|
105
|
+
|
|
106
|
+
# Stack bilinear terms: [batch_size, k]
|
|
107
|
+
bilinear_output = torch.stack(bilinear_terms, dim=1)
|
|
108
|
+
|
|
109
|
+
#Compute linear term V_R [eh; et]
|
|
110
|
+
# Concatenate head and tail embeddings: [batch_size, 2d]
|
|
111
|
+
concatenated = torch.cat([head_embeddings, tail_embeddings], dim=1)
|
|
112
|
+
|
|
113
|
+
# Linear transformation: V_R @ [eh; et]^T -> [batch_size, k]
|
|
114
|
+
linear_output = torch.matmul(concatenated, self.V_R.transpose(0, 1))
|
|
115
|
+
|
|
116
|
+
# Combine terms inside activation
|
|
117
|
+
# e^T_h W_R et + V_R [eh; et] + b_R
|
|
118
|
+
combined = bilinear_output + linear_output + self.b_R.unsqueeze(0)
|
|
119
|
+
|
|
120
|
+
# Apply non-linear activation f
|
|
121
|
+
activated = self.activation(combined) # [batch_size, k]
|
|
122
|
+
|
|
123
|
+
#Final projection u^T_R f(...)
|
|
124
|
+
scores = torch.matmul(activated, self.u_R) # [batch_size]
|
|
125
|
+
|
|
126
|
+
return scores
|
|
127
|
+
|
|
128
|
+
def forward_single(self, head_embedding: torch.Tensor,
|
|
129
|
+
relation_embedding: torch.Tensor,
|
|
130
|
+
tail_embedding: torch.Tensor) -> torch.Tensor:
|
|
131
|
+
"""Forward pass for single triple"""
|
|
132
|
+
# Add batch dimension
|
|
133
|
+
head_batch = head_embedding.unsqueeze(0)
|
|
134
|
+
rel_batch = relation_embedding.unsqueeze(0)
|
|
135
|
+
tail_batch = tail_embedding.unsqueeze(0)
|
|
136
|
+
|
|
137
|
+
score = self.forward(head_batch, rel_batch, tail_batch)
|
|
138
|
+
return score.squeeze(0)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class NPLLScoringModule(nn.Module):
|
|
142
|
+
|
|
143
|
+
def __init__(self, config: NPLLConfig):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.config = config
|
|
146
|
+
|
|
147
|
+
# Embedding manager for entities and relations
|
|
148
|
+
self.embedding_manager = EmbeddingManager(config)
|
|
149
|
+
|
|
150
|
+
# Bilinear scoring function (Equation 7)
|
|
151
|
+
self.scoring_function = BilinearScoringFunction(config)
|
|
152
|
+
|
|
153
|
+
# Temperature parameter for calibration (paper mentions temperature scaling)
|
|
154
|
+
self.temperature = nn.Parameter(torch.tensor(config.temperature))
|
|
155
|
+
|
|
156
|
+
def forward(self, triples: List[Triple]) -> torch.Tensor:
|
|
157
|
+
"""
|
|
158
|
+
Score a batch of triples
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
triples: List of Triple objects to score
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
scores: [batch_size] raw scores g(l, eh, et)
|
|
165
|
+
"""
|
|
166
|
+
if not triples:
|
|
167
|
+
return torch.tensor([])
|
|
168
|
+
|
|
169
|
+
# Extract entity and relation names
|
|
170
|
+
head_names = [triple.head.name for triple in triples]
|
|
171
|
+
relation_names = [triple.relation.name for triple in triples]
|
|
172
|
+
tail_names = [triple.tail.name for triple in triples]
|
|
173
|
+
|
|
174
|
+
# Get embeddings
|
|
175
|
+
head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
|
|
176
|
+
head_names, relation_names, tail_names, add_if_missing=False
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Compute scores using Equation 7
|
|
180
|
+
scores = self.scoring_function(head_embs, rel_embs, tail_embs)
|
|
181
|
+
|
|
182
|
+
return scores
|
|
183
|
+
|
|
184
|
+
def forward_with_names(self, head_names: List[str],
|
|
185
|
+
relation_names: List[str],
|
|
186
|
+
tail_names: List[str]) -> torch.Tensor:
|
|
187
|
+
"""Score triples given entity/relation names directly"""
|
|
188
|
+
# Get embeddings
|
|
189
|
+
head_embs, rel_embs, tail_embs = self.embedding_manager.get_embeddings_for_scoring(
|
|
190
|
+
head_names, relation_names, tail_names, add_if_missing=False
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Compute scores
|
|
194
|
+
scores = self.scoring_function(head_embs, rel_embs, tail_embs)
|
|
195
|
+
|
|
196
|
+
return scores
|
|
197
|
+
|
|
198
|
+
def score_single_triple(self, triple: Triple) -> torch.Tensor:
|
|
199
|
+
"""Score a single triple"""
|
|
200
|
+
head_emb, rel_emb, tail_emb = self.embedding_manager.get_triple_embeddings(triple, add_if_missing=False)
|
|
201
|
+
|
|
202
|
+
# Update entity embeddings
|
|
203
|
+
head_emb = self.embedding_manager.update_entity_embeddings(head_emb.unsqueeze(0)).squeeze(0)
|
|
204
|
+
tail_emb = self.embedding_manager.update_entity_embeddings(tail_emb.unsqueeze(0)).squeeze(0)
|
|
205
|
+
|
|
206
|
+
score = self.scoring_function.forward_single(head_emb, rel_emb, tail_emb)
|
|
207
|
+
return score
|
|
208
|
+
|
|
209
|
+
def get_probabilities(self, triples: List[Triple],
|
|
210
|
+
apply_temperature: bool = True) -> torch.Tensor:
|
|
211
|
+
"""
|
|
212
|
+
Get probability scores p = sigmoid(g(l, eh, et))
|
|
213
|
+
"""
|
|
214
|
+
scores = self.forward(triples)
|
|
215
|
+
|
|
216
|
+
if apply_temperature:
|
|
217
|
+
# Apply temperature scaling for calibration
|
|
218
|
+
scores = scores / self.temperature
|
|
219
|
+
|
|
220
|
+
# Apply sigmoid transformation
|
|
221
|
+
probabilities = safe_sigmoid(scores)
|
|
222
|
+
|
|
223
|
+
return probabilities
|
|
224
|
+
|
|
225
|
+
def compute_fact_scores(self, known_facts: List[Triple],
|
|
226
|
+
unknown_facts: List[Triple]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
227
|
+
"""
|
|
228
|
+
Compute scores for known and unknown facts
|
|
229
|
+
|
|
230
|
+
Used in E-step for computing approximate posterior Q(U)
|
|
231
|
+
"""
|
|
232
|
+
known_scores = self.forward(known_facts) if known_facts else torch.tensor([])
|
|
233
|
+
unknown_scores = self.forward(unknown_facts) if unknown_facts else torch.tensor([])
|
|
234
|
+
|
|
235
|
+
return known_scores, unknown_scores
|
|
236
|
+
|
|
237
|
+
def build_vocabulary(self, kg):
|
|
238
|
+
"""Build vocabulary from knowledge graph"""
|
|
239
|
+
self.embedding_manager.build_vocabulary_from_kg(kg)
|
|
240
|
+
|
|
241
|
+
@property
|
|
242
|
+
def entity_vocab_size(self) -> int:
|
|
243
|
+
"""Number of entities in vocabulary"""
|
|
244
|
+
return self.embedding_manager.entity_vocab_size
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def relation_vocab_size(self) -> int:
|
|
248
|
+
"""Number of relations in vocabulary"""
|
|
249
|
+
return self.embedding_manager.relation_vocab_size
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class BatchedScoringModule(nn.Module):
|
|
253
|
+
"""
|
|
254
|
+
Optimized scoring module for large-scale batch processing
|
|
255
|
+
Implements memory-efficient batching for scoring many triples
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(self, config: NPLLConfig, base_scoring_module: NPLLScoringModule):
|
|
259
|
+
super().__init__()
|
|
260
|
+
self.config = config
|
|
261
|
+
self.base_module = base_scoring_module
|
|
262
|
+
self.batch_size = config.batch_size
|
|
263
|
+
|
|
264
|
+
def score_large_batch(self, triples: List[Triple]) -> torch.Tensor:
|
|
265
|
+
"""Score large batch of triples with memory-efficient batching"""
|
|
266
|
+
all_scores = []
|
|
267
|
+
|
|
268
|
+
# Process in smaller batches
|
|
269
|
+
for i in range(0, len(triples), self.batch_size):
|
|
270
|
+
batch_triples = triples[i:i + self.batch_size]
|
|
271
|
+
batch_scores = self.base_module.forward(batch_triples)
|
|
272
|
+
all_scores.append(batch_scores)
|
|
273
|
+
|
|
274
|
+
# Concatenate all scores
|
|
275
|
+
if all_scores:
|
|
276
|
+
return torch.cat(all_scores, dim=0)
|
|
277
|
+
else:
|
|
278
|
+
return torch.tensor([])
|
|
279
|
+
|
|
280
|
+
def score_all_possible_triples(self, entities: List[str], relations: List[str]) -> torch.Tensor:
|
|
281
|
+
"""
|
|
282
|
+
Score all possible (head, relation, tail) combinations
|
|
283
|
+
Warning: This can be very memory intensive for large vocabularies
|
|
284
|
+
"""
|
|
285
|
+
all_triples = []
|
|
286
|
+
|
|
287
|
+
# Generate all possible combinations
|
|
288
|
+
for head in entities:
|
|
289
|
+
for relation in relations:
|
|
290
|
+
for tail in entities:
|
|
291
|
+
if head != tail: # Avoid self-loops typically
|
|
292
|
+
all_triples.append((head, relation, tail))
|
|
293
|
+
|
|
294
|
+
logger.warning(f"Scoring {len(all_triples)} possible triples - this may be memory intensive")
|
|
295
|
+
|
|
296
|
+
# Score in batches
|
|
297
|
+
all_scores = []
|
|
298
|
+
for i in range(0, len(all_triples), self.batch_size):
|
|
299
|
+
batch_triples = all_triples[i:i + self.batch_size]
|
|
300
|
+
|
|
301
|
+
# Extract names
|
|
302
|
+
head_names = [t[0] for t in batch_triples]
|
|
303
|
+
rel_names = [t[1] for t in batch_triples]
|
|
304
|
+
tail_names = [t[2] for t in batch_triples]
|
|
305
|
+
|
|
306
|
+
batch_scores = self.base_module.forward_with_names(head_names, rel_names, tail_names)
|
|
307
|
+
all_scores.append(batch_scores)
|
|
308
|
+
|
|
309
|
+
if all_scores:
|
|
310
|
+
return torch.cat(all_scores, dim=0)
|
|
311
|
+
else:
|
|
312
|
+
return torch.tensor([])
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def create_scoring_module(config: NPLLConfig, kg=None) -> NPLLScoringModule:
|
|
316
|
+
"""
|
|
317
|
+
Factory function to create and initialize NPLL scoring module
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
config: NPLL configuration
|
|
321
|
+
kg: Optional knowledge graph to build vocabulary from
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Initialized NPLLScoringModule
|
|
325
|
+
"""
|
|
326
|
+
scoring_module = NPLLScoringModule(config)
|
|
327
|
+
|
|
328
|
+
if kg is not None:
|
|
329
|
+
scoring_module.build_vocabulary(kg)
|
|
330
|
+
logger.info(f"Built vocabulary: {scoring_module.entity_vocab_size} entities, "
|
|
331
|
+
f"{scoring_module.relation_vocab_size} relations")
|
|
332
|
+
|
|
333
|
+
return scoring_module
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def verify_equation7_implementation():
|
|
337
|
+
"""
|
|
338
|
+
Verification function to ensure Equation 7 is implemented correctly
|
|
339
|
+
Tests the mathematical operations step by step
|
|
340
|
+
"""
|
|
341
|
+
from ..utils.config import default_config
|
|
342
|
+
|
|
343
|
+
config = default_config
|
|
344
|
+
scoring_func = BilinearScoringFunction(config)
|
|
345
|
+
|
|
346
|
+
# Create test inputs
|
|
347
|
+
batch_size = 3
|
|
348
|
+
d = config.entity_embedding_dim
|
|
349
|
+
k = config.scoring_hidden_dim
|
|
350
|
+
|
|
351
|
+
head_emb = torch.randn(batch_size, d)
|
|
352
|
+
tail_emb = torch.randn(batch_size, d)
|
|
353
|
+
rel_emb = torch.randn(batch_size, d) # Not used in Eq 7 but kept for interface
|
|
354
|
+
|
|
355
|
+
# Forward pass
|
|
356
|
+
scores = scoring_func(head_emb, rel_emb, tail_emb)
|
|
357
|
+
|
|
358
|
+
# Verify output shape
|
|
359
|
+
assert scores.shape == (batch_size,), f"Expected shape ({batch_size},), got {scores.shape}"
|
|
360
|
+
|
|
361
|
+
# Verify no NaN values
|
|
362
|
+
assert not torch.isnan(scores).any(), "Output contains NaN values"
|
|
363
|
+
|
|
364
|
+
# Test single input
|
|
365
|
+
single_score = scoring_func.forward_single(head_emb[0], rel_emb[0], tail_emb[0])
|
|
366
|
+
assert abs(single_score.item() - scores[0].item()) < 1e-5, "Single vs batch mismatch"
|
|
367
|
+
|
|
368
|
+
logger.info("Equation 7 implementation verified successfully")
|
|
369
|
+
|
|
370
370
|
return True
|