odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/bootstrap.py
CHANGED
|
@@ -1,474 +1,474 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Bootstrap module for NPLL.
|
|
3
|
-
Handles the end-to-end lifecycle of the NPLL model:
|
|
4
|
-
1. Extracting data from ArangoDB
|
|
5
|
-
2. Generating domain-appropriate logical rules
|
|
6
|
-
3. Training the model
|
|
7
|
-
4. Storing ONLY WEIGHTS in database (not full model)
|
|
8
|
-
|
|
9
|
-
Architecture:
|
|
10
|
-
- Weights stored in OdinModels collection (~1 KB)
|
|
11
|
-
- Model rebuilt from KG on each load (~30 sec)
|
|
12
|
-
- No external files needed
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
import os
|
|
16
|
-
import hashlib
|
|
17
|
-
import json
|
|
18
|
-
import logging
|
|
19
|
-
import random
|
|
20
|
-
import time
|
|
21
|
-
import torch
|
|
22
|
-
from datetime import datetime
|
|
23
|
-
from typing import List, Tuple, Dict, Optional, Any
|
|
24
|
-
from arango.database import StandardDatabase
|
|
25
|
-
|
|
26
|
-
from .core.knowledge_graph import KnowledgeGraph, load_knowledge_graph_from_triples
|
|
27
|
-
from .core.logical_rules import LogicalRule, Atom, Variable, RuleType
|
|
28
|
-
from .npll_model import create_initialized_npll_model, NPLLModel
|
|
29
|
-
from .training.npll_trainer import TrainingConfig, create_trainer
|
|
30
|
-
from .utils.config import get_config
|
|
31
|
-
|
|
32
|
-
logger = logging.getLogger(__name__)
|
|
33
|
-
|
|
34
|
-
# Collection name for storing model weights
|
|
35
|
-
ODIN_MODELS_COLLECTION = "OdinModels"
|
|
36
|
-
NPLL_MODEL_KEY = "npll_current"
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class KnowledgeBootstrapper:
|
|
40
|
-
"""
|
|
41
|
-
Manages the lifecycle of the NPLL model.
|
|
42
|
-
|
|
43
|
-
Storage Strategy:
|
|
44
|
-
- Only rule weights are saved to database (~1 KB)
|
|
45
|
-
- Model is rebuilt from KG data on each load (~30 sec)
|
|
46
|
-
- No external .pt files needed
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
|
-
def __init__(self, db: StandardDatabase):
|
|
50
|
-
"""
|
|
51
|
-
Initialize the bootstrapper.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
db: An already-connected ArangoDB database instance
|
|
55
|
-
"""
|
|
56
|
-
self.db = db
|
|
57
|
-
self._ensure_collection_exists()
|
|
58
|
-
|
|
59
|
-
def _ensure_collection_exists(self):
|
|
60
|
-
"""Ensure OdinModels collection exists."""
|
|
61
|
-
try:
|
|
62
|
-
if not self.db.has_collection(ODIN_MODELS_COLLECTION):
|
|
63
|
-
self.db.create_collection(ODIN_MODELS_COLLECTION)
|
|
64
|
-
logger.info(f"Created {ODIN_MODELS_COLLECTION} collection")
|
|
65
|
-
except Exception as e:
|
|
66
|
-
# Suppress permission errors (common in read-only environments)
|
|
67
|
-
logger.warning(f"Could not verify/create {ODIN_MODELS_COLLECTION} (Permission Error?): {e}")
|
|
68
|
-
|
|
69
|
-
def ensure_model_ready(self, force_retrain: bool = False) -> Optional[NPLLModel]:
|
|
70
|
-
"""
|
|
71
|
-
Ensures a trained NPLL model is available.
|
|
72
|
-
|
|
73
|
-
Flow:
|
|
74
|
-
1. Compute current data hash
|
|
75
|
-
2. Check OdinModels for saved weights with matching hash
|
|
76
|
-
3. If found: rebuild model from KG, apply saved weights
|
|
77
|
-
4. If not found: train new model, save weights to DB
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
force_retrain: If True, ignores cached weights and retrains
|
|
81
|
-
|
|
82
|
-
Returns:
|
|
83
|
-
Loaded NPLLModel ready for inference, or None if failed
|
|
84
|
-
"""
|
|
85
|
-
current_hash = self._compute_data_hash()
|
|
86
|
-
logger.info(f"Current data hash: {current_hash[:16]}...")
|
|
87
|
-
|
|
88
|
-
if not force_retrain:
|
|
89
|
-
# Try to load existing weights and rebuild model
|
|
90
|
-
model = self._load_model_with_weights(current_hash)
|
|
91
|
-
if model:
|
|
92
|
-
return model
|
|
93
|
-
|
|
94
|
-
# Train new model
|
|
95
|
-
logger.info("Training new NPLL model...")
|
|
96
|
-
return self._train_and_save_weights(current_hash)
|
|
97
|
-
|
|
98
|
-
def _compute_data_hash(self) -> str:
|
|
99
|
-
"""
|
|
100
|
-
Compute a hash of the current schema to detect when retraining is needed.
|
|
101
|
-
"""
|
|
102
|
-
try:
|
|
103
|
-
# Get relation names
|
|
104
|
-
rel_query = """
|
|
105
|
-
FOR e IN ExtractedRelationships
|
|
106
|
-
COLLECT rel = e.relationship WITH COUNT INTO cnt
|
|
107
|
-
SORT rel
|
|
108
|
-
RETURN {rel: rel, count: cnt}
|
|
109
|
-
"""
|
|
110
|
-
relations = list(self.db.aql.execute(rel_query))
|
|
111
|
-
relation_names = sorted([r['rel'] for r in relations if r['rel']])
|
|
112
|
-
|
|
113
|
-
# Get counts
|
|
114
|
-
entity_count = self.db.collection("ExtractedEntities").count()
|
|
115
|
-
fact_count = self.db.collection("ExtractedRelationships").count()
|
|
116
|
-
|
|
117
|
-
hash_input = {
|
|
118
|
-
"relations": relation_names,
|
|
119
|
-
"entity_count": entity_count,
|
|
120
|
-
"fact_count": fact_count,
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
hash_str = json.dumps(hash_input, sort_keys=True)
|
|
124
|
-
return hashlib.sha256(hash_str.encode()).hexdigest()
|
|
125
|
-
|
|
126
|
-
except Exception as e:
|
|
127
|
-
logger.warning(f"Could not compute data hash: {e}")
|
|
128
|
-
return hashlib.sha256(str(time.time()).encode()).hexdigest()
|
|
129
|
-
|
|
130
|
-
def _load_model_with_weights(self, expected_hash: str) -> Optional[NPLLModel]:
|
|
131
|
-
"""
|
|
132
|
-
Load saved weights from DB and rebuild the model.
|
|
133
|
-
|
|
134
|
-
Flow:
|
|
135
|
-
1. Check if saved weights exist with matching hash
|
|
136
|
-
2. Extract triples from DB → build KG
|
|
137
|
-
3. Generate rules (same code = same rules)
|
|
138
|
-
4. Initialize fresh model
|
|
139
|
-
5. Apply saved weights
|
|
140
|
-
"""
|
|
141
|
-
try:
|
|
142
|
-
collection = self.db.collection(ODIN_MODELS_COLLECTION)
|
|
143
|
-
doc = collection.get(NPLL_MODEL_KEY)
|
|
144
|
-
|
|
145
|
-
if not doc:
|
|
146
|
-
logger.info("No saved weights found in database")
|
|
147
|
-
return None
|
|
148
|
-
|
|
149
|
-
stored_hash = doc.get("data_hash", "")
|
|
150
|
-
if stored_hash != expected_hash:
|
|
151
|
-
logger.info(f"Data has changed. Stored: {stored_hash[:16]}..., Current: {expected_hash[:16]}...")
|
|
152
|
-
return None
|
|
153
|
-
|
|
154
|
-
# Get saved weights
|
|
155
|
-
saved_weights = doc.get("rule_weights")
|
|
156
|
-
if not saved_weights:
|
|
157
|
-
logger.warning("No rule_weights in saved document")
|
|
158
|
-
return None
|
|
159
|
-
|
|
160
|
-
logger.info("Rebuilding model from KG and applying saved weights...")
|
|
161
|
-
|
|
162
|
-
# 1. Extract triples
|
|
163
|
-
triples = self._extract_triples()
|
|
164
|
-
if not triples:
|
|
165
|
-
return None
|
|
166
|
-
|
|
167
|
-
# 2. Build KG
|
|
168
|
-
kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
|
|
169
|
-
|
|
170
|
-
# 3. Generate rules (same code = same rules)
|
|
171
|
-
rules = self._generate_smart_rules(kg)
|
|
172
|
-
|
|
173
|
-
if len(rules) != len(saved_weights):
|
|
174
|
-
logger.warning(f"Rule count mismatch: {len(rules)} rules, {len(saved_weights)} weights. Retraining.")
|
|
175
|
-
return None
|
|
176
|
-
|
|
177
|
-
# 4. Initialize model
|
|
178
|
-
config = get_config("ArangoDB_Triples")
|
|
179
|
-
model = create_initialized_npll_model(kg, rules, config)
|
|
180
|
-
|
|
181
|
-
# 5. Apply saved weights
|
|
182
|
-
with torch.no_grad():
|
|
183
|
-
model.mln.rule_weights.copy_(torch.tensor(saved_weights, dtype=torch.float32))
|
|
184
|
-
|
|
185
|
-
trained_at = doc.get("trained_at", "unknown")
|
|
186
|
-
logger.info(f"✓ Model rebuilt with saved weights (trained: {trained_at})")
|
|
187
|
-
|
|
188
|
-
return model
|
|
189
|
-
|
|
190
|
-
except Exception as e:
|
|
191
|
-
logger.warning(f"Failed to load model: {e}")
|
|
192
|
-
return None
|
|
193
|
-
|
|
194
|
-
def _train_and_save_weights(self, data_hash: str) -> Optional[NPLLModel]:
|
|
195
|
-
"""
|
|
196
|
-
Train a new NPLL model and save ONLY the weights to database.
|
|
197
|
-
"""
|
|
198
|
-
# 1. Extract Triples
|
|
199
|
-
triples = self._extract_triples()
|
|
200
|
-
if not triples:
|
|
201
|
-
logger.error("No triples extracted. Cannot train.")
|
|
202
|
-
return None
|
|
203
|
-
|
|
204
|
-
# 2. Build KG
|
|
205
|
-
kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
|
|
206
|
-
logger.info(f"Built KG: {len(kg.entities)} entities, {len(kg.relations)} relations, {len(kg.known_facts)} facts")
|
|
207
|
-
|
|
208
|
-
# Create unknown facts for training (10%)
|
|
209
|
-
known_facts_list = list(kg.known_facts)
|
|
210
|
-
random.seed(42)
|
|
211
|
-
num_unknown = max(1, len(known_facts_list) // 10)
|
|
212
|
-
unknown_facts = random.sample(known_facts_list, num_unknown)
|
|
213
|
-
|
|
214
|
-
for fact in unknown_facts:
|
|
215
|
-
kg.known_facts.remove(fact)
|
|
216
|
-
kg.add_unknown_fact(fact.head.name, fact.relation.name, fact.tail.name)
|
|
217
|
-
|
|
218
|
-
# 3. Generate Rules
|
|
219
|
-
rules = self._generate_smart_rules(kg)
|
|
220
|
-
logger.info(f"Generated {len(rules)} logical rules")
|
|
221
|
-
|
|
222
|
-
if not rules:
|
|
223
|
-
logger.error("No rules generated. Cannot train.")
|
|
224
|
-
return None
|
|
225
|
-
|
|
226
|
-
# 4. Initialize Model
|
|
227
|
-
config = get_config("ArangoDB_Triples")
|
|
228
|
-
model = create_initialized_npll_model(kg, rules, config)
|
|
229
|
-
|
|
230
|
-
# 5. Train
|
|
231
|
-
train_config = TrainingConfig(
|
|
232
|
-
num_epochs=10,
|
|
233
|
-
max_em_iterations_per_epoch=5,
|
|
234
|
-
early_stopping_patience=3,
|
|
235
|
-
save_checkpoints=False
|
|
236
|
-
)
|
|
237
|
-
trainer = create_trainer(model, train_config)
|
|
238
|
-
|
|
239
|
-
training_result = None
|
|
240
|
-
try:
|
|
241
|
-
logger.info("Starting NPLL training...")
|
|
242
|
-
training_result = trainer.train()
|
|
243
|
-
logger.info(f"Training completed. Final ELBO: {training_result.final_elbo}")
|
|
244
|
-
except Exception as e:
|
|
245
|
-
logger.error(f"Training failed: {e}", exc_info=True)
|
|
246
|
-
return None
|
|
247
|
-
|
|
248
|
-
# 6. Save ONLY weights to database
|
|
249
|
-
self._save_weights_to_db(model, kg, rules, data_hash, training_result)
|
|
250
|
-
|
|
251
|
-
return model
|
|
252
|
-
|
|
253
|
-
def _save_weights_to_db(self, model: NPLLModel, kg: KnowledgeGraph,
|
|
254
|
-
rules: List[LogicalRule], data_hash: str,
|
|
255
|
-
training_result: Any):
|
|
256
|
-
"""
|
|
257
|
-
Save ONLY the learned weights to OdinModels collection.
|
|
258
|
-
This is tiny (~1 KB) compared to the full model (280 MB).
|
|
259
|
-
"""
|
|
260
|
-
try:
|
|
261
|
-
# Extract just the rule weights
|
|
262
|
-
rule_weights = model.mln.rule_weights.detach().cpu().tolist()
|
|
263
|
-
|
|
264
|
-
doc = {
|
|
265
|
-
"_key": NPLL_MODEL_KEY,
|
|
266
|
-
"model_type": "npll",
|
|
267
|
-
"storage_type": "weights_only", # Mark this as weights-only storage
|
|
268
|
-
"trained_at": datetime.utcnow().isoformat() + "Z",
|
|
269
|
-
"data_hash": data_hash,
|
|
270
|
-
"rule_weights": rule_weights, # The learned weights - this is all we need!
|
|
271
|
-
"schema_snapshot": {
|
|
272
|
-
"entity_count": len(kg.entities),
|
|
273
|
-
"relation_count": len(kg.relations),
|
|
274
|
-
"fact_count": len(kg.known_facts),
|
|
275
|
-
"relation_names": sorted([r.name for r in kg.relations])[:50],
|
|
276
|
-
},
|
|
277
|
-
"training_result": {
|
|
278
|
-
"final_elbo": float(training_result.final_elbo) if training_result else 0,
|
|
279
|
-
"best_elbo": float(training_result.best_elbo) if training_result else 0,
|
|
280
|
-
"converged": training_result.converged if training_result else False,
|
|
281
|
-
"training_time_seconds": training_result.total_training_time if training_result else 0,
|
|
282
|
-
},
|
|
283
|
-
"rules": [
|
|
284
|
-
{
|
|
285
|
-
"rule_id": r.rule_id,
|
|
286
|
-
"rule_text": str(r),
|
|
287
|
-
"confidence": r.confidence,
|
|
288
|
-
}
|
|
289
|
-
for r in rules
|
|
290
|
-
],
|
|
291
|
-
"version": "2.0", # Version 2 = weights-only storage
|
|
292
|
-
}
|
|
293
|
-
|
|
294
|
-
# Upsert
|
|
295
|
-
collection = self.db.collection(ODIN_MODELS_COLLECTION)
|
|
296
|
-
if collection.has(NPLL_MODEL_KEY):
|
|
297
|
-
collection.update(doc)
|
|
298
|
-
else:
|
|
299
|
-
collection.insert(doc)
|
|
300
|
-
|
|
301
|
-
weights_size = len(json.dumps(rule_weights))
|
|
302
|
-
logger.info(f"✓ Saved rule weights to database ({weights_size} bytes)")
|
|
303
|
-
|
|
304
|
-
except Exception as e:
|
|
305
|
-
logger.error(f"Failed to save weights: {e}", exc_info=True)
|
|
306
|
-
|
|
307
|
-
def _extract_triples(self) -> List[Tuple[str, str, str]]:
|
|
308
|
-
"""Extracts S-P-O triples from ArangoDB."""
|
|
309
|
-
logger.info("Extracting triples from database...")
|
|
310
|
-
triples = []
|
|
311
|
-
|
|
312
|
-
# Extract Relationships
|
|
313
|
-
query = """
|
|
314
|
-
FOR rel IN ExtractedRelationships
|
|
315
|
-
LET source = DOCUMENT(rel._from)
|
|
316
|
-
LET target = DOCUMENT(rel._to)
|
|
317
|
-
FILTER source != null AND target != null
|
|
318
|
-
FILTER source._key != null AND target._key != null
|
|
319
|
-
RETURN {
|
|
320
|
-
source: source._key,
|
|
321
|
-
target: target._key,
|
|
322
|
-
relation: rel.relationship || "related_to"
|
|
323
|
-
}
|
|
324
|
-
"""
|
|
325
|
-
try:
|
|
326
|
-
cursor = self.db.aql.execute(query)
|
|
327
|
-
for doc in cursor:
|
|
328
|
-
s, t = doc['source'], doc['target']
|
|
329
|
-
r = str(doc['relation']).replace(' ', '_').lower()
|
|
330
|
-
triples.append((s, r, t))
|
|
331
|
-
logger.info(f"Extracted {len(triples)} relationship triples")
|
|
332
|
-
except Exception as e:
|
|
333
|
-
logger.error(f"Extraction error: {e}")
|
|
334
|
-
return []
|
|
335
|
-
|
|
336
|
-
# Extract Entity Types
|
|
337
|
-
query_types = """
|
|
338
|
-
FOR entity IN ExtractedEntities
|
|
339
|
-
FILTER entity._key != null AND entity.type != null
|
|
340
|
-
RETURN { key: entity._key, type: entity.type }
|
|
341
|
-
"""
|
|
342
|
-
try:
|
|
343
|
-
cursor = self.db.aql.execute(query_types)
|
|
344
|
-
type_count = 0
|
|
345
|
-
for doc in cursor:
|
|
346
|
-
triples.append((doc['key'], 'has_type', doc['type']))
|
|
347
|
-
type_count += 1
|
|
348
|
-
logger.info(f"Extracted {type_count} entity type triples")
|
|
349
|
-
except Exception as e:
|
|
350
|
-
logger.error(f"Type extraction error: {e}")
|
|
351
|
-
|
|
352
|
-
logger.info(f"Total triples: {len(triples)}")
|
|
353
|
-
return triples
|
|
354
|
-
|
|
355
|
-
def _generate_smart_rules(self, kg: KnowledgeGraph) -> List[LogicalRule]:
|
|
356
|
-
"""
|
|
357
|
-
Generates domain-appropriate rules based on available relations.
|
|
358
|
-
"""
|
|
359
|
-
rules = []
|
|
360
|
-
relations = {r.name: r for r in kg.relations}
|
|
361
|
-
x, y, z = Variable("?x"), Variable("?y"), Variable("?z")
|
|
362
|
-
|
|
363
|
-
logger.info(f"Generating rules for {len(relations)} relation types...")
|
|
364
|
-
|
|
365
|
-
# --- HEALTHCARE DOMAIN ---
|
|
366
|
-
if 'has_claim' in relations and 'submitted_by_provider' in relations and 'treated_by' in relations:
|
|
367
|
-
rules.append(LogicalRule(
|
|
368
|
-
rule_id="hc_claim_provider_link",
|
|
369
|
-
body=[
|
|
370
|
-
Atom(relations['has_claim'], (x, y)),
|
|
371
|
-
Atom(relations['submitted_by_provider'], (y, z))
|
|
372
|
-
],
|
|
373
|
-
head=Atom(relations['treated_by'], (x, z)),
|
|
374
|
-
confidence=0.7
|
|
375
|
-
))
|
|
376
|
-
logger.info(" + Added: hc_claim_provider_link")
|
|
377
|
-
|
|
378
|
-
if 'diagnosed_with' in relations and 'indicates' in relations:
|
|
379
|
-
target_rel = relations.get('recommended_procedure') or relations.get('related_to')
|
|
380
|
-
if target_rel:
|
|
381
|
-
rules.append(LogicalRule(
|
|
382
|
-
rule_id="hc_diagnosis_procedure",
|
|
383
|
-
body=[
|
|
384
|
-
Atom(relations['diagnosed_with'], (x, y)),
|
|
385
|
-
Atom(relations['indicates'], (y, z))
|
|
386
|
-
],
|
|
387
|
-
head=Atom(target_rel, (x, z)),
|
|
388
|
-
confidence=0.6
|
|
389
|
-
))
|
|
390
|
-
logger.info(" + Added: hc_diagnosis_procedure")
|
|
391
|
-
|
|
392
|
-
if 'works_at' in relations and 'located_at' in relations:
|
|
393
|
-
target_rel = relations.get('affiliated_with') or relations.get('related_to')
|
|
394
|
-
if target_rel:
|
|
395
|
-
rules.append(LogicalRule(
|
|
396
|
-
rule_id="hc_provider_facility",
|
|
397
|
-
body=[
|
|
398
|
-
Atom(relations['works_at'], (x, y)),
|
|
399
|
-
Atom(relations['located_at'], (y, z))
|
|
400
|
-
],
|
|
401
|
-
head=Atom(target_rel, (x, z)),
|
|
402
|
-
confidence=0.6
|
|
403
|
-
))
|
|
404
|
-
logger.info(" + Added: hc_provider_facility")
|
|
405
|
-
|
|
406
|
-
# --- INSURANCE DOMAIN ---
|
|
407
|
-
if 'policyholder' in relations and 'claim_number' in relations and 'related_to' in relations:
|
|
408
|
-
rules.append(LogicalRule(
|
|
409
|
-
rule_id="ins_policy_claim",
|
|
410
|
-
body=[
|
|
411
|
-
Atom(relations['policyholder'], (x, y)),
|
|
412
|
-
Atom(relations['claim_number'], (x, z))
|
|
413
|
-
],
|
|
414
|
-
head=Atom(relations['related_to'], (y, z)),
|
|
415
|
-
confidence=0.8
|
|
416
|
-
))
|
|
417
|
-
logger.info(" + Added: ins_policy_claim")
|
|
418
|
-
|
|
419
|
-
if 'assessor' in relations and 'insurer' in relations and 'related_to' in relations:
|
|
420
|
-
rules.append(LogicalRule(
|
|
421
|
-
rule_id="ins_assessor_insurer",
|
|
422
|
-
body=[
|
|
423
|
-
Atom(relations['assessor'], (x, y)),
|
|
424
|
-
Atom(relations['insurer'], (z, y))
|
|
425
|
-
],
|
|
426
|
-
head=Atom(relations['related_to'], (x, z)),
|
|
427
|
-
confidence=0.7
|
|
428
|
-
))
|
|
429
|
-
logger.info(" + Added: ins_assessor_insurer")
|
|
430
|
-
|
|
431
|
-
# --- GENERIC RULES ---
|
|
432
|
-
if 'related_to' in relations:
|
|
433
|
-
rules.append(LogicalRule(
|
|
434
|
-
rule_id="gen_transitivity",
|
|
435
|
-
body=[
|
|
436
|
-
Atom(relations['related_to'], (x, y)),
|
|
437
|
-
Atom(relations['related_to'], (y, z))
|
|
438
|
-
],
|
|
439
|
-
head=Atom(relations['related_to'], (x, z)),
|
|
440
|
-
rule_type=RuleType.TRANSITIVITY,
|
|
441
|
-
confidence=0.5
|
|
442
|
-
))
|
|
443
|
-
logger.info(" + Added: gen_transitivity")
|
|
444
|
-
|
|
445
|
-
if 'has_type' in relations and 'related_to' in relations:
|
|
446
|
-
rules.append(LogicalRule(
|
|
447
|
-
rule_id="gen_type_cooccurrence",
|
|
448
|
-
body=[
|
|
449
|
-
Atom(relations['has_type'], (x, y)),
|
|
450
|
-
Atom(relations['has_type'], (z, y))
|
|
451
|
-
],
|
|
452
|
-
head=Atom(relations['related_to'], (x, z)),
|
|
453
|
-
confidence=0.3
|
|
454
|
-
))
|
|
455
|
-
logger.info(" + Added: gen_type_cooccurrence")
|
|
456
|
-
|
|
457
|
-
# Fallback
|
|
458
|
-
if not rules:
|
|
459
|
-
logger.warning("No domain rules matched. Creating fallback.")
|
|
460
|
-
rel = next(iter(kg.relations))
|
|
461
|
-
rules.append(LogicalRule(
|
|
462
|
-
rule_id="fallback_self",
|
|
463
|
-
body=[Atom(rel, (x, y))],
|
|
464
|
-
head=Atom(rel, (x, y)),
|
|
465
|
-
confidence=0.5
|
|
466
|
-
))
|
|
467
|
-
|
|
468
|
-
logger.info(f"Total rules: {len(rules)}")
|
|
469
|
-
return rules
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
def create_bootstrapper(db: StandardDatabase) -> KnowledgeBootstrapper:
|
|
473
|
-
"""Factory function to create a KnowledgeBootstrapper."""
|
|
474
|
-
return KnowledgeBootstrapper(db)
|
|
1
|
+
"""
|
|
2
|
+
Bootstrap module for NPLL.
|
|
3
|
+
Handles the end-to-end lifecycle of the NPLL model:
|
|
4
|
+
1. Extracting data from ArangoDB
|
|
5
|
+
2. Generating domain-appropriate logical rules
|
|
6
|
+
3. Training the model
|
|
7
|
+
4. Storing ONLY WEIGHTS in database (not full model)
|
|
8
|
+
|
|
9
|
+
Architecture:
|
|
10
|
+
- Weights stored in OdinModels collection (~1 KB)
|
|
11
|
+
- Model rebuilt from KG on each load (~30 sec)
|
|
12
|
+
- No external files needed
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import hashlib
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import random
|
|
20
|
+
import time
|
|
21
|
+
import torch
|
|
22
|
+
from datetime import datetime
|
|
23
|
+
from typing import List, Tuple, Dict, Optional, Any
|
|
24
|
+
from arango.database import StandardDatabase
|
|
25
|
+
|
|
26
|
+
from .core.knowledge_graph import KnowledgeGraph, load_knowledge_graph_from_triples
|
|
27
|
+
from .core.logical_rules import LogicalRule, Atom, Variable, RuleType
|
|
28
|
+
from .npll_model import create_initialized_npll_model, NPLLModel
|
|
29
|
+
from .training.npll_trainer import TrainingConfig, create_trainer
|
|
30
|
+
from .utils.config import get_config
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
# Collection name for storing model weights
|
|
35
|
+
ODIN_MODELS_COLLECTION = "OdinModels"
|
|
36
|
+
NPLL_MODEL_KEY = "npll_current"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class KnowledgeBootstrapper:
|
|
40
|
+
"""
|
|
41
|
+
Manages the lifecycle of the NPLL model.
|
|
42
|
+
|
|
43
|
+
Storage Strategy:
|
|
44
|
+
- Only rule weights are saved to database (~1 KB)
|
|
45
|
+
- Model is rebuilt from KG data on each load (~30 sec)
|
|
46
|
+
- No external .pt files needed
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, db: StandardDatabase):
|
|
50
|
+
"""
|
|
51
|
+
Initialize the bootstrapper.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
db: An already-connected ArangoDB database instance
|
|
55
|
+
"""
|
|
56
|
+
self.db = db
|
|
57
|
+
self._ensure_collection_exists()
|
|
58
|
+
|
|
59
|
+
def _ensure_collection_exists(self):
|
|
60
|
+
"""Ensure OdinModels collection exists."""
|
|
61
|
+
try:
|
|
62
|
+
if not self.db.has_collection(ODIN_MODELS_COLLECTION):
|
|
63
|
+
self.db.create_collection(ODIN_MODELS_COLLECTION)
|
|
64
|
+
logger.info(f"Created {ODIN_MODELS_COLLECTION} collection")
|
|
65
|
+
except Exception as e:
|
|
66
|
+
# Suppress permission errors (common in read-only environments)
|
|
67
|
+
logger.warning(f"Could not verify/create {ODIN_MODELS_COLLECTION} (Permission Error?): {e}")
|
|
68
|
+
|
|
69
|
+
def ensure_model_ready(self, force_retrain: bool = False) -> Optional[NPLLModel]:
|
|
70
|
+
"""
|
|
71
|
+
Ensures a trained NPLL model is available.
|
|
72
|
+
|
|
73
|
+
Flow:
|
|
74
|
+
1. Compute current data hash
|
|
75
|
+
2. Check OdinModels for saved weights with matching hash
|
|
76
|
+
3. If found: rebuild model from KG, apply saved weights
|
|
77
|
+
4. If not found: train new model, save weights to DB
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
force_retrain: If True, ignores cached weights and retrains
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Loaded NPLLModel ready for inference, or None if failed
|
|
84
|
+
"""
|
|
85
|
+
current_hash = self._compute_data_hash()
|
|
86
|
+
logger.info(f"Current data hash: {current_hash[:16]}...")
|
|
87
|
+
|
|
88
|
+
if not force_retrain:
|
|
89
|
+
# Try to load existing weights and rebuild model
|
|
90
|
+
model = self._load_model_with_weights(current_hash)
|
|
91
|
+
if model:
|
|
92
|
+
return model
|
|
93
|
+
|
|
94
|
+
# Train new model
|
|
95
|
+
logger.info("Training new NPLL model...")
|
|
96
|
+
return self._train_and_save_weights(current_hash)
|
|
97
|
+
|
|
98
|
+
def _compute_data_hash(self) -> str:
|
|
99
|
+
"""
|
|
100
|
+
Compute a hash of the current schema to detect when retraining is needed.
|
|
101
|
+
"""
|
|
102
|
+
try:
|
|
103
|
+
# Get relation names
|
|
104
|
+
rel_query = """
|
|
105
|
+
FOR e IN ExtractedRelationships
|
|
106
|
+
COLLECT rel = e.relationship WITH COUNT INTO cnt
|
|
107
|
+
SORT rel
|
|
108
|
+
RETURN {rel: rel, count: cnt}
|
|
109
|
+
"""
|
|
110
|
+
relations = list(self.db.aql.execute(rel_query))
|
|
111
|
+
relation_names = sorted([r['rel'] for r in relations if r['rel']])
|
|
112
|
+
|
|
113
|
+
# Get counts
|
|
114
|
+
entity_count = self.db.collection("ExtractedEntities").count()
|
|
115
|
+
fact_count = self.db.collection("ExtractedRelationships").count()
|
|
116
|
+
|
|
117
|
+
hash_input = {
|
|
118
|
+
"relations": relation_names,
|
|
119
|
+
"entity_count": entity_count,
|
|
120
|
+
"fact_count": fact_count,
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
hash_str = json.dumps(hash_input, sort_keys=True)
|
|
124
|
+
return hashlib.sha256(hash_str.encode()).hexdigest()
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
logger.warning(f"Could not compute data hash: {e}")
|
|
128
|
+
return hashlib.sha256(str(time.time()).encode()).hexdigest()
|
|
129
|
+
|
|
130
|
+
def _load_model_with_weights(self, expected_hash: str) -> Optional[NPLLModel]:
|
|
131
|
+
"""
|
|
132
|
+
Load saved weights from DB and rebuild the model.
|
|
133
|
+
|
|
134
|
+
Flow:
|
|
135
|
+
1. Check if saved weights exist with matching hash
|
|
136
|
+
2. Extract triples from DB → build KG
|
|
137
|
+
3. Generate rules (same code = same rules)
|
|
138
|
+
4. Initialize fresh model
|
|
139
|
+
5. Apply saved weights
|
|
140
|
+
"""
|
|
141
|
+
try:
|
|
142
|
+
collection = self.db.collection(ODIN_MODELS_COLLECTION)
|
|
143
|
+
doc = collection.get(NPLL_MODEL_KEY)
|
|
144
|
+
|
|
145
|
+
if not doc:
|
|
146
|
+
logger.info("No saved weights found in database")
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
stored_hash = doc.get("data_hash", "")
|
|
150
|
+
if stored_hash != expected_hash:
|
|
151
|
+
logger.info(f"Data has changed. Stored: {stored_hash[:16]}..., Current: {expected_hash[:16]}...")
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
# Get saved weights
|
|
155
|
+
saved_weights = doc.get("rule_weights")
|
|
156
|
+
if not saved_weights:
|
|
157
|
+
logger.warning("No rule_weights in saved document")
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
logger.info("Rebuilding model from KG and applying saved weights...")
|
|
161
|
+
|
|
162
|
+
# 1. Extract triples
|
|
163
|
+
triples = self._extract_triples()
|
|
164
|
+
if not triples:
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
# 2. Build KG
|
|
168
|
+
kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
|
|
169
|
+
|
|
170
|
+
# 3. Generate rules (same code = same rules)
|
|
171
|
+
rules = self._generate_smart_rules(kg)
|
|
172
|
+
|
|
173
|
+
if len(rules) != len(saved_weights):
|
|
174
|
+
logger.warning(f"Rule count mismatch: {len(rules)} rules, {len(saved_weights)} weights. Retraining.")
|
|
175
|
+
return None
|
|
176
|
+
|
|
177
|
+
# 4. Initialize model
|
|
178
|
+
config = get_config("ArangoDB_Triples")
|
|
179
|
+
model = create_initialized_npll_model(kg, rules, config)
|
|
180
|
+
|
|
181
|
+
# 5. Apply saved weights
|
|
182
|
+
with torch.no_grad():
|
|
183
|
+
model.mln.rule_weights.copy_(torch.tensor(saved_weights, dtype=torch.float32))
|
|
184
|
+
|
|
185
|
+
trained_at = doc.get("trained_at", "unknown")
|
|
186
|
+
logger.info(f"✓ Model rebuilt with saved weights (trained: {trained_at})")
|
|
187
|
+
|
|
188
|
+
return model
|
|
189
|
+
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.warning(f"Failed to load model: {e}")
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def _train_and_save_weights(self, data_hash: str) -> Optional[NPLLModel]:
|
|
195
|
+
"""
|
|
196
|
+
Train a new NPLL model and save ONLY the weights to database.
|
|
197
|
+
"""
|
|
198
|
+
# 1. Extract Triples
|
|
199
|
+
triples = self._extract_triples()
|
|
200
|
+
if not triples:
|
|
201
|
+
logger.error("No triples extracted. Cannot train.")
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
# 2. Build KG
|
|
205
|
+
kg = load_knowledge_graph_from_triples(triples, "ArangoDB_KG")
|
|
206
|
+
logger.info(f"Built KG: {len(kg.entities)} entities, {len(kg.relations)} relations, {len(kg.known_facts)} facts")
|
|
207
|
+
|
|
208
|
+
# Create unknown facts for training (10%)
|
|
209
|
+
known_facts_list = list(kg.known_facts)
|
|
210
|
+
random.seed(42)
|
|
211
|
+
num_unknown = max(1, len(known_facts_list) // 10)
|
|
212
|
+
unknown_facts = random.sample(known_facts_list, num_unknown)
|
|
213
|
+
|
|
214
|
+
for fact in unknown_facts:
|
|
215
|
+
kg.known_facts.remove(fact)
|
|
216
|
+
kg.add_unknown_fact(fact.head.name, fact.relation.name, fact.tail.name)
|
|
217
|
+
|
|
218
|
+
# 3. Generate Rules
|
|
219
|
+
rules = self._generate_smart_rules(kg)
|
|
220
|
+
logger.info(f"Generated {len(rules)} logical rules")
|
|
221
|
+
|
|
222
|
+
if not rules:
|
|
223
|
+
logger.error("No rules generated. Cannot train.")
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
# 4. Initialize Model
|
|
227
|
+
config = get_config("ArangoDB_Triples")
|
|
228
|
+
model = create_initialized_npll_model(kg, rules, config)
|
|
229
|
+
|
|
230
|
+
# 5. Train
|
|
231
|
+
train_config = TrainingConfig(
|
|
232
|
+
num_epochs=10,
|
|
233
|
+
max_em_iterations_per_epoch=5,
|
|
234
|
+
early_stopping_patience=3,
|
|
235
|
+
save_checkpoints=False
|
|
236
|
+
)
|
|
237
|
+
trainer = create_trainer(model, train_config)
|
|
238
|
+
|
|
239
|
+
training_result = None
|
|
240
|
+
try:
|
|
241
|
+
logger.info("Starting NPLL training...")
|
|
242
|
+
training_result = trainer.train()
|
|
243
|
+
logger.info(f"Training completed. Final ELBO: {training_result.final_elbo}")
|
|
244
|
+
except Exception as e:
|
|
245
|
+
logger.error(f"Training failed: {e}", exc_info=True)
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
# 6. Save ONLY weights to database
|
|
249
|
+
self._save_weights_to_db(model, kg, rules, data_hash, training_result)
|
|
250
|
+
|
|
251
|
+
return model
|
|
252
|
+
|
|
253
|
+
def _save_weights_to_db(self, model: NPLLModel, kg: KnowledgeGraph,
|
|
254
|
+
rules: List[LogicalRule], data_hash: str,
|
|
255
|
+
training_result: Any):
|
|
256
|
+
"""
|
|
257
|
+
Save ONLY the learned weights to OdinModels collection.
|
|
258
|
+
This is tiny (~1 KB) compared to the full model (280 MB).
|
|
259
|
+
"""
|
|
260
|
+
try:
|
|
261
|
+
# Extract just the rule weights
|
|
262
|
+
rule_weights = model.mln.rule_weights.detach().cpu().tolist()
|
|
263
|
+
|
|
264
|
+
doc = {
|
|
265
|
+
"_key": NPLL_MODEL_KEY,
|
|
266
|
+
"model_type": "npll",
|
|
267
|
+
"storage_type": "weights_only", # Mark this as weights-only storage
|
|
268
|
+
"trained_at": datetime.utcnow().isoformat() + "Z",
|
|
269
|
+
"data_hash": data_hash,
|
|
270
|
+
"rule_weights": rule_weights, # The learned weights - this is all we need!
|
|
271
|
+
"schema_snapshot": {
|
|
272
|
+
"entity_count": len(kg.entities),
|
|
273
|
+
"relation_count": len(kg.relations),
|
|
274
|
+
"fact_count": len(kg.known_facts),
|
|
275
|
+
"relation_names": sorted([r.name for r in kg.relations])[:50],
|
|
276
|
+
},
|
|
277
|
+
"training_result": {
|
|
278
|
+
"final_elbo": float(training_result.final_elbo) if training_result else 0,
|
|
279
|
+
"best_elbo": float(training_result.best_elbo) if training_result else 0,
|
|
280
|
+
"converged": training_result.converged if training_result else False,
|
|
281
|
+
"training_time_seconds": training_result.total_training_time if training_result else 0,
|
|
282
|
+
},
|
|
283
|
+
"rules": [
|
|
284
|
+
{
|
|
285
|
+
"rule_id": r.rule_id,
|
|
286
|
+
"rule_text": str(r),
|
|
287
|
+
"confidence": r.confidence,
|
|
288
|
+
}
|
|
289
|
+
for r in rules
|
|
290
|
+
],
|
|
291
|
+
"version": "2.0", # Version 2 = weights-only storage
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
# Upsert
|
|
295
|
+
collection = self.db.collection(ODIN_MODELS_COLLECTION)
|
|
296
|
+
if collection.has(NPLL_MODEL_KEY):
|
|
297
|
+
collection.update(doc)
|
|
298
|
+
else:
|
|
299
|
+
collection.insert(doc)
|
|
300
|
+
|
|
301
|
+
weights_size = len(json.dumps(rule_weights))
|
|
302
|
+
logger.info(f"✓ Saved rule weights to database ({weights_size} bytes)")
|
|
303
|
+
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.error(f"Failed to save weights: {e}", exc_info=True)
|
|
306
|
+
|
|
307
|
+
def _extract_triples(self) -> List[Tuple[str, str, str]]:
|
|
308
|
+
"""Extracts S-P-O triples from ArangoDB."""
|
|
309
|
+
logger.info("Extracting triples from database...")
|
|
310
|
+
triples = []
|
|
311
|
+
|
|
312
|
+
# Extract Relationships
|
|
313
|
+
query = """
|
|
314
|
+
FOR rel IN ExtractedRelationships
|
|
315
|
+
LET source = DOCUMENT(rel._from)
|
|
316
|
+
LET target = DOCUMENT(rel._to)
|
|
317
|
+
FILTER source != null AND target != null
|
|
318
|
+
FILTER source._key != null AND target._key != null
|
|
319
|
+
RETURN {
|
|
320
|
+
source: source._key,
|
|
321
|
+
target: target._key,
|
|
322
|
+
relation: rel.relationship || "related_to"
|
|
323
|
+
}
|
|
324
|
+
"""
|
|
325
|
+
try:
|
|
326
|
+
cursor = self.db.aql.execute(query)
|
|
327
|
+
for doc in cursor:
|
|
328
|
+
s, t = doc['source'], doc['target']
|
|
329
|
+
r = str(doc['relation']).replace(' ', '_').lower()
|
|
330
|
+
triples.append((s, r, t))
|
|
331
|
+
logger.info(f"Extracted {len(triples)} relationship triples")
|
|
332
|
+
except Exception as e:
|
|
333
|
+
logger.error(f"Extraction error: {e}")
|
|
334
|
+
return []
|
|
335
|
+
|
|
336
|
+
# Extract Entity Types
|
|
337
|
+
query_types = """
|
|
338
|
+
FOR entity IN ExtractedEntities
|
|
339
|
+
FILTER entity._key != null AND entity.type != null
|
|
340
|
+
RETURN { key: entity._key, type: entity.type }
|
|
341
|
+
"""
|
|
342
|
+
try:
|
|
343
|
+
cursor = self.db.aql.execute(query_types)
|
|
344
|
+
type_count = 0
|
|
345
|
+
for doc in cursor:
|
|
346
|
+
triples.append((doc['key'], 'has_type', doc['type']))
|
|
347
|
+
type_count += 1
|
|
348
|
+
logger.info(f"Extracted {type_count} entity type triples")
|
|
349
|
+
except Exception as e:
|
|
350
|
+
logger.error(f"Type extraction error: {e}")
|
|
351
|
+
|
|
352
|
+
logger.info(f"Total triples: {len(triples)}")
|
|
353
|
+
return triples
|
|
354
|
+
|
|
355
|
+
def _generate_smart_rules(self, kg: KnowledgeGraph) -> List[LogicalRule]:
|
|
356
|
+
"""
|
|
357
|
+
Generates domain-appropriate rules based on available relations.
|
|
358
|
+
"""
|
|
359
|
+
rules = []
|
|
360
|
+
relations = {r.name: r for r in kg.relations}
|
|
361
|
+
x, y, z = Variable("?x"), Variable("?y"), Variable("?z")
|
|
362
|
+
|
|
363
|
+
logger.info(f"Generating rules for {len(relations)} relation types...")
|
|
364
|
+
|
|
365
|
+
# --- HEALTHCARE DOMAIN ---
|
|
366
|
+
if 'has_claim' in relations and 'submitted_by_provider' in relations and 'treated_by' in relations:
|
|
367
|
+
rules.append(LogicalRule(
|
|
368
|
+
rule_id="hc_claim_provider_link",
|
|
369
|
+
body=[
|
|
370
|
+
Atom(relations['has_claim'], (x, y)),
|
|
371
|
+
Atom(relations['submitted_by_provider'], (y, z))
|
|
372
|
+
],
|
|
373
|
+
head=Atom(relations['treated_by'], (x, z)),
|
|
374
|
+
confidence=0.7
|
|
375
|
+
))
|
|
376
|
+
logger.info(" + Added: hc_claim_provider_link")
|
|
377
|
+
|
|
378
|
+
if 'diagnosed_with' in relations and 'indicates' in relations:
|
|
379
|
+
target_rel = relations.get('recommended_procedure') or relations.get('related_to')
|
|
380
|
+
if target_rel:
|
|
381
|
+
rules.append(LogicalRule(
|
|
382
|
+
rule_id="hc_diagnosis_procedure",
|
|
383
|
+
body=[
|
|
384
|
+
Atom(relations['diagnosed_with'], (x, y)),
|
|
385
|
+
Atom(relations['indicates'], (y, z))
|
|
386
|
+
],
|
|
387
|
+
head=Atom(target_rel, (x, z)),
|
|
388
|
+
confidence=0.6
|
|
389
|
+
))
|
|
390
|
+
logger.info(" + Added: hc_diagnosis_procedure")
|
|
391
|
+
|
|
392
|
+
if 'works_at' in relations and 'located_at' in relations:
|
|
393
|
+
target_rel = relations.get('affiliated_with') or relations.get('related_to')
|
|
394
|
+
if target_rel:
|
|
395
|
+
rules.append(LogicalRule(
|
|
396
|
+
rule_id="hc_provider_facility",
|
|
397
|
+
body=[
|
|
398
|
+
Atom(relations['works_at'], (x, y)),
|
|
399
|
+
Atom(relations['located_at'], (y, z))
|
|
400
|
+
],
|
|
401
|
+
head=Atom(target_rel, (x, z)),
|
|
402
|
+
confidence=0.6
|
|
403
|
+
))
|
|
404
|
+
logger.info(" + Added: hc_provider_facility")
|
|
405
|
+
|
|
406
|
+
# --- INSURANCE DOMAIN ---
|
|
407
|
+
if 'policyholder' in relations and 'claim_number' in relations and 'related_to' in relations:
|
|
408
|
+
rules.append(LogicalRule(
|
|
409
|
+
rule_id="ins_policy_claim",
|
|
410
|
+
body=[
|
|
411
|
+
Atom(relations['policyholder'], (x, y)),
|
|
412
|
+
Atom(relations['claim_number'], (x, z))
|
|
413
|
+
],
|
|
414
|
+
head=Atom(relations['related_to'], (y, z)),
|
|
415
|
+
confidence=0.8
|
|
416
|
+
))
|
|
417
|
+
logger.info(" + Added: ins_policy_claim")
|
|
418
|
+
|
|
419
|
+
if 'assessor' in relations and 'insurer' in relations and 'related_to' in relations:
|
|
420
|
+
rules.append(LogicalRule(
|
|
421
|
+
rule_id="ins_assessor_insurer",
|
|
422
|
+
body=[
|
|
423
|
+
Atom(relations['assessor'], (x, y)),
|
|
424
|
+
Atom(relations['insurer'], (z, y))
|
|
425
|
+
],
|
|
426
|
+
head=Atom(relations['related_to'], (x, z)),
|
|
427
|
+
confidence=0.7
|
|
428
|
+
))
|
|
429
|
+
logger.info(" + Added: ins_assessor_insurer")
|
|
430
|
+
|
|
431
|
+
# --- GENERIC RULES ---
|
|
432
|
+
if 'related_to' in relations:
|
|
433
|
+
rules.append(LogicalRule(
|
|
434
|
+
rule_id="gen_transitivity",
|
|
435
|
+
body=[
|
|
436
|
+
Atom(relations['related_to'], (x, y)),
|
|
437
|
+
Atom(relations['related_to'], (y, z))
|
|
438
|
+
],
|
|
439
|
+
head=Atom(relations['related_to'], (x, z)),
|
|
440
|
+
rule_type=RuleType.TRANSITIVITY,
|
|
441
|
+
confidence=0.5
|
|
442
|
+
))
|
|
443
|
+
logger.info(" + Added: gen_transitivity")
|
|
444
|
+
|
|
445
|
+
if 'has_type' in relations and 'related_to' in relations:
|
|
446
|
+
rules.append(LogicalRule(
|
|
447
|
+
rule_id="gen_type_cooccurrence",
|
|
448
|
+
body=[
|
|
449
|
+
Atom(relations['has_type'], (x, y)),
|
|
450
|
+
Atom(relations['has_type'], (z, y))
|
|
451
|
+
],
|
|
452
|
+
head=Atom(relations['related_to'], (x, z)),
|
|
453
|
+
confidence=0.3
|
|
454
|
+
))
|
|
455
|
+
logger.info(" + Added: gen_type_cooccurrence")
|
|
456
|
+
|
|
457
|
+
# Fallback
|
|
458
|
+
if not rules:
|
|
459
|
+
logger.warning("No domain rules matched. Creating fallback.")
|
|
460
|
+
rel = next(iter(kg.relations))
|
|
461
|
+
rules.append(LogicalRule(
|
|
462
|
+
rule_id="fallback_self",
|
|
463
|
+
body=[Atom(rel, (x, y))],
|
|
464
|
+
head=Atom(rel, (x, y)),
|
|
465
|
+
confidence=0.5
|
|
466
|
+
))
|
|
467
|
+
|
|
468
|
+
logger.info(f"Total rules: {len(rules)}")
|
|
469
|
+
return rules
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def create_bootstrapper(db: StandardDatabase) -> KnowledgeBootstrapper:
|
|
473
|
+
"""Factory function to create a KnowledgeBootstrapper."""
|
|
474
|
+
return KnowledgeBootstrapper(db)
|