odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/npll_model.py
CHANGED
|
@@ -1,632 +1,632 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Main NPLL Model - Neural Probabilistic Logic Learning
|
|
3
|
-
Integrates all components: MLN, Scoring Module, E-step, M-step
|
|
4
|
-
Exact implementation of the complete NPLL framework from the paper
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
import torch.nn as nn
|
|
9
|
-
from typing import List, Dict, Set, Tuple, Optional, Any, Union
|
|
10
|
-
import logging
|
|
11
|
-
from dataclasses import dataclass
|
|
12
|
-
import time
|
|
13
|
-
import os
|
|
14
|
-
|
|
15
|
-
from .core import (
|
|
16
|
-
KnowledgeGraph, Entity, Relation, Triple, LogicalRule, GroundRule,
|
|
17
|
-
MarkovLogicNetwork, create_mln_from_kg_and_rules
|
|
18
|
-
)
|
|
19
|
-
from .scoring import (
|
|
20
|
-
NPLLScoringModule, create_scoring_module,
|
|
21
|
-
ProbabilityTransform, create_probability_components
|
|
22
|
-
)
|
|
23
|
-
from .inference import (
|
|
24
|
-
EStepRunner, MStepRunner, ELBOComputer, EStepResult, MStepResult,
|
|
25
|
-
create_e_step_runner, create_m_step_runner
|
|
26
|
-
)
|
|
27
|
-
from .utils import NPLLConfig, get_config
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger(__name__)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@dataclass
|
|
33
|
-
class NPLLTrainingState:
|
|
34
|
-
"""
|
|
35
|
-
Training state for NPLL model
|
|
36
|
-
Tracks progress through E-M iterations
|
|
37
|
-
"""
|
|
38
|
-
epoch: int
|
|
39
|
-
em_iteration: int
|
|
40
|
-
elbo_history: List[float]
|
|
41
|
-
rule_weight_history: List[List[float]]
|
|
42
|
-
convergence_info: Dict[str, Any]
|
|
43
|
-
training_time: float
|
|
44
|
-
best_elbo: float
|
|
45
|
-
best_weights: Optional[torch.Tensor] = None
|
|
46
|
-
|
|
47
|
-
def __str__(self) -> str:
|
|
48
|
-
return (f"NPLL Training State:\n"
|
|
49
|
-
f" Epoch: {self.epoch}\n"
|
|
50
|
-
f" EM Iteration: {self.em_iteration}\n"
|
|
51
|
-
f" Current ELBO: {self.elbo_history[-1] if self.elbo_history else 'N/A'}\n"
|
|
52
|
-
f" Best ELBO: {self.best_elbo:.6f}\n"
|
|
53
|
-
f" Training Time: {self.training_time:.2f}s")
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class NPLLModel(nn.Module):
|
|
57
|
-
"""
|
|
58
|
-
Complete Neural Probabilistic Logic Learning Model
|
|
59
|
-
|
|
60
|
-
Integrates all paper components:
|
|
61
|
-
- Knowledge Graph representation (Section 3)
|
|
62
|
-
- Scoring Module with bilinear function (Section 4.1, Equation 7)
|
|
63
|
-
- Markov Logic Network (Sections 3-4, Equations 1-2)
|
|
64
|
-
- E-M Algorithm (Sections 4.2-4.3, Equations 8-14)
|
|
65
|
-
- ELBO optimization (Equations 3-5)
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
def __init__(self, config: NPLLConfig):
|
|
69
|
-
super().__init__()
|
|
70
|
-
self.config = config
|
|
71
|
-
|
|
72
|
-
# Core components (initialized when knowledge graph is provided)
|
|
73
|
-
self.knowledge_graph: Optional[KnowledgeGraph] = None
|
|
74
|
-
self.mln: Optional[MarkovLogicNetwork] = None
|
|
75
|
-
self.scoring_module: Optional[NPLLScoringModule] = None
|
|
76
|
-
|
|
77
|
-
# Inference components
|
|
78
|
-
self.e_step_runner: Optional[EStepRunner] = None
|
|
79
|
-
self.m_step_runner: Optional[MStepRunner] = None
|
|
80
|
-
self.elbo_computer: Optional[ELBOComputer] = None
|
|
81
|
-
|
|
82
|
-
# Probability transformation
|
|
83
|
-
self.probability_transform: Optional[ProbabilityTransform] = None
|
|
84
|
-
|
|
85
|
-
# Training state
|
|
86
|
-
self.training_state: Optional[NPLLTrainingState] = None
|
|
87
|
-
self.is_initialized = False
|
|
88
|
-
|
|
89
|
-
# Model metadata
|
|
90
|
-
self.model_version = "1.0"
|
|
91
|
-
self.creation_time = time.time()
|
|
92
|
-
self.calibration_version = None
|
|
93
|
-
|
|
94
|
-
def initialize(self,
|
|
95
|
-
knowledge_graph: KnowledgeGraph,
|
|
96
|
-
logical_rules: List[LogicalRule]):
|
|
97
|
-
"""
|
|
98
|
-
Initialize NPLL model with knowledge graph and rules
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
knowledge_graph: Knowledge graph K = (E, L, F)
|
|
102
|
-
logical_rules: List of logical rules R
|
|
103
|
-
"""
|
|
104
|
-
logger.info("Initializing NPLL model...")
|
|
105
|
-
|
|
106
|
-
# Store knowledge graph
|
|
107
|
-
self.knowledge_graph = knowledge_graph
|
|
108
|
-
|
|
109
|
-
# Create and initialize MLN
|
|
110
|
-
self.mln = create_mln_from_kg_and_rules(knowledge_graph, logical_rules, self.config)
|
|
111
|
-
|
|
112
|
-
# Create scoring module
|
|
113
|
-
self.scoring_module = create_scoring_module(self.config, knowledge_graph)
|
|
114
|
-
|
|
115
|
-
# Create inference components
|
|
116
|
-
self.e_step_runner = create_e_step_runner(self.config)
|
|
117
|
-
self.m_step_runner = create_m_step_runner(self.config)
|
|
118
|
-
|
|
119
|
-
# Create ELBO computer
|
|
120
|
-
from .inference import create_elbo_computer
|
|
121
|
-
self.elbo_computer = create_elbo_computer(self.config)
|
|
122
|
-
|
|
123
|
-
# Create probability transformation (enable per-relation groups if available)
|
|
124
|
-
num_rel = None
|
|
125
|
-
if self.scoring_module is not None and hasattr(self.scoring_module, 'embedding_manager'):
|
|
126
|
-
emb_mgr = self.scoring_module.embedding_manager
|
|
127
|
-
if hasattr(emb_mgr, 'relation_num_groups'):
|
|
128
|
-
num_rel = emb_mgr.relation_num_groups
|
|
129
|
-
per_relation = num_rel is not None
|
|
130
|
-
prob_transform, _ = create_probability_components(
|
|
131
|
-
self.config.temperature,
|
|
132
|
-
per_relation=per_relation,
|
|
133
|
-
num_relations=(num_rel or 1)
|
|
134
|
-
)
|
|
135
|
-
self.probability_transform = prob_transform
|
|
136
|
-
|
|
137
|
-
# Initialize training state
|
|
138
|
-
self.training_state = NPLLTrainingState(
|
|
139
|
-
epoch=0,
|
|
140
|
-
em_iteration=0,
|
|
141
|
-
elbo_history=[],
|
|
142
|
-
rule_weight_history=[],
|
|
143
|
-
convergence_info={},
|
|
144
|
-
training_time=0.0,
|
|
145
|
-
best_elbo=float('-inf')
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
self.is_initialized = True
|
|
149
|
-
|
|
150
|
-
logger.info(f"NPLL model initialized with {len(logical_rules)} rules, "
|
|
151
|
-
f"{len(knowledge_graph.entities)} entities, "
|
|
152
|
-
f"{len(knowledge_graph.relations)} relations")
|
|
153
|
-
|
|
154
|
-
def forward(self, triples: List[Triple]) -> Dict[str, torch.Tensor]:
|
|
155
|
-
"""
|
|
156
|
-
Forward pass through NPLL model
|
|
157
|
-
|
|
158
|
-
Args:
|
|
159
|
-
triples: List of triples to score
|
|
160
|
-
|
|
161
|
-
Returns:
|
|
162
|
-
Dictionary with scores and probabilities
|
|
163
|
-
"""
|
|
164
|
-
if not self.is_initialized:
|
|
165
|
-
raise RuntimeError("NPLL model not initialized. Call initialize() first.")
|
|
166
|
-
|
|
167
|
-
# Get raw scores from scoring module
|
|
168
|
-
raw_scores = self.scoring_module.forward(triples)
|
|
169
|
-
|
|
170
|
-
# Group IDs per relation (no vocab mutation)
|
|
171
|
-
group_ids = None
|
|
172
|
-
if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
|
|
173
|
-
emb_mgr = self.scoring_module.embedding_manager
|
|
174
|
-
group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
|
|
175
|
-
# Ensure transform capacity if table grew
|
|
176
|
-
if hasattr(self.probability_transform, 'ensure_num_groups'):
|
|
177
|
-
self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
|
|
178
|
-
|
|
179
|
-
# Transform to probabilities
|
|
180
|
-
probabilities = self.probability_transform(raw_scores, apply_temperature=True, group_ids=group_ids)
|
|
181
|
-
|
|
182
|
-
# Get log probabilities
|
|
183
|
-
log_probabilities = self.probability_transform.get_log_probabilities(raw_scores, apply_temperature=True, group_ids=group_ids)
|
|
184
|
-
|
|
185
|
-
return {
|
|
186
|
-
'raw_scores': raw_scores,
|
|
187
|
-
'probabilities': probabilities,
|
|
188
|
-
'log_probabilities': log_probabilities
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
def predict_single_triple(self, head: str, relation: str, tail: str, transient: bool = True) -> Dict[str, float]:
|
|
192
|
-
"""
|
|
193
|
-
Predict probability for a single triple
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
head: Head entity name
|
|
197
|
-
relation: Relation name
|
|
198
|
-
tail: Tail entity name
|
|
199
|
-
transient: If True, do not mutate the underlying knowledge graph
|
|
200
|
-
|
|
201
|
-
Returns:
|
|
202
|
-
Dictionary with prediction results
|
|
203
|
-
"""
|
|
204
|
-
if not self.is_initialized:
|
|
205
|
-
raise RuntimeError("NPLL model not initialized")
|
|
206
|
-
|
|
207
|
-
# Create triple object without mutating KG by default
|
|
208
|
-
if transient:
|
|
209
|
-
head_entity = Entity(head)
|
|
210
|
-
relation_obj = Relation(relation)
|
|
211
|
-
tail_entity = Entity(tail)
|
|
212
|
-
else:
|
|
213
|
-
head_entity = self.knowledge_graph.get_entity(head) or self.knowledge_graph.add_entity(head)
|
|
214
|
-
relation_obj = self.knowledge_graph.get_relation(relation) or self.knowledge_graph.add_relation(relation)
|
|
215
|
-
tail_entity = self.knowledge_graph.get_entity(tail) or self.knowledge_graph.add_entity(tail)
|
|
216
|
-
|
|
217
|
-
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
218
|
-
|
|
219
|
-
# Get predictions
|
|
220
|
-
self.eval()
|
|
221
|
-
with torch.no_grad():
|
|
222
|
-
results = self.forward([triple])
|
|
223
|
-
|
|
224
|
-
return {
|
|
225
|
-
'probability': results['probabilities'][0].item(),
|
|
226
|
-
'log_probability': results['log_probabilities'][0].item(),
|
|
227
|
-
'raw_score': results['raw_scores'][0].item()
|
|
228
|
-
}
|
|
229
|
-
|
|
230
|
-
def run_single_em_iteration(self) -> Dict[str, Any]:
|
|
231
|
-
"""
|
|
232
|
-
Run a single E-M iteration
|
|
233
|
-
|
|
234
|
-
Returns:
|
|
235
|
-
Dictionary with iteration results
|
|
236
|
-
"""
|
|
237
|
-
if not self.is_initialized:
|
|
238
|
-
raise RuntimeError("NPLL model not initialized")
|
|
239
|
-
|
|
240
|
-
iteration_start_time = time.time()
|
|
241
|
-
|
|
242
|
-
logger.debug(f"Starting E-M iteration {self.training_state.em_iteration}")
|
|
243
|
-
|
|
244
|
-
# E-step: Optimize Q(U)
|
|
245
|
-
logger.debug("Running E-step...")
|
|
246
|
-
e_step_result = self.e_step_runner.run_e_step(
|
|
247
|
-
self.mln, self.scoring_module, self.knowledge_graph
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
# M-step: Optimize rule weights ω
|
|
251
|
-
logger.debug("Running M-step...")
|
|
252
|
-
m_step_result = self.m_step_runner.run_m_step(self.mln, e_step_result)
|
|
253
|
-
|
|
254
|
-
# Update training state
|
|
255
|
-
current_elbo = e_step_result.elbo_value.item()
|
|
256
|
-
self.training_state.elbo_history.append(current_elbo)
|
|
257
|
-
|
|
258
|
-
if self.mln.rule_weights is not None:
|
|
259
|
-
current_weights = self.mln.rule_weights.data.tolist()
|
|
260
|
-
self.training_state.rule_weight_history.append(current_weights)
|
|
261
|
-
|
|
262
|
-
# Track best snapshot (MLN + scoring)
|
|
263
|
-
if current_elbo > self.training_state.best_elbo:
|
|
264
|
-
self.training_state.best_elbo = current_elbo
|
|
265
|
-
self.training_state.best_weights = {
|
|
266
|
-
'mln': {k: v.clone() for k, v in self.mln.state_dict().items()},
|
|
267
|
-
'scoring': {k: v.clone() for k, v in self.scoring_module.state_dict().items()} if self.scoring_module else {},
|
|
268
|
-
}
|
|
269
|
-
|
|
270
|
-
self.training_state.em_iteration += 1
|
|
271
|
-
iteration_time = time.time() - iteration_start_time
|
|
272
|
-
self.training_state.training_time += iteration_time
|
|
273
|
-
|
|
274
|
-
# Check convergence
|
|
275
|
-
converged = self._check_em_convergence()
|
|
276
|
-
|
|
277
|
-
iteration_result = {
|
|
278
|
-
'em_iteration': self.training_state.em_iteration - 1,
|
|
279
|
-
'e_step_result': e_step_result,
|
|
280
|
-
'm_step_result': m_step_result,
|
|
281
|
-
'elbo': current_elbo,
|
|
282
|
-
'iteration_time': iteration_time,
|
|
283
|
-
'converged': converged,
|
|
284
|
-
'convergence_info': {
|
|
285
|
-
'e_step_converged': e_step_result.convergence_info.get('converged', False),
|
|
286
|
-
'm_step_converged': m_step_result.convergence_info.get('converged', False)
|
|
287
|
-
}
|
|
288
|
-
}
|
|
289
|
-
|
|
290
|
-
logger.debug(f"E-M iteration completed: ELBO={current_elbo:.6f}, "
|
|
291
|
-
f"Time={iteration_time:.2f}s, Converged={converged}")
|
|
292
|
-
|
|
293
|
-
return iteration_result
|
|
294
|
-
|
|
295
|
-
def train_epoch(self, max_em_iterations: Optional[int] = None) -> Dict[str, Any]:
|
|
296
|
-
"""
|
|
297
|
-
Train for one epoch (multiple E-M iterations until convergence)
|
|
298
|
-
|
|
299
|
-
Args:
|
|
300
|
-
max_em_iterations: Maximum E-M iterations per epoch
|
|
301
|
-
|
|
302
|
-
Returns:
|
|
303
|
-
Dictionary with epoch results
|
|
304
|
-
"""
|
|
305
|
-
if not self.is_initialized:
|
|
306
|
-
raise RuntimeError("NPLL model not initialized")
|
|
307
|
-
|
|
308
|
-
max_iterations = max_em_iterations or self.config.em_iterations
|
|
309
|
-
epoch_start_time = time.time()
|
|
310
|
-
|
|
311
|
-
logger.info(f"Starting training epoch {self.training_state.epoch}")
|
|
312
|
-
|
|
313
|
-
epoch_results = []
|
|
314
|
-
converged = False
|
|
315
|
-
|
|
316
|
-
for em_iter in range(max_iterations):
|
|
317
|
-
iteration_result = self.run_single_em_iteration()
|
|
318
|
-
epoch_results.append(iteration_result)
|
|
319
|
-
|
|
320
|
-
if iteration_result['converged']:
|
|
321
|
-
converged = True
|
|
322
|
-
logger.info(f"Converged after {em_iter + 1} E-M iterations")
|
|
323
|
-
break
|
|
324
|
-
|
|
325
|
-
epoch_time = time.time() - epoch_start_time
|
|
326
|
-
self.training_state.epoch += 1
|
|
327
|
-
|
|
328
|
-
epoch_summary = {
|
|
329
|
-
'epoch': self.training_state.epoch - 1,
|
|
330
|
-
'em_iterations': len(epoch_results),
|
|
331
|
-
'converged': converged,
|
|
332
|
-
'final_elbo': epoch_results[-1]['elbo'] if epoch_results else float('-inf'),
|
|
333
|
-
'best_elbo_this_epoch': max(r['elbo'] for r in epoch_results) if epoch_results else float('-inf'),
|
|
334
|
-
'epoch_time': epoch_time,
|
|
335
|
-
'iteration_results': epoch_results
|
|
336
|
-
}
|
|
337
|
-
|
|
338
|
-
logger.info(f"Epoch {self.training_state.epoch - 1} completed: "
|
|
339
|
-
f"ELBO={epoch_summary['final_elbo']:.6f}, "
|
|
340
|
-
f"EM iterations={epoch_summary['em_iterations']}, "
|
|
341
|
-
f"Time={epoch_time:.2f}s")
|
|
342
|
-
|
|
343
|
-
return epoch_summary
|
|
344
|
-
|
|
345
|
-
def _check_em_convergence(self) -> bool:
|
|
346
|
-
"""Check if E-M algorithm has converged with patience and relative tolerance"""
|
|
347
|
-
if len(self.training_state.elbo_history) < 2:
|
|
348
|
-
return False
|
|
349
|
-
h = self.training_state.elbo_history
|
|
350
|
-
rel = abs(h[-1] - h[-2]) / (abs(h[-2]) + 1e-8)
|
|
351
|
-
elbo_ok = rel < getattr(self.config, 'elbo_rel_tol', self.config.convergence_threshold)
|
|
352
|
-
|
|
353
|
-
weight_ok = True
|
|
354
|
-
if len(self.training_state.rule_weight_history) >= 2:
|
|
355
|
-
current_weights = torch.tensor(self.training_state.rule_weight_history[-1])
|
|
356
|
-
prev_weights = torch.tensor(self.training_state.rule_weight_history[-2])
|
|
357
|
-
weight_change = torch.norm(current_weights - prev_weights).item()
|
|
358
|
-
weight_ok = weight_change < getattr(self.config, 'weight_abs_tol', self.config.convergence_threshold)
|
|
359
|
-
|
|
360
|
-
if elbo_ok and weight_ok:
|
|
361
|
-
hits = self.training_state.convergence_info.get('hits', 0) + 1
|
|
362
|
-
self.training_state.convergence_info['hits'] = hits
|
|
363
|
-
else:
|
|
364
|
-
self.training_state.convergence_info['hits'] = 0
|
|
365
|
-
patience = getattr(self.config, 'convergence_patience', 3)
|
|
366
|
-
return self.training_state.convergence_info.get('hits', 0) >= patience
|
|
367
|
-
|
|
368
|
-
def get_rule_confidences(self) -> Dict[str, float]:
|
|
369
|
-
"""Get learned confidence scores for all rules"""
|
|
370
|
-
if not self.is_initialized or self.mln is None or self.mln.rule_weights is None:
|
|
371
|
-
return {}
|
|
372
|
-
|
|
373
|
-
confidences = {}
|
|
374
|
-
for i, rule in enumerate(self.mln.logical_rules):
|
|
375
|
-
# Convert rule weight to confidence using sigmoid
|
|
376
|
-
weight = self.mln.rule_weights[i].item()
|
|
377
|
-
confidence = torch.sigmoid(torch.tensor(weight)).item()
|
|
378
|
-
confidences[rule.rule_id] = confidence
|
|
379
|
-
|
|
380
|
-
return confidences
|
|
381
|
-
|
|
382
|
-
def predict_unknown_facts(self, top_k: int = 10) -> List[Tuple[Triple, float]]:
|
|
383
|
-
"""
|
|
384
|
-
Predict probabilities for unknown facts
|
|
385
|
-
|
|
386
|
-
Args:
|
|
387
|
-
top_k: Number of top predictions to return
|
|
388
|
-
|
|
389
|
-
Returns:
|
|
390
|
-
List of (triple, probability) tuples sorted by probability
|
|
391
|
-
"""
|
|
392
|
-
if not self.is_initialized:
|
|
393
|
-
raise RuntimeError("NPLL model not initialized")
|
|
394
|
-
|
|
395
|
-
unknown_facts = list(self.knowledge_graph.unknown_facts)
|
|
396
|
-
if not unknown_facts:
|
|
397
|
-
return []
|
|
398
|
-
|
|
399
|
-
# Get predictions for unknown facts
|
|
400
|
-
results = self.forward(unknown_facts)
|
|
401
|
-
probabilities = results['probabilities']
|
|
402
|
-
|
|
403
|
-
# Create (triple, probability) pairs
|
|
404
|
-
predictions = list(zip(unknown_facts, probabilities.tolist()))
|
|
405
|
-
|
|
406
|
-
# Sort by probability (descending)
|
|
407
|
-
predictions.sort(key=lambda x: x[1], reverse=True)
|
|
408
|
-
|
|
409
|
-
return predictions[:top_k]
|
|
410
|
-
|
|
411
|
-
def save_model(self, filepath: str):
|
|
412
|
-
"""Save complete NPLL model state with KG and rules serialization"""
|
|
413
|
-
if not self.is_initialized:
|
|
414
|
-
raise RuntimeError("Cannot save uninitialized model")
|
|
415
|
-
|
|
416
|
-
rules = self.mln.logical_rules if self.mln else []
|
|
417
|
-
payload = {
|
|
418
|
-
'schema_version': self.model_version,
|
|
419
|
-
'config': self.config.__dict__,
|
|
420
|
-
'creation_time': self.creation_time,
|
|
421
|
-
'training_state': self.training_state,
|
|
422
|
-
'mln_state': self.mln.state_dict() if self.mln else None,
|
|
423
|
-
'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
|
|
424
|
-
'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
|
|
425
|
-
'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
|
|
426
|
-
'rules': [rule.serialize() for rule in rules],
|
|
427
|
-
}
|
|
428
|
-
torch.save(payload, filepath)
|
|
429
|
-
logger.info(f"NPLL model saved to {filepath}")
|
|
430
|
-
|
|
431
|
-
def load_model(self, filepath: str):
|
|
432
|
-
"""
|
|
433
|
-
Load complete NPLL model state, rebuilding KG and rules before weights.
|
|
434
|
-
|
|
435
|
-
PRODUCTION FIX: Respects device configuration (cpu/cuda) instead of hardcoding cpu.
|
|
436
|
-
"""
|
|
437
|
-
if not os.path.exists(filepath):
|
|
438
|
-
raise FileNotFoundError(f"Model file not found: {filepath}")
|
|
439
|
-
|
|
440
|
-
from .core import KnowledgeGraph, LogicalRule
|
|
441
|
-
|
|
442
|
-
# Respect device config: load to configured device (cpu or cuda)
|
|
443
|
-
# Note: weights_only=False required for custom objects (NPLLTrainingState, etc.)
|
|
444
|
-
device = getattr(self, 'device', 'cpu')
|
|
445
|
-
payload = torch.load(filepath, map_location=device, weights_only=False)
|
|
446
|
-
|
|
447
|
-
# Rebuild config
|
|
448
|
-
self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
|
|
449
|
-
# Rebuild KG and rules
|
|
450
|
-
kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
|
|
451
|
-
rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
|
|
452
|
-
# Recreate components
|
|
453
|
-
self.__init__(self.config)
|
|
454
|
-
if kg is not None:
|
|
455
|
-
self.initialize(kg, rules)
|
|
456
|
-
# Load weights
|
|
457
|
-
if payload.get('mln_state') and self.mln:
|
|
458
|
-
self.mln.load_state_dict(payload['mln_state'])
|
|
459
|
-
if payload.get('scoring_state') and self.scoring_module:
|
|
460
|
-
self.scoring_module.load_state_dict(payload['scoring_state'])
|
|
461
|
-
if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
|
|
462
|
-
self.probability_transform.load_state_dict(payload['prob_state'])
|
|
463
|
-
# Metadata
|
|
464
|
-
self.model_version = payload.get('schema_version', '1.0')
|
|
465
|
-
self.creation_time = payload.get('creation_time', time.time())
|
|
466
|
-
self.training_state = payload.get('training_state', self.training_state)
|
|
467
|
-
logger.info(f"NPLL model loaded from {filepath}")
|
|
468
|
-
|
|
469
|
-
def save_to_buffer(self, buffer):
|
|
470
|
-
"""Save complete NPLL model state to a BytesIO buffer (for database storage)."""
|
|
471
|
-
if not self.is_initialized:
|
|
472
|
-
raise RuntimeError("Cannot save uninitialized model")
|
|
473
|
-
|
|
474
|
-
rules = self.mln.logical_rules if self.mln else []
|
|
475
|
-
payload = {
|
|
476
|
-
'schema_version': self.model_version,
|
|
477
|
-
'config': self.config.__dict__,
|
|
478
|
-
'creation_time': self.creation_time,
|
|
479
|
-
'training_state': self.training_state,
|
|
480
|
-
'mln_state': self.mln.state_dict() if self.mln else None,
|
|
481
|
-
'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
|
|
482
|
-
'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
|
|
483
|
-
'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
|
|
484
|
-
'rules': [rule.serialize() for rule in rules],
|
|
485
|
-
}
|
|
486
|
-
torch.save(payload, buffer)
|
|
487
|
-
logger.info("NPLL model saved to buffer")
|
|
488
|
-
|
|
489
|
-
def load_from_buffer(self, buffer):
|
|
490
|
-
"""Load complete NPLL model state from a BytesIO buffer (for database storage)."""
|
|
491
|
-
from .core import KnowledgeGraph, LogicalRule
|
|
492
|
-
|
|
493
|
-
# Respect device config
|
|
494
|
-
device = getattr(self, 'device', 'cpu')
|
|
495
|
-
buffer.seek(0)
|
|
496
|
-
payload = torch.load(buffer, map_location=device, weights_only=False)
|
|
497
|
-
|
|
498
|
-
# Rebuild config
|
|
499
|
-
self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
|
|
500
|
-
# Rebuild KG and rules
|
|
501
|
-
kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
|
|
502
|
-
rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
|
|
503
|
-
# Recreate components
|
|
504
|
-
self.__init__(self.config)
|
|
505
|
-
if kg is not None:
|
|
506
|
-
self.initialize(kg, rules)
|
|
507
|
-
# Load weights
|
|
508
|
-
if payload.get('mln_state') and self.mln:
|
|
509
|
-
self.mln.load_state_dict(payload['mln_state'])
|
|
510
|
-
if payload.get('scoring_state') and self.scoring_module:
|
|
511
|
-
self.scoring_module.load_state_dict(payload['scoring_state'])
|
|
512
|
-
if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
|
|
513
|
-
self.probability_transform.load_state_dict(payload['prob_state'])
|
|
514
|
-
# Metadata
|
|
515
|
-
self.model_version = payload.get('schema_version', '1.0')
|
|
516
|
-
self.creation_time = payload.get('creation_time', time.time())
|
|
517
|
-
self.training_state = payload.get('training_state', self.training_state)
|
|
518
|
-
logger.info("NPLL model loaded from buffer")
|
|
519
|
-
|
|
520
|
-
def _get_device(self) -> str:
|
|
521
|
-
for p in self.parameters():
|
|
522
|
-
return str(p.device)
|
|
523
|
-
return 'cpu'
|
|
524
|
-
|
|
525
|
-
def get_model_summary(self) -> Dict[str, Any]:
|
|
526
|
-
"""Get comprehensive model summary (device-safe)"""
|
|
527
|
-
summary = {
|
|
528
|
-
'model_version': self.model_version,
|
|
529
|
-
'is_initialized': self.is_initialized,
|
|
530
|
-
'config': self.config.__dict__ if self.config else {},
|
|
531
|
-
'creation_time': self.creation_time,
|
|
532
|
-
'calibration_version': self.calibration_version,
|
|
533
|
-
'training_state': self.training_state.__dict__ if self.training_state else {},
|
|
534
|
-
'num_parameters': sum(p.numel() for p in self.parameters()),
|
|
535
|
-
'device': self._get_device(),
|
|
536
|
-
}
|
|
537
|
-
if self.is_initialized:
|
|
538
|
-
summary.update({
|
|
539
|
-
'knowledge_graph_stats': self.knowledge_graph.get_statistics() if self.knowledge_graph else {},
|
|
540
|
-
'mln_stats': self.mln.get_rule_statistics() if self.mln else {},
|
|
541
|
-
'rule_confidences': self.get_rule_confidences(),
|
|
542
|
-
})
|
|
543
|
-
return summary
|
|
544
|
-
|
|
545
|
-
# --- Utilities: checkpoints and calibration ---
|
|
546
|
-
def restore_best(self) -> bool:
|
|
547
|
-
"""Restore the best-scoring MLN and scoring-module weights if available."""
|
|
548
|
-
if not self.training_state or not getattr(self.training_state, 'best_weights', None):
|
|
549
|
-
return False
|
|
550
|
-
best = self.training_state.best_weights
|
|
551
|
-
if 'mln' in best and self.mln:
|
|
552
|
-
self.mln.load_state_dict(best['mln'])
|
|
553
|
-
if 'scoring' in best and self.scoring_module:
|
|
554
|
-
self.scoring_module.load_state_dict(best['scoring'])
|
|
555
|
-
return True
|
|
556
|
-
|
|
557
|
-
def calibrate_temperature_on_data(self,
|
|
558
|
-
triples: List[Triple],
|
|
559
|
-
labels: torch.Tensor,
|
|
560
|
-
max_iter: int = 100,
|
|
561
|
-
version: Optional[str] = None) -> float:
|
|
562
|
-
"""
|
|
563
|
-
Calibrate the ProbabilityTransform temperature on a holdout set.
|
|
564
|
-
Stores the learned temperature in the transform and records a calibration version.
|
|
565
|
-
Returns the optimized temperature value.
|
|
566
|
-
"""
|
|
567
|
-
if not self.is_initialized:
|
|
568
|
-
raise RuntimeError("NPLL model not initialized")
|
|
569
|
-
self.eval()
|
|
570
|
-
with torch.no_grad():
|
|
571
|
-
scores = self.scoring_module.forward(triples)
|
|
572
|
-
|
|
573
|
-
# Extract group_ids for calibration if per_group is enabled
|
|
574
|
-
group_ids = None
|
|
575
|
-
if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
|
|
576
|
-
emb_mgr = self.scoring_module.embedding_manager
|
|
577
|
-
group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
|
|
578
|
-
if hasattr(self.probability_transform, 'ensure_num_groups'):
|
|
579
|
-
self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
|
|
580
|
-
|
|
581
|
-
optimized_temp = self.probability_transform.calibrate_temperature(scores, labels.float(), max_iter=max_iter, group_ids=group_ids)
|
|
582
|
-
self.calibration_version = version or f"temp@{time.time():.0f}"
|
|
583
|
-
return optimized_temp
|
|
584
|
-
|
|
585
|
-
def __str__(self) -> str:
|
|
586
|
-
if not self.is_initialized:
|
|
587
|
-
return "NPLL Model (not initialized)"
|
|
588
|
-
|
|
589
|
-
stats = self.knowledge_graph.get_statistics() if self.knowledge_graph else {}
|
|
590
|
-
return (f"NPLL Model:\n"
|
|
591
|
-
f" Entities: {stats.get('num_entities', 0)}\n"
|
|
592
|
-
f" Relations: {stats.get('num_relations', 0)}\n"
|
|
593
|
-
f" Known Facts: {stats.get('num_known_facts', 0)}\n"
|
|
594
|
-
f" Unknown Facts: {stats.get('num_unknown_facts', 0)}\n"
|
|
595
|
-
f" Rules: {len(self.mln.logical_rules) if self.mln else 0}\n"
|
|
596
|
-
f" Training State: {self.training_state}")
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def create_npll_model(config: NPLLConfig) -> NPLLModel:
|
|
600
|
-
"""
|
|
601
|
-
Factory function to create NPLL model
|
|
602
|
-
|
|
603
|
-
Args:
|
|
604
|
-
config: NPLL configuration
|
|
605
|
-
|
|
606
|
-
Returns:
|
|
607
|
-
Uninitialized NPLL model
|
|
608
|
-
"""
|
|
609
|
-
return NPLLModel(config)
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
def create_initialized_npll_model(knowledge_graph: KnowledgeGraph,
|
|
613
|
-
logical_rules: List[LogicalRule],
|
|
614
|
-
config: Optional[NPLLConfig] = None) -> NPLLModel:
|
|
615
|
-
"""
|
|
616
|
-
Factory function to create and initialize NPLL model
|
|
617
|
-
|
|
618
|
-
Args:
|
|
619
|
-
knowledge_graph: Knowledge graph
|
|
620
|
-
logical_rules: List of logical rules
|
|
621
|
-
config: Optional configuration (uses default if not provided)
|
|
622
|
-
|
|
623
|
-
Returns:
|
|
624
|
-
Initialized NPLL model ready for training
|
|
625
|
-
"""
|
|
626
|
-
if config is None:
|
|
627
|
-
config = get_config("ArangoDB_Triples") # Default configuration
|
|
628
|
-
|
|
629
|
-
model = NPLLModel(config)
|
|
630
|
-
model.initialize(knowledge_graph, logical_rules)
|
|
631
|
-
|
|
1
|
+
"""
|
|
2
|
+
Main NPLL Model - Neural Probabilistic Logic Learning
|
|
3
|
+
Integrates all components: MLN, Scoring Module, E-step, M-step
|
|
4
|
+
Exact implementation of the complete NPLL framework from the paper
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from typing import List, Dict, Set, Tuple, Optional, Any, Union
|
|
10
|
+
import logging
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
import time
|
|
13
|
+
import os
|
|
14
|
+
|
|
15
|
+
from .core import (
|
|
16
|
+
KnowledgeGraph, Entity, Relation, Triple, LogicalRule, GroundRule,
|
|
17
|
+
MarkovLogicNetwork, create_mln_from_kg_and_rules
|
|
18
|
+
)
|
|
19
|
+
from .scoring import (
|
|
20
|
+
NPLLScoringModule, create_scoring_module,
|
|
21
|
+
ProbabilityTransform, create_probability_components
|
|
22
|
+
)
|
|
23
|
+
from .inference import (
|
|
24
|
+
EStepRunner, MStepRunner, ELBOComputer, EStepResult, MStepResult,
|
|
25
|
+
create_e_step_runner, create_m_step_runner
|
|
26
|
+
)
|
|
27
|
+
from .utils import NPLLConfig, get_config
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class NPLLTrainingState:
|
|
34
|
+
"""
|
|
35
|
+
Training state for NPLL model
|
|
36
|
+
Tracks progress through E-M iterations
|
|
37
|
+
"""
|
|
38
|
+
epoch: int
|
|
39
|
+
em_iteration: int
|
|
40
|
+
elbo_history: List[float]
|
|
41
|
+
rule_weight_history: List[List[float]]
|
|
42
|
+
convergence_info: Dict[str, Any]
|
|
43
|
+
training_time: float
|
|
44
|
+
best_elbo: float
|
|
45
|
+
best_weights: Optional[torch.Tensor] = None
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
return (f"NPLL Training State:\n"
|
|
49
|
+
f" Epoch: {self.epoch}\n"
|
|
50
|
+
f" EM Iteration: {self.em_iteration}\n"
|
|
51
|
+
f" Current ELBO: {self.elbo_history[-1] if self.elbo_history else 'N/A'}\n"
|
|
52
|
+
f" Best ELBO: {self.best_elbo:.6f}\n"
|
|
53
|
+
f" Training Time: {self.training_time:.2f}s")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class NPLLModel(nn.Module):
|
|
57
|
+
"""
|
|
58
|
+
Complete Neural Probabilistic Logic Learning Model
|
|
59
|
+
|
|
60
|
+
Integrates all paper components:
|
|
61
|
+
- Knowledge Graph representation (Section 3)
|
|
62
|
+
- Scoring Module with bilinear function (Section 4.1, Equation 7)
|
|
63
|
+
- Markov Logic Network (Sections 3-4, Equations 1-2)
|
|
64
|
+
- E-M Algorithm (Sections 4.2-4.3, Equations 8-14)
|
|
65
|
+
- ELBO optimization (Equations 3-5)
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: NPLLConfig):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.config = config
|
|
71
|
+
|
|
72
|
+
# Core components (initialized when knowledge graph is provided)
|
|
73
|
+
self.knowledge_graph: Optional[KnowledgeGraph] = None
|
|
74
|
+
self.mln: Optional[MarkovLogicNetwork] = None
|
|
75
|
+
self.scoring_module: Optional[NPLLScoringModule] = None
|
|
76
|
+
|
|
77
|
+
# Inference components
|
|
78
|
+
self.e_step_runner: Optional[EStepRunner] = None
|
|
79
|
+
self.m_step_runner: Optional[MStepRunner] = None
|
|
80
|
+
self.elbo_computer: Optional[ELBOComputer] = None
|
|
81
|
+
|
|
82
|
+
# Probability transformation
|
|
83
|
+
self.probability_transform: Optional[ProbabilityTransform] = None
|
|
84
|
+
|
|
85
|
+
# Training state
|
|
86
|
+
self.training_state: Optional[NPLLTrainingState] = None
|
|
87
|
+
self.is_initialized = False
|
|
88
|
+
|
|
89
|
+
# Model metadata
|
|
90
|
+
self.model_version = "1.0"
|
|
91
|
+
self.creation_time = time.time()
|
|
92
|
+
self.calibration_version = None
|
|
93
|
+
|
|
94
|
+
def initialize(self,
|
|
95
|
+
knowledge_graph: KnowledgeGraph,
|
|
96
|
+
logical_rules: List[LogicalRule]):
|
|
97
|
+
"""
|
|
98
|
+
Initialize NPLL model with knowledge graph and rules
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
knowledge_graph: Knowledge graph K = (E, L, F)
|
|
102
|
+
logical_rules: List of logical rules R
|
|
103
|
+
"""
|
|
104
|
+
logger.info("Initializing NPLL model...")
|
|
105
|
+
|
|
106
|
+
# Store knowledge graph
|
|
107
|
+
self.knowledge_graph = knowledge_graph
|
|
108
|
+
|
|
109
|
+
# Create and initialize MLN
|
|
110
|
+
self.mln = create_mln_from_kg_and_rules(knowledge_graph, logical_rules, self.config)
|
|
111
|
+
|
|
112
|
+
# Create scoring module
|
|
113
|
+
self.scoring_module = create_scoring_module(self.config, knowledge_graph)
|
|
114
|
+
|
|
115
|
+
# Create inference components
|
|
116
|
+
self.e_step_runner = create_e_step_runner(self.config)
|
|
117
|
+
self.m_step_runner = create_m_step_runner(self.config)
|
|
118
|
+
|
|
119
|
+
# Create ELBO computer
|
|
120
|
+
from .inference import create_elbo_computer
|
|
121
|
+
self.elbo_computer = create_elbo_computer(self.config)
|
|
122
|
+
|
|
123
|
+
# Create probability transformation (enable per-relation groups if available)
|
|
124
|
+
num_rel = None
|
|
125
|
+
if self.scoring_module is not None and hasattr(self.scoring_module, 'embedding_manager'):
|
|
126
|
+
emb_mgr = self.scoring_module.embedding_manager
|
|
127
|
+
if hasattr(emb_mgr, 'relation_num_groups'):
|
|
128
|
+
num_rel = emb_mgr.relation_num_groups
|
|
129
|
+
per_relation = num_rel is not None
|
|
130
|
+
prob_transform, _ = create_probability_components(
|
|
131
|
+
self.config.temperature,
|
|
132
|
+
per_relation=per_relation,
|
|
133
|
+
num_relations=(num_rel or 1)
|
|
134
|
+
)
|
|
135
|
+
self.probability_transform = prob_transform
|
|
136
|
+
|
|
137
|
+
# Initialize training state
|
|
138
|
+
self.training_state = NPLLTrainingState(
|
|
139
|
+
epoch=0,
|
|
140
|
+
em_iteration=0,
|
|
141
|
+
elbo_history=[],
|
|
142
|
+
rule_weight_history=[],
|
|
143
|
+
convergence_info={},
|
|
144
|
+
training_time=0.0,
|
|
145
|
+
best_elbo=float('-inf')
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.is_initialized = True
|
|
149
|
+
|
|
150
|
+
logger.info(f"NPLL model initialized with {len(logical_rules)} rules, "
|
|
151
|
+
f"{len(knowledge_graph.entities)} entities, "
|
|
152
|
+
f"{len(knowledge_graph.relations)} relations")
|
|
153
|
+
|
|
154
|
+
def forward(self, triples: List[Triple]) -> Dict[str, torch.Tensor]:
|
|
155
|
+
"""
|
|
156
|
+
Forward pass through NPLL model
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
triples: List of triples to score
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Dictionary with scores and probabilities
|
|
163
|
+
"""
|
|
164
|
+
if not self.is_initialized:
|
|
165
|
+
raise RuntimeError("NPLL model not initialized. Call initialize() first.")
|
|
166
|
+
|
|
167
|
+
# Get raw scores from scoring module
|
|
168
|
+
raw_scores = self.scoring_module.forward(triples)
|
|
169
|
+
|
|
170
|
+
# Group IDs per relation (no vocab mutation)
|
|
171
|
+
group_ids = None
|
|
172
|
+
if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
|
|
173
|
+
emb_mgr = self.scoring_module.embedding_manager
|
|
174
|
+
group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
|
|
175
|
+
# Ensure transform capacity if table grew
|
|
176
|
+
if hasattr(self.probability_transform, 'ensure_num_groups'):
|
|
177
|
+
self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
|
|
178
|
+
|
|
179
|
+
# Transform to probabilities
|
|
180
|
+
probabilities = self.probability_transform(raw_scores, apply_temperature=True, group_ids=group_ids)
|
|
181
|
+
|
|
182
|
+
# Get log probabilities
|
|
183
|
+
log_probabilities = self.probability_transform.get_log_probabilities(raw_scores, apply_temperature=True, group_ids=group_ids)
|
|
184
|
+
|
|
185
|
+
return {
|
|
186
|
+
'raw_scores': raw_scores,
|
|
187
|
+
'probabilities': probabilities,
|
|
188
|
+
'log_probabilities': log_probabilities
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
def predict_single_triple(self, head: str, relation: str, tail: str, transient: bool = True) -> Dict[str, float]:
|
|
192
|
+
"""
|
|
193
|
+
Predict probability for a single triple
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
head: Head entity name
|
|
197
|
+
relation: Relation name
|
|
198
|
+
tail: Tail entity name
|
|
199
|
+
transient: If True, do not mutate the underlying knowledge graph
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Dictionary with prediction results
|
|
203
|
+
"""
|
|
204
|
+
if not self.is_initialized:
|
|
205
|
+
raise RuntimeError("NPLL model not initialized")
|
|
206
|
+
|
|
207
|
+
# Create triple object without mutating KG by default
|
|
208
|
+
if transient:
|
|
209
|
+
head_entity = Entity(head)
|
|
210
|
+
relation_obj = Relation(relation)
|
|
211
|
+
tail_entity = Entity(tail)
|
|
212
|
+
else:
|
|
213
|
+
head_entity = self.knowledge_graph.get_entity(head) or self.knowledge_graph.add_entity(head)
|
|
214
|
+
relation_obj = self.knowledge_graph.get_relation(relation) or self.knowledge_graph.add_relation(relation)
|
|
215
|
+
tail_entity = self.knowledge_graph.get_entity(tail) or self.knowledge_graph.add_entity(tail)
|
|
216
|
+
|
|
217
|
+
triple = Triple(head=head_entity, relation=relation_obj, tail=tail_entity)
|
|
218
|
+
|
|
219
|
+
# Get predictions
|
|
220
|
+
self.eval()
|
|
221
|
+
with torch.no_grad():
|
|
222
|
+
results = self.forward([triple])
|
|
223
|
+
|
|
224
|
+
return {
|
|
225
|
+
'probability': results['probabilities'][0].item(),
|
|
226
|
+
'log_probability': results['log_probabilities'][0].item(),
|
|
227
|
+
'raw_score': results['raw_scores'][0].item()
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
def run_single_em_iteration(self) -> Dict[str, Any]:
|
|
231
|
+
"""
|
|
232
|
+
Run a single E-M iteration
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Dictionary with iteration results
|
|
236
|
+
"""
|
|
237
|
+
if not self.is_initialized:
|
|
238
|
+
raise RuntimeError("NPLL model not initialized")
|
|
239
|
+
|
|
240
|
+
iteration_start_time = time.time()
|
|
241
|
+
|
|
242
|
+
logger.debug(f"Starting E-M iteration {self.training_state.em_iteration}")
|
|
243
|
+
|
|
244
|
+
# E-step: Optimize Q(U)
|
|
245
|
+
logger.debug("Running E-step...")
|
|
246
|
+
e_step_result = self.e_step_runner.run_e_step(
|
|
247
|
+
self.mln, self.scoring_module, self.knowledge_graph
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# M-step: Optimize rule weights ω
|
|
251
|
+
logger.debug("Running M-step...")
|
|
252
|
+
m_step_result = self.m_step_runner.run_m_step(self.mln, e_step_result)
|
|
253
|
+
|
|
254
|
+
# Update training state
|
|
255
|
+
current_elbo = e_step_result.elbo_value.item()
|
|
256
|
+
self.training_state.elbo_history.append(current_elbo)
|
|
257
|
+
|
|
258
|
+
if self.mln.rule_weights is not None:
|
|
259
|
+
current_weights = self.mln.rule_weights.data.tolist()
|
|
260
|
+
self.training_state.rule_weight_history.append(current_weights)
|
|
261
|
+
|
|
262
|
+
# Track best snapshot (MLN + scoring)
|
|
263
|
+
if current_elbo > self.training_state.best_elbo:
|
|
264
|
+
self.training_state.best_elbo = current_elbo
|
|
265
|
+
self.training_state.best_weights = {
|
|
266
|
+
'mln': {k: v.clone() for k, v in self.mln.state_dict().items()},
|
|
267
|
+
'scoring': {k: v.clone() for k, v in self.scoring_module.state_dict().items()} if self.scoring_module else {},
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
self.training_state.em_iteration += 1
|
|
271
|
+
iteration_time = time.time() - iteration_start_time
|
|
272
|
+
self.training_state.training_time += iteration_time
|
|
273
|
+
|
|
274
|
+
# Check convergence
|
|
275
|
+
converged = self._check_em_convergence()
|
|
276
|
+
|
|
277
|
+
iteration_result = {
|
|
278
|
+
'em_iteration': self.training_state.em_iteration - 1,
|
|
279
|
+
'e_step_result': e_step_result,
|
|
280
|
+
'm_step_result': m_step_result,
|
|
281
|
+
'elbo': current_elbo,
|
|
282
|
+
'iteration_time': iteration_time,
|
|
283
|
+
'converged': converged,
|
|
284
|
+
'convergence_info': {
|
|
285
|
+
'e_step_converged': e_step_result.convergence_info.get('converged', False),
|
|
286
|
+
'm_step_converged': m_step_result.convergence_info.get('converged', False)
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
logger.debug(f"E-M iteration completed: ELBO={current_elbo:.6f}, "
|
|
291
|
+
f"Time={iteration_time:.2f}s, Converged={converged}")
|
|
292
|
+
|
|
293
|
+
return iteration_result
|
|
294
|
+
|
|
295
|
+
def train_epoch(self, max_em_iterations: Optional[int] = None) -> Dict[str, Any]:
|
|
296
|
+
"""
|
|
297
|
+
Train for one epoch (multiple E-M iterations until convergence)
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
max_em_iterations: Maximum E-M iterations per epoch
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Dictionary with epoch results
|
|
304
|
+
"""
|
|
305
|
+
if not self.is_initialized:
|
|
306
|
+
raise RuntimeError("NPLL model not initialized")
|
|
307
|
+
|
|
308
|
+
max_iterations = max_em_iterations or self.config.em_iterations
|
|
309
|
+
epoch_start_time = time.time()
|
|
310
|
+
|
|
311
|
+
logger.info(f"Starting training epoch {self.training_state.epoch}")
|
|
312
|
+
|
|
313
|
+
epoch_results = []
|
|
314
|
+
converged = False
|
|
315
|
+
|
|
316
|
+
for em_iter in range(max_iterations):
|
|
317
|
+
iteration_result = self.run_single_em_iteration()
|
|
318
|
+
epoch_results.append(iteration_result)
|
|
319
|
+
|
|
320
|
+
if iteration_result['converged']:
|
|
321
|
+
converged = True
|
|
322
|
+
logger.info(f"Converged after {em_iter + 1} E-M iterations")
|
|
323
|
+
break
|
|
324
|
+
|
|
325
|
+
epoch_time = time.time() - epoch_start_time
|
|
326
|
+
self.training_state.epoch += 1
|
|
327
|
+
|
|
328
|
+
epoch_summary = {
|
|
329
|
+
'epoch': self.training_state.epoch - 1,
|
|
330
|
+
'em_iterations': len(epoch_results),
|
|
331
|
+
'converged': converged,
|
|
332
|
+
'final_elbo': epoch_results[-1]['elbo'] if epoch_results else float('-inf'),
|
|
333
|
+
'best_elbo_this_epoch': max(r['elbo'] for r in epoch_results) if epoch_results else float('-inf'),
|
|
334
|
+
'epoch_time': epoch_time,
|
|
335
|
+
'iteration_results': epoch_results
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
logger.info(f"Epoch {self.training_state.epoch - 1} completed: "
|
|
339
|
+
f"ELBO={epoch_summary['final_elbo']:.6f}, "
|
|
340
|
+
f"EM iterations={epoch_summary['em_iterations']}, "
|
|
341
|
+
f"Time={epoch_time:.2f}s")
|
|
342
|
+
|
|
343
|
+
return epoch_summary
|
|
344
|
+
|
|
345
|
+
def _check_em_convergence(self) -> bool:
|
|
346
|
+
"""Check if E-M algorithm has converged with patience and relative tolerance"""
|
|
347
|
+
if len(self.training_state.elbo_history) < 2:
|
|
348
|
+
return False
|
|
349
|
+
h = self.training_state.elbo_history
|
|
350
|
+
rel = abs(h[-1] - h[-2]) / (abs(h[-2]) + 1e-8)
|
|
351
|
+
elbo_ok = rel < getattr(self.config, 'elbo_rel_tol', self.config.convergence_threshold)
|
|
352
|
+
|
|
353
|
+
weight_ok = True
|
|
354
|
+
if len(self.training_state.rule_weight_history) >= 2:
|
|
355
|
+
current_weights = torch.tensor(self.training_state.rule_weight_history[-1])
|
|
356
|
+
prev_weights = torch.tensor(self.training_state.rule_weight_history[-2])
|
|
357
|
+
weight_change = torch.norm(current_weights - prev_weights).item()
|
|
358
|
+
weight_ok = weight_change < getattr(self.config, 'weight_abs_tol', self.config.convergence_threshold)
|
|
359
|
+
|
|
360
|
+
if elbo_ok and weight_ok:
|
|
361
|
+
hits = self.training_state.convergence_info.get('hits', 0) + 1
|
|
362
|
+
self.training_state.convergence_info['hits'] = hits
|
|
363
|
+
else:
|
|
364
|
+
self.training_state.convergence_info['hits'] = 0
|
|
365
|
+
patience = getattr(self.config, 'convergence_patience', 3)
|
|
366
|
+
return self.training_state.convergence_info.get('hits', 0) >= patience
|
|
367
|
+
|
|
368
|
+
def get_rule_confidences(self) -> Dict[str, float]:
|
|
369
|
+
"""Get learned confidence scores for all rules"""
|
|
370
|
+
if not self.is_initialized or self.mln is None or self.mln.rule_weights is None:
|
|
371
|
+
return {}
|
|
372
|
+
|
|
373
|
+
confidences = {}
|
|
374
|
+
for i, rule in enumerate(self.mln.logical_rules):
|
|
375
|
+
# Convert rule weight to confidence using sigmoid
|
|
376
|
+
weight = self.mln.rule_weights[i].item()
|
|
377
|
+
confidence = torch.sigmoid(torch.tensor(weight)).item()
|
|
378
|
+
confidences[rule.rule_id] = confidence
|
|
379
|
+
|
|
380
|
+
return confidences
|
|
381
|
+
|
|
382
|
+
def predict_unknown_facts(self, top_k: int = 10) -> List[Tuple[Triple, float]]:
|
|
383
|
+
"""
|
|
384
|
+
Predict probabilities for unknown facts
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
top_k: Number of top predictions to return
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
List of (triple, probability) tuples sorted by probability
|
|
391
|
+
"""
|
|
392
|
+
if not self.is_initialized:
|
|
393
|
+
raise RuntimeError("NPLL model not initialized")
|
|
394
|
+
|
|
395
|
+
unknown_facts = list(self.knowledge_graph.unknown_facts)
|
|
396
|
+
if not unknown_facts:
|
|
397
|
+
return []
|
|
398
|
+
|
|
399
|
+
# Get predictions for unknown facts
|
|
400
|
+
results = self.forward(unknown_facts)
|
|
401
|
+
probabilities = results['probabilities']
|
|
402
|
+
|
|
403
|
+
# Create (triple, probability) pairs
|
|
404
|
+
predictions = list(zip(unknown_facts, probabilities.tolist()))
|
|
405
|
+
|
|
406
|
+
# Sort by probability (descending)
|
|
407
|
+
predictions.sort(key=lambda x: x[1], reverse=True)
|
|
408
|
+
|
|
409
|
+
return predictions[:top_k]
|
|
410
|
+
|
|
411
|
+
def save_model(self, filepath: str):
|
|
412
|
+
"""Save complete NPLL model state with KG and rules serialization"""
|
|
413
|
+
if not self.is_initialized:
|
|
414
|
+
raise RuntimeError("Cannot save uninitialized model")
|
|
415
|
+
|
|
416
|
+
rules = self.mln.logical_rules if self.mln else []
|
|
417
|
+
payload = {
|
|
418
|
+
'schema_version': self.model_version,
|
|
419
|
+
'config': self.config.__dict__,
|
|
420
|
+
'creation_time': self.creation_time,
|
|
421
|
+
'training_state': self.training_state,
|
|
422
|
+
'mln_state': self.mln.state_dict() if self.mln else None,
|
|
423
|
+
'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
|
|
424
|
+
'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
|
|
425
|
+
'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
|
|
426
|
+
'rules': [rule.serialize() for rule in rules],
|
|
427
|
+
}
|
|
428
|
+
torch.save(payload, filepath)
|
|
429
|
+
logger.info(f"NPLL model saved to {filepath}")
|
|
430
|
+
|
|
431
|
+
def load_model(self, filepath: str):
|
|
432
|
+
"""
|
|
433
|
+
Load complete NPLL model state, rebuilding KG and rules before weights.
|
|
434
|
+
|
|
435
|
+
PRODUCTION FIX: Respects device configuration (cpu/cuda) instead of hardcoding cpu.
|
|
436
|
+
"""
|
|
437
|
+
if not os.path.exists(filepath):
|
|
438
|
+
raise FileNotFoundError(f"Model file not found: {filepath}")
|
|
439
|
+
|
|
440
|
+
from .core import KnowledgeGraph, LogicalRule
|
|
441
|
+
|
|
442
|
+
# Respect device config: load to configured device (cpu or cuda)
|
|
443
|
+
# Note: weights_only=False required for custom objects (NPLLTrainingState, etc.)
|
|
444
|
+
device = getattr(self, 'device', 'cpu')
|
|
445
|
+
payload = torch.load(filepath, map_location=device, weights_only=False)
|
|
446
|
+
|
|
447
|
+
# Rebuild config
|
|
448
|
+
self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
|
|
449
|
+
# Rebuild KG and rules
|
|
450
|
+
kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
|
|
451
|
+
rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
|
|
452
|
+
# Recreate components
|
|
453
|
+
self.__init__(self.config)
|
|
454
|
+
if kg is not None:
|
|
455
|
+
self.initialize(kg, rules)
|
|
456
|
+
# Load weights
|
|
457
|
+
if payload.get('mln_state') and self.mln:
|
|
458
|
+
self.mln.load_state_dict(payload['mln_state'])
|
|
459
|
+
if payload.get('scoring_state') and self.scoring_module:
|
|
460
|
+
self.scoring_module.load_state_dict(payload['scoring_state'])
|
|
461
|
+
if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
|
|
462
|
+
self.probability_transform.load_state_dict(payload['prob_state'])
|
|
463
|
+
# Metadata
|
|
464
|
+
self.model_version = payload.get('schema_version', '1.0')
|
|
465
|
+
self.creation_time = payload.get('creation_time', time.time())
|
|
466
|
+
self.training_state = payload.get('training_state', self.training_state)
|
|
467
|
+
logger.info(f"NPLL model loaded from {filepath}")
|
|
468
|
+
|
|
469
|
+
def save_to_buffer(self, buffer):
|
|
470
|
+
"""Save complete NPLL model state to a BytesIO buffer (for database storage)."""
|
|
471
|
+
if not self.is_initialized:
|
|
472
|
+
raise RuntimeError("Cannot save uninitialized model")
|
|
473
|
+
|
|
474
|
+
rules = self.mln.logical_rules if self.mln else []
|
|
475
|
+
payload = {
|
|
476
|
+
'schema_version': self.model_version,
|
|
477
|
+
'config': self.config.__dict__,
|
|
478
|
+
'creation_time': self.creation_time,
|
|
479
|
+
'training_state': self.training_state,
|
|
480
|
+
'mln_state': self.mln.state_dict() if self.mln else None,
|
|
481
|
+
'scoring_state': self.scoring_module.state_dict() if self.scoring_module else None,
|
|
482
|
+
'prob_state': self.probability_transform.state_dict() if hasattr(self.probability_transform, 'state_dict') else None,
|
|
483
|
+
'knowledge_graph': self.knowledge_graph.serialize() if self.knowledge_graph else None,
|
|
484
|
+
'rules': [rule.serialize() for rule in rules],
|
|
485
|
+
}
|
|
486
|
+
torch.save(payload, buffer)
|
|
487
|
+
logger.info("NPLL model saved to buffer")
|
|
488
|
+
|
|
489
|
+
def load_from_buffer(self, buffer):
|
|
490
|
+
"""Load complete NPLL model state from a BytesIO buffer (for database storage)."""
|
|
491
|
+
from .core import KnowledgeGraph, LogicalRule
|
|
492
|
+
|
|
493
|
+
# Respect device config
|
|
494
|
+
device = getattr(self, 'device', 'cpu')
|
|
495
|
+
buffer.seek(0)
|
|
496
|
+
payload = torch.load(buffer, map_location=device, weights_only=False)
|
|
497
|
+
|
|
498
|
+
# Rebuild config
|
|
499
|
+
self.config = NPLLConfig(**payload['config']) if isinstance(payload.get('config'), dict) else payload['config']
|
|
500
|
+
# Rebuild KG and rules
|
|
501
|
+
kg = KnowledgeGraph.deserialize(payload['knowledge_graph']) if payload.get('knowledge_graph') else None
|
|
502
|
+
rules = [LogicalRule.deserialize(x) for x in payload.get('rules', [])]
|
|
503
|
+
# Recreate components
|
|
504
|
+
self.__init__(self.config)
|
|
505
|
+
if kg is not None:
|
|
506
|
+
self.initialize(kg, rules)
|
|
507
|
+
# Load weights
|
|
508
|
+
if payload.get('mln_state') and self.mln:
|
|
509
|
+
self.mln.load_state_dict(payload['mln_state'])
|
|
510
|
+
if payload.get('scoring_state') and self.scoring_module:
|
|
511
|
+
self.scoring_module.load_state_dict(payload['scoring_state'])
|
|
512
|
+
if payload.get('prob_state') and hasattr(self.probability_transform, 'load_state_dict'):
|
|
513
|
+
self.probability_transform.load_state_dict(payload['prob_state'])
|
|
514
|
+
# Metadata
|
|
515
|
+
self.model_version = payload.get('schema_version', '1.0')
|
|
516
|
+
self.creation_time = payload.get('creation_time', time.time())
|
|
517
|
+
self.training_state = payload.get('training_state', self.training_state)
|
|
518
|
+
logger.info("NPLL model loaded from buffer")
|
|
519
|
+
|
|
520
|
+
def _get_device(self) -> str:
|
|
521
|
+
for p in self.parameters():
|
|
522
|
+
return str(p.device)
|
|
523
|
+
return 'cpu'
|
|
524
|
+
|
|
525
|
+
def get_model_summary(self) -> Dict[str, Any]:
|
|
526
|
+
"""Get comprehensive model summary (device-safe)"""
|
|
527
|
+
summary = {
|
|
528
|
+
'model_version': self.model_version,
|
|
529
|
+
'is_initialized': self.is_initialized,
|
|
530
|
+
'config': self.config.__dict__ if self.config else {},
|
|
531
|
+
'creation_time': self.creation_time,
|
|
532
|
+
'calibration_version': self.calibration_version,
|
|
533
|
+
'training_state': self.training_state.__dict__ if self.training_state else {},
|
|
534
|
+
'num_parameters': sum(p.numel() for p in self.parameters()),
|
|
535
|
+
'device': self._get_device(),
|
|
536
|
+
}
|
|
537
|
+
if self.is_initialized:
|
|
538
|
+
summary.update({
|
|
539
|
+
'knowledge_graph_stats': self.knowledge_graph.get_statistics() if self.knowledge_graph else {},
|
|
540
|
+
'mln_stats': self.mln.get_rule_statistics() if self.mln else {},
|
|
541
|
+
'rule_confidences': self.get_rule_confidences(),
|
|
542
|
+
})
|
|
543
|
+
return summary
|
|
544
|
+
|
|
545
|
+
# --- Utilities: checkpoints and calibration ---
|
|
546
|
+
def restore_best(self) -> bool:
|
|
547
|
+
"""Restore the best-scoring MLN and scoring-module weights if available."""
|
|
548
|
+
if not self.training_state or not getattr(self.training_state, 'best_weights', None):
|
|
549
|
+
return False
|
|
550
|
+
best = self.training_state.best_weights
|
|
551
|
+
if 'mln' in best and self.mln:
|
|
552
|
+
self.mln.load_state_dict(best['mln'])
|
|
553
|
+
if 'scoring' in best and self.scoring_module:
|
|
554
|
+
self.scoring_module.load_state_dict(best['scoring'])
|
|
555
|
+
return True
|
|
556
|
+
|
|
557
|
+
def calibrate_temperature_on_data(self,
|
|
558
|
+
triples: List[Triple],
|
|
559
|
+
labels: torch.Tensor,
|
|
560
|
+
max_iter: int = 100,
|
|
561
|
+
version: Optional[str] = None) -> float:
|
|
562
|
+
"""
|
|
563
|
+
Calibrate the ProbabilityTransform temperature on a holdout set.
|
|
564
|
+
Stores the learned temperature in the transform and records a calibration version.
|
|
565
|
+
Returns the optimized temperature value.
|
|
566
|
+
"""
|
|
567
|
+
if not self.is_initialized:
|
|
568
|
+
raise RuntimeError("NPLL model not initialized")
|
|
569
|
+
self.eval()
|
|
570
|
+
with torch.no_grad():
|
|
571
|
+
scores = self.scoring_module.forward(triples)
|
|
572
|
+
|
|
573
|
+
# Extract group_ids for calibration if per_group is enabled
|
|
574
|
+
group_ids = None
|
|
575
|
+
if hasattr(self.scoring_module, 'embedding_manager') and getattr(self.probability_transform, 'per_group', False):
|
|
576
|
+
emb_mgr = self.scoring_module.embedding_manager
|
|
577
|
+
group_ids = emb_mgr.relation_group_ids_for_triples(triples, add_if_missing=False)
|
|
578
|
+
if hasattr(self.probability_transform, 'ensure_num_groups'):
|
|
579
|
+
self.probability_transform.ensure_num_groups(emb_mgr.relation_num_groups)
|
|
580
|
+
|
|
581
|
+
optimized_temp = self.probability_transform.calibrate_temperature(scores, labels.float(), max_iter=max_iter, group_ids=group_ids)
|
|
582
|
+
self.calibration_version = version or f"temp@{time.time():.0f}"
|
|
583
|
+
return optimized_temp
|
|
584
|
+
|
|
585
|
+
def __str__(self) -> str:
|
|
586
|
+
if not self.is_initialized:
|
|
587
|
+
return "NPLL Model (not initialized)"
|
|
588
|
+
|
|
589
|
+
stats = self.knowledge_graph.get_statistics() if self.knowledge_graph else {}
|
|
590
|
+
return (f"NPLL Model:\n"
|
|
591
|
+
f" Entities: {stats.get('num_entities', 0)}\n"
|
|
592
|
+
f" Relations: {stats.get('num_relations', 0)}\n"
|
|
593
|
+
f" Known Facts: {stats.get('num_known_facts', 0)}\n"
|
|
594
|
+
f" Unknown Facts: {stats.get('num_unknown_facts', 0)}\n"
|
|
595
|
+
f" Rules: {len(self.mln.logical_rules) if self.mln else 0}\n"
|
|
596
|
+
f" Training State: {self.training_state}")
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def create_npll_model(config: NPLLConfig) -> NPLLModel:
|
|
600
|
+
"""
|
|
601
|
+
Factory function to create NPLL model
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
config: NPLL configuration
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
Uninitialized NPLL model
|
|
608
|
+
"""
|
|
609
|
+
return NPLLModel(config)
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def create_initialized_npll_model(knowledge_graph: KnowledgeGraph,
|
|
613
|
+
logical_rules: List[LogicalRule],
|
|
614
|
+
config: Optional[NPLLConfig] = None) -> NPLLModel:
|
|
615
|
+
"""
|
|
616
|
+
Factory function to create and initialize NPLL model
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
knowledge_graph: Knowledge graph
|
|
620
|
+
logical_rules: List of logical rules
|
|
621
|
+
config: Optional configuration (uses default if not provided)
|
|
622
|
+
|
|
623
|
+
Returns:
|
|
624
|
+
Initialized NPLL model ready for training
|
|
625
|
+
"""
|
|
626
|
+
if config is None:
|
|
627
|
+
config = get_config("ArangoDB_Triples") # Default configuration
|
|
628
|
+
|
|
629
|
+
model = NPLLModel(config)
|
|
630
|
+
model.initialize(knowledge_graph, logical_rules)
|
|
631
|
+
|
|
632
632
|
return model
|