odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/scoring/embeddings.py
CHANGED
|
@@ -1,442 +1,442 @@
|
|
|
1
|
-
|
|
2
|
-
import torch
|
|
3
|
-
import torch.nn as nn
|
|
4
|
-
import torch.nn.functional as F
|
|
5
|
-
from typing import Dict, List, Optional, Set, Tuple
|
|
6
|
-
import logging
|
|
7
|
-
from collections import OrderedDict
|
|
8
|
-
|
|
9
|
-
from ..core import KnowledgeGraph, Entity, Relation, Triple
|
|
10
|
-
from ..utils.config import NPLLConfig
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class EntityEmbedding(nn.Module):
|
|
16
|
-
"""
|
|
17
|
-
Entity embedding layer with dynamic vocabulary expansion
|
|
18
|
-
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def __init__(self, config: NPLLConfig, initial_vocab_size: int = 10000):
|
|
22
|
-
super().__init__()
|
|
23
|
-
self.config = config
|
|
24
|
-
self.embedding_dim = config.entity_embedding_dim
|
|
25
|
-
|
|
26
|
-
# Entity vocabulary mapping: entity_name -> index
|
|
27
|
-
self.entity_to_idx: Dict[str, int] = {}
|
|
28
|
-
self.idx_to_entity: Dict[int, str] = {}
|
|
29
|
-
|
|
30
|
-
# Reserve indices: 0 padding, 1 OOV; new entries from 2+
|
|
31
|
-
self.padding_idx = 0
|
|
32
|
-
self.oov_idx = 1
|
|
33
|
-
# Embedding layer - will expand dynamically
|
|
34
|
-
self.embedding = nn.Embedding(
|
|
35
|
-
num_embeddings=initial_vocab_size,
|
|
36
|
-
embedding_dim=self.embedding_dim,
|
|
37
|
-
padding_idx=self.padding_idx
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Initialize embeddings using Xavier initialization
|
|
41
|
-
nn.init.xavier_uniform_(self.embedding.weight.data)
|
|
42
|
-
# Set padding embedding to zero; init OOV to a valid vector
|
|
43
|
-
with torch.no_grad():
|
|
44
|
-
self.embedding.weight[self.padding_idx].zero_()
|
|
45
|
-
# Track next available index (0 pad, 1 oov reserved)
|
|
46
|
-
self._next_idx = 2
|
|
47
|
-
|
|
48
|
-
def add_entity(self, entity_name: str) -> int:
|
|
49
|
-
"""Add entity to vocabulary and return its index"""
|
|
50
|
-
if entity_name not in self.entity_to_idx:
|
|
51
|
-
# Check if we need to expand embedding layer
|
|
52
|
-
if self._next_idx >= self.embedding.num_embeddings:
|
|
53
|
-
self._expand_embeddings()
|
|
54
|
-
|
|
55
|
-
# Add entity to vocabulary
|
|
56
|
-
idx = self._next_idx
|
|
57
|
-
self.entity_to_idx[entity_name] = idx
|
|
58
|
-
self.idx_to_entity[idx] = entity_name
|
|
59
|
-
self._next_idx += 1
|
|
60
|
-
|
|
61
|
-
logger.debug(f"Added entity '{entity_name}' with index {idx}")
|
|
62
|
-
|
|
63
|
-
return self.entity_to_idx[entity_name]
|
|
64
|
-
|
|
65
|
-
def _expand_embeddings(self, grow_by: int = 1000, on_expand=None):
|
|
66
|
-
"""Expand embedding layer when vocabulary grows (preserve device/dtype)."""
|
|
67
|
-
old = self.embedding
|
|
68
|
-
old_size = old.num_embeddings
|
|
69
|
-
new_size = max(old_size * 2, self._next_idx + grow_by)
|
|
70
|
-
new = nn.Embedding(
|
|
71
|
-
num_embeddings=new_size,
|
|
72
|
-
embedding_dim=self.embedding_dim,
|
|
73
|
-
padding_idx=self.padding_idx,
|
|
74
|
-
dtype=old.weight.dtype,
|
|
75
|
-
device=old.weight.device,
|
|
76
|
-
)
|
|
77
|
-
with torch.no_grad():
|
|
78
|
-
new.weight[:old_size].copy_(old.weight)
|
|
79
|
-
nn.init.xavier_uniform_(new.weight[old_size:])
|
|
80
|
-
new.weight[self.padding_idx].zero_()
|
|
81
|
-
self.embedding = new
|
|
82
|
-
if on_expand is not None:
|
|
83
|
-
on_expand(self.embedding)
|
|
84
|
-
logger.info(f"Expanded entity embeddings from {old_size} to {new_size}")
|
|
85
|
-
|
|
86
|
-
def get_entity_index(self, entity_name: str, add_if_missing: bool = False) -> int:
|
|
87
|
-
"""Get index for entity; returns OOV if missing and add_if_missing=False."""
|
|
88
|
-
idx = self.entity_to_idx.get(entity_name)
|
|
89
|
-
if idx is None:
|
|
90
|
-
return self.add_entity(entity_name) if add_if_missing else self.oov_idx
|
|
91
|
-
return idx
|
|
92
|
-
|
|
93
|
-
def get_entity_name(self, idx: int) -> Optional[str]:
|
|
94
|
-
"""Get entity name from index"""
|
|
95
|
-
return self.idx_to_entity.get(idx)
|
|
96
|
-
|
|
97
|
-
def get_embedding(self, entity_name: str, add_if_missing: bool = False) -> torch.Tensor:
|
|
98
|
-
"""Get embedding vector for entity (device-safe)."""
|
|
99
|
-
idx = self.get_entity_index(entity_name, add_if_missing=add_if_missing)
|
|
100
|
-
device = self.embedding.weight.device
|
|
101
|
-
return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
|
|
102
|
-
|
|
103
|
-
def get_embeddings_batch(self, entity_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
|
|
104
|
-
"""Get embedding vectors for batch of entities (device-safe)."""
|
|
105
|
-
indices = [self.get_entity_index(name, add_if_missing=add_if_missing) for name in entity_names]
|
|
106
|
-
device = self.embedding.weight.device
|
|
107
|
-
indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
|
|
108
|
-
return self.embedding(indices_tensor)
|
|
109
|
-
|
|
110
|
-
def forward(self, entity_indices: torch.Tensor) -> torch.Tensor:
|
|
111
|
-
"""Forward pass for entity embeddings"""
|
|
112
|
-
return self.embedding(entity_indices)
|
|
113
|
-
|
|
114
|
-
@property
|
|
115
|
-
def vocab_size(self) -> int:
|
|
116
|
-
"""Current vocabulary size"""
|
|
117
|
-
return len(self.entity_to_idx)
|
|
118
|
-
|
|
119
|
-
def state_dict_with_vocab(self) -> Dict:
|
|
120
|
-
"""Get state dict including vocabulary mappings"""
|
|
121
|
-
state = super().state_dict()
|
|
122
|
-
state['entity_to_idx'] = self.entity_to_idx.copy()
|
|
123
|
-
state['idx_to_entity'] = self.idx_to_entity.copy()
|
|
124
|
-
state['_next_idx'] = self._next_idx
|
|
125
|
-
state['padding_idx'] = self.padding_idx
|
|
126
|
-
state['oov_idx'] = self.oov_idx
|
|
127
|
-
return state
|
|
128
|
-
|
|
129
|
-
def load_state_dict_with_vocab(self, state_dict: Dict):
|
|
130
|
-
"""Load state dict including vocabulary mappings"""
|
|
131
|
-
# Load vocabulary mappings first
|
|
132
|
-
self.entity_to_idx = state_dict.pop('entity_to_idx', {})
|
|
133
|
-
self.idx_to_entity = state_dict.pop('idx_to_entity', {})
|
|
134
|
-
self._next_idx = state_dict.pop('_next_idx', 2)
|
|
135
|
-
self.padding_idx = state_dict.pop('padding_idx', 0)
|
|
136
|
-
self.oov_idx = state_dict.pop('oov_idx', 1)
|
|
137
|
-
needed = max(self._next_idx, self.embedding.num_embeddings)
|
|
138
|
-
if needed > self.embedding.num_embeddings:
|
|
139
|
-
old = self.embedding
|
|
140
|
-
new = nn.Embedding(
|
|
141
|
-
num_embeddings=needed,
|
|
142
|
-
embedding_dim=self.embedding_dim,
|
|
143
|
-
padding_idx=self.padding_idx,
|
|
144
|
-
dtype=old.weight.dtype,
|
|
145
|
-
device=old.weight.device,
|
|
146
|
-
)
|
|
147
|
-
with torch.no_grad():
|
|
148
|
-
new.weight[:old.num_embeddings].copy_(old.weight)
|
|
149
|
-
nn.init.xavier_uniform_(new.weight[old.num_embeddings:])
|
|
150
|
-
new.weight[self.padding_idx].zero_()
|
|
151
|
-
self.embedding = new
|
|
152
|
-
# Load model parameters leniently in case of size diffs
|
|
153
|
-
super().load_state_dict(state_dict, strict=False)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
class RelationEmbedding(nn.Module):
|
|
157
|
-
"""
|
|
158
|
-
Relation embedding layer with dynamic vocabulary expansion
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
def __init__(self, config: NPLLConfig, initial_vocab_size: int = 1000):
|
|
162
|
-
super().__init__()
|
|
163
|
-
self.config = config
|
|
164
|
-
self.embedding_dim = config.relation_embedding_dim
|
|
165
|
-
|
|
166
|
-
# Relation vocabulary mapping
|
|
167
|
-
self.relation_to_idx: Dict[str, int] = {}
|
|
168
|
-
self.idx_to_relation: Dict[int, str] = {}
|
|
169
|
-
|
|
170
|
-
# Embedding layer
|
|
171
|
-
self.padding_idx = 0
|
|
172
|
-
self.oov_idx = 1
|
|
173
|
-
self.embedding = nn.Embedding(
|
|
174
|
-
num_embeddings=initial_vocab_size,
|
|
175
|
-
embedding_dim=self.embedding_dim,
|
|
176
|
-
padding_idx=self.padding_idx
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
# Initialize embeddings
|
|
180
|
-
nn.init.xavier_uniform_(self.embedding.weight.data)
|
|
181
|
-
with torch.no_grad():
|
|
182
|
-
self.embedding.weight[self.padding_idx].zero_()
|
|
183
|
-
|
|
184
|
-
self._next_idx = 2
|
|
185
|
-
|
|
186
|
-
def add_relation(self, relation_name: str) -> int:
|
|
187
|
-
"""Add relation to vocabulary and return its index"""
|
|
188
|
-
if relation_name not in self.relation_to_idx:
|
|
189
|
-
if self._next_idx >= self.embedding.num_embeddings:
|
|
190
|
-
self._expand_embeddings()
|
|
191
|
-
|
|
192
|
-
idx = self._next_idx
|
|
193
|
-
self.relation_to_idx[relation_name] = idx
|
|
194
|
-
self.idx_to_relation[idx] = relation_name
|
|
195
|
-
self._next_idx += 1
|
|
196
|
-
|
|
197
|
-
logger.debug(f"Added relation '{relation_name}' with index {idx}")
|
|
198
|
-
|
|
199
|
-
return self.relation_to_idx[relation_name]
|
|
200
|
-
|
|
201
|
-
def _expand_embeddings(self, grow_by: int = 100, on_expand=None):
|
|
202
|
-
"""Expand embedding layer when vocabulary grows (preserve device/dtype)."""
|
|
203
|
-
old = self.embedding
|
|
204
|
-
old_size = old.num_embeddings
|
|
205
|
-
new_size = max(old_size * 2, self._next_idx + grow_by)
|
|
206
|
-
new = nn.Embedding(
|
|
207
|
-
num_embeddings=new_size,
|
|
208
|
-
embedding_dim=self.embedding_dim,
|
|
209
|
-
padding_idx=self.padding_idx,
|
|
210
|
-
dtype=old.weight.dtype,
|
|
211
|
-
device=old.weight.device,
|
|
212
|
-
)
|
|
213
|
-
with torch.no_grad():
|
|
214
|
-
new.weight[:old_size].copy_(old.weight)
|
|
215
|
-
nn.init.xavier_uniform_(new.weight[old_size:])
|
|
216
|
-
new.weight[self.padding_idx].zero_()
|
|
217
|
-
self.embedding = new
|
|
218
|
-
if on_expand is not None:
|
|
219
|
-
on_expand(self.embedding)
|
|
220
|
-
logger.info(f"Expanded relation embeddings from {old_size} to {new_size}")
|
|
221
|
-
|
|
222
|
-
def get_relation_index(self, relation_name: str, add_if_missing: bool = False) -> int:
|
|
223
|
-
"""Get index for relation; returns OOV if missing and add_if_missing=False."""
|
|
224
|
-
idx = self.relation_to_idx.get(relation_name)
|
|
225
|
-
if idx is None:
|
|
226
|
-
return self.add_relation(relation_name) if add_if_missing else self.oov_idx
|
|
227
|
-
return idx
|
|
228
|
-
|
|
229
|
-
def get_embedding(self, relation_name: str, add_if_missing: bool = False) -> torch.Tensor:
|
|
230
|
-
"""Get embedding vector for relation (device-safe)."""
|
|
231
|
-
idx = self.get_relation_index(relation_name, add_if_missing=add_if_missing)
|
|
232
|
-
device = self.embedding.weight.device
|
|
233
|
-
return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
|
|
234
|
-
|
|
235
|
-
def get_embeddings_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
|
|
236
|
-
"""Get embedding vectors for batch of relations (device-safe)."""
|
|
237
|
-
indices = [self.get_relation_index(name, add_if_missing=add_if_missing) for name in relation_names]
|
|
238
|
-
device = self.embedding.weight.device
|
|
239
|
-
indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
|
|
240
|
-
return self.embedding(indices_tensor)
|
|
241
|
-
|
|
242
|
-
def forward(self, relation_indices: torch.Tensor) -> torch.Tensor:
|
|
243
|
-
"""Forward pass for relation embeddings"""
|
|
244
|
-
return self.embedding(relation_indices)
|
|
245
|
-
|
|
246
|
-
@property
|
|
247
|
-
def vocab_size(self) -> int:
|
|
248
|
-
"""Current vocabulary size"""
|
|
249
|
-
return len(self.relation_to_idx)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
class EmbeddingManager(nn.Module):
|
|
253
|
-
"""
|
|
254
|
-
Manages both entity and relation embeddings
|
|
255
|
-
Handles vocabulary building from knowledge graph
|
|
256
|
-
"""
|
|
257
|
-
|
|
258
|
-
def __init__(self, config: NPLLConfig):
|
|
259
|
-
super().__init__()
|
|
260
|
-
self.config = config
|
|
261
|
-
|
|
262
|
-
# Entity and relation embedding layers
|
|
263
|
-
self.entity_embeddings = EntityEmbedding(config)
|
|
264
|
-
self.relation_embeddings = RelationEmbedding(config)
|
|
265
|
-
|
|
266
|
-
# Neural network for updating embeddings (paper Section 4.1)
|
|
267
|
-
self.entity_update_network = nn.Sequential(
|
|
268
|
-
nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim),
|
|
269
|
-
nn.ReLU(),
|
|
270
|
-
nn.Dropout(config.dropout),
|
|
271
|
-
nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim)
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
def build_vocabulary_from_kg(self, kg: KnowledgeGraph):
|
|
275
|
-
"""Build vocabulary from knowledge graph entities and relations"""
|
|
276
|
-
logger.info("Building vocabulary from knowledge graph...")
|
|
277
|
-
|
|
278
|
-
# Add all entities
|
|
279
|
-
for entity in kg.entities:
|
|
280
|
-
self.entity_embeddings.add_entity(entity.name)
|
|
281
|
-
|
|
282
|
-
# Add all relations
|
|
283
|
-
for relation in kg.relations:
|
|
284
|
-
self.relation_embeddings.add_relation(relation.name)
|
|
285
|
-
|
|
286
|
-
logger.info(f"Vocabulary built: {self.entity_embeddings.vocab_size} entities, "
|
|
287
|
-
f"{self.relation_embeddings.vocab_size} relations")
|
|
288
|
-
|
|
289
|
-
def get_triple_embeddings(self, triple: Triple, add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
290
|
-
"""
|
|
291
|
-
Get embeddings for a triple (head, relation, tail)
|
|
292
|
-
"""
|
|
293
|
-
head_emb = self.entity_embeddings.get_embedding(triple.head.name, add_if_missing=add_if_missing)
|
|
294
|
-
rel_emb = self.relation_embeddings.get_embedding(triple.relation.name, add_if_missing=add_if_missing)
|
|
295
|
-
tail_emb = self.entity_embeddings.get_embedding(triple.tail.name, add_if_missing=add_if_missing)
|
|
296
|
-
|
|
297
|
-
return head_emb, rel_emb, tail_emb
|
|
298
|
-
|
|
299
|
-
def get_triple_embeddings_batch(self, triples: List[Triple], add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
300
|
-
"""Get embeddings for batch of triples"""
|
|
301
|
-
head_names = [t.head.name for t in triples]
|
|
302
|
-
rel_names = [t.relation.name for t in triples]
|
|
303
|
-
tail_names = [t.tail.name for t in triples]
|
|
304
|
-
|
|
305
|
-
head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
|
|
306
|
-
rel_embs = self.relation_embeddings.get_embeddings_batch(rel_names, add_if_missing=add_if_missing)
|
|
307
|
-
tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
|
|
308
|
-
|
|
309
|
-
return head_embs, rel_embs, tail_embs
|
|
310
|
-
|
|
311
|
-
def update_entity_embeddings(self, entity_embeddings: torch.Tensor) -> torch.Tensor:
|
|
312
|
-
"""
|
|
313
|
-
Update entity embeddings using neural network
|
|
314
|
-
|
|
315
|
-
"""
|
|
316
|
-
return self.entity_update_network(entity_embeddings)
|
|
317
|
-
|
|
318
|
-
def get_entity_embedding(self, entity_name: str) -> torch.Tensor:
|
|
319
|
-
"""Get embedding for single entity"""
|
|
320
|
-
return self.entity_embeddings.get_embedding(entity_name)
|
|
321
|
-
|
|
322
|
-
def get_relation_embedding(self, relation_name: str) -> torch.Tensor:
|
|
323
|
-
"""Get embedding for single relation"""
|
|
324
|
-
return self.relation_embeddings.get_embedding(relation_name)
|
|
325
|
-
|
|
326
|
-
def get_embeddings_for_scoring(self, head_names: List[str],
|
|
327
|
-
relation_names: List[str],
|
|
328
|
-
tail_names: List[str],
|
|
329
|
-
add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
330
|
-
"""
|
|
331
|
-
Get embeddings formatted for scoring module input
|
|
332
|
-
|
|
333
|
-
"""
|
|
334
|
-
# Get base embeddings
|
|
335
|
-
head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
|
|
336
|
-
rel_embs = self.relation_embeddings.get_embeddings_batch(relation_names, add_if_missing=add_if_missing)
|
|
337
|
-
tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
|
|
338
|
-
|
|
339
|
-
# Update entity embeddings through neural network
|
|
340
|
-
head_embs = self.update_entity_embeddings(head_embs)
|
|
341
|
-
tail_embs = self.update_entity_embeddings(tail_embs)
|
|
342
|
-
|
|
343
|
-
return head_embs, rel_embs, tail_embs
|
|
344
|
-
|
|
345
|
-
@property
|
|
346
|
-
def entity_vocab_size(self) -> int:
|
|
347
|
-
"""Number of entities in vocabulary"""
|
|
348
|
-
return self.entity_embeddings.vocab_size
|
|
349
|
-
|
|
350
|
-
@property
|
|
351
|
-
def relation_vocab_size(self) -> int:
|
|
352
|
-
"""Number of relations in vocabulary"""
|
|
353
|
-
return self.relation_embeddings.vocab_size
|
|
354
|
-
|
|
355
|
-
@property
|
|
356
|
-
def relation_num_groups(self) -> int:
|
|
357
|
-
"""Size of relation embedding table (for per-relation temperature groups)."""
|
|
358
|
-
return int(self.relation_embeddings.embedding.num_embeddings)
|
|
359
|
-
|
|
360
|
-
def get_relation_indices_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
|
|
361
|
-
idxs = [self.relation_embeddings.get_relation_index(n, add_if_missing=add_if_missing) for n in relation_names]
|
|
362
|
-
device = self.relation_embeddings.embedding.weight.device
|
|
363
|
-
return torch.tensor(idxs, dtype=torch.long, device=device)
|
|
364
|
-
|
|
365
|
-
def relation_group_ids_for_triples(self, triples: List[Triple], add_if_missing: bool = False) -> torch.Tensor:
|
|
366
|
-
rels = [t.relation.name for t in triples]
|
|
367
|
-
return self.get_relation_indices_batch(rels, add_if_missing=add_if_missing)
|
|
368
|
-
|
|
369
|
-
def save_vocabulary(self, filepath: str):
|
|
370
|
-
"""Save vocabulary mappings to file"""
|
|
371
|
-
vocab_data = {
|
|
372
|
-
'entity_to_idx': self.entity_embeddings.entity_to_idx,
|
|
373
|
-
'relation_to_idx': self.relation_embeddings.relation_to_idx,
|
|
374
|
-
'config': self.config
|
|
375
|
-
}
|
|
376
|
-
torch.save(vocab_data, filepath)
|
|
377
|
-
logger.info(f"Saved vocabulary to {filepath}")
|
|
378
|
-
|
|
379
|
-
def load_vocabulary(self, filepath: str):
|
|
380
|
-
"""Load vocabulary mappings from file"""
|
|
381
|
-
vocab_data = torch.load(filepath)
|
|
382
|
-
|
|
383
|
-
# Load entity vocabulary
|
|
384
|
-
self.entity_embeddings.entity_to_idx = vocab_data['entity_to_idx']
|
|
385
|
-
self.entity_embeddings.idx_to_entity = {
|
|
386
|
-
v: k for k, v in vocab_data['entity_to_idx'].items()
|
|
387
|
-
}
|
|
388
|
-
self.entity_embeddings._next_idx = len(vocab_data['entity_to_idx']) + 1
|
|
389
|
-
|
|
390
|
-
# Load relation vocabulary
|
|
391
|
-
self.relation_embeddings.relation_to_idx = vocab_data['relation_to_idx']
|
|
392
|
-
self.relation_embeddings.idx_to_relation = {
|
|
393
|
-
v: k for k, v in vocab_data['relation_to_idx'].items()
|
|
394
|
-
}
|
|
395
|
-
self.relation_embeddings._next_idx = len(vocab_data['relation_to_idx']) + 1
|
|
396
|
-
|
|
397
|
-
logger.info(f"Loaded vocabulary from {filepath}")
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
def initialize_embeddings_from_pretrained(embedding_manager: EmbeddingManager,
|
|
401
|
-
pretrained_entity_embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
|
402
|
-
pretrained_relation_embeddings: Optional[Dict[str, torch.Tensor]] = None):
|
|
403
|
-
"""
|
|
404
|
-
Initialize embeddings from pretrained vectors
|
|
405
|
-
"""
|
|
406
|
-
if pretrained_entity_embeddings:
|
|
407
|
-
logger.info(f"Initializing {len(pretrained_entity_embeddings)} entity embeddings from pretrained")
|
|
408
|
-
|
|
409
|
-
with torch.no_grad():
|
|
410
|
-
for entity_name, embedding_vector in pretrained_entity_embeddings.items():
|
|
411
|
-
idx = embedding_manager.entity_embeddings.add_entity(entity_name)
|
|
412
|
-
if embedding_vector.size(0) == embedding_manager.config.entity_embedding_dim:
|
|
413
|
-
embedding_manager.entity_embeddings.embedding.weight[idx] = embedding_vector
|
|
414
|
-
else:
|
|
415
|
-
logger.warning(f"Dimension mismatch for entity {entity_name}: "
|
|
416
|
-
f"expected {embedding_manager.config.entity_embedding_dim}, "
|
|
417
|
-
f"got {embedding_vector.size(0)}")
|
|
418
|
-
|
|
419
|
-
if pretrained_relation_embeddings:
|
|
420
|
-
logger.info(f"Initializing {len(pretrained_relation_embeddings)} relation embeddings from pretrained")
|
|
421
|
-
|
|
422
|
-
with torch.no_grad():
|
|
423
|
-
for relation_name, embedding_vector in pretrained_relation_embeddings.items():
|
|
424
|
-
idx = embedding_manager.relation_embeddings.add_relation(relation_name)
|
|
425
|
-
if embedding_vector.size(0) == embedding_manager.config.relation_embedding_dim:
|
|
426
|
-
embedding_manager.relation_embeddings.embedding.weight[idx] = embedding_vector
|
|
427
|
-
else:
|
|
428
|
-
logger.warning(f"Dimension mismatch for relation {relation_name}")
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
# Factory function for creating embedding manager
|
|
432
|
-
def create_embedding_manager(config: NPLLConfig, kg: Optional[KnowledgeGraph] = None) -> EmbeddingManager:
|
|
433
|
-
"""
|
|
434
|
-
Factory function to create and initialize EmbeddingManager
|
|
435
|
-
|
|
436
|
-
"""
|
|
437
|
-
manager = EmbeddingManager(config)
|
|
438
|
-
|
|
439
|
-
if kg is not None:
|
|
440
|
-
manager.build_vocabulary_from_kg(kg)
|
|
441
|
-
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from typing import Dict, List, Optional, Set, Tuple
|
|
6
|
+
import logging
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
|
|
9
|
+
from ..core import KnowledgeGraph, Entity, Relation, Triple
|
|
10
|
+
from ..utils.config import NPLLConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EntityEmbedding(nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
Entity embedding layer with dynamic vocabulary expansion
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: NPLLConfig, initial_vocab_size: int = 10000):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.config = config
|
|
24
|
+
self.embedding_dim = config.entity_embedding_dim
|
|
25
|
+
|
|
26
|
+
# Entity vocabulary mapping: entity_name -> index
|
|
27
|
+
self.entity_to_idx: Dict[str, int] = {}
|
|
28
|
+
self.idx_to_entity: Dict[int, str] = {}
|
|
29
|
+
|
|
30
|
+
# Reserve indices: 0 padding, 1 OOV; new entries from 2+
|
|
31
|
+
self.padding_idx = 0
|
|
32
|
+
self.oov_idx = 1
|
|
33
|
+
# Embedding layer - will expand dynamically
|
|
34
|
+
self.embedding = nn.Embedding(
|
|
35
|
+
num_embeddings=initial_vocab_size,
|
|
36
|
+
embedding_dim=self.embedding_dim,
|
|
37
|
+
padding_idx=self.padding_idx
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Initialize embeddings using Xavier initialization
|
|
41
|
+
nn.init.xavier_uniform_(self.embedding.weight.data)
|
|
42
|
+
# Set padding embedding to zero; init OOV to a valid vector
|
|
43
|
+
with torch.no_grad():
|
|
44
|
+
self.embedding.weight[self.padding_idx].zero_()
|
|
45
|
+
# Track next available index (0 pad, 1 oov reserved)
|
|
46
|
+
self._next_idx = 2
|
|
47
|
+
|
|
48
|
+
def add_entity(self, entity_name: str) -> int:
|
|
49
|
+
"""Add entity to vocabulary and return its index"""
|
|
50
|
+
if entity_name not in self.entity_to_idx:
|
|
51
|
+
# Check if we need to expand embedding layer
|
|
52
|
+
if self._next_idx >= self.embedding.num_embeddings:
|
|
53
|
+
self._expand_embeddings()
|
|
54
|
+
|
|
55
|
+
# Add entity to vocabulary
|
|
56
|
+
idx = self._next_idx
|
|
57
|
+
self.entity_to_idx[entity_name] = idx
|
|
58
|
+
self.idx_to_entity[idx] = entity_name
|
|
59
|
+
self._next_idx += 1
|
|
60
|
+
|
|
61
|
+
logger.debug(f"Added entity '{entity_name}' with index {idx}")
|
|
62
|
+
|
|
63
|
+
return self.entity_to_idx[entity_name]
|
|
64
|
+
|
|
65
|
+
def _expand_embeddings(self, grow_by: int = 1000, on_expand=None):
|
|
66
|
+
"""Expand embedding layer when vocabulary grows (preserve device/dtype)."""
|
|
67
|
+
old = self.embedding
|
|
68
|
+
old_size = old.num_embeddings
|
|
69
|
+
new_size = max(old_size * 2, self._next_idx + grow_by)
|
|
70
|
+
new = nn.Embedding(
|
|
71
|
+
num_embeddings=new_size,
|
|
72
|
+
embedding_dim=self.embedding_dim,
|
|
73
|
+
padding_idx=self.padding_idx,
|
|
74
|
+
dtype=old.weight.dtype,
|
|
75
|
+
device=old.weight.device,
|
|
76
|
+
)
|
|
77
|
+
with torch.no_grad():
|
|
78
|
+
new.weight[:old_size].copy_(old.weight)
|
|
79
|
+
nn.init.xavier_uniform_(new.weight[old_size:])
|
|
80
|
+
new.weight[self.padding_idx].zero_()
|
|
81
|
+
self.embedding = new
|
|
82
|
+
if on_expand is not None:
|
|
83
|
+
on_expand(self.embedding)
|
|
84
|
+
logger.info(f"Expanded entity embeddings from {old_size} to {new_size}")
|
|
85
|
+
|
|
86
|
+
def get_entity_index(self, entity_name: str, add_if_missing: bool = False) -> int:
|
|
87
|
+
"""Get index for entity; returns OOV if missing and add_if_missing=False."""
|
|
88
|
+
idx = self.entity_to_idx.get(entity_name)
|
|
89
|
+
if idx is None:
|
|
90
|
+
return self.add_entity(entity_name) if add_if_missing else self.oov_idx
|
|
91
|
+
return idx
|
|
92
|
+
|
|
93
|
+
def get_entity_name(self, idx: int) -> Optional[str]:
|
|
94
|
+
"""Get entity name from index"""
|
|
95
|
+
return self.idx_to_entity.get(idx)
|
|
96
|
+
|
|
97
|
+
def get_embedding(self, entity_name: str, add_if_missing: bool = False) -> torch.Tensor:
|
|
98
|
+
"""Get embedding vector for entity (device-safe)."""
|
|
99
|
+
idx = self.get_entity_index(entity_name, add_if_missing=add_if_missing)
|
|
100
|
+
device = self.embedding.weight.device
|
|
101
|
+
return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
|
|
102
|
+
|
|
103
|
+
def get_embeddings_batch(self, entity_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
|
|
104
|
+
"""Get embedding vectors for batch of entities (device-safe)."""
|
|
105
|
+
indices = [self.get_entity_index(name, add_if_missing=add_if_missing) for name in entity_names]
|
|
106
|
+
device = self.embedding.weight.device
|
|
107
|
+
indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
|
|
108
|
+
return self.embedding(indices_tensor)
|
|
109
|
+
|
|
110
|
+
def forward(self, entity_indices: torch.Tensor) -> torch.Tensor:
|
|
111
|
+
"""Forward pass for entity embeddings"""
|
|
112
|
+
return self.embedding(entity_indices)
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def vocab_size(self) -> int:
|
|
116
|
+
"""Current vocabulary size"""
|
|
117
|
+
return len(self.entity_to_idx)
|
|
118
|
+
|
|
119
|
+
def state_dict_with_vocab(self) -> Dict:
|
|
120
|
+
"""Get state dict including vocabulary mappings"""
|
|
121
|
+
state = super().state_dict()
|
|
122
|
+
state['entity_to_idx'] = self.entity_to_idx.copy()
|
|
123
|
+
state['idx_to_entity'] = self.idx_to_entity.copy()
|
|
124
|
+
state['_next_idx'] = self._next_idx
|
|
125
|
+
state['padding_idx'] = self.padding_idx
|
|
126
|
+
state['oov_idx'] = self.oov_idx
|
|
127
|
+
return state
|
|
128
|
+
|
|
129
|
+
def load_state_dict_with_vocab(self, state_dict: Dict):
|
|
130
|
+
"""Load state dict including vocabulary mappings"""
|
|
131
|
+
# Load vocabulary mappings first
|
|
132
|
+
self.entity_to_idx = state_dict.pop('entity_to_idx', {})
|
|
133
|
+
self.idx_to_entity = state_dict.pop('idx_to_entity', {})
|
|
134
|
+
self._next_idx = state_dict.pop('_next_idx', 2)
|
|
135
|
+
self.padding_idx = state_dict.pop('padding_idx', 0)
|
|
136
|
+
self.oov_idx = state_dict.pop('oov_idx', 1)
|
|
137
|
+
needed = max(self._next_idx, self.embedding.num_embeddings)
|
|
138
|
+
if needed > self.embedding.num_embeddings:
|
|
139
|
+
old = self.embedding
|
|
140
|
+
new = nn.Embedding(
|
|
141
|
+
num_embeddings=needed,
|
|
142
|
+
embedding_dim=self.embedding_dim,
|
|
143
|
+
padding_idx=self.padding_idx,
|
|
144
|
+
dtype=old.weight.dtype,
|
|
145
|
+
device=old.weight.device,
|
|
146
|
+
)
|
|
147
|
+
with torch.no_grad():
|
|
148
|
+
new.weight[:old.num_embeddings].copy_(old.weight)
|
|
149
|
+
nn.init.xavier_uniform_(new.weight[old.num_embeddings:])
|
|
150
|
+
new.weight[self.padding_idx].zero_()
|
|
151
|
+
self.embedding = new
|
|
152
|
+
# Load model parameters leniently in case of size diffs
|
|
153
|
+
super().load_state_dict(state_dict, strict=False)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class RelationEmbedding(nn.Module):
|
|
157
|
+
"""
|
|
158
|
+
Relation embedding layer with dynamic vocabulary expansion
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def __init__(self, config: NPLLConfig, initial_vocab_size: int = 1000):
|
|
162
|
+
super().__init__()
|
|
163
|
+
self.config = config
|
|
164
|
+
self.embedding_dim = config.relation_embedding_dim
|
|
165
|
+
|
|
166
|
+
# Relation vocabulary mapping
|
|
167
|
+
self.relation_to_idx: Dict[str, int] = {}
|
|
168
|
+
self.idx_to_relation: Dict[int, str] = {}
|
|
169
|
+
|
|
170
|
+
# Embedding layer
|
|
171
|
+
self.padding_idx = 0
|
|
172
|
+
self.oov_idx = 1
|
|
173
|
+
self.embedding = nn.Embedding(
|
|
174
|
+
num_embeddings=initial_vocab_size,
|
|
175
|
+
embedding_dim=self.embedding_dim,
|
|
176
|
+
padding_idx=self.padding_idx
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Initialize embeddings
|
|
180
|
+
nn.init.xavier_uniform_(self.embedding.weight.data)
|
|
181
|
+
with torch.no_grad():
|
|
182
|
+
self.embedding.weight[self.padding_idx].zero_()
|
|
183
|
+
|
|
184
|
+
self._next_idx = 2
|
|
185
|
+
|
|
186
|
+
def add_relation(self, relation_name: str) -> int:
|
|
187
|
+
"""Add relation to vocabulary and return its index"""
|
|
188
|
+
if relation_name not in self.relation_to_idx:
|
|
189
|
+
if self._next_idx >= self.embedding.num_embeddings:
|
|
190
|
+
self._expand_embeddings()
|
|
191
|
+
|
|
192
|
+
idx = self._next_idx
|
|
193
|
+
self.relation_to_idx[relation_name] = idx
|
|
194
|
+
self.idx_to_relation[idx] = relation_name
|
|
195
|
+
self._next_idx += 1
|
|
196
|
+
|
|
197
|
+
logger.debug(f"Added relation '{relation_name}' with index {idx}")
|
|
198
|
+
|
|
199
|
+
return self.relation_to_idx[relation_name]
|
|
200
|
+
|
|
201
|
+
def _expand_embeddings(self, grow_by: int = 100, on_expand=None):
|
|
202
|
+
"""Expand embedding layer when vocabulary grows (preserve device/dtype)."""
|
|
203
|
+
old = self.embedding
|
|
204
|
+
old_size = old.num_embeddings
|
|
205
|
+
new_size = max(old_size * 2, self._next_idx + grow_by)
|
|
206
|
+
new = nn.Embedding(
|
|
207
|
+
num_embeddings=new_size,
|
|
208
|
+
embedding_dim=self.embedding_dim,
|
|
209
|
+
padding_idx=self.padding_idx,
|
|
210
|
+
dtype=old.weight.dtype,
|
|
211
|
+
device=old.weight.device,
|
|
212
|
+
)
|
|
213
|
+
with torch.no_grad():
|
|
214
|
+
new.weight[:old_size].copy_(old.weight)
|
|
215
|
+
nn.init.xavier_uniform_(new.weight[old_size:])
|
|
216
|
+
new.weight[self.padding_idx].zero_()
|
|
217
|
+
self.embedding = new
|
|
218
|
+
if on_expand is not None:
|
|
219
|
+
on_expand(self.embedding)
|
|
220
|
+
logger.info(f"Expanded relation embeddings from {old_size} to {new_size}")
|
|
221
|
+
|
|
222
|
+
def get_relation_index(self, relation_name: str, add_if_missing: bool = False) -> int:
|
|
223
|
+
"""Get index for relation; returns OOV if missing and add_if_missing=False."""
|
|
224
|
+
idx = self.relation_to_idx.get(relation_name)
|
|
225
|
+
if idx is None:
|
|
226
|
+
return self.add_relation(relation_name) if add_if_missing else self.oov_idx
|
|
227
|
+
return idx
|
|
228
|
+
|
|
229
|
+
def get_embedding(self, relation_name: str, add_if_missing: bool = False) -> torch.Tensor:
|
|
230
|
+
"""Get embedding vector for relation (device-safe)."""
|
|
231
|
+
idx = self.get_relation_index(relation_name, add_if_missing=add_if_missing)
|
|
232
|
+
device = self.embedding.weight.device
|
|
233
|
+
return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
|
|
234
|
+
|
|
235
|
+
def get_embeddings_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
|
|
236
|
+
"""Get embedding vectors for batch of relations (device-safe)."""
|
|
237
|
+
indices = [self.get_relation_index(name, add_if_missing=add_if_missing) for name in relation_names]
|
|
238
|
+
device = self.embedding.weight.device
|
|
239
|
+
indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
|
|
240
|
+
return self.embedding(indices_tensor)
|
|
241
|
+
|
|
242
|
+
def forward(self, relation_indices: torch.Tensor) -> torch.Tensor:
|
|
243
|
+
"""Forward pass for relation embeddings"""
|
|
244
|
+
return self.embedding(relation_indices)
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def vocab_size(self) -> int:
|
|
248
|
+
"""Current vocabulary size"""
|
|
249
|
+
return len(self.relation_to_idx)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class EmbeddingManager(nn.Module):
|
|
253
|
+
"""
|
|
254
|
+
Manages both entity and relation embeddings
|
|
255
|
+
Handles vocabulary building from knowledge graph
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(self, config: NPLLConfig):
|
|
259
|
+
super().__init__()
|
|
260
|
+
self.config = config
|
|
261
|
+
|
|
262
|
+
# Entity and relation embedding layers
|
|
263
|
+
self.entity_embeddings = EntityEmbedding(config)
|
|
264
|
+
self.relation_embeddings = RelationEmbedding(config)
|
|
265
|
+
|
|
266
|
+
# Neural network for updating embeddings (paper Section 4.1)
|
|
267
|
+
self.entity_update_network = nn.Sequential(
|
|
268
|
+
nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim),
|
|
269
|
+
nn.ReLU(),
|
|
270
|
+
nn.Dropout(config.dropout),
|
|
271
|
+
nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim)
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def build_vocabulary_from_kg(self, kg: KnowledgeGraph):
|
|
275
|
+
"""Build vocabulary from knowledge graph entities and relations"""
|
|
276
|
+
logger.info("Building vocabulary from knowledge graph...")
|
|
277
|
+
|
|
278
|
+
# Add all entities
|
|
279
|
+
for entity in kg.entities:
|
|
280
|
+
self.entity_embeddings.add_entity(entity.name)
|
|
281
|
+
|
|
282
|
+
# Add all relations
|
|
283
|
+
for relation in kg.relations:
|
|
284
|
+
self.relation_embeddings.add_relation(relation.name)
|
|
285
|
+
|
|
286
|
+
logger.info(f"Vocabulary built: {self.entity_embeddings.vocab_size} entities, "
|
|
287
|
+
f"{self.relation_embeddings.vocab_size} relations")
|
|
288
|
+
|
|
289
|
+
def get_triple_embeddings(self, triple: Triple, add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
290
|
+
"""
|
|
291
|
+
Get embeddings for a triple (head, relation, tail)
|
|
292
|
+
"""
|
|
293
|
+
head_emb = self.entity_embeddings.get_embedding(triple.head.name, add_if_missing=add_if_missing)
|
|
294
|
+
rel_emb = self.relation_embeddings.get_embedding(triple.relation.name, add_if_missing=add_if_missing)
|
|
295
|
+
tail_emb = self.entity_embeddings.get_embedding(triple.tail.name, add_if_missing=add_if_missing)
|
|
296
|
+
|
|
297
|
+
return head_emb, rel_emb, tail_emb
|
|
298
|
+
|
|
299
|
+
def get_triple_embeddings_batch(self, triples: List[Triple], add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
300
|
+
"""Get embeddings for batch of triples"""
|
|
301
|
+
head_names = [t.head.name for t in triples]
|
|
302
|
+
rel_names = [t.relation.name for t in triples]
|
|
303
|
+
tail_names = [t.tail.name for t in triples]
|
|
304
|
+
|
|
305
|
+
head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
|
|
306
|
+
rel_embs = self.relation_embeddings.get_embeddings_batch(rel_names, add_if_missing=add_if_missing)
|
|
307
|
+
tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
|
|
308
|
+
|
|
309
|
+
return head_embs, rel_embs, tail_embs
|
|
310
|
+
|
|
311
|
+
def update_entity_embeddings(self, entity_embeddings: torch.Tensor) -> torch.Tensor:
|
|
312
|
+
"""
|
|
313
|
+
Update entity embeddings using neural network
|
|
314
|
+
|
|
315
|
+
"""
|
|
316
|
+
return self.entity_update_network(entity_embeddings)
|
|
317
|
+
|
|
318
|
+
def get_entity_embedding(self, entity_name: str) -> torch.Tensor:
|
|
319
|
+
"""Get embedding for single entity"""
|
|
320
|
+
return self.entity_embeddings.get_embedding(entity_name)
|
|
321
|
+
|
|
322
|
+
def get_relation_embedding(self, relation_name: str) -> torch.Tensor:
|
|
323
|
+
"""Get embedding for single relation"""
|
|
324
|
+
return self.relation_embeddings.get_embedding(relation_name)
|
|
325
|
+
|
|
326
|
+
def get_embeddings_for_scoring(self, head_names: List[str],
|
|
327
|
+
relation_names: List[str],
|
|
328
|
+
tail_names: List[str],
|
|
329
|
+
add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
330
|
+
"""
|
|
331
|
+
Get embeddings formatted for scoring module input
|
|
332
|
+
|
|
333
|
+
"""
|
|
334
|
+
# Get base embeddings
|
|
335
|
+
head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
|
|
336
|
+
rel_embs = self.relation_embeddings.get_embeddings_batch(relation_names, add_if_missing=add_if_missing)
|
|
337
|
+
tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
|
|
338
|
+
|
|
339
|
+
# Update entity embeddings through neural network
|
|
340
|
+
head_embs = self.update_entity_embeddings(head_embs)
|
|
341
|
+
tail_embs = self.update_entity_embeddings(tail_embs)
|
|
342
|
+
|
|
343
|
+
return head_embs, rel_embs, tail_embs
|
|
344
|
+
|
|
345
|
+
@property
|
|
346
|
+
def entity_vocab_size(self) -> int:
|
|
347
|
+
"""Number of entities in vocabulary"""
|
|
348
|
+
return self.entity_embeddings.vocab_size
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def relation_vocab_size(self) -> int:
|
|
352
|
+
"""Number of relations in vocabulary"""
|
|
353
|
+
return self.relation_embeddings.vocab_size
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def relation_num_groups(self) -> int:
|
|
357
|
+
"""Size of relation embedding table (for per-relation temperature groups)."""
|
|
358
|
+
return int(self.relation_embeddings.embedding.num_embeddings)
|
|
359
|
+
|
|
360
|
+
def get_relation_indices_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
|
|
361
|
+
idxs = [self.relation_embeddings.get_relation_index(n, add_if_missing=add_if_missing) for n in relation_names]
|
|
362
|
+
device = self.relation_embeddings.embedding.weight.device
|
|
363
|
+
return torch.tensor(idxs, dtype=torch.long, device=device)
|
|
364
|
+
|
|
365
|
+
def relation_group_ids_for_triples(self, triples: List[Triple], add_if_missing: bool = False) -> torch.Tensor:
|
|
366
|
+
rels = [t.relation.name for t in triples]
|
|
367
|
+
return self.get_relation_indices_batch(rels, add_if_missing=add_if_missing)
|
|
368
|
+
|
|
369
|
+
def save_vocabulary(self, filepath: str):
|
|
370
|
+
"""Save vocabulary mappings to file"""
|
|
371
|
+
vocab_data = {
|
|
372
|
+
'entity_to_idx': self.entity_embeddings.entity_to_idx,
|
|
373
|
+
'relation_to_idx': self.relation_embeddings.relation_to_idx,
|
|
374
|
+
'config': self.config
|
|
375
|
+
}
|
|
376
|
+
torch.save(vocab_data, filepath)
|
|
377
|
+
logger.info(f"Saved vocabulary to {filepath}")
|
|
378
|
+
|
|
379
|
+
def load_vocabulary(self, filepath: str):
|
|
380
|
+
"""Load vocabulary mappings from file"""
|
|
381
|
+
vocab_data = torch.load(filepath)
|
|
382
|
+
|
|
383
|
+
# Load entity vocabulary
|
|
384
|
+
self.entity_embeddings.entity_to_idx = vocab_data['entity_to_idx']
|
|
385
|
+
self.entity_embeddings.idx_to_entity = {
|
|
386
|
+
v: k for k, v in vocab_data['entity_to_idx'].items()
|
|
387
|
+
}
|
|
388
|
+
self.entity_embeddings._next_idx = len(vocab_data['entity_to_idx']) + 1
|
|
389
|
+
|
|
390
|
+
# Load relation vocabulary
|
|
391
|
+
self.relation_embeddings.relation_to_idx = vocab_data['relation_to_idx']
|
|
392
|
+
self.relation_embeddings.idx_to_relation = {
|
|
393
|
+
v: k for k, v in vocab_data['relation_to_idx'].items()
|
|
394
|
+
}
|
|
395
|
+
self.relation_embeddings._next_idx = len(vocab_data['relation_to_idx']) + 1
|
|
396
|
+
|
|
397
|
+
logger.info(f"Loaded vocabulary from {filepath}")
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def initialize_embeddings_from_pretrained(embedding_manager: EmbeddingManager,
|
|
401
|
+
pretrained_entity_embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
|
402
|
+
pretrained_relation_embeddings: Optional[Dict[str, torch.Tensor]] = None):
|
|
403
|
+
"""
|
|
404
|
+
Initialize embeddings from pretrained vectors
|
|
405
|
+
"""
|
|
406
|
+
if pretrained_entity_embeddings:
|
|
407
|
+
logger.info(f"Initializing {len(pretrained_entity_embeddings)} entity embeddings from pretrained")
|
|
408
|
+
|
|
409
|
+
with torch.no_grad():
|
|
410
|
+
for entity_name, embedding_vector in pretrained_entity_embeddings.items():
|
|
411
|
+
idx = embedding_manager.entity_embeddings.add_entity(entity_name)
|
|
412
|
+
if embedding_vector.size(0) == embedding_manager.config.entity_embedding_dim:
|
|
413
|
+
embedding_manager.entity_embeddings.embedding.weight[idx] = embedding_vector
|
|
414
|
+
else:
|
|
415
|
+
logger.warning(f"Dimension mismatch for entity {entity_name}: "
|
|
416
|
+
f"expected {embedding_manager.config.entity_embedding_dim}, "
|
|
417
|
+
f"got {embedding_vector.size(0)}")
|
|
418
|
+
|
|
419
|
+
if pretrained_relation_embeddings:
|
|
420
|
+
logger.info(f"Initializing {len(pretrained_relation_embeddings)} relation embeddings from pretrained")
|
|
421
|
+
|
|
422
|
+
with torch.no_grad():
|
|
423
|
+
for relation_name, embedding_vector in pretrained_relation_embeddings.items():
|
|
424
|
+
idx = embedding_manager.relation_embeddings.add_relation(relation_name)
|
|
425
|
+
if embedding_vector.size(0) == embedding_manager.config.relation_embedding_dim:
|
|
426
|
+
embedding_manager.relation_embeddings.embedding.weight[idx] = embedding_vector
|
|
427
|
+
else:
|
|
428
|
+
logger.warning(f"Dimension mismatch for relation {relation_name}")
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
# Factory function for creating embedding manager
|
|
432
|
+
def create_embedding_manager(config: NPLLConfig, kg: Optional[KnowledgeGraph] = None) -> EmbeddingManager:
|
|
433
|
+
"""
|
|
434
|
+
Factory function to create and initialize EmbeddingManager
|
|
435
|
+
|
|
436
|
+
"""
|
|
437
|
+
manager = EmbeddingManager(config)
|
|
438
|
+
|
|
439
|
+
if kg is not None:
|
|
440
|
+
manager.build_vocabulary_from_kg(kg)
|
|
441
|
+
|
|
442
442
|
return manager
|