odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,442 +1,442 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from typing import Dict, List, Optional, Set, Tuple
6
- import logging
7
- from collections import OrderedDict
8
-
9
- from ..core import KnowledgeGraph, Entity, Relation, Triple
10
- from ..utils.config import NPLLConfig
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class EntityEmbedding(nn.Module):
16
- """
17
- Entity embedding layer with dynamic vocabulary expansion
18
-
19
- """
20
-
21
- def __init__(self, config: NPLLConfig, initial_vocab_size: int = 10000):
22
- super().__init__()
23
- self.config = config
24
- self.embedding_dim = config.entity_embedding_dim
25
-
26
- # Entity vocabulary mapping: entity_name -> index
27
- self.entity_to_idx: Dict[str, int] = {}
28
- self.idx_to_entity: Dict[int, str] = {}
29
-
30
- # Reserve indices: 0 padding, 1 OOV; new entries from 2+
31
- self.padding_idx = 0
32
- self.oov_idx = 1
33
- # Embedding layer - will expand dynamically
34
- self.embedding = nn.Embedding(
35
- num_embeddings=initial_vocab_size,
36
- embedding_dim=self.embedding_dim,
37
- padding_idx=self.padding_idx
38
- )
39
-
40
- # Initialize embeddings using Xavier initialization
41
- nn.init.xavier_uniform_(self.embedding.weight.data)
42
- # Set padding embedding to zero; init OOV to a valid vector
43
- with torch.no_grad():
44
- self.embedding.weight[self.padding_idx].zero_()
45
- # Track next available index (0 pad, 1 oov reserved)
46
- self._next_idx = 2
47
-
48
- def add_entity(self, entity_name: str) -> int:
49
- """Add entity to vocabulary and return its index"""
50
- if entity_name not in self.entity_to_idx:
51
- # Check if we need to expand embedding layer
52
- if self._next_idx >= self.embedding.num_embeddings:
53
- self._expand_embeddings()
54
-
55
- # Add entity to vocabulary
56
- idx = self._next_idx
57
- self.entity_to_idx[entity_name] = idx
58
- self.idx_to_entity[idx] = entity_name
59
- self._next_idx += 1
60
-
61
- logger.debug(f"Added entity '{entity_name}' with index {idx}")
62
-
63
- return self.entity_to_idx[entity_name]
64
-
65
- def _expand_embeddings(self, grow_by: int = 1000, on_expand=None):
66
- """Expand embedding layer when vocabulary grows (preserve device/dtype)."""
67
- old = self.embedding
68
- old_size = old.num_embeddings
69
- new_size = max(old_size * 2, self._next_idx + grow_by)
70
- new = nn.Embedding(
71
- num_embeddings=new_size,
72
- embedding_dim=self.embedding_dim,
73
- padding_idx=self.padding_idx,
74
- dtype=old.weight.dtype,
75
- device=old.weight.device,
76
- )
77
- with torch.no_grad():
78
- new.weight[:old_size].copy_(old.weight)
79
- nn.init.xavier_uniform_(new.weight[old_size:])
80
- new.weight[self.padding_idx].zero_()
81
- self.embedding = new
82
- if on_expand is not None:
83
- on_expand(self.embedding)
84
- logger.info(f"Expanded entity embeddings from {old_size} to {new_size}")
85
-
86
- def get_entity_index(self, entity_name: str, add_if_missing: bool = False) -> int:
87
- """Get index for entity; returns OOV if missing and add_if_missing=False."""
88
- idx = self.entity_to_idx.get(entity_name)
89
- if idx is None:
90
- return self.add_entity(entity_name) if add_if_missing else self.oov_idx
91
- return idx
92
-
93
- def get_entity_name(self, idx: int) -> Optional[str]:
94
- """Get entity name from index"""
95
- return self.idx_to_entity.get(idx)
96
-
97
- def get_embedding(self, entity_name: str, add_if_missing: bool = False) -> torch.Tensor:
98
- """Get embedding vector for entity (device-safe)."""
99
- idx = self.get_entity_index(entity_name, add_if_missing=add_if_missing)
100
- device = self.embedding.weight.device
101
- return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
102
-
103
- def get_embeddings_batch(self, entity_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
104
- """Get embedding vectors for batch of entities (device-safe)."""
105
- indices = [self.get_entity_index(name, add_if_missing=add_if_missing) for name in entity_names]
106
- device = self.embedding.weight.device
107
- indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
108
- return self.embedding(indices_tensor)
109
-
110
- def forward(self, entity_indices: torch.Tensor) -> torch.Tensor:
111
- """Forward pass for entity embeddings"""
112
- return self.embedding(entity_indices)
113
-
114
- @property
115
- def vocab_size(self) -> int:
116
- """Current vocabulary size"""
117
- return len(self.entity_to_idx)
118
-
119
- def state_dict_with_vocab(self) -> Dict:
120
- """Get state dict including vocabulary mappings"""
121
- state = super().state_dict()
122
- state['entity_to_idx'] = self.entity_to_idx.copy()
123
- state['idx_to_entity'] = self.idx_to_entity.copy()
124
- state['_next_idx'] = self._next_idx
125
- state['padding_idx'] = self.padding_idx
126
- state['oov_idx'] = self.oov_idx
127
- return state
128
-
129
- def load_state_dict_with_vocab(self, state_dict: Dict):
130
- """Load state dict including vocabulary mappings"""
131
- # Load vocabulary mappings first
132
- self.entity_to_idx = state_dict.pop('entity_to_idx', {})
133
- self.idx_to_entity = state_dict.pop('idx_to_entity', {})
134
- self._next_idx = state_dict.pop('_next_idx', 2)
135
- self.padding_idx = state_dict.pop('padding_idx', 0)
136
- self.oov_idx = state_dict.pop('oov_idx', 1)
137
- needed = max(self._next_idx, self.embedding.num_embeddings)
138
- if needed > self.embedding.num_embeddings:
139
- old = self.embedding
140
- new = nn.Embedding(
141
- num_embeddings=needed,
142
- embedding_dim=self.embedding_dim,
143
- padding_idx=self.padding_idx,
144
- dtype=old.weight.dtype,
145
- device=old.weight.device,
146
- )
147
- with torch.no_grad():
148
- new.weight[:old.num_embeddings].copy_(old.weight)
149
- nn.init.xavier_uniform_(new.weight[old.num_embeddings:])
150
- new.weight[self.padding_idx].zero_()
151
- self.embedding = new
152
- # Load model parameters leniently in case of size diffs
153
- super().load_state_dict(state_dict, strict=False)
154
-
155
-
156
- class RelationEmbedding(nn.Module):
157
- """
158
- Relation embedding layer with dynamic vocabulary expansion
159
- """
160
-
161
- def __init__(self, config: NPLLConfig, initial_vocab_size: int = 1000):
162
- super().__init__()
163
- self.config = config
164
- self.embedding_dim = config.relation_embedding_dim
165
-
166
- # Relation vocabulary mapping
167
- self.relation_to_idx: Dict[str, int] = {}
168
- self.idx_to_relation: Dict[int, str] = {}
169
-
170
- # Embedding layer
171
- self.padding_idx = 0
172
- self.oov_idx = 1
173
- self.embedding = nn.Embedding(
174
- num_embeddings=initial_vocab_size,
175
- embedding_dim=self.embedding_dim,
176
- padding_idx=self.padding_idx
177
- )
178
-
179
- # Initialize embeddings
180
- nn.init.xavier_uniform_(self.embedding.weight.data)
181
- with torch.no_grad():
182
- self.embedding.weight[self.padding_idx].zero_()
183
-
184
- self._next_idx = 2
185
-
186
- def add_relation(self, relation_name: str) -> int:
187
- """Add relation to vocabulary and return its index"""
188
- if relation_name not in self.relation_to_idx:
189
- if self._next_idx >= self.embedding.num_embeddings:
190
- self._expand_embeddings()
191
-
192
- idx = self._next_idx
193
- self.relation_to_idx[relation_name] = idx
194
- self.idx_to_relation[idx] = relation_name
195
- self._next_idx += 1
196
-
197
- logger.debug(f"Added relation '{relation_name}' with index {idx}")
198
-
199
- return self.relation_to_idx[relation_name]
200
-
201
- def _expand_embeddings(self, grow_by: int = 100, on_expand=None):
202
- """Expand embedding layer when vocabulary grows (preserve device/dtype)."""
203
- old = self.embedding
204
- old_size = old.num_embeddings
205
- new_size = max(old_size * 2, self._next_idx + grow_by)
206
- new = nn.Embedding(
207
- num_embeddings=new_size,
208
- embedding_dim=self.embedding_dim,
209
- padding_idx=self.padding_idx,
210
- dtype=old.weight.dtype,
211
- device=old.weight.device,
212
- )
213
- with torch.no_grad():
214
- new.weight[:old_size].copy_(old.weight)
215
- nn.init.xavier_uniform_(new.weight[old_size:])
216
- new.weight[self.padding_idx].zero_()
217
- self.embedding = new
218
- if on_expand is not None:
219
- on_expand(self.embedding)
220
- logger.info(f"Expanded relation embeddings from {old_size} to {new_size}")
221
-
222
- def get_relation_index(self, relation_name: str, add_if_missing: bool = False) -> int:
223
- """Get index for relation; returns OOV if missing and add_if_missing=False."""
224
- idx = self.relation_to_idx.get(relation_name)
225
- if idx is None:
226
- return self.add_relation(relation_name) if add_if_missing else self.oov_idx
227
- return idx
228
-
229
- def get_embedding(self, relation_name: str, add_if_missing: bool = False) -> torch.Tensor:
230
- """Get embedding vector for relation (device-safe)."""
231
- idx = self.get_relation_index(relation_name, add_if_missing=add_if_missing)
232
- device = self.embedding.weight.device
233
- return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
234
-
235
- def get_embeddings_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
236
- """Get embedding vectors for batch of relations (device-safe)."""
237
- indices = [self.get_relation_index(name, add_if_missing=add_if_missing) for name in relation_names]
238
- device = self.embedding.weight.device
239
- indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
240
- return self.embedding(indices_tensor)
241
-
242
- def forward(self, relation_indices: torch.Tensor) -> torch.Tensor:
243
- """Forward pass for relation embeddings"""
244
- return self.embedding(relation_indices)
245
-
246
- @property
247
- def vocab_size(self) -> int:
248
- """Current vocabulary size"""
249
- return len(self.relation_to_idx)
250
-
251
-
252
- class EmbeddingManager(nn.Module):
253
- """
254
- Manages both entity and relation embeddings
255
- Handles vocabulary building from knowledge graph
256
- """
257
-
258
- def __init__(self, config: NPLLConfig):
259
- super().__init__()
260
- self.config = config
261
-
262
- # Entity and relation embedding layers
263
- self.entity_embeddings = EntityEmbedding(config)
264
- self.relation_embeddings = RelationEmbedding(config)
265
-
266
- # Neural network for updating embeddings (paper Section 4.1)
267
- self.entity_update_network = nn.Sequential(
268
- nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim),
269
- nn.ReLU(),
270
- nn.Dropout(config.dropout),
271
- nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim)
272
- )
273
-
274
- def build_vocabulary_from_kg(self, kg: KnowledgeGraph):
275
- """Build vocabulary from knowledge graph entities and relations"""
276
- logger.info("Building vocabulary from knowledge graph...")
277
-
278
- # Add all entities
279
- for entity in kg.entities:
280
- self.entity_embeddings.add_entity(entity.name)
281
-
282
- # Add all relations
283
- for relation in kg.relations:
284
- self.relation_embeddings.add_relation(relation.name)
285
-
286
- logger.info(f"Vocabulary built: {self.entity_embeddings.vocab_size} entities, "
287
- f"{self.relation_embeddings.vocab_size} relations")
288
-
289
- def get_triple_embeddings(self, triple: Triple, add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
290
- """
291
- Get embeddings for a triple (head, relation, tail)
292
- """
293
- head_emb = self.entity_embeddings.get_embedding(triple.head.name, add_if_missing=add_if_missing)
294
- rel_emb = self.relation_embeddings.get_embedding(triple.relation.name, add_if_missing=add_if_missing)
295
- tail_emb = self.entity_embeddings.get_embedding(triple.tail.name, add_if_missing=add_if_missing)
296
-
297
- return head_emb, rel_emb, tail_emb
298
-
299
- def get_triple_embeddings_batch(self, triples: List[Triple], add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
300
- """Get embeddings for batch of triples"""
301
- head_names = [t.head.name for t in triples]
302
- rel_names = [t.relation.name for t in triples]
303
- tail_names = [t.tail.name for t in triples]
304
-
305
- head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
306
- rel_embs = self.relation_embeddings.get_embeddings_batch(rel_names, add_if_missing=add_if_missing)
307
- tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
308
-
309
- return head_embs, rel_embs, tail_embs
310
-
311
- def update_entity_embeddings(self, entity_embeddings: torch.Tensor) -> torch.Tensor:
312
- """
313
- Update entity embeddings using neural network
314
-
315
- """
316
- return self.entity_update_network(entity_embeddings)
317
-
318
- def get_entity_embedding(self, entity_name: str) -> torch.Tensor:
319
- """Get embedding for single entity"""
320
- return self.entity_embeddings.get_embedding(entity_name)
321
-
322
- def get_relation_embedding(self, relation_name: str) -> torch.Tensor:
323
- """Get embedding for single relation"""
324
- return self.relation_embeddings.get_embedding(relation_name)
325
-
326
- def get_embeddings_for_scoring(self, head_names: List[str],
327
- relation_names: List[str],
328
- tail_names: List[str],
329
- add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
330
- """
331
- Get embeddings formatted for scoring module input
332
-
333
- """
334
- # Get base embeddings
335
- head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
336
- rel_embs = self.relation_embeddings.get_embeddings_batch(relation_names, add_if_missing=add_if_missing)
337
- tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
338
-
339
- # Update entity embeddings through neural network
340
- head_embs = self.update_entity_embeddings(head_embs)
341
- tail_embs = self.update_entity_embeddings(tail_embs)
342
-
343
- return head_embs, rel_embs, tail_embs
344
-
345
- @property
346
- def entity_vocab_size(self) -> int:
347
- """Number of entities in vocabulary"""
348
- return self.entity_embeddings.vocab_size
349
-
350
- @property
351
- def relation_vocab_size(self) -> int:
352
- """Number of relations in vocabulary"""
353
- return self.relation_embeddings.vocab_size
354
-
355
- @property
356
- def relation_num_groups(self) -> int:
357
- """Size of relation embedding table (for per-relation temperature groups)."""
358
- return int(self.relation_embeddings.embedding.num_embeddings)
359
-
360
- def get_relation_indices_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
361
- idxs = [self.relation_embeddings.get_relation_index(n, add_if_missing=add_if_missing) for n in relation_names]
362
- device = self.relation_embeddings.embedding.weight.device
363
- return torch.tensor(idxs, dtype=torch.long, device=device)
364
-
365
- def relation_group_ids_for_triples(self, triples: List[Triple], add_if_missing: bool = False) -> torch.Tensor:
366
- rels = [t.relation.name for t in triples]
367
- return self.get_relation_indices_batch(rels, add_if_missing=add_if_missing)
368
-
369
- def save_vocabulary(self, filepath: str):
370
- """Save vocabulary mappings to file"""
371
- vocab_data = {
372
- 'entity_to_idx': self.entity_embeddings.entity_to_idx,
373
- 'relation_to_idx': self.relation_embeddings.relation_to_idx,
374
- 'config': self.config
375
- }
376
- torch.save(vocab_data, filepath)
377
- logger.info(f"Saved vocabulary to {filepath}")
378
-
379
- def load_vocabulary(self, filepath: str):
380
- """Load vocabulary mappings from file"""
381
- vocab_data = torch.load(filepath)
382
-
383
- # Load entity vocabulary
384
- self.entity_embeddings.entity_to_idx = vocab_data['entity_to_idx']
385
- self.entity_embeddings.idx_to_entity = {
386
- v: k for k, v in vocab_data['entity_to_idx'].items()
387
- }
388
- self.entity_embeddings._next_idx = len(vocab_data['entity_to_idx']) + 1
389
-
390
- # Load relation vocabulary
391
- self.relation_embeddings.relation_to_idx = vocab_data['relation_to_idx']
392
- self.relation_embeddings.idx_to_relation = {
393
- v: k for k, v in vocab_data['relation_to_idx'].items()
394
- }
395
- self.relation_embeddings._next_idx = len(vocab_data['relation_to_idx']) + 1
396
-
397
- logger.info(f"Loaded vocabulary from {filepath}")
398
-
399
-
400
- def initialize_embeddings_from_pretrained(embedding_manager: EmbeddingManager,
401
- pretrained_entity_embeddings: Optional[Dict[str, torch.Tensor]] = None,
402
- pretrained_relation_embeddings: Optional[Dict[str, torch.Tensor]] = None):
403
- """
404
- Initialize embeddings from pretrained vectors
405
- """
406
- if pretrained_entity_embeddings:
407
- logger.info(f"Initializing {len(pretrained_entity_embeddings)} entity embeddings from pretrained")
408
-
409
- with torch.no_grad():
410
- for entity_name, embedding_vector in pretrained_entity_embeddings.items():
411
- idx = embedding_manager.entity_embeddings.add_entity(entity_name)
412
- if embedding_vector.size(0) == embedding_manager.config.entity_embedding_dim:
413
- embedding_manager.entity_embeddings.embedding.weight[idx] = embedding_vector
414
- else:
415
- logger.warning(f"Dimension mismatch for entity {entity_name}: "
416
- f"expected {embedding_manager.config.entity_embedding_dim}, "
417
- f"got {embedding_vector.size(0)}")
418
-
419
- if pretrained_relation_embeddings:
420
- logger.info(f"Initializing {len(pretrained_relation_embeddings)} relation embeddings from pretrained")
421
-
422
- with torch.no_grad():
423
- for relation_name, embedding_vector in pretrained_relation_embeddings.items():
424
- idx = embedding_manager.relation_embeddings.add_relation(relation_name)
425
- if embedding_vector.size(0) == embedding_manager.config.relation_embedding_dim:
426
- embedding_manager.relation_embeddings.embedding.weight[idx] = embedding_vector
427
- else:
428
- logger.warning(f"Dimension mismatch for relation {relation_name}")
429
-
430
-
431
- # Factory function for creating embedding manager
432
- def create_embedding_manager(config: NPLLConfig, kg: Optional[KnowledgeGraph] = None) -> EmbeddingManager:
433
- """
434
- Factory function to create and initialize EmbeddingManager
435
-
436
- """
437
- manager = EmbeddingManager(config)
438
-
439
- if kg is not None:
440
- manager.build_vocabulary_from_kg(kg)
441
-
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Dict, List, Optional, Set, Tuple
6
+ import logging
7
+ from collections import OrderedDict
8
+
9
+ from ..core import KnowledgeGraph, Entity, Relation, Triple
10
+ from ..utils.config import NPLLConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class EntityEmbedding(nn.Module):
16
+ """
17
+ Entity embedding layer with dynamic vocabulary expansion
18
+
19
+ """
20
+
21
+ def __init__(self, config: NPLLConfig, initial_vocab_size: int = 10000):
22
+ super().__init__()
23
+ self.config = config
24
+ self.embedding_dim = config.entity_embedding_dim
25
+
26
+ # Entity vocabulary mapping: entity_name -> index
27
+ self.entity_to_idx: Dict[str, int] = {}
28
+ self.idx_to_entity: Dict[int, str] = {}
29
+
30
+ # Reserve indices: 0 padding, 1 OOV; new entries from 2+
31
+ self.padding_idx = 0
32
+ self.oov_idx = 1
33
+ # Embedding layer - will expand dynamically
34
+ self.embedding = nn.Embedding(
35
+ num_embeddings=initial_vocab_size,
36
+ embedding_dim=self.embedding_dim,
37
+ padding_idx=self.padding_idx
38
+ )
39
+
40
+ # Initialize embeddings using Xavier initialization
41
+ nn.init.xavier_uniform_(self.embedding.weight.data)
42
+ # Set padding embedding to zero; init OOV to a valid vector
43
+ with torch.no_grad():
44
+ self.embedding.weight[self.padding_idx].zero_()
45
+ # Track next available index (0 pad, 1 oov reserved)
46
+ self._next_idx = 2
47
+
48
+ def add_entity(self, entity_name: str) -> int:
49
+ """Add entity to vocabulary and return its index"""
50
+ if entity_name not in self.entity_to_idx:
51
+ # Check if we need to expand embedding layer
52
+ if self._next_idx >= self.embedding.num_embeddings:
53
+ self._expand_embeddings()
54
+
55
+ # Add entity to vocabulary
56
+ idx = self._next_idx
57
+ self.entity_to_idx[entity_name] = idx
58
+ self.idx_to_entity[idx] = entity_name
59
+ self._next_idx += 1
60
+
61
+ logger.debug(f"Added entity '{entity_name}' with index {idx}")
62
+
63
+ return self.entity_to_idx[entity_name]
64
+
65
+ def _expand_embeddings(self, grow_by: int = 1000, on_expand=None):
66
+ """Expand embedding layer when vocabulary grows (preserve device/dtype)."""
67
+ old = self.embedding
68
+ old_size = old.num_embeddings
69
+ new_size = max(old_size * 2, self._next_idx + grow_by)
70
+ new = nn.Embedding(
71
+ num_embeddings=new_size,
72
+ embedding_dim=self.embedding_dim,
73
+ padding_idx=self.padding_idx,
74
+ dtype=old.weight.dtype,
75
+ device=old.weight.device,
76
+ )
77
+ with torch.no_grad():
78
+ new.weight[:old_size].copy_(old.weight)
79
+ nn.init.xavier_uniform_(new.weight[old_size:])
80
+ new.weight[self.padding_idx].zero_()
81
+ self.embedding = new
82
+ if on_expand is not None:
83
+ on_expand(self.embedding)
84
+ logger.info(f"Expanded entity embeddings from {old_size} to {new_size}")
85
+
86
+ def get_entity_index(self, entity_name: str, add_if_missing: bool = False) -> int:
87
+ """Get index for entity; returns OOV if missing and add_if_missing=False."""
88
+ idx = self.entity_to_idx.get(entity_name)
89
+ if idx is None:
90
+ return self.add_entity(entity_name) if add_if_missing else self.oov_idx
91
+ return idx
92
+
93
+ def get_entity_name(self, idx: int) -> Optional[str]:
94
+ """Get entity name from index"""
95
+ return self.idx_to_entity.get(idx)
96
+
97
+ def get_embedding(self, entity_name: str, add_if_missing: bool = False) -> torch.Tensor:
98
+ """Get embedding vector for entity (device-safe)."""
99
+ idx = self.get_entity_index(entity_name, add_if_missing=add_if_missing)
100
+ device = self.embedding.weight.device
101
+ return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
102
+
103
+ def get_embeddings_batch(self, entity_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
104
+ """Get embedding vectors for batch of entities (device-safe)."""
105
+ indices = [self.get_entity_index(name, add_if_missing=add_if_missing) for name in entity_names]
106
+ device = self.embedding.weight.device
107
+ indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
108
+ return self.embedding(indices_tensor)
109
+
110
+ def forward(self, entity_indices: torch.Tensor) -> torch.Tensor:
111
+ """Forward pass for entity embeddings"""
112
+ return self.embedding(entity_indices)
113
+
114
+ @property
115
+ def vocab_size(self) -> int:
116
+ """Current vocabulary size"""
117
+ return len(self.entity_to_idx)
118
+
119
+ def state_dict_with_vocab(self) -> Dict:
120
+ """Get state dict including vocabulary mappings"""
121
+ state = super().state_dict()
122
+ state['entity_to_idx'] = self.entity_to_idx.copy()
123
+ state['idx_to_entity'] = self.idx_to_entity.copy()
124
+ state['_next_idx'] = self._next_idx
125
+ state['padding_idx'] = self.padding_idx
126
+ state['oov_idx'] = self.oov_idx
127
+ return state
128
+
129
+ def load_state_dict_with_vocab(self, state_dict: Dict):
130
+ """Load state dict including vocabulary mappings"""
131
+ # Load vocabulary mappings first
132
+ self.entity_to_idx = state_dict.pop('entity_to_idx', {})
133
+ self.idx_to_entity = state_dict.pop('idx_to_entity', {})
134
+ self._next_idx = state_dict.pop('_next_idx', 2)
135
+ self.padding_idx = state_dict.pop('padding_idx', 0)
136
+ self.oov_idx = state_dict.pop('oov_idx', 1)
137
+ needed = max(self._next_idx, self.embedding.num_embeddings)
138
+ if needed > self.embedding.num_embeddings:
139
+ old = self.embedding
140
+ new = nn.Embedding(
141
+ num_embeddings=needed,
142
+ embedding_dim=self.embedding_dim,
143
+ padding_idx=self.padding_idx,
144
+ dtype=old.weight.dtype,
145
+ device=old.weight.device,
146
+ )
147
+ with torch.no_grad():
148
+ new.weight[:old.num_embeddings].copy_(old.weight)
149
+ nn.init.xavier_uniform_(new.weight[old.num_embeddings:])
150
+ new.weight[self.padding_idx].zero_()
151
+ self.embedding = new
152
+ # Load model parameters leniently in case of size diffs
153
+ super().load_state_dict(state_dict, strict=False)
154
+
155
+
156
+ class RelationEmbedding(nn.Module):
157
+ """
158
+ Relation embedding layer with dynamic vocabulary expansion
159
+ """
160
+
161
+ def __init__(self, config: NPLLConfig, initial_vocab_size: int = 1000):
162
+ super().__init__()
163
+ self.config = config
164
+ self.embedding_dim = config.relation_embedding_dim
165
+
166
+ # Relation vocabulary mapping
167
+ self.relation_to_idx: Dict[str, int] = {}
168
+ self.idx_to_relation: Dict[int, str] = {}
169
+
170
+ # Embedding layer
171
+ self.padding_idx = 0
172
+ self.oov_idx = 1
173
+ self.embedding = nn.Embedding(
174
+ num_embeddings=initial_vocab_size,
175
+ embedding_dim=self.embedding_dim,
176
+ padding_idx=self.padding_idx
177
+ )
178
+
179
+ # Initialize embeddings
180
+ nn.init.xavier_uniform_(self.embedding.weight.data)
181
+ with torch.no_grad():
182
+ self.embedding.weight[self.padding_idx].zero_()
183
+
184
+ self._next_idx = 2
185
+
186
+ def add_relation(self, relation_name: str) -> int:
187
+ """Add relation to vocabulary and return its index"""
188
+ if relation_name not in self.relation_to_idx:
189
+ if self._next_idx >= self.embedding.num_embeddings:
190
+ self._expand_embeddings()
191
+
192
+ idx = self._next_idx
193
+ self.relation_to_idx[relation_name] = idx
194
+ self.idx_to_relation[idx] = relation_name
195
+ self._next_idx += 1
196
+
197
+ logger.debug(f"Added relation '{relation_name}' with index {idx}")
198
+
199
+ return self.relation_to_idx[relation_name]
200
+
201
+ def _expand_embeddings(self, grow_by: int = 100, on_expand=None):
202
+ """Expand embedding layer when vocabulary grows (preserve device/dtype)."""
203
+ old = self.embedding
204
+ old_size = old.num_embeddings
205
+ new_size = max(old_size * 2, self._next_idx + grow_by)
206
+ new = nn.Embedding(
207
+ num_embeddings=new_size,
208
+ embedding_dim=self.embedding_dim,
209
+ padding_idx=self.padding_idx,
210
+ dtype=old.weight.dtype,
211
+ device=old.weight.device,
212
+ )
213
+ with torch.no_grad():
214
+ new.weight[:old_size].copy_(old.weight)
215
+ nn.init.xavier_uniform_(new.weight[old_size:])
216
+ new.weight[self.padding_idx].zero_()
217
+ self.embedding = new
218
+ if on_expand is not None:
219
+ on_expand(self.embedding)
220
+ logger.info(f"Expanded relation embeddings from {old_size} to {new_size}")
221
+
222
+ def get_relation_index(self, relation_name: str, add_if_missing: bool = False) -> int:
223
+ """Get index for relation; returns OOV if missing and add_if_missing=False."""
224
+ idx = self.relation_to_idx.get(relation_name)
225
+ if idx is None:
226
+ return self.add_relation(relation_name) if add_if_missing else self.oov_idx
227
+ return idx
228
+
229
+ def get_embedding(self, relation_name: str, add_if_missing: bool = False) -> torch.Tensor:
230
+ """Get embedding vector for relation (device-safe)."""
231
+ idx = self.get_relation_index(relation_name, add_if_missing=add_if_missing)
232
+ device = self.embedding.weight.device
233
+ return self.embedding(torch.tensor([idx], dtype=torch.long, device=device)).squeeze(0)
234
+
235
+ def get_embeddings_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
236
+ """Get embedding vectors for batch of relations (device-safe)."""
237
+ indices = [self.get_relation_index(name, add_if_missing=add_if_missing) for name in relation_names]
238
+ device = self.embedding.weight.device
239
+ indices_tensor = torch.tensor(indices, dtype=torch.long, device=device)
240
+ return self.embedding(indices_tensor)
241
+
242
+ def forward(self, relation_indices: torch.Tensor) -> torch.Tensor:
243
+ """Forward pass for relation embeddings"""
244
+ return self.embedding(relation_indices)
245
+
246
+ @property
247
+ def vocab_size(self) -> int:
248
+ """Current vocabulary size"""
249
+ return len(self.relation_to_idx)
250
+
251
+
252
+ class EmbeddingManager(nn.Module):
253
+ """
254
+ Manages both entity and relation embeddings
255
+ Handles vocabulary building from knowledge graph
256
+ """
257
+
258
+ def __init__(self, config: NPLLConfig):
259
+ super().__init__()
260
+ self.config = config
261
+
262
+ # Entity and relation embedding layers
263
+ self.entity_embeddings = EntityEmbedding(config)
264
+ self.relation_embeddings = RelationEmbedding(config)
265
+
266
+ # Neural network for updating embeddings (paper Section 4.1)
267
+ self.entity_update_network = nn.Sequential(
268
+ nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim),
269
+ nn.ReLU(),
270
+ nn.Dropout(config.dropout),
271
+ nn.Linear(config.entity_embedding_dim, config.entity_embedding_dim)
272
+ )
273
+
274
+ def build_vocabulary_from_kg(self, kg: KnowledgeGraph):
275
+ """Build vocabulary from knowledge graph entities and relations"""
276
+ logger.info("Building vocabulary from knowledge graph...")
277
+
278
+ # Add all entities
279
+ for entity in kg.entities:
280
+ self.entity_embeddings.add_entity(entity.name)
281
+
282
+ # Add all relations
283
+ for relation in kg.relations:
284
+ self.relation_embeddings.add_relation(relation.name)
285
+
286
+ logger.info(f"Vocabulary built: {self.entity_embeddings.vocab_size} entities, "
287
+ f"{self.relation_embeddings.vocab_size} relations")
288
+
289
+ def get_triple_embeddings(self, triple: Triple, add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
290
+ """
291
+ Get embeddings for a triple (head, relation, tail)
292
+ """
293
+ head_emb = self.entity_embeddings.get_embedding(triple.head.name, add_if_missing=add_if_missing)
294
+ rel_emb = self.relation_embeddings.get_embedding(triple.relation.name, add_if_missing=add_if_missing)
295
+ tail_emb = self.entity_embeddings.get_embedding(triple.tail.name, add_if_missing=add_if_missing)
296
+
297
+ return head_emb, rel_emb, tail_emb
298
+
299
+ def get_triple_embeddings_batch(self, triples: List[Triple], add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
300
+ """Get embeddings for batch of triples"""
301
+ head_names = [t.head.name for t in triples]
302
+ rel_names = [t.relation.name for t in triples]
303
+ tail_names = [t.tail.name for t in triples]
304
+
305
+ head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
306
+ rel_embs = self.relation_embeddings.get_embeddings_batch(rel_names, add_if_missing=add_if_missing)
307
+ tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
308
+
309
+ return head_embs, rel_embs, tail_embs
310
+
311
+ def update_entity_embeddings(self, entity_embeddings: torch.Tensor) -> torch.Tensor:
312
+ """
313
+ Update entity embeddings using neural network
314
+
315
+ """
316
+ return self.entity_update_network(entity_embeddings)
317
+
318
+ def get_entity_embedding(self, entity_name: str) -> torch.Tensor:
319
+ """Get embedding for single entity"""
320
+ return self.entity_embeddings.get_embedding(entity_name)
321
+
322
+ def get_relation_embedding(self, relation_name: str) -> torch.Tensor:
323
+ """Get embedding for single relation"""
324
+ return self.relation_embeddings.get_embedding(relation_name)
325
+
326
+ def get_embeddings_for_scoring(self, head_names: List[str],
327
+ relation_names: List[str],
328
+ tail_names: List[str],
329
+ add_if_missing: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
330
+ """
331
+ Get embeddings formatted for scoring module input
332
+
333
+ """
334
+ # Get base embeddings
335
+ head_embs = self.entity_embeddings.get_embeddings_batch(head_names, add_if_missing=add_if_missing)
336
+ rel_embs = self.relation_embeddings.get_embeddings_batch(relation_names, add_if_missing=add_if_missing)
337
+ tail_embs = self.entity_embeddings.get_embeddings_batch(tail_names, add_if_missing=add_if_missing)
338
+
339
+ # Update entity embeddings through neural network
340
+ head_embs = self.update_entity_embeddings(head_embs)
341
+ tail_embs = self.update_entity_embeddings(tail_embs)
342
+
343
+ return head_embs, rel_embs, tail_embs
344
+
345
+ @property
346
+ def entity_vocab_size(self) -> int:
347
+ """Number of entities in vocabulary"""
348
+ return self.entity_embeddings.vocab_size
349
+
350
+ @property
351
+ def relation_vocab_size(self) -> int:
352
+ """Number of relations in vocabulary"""
353
+ return self.relation_embeddings.vocab_size
354
+
355
+ @property
356
+ def relation_num_groups(self) -> int:
357
+ """Size of relation embedding table (for per-relation temperature groups)."""
358
+ return int(self.relation_embeddings.embedding.num_embeddings)
359
+
360
+ def get_relation_indices_batch(self, relation_names: List[str], add_if_missing: bool = False) -> torch.Tensor:
361
+ idxs = [self.relation_embeddings.get_relation_index(n, add_if_missing=add_if_missing) for n in relation_names]
362
+ device = self.relation_embeddings.embedding.weight.device
363
+ return torch.tensor(idxs, dtype=torch.long, device=device)
364
+
365
+ def relation_group_ids_for_triples(self, triples: List[Triple], add_if_missing: bool = False) -> torch.Tensor:
366
+ rels = [t.relation.name for t in triples]
367
+ return self.get_relation_indices_batch(rels, add_if_missing=add_if_missing)
368
+
369
+ def save_vocabulary(self, filepath: str):
370
+ """Save vocabulary mappings to file"""
371
+ vocab_data = {
372
+ 'entity_to_idx': self.entity_embeddings.entity_to_idx,
373
+ 'relation_to_idx': self.relation_embeddings.relation_to_idx,
374
+ 'config': self.config
375
+ }
376
+ torch.save(vocab_data, filepath)
377
+ logger.info(f"Saved vocabulary to {filepath}")
378
+
379
+ def load_vocabulary(self, filepath: str):
380
+ """Load vocabulary mappings from file"""
381
+ vocab_data = torch.load(filepath)
382
+
383
+ # Load entity vocabulary
384
+ self.entity_embeddings.entity_to_idx = vocab_data['entity_to_idx']
385
+ self.entity_embeddings.idx_to_entity = {
386
+ v: k for k, v in vocab_data['entity_to_idx'].items()
387
+ }
388
+ self.entity_embeddings._next_idx = len(vocab_data['entity_to_idx']) + 1
389
+
390
+ # Load relation vocabulary
391
+ self.relation_embeddings.relation_to_idx = vocab_data['relation_to_idx']
392
+ self.relation_embeddings.idx_to_relation = {
393
+ v: k for k, v in vocab_data['relation_to_idx'].items()
394
+ }
395
+ self.relation_embeddings._next_idx = len(vocab_data['relation_to_idx']) + 1
396
+
397
+ logger.info(f"Loaded vocabulary from {filepath}")
398
+
399
+
400
+ def initialize_embeddings_from_pretrained(embedding_manager: EmbeddingManager,
401
+ pretrained_entity_embeddings: Optional[Dict[str, torch.Tensor]] = None,
402
+ pretrained_relation_embeddings: Optional[Dict[str, torch.Tensor]] = None):
403
+ """
404
+ Initialize embeddings from pretrained vectors
405
+ """
406
+ if pretrained_entity_embeddings:
407
+ logger.info(f"Initializing {len(pretrained_entity_embeddings)} entity embeddings from pretrained")
408
+
409
+ with torch.no_grad():
410
+ for entity_name, embedding_vector in pretrained_entity_embeddings.items():
411
+ idx = embedding_manager.entity_embeddings.add_entity(entity_name)
412
+ if embedding_vector.size(0) == embedding_manager.config.entity_embedding_dim:
413
+ embedding_manager.entity_embeddings.embedding.weight[idx] = embedding_vector
414
+ else:
415
+ logger.warning(f"Dimension mismatch for entity {entity_name}: "
416
+ f"expected {embedding_manager.config.entity_embedding_dim}, "
417
+ f"got {embedding_vector.size(0)}")
418
+
419
+ if pretrained_relation_embeddings:
420
+ logger.info(f"Initializing {len(pretrained_relation_embeddings)} relation embeddings from pretrained")
421
+
422
+ with torch.no_grad():
423
+ for relation_name, embedding_vector in pretrained_relation_embeddings.items():
424
+ idx = embedding_manager.relation_embeddings.add_relation(relation_name)
425
+ if embedding_vector.size(0) == embedding_manager.config.relation_embedding_dim:
426
+ embedding_manager.relation_embeddings.embedding.weight[idx] = embedding_vector
427
+ else:
428
+ logger.warning(f"Dimension mismatch for relation {relation_name}")
429
+
430
+
431
+ # Factory function for creating embedding manager
432
+ def create_embedding_manager(config: NPLLConfig, kg: Optional[KnowledgeGraph] = None) -> EmbeddingManager:
433
+ """
434
+ Factory function to create and initialize EmbeddingManager
435
+
436
+ """
437
+ manager = EmbeddingManager(config)
438
+
439
+ if kg is not None:
440
+ manager.build_vocabulary_from_kg(kg)
441
+
442
442
  return manager