odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/bootstrap.py CHANGED
@@ -1,474 +1,474 @@
1
- """
2
- Bootstrap module for NPLL.
3
- Handles the end-to-end lifecycle of the NPLL model:
4
- 1. Extracting data from ArangoDB
5
- 2. Generating domain-appropriate logical rules
6
- 3. Training the model
7
- 4. Storing ONLY WEIGHTS in database (not full model)
8
-
9
- Architecture:
10
- - Weights stored in OdinModels collection (~1 KB)
11
- - Model rebuilt from KG on each load (~30 sec)
12
- - No external files needed
13
- """
14
-
15
- import os
16
- import hashlib
17
- import json
18
- import logging
19
- import random
20
- import time
21
- import torch
22
- from datetime import datetime
23
- from typing import List, Tuple, Dict, Optional, Any
24
- from arango.database import StandardDatabase
25
-
26
- from .core.knowledge_graph import KnowledgeGraph, load_knowledge_graph_from_triples
27
- from .core.logical_rules import LogicalRule, Atom, Variable, RuleType
28
- from .npll_model import create_initialized_npll_model, NPLLModel
29
- from .training.npll_trainer import TrainingConfig, create_trainer
30
- from .utils.config import get_config
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
- # Collection name for storing model weights
35
- ODIN_MODELS_COLLECTION = "OdinModels"
36
- NPLL_MODEL_KEY = "npll_current"
37
-
38
-
39
- class KnowledgeBootstrapper:
40
- """
41
- Manages the lifecycle of the NPLL model.
42
-
43
- Storage Strategy:
44
- - Only rule weights are saved to database (~1 KB)
45
- - Model is rebuilt from KG data on each load (~30 sec)
46
- - No external .pt files needed
47
- """
48
-
49
- def __init__(self, db: StandardDatabase):
50
- """
51
- Initialize the bootstrapper.
52
-
53
- Args:
54
- db: An already-connected ArangoDB database instance
55
- """
56
- self.db = db
57
- self._ensure_collection_exists()
58
-
59
- def _ensure_collection_exists(self):
60
- """Ensure OdinModels collection exists."""
61
- try:
62
- if not self.db.has_collection(ODIN_MODELS_COLLECTION):
63
- self.db.create_collection(ODIN_MODELS_COLLECTION)
64
- logger.info(f"Created {ODIN_MODELS_COLLECTION} collection")
65
- except Exception as e:
66
- # Suppress permission errors (common in read-only environments)
67
- logger.warning(f"Could not verify/create {ODIN_MODELS_COLLECTION} (Permission Error?): {e}")
68
-
69
- def ensure_model_ready(self, force_retrain: bool = False) -> Optional[NPLLModel]:
70
- """
71
- Ensures a trained NPLL model is available.
72
-
73
- Flow:
74
- 1. Compute current data hash
75
- 2. Check OdinModels for saved weights with matching hash
76
- 3. If found: rebuild model from KG, apply saved weights
77
- 4. If not found: train new model, save weights to DB
78
-
79
- Args:
80
- force_retrain: If True, ignores cached weights and retrains
81
-
82
- Returns:
83
- Loaded NPLLModel ready for inference, or None if failed
84
- """
85
- current_hash = self._compute_data_hash()
86
- logger.info(f"Current data hash: {current_hash[:16]}...")
87
-
88
- if not force_retrain:
89
- # Try to load existing weights and rebuild model
90
- model = self._load_model_with_weights(current_hash)
91
- if model:
92
- return model
93
-
94
- # Train new model
95
- logger.info("Training new NPLL model...")
96
- return self._train_and_save_weights(current_hash)
97
-
98
- def _compute_data_hash(self) -> str:
99
- """
100
- Compute a hash of the current schema to detect when retraining is needed.
101
- """
102
- try:
103
- # Get relation names
104
- rel_query = """
105
- FOR e IN ExtractedRelationships
106
- COLLECT rel = e.relationship WITH COUNT INTO cnt
107
- SORT rel
108
- RETURN {rel: rel, count: cnt}
109
- """
110
- relations = list(self.db.aql.execute(rel_query))
111
- relation_names = sorted([r['rel'] for r in relations if r['rel']])
112
-
113
- # Get counts
114
- entity_count = self.db.collection("ExtractedEntities").count()
115
- fact_count = self.db.collection("ExtractedRelationships").count()
116
-
117
- hash_input = {
118
- "relations": relation_names,
119
- "entity_count": entity_count,
120
- "fact_count": fact_count,
121
- }
122
-
123
- hash_str = json.dumps(hash_input, sort_keys=True)
124
- return hashlib.sha256(hash_str.encode()).hexdigest()
125
-
126
- except Exception as e:
127
- logger.warning(f"Could not compute data hash: {e}")
128
- return hashlib.sha256(str(time.time()).encode()).hexdigest()
129
-
130
- def _load_model_with_weights(self, expected_hash: str) -> Optional[NPLLModel]:
131
- """
132
- Load saved weights from DB and rebuild the model.
133
-
134
- Flow:
135
- 1. Check if saved weights exist with matching hash
136
- 2. Extract triples from DB → build KG
137
- 3. Generate rules (same code = same rules)
138
- 4. Initialize fresh model
139
- 5. Apply saved weights
140
- """
141
- try:
142
- collection = self.db.collection(ODIN_MODELS_COLLECTION)
143
- doc = collection.get(NPLL_MODEL_KEY)
144
-
145
- if not doc:
146
- logger.info("No saved weights found in database")
147
- return None
148
-
149
- stored_hash = doc.get("data_hash", "")
150
- if stored_hash != expected_hash:
151
- logger.info(f"Data has changed. Stored: {stored_hash[:16]}..., Current: {expected_hash[:16]}...")
152
- return None
153
-
154
- # Get saved weights
155
- saved_weights = doc.get("rule_weights")
156
- if not saved_weights:
157
- logger.warning("No rule_weights in saved document")
158
- return None
159
-
160
- logger.info("Rebuilding model from KG and applying saved weights...")
161
-
162
- # 1. Extract triples
163
- triples = self._extract_triples()
164
- if not triples:
165
- return None
166
-
167
- # 2. Build KG
168
- kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
169
-
170
- # 3. Generate rules (same code = same rules)
171
- rules = self._generate_smart_rules(kg)
172
-
173
- if len(rules) != len(saved_weights):
174
- logger.warning(f"Rule count mismatch: {len(rules)} rules, {len(saved_weights)} weights. Retraining.")
175
- return None
176
-
177
- # 4. Initialize model
178
- config = get_config("ArangoDB_Triples")
179
- model = create_initialized_npll_model(kg, rules, config)
180
-
181
- # 5. Apply saved weights
182
- with torch.no_grad():
183
- model.mln.rule_weights.copy_(torch.tensor(saved_weights, dtype=torch.float32))
184
-
185
- trained_at = doc.get("trained_at", "unknown")
186
- logger.info(f"✓ Model rebuilt with saved weights (trained: {trained_at})")
187
-
188
- return model
189
-
190
- except Exception as e:
191
- logger.warning(f"Failed to load model: {e}")
192
- return None
193
-
194
- def _train_and_save_weights(self, data_hash: str) -> Optional[NPLLModel]:
195
- """
196
- Train a new NPLL model and save ONLY the weights to database.
197
- """
198
- # 1. Extract Triples
199
- triples = self._extract_triples()
200
- if not triples:
201
- logger.error("No triples extracted. Cannot train.")
202
- return None
203
-
204
- # 2. Build KG
205
- kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
206
- logger.info(f"Built KG: {len(kg.entities)} entities, {len(kg.relations)} relations, {len(kg.known_facts)} facts")
207
-
208
- # Create unknown facts for training (10%)
209
- known_facts_list = list(kg.known_facts)
210
- random.seed(42)
211
- num_unknown = max(1, len(known_facts_list) // 10)
212
- unknown_facts = random.sample(known_facts_list, num_unknown)
213
-
214
- for fact in unknown_facts:
215
- kg.known_facts.remove(fact)
216
- kg.add_unknown_fact(fact.head.name, fact.relation.name, fact.tail.name)
217
-
218
- # 3. Generate Rules
219
- rules = self._generate_smart_rules(kg)
220
- logger.info(f"Generated {len(rules)} logical rules")
221
-
222
- if not rules:
223
- logger.error("No rules generated. Cannot train.")
224
- return None
225
-
226
- # 4. Initialize Model
227
- config = get_config("ArangoDB_Triples")
228
- model = create_initialized_npll_model(kg, rules, config)
229
-
230
- # 5. Train
231
- train_config = TrainingConfig(
232
- num_epochs=10,
233
- max_em_iterations_per_epoch=5,
234
- early_stopping_patience=3,
235
- save_checkpoints=False
236
- )
237
- trainer = create_trainer(model, train_config)
238
-
239
- training_result = None
240
- try:
241
- logger.info("Starting NPLL training...")
242
- training_result = trainer.train()
243
- logger.info(f"Training completed. Final ELBO: {training_result.final_elbo}")
244
- except Exception as e:
245
- logger.error(f"Training failed: {e}", exc_info=True)
246
- return None
247
-
248
- # 6. Save ONLY weights to database
249
- self._save_weights_to_db(model, kg, rules, data_hash, training_result)
250
-
251
- return model
252
-
253
- def _save_weights_to_db(self, model: NPLLModel, kg: KnowledgeGraph,
254
- rules: List[LogicalRule], data_hash: str,
255
- training_result: Any):
256
- """
257
- Save ONLY the learned weights to OdinModels collection.
258
- This is tiny (~1 KB) compared to the full model (280 MB).
259
- """
260
- try:
261
- # Extract just the rule weights
262
- rule_weights = model.mln.rule_weights.detach().cpu().tolist()
263
-
264
- doc = {
265
- "_key": NPLL_MODEL_KEY,
266
- "model_type": "npll",
267
- "storage_type": "weights_only", # Mark this as weights-only storage
268
- "trained_at": datetime.utcnow().isoformat() + "Z",
269
- "data_hash": data_hash,
270
- "rule_weights": rule_weights, # The learned weights - this is all we need!
271
- "schema_snapshot": {
272
- "entity_count": len(kg.entities),
273
- "relation_count": len(kg.relations),
274
- "fact_count": len(kg.known_facts),
275
- "relation_names": sorted([r.name for r in kg.relations])[:50],
276
- },
277
- "training_result": {
278
- "final_elbo": float(training_result.final_elbo) if training_result else 0,
279
- "best_elbo": float(training_result.best_elbo) if training_result else 0,
280
- "converged": training_result.converged if training_result else False,
281
- "training_time_seconds": training_result.total_training_time if training_result else 0,
282
- },
283
- "rules": [
284
- {
285
- "rule_id": r.rule_id,
286
- "rule_text": str(r),
287
- "confidence": r.confidence,
288
- }
289
- for r in rules
290
- ],
291
- "version": "2.0", # Version 2 = weights-only storage
292
- }
293
-
294
- # Upsert
295
- collection = self.db.collection(ODIN_MODELS_COLLECTION)
296
- if collection.has(NPLL_MODEL_KEY):
297
- collection.update(doc)
298
- else:
299
- collection.insert(doc)
300
-
301
- weights_size = len(json.dumps(rule_weights))
302
- logger.info(f"✓ Saved rule weights to database ({weights_size} bytes)")
303
-
304
- except Exception as e:
305
- logger.error(f"Failed to save weights: {e}", exc_info=True)
306
-
307
- def _extract_triples(self) -> List[Tuple[str, str, str]]:
308
- """Extracts S-P-O triples from ArangoDB."""
309
- logger.info("Extracting triples from database...")
310
- triples = []
311
-
312
- # Extract Relationships
313
- query = """
314
- FOR rel IN ExtractedRelationships
315
- LET source = DOCUMENT(rel._from)
316
- LET target = DOCUMENT(rel._to)
317
- FILTER source != null AND target != null
318
- FILTER source._key != null AND target._key != null
319
- RETURN {
320
- source: source._key,
321
- target: target._key,
322
- relation: rel.relationship || "related_to"
323
- }
324
- """
325
- try:
326
- cursor = self.db.aql.execute(query)
327
- for doc in cursor:
328
- s, t = doc['source'], doc['target']
329
- r = str(doc['relation']).replace(' ', '_').lower()
330
- triples.append((s, r, t))
331
- logger.info(f"Extracted {len(triples)} relationship triples")
332
- except Exception as e:
333
- logger.error(f"Extraction error: {e}")
334
- return []
335
-
336
- # Extract Entity Types
337
- query_types = """
338
- FOR entity IN ExtractedEntities
339
- FILTER entity._key != null AND entity.type != null
340
- RETURN { key: entity._key, type: entity.type }
341
- """
342
- try:
343
- cursor = self.db.aql.execute(query_types)
344
- type_count = 0
345
- for doc in cursor:
346
- triples.append((doc['key'], 'has_type', doc['type']))
347
- type_count += 1
348
- logger.info(f"Extracted {type_count} entity type triples")
349
- except Exception as e:
350
- logger.error(f"Type extraction error: {e}")
351
-
352
- logger.info(f"Total triples: {len(triples)}")
353
- return triples
354
-
355
- def _generate_smart_rules(self, kg: KnowledgeGraph) -> List[LogicalRule]:
356
- """
357
- Generates domain-appropriate rules based on available relations.
358
- """
359
- rules = []
360
- relations = {r.name: r for r in kg.relations}
361
- x, y, z = Variable("?x"), Variable("?y"), Variable("?z")
362
-
363
- logger.info(f"Generating rules for {len(relations)} relation types...")
364
-
365
- # --- HEALTHCARE DOMAIN ---
366
- if 'has_claim' in relations and 'submitted_by_provider' in relations and 'treated_by' in relations:
367
- rules.append(LogicalRule(
368
- rule_id="hc_claim_provider_link",
369
- body=[
370
- Atom(relations['has_claim'], (x, y)),
371
- Atom(relations['submitted_by_provider'], (y, z))
372
- ],
373
- head=Atom(relations['treated_by'], (x, z)),
374
- confidence=0.7
375
- ))
376
- logger.info(" + Added: hc_claim_provider_link")
377
-
378
- if 'diagnosed_with' in relations and 'indicates' in relations:
379
- target_rel = relations.get('recommended_procedure') or relations.get('related_to')
380
- if target_rel:
381
- rules.append(LogicalRule(
382
- rule_id="hc_diagnosis_procedure",
383
- body=[
384
- Atom(relations['diagnosed_with'], (x, y)),
385
- Atom(relations['indicates'], (y, z))
386
- ],
387
- head=Atom(target_rel, (x, z)),
388
- confidence=0.6
389
- ))
390
- logger.info(" + Added: hc_diagnosis_procedure")
391
-
392
- if 'works_at' in relations and 'located_at' in relations:
393
- target_rel = relations.get('affiliated_with') or relations.get('related_to')
394
- if target_rel:
395
- rules.append(LogicalRule(
396
- rule_id="hc_provider_facility",
397
- body=[
398
- Atom(relations['works_at'], (x, y)),
399
- Atom(relations['located_at'], (y, z))
400
- ],
401
- head=Atom(target_rel, (x, z)),
402
- confidence=0.6
403
- ))
404
- logger.info(" + Added: hc_provider_facility")
405
-
406
- # --- INSURANCE DOMAIN ---
407
- if 'policyholder' in relations and 'claim_number' in relations and 'related_to' in relations:
408
- rules.append(LogicalRule(
409
- rule_id="ins_policy_claim",
410
- body=[
411
- Atom(relations['policyholder'], (x, y)),
412
- Atom(relations['claim_number'], (x, z))
413
- ],
414
- head=Atom(relations['related_to'], (y, z)),
415
- confidence=0.8
416
- ))
417
- logger.info(" + Added: ins_policy_claim")
418
-
419
- if 'assessor' in relations and 'insurer' in relations and 'related_to' in relations:
420
- rules.append(LogicalRule(
421
- rule_id="ins_assessor_insurer",
422
- body=[
423
- Atom(relations['assessor'], (x, y)),
424
- Atom(relations['insurer'], (z, y))
425
- ],
426
- head=Atom(relations['related_to'], (x, z)),
427
- confidence=0.7
428
- ))
429
- logger.info(" + Added: ins_assessor_insurer")
430
-
431
- # --- GENERIC RULES ---
432
- if 'related_to' in relations:
433
- rules.append(LogicalRule(
434
- rule_id="gen_transitivity",
435
- body=[
436
- Atom(relations['related_to'], (x, y)),
437
- Atom(relations['related_to'], (y, z))
438
- ],
439
- head=Atom(relations['related_to'], (x, z)),
440
- rule_type=RuleType.TRANSITIVITY,
441
- confidence=0.5
442
- ))
443
- logger.info(" + Added: gen_transitivity")
444
-
445
- if 'has_type' in relations and 'related_to' in relations:
446
- rules.append(LogicalRule(
447
- rule_id="gen_type_cooccurrence",
448
- body=[
449
- Atom(relations['has_type'], (x, y)),
450
- Atom(relations['has_type'], (z, y))
451
- ],
452
- head=Atom(relations['related_to'], (x, z)),
453
- confidence=0.3
454
- ))
455
- logger.info(" + Added: gen_type_cooccurrence")
456
-
457
- # Fallback
458
- if not rules:
459
- logger.warning("No domain rules matched. Creating fallback.")
460
- rel = next(iter(kg.relations))
461
- rules.append(LogicalRule(
462
- rule_id="fallback_self",
463
- body=[Atom(rel, (x, y))],
464
- head=Atom(rel, (x, y)),
465
- confidence=0.5
466
- ))
467
-
468
- logger.info(f"Total rules: {len(rules)}")
469
- return rules
470
-
471
-
472
- def create_bootstrapper(db: StandardDatabase) -> KnowledgeBootstrapper:
473
- """Factory function to create a KnowledgeBootstrapper."""
474
- return KnowledgeBootstrapper(db)
1
+ """
2
+ Bootstrap module for NPLL.
3
+ Handles the end-to-end lifecycle of the NPLL model:
4
+ 1. Extracting data from ArangoDB
5
+ 2. Generating domain-appropriate logical rules
6
+ 3. Training the model
7
+ 4. Storing ONLY WEIGHTS in database (not full model)
8
+
9
+ Architecture:
10
+ - Weights stored in OdinModels collection (~1 KB)
11
+ - Model rebuilt from KG on each load (~30 sec)
12
+ - No external files needed
13
+ """
14
+
15
+ import os
16
+ import hashlib
17
+ import json
18
+ import logging
19
+ import random
20
+ import time
21
+ import torch
22
+ from datetime import datetime
23
+ from typing import List, Tuple, Dict, Optional, Any
24
+ from arango.database import StandardDatabase
25
+
26
+ from .core.knowledge_graph import KnowledgeGraph, load_knowledge_graph_from_triples
27
+ from .core.logical_rules import LogicalRule, Atom, Variable, RuleType
28
+ from .npll_model import create_initialized_npll_model, NPLLModel
29
+ from .training.npll_trainer import TrainingConfig, create_trainer
30
+ from .utils.config import get_config
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Collection name for storing model weights
35
+ ODIN_MODELS_COLLECTION = "OdinModels"
36
+ NPLL_MODEL_KEY = "npll_current"
37
+
38
+
39
+ class KnowledgeBootstrapper:
40
+ """
41
+ Manages the lifecycle of the NPLL model.
42
+
43
+ Storage Strategy:
44
+ - Only rule weights are saved to database (~1 KB)
45
+ - Model is rebuilt from KG data on each load (~30 sec)
46
+ - No external .pt files needed
47
+ """
48
+
49
+ def __init__(self, db: StandardDatabase):
50
+ """
51
+ Initialize the bootstrapper.
52
+
53
+ Args:
54
+ db: An already-connected ArangoDB database instance
55
+ """
56
+ self.db = db
57
+ self._ensure_collection_exists()
58
+
59
+ def _ensure_collection_exists(self):
60
+ """Ensure OdinModels collection exists."""
61
+ try:
62
+ if not self.db.has_collection(ODIN_MODELS_COLLECTION):
63
+ self.db.create_collection(ODIN_MODELS_COLLECTION)
64
+ logger.info(f"Created {ODIN_MODELS_COLLECTION} collection")
65
+ except Exception as e:
66
+ # Suppress permission errors (common in read-only environments)
67
+ logger.warning(f"Could not verify/create {ODIN_MODELS_COLLECTION} (Permission Error?): {e}")
68
+
69
+ def ensure_model_ready(self, force_retrain: bool = False) -> Optional[NPLLModel]:
70
+ """
71
+ Ensures a trained NPLL model is available.
72
+
73
+ Flow:
74
+ 1. Compute current data hash
75
+ 2. Check OdinModels for saved weights with matching hash
76
+ 3. If found: rebuild model from KG, apply saved weights
77
+ 4. If not found: train new model, save weights to DB
78
+
79
+ Args:
80
+ force_retrain: If True, ignores cached weights and retrains
81
+
82
+ Returns:
83
+ Loaded NPLLModel ready for inference, or None if failed
84
+ """
85
+ current_hash = self._compute_data_hash()
86
+ logger.info(f"Current data hash: {current_hash[:16]}...")
87
+
88
+ if not force_retrain:
89
+ # Try to load existing weights and rebuild model
90
+ model = self._load_model_with_weights(current_hash)
91
+ if model:
92
+ return model
93
+
94
+ # Train new model
95
+ logger.info("Training new NPLL model...")
96
+ return self._train_and_save_weights(current_hash)
97
+
98
+ def _compute_data_hash(self) -> str:
99
+ """
100
+ Compute a hash of the current schema to detect when retraining is needed.
101
+ """
102
+ try:
103
+ # Get relation names
104
+ rel_query = """
105
+ FOR e IN ExtractedRelationships
106
+ COLLECT rel = e.relationship WITH COUNT INTO cnt
107
+ SORT rel
108
+ RETURN {rel: rel, count: cnt}
109
+ """
110
+ relations = list(self.db.aql.execute(rel_query))
111
+ relation_names = sorted([r['rel'] for r in relations if r['rel']])
112
+
113
+ # Get counts
114
+ entity_count = self.db.collection("ExtractedEntities").count()
115
+ fact_count = self.db.collection("ExtractedRelationships").count()
116
+
117
+ hash_input = {
118
+ "relations": relation_names,
119
+ "entity_count": entity_count,
120
+ "fact_count": fact_count,
121
+ }
122
+
123
+ hash_str = json.dumps(hash_input, sort_keys=True)
124
+ return hashlib.sha256(hash_str.encode()).hexdigest()
125
+
126
+ except Exception as e:
127
+ logger.warning(f"Could not compute data hash: {e}")
128
+ return hashlib.sha256(str(time.time()).encode()).hexdigest()
129
+
130
+ def _load_model_with_weights(self, expected_hash: str) -> Optional[NPLLModel]:
131
+ """
132
+ Load saved weights from DB and rebuild the model.
133
+
134
+ Flow:
135
+ 1. Check if saved weights exist with matching hash
136
+ 2. Extract triples from DB → build KG
137
+ 3. Generate rules (same code = same rules)
138
+ 4. Initialize fresh model
139
+ 5. Apply saved weights
140
+ """
141
+ try:
142
+ collection = self.db.collection(ODIN_MODELS_COLLECTION)
143
+ doc = collection.get(NPLL_MODEL_KEY)
144
+
145
+ if not doc:
146
+ logger.info("No saved weights found in database")
147
+ return None
148
+
149
+ stored_hash = doc.get("data_hash", "")
150
+ if stored_hash != expected_hash:
151
+ logger.info(f"Data has changed. Stored: {stored_hash[:16]}..., Current: {expected_hash[:16]}...")
152
+ return None
153
+
154
+ # Get saved weights
155
+ saved_weights = doc.get("rule_weights")
156
+ if not saved_weights:
157
+ logger.warning("No rule_weights in saved document")
158
+ return None
159
+
160
+ logger.info("Rebuilding model from KG and applying saved weights...")
161
+
162
+ # 1. Extract triples
163
+ triples = self._extract_triples()
164
+ if not triples:
165
+ return None
166
+
167
+ # 2. Build KG
168
+ kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
169
+
170
+ # 3. Generate rules (same code = same rules)
171
+ rules = self._generate_smart_rules(kg)
172
+
173
+ if len(rules) != len(saved_weights):
174
+ logger.warning(f"Rule count mismatch: {len(rules)} rules, {len(saved_weights)} weights. Retraining.")
175
+ return None
176
+
177
+ # 4. Initialize model
178
+ config = get_config("ArangoDB_Triples")
179
+ model = create_initialized_npll_model(kg, rules, config)
180
+
181
+ # 5. Apply saved weights
182
+ with torch.no_grad():
183
+ model.mln.rule_weights.copy_(torch.tensor(saved_weights, dtype=torch.float32))
184
+
185
+ trained_at = doc.get("trained_at", "unknown")
186
+ logger.info(f"✓ Model rebuilt with saved weights (trained: {trained_at})")
187
+
188
+ return model
189
+
190
+ except Exception as e:
191
+ logger.warning(f"Failed to load model: {e}")
192
+ return None
193
+
194
+ def _train_and_save_weights(self, data_hash: str) -> Optional[NPLLModel]:
195
+ """
196
+ Train a new NPLL model and save ONLY the weights to database.
197
+ """
198
+ # 1. Extract Triples
199
+ triples = self._extract_triples()
200
+ if not triples:
201
+ logger.error("No triples extracted. Cannot train.")
202
+ return None
203
+
204
+ # 2. Build KG
205
+ kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
206
+ logger.info(f"Built KG: {len(kg.entities)} entities, {len(kg.relations)} relations, {len(kg.known_facts)} facts")
207
+
208
+ # Create unknown facts for training (10%)
209
+ known_facts_list = list(kg.known_facts)
210
+ random.seed(42)
211
+ num_unknown = max(1, len(known_facts_list) // 10)
212
+ unknown_facts = random.sample(known_facts_list, num_unknown)
213
+
214
+ for fact in unknown_facts:
215
+ kg.known_facts.remove(fact)
216
+ kg.add_unknown_fact(fact.head.name, fact.relation.name, fact.tail.name)
217
+
218
+ # 3. Generate Rules
219
+ rules = self._generate_smart_rules(kg)
220
+ logger.info(f"Generated {len(rules)} logical rules")
221
+
222
+ if not rules:
223
+ logger.error("No rules generated. Cannot train.")
224
+ return None
225
+
226
+ # 4. Initialize Model
227
+ config = get_config("ArangoDB_Triples")
228
+ model = create_initialized_npll_model(kg, rules, config)
229
+
230
+ # 5. Train
231
+ train_config = TrainingConfig(
232
+ num_epochs=10,
233
+ max_em_iterations_per_epoch=5,
234
+ early_stopping_patience=3,
235
+ save_checkpoints=False
236
+ )
237
+ trainer = create_trainer(model, train_config)
238
+
239
+ training_result = None
240
+ try:
241
+ logger.info("Starting NPLL training...")
242
+ training_result = trainer.train()
243
+ logger.info(f"Training completed. Final ELBO: {training_result.final_elbo}")
244
+ except Exception as e:
245
+ logger.error(f"Training failed: {e}", exc_info=True)
246
+ return None
247
+
248
+ # 6. Save ONLY weights to database
249
+ self._save_weights_to_db(model, kg, rules, data_hash, training_result)
250
+
251
+ return model
252
+
253
+ def _save_weights_to_db(self, model: NPLLModel, kg: KnowledgeGraph,
254
+ rules: List[LogicalRule], data_hash: str,
255
+ training_result: Any):
256
+ """
257
+ Save ONLY the learned weights to OdinModels collection.
258
+ This is tiny (~1 KB) compared to the full model (280 MB).
259
+ """
260
+ try:
261
+ # Extract just the rule weights
262
+ rule_weights = model.mln.rule_weights.detach().cpu().tolist()
263
+
264
+ doc = {
265
+ "_key": NPLL_MODEL_KEY,
266
+ "model_type": "npll",
267
+ "storage_type": "weights_only", # Mark this as weights-only storage
268
+ "trained_at": datetime.utcnow().isoformat() + "Z",
269
+ "data_hash": data_hash,
270
+ "rule_weights": rule_weights, # The learned weights - this is all we need!
271
+ "schema_snapshot": {
272
+ "entity_count": len(kg.entities),
273
+ "relation_count": len(kg.relations),
274
+ "fact_count": len(kg.known_facts),
275
+ "relation_names": sorted([r.name for r in kg.relations])[:50],
276
+ },
277
+ "training_result": {
278
+ "final_elbo": float(training_result.final_elbo) if training_result else 0,
279
+ "best_elbo": float(training_result.best_elbo) if training_result else 0,
280
+ "converged": training_result.converged if training_result else False,
281
+ "training_time_seconds": training_result.total_training_time if training_result else 0,
282
+ },
283
+ "rules": [
284
+ {
285
+ "rule_id": r.rule_id,
286
+ "rule_text": str(r),
287
+ "confidence": r.confidence,
288
+ }
289
+ for r in rules
290
+ ],
291
+ "version": "2.0", # Version 2 = weights-only storage
292
+ }
293
+
294
+ # Upsert
295
+ collection = self.db.collection(ODIN_MODELS_COLLECTION)
296
+ if collection.has(NPLL_MODEL_KEY):
297
+ collection.update(doc)
298
+ else:
299
+ collection.insert(doc)
300
+
301
+ weights_size = len(json.dumps(rule_weights))
302
+ logger.info(f"✓ Saved rule weights to database ({weights_size} bytes)")
303
+
304
+ except Exception as e:
305
+ logger.error(f"Failed to save weights: {e}", exc_info=True)
306
+
307
+ def _extract_triples(self) -> List[Tuple[str, str, str]]:
308
+ """Extracts S-P-O triples from ArangoDB."""
309
+ logger.info("Extracting triples from database...")
310
+ triples = []
311
+
312
+ # Extract Relationships
313
+ query = """
314
+ FOR rel IN ExtractedRelationships
315
+ LET source = DOCUMENT(rel._from)
316
+ LET target = DOCUMENT(rel._to)
317
+ FILTER source != null AND target != null
318
+ FILTER source._key != null AND target._key != null
319
+ RETURN {
320
+ source: source._key,
321
+ target: target._key,
322
+ relation: rel.relationship || "related_to"
323
+ }
324
+ """
325
+ try:
326
+ cursor = self.db.aql.execute(query)
327
+ for doc in cursor:
328
+ s, t = doc['source'], doc['target']
329
+ r = str(doc['relation']).replace(' ', '_').lower()
330
+ triples.append((s, r, t))
331
+ logger.info(f"Extracted {len(triples)} relationship triples")
332
+ except Exception as e:
333
+ logger.error(f"Extraction error: {e}")
334
+ return []
335
+
336
+ # Extract Entity Types
337
+ query_types = """
338
+ FOR entity IN ExtractedEntities
339
+ FILTER entity._key != null AND entity.type != null
340
+ RETURN { key: entity._key, type: entity.type }
341
+ """
342
+ try:
343
+ cursor = self.db.aql.execute(query_types)
344
+ type_count = 0
345
+ for doc in cursor:
346
+ triples.append((doc['key'], 'has_type', doc['type']))
347
+ type_count += 1
348
+ logger.info(f"Extracted {type_count} entity type triples")
349
+ except Exception as e:
350
+ logger.error(f"Type extraction error: {e}")
351
+
352
+ logger.info(f"Total triples: {len(triples)}")
353
+ return triples
354
+
355
+ def _generate_smart_rules(self, kg: KnowledgeGraph) -> List[LogicalRule]:
356
+ """
357
+ Generates domain-appropriate rules based on available relations.
358
+ """
359
+ rules = []
360
+ relations = {r.name: r for r in kg.relations}
361
+ x, y, z = Variable("?x"), Variable("?y"), Variable("?z")
362
+
363
+ logger.info(f"Generating rules for {len(relations)} relation types...")
364
+
365
+ # --- HEALTHCARE DOMAIN ---
366
+ if 'has_claim' in relations and 'submitted_by_provider' in relations and 'treated_by' in relations:
367
+ rules.append(LogicalRule(
368
+ rule_id="hc_claim_provider_link",
369
+ body=[
370
+ Atom(relations['has_claim'], (x, y)),
371
+ Atom(relations['submitted_by_provider'], (y, z))
372
+ ],
373
+ head=Atom(relations['treated_by'], (x, z)),
374
+ confidence=0.7
375
+ ))
376
+ logger.info(" + Added: hc_claim_provider_link")
377
+
378
+ if 'diagnosed_with' in relations and 'indicates' in relations:
379
+ target_rel = relations.get('recommended_procedure') or relations.get('related_to')
380
+ if target_rel:
381
+ rules.append(LogicalRule(
382
+ rule_id="hc_diagnosis_procedure",
383
+ body=[
384
+ Atom(relations['diagnosed_with'], (x, y)),
385
+ Atom(relations['indicates'], (y, z))
386
+ ],
387
+ head=Atom(target_rel, (x, z)),
388
+ confidence=0.6
389
+ ))
390
+ logger.info(" + Added: hc_diagnosis_procedure")
391
+
392
+ if 'works_at' in relations and 'located_at' in relations:
393
+ target_rel = relations.get('affiliated_with') or relations.get('related_to')
394
+ if target_rel:
395
+ rules.append(LogicalRule(
396
+ rule_id="hc_provider_facility",
397
+ body=[
398
+ Atom(relations['works_at'], (x, y)),
399
+ Atom(relations['located_at'], (y, z))
400
+ ],
401
+ head=Atom(target_rel, (x, z)),
402
+ confidence=0.6
403
+ ))
404
+ logger.info(" + Added: hc_provider_facility")
405
+
406
+ # --- INSURANCE DOMAIN ---
407
+ if 'policyholder' in relations and 'claim_number' in relations and 'related_to' in relations:
408
+ rules.append(LogicalRule(
409
+ rule_id="ins_policy_claim",
410
+ body=[
411
+ Atom(relations['policyholder'], (x, y)),
412
+ Atom(relations['claim_number'], (x, z))
413
+ ],
414
+ head=Atom(relations['related_to'], (y, z)),
415
+ confidence=0.8
416
+ ))
417
+ logger.info(" + Added: ins_policy_claim")
418
+
419
+ if 'assessor' in relations and 'insurer' in relations and 'related_to' in relations:
420
+ rules.append(LogicalRule(
421
+ rule_id="ins_assessor_insurer",
422
+ body=[
423
+ Atom(relations['assessor'], (x, y)),
424
+ Atom(relations['insurer'], (z, y))
425
+ ],
426
+ head=Atom(relations['related_to'], (x, z)),
427
+ confidence=0.7
428
+ ))
429
+ logger.info(" + Added: ins_assessor_insurer")
430
+
431
+ # --- GENERIC RULES ---
432
+ if 'related_to' in relations:
433
+ rules.append(LogicalRule(
434
+ rule_id="gen_transitivity",
435
+ body=[
436
+ Atom(relations['related_to'], (x, y)),
437
+ Atom(relations['related_to'], (y, z))
438
+ ],
439
+ head=Atom(relations['related_to'], (x, z)),
440
+ rule_type=RuleType.TRANSITIVITY,
441
+ confidence=0.5
442
+ ))
443
+ logger.info(" + Added: gen_transitivity")
444
+
445
+ if 'has_type' in relations and 'related_to' in relations:
446
+ rules.append(LogicalRule(
447
+ rule_id="gen_type_cooccurrence",
448
+ body=[
449
+ Atom(relations['has_type'], (x, y)),
450
+ Atom(relations['has_type'], (z, y))
451
+ ],
452
+ head=Atom(relations['related_to'], (x, z)),
453
+ confidence=0.3
454
+ ))
455
+ logger.info(" + Added: gen_type_cooccurrence")
456
+
457
+ # Fallback
458
+ if not rules:
459
+ logger.warning("No domain rules matched. Creating fallback.")
460
+ rel = next(iter(kg.relations))
461
+ rules.append(LogicalRule(
462
+ rule_id="fallback_self",
463
+ body=[Atom(rel, (x, y))],
464
+ head=Atom(rel, (x, y)),
465
+ confidence=0.5
466
+ ))
467
+
468
+ logger.info(f"Total rules: {len(rules)}")
469
+ return rules
470
+
471
+
472
+ def create_bootstrapper(db: StandardDatabase) -> KnowledgeBootstrapper:
473
+ """Factory function to create a KnowledgeBootstrapper."""
474
+ return KnowledgeBootstrapper(db)