odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. benchmarks/__init__.py +17 -17
  2. benchmarks/datasets.py +284 -284
  3. benchmarks/metrics.py +275 -275
  4. benchmarks/run_ablation.py +279 -279
  5. benchmarks/run_npll_benchmark.py +270 -270
  6. npll/__init__.py +10 -10
  7. npll/bootstrap.py +474 -474
  8. npll/core/__init__.py +33 -33
  9. npll/core/knowledge_graph.py +308 -308
  10. npll/core/logical_rules.py +496 -496
  11. npll/core/mln.py +474 -474
  12. npll/inference/__init__.py +40 -40
  13. npll/inference/e_step.py +419 -419
  14. npll/inference/elbo.py +434 -434
  15. npll/inference/m_step.py +576 -576
  16. npll/npll_model.py +631 -631
  17. npll/scoring/__init__.py +42 -42
  18. npll/scoring/embeddings.py +441 -441
  19. npll/scoring/probability.py +402 -402
  20. npll/scoring/scoring_module.py +369 -369
  21. npll/training/__init__.py +24 -24
  22. npll/training/evaluation.py +496 -496
  23. npll/training/npll_trainer.py +520 -520
  24. npll/utils/__init__.py +47 -47
  25. npll/utils/batch_utils.py +492 -492
  26. npll/utils/config.py +144 -144
  27. npll/utils/math_utils.py +338 -338
  28. odin/__init__.py +21 -20
  29. odin/engine.py +264 -264
  30. odin/schema.py +210 -0
  31. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
  32. odin_engine-0.2.0.dist-info/RECORD +63 -0
  33. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
  34. retrieval/__init__.py +50 -50
  35. retrieval/adapters.py +140 -140
  36. retrieval/adapters_arango.py +1418 -1418
  37. retrieval/aggregators.py +707 -707
  38. retrieval/beam.py +127 -127
  39. retrieval/budget.py +60 -60
  40. retrieval/cache.py +159 -159
  41. retrieval/confidence.py +88 -88
  42. retrieval/eval.py +49 -49
  43. retrieval/linker.py +87 -87
  44. retrieval/metrics.py +105 -105
  45. retrieval/metrics_motifs.py +36 -36
  46. retrieval/orchestrator.py +571 -571
  47. retrieval/ppr/__init__.py +12 -12
  48. retrieval/ppr/anchors.py +41 -41
  49. retrieval/ppr/bippr.py +61 -61
  50. retrieval/ppr/engines.py +257 -257
  51. retrieval/ppr/global_pr.py +76 -76
  52. retrieval/ppr/indexes.py +78 -78
  53. retrieval/ppr.py +156 -156
  54. retrieval/ppr_cache.py +25 -25
  55. retrieval/scoring.py +294 -294
  56. retrieval/utils/pii_redaction.py +36 -36
  57. retrieval/writers/__init__.py +9 -9
  58. retrieval/writers/arango_writer.py +28 -28
  59. retrieval/writers/base.py +21 -21
  60. retrieval/writers/janus_writer.py +36 -36
  61. odin_engine-0.1.0.dist-info/RECORD +0 -62
  62. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
  63. {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,497 +1,497 @@
1
- """
2
- Logical Rules and Ground Rules for NPLL
3
- """
4
-
5
- from typing import List, Set, Dict, Tuple, Optional, Iterator, Any
6
- from dataclasses import dataclass, field
7
- from enum import Enum
8
- import itertools
9
- import logging
10
- from collections import defaultdict
11
-
12
- from .knowledge_graph import Entity, Relation, Triple, KnowledgeGraph
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class Variable:
18
- """Variable in logical rules"""
19
-
20
- def __init__(self, name: str):
21
- self.name = name
22
-
23
- def __str__(self) -> str:
24
- return self.name
25
-
26
- def __repr__(self) -> str:
27
- return f"Var({self.name})"
28
-
29
- def __eq__(self, other) -> bool:
30
- return isinstance(other, Variable) and self.name == other.name
31
-
32
- def __hash__(self) -> int:
33
- return hash(self.name)
34
-
35
-
36
- @dataclass(frozen=True)
37
- class Atom:
38
- """
39
- Atomic formula in first-order logic: Pred(arg1, arg2)
40
- """
41
- predicate: Relation
42
- arguments: Tuple[Any, ...]
43
-
44
- def __post_init__(self):
45
- # Validate arguments
46
- for arg in self.arguments:
47
- if not isinstance(arg, (Variable, Entity)):
48
- raise ValueError(f"Atom arguments must be Variable or Entity, got {type(arg)}")
49
-
50
- def __str__(self) -> str:
51
- args_str = ", ".join(str(arg) for arg in self.arguments)
52
- return f"{self.predicate.name}({args_str})"
53
-
54
- def is_ground(self) -> bool:
55
- """Check if atom is ground (no variables)"""
56
- return all(isinstance(arg, Entity) for arg in self.arguments)
57
-
58
- def get_variables(self) -> Set[Variable]:
59
- """Get all variables in this atom"""
60
- return {arg for arg in self.arguments if isinstance(arg, Variable)}
61
-
62
- def ground_with_substitution(self, substitution: Dict[Variable, Entity]) -> 'Atom':
63
- """Ground atom by substituting variables with entities"""
64
- new_args = []
65
- for arg in self.arguments:
66
- if isinstance(arg, Variable) and arg in substitution:
67
- new_args.append(substitution[arg])
68
- else:
69
- new_args.append(arg)
70
-
71
- return Atom(predicate=self.predicate, arguments=tuple(new_args))
72
-
73
- def to_triple(self) -> Optional[Triple]:
74
- """Convert ground atom to triple (if binary and ground)"""
75
- if len(self.arguments) == 2 and self.is_ground():
76
- head, tail = self.arguments
77
- return Triple(head=head, relation=self.predicate, tail=tail)
78
- return None
79
-
80
-
81
- class RuleType(Enum):
82
- """Types of logical rules"""
83
- HORN_CLAUSE = "horn" # Standard Horn clause
84
- EQUALITY = "equality" # Equality rules
85
- TRANSITIVITY = "transitivity" # Transitive rules
86
- SYMMETRY = "symmetry" # Symmetric rules
87
- GENERAL = "general" # General first-order rules
88
-
89
-
90
- @dataclass
91
- class LogicalRule:
92
- """
93
- First-order logical rule: Body ⇒ Head
94
- """
95
-
96
- rule_id: str
97
- body: List[Atom] # Body atoms (premise)
98
- head: Atom # Head atom (conclusion)
99
- rule_type: RuleType = RuleType.GENERAL
100
- confidence: float = 0.5 # Initial confidence score
101
- support: int = 0 # Number of supporting instances
102
-
103
- def __post_init__(self):
104
- """Validate rule structure"""
105
- if not self.body:
106
- raise ValueError("Rule body cannot be empty")
107
-
108
- if not isinstance(self.head, Atom):
109
- raise ValueError("Rule head must be an Atom")
110
-
111
- # Check variable consistency
112
- body_vars = set()
113
- for atom in self.body:
114
- body_vars.update(atom.get_variables())
115
-
116
- head_vars = self.head.get_variables()
117
-
118
- # All head variables should appear in body
119
- if not head_vars.issubset(body_vars):
120
- logger.warning(f"Rule {self.rule_id}: Head variables not in body")
121
-
122
- def get_all_variables(self) -> Set[Variable]:
123
- """Get all variables in the rule"""
124
- variables = set()
125
- for atom in self.body:
126
- variables.update(atom.get_variables())
127
- variables.update(self.head.get_variables())
128
- return variables
129
-
130
- def get_predicates(self) -> Set[Relation]:
131
- """Get all predicates used in the rule"""
132
- predicates = {atom.predicate for atom in self.body}
133
- predicates.add(self.head.predicate)
134
- return predicates
135
-
136
- def generate_ground_rules(self, kg: KnowledgeGraph,
137
- max_groundings: int = 1000) -> List['GroundRule']:
138
- """
139
- Generate ground rules by substituting variables with entities
140
- """
141
- ground_rules = []
142
- entities = list(kg.entities)
143
-
144
- if not entities:
145
- return ground_rules
146
-
147
- # Get all variables that need to be substituted
148
- variables = list(self.get_all_variables())
149
-
150
- if not variables:
151
- # Already ground rule
152
- ground_body = [atom.to_triple() for atom in self.body if atom.is_ground()]
153
- ground_head_triple = self.head.to_triple()
154
-
155
- if ground_head_triple and all(t is not None for t in ground_body):
156
- ground_rules.append(GroundRule(
157
- rule_id=self.rule_id,
158
- body_facts=[t for t in ground_body if t is not None],
159
- head_fact=ground_head_triple,
160
- parent_rule=self
161
- ))
162
- return ground_rules
163
-
164
- # Generate all possible variable substitutions
165
- # Limit combinations to prevent explosion
166
- max_entities_per_var = min(len(entities), max_groundings // len(variables) + 1)
167
-
168
- substitution_count = 0
169
- for substitution_values in itertools.product(entities[:max_entities_per_var],
170
- repeat=len(variables)):
171
- if substitution_count >= max_groundings:
172
- break
173
-
174
- substitution = dict(zip(variables, substitution_values))
175
-
176
- # Generate ground atoms
177
- try:
178
- ground_body_atoms = [atom.ground_with_substitution(substitution)
179
- for atom in self.body]
180
- ground_head_atom = self.head.ground_with_substitution(substitution)
181
-
182
- # Convert to triples
183
- ground_body_triples = [atom.to_triple() for atom in ground_body_atoms
184
- if atom.is_ground()]
185
- ground_head_triple = ground_head_atom.to_triple()
186
-
187
- # Check if all conversions successful
188
- if (ground_head_triple and
189
- len(ground_body_triples) == len(ground_body_atoms) and
190
- all(t is not None for t in ground_body_triples)):
191
-
192
- ground_rule = GroundRule(
193
- rule_id=f"{self.rule_id}_ground_{substitution_count}",
194
- body_facts=ground_body_triples,
195
- head_fact=ground_head_triple,
196
- parent_rule=self,
197
- substitution=substitution.copy()
198
- )
199
-
200
- ground_rules.append(ground_rule)
201
- substitution_count += 1
202
-
203
- except Exception as e:
204
- logger.debug(f"Failed to ground rule {self.rule_id}: {e}")
205
- continue
206
-
207
- logger.info(f"Generated {len(ground_rules)} ground rules for {self.rule_id}")
208
- return ground_rules
209
-
210
- def to_cnf(self) -> str:
211
- """
212
- Convert rule to Conjunctive Normal Form (CNF)
213
- """
214
- # Body atoms become negated disjuncts
215
- cnf_parts = [f"¬{atom}" for atom in self.body]
216
- # Head atom becomes positive disjunct
217
- cnf_parts.append(str(self.head))
218
-
219
- return " ∨ ".join(cnf_parts)
220
-
221
- def __str__(self) -> str:
222
- body_str = " ∧ ".join(str(atom) for atom in self.body)
223
- return f"{body_str} ⇒ {self.head}"
224
-
225
- def __repr__(self) -> str:
226
- return f"LogicalRule(id={self.rule_id}, {self})"
227
-
228
- # --- Serialization helpers ---
229
- def serialize(self) -> Dict[str, Any]:
230
- def _atom_to_dict(atom: Atom) -> Dict[str, Any]:
231
- args = []
232
- for a in atom.arguments:
233
- if isinstance(a, Variable):
234
- args.append({'type': 'var', 'name': a.name})
235
- elif isinstance(a, Entity):
236
- args.append({'type': 'ent', 'name': a.name})
237
- else:
238
- args.append({'type': 'raw', 'value': str(a)})
239
- return {'predicate': atom.predicate.name, 'arguments': args}
240
-
241
- return {
242
- 'rule_id': self.rule_id,
243
- 'rule_type': self.rule_type.value,
244
- 'confidence': self.confidence,
245
- 'support': self.support,
246
- 'body': [_atom_to_dict(a) for a in self.body],
247
- 'head': _atom_to_dict(self.head),
248
- }
249
-
250
- @staticmethod
251
- def deserialize(data: Dict[str, Any]) -> 'LogicalRule':
252
- rule_id = data['rule_id']
253
- rule_type = RuleType(data.get('rule_type', RuleType.GENERAL.value))
254
- confidence = data.get('confidence', 0.5)
255
- support = data.get('support', 0)
256
-
257
- def _atom_from_dict(d: Dict[str, Any]) -> Atom:
258
- pred = Relation(d['predicate'])
259
- args = []
260
- for a in d['arguments']:
261
- if a.get('type') == 'var':
262
- args.append(Variable(a['name']))
263
- elif a.get('type') == 'ent':
264
- args.append(Entity(a['name']))
265
- else:
266
- # Fallback: treat as entity by name
267
- args.append(Entity(str(a.get('value', ''))))
268
- return Atom(predicate=pred, arguments=tuple(args))
269
-
270
- body_atoms = [_atom_from_dict(x) for x in data['body']]
271
- head_atom = _atom_from_dict(data['head'])
272
- return LogicalRule(
273
- rule_id=rule_id,
274
- body=body_atoms,
275
- head=head_atom,
276
- rule_type=rule_type,
277
- confidence=confidence,
278
- support=support,
279
- )
280
-
281
-
282
- @dataclass
283
- class GroundRule:
284
- """
285
- Ground instance of a logical rule with all variables substituted
286
-
287
- """
288
-
289
- rule_id: str
290
- body_facts: List[Triple] # Ground body facts
291
- head_fact: Triple # Ground head fact
292
- parent_rule: LogicalRule # Original rule this was grounded from
293
- substitution: Optional[Dict[Variable, Entity]] = None
294
-
295
- def __post_init__(self):
296
- """Validate ground rule"""
297
- if not self.body_facts:
298
- raise ValueError("Ground rule body cannot be empty")
299
-
300
- if not isinstance(self.head_fact, Triple):
301
- raise ValueError("Ground rule head must be a Triple")
302
-
303
- def evaluate_truth_value(self, kg: KnowledgeGraph) -> bool:
304
- """
305
- Evaluate truth value of ground rule given knowledge graph
306
- """
307
- # Check if all body facts are true in KG
308
- body_satisfied = all(
309
- kg.contains_fact(fact.head.name, fact.relation.name, fact.tail.name)
310
- for fact in self.body_facts
311
- )
312
-
313
- # If body is false, rule is vacuously true
314
- if not body_satisfied:
315
- return True
316
-
317
- # If body is true, check if head is true
318
- head_satisfied = kg.contains_fact(
319
- self.head_fact.head.name,
320
- self.head_fact.relation.name,
321
- self.head_fact.tail.name
322
- )
323
-
324
- return head_satisfied
325
-
326
- def get_all_facts(self) -> List[Triple]:
327
- """Get all facts (body + head) in this ground rule"""
328
- return self.body_facts + [self.head_fact]
329
-
330
- def get_fact_truth_values(self, kg: KnowledgeGraph) -> Dict[Triple, bool]:
331
- """Get truth values for all facts in this ground rule"""
332
- truth_values = {}
333
-
334
- for fact in self.body_facts:
335
- truth_values[fact] = kg.contains_fact(
336
- fact.head.name, fact.relation.name, fact.tail.name
337
- )
338
-
339
- truth_values[self.head_fact] = kg.contains_fact(
340
- self.head_fact.head.name,
341
- self.head_fact.relation.name,
342
- self.head_fact.tail.name
343
- )
344
-
345
- return truth_values
346
-
347
- def __str__(self) -> str:
348
- body_str = " ∧ ".join(str(fact) for fact in self.body_facts)
349
- return f"{body_str} ⇒ {self.head_fact}"
350
-
351
- def __repr__(self) -> str:
352
- return f"GroundRule(id={self.rule_id}, {self})"
353
-
354
-
355
- class RuleGenerator:
356
- """Generate logical rules from knowledge graph patterns"""
357
-
358
- def __init__(self, kg: KnowledgeGraph):
359
- self.kg = kg
360
-
361
- def generate_simple_rules(self, min_support: int = 2,
362
- max_rule_length: int = 3) -> List[LogicalRule]:
363
- """
364
- Generate simple logical rules from knowledge graph patterns
365
- """
366
- rules = []
367
-
368
- # Generate transitivity rules: R1(x,y) ∧ R2(y,z) ⇒ R3(x,z)
369
- relations = list(self.kg.relations)
370
-
371
- for r1, r2, r3 in itertools.combinations_with_replacement(relations, 3):
372
- if r1 == r2 == r3: # Skip trivial cases
373
- continue
374
-
375
- # Create variables
376
- x, y, z = Variable('x'), Variable('y'), Variable('z')
377
-
378
- # Create atoms
379
- atom1 = Atom(predicate=r1, arguments=(x, y))
380
- atom2 = Atom(predicate=r2, arguments=(y, z))
381
- head_atom = Atom(predicate=r3, arguments=(x, z))
382
-
383
- rule = LogicalRule(
384
- rule_id=f"trans_{r1.name}_{r2.name}_{r3.name}",
385
- body=[atom1, atom2],
386
- head=head_atom,
387
- rule_type=RuleType.TRANSITIVITY,
388
- confidence=0.5 # Will be learned
389
- )
390
-
391
- # Check support by grounding rule
392
- ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
393
-
394
- # Count supporting instances
395
- support_count = sum(1 for gr in ground_rules
396
- if gr.evaluate_truth_value(self.kg))
397
-
398
- if support_count >= min_support:
399
- rule.support = support_count
400
- rules.append(rule)
401
-
402
- logger.info(f"Generated {len(rules)} rules with min support {min_support}")
403
- return rules
404
-
405
- def generate_symmetry_rules(self, min_support: int = 2) -> List[LogicalRule]:
406
- """Generate symmetry rules: R(x,y) ⇒ R(y,x)"""
407
- rules = []
408
-
409
- for relation in self.kg.relations:
410
- x, y = Variable('x'), Variable('y')
411
-
412
- body_atom = Atom(predicate=relation, arguments=(x, y))
413
- head_atom = Atom(predicate=relation, arguments=(y, x))
414
-
415
- rule = LogicalRule(
416
- rule_id=f"sym_{relation.name}",
417
- body=[body_atom],
418
- head=head_atom,
419
- rule_type=RuleType.SYMMETRY,
420
- confidence=0.5
421
- )
422
-
423
- # Check support
424
- ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
425
- support_count = sum(1 for gr in ground_rules
426
- if gr.evaluate_truth_value(self.kg))
427
-
428
- if support_count >= min_support:
429
- rule.support = support_count
430
- rules.append(rule)
431
-
432
- return rules
433
-
434
-
435
- def parse_rule_from_string(rule_str: str, entities: Dict[str, Entity],
436
- relations: Dict[str, Relation]) -> LogicalRule:
437
- """
438
- Parse logical rule from string format
439
- """
440
- # Split by implication arrow
441
- if "⇒" not in rule_str:
442
- raise ValueError(f"Rule must contain ⇒: {rule_str}")
443
-
444
- body_str, head_str = rule_str.split("⇒", 1)
445
-
446
- # Parse body atoms (split by ∧)
447
- body_atom_strs = [atom.strip() for atom in body_str.split("∧")]
448
- body_atoms = []
449
-
450
- for atom_str in body_atom_strs:
451
- atom = _parse_atom_from_string(atom_str.strip(), entities, relations)
452
- body_atoms.append(atom)
453
-
454
- # Parse head atom
455
- head_atom = _parse_atom_from_string(head_str.strip(), entities, relations)
456
-
457
- rule_id = f"parsed_{hash(rule_str) % 10000}"
458
-
459
- return LogicalRule(
460
- rule_id=rule_id,
461
- body=body_atoms,
462
- head=head_atom,
463
- rule_type=RuleType.GENERAL,
464
- confidence=0.5
465
- )
466
-
467
-
468
- def _parse_atom_from_string(atom_str: str, entities: Dict[str, Entity],
469
- relations: Dict[str, Relation]) -> Atom:
470
- """Parse single atom from string: 'Pred(arg1,arg2)'"""
471
- # Extract predicate and arguments
472
- if "(" not in atom_str or ")" not in atom_str:
473
- raise ValueError(f"Invalid atom format: {atom_str}")
474
-
475
- pred_name = atom_str[:atom_str.index("(")].strip()
476
- args_str = atom_str[atom_str.index("(")+1:atom_str.rindex(")")].strip()
477
-
478
- # Get or create relation
479
- if pred_name not in relations:
480
- relations[pred_name] = Relation(name=pred_name)
481
- predicate = relations[pred_name]
482
-
483
- # Parse arguments
484
- arg_names = [arg.strip() for arg in args_str.split(",")]
485
- arguments = []
486
-
487
- for arg_name in arg_names:
488
- # Check if it's a variable (lowercase) or entity
489
- if arg_name.islower() or arg_name.startswith('?'):
490
- arguments.append(Variable(arg_name))
491
- else:
492
- # Entity - create if doesn't exist
493
- if arg_name not in entities:
494
- entities[arg_name] = Entity(name=arg_name)
495
- arguments.append(entities[arg_name])
496
-
1
+ """
2
+ Logical Rules and Ground Rules for NPLL
3
+ """
4
+
5
+ from typing import List, Set, Dict, Tuple, Optional, Iterator, Any
6
+ from dataclasses import dataclass, field
7
+ from enum import Enum
8
+ import itertools
9
+ import logging
10
+ from collections import defaultdict
11
+
12
+ from .knowledge_graph import Entity, Relation, Triple, KnowledgeGraph
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Variable:
18
+ """Variable in logical rules"""
19
+
20
+ def __init__(self, name: str):
21
+ self.name = name
22
+
23
+ def __str__(self) -> str:
24
+ return self.name
25
+
26
+ def __repr__(self) -> str:
27
+ return f"Var({self.name})"
28
+
29
+ def __eq__(self, other) -> bool:
30
+ return isinstance(other, Variable) and self.name == other.name
31
+
32
+ def __hash__(self) -> int:
33
+ return hash(self.name)
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class Atom:
38
+ """
39
+ Atomic formula in first-order logic: Pred(arg1, arg2)
40
+ """
41
+ predicate: Relation
42
+ arguments: Tuple[Any, ...]
43
+
44
+ def __post_init__(self):
45
+ # Validate arguments
46
+ for arg in self.arguments:
47
+ if not isinstance(arg, (Variable, Entity)):
48
+ raise ValueError(f"Atom arguments must be Variable or Entity, got {type(arg)}")
49
+
50
+ def __str__(self) -> str:
51
+ args_str = ", ".join(str(arg) for arg in self.arguments)
52
+ return f"{self.predicate.name}({args_str})"
53
+
54
+ def is_ground(self) -> bool:
55
+ """Check if atom is ground (no variables)"""
56
+ return all(isinstance(arg, Entity) for arg in self.arguments)
57
+
58
+ def get_variables(self) -> Set[Variable]:
59
+ """Get all variables in this atom"""
60
+ return {arg for arg in self.arguments if isinstance(arg, Variable)}
61
+
62
+ def ground_with_substitution(self, substitution: Dict[Variable, Entity]) -> 'Atom':
63
+ """Ground atom by substituting variables with entities"""
64
+ new_args = []
65
+ for arg in self.arguments:
66
+ if isinstance(arg, Variable) and arg in substitution:
67
+ new_args.append(substitution[arg])
68
+ else:
69
+ new_args.append(arg)
70
+
71
+ return Atom(predicate=self.predicate, arguments=tuple(new_args))
72
+
73
+ def to_triple(self) -> Optional[Triple]:
74
+ """Convert ground atom to triple (if binary and ground)"""
75
+ if len(self.arguments) == 2 and self.is_ground():
76
+ head, tail = self.arguments
77
+ return Triple(head=head, relation=self.predicate, tail=tail)
78
+ return None
79
+
80
+
81
+ class RuleType(Enum):
82
+ """Types of logical rules"""
83
+ HORN_CLAUSE = "horn" # Standard Horn clause
84
+ EQUALITY = "equality" # Equality rules
85
+ TRANSITIVITY = "transitivity" # Transitive rules
86
+ SYMMETRY = "symmetry" # Symmetric rules
87
+ GENERAL = "general" # General first-order rules
88
+
89
+
90
+ @dataclass
91
+ class LogicalRule:
92
+ """
93
+ First-order logical rule: Body ⇒ Head
94
+ """
95
+
96
+ rule_id: str
97
+ body: List[Atom] # Body atoms (premise)
98
+ head: Atom # Head atom (conclusion)
99
+ rule_type: RuleType = RuleType.GENERAL
100
+ confidence: float = 0.5 # Initial confidence score
101
+ support: int = 0 # Number of supporting instances
102
+
103
+ def __post_init__(self):
104
+ """Validate rule structure"""
105
+ if not self.body:
106
+ raise ValueError("Rule body cannot be empty")
107
+
108
+ if not isinstance(self.head, Atom):
109
+ raise ValueError("Rule head must be an Atom")
110
+
111
+ # Check variable consistency
112
+ body_vars = set()
113
+ for atom in self.body:
114
+ body_vars.update(atom.get_variables())
115
+
116
+ head_vars = self.head.get_variables()
117
+
118
+ # All head variables should appear in body
119
+ if not head_vars.issubset(body_vars):
120
+ logger.warning(f"Rule {self.rule_id}: Head variables not in body")
121
+
122
+ def get_all_variables(self) -> Set[Variable]:
123
+ """Get all variables in the rule"""
124
+ variables = set()
125
+ for atom in self.body:
126
+ variables.update(atom.get_variables())
127
+ variables.update(self.head.get_variables())
128
+ return variables
129
+
130
+ def get_predicates(self) -> Set[Relation]:
131
+ """Get all predicates used in the rule"""
132
+ predicates = {atom.predicate for atom in self.body}
133
+ predicates.add(self.head.predicate)
134
+ return predicates
135
+
136
+ def generate_ground_rules(self, kg: KnowledgeGraph,
137
+ max_groundings: int = 1000) -> List['GroundRule']:
138
+ """
139
+ Generate ground rules by substituting variables with entities
140
+ """
141
+ ground_rules = []
142
+ entities = list(kg.entities)
143
+
144
+ if not entities:
145
+ return ground_rules
146
+
147
+ # Get all variables that need to be substituted
148
+ variables = list(self.get_all_variables())
149
+
150
+ if not variables:
151
+ # Already ground rule
152
+ ground_body = [atom.to_triple() for atom in self.body if atom.is_ground()]
153
+ ground_head_triple = self.head.to_triple()
154
+
155
+ if ground_head_triple and all(t is not None for t in ground_body):
156
+ ground_rules.append(GroundRule(
157
+ rule_id=self.rule_id,
158
+ body_facts=[t for t in ground_body if t is not None],
159
+ head_fact=ground_head_triple,
160
+ parent_rule=self
161
+ ))
162
+ return ground_rules
163
+
164
+ # Generate all possible variable substitutions
165
+ # Limit combinations to prevent explosion
166
+ max_entities_per_var = min(len(entities), max_groundings // len(variables) + 1)
167
+
168
+ substitution_count = 0
169
+ for substitution_values in itertools.product(entities[:max_entities_per_var],
170
+ repeat=len(variables)):
171
+ if substitution_count >= max_groundings:
172
+ break
173
+
174
+ substitution = dict(zip(variables, substitution_values))
175
+
176
+ # Generate ground atoms
177
+ try:
178
+ ground_body_atoms = [atom.ground_with_substitution(substitution)
179
+ for atom in self.body]
180
+ ground_head_atom = self.head.ground_with_substitution(substitution)
181
+
182
+ # Convert to triples
183
+ ground_body_triples = [atom.to_triple() for atom in ground_body_atoms
184
+ if atom.is_ground()]
185
+ ground_head_triple = ground_head_atom.to_triple()
186
+
187
+ # Check if all conversions successful
188
+ if (ground_head_triple and
189
+ len(ground_body_triples) == len(ground_body_atoms) and
190
+ all(t is not None for t in ground_body_triples)):
191
+
192
+ ground_rule = GroundRule(
193
+ rule_id=f"{self.rule_id}_ground_{substitution_count}",
194
+ body_facts=ground_body_triples,
195
+ head_fact=ground_head_triple,
196
+ parent_rule=self,
197
+ substitution=substitution.copy()
198
+ )
199
+
200
+ ground_rules.append(ground_rule)
201
+ substitution_count += 1
202
+
203
+ except Exception as e:
204
+ logger.debug(f"Failed to ground rule {self.rule_id}: {e}")
205
+ continue
206
+
207
+ logger.info(f"Generated {len(ground_rules)} ground rules for {self.rule_id}")
208
+ return ground_rules
209
+
210
+ def to_cnf(self) -> str:
211
+ """
212
+ Convert rule to Conjunctive Normal Form (CNF)
213
+ """
214
+ # Body atoms become negated disjuncts
215
+ cnf_parts = [f"¬{atom}" for atom in self.body]
216
+ # Head atom becomes positive disjunct
217
+ cnf_parts.append(str(self.head))
218
+
219
+ return " ∨ ".join(cnf_parts)
220
+
221
+ def __str__(self) -> str:
222
+ body_str = " ∧ ".join(str(atom) for atom in self.body)
223
+ return f"{body_str} ⇒ {self.head}"
224
+
225
+ def __repr__(self) -> str:
226
+ return f"LogicalRule(id={self.rule_id}, {self})"
227
+
228
+ # --- Serialization helpers ---
229
+ def serialize(self) -> Dict[str, Any]:
230
+ def _atom_to_dict(atom: Atom) -> Dict[str, Any]:
231
+ args = []
232
+ for a in atom.arguments:
233
+ if isinstance(a, Variable):
234
+ args.append({'type': 'var', 'name': a.name})
235
+ elif isinstance(a, Entity):
236
+ args.append({'type': 'ent', 'name': a.name})
237
+ else:
238
+ args.append({'type': 'raw', 'value': str(a)})
239
+ return {'predicate': atom.predicate.name, 'arguments': args}
240
+
241
+ return {
242
+ 'rule_id': self.rule_id,
243
+ 'rule_type': self.rule_type.value,
244
+ 'confidence': self.confidence,
245
+ 'support': self.support,
246
+ 'body': [_atom_to_dict(a) for a in self.body],
247
+ 'head': _atom_to_dict(self.head),
248
+ }
249
+
250
+ @staticmethod
251
+ def deserialize(data: Dict[str, Any]) -> 'LogicalRule':
252
+ rule_id = data['rule_id']
253
+ rule_type = RuleType(data.get('rule_type', RuleType.GENERAL.value))
254
+ confidence = data.get('confidence', 0.5)
255
+ support = data.get('support', 0)
256
+
257
+ def _atom_from_dict(d: Dict[str, Any]) -> Atom:
258
+ pred = Relation(d['predicate'])
259
+ args = []
260
+ for a in d['arguments']:
261
+ if a.get('type') == 'var':
262
+ args.append(Variable(a['name']))
263
+ elif a.get('type') == 'ent':
264
+ args.append(Entity(a['name']))
265
+ else:
266
+ # Fallback: treat as entity by name
267
+ args.append(Entity(str(a.get('value', ''))))
268
+ return Atom(predicate=pred, arguments=tuple(args))
269
+
270
+ body_atoms = [_atom_from_dict(x) for x in data['body']]
271
+ head_atom = _atom_from_dict(data['head'])
272
+ return LogicalRule(
273
+ rule_id=rule_id,
274
+ body=body_atoms,
275
+ head=head_atom,
276
+ rule_type=rule_type,
277
+ confidence=confidence,
278
+ support=support,
279
+ )
280
+
281
+
282
+ @dataclass
283
+ class GroundRule:
284
+ """
285
+ Ground instance of a logical rule with all variables substituted
286
+
287
+ """
288
+
289
+ rule_id: str
290
+ body_facts: List[Triple] # Ground body facts
291
+ head_fact: Triple # Ground head fact
292
+ parent_rule: LogicalRule # Original rule this was grounded from
293
+ substitution: Optional[Dict[Variable, Entity]] = None
294
+
295
+ def __post_init__(self):
296
+ """Validate ground rule"""
297
+ if not self.body_facts:
298
+ raise ValueError("Ground rule body cannot be empty")
299
+
300
+ if not isinstance(self.head_fact, Triple):
301
+ raise ValueError("Ground rule head must be a Triple")
302
+
303
+ def evaluate_truth_value(self, kg: KnowledgeGraph) -> bool:
304
+ """
305
+ Evaluate truth value of ground rule given knowledge graph
306
+ """
307
+ # Check if all body facts are true in KG
308
+ body_satisfied = all(
309
+ kg.contains_fact(fact.head.name, fact.relation.name, fact.tail.name)
310
+ for fact in self.body_facts
311
+ )
312
+
313
+ # If body is false, rule is vacuously true
314
+ if not body_satisfied:
315
+ return True
316
+
317
+ # If body is true, check if head is true
318
+ head_satisfied = kg.contains_fact(
319
+ self.head_fact.head.name,
320
+ self.head_fact.relation.name,
321
+ self.head_fact.tail.name
322
+ )
323
+
324
+ return head_satisfied
325
+
326
+ def get_all_facts(self) -> List[Triple]:
327
+ """Get all facts (body + head) in this ground rule"""
328
+ return self.body_facts + [self.head_fact]
329
+
330
+ def get_fact_truth_values(self, kg: KnowledgeGraph) -> Dict[Triple, bool]:
331
+ """Get truth values for all facts in this ground rule"""
332
+ truth_values = {}
333
+
334
+ for fact in self.body_facts:
335
+ truth_values[fact] = kg.contains_fact(
336
+ fact.head.name, fact.relation.name, fact.tail.name
337
+ )
338
+
339
+ truth_values[self.head_fact] = kg.contains_fact(
340
+ self.head_fact.head.name,
341
+ self.head_fact.relation.name,
342
+ self.head_fact.tail.name
343
+ )
344
+
345
+ return truth_values
346
+
347
+ def __str__(self) -> str:
348
+ body_str = " ∧ ".join(str(fact) for fact in self.body_facts)
349
+ return f"{body_str} ⇒ {self.head_fact}"
350
+
351
+ def __repr__(self) -> str:
352
+ return f"GroundRule(id={self.rule_id}, {self})"
353
+
354
+
355
+ class RuleGenerator:
356
+ """Generate logical rules from knowledge graph patterns"""
357
+
358
+ def __init__(self, kg: KnowledgeGraph):
359
+ self.kg = kg
360
+
361
+ def generate_simple_rules(self, min_support: int = 2,
362
+ max_rule_length: int = 3) -> List[LogicalRule]:
363
+ """
364
+ Generate simple logical rules from knowledge graph patterns
365
+ """
366
+ rules = []
367
+
368
+ # Generate transitivity rules: R1(x,y) ∧ R2(y,z) ⇒ R3(x,z)
369
+ relations = list(self.kg.relations)
370
+
371
+ for r1, r2, r3 in itertools.combinations_with_replacement(relations, 3):
372
+ if r1 == r2 == r3: # Skip trivial cases
373
+ continue
374
+
375
+ # Create variables
376
+ x, y, z = Variable('x'), Variable('y'), Variable('z')
377
+
378
+ # Create atoms
379
+ atom1 = Atom(predicate=r1, arguments=(x, y))
380
+ atom2 = Atom(predicate=r2, arguments=(y, z))
381
+ head_atom = Atom(predicate=r3, arguments=(x, z))
382
+
383
+ rule = LogicalRule(
384
+ rule_id=f"trans_{r1.name}_{r2.name}_{r3.name}",
385
+ body=[atom1, atom2],
386
+ head=head_atom,
387
+ rule_type=RuleType.TRANSITIVITY,
388
+ confidence=0.5 # Will be learned
389
+ )
390
+
391
+ # Check support by grounding rule
392
+ ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
393
+
394
+ # Count supporting instances
395
+ support_count = sum(1 for gr in ground_rules
396
+ if gr.evaluate_truth_value(self.kg))
397
+
398
+ if support_count >= min_support:
399
+ rule.support = support_count
400
+ rules.append(rule)
401
+
402
+ logger.info(f"Generated {len(rules)} rules with min support {min_support}")
403
+ return rules
404
+
405
+ def generate_symmetry_rules(self, min_support: int = 2) -> List[LogicalRule]:
406
+ """Generate symmetry rules: R(x,y) ⇒ R(y,x)"""
407
+ rules = []
408
+
409
+ for relation in self.kg.relations:
410
+ x, y = Variable('x'), Variable('y')
411
+
412
+ body_atom = Atom(predicate=relation, arguments=(x, y))
413
+ head_atom = Atom(predicate=relation, arguments=(y, x))
414
+
415
+ rule = LogicalRule(
416
+ rule_id=f"sym_{relation.name}",
417
+ body=[body_atom],
418
+ head=head_atom,
419
+ rule_type=RuleType.SYMMETRY,
420
+ confidence=0.5
421
+ )
422
+
423
+ # Check support
424
+ ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
425
+ support_count = sum(1 for gr in ground_rules
426
+ if gr.evaluate_truth_value(self.kg))
427
+
428
+ if support_count >= min_support:
429
+ rule.support = support_count
430
+ rules.append(rule)
431
+
432
+ return rules
433
+
434
+
435
+ def parse_rule_from_string(rule_str: str, entities: Dict[str, Entity],
436
+ relations: Dict[str, Relation]) -> LogicalRule:
437
+ """
438
+ Parse logical rule from string format
439
+ """
440
+ # Split by implication arrow
441
+ if "⇒" not in rule_str:
442
+ raise ValueError(f"Rule must contain ⇒: {rule_str}")
443
+
444
+ body_str, head_str = rule_str.split("⇒", 1)
445
+
446
+ # Parse body atoms (split by ∧)
447
+ body_atom_strs = [atom.strip() for atom in body_str.split("∧")]
448
+ body_atoms = []
449
+
450
+ for atom_str in body_atom_strs:
451
+ atom = _parse_atom_from_string(atom_str.strip(), entities, relations)
452
+ body_atoms.append(atom)
453
+
454
+ # Parse head atom
455
+ head_atom = _parse_atom_from_string(head_str.strip(), entities, relations)
456
+
457
+ rule_id = f"parsed_{hash(rule_str) % 10000}"
458
+
459
+ return LogicalRule(
460
+ rule_id=rule_id,
461
+ body=body_atoms,
462
+ head=head_atom,
463
+ rule_type=RuleType.GENERAL,
464
+ confidence=0.5
465
+ )
466
+
467
+
468
+ def _parse_atom_from_string(atom_str: str, entities: Dict[str, Entity],
469
+ relations: Dict[str, Relation]) -> Atom:
470
+ """Parse single atom from string: 'Pred(arg1,arg2)'"""
471
+ # Extract predicate and arguments
472
+ if "(" not in atom_str or ")" not in atom_str:
473
+ raise ValueError(f"Invalid atom format: {atom_str}")
474
+
475
+ pred_name = atom_str[:atom_str.index("(")].strip()
476
+ args_str = atom_str[atom_str.index("(")+1:atom_str.rindex(")")].strip()
477
+
478
+ # Get or create relation
479
+ if pred_name not in relations:
480
+ relations[pred_name] = Relation(name=pred_name)
481
+ predicate = relations[pred_name]
482
+
483
+ # Parse arguments
484
+ arg_names = [arg.strip() for arg in args_str.split(",")]
485
+ arguments = []
486
+
487
+ for arg_name in arg_names:
488
+ # Check if it's a variable (lowercase) or entity
489
+ if arg_name.islower() or arg_name.startswith('?'):
490
+ arguments.append(Variable(arg_name))
491
+ else:
492
+ # Entity - create if doesn't exist
493
+ if arg_name not in entities:
494
+ entities[arg_name] = Entity(name=arg_name)
495
+ arguments.append(entities[arg_name])
496
+
497
497
  return Atom(predicate=predicate, arguments=tuple(arguments))