odin-engine 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +17 -17
- benchmarks/datasets.py +284 -284
- benchmarks/metrics.py +275 -275
- benchmarks/run_ablation.py +279 -279
- benchmarks/run_npll_benchmark.py +270 -270
- npll/__init__.py +10 -10
- npll/bootstrap.py +474 -474
- npll/core/__init__.py +33 -33
- npll/core/knowledge_graph.py +308 -308
- npll/core/logical_rules.py +496 -496
- npll/core/mln.py +474 -474
- npll/inference/__init__.py +40 -40
- npll/inference/e_step.py +419 -419
- npll/inference/elbo.py +434 -434
- npll/inference/m_step.py +576 -576
- npll/npll_model.py +631 -631
- npll/scoring/__init__.py +42 -42
- npll/scoring/embeddings.py +441 -441
- npll/scoring/probability.py +402 -402
- npll/scoring/scoring_module.py +369 -369
- npll/training/__init__.py +24 -24
- npll/training/evaluation.py +496 -496
- npll/training/npll_trainer.py +520 -520
- npll/utils/__init__.py +47 -47
- npll/utils/batch_utils.py +492 -492
- npll/utils/config.py +144 -144
- npll/utils/math_utils.py +338 -338
- odin/__init__.py +21 -20
- odin/engine.py +264 -264
- odin/schema.py +210 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/METADATA +503 -456
- odin_engine-0.2.0.dist-info/RECORD +63 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/licenses/LICENSE +21 -21
- retrieval/__init__.py +50 -50
- retrieval/adapters.py +140 -140
- retrieval/adapters_arango.py +1418 -1418
- retrieval/aggregators.py +707 -707
- retrieval/beam.py +127 -127
- retrieval/budget.py +60 -60
- retrieval/cache.py +159 -159
- retrieval/confidence.py +88 -88
- retrieval/eval.py +49 -49
- retrieval/linker.py +87 -87
- retrieval/metrics.py +105 -105
- retrieval/metrics_motifs.py +36 -36
- retrieval/orchestrator.py +571 -571
- retrieval/ppr/__init__.py +12 -12
- retrieval/ppr/anchors.py +41 -41
- retrieval/ppr/bippr.py +61 -61
- retrieval/ppr/engines.py +257 -257
- retrieval/ppr/global_pr.py +76 -76
- retrieval/ppr/indexes.py +78 -78
- retrieval/ppr.py +156 -156
- retrieval/ppr_cache.py +25 -25
- retrieval/scoring.py +294 -294
- retrieval/utils/pii_redaction.py +36 -36
- retrieval/writers/__init__.py +9 -9
- retrieval/writers/arango_writer.py +28 -28
- retrieval/writers/base.py +21 -21
- retrieval/writers/janus_writer.py +36 -36
- odin_engine-0.1.0.dist-info/RECORD +0 -62
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/WHEEL +0 -0
- {odin_engine-0.1.0.dist-info → odin_engine-0.2.0.dist-info}/top_level.txt +0 -0
npll/core/logical_rules.py
CHANGED
|
@@ -1,497 +1,497 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Logical Rules and Ground Rules for NPLL
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import List, Set, Dict, Tuple, Optional, Iterator, Any
|
|
6
|
-
from dataclasses import dataclass, field
|
|
7
|
-
from enum import Enum
|
|
8
|
-
import itertools
|
|
9
|
-
import logging
|
|
10
|
-
from collections import defaultdict
|
|
11
|
-
|
|
12
|
-
from .knowledge_graph import Entity, Relation, Triple, KnowledgeGraph
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class Variable:
|
|
18
|
-
"""Variable in logical rules"""
|
|
19
|
-
|
|
20
|
-
def __init__(self, name: str):
|
|
21
|
-
self.name = name
|
|
22
|
-
|
|
23
|
-
def __str__(self) -> str:
|
|
24
|
-
return self.name
|
|
25
|
-
|
|
26
|
-
def __repr__(self) -> str:
|
|
27
|
-
return f"Var({self.name})"
|
|
28
|
-
|
|
29
|
-
def __eq__(self, other) -> bool:
|
|
30
|
-
return isinstance(other, Variable) and self.name == other.name
|
|
31
|
-
|
|
32
|
-
def __hash__(self) -> int:
|
|
33
|
-
return hash(self.name)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@dataclass(frozen=True)
|
|
37
|
-
class Atom:
|
|
38
|
-
"""
|
|
39
|
-
Atomic formula in first-order logic: Pred(arg1, arg2)
|
|
40
|
-
"""
|
|
41
|
-
predicate: Relation
|
|
42
|
-
arguments: Tuple[Any, ...]
|
|
43
|
-
|
|
44
|
-
def __post_init__(self):
|
|
45
|
-
# Validate arguments
|
|
46
|
-
for arg in self.arguments:
|
|
47
|
-
if not isinstance(arg, (Variable, Entity)):
|
|
48
|
-
raise ValueError(f"Atom arguments must be Variable or Entity, got {type(arg)}")
|
|
49
|
-
|
|
50
|
-
def __str__(self) -> str:
|
|
51
|
-
args_str = ", ".join(str(arg) for arg in self.arguments)
|
|
52
|
-
return f"{self.predicate.name}({args_str})"
|
|
53
|
-
|
|
54
|
-
def is_ground(self) -> bool:
|
|
55
|
-
"""Check if atom is ground (no variables)"""
|
|
56
|
-
return all(isinstance(arg, Entity) for arg in self.arguments)
|
|
57
|
-
|
|
58
|
-
def get_variables(self) -> Set[Variable]:
|
|
59
|
-
"""Get all variables in this atom"""
|
|
60
|
-
return {arg for arg in self.arguments if isinstance(arg, Variable)}
|
|
61
|
-
|
|
62
|
-
def ground_with_substitution(self, substitution: Dict[Variable, Entity]) -> 'Atom':
|
|
63
|
-
"""Ground atom by substituting variables with entities"""
|
|
64
|
-
new_args = []
|
|
65
|
-
for arg in self.arguments:
|
|
66
|
-
if isinstance(arg, Variable) and arg in substitution:
|
|
67
|
-
new_args.append(substitution[arg])
|
|
68
|
-
else:
|
|
69
|
-
new_args.append(arg)
|
|
70
|
-
|
|
71
|
-
return Atom(predicate=self.predicate, arguments=tuple(new_args))
|
|
72
|
-
|
|
73
|
-
def to_triple(self) -> Optional[Triple]:
|
|
74
|
-
"""Convert ground atom to triple (if binary and ground)"""
|
|
75
|
-
if len(self.arguments) == 2 and self.is_ground():
|
|
76
|
-
head, tail = self.arguments
|
|
77
|
-
return Triple(head=head, relation=self.predicate, tail=tail)
|
|
78
|
-
return None
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class RuleType(Enum):
|
|
82
|
-
"""Types of logical rules"""
|
|
83
|
-
HORN_CLAUSE = "horn" # Standard Horn clause
|
|
84
|
-
EQUALITY = "equality" # Equality rules
|
|
85
|
-
TRANSITIVITY = "transitivity" # Transitive rules
|
|
86
|
-
SYMMETRY = "symmetry" # Symmetric rules
|
|
87
|
-
GENERAL = "general" # General first-order rules
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
@dataclass
|
|
91
|
-
class LogicalRule:
|
|
92
|
-
"""
|
|
93
|
-
First-order logical rule: Body ⇒ Head
|
|
94
|
-
"""
|
|
95
|
-
|
|
96
|
-
rule_id: str
|
|
97
|
-
body: List[Atom] # Body atoms (premise)
|
|
98
|
-
head: Atom # Head atom (conclusion)
|
|
99
|
-
rule_type: RuleType = RuleType.GENERAL
|
|
100
|
-
confidence: float = 0.5 # Initial confidence score
|
|
101
|
-
support: int = 0 # Number of supporting instances
|
|
102
|
-
|
|
103
|
-
def __post_init__(self):
|
|
104
|
-
"""Validate rule structure"""
|
|
105
|
-
if not self.body:
|
|
106
|
-
raise ValueError("Rule body cannot be empty")
|
|
107
|
-
|
|
108
|
-
if not isinstance(self.head, Atom):
|
|
109
|
-
raise ValueError("Rule head must be an Atom")
|
|
110
|
-
|
|
111
|
-
# Check variable consistency
|
|
112
|
-
body_vars = set()
|
|
113
|
-
for atom in self.body:
|
|
114
|
-
body_vars.update(atom.get_variables())
|
|
115
|
-
|
|
116
|
-
head_vars = self.head.get_variables()
|
|
117
|
-
|
|
118
|
-
# All head variables should appear in body
|
|
119
|
-
if not head_vars.issubset(body_vars):
|
|
120
|
-
logger.warning(f"Rule {self.rule_id}: Head variables not in body")
|
|
121
|
-
|
|
122
|
-
def get_all_variables(self) -> Set[Variable]:
|
|
123
|
-
"""Get all variables in the rule"""
|
|
124
|
-
variables = set()
|
|
125
|
-
for atom in self.body:
|
|
126
|
-
variables.update(atom.get_variables())
|
|
127
|
-
variables.update(self.head.get_variables())
|
|
128
|
-
return variables
|
|
129
|
-
|
|
130
|
-
def get_predicates(self) -> Set[Relation]:
|
|
131
|
-
"""Get all predicates used in the rule"""
|
|
132
|
-
predicates = {atom.predicate for atom in self.body}
|
|
133
|
-
predicates.add(self.head.predicate)
|
|
134
|
-
return predicates
|
|
135
|
-
|
|
136
|
-
def generate_ground_rules(self, kg: KnowledgeGraph,
|
|
137
|
-
max_groundings: int = 1000) -> List['GroundRule']:
|
|
138
|
-
"""
|
|
139
|
-
Generate ground rules by substituting variables with entities
|
|
140
|
-
"""
|
|
141
|
-
ground_rules = []
|
|
142
|
-
entities = list(kg.entities)
|
|
143
|
-
|
|
144
|
-
if not entities:
|
|
145
|
-
return ground_rules
|
|
146
|
-
|
|
147
|
-
# Get all variables that need to be substituted
|
|
148
|
-
variables = list(self.get_all_variables())
|
|
149
|
-
|
|
150
|
-
if not variables:
|
|
151
|
-
# Already ground rule
|
|
152
|
-
ground_body = [atom.to_triple() for atom in self.body if atom.is_ground()]
|
|
153
|
-
ground_head_triple = self.head.to_triple()
|
|
154
|
-
|
|
155
|
-
if ground_head_triple and all(t is not None for t in ground_body):
|
|
156
|
-
ground_rules.append(GroundRule(
|
|
157
|
-
rule_id=self.rule_id,
|
|
158
|
-
body_facts=[t for t in ground_body if t is not None],
|
|
159
|
-
head_fact=ground_head_triple,
|
|
160
|
-
parent_rule=self
|
|
161
|
-
))
|
|
162
|
-
return ground_rules
|
|
163
|
-
|
|
164
|
-
# Generate all possible variable substitutions
|
|
165
|
-
# Limit combinations to prevent explosion
|
|
166
|
-
max_entities_per_var = min(len(entities), max_groundings // len(variables) + 1)
|
|
167
|
-
|
|
168
|
-
substitution_count = 0
|
|
169
|
-
for substitution_values in itertools.product(entities[:max_entities_per_var],
|
|
170
|
-
repeat=len(variables)):
|
|
171
|
-
if substitution_count >= max_groundings:
|
|
172
|
-
break
|
|
173
|
-
|
|
174
|
-
substitution = dict(zip(variables, substitution_values))
|
|
175
|
-
|
|
176
|
-
# Generate ground atoms
|
|
177
|
-
try:
|
|
178
|
-
ground_body_atoms = [atom.ground_with_substitution(substitution)
|
|
179
|
-
for atom in self.body]
|
|
180
|
-
ground_head_atom = self.head.ground_with_substitution(substitution)
|
|
181
|
-
|
|
182
|
-
# Convert to triples
|
|
183
|
-
ground_body_triples = [atom.to_triple() for atom in ground_body_atoms
|
|
184
|
-
if atom.is_ground()]
|
|
185
|
-
ground_head_triple = ground_head_atom.to_triple()
|
|
186
|
-
|
|
187
|
-
# Check if all conversions successful
|
|
188
|
-
if (ground_head_triple and
|
|
189
|
-
len(ground_body_triples) == len(ground_body_atoms) and
|
|
190
|
-
all(t is not None for t in ground_body_triples)):
|
|
191
|
-
|
|
192
|
-
ground_rule = GroundRule(
|
|
193
|
-
rule_id=f"{self.rule_id}_ground_{substitution_count}",
|
|
194
|
-
body_facts=ground_body_triples,
|
|
195
|
-
head_fact=ground_head_triple,
|
|
196
|
-
parent_rule=self,
|
|
197
|
-
substitution=substitution.copy()
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
ground_rules.append(ground_rule)
|
|
201
|
-
substitution_count += 1
|
|
202
|
-
|
|
203
|
-
except Exception as e:
|
|
204
|
-
logger.debug(f"Failed to ground rule {self.rule_id}: {e}")
|
|
205
|
-
continue
|
|
206
|
-
|
|
207
|
-
logger.info(f"Generated {len(ground_rules)} ground rules for {self.rule_id}")
|
|
208
|
-
return ground_rules
|
|
209
|
-
|
|
210
|
-
def to_cnf(self) -> str:
|
|
211
|
-
"""
|
|
212
|
-
Convert rule to Conjunctive Normal Form (CNF)
|
|
213
|
-
"""
|
|
214
|
-
# Body atoms become negated disjuncts
|
|
215
|
-
cnf_parts = [f"¬{atom}" for atom in self.body]
|
|
216
|
-
# Head atom becomes positive disjunct
|
|
217
|
-
cnf_parts.append(str(self.head))
|
|
218
|
-
|
|
219
|
-
return " ∨ ".join(cnf_parts)
|
|
220
|
-
|
|
221
|
-
def __str__(self) -> str:
|
|
222
|
-
body_str = " ∧ ".join(str(atom) for atom in self.body)
|
|
223
|
-
return f"{body_str} ⇒ {self.head}"
|
|
224
|
-
|
|
225
|
-
def __repr__(self) -> str:
|
|
226
|
-
return f"LogicalRule(id={self.rule_id}, {self})"
|
|
227
|
-
|
|
228
|
-
# --- Serialization helpers ---
|
|
229
|
-
def serialize(self) -> Dict[str, Any]:
|
|
230
|
-
def _atom_to_dict(atom: Atom) -> Dict[str, Any]:
|
|
231
|
-
args = []
|
|
232
|
-
for a in atom.arguments:
|
|
233
|
-
if isinstance(a, Variable):
|
|
234
|
-
args.append({'type': 'var', 'name': a.name})
|
|
235
|
-
elif isinstance(a, Entity):
|
|
236
|
-
args.append({'type': 'ent', 'name': a.name})
|
|
237
|
-
else:
|
|
238
|
-
args.append({'type': 'raw', 'value': str(a)})
|
|
239
|
-
return {'predicate': atom.predicate.name, 'arguments': args}
|
|
240
|
-
|
|
241
|
-
return {
|
|
242
|
-
'rule_id': self.rule_id,
|
|
243
|
-
'rule_type': self.rule_type.value,
|
|
244
|
-
'confidence': self.confidence,
|
|
245
|
-
'support': self.support,
|
|
246
|
-
'body': [_atom_to_dict(a) for a in self.body],
|
|
247
|
-
'head': _atom_to_dict(self.head),
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
@staticmethod
|
|
251
|
-
def deserialize(data: Dict[str, Any]) -> 'LogicalRule':
|
|
252
|
-
rule_id = data['rule_id']
|
|
253
|
-
rule_type = RuleType(data.get('rule_type', RuleType.GENERAL.value))
|
|
254
|
-
confidence = data.get('confidence', 0.5)
|
|
255
|
-
support = data.get('support', 0)
|
|
256
|
-
|
|
257
|
-
def _atom_from_dict(d: Dict[str, Any]) -> Atom:
|
|
258
|
-
pred = Relation(d['predicate'])
|
|
259
|
-
args = []
|
|
260
|
-
for a in d['arguments']:
|
|
261
|
-
if a.get('type') == 'var':
|
|
262
|
-
args.append(Variable(a['name']))
|
|
263
|
-
elif a.get('type') == 'ent':
|
|
264
|
-
args.append(Entity(a['name']))
|
|
265
|
-
else:
|
|
266
|
-
# Fallback: treat as entity by name
|
|
267
|
-
args.append(Entity(str(a.get('value', ''))))
|
|
268
|
-
return Atom(predicate=pred, arguments=tuple(args))
|
|
269
|
-
|
|
270
|
-
body_atoms = [_atom_from_dict(x) for x in data['body']]
|
|
271
|
-
head_atom = _atom_from_dict(data['head'])
|
|
272
|
-
return LogicalRule(
|
|
273
|
-
rule_id=rule_id,
|
|
274
|
-
body=body_atoms,
|
|
275
|
-
head=head_atom,
|
|
276
|
-
rule_type=rule_type,
|
|
277
|
-
confidence=confidence,
|
|
278
|
-
support=support,
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
@dataclass
|
|
283
|
-
class GroundRule:
|
|
284
|
-
"""
|
|
285
|
-
Ground instance of a logical rule with all variables substituted
|
|
286
|
-
|
|
287
|
-
"""
|
|
288
|
-
|
|
289
|
-
rule_id: str
|
|
290
|
-
body_facts: List[Triple] # Ground body facts
|
|
291
|
-
head_fact: Triple # Ground head fact
|
|
292
|
-
parent_rule: LogicalRule # Original rule this was grounded from
|
|
293
|
-
substitution: Optional[Dict[Variable, Entity]] = None
|
|
294
|
-
|
|
295
|
-
def __post_init__(self):
|
|
296
|
-
"""Validate ground rule"""
|
|
297
|
-
if not self.body_facts:
|
|
298
|
-
raise ValueError("Ground rule body cannot be empty")
|
|
299
|
-
|
|
300
|
-
if not isinstance(self.head_fact, Triple):
|
|
301
|
-
raise ValueError("Ground rule head must be a Triple")
|
|
302
|
-
|
|
303
|
-
def evaluate_truth_value(self, kg: KnowledgeGraph) -> bool:
|
|
304
|
-
"""
|
|
305
|
-
Evaluate truth value of ground rule given knowledge graph
|
|
306
|
-
"""
|
|
307
|
-
# Check if all body facts are true in KG
|
|
308
|
-
body_satisfied = all(
|
|
309
|
-
kg.contains_fact(fact.head.name, fact.relation.name, fact.tail.name)
|
|
310
|
-
for fact in self.body_facts
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
# If body is false, rule is vacuously true
|
|
314
|
-
if not body_satisfied:
|
|
315
|
-
return True
|
|
316
|
-
|
|
317
|
-
# If body is true, check if head is true
|
|
318
|
-
head_satisfied = kg.contains_fact(
|
|
319
|
-
self.head_fact.head.name,
|
|
320
|
-
self.head_fact.relation.name,
|
|
321
|
-
self.head_fact.tail.name
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
return head_satisfied
|
|
325
|
-
|
|
326
|
-
def get_all_facts(self) -> List[Triple]:
|
|
327
|
-
"""Get all facts (body + head) in this ground rule"""
|
|
328
|
-
return self.body_facts + [self.head_fact]
|
|
329
|
-
|
|
330
|
-
def get_fact_truth_values(self, kg: KnowledgeGraph) -> Dict[Triple, bool]:
|
|
331
|
-
"""Get truth values for all facts in this ground rule"""
|
|
332
|
-
truth_values = {}
|
|
333
|
-
|
|
334
|
-
for fact in self.body_facts:
|
|
335
|
-
truth_values[fact] = kg.contains_fact(
|
|
336
|
-
fact.head.name, fact.relation.name, fact.tail.name
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
truth_values[self.head_fact] = kg.contains_fact(
|
|
340
|
-
self.head_fact.head.name,
|
|
341
|
-
self.head_fact.relation.name,
|
|
342
|
-
self.head_fact.tail.name
|
|
343
|
-
)
|
|
344
|
-
|
|
345
|
-
return truth_values
|
|
346
|
-
|
|
347
|
-
def __str__(self) -> str:
|
|
348
|
-
body_str = " ∧ ".join(str(fact) for fact in self.body_facts)
|
|
349
|
-
return f"{body_str} ⇒ {self.head_fact}"
|
|
350
|
-
|
|
351
|
-
def __repr__(self) -> str:
|
|
352
|
-
return f"GroundRule(id={self.rule_id}, {self})"
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
class RuleGenerator:
|
|
356
|
-
"""Generate logical rules from knowledge graph patterns"""
|
|
357
|
-
|
|
358
|
-
def __init__(self, kg: KnowledgeGraph):
|
|
359
|
-
self.kg = kg
|
|
360
|
-
|
|
361
|
-
def generate_simple_rules(self, min_support: int = 2,
|
|
362
|
-
max_rule_length: int = 3) -> List[LogicalRule]:
|
|
363
|
-
"""
|
|
364
|
-
Generate simple logical rules from knowledge graph patterns
|
|
365
|
-
"""
|
|
366
|
-
rules = []
|
|
367
|
-
|
|
368
|
-
# Generate transitivity rules: R1(x,y) ∧ R2(y,z) ⇒ R3(x,z)
|
|
369
|
-
relations = list(self.kg.relations)
|
|
370
|
-
|
|
371
|
-
for r1, r2, r3 in itertools.combinations_with_replacement(relations, 3):
|
|
372
|
-
if r1 == r2 == r3: # Skip trivial cases
|
|
373
|
-
continue
|
|
374
|
-
|
|
375
|
-
# Create variables
|
|
376
|
-
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
377
|
-
|
|
378
|
-
# Create atoms
|
|
379
|
-
atom1 = Atom(predicate=r1, arguments=(x, y))
|
|
380
|
-
atom2 = Atom(predicate=r2, arguments=(y, z))
|
|
381
|
-
head_atom = Atom(predicate=r3, arguments=(x, z))
|
|
382
|
-
|
|
383
|
-
rule = LogicalRule(
|
|
384
|
-
rule_id=f"trans_{r1.name}_{r2.name}_{r3.name}",
|
|
385
|
-
body=[atom1, atom2],
|
|
386
|
-
head=head_atom,
|
|
387
|
-
rule_type=RuleType.TRANSITIVITY,
|
|
388
|
-
confidence=0.5 # Will be learned
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
# Check support by grounding rule
|
|
392
|
-
ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
|
|
393
|
-
|
|
394
|
-
# Count supporting instances
|
|
395
|
-
support_count = sum(1 for gr in ground_rules
|
|
396
|
-
if gr.evaluate_truth_value(self.kg))
|
|
397
|
-
|
|
398
|
-
if support_count >= min_support:
|
|
399
|
-
rule.support = support_count
|
|
400
|
-
rules.append(rule)
|
|
401
|
-
|
|
402
|
-
logger.info(f"Generated {len(rules)} rules with min support {min_support}")
|
|
403
|
-
return rules
|
|
404
|
-
|
|
405
|
-
def generate_symmetry_rules(self, min_support: int = 2) -> List[LogicalRule]:
|
|
406
|
-
"""Generate symmetry rules: R(x,y) ⇒ R(y,x)"""
|
|
407
|
-
rules = []
|
|
408
|
-
|
|
409
|
-
for relation in self.kg.relations:
|
|
410
|
-
x, y = Variable('x'), Variable('y')
|
|
411
|
-
|
|
412
|
-
body_atom = Atom(predicate=relation, arguments=(x, y))
|
|
413
|
-
head_atom = Atom(predicate=relation, arguments=(y, x))
|
|
414
|
-
|
|
415
|
-
rule = LogicalRule(
|
|
416
|
-
rule_id=f"sym_{relation.name}",
|
|
417
|
-
body=[body_atom],
|
|
418
|
-
head=head_atom,
|
|
419
|
-
rule_type=RuleType.SYMMETRY,
|
|
420
|
-
confidence=0.5
|
|
421
|
-
)
|
|
422
|
-
|
|
423
|
-
# Check support
|
|
424
|
-
ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
|
|
425
|
-
support_count = sum(1 for gr in ground_rules
|
|
426
|
-
if gr.evaluate_truth_value(self.kg))
|
|
427
|
-
|
|
428
|
-
if support_count >= min_support:
|
|
429
|
-
rule.support = support_count
|
|
430
|
-
rules.append(rule)
|
|
431
|
-
|
|
432
|
-
return rules
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
def parse_rule_from_string(rule_str: str, entities: Dict[str, Entity],
|
|
436
|
-
relations: Dict[str, Relation]) -> LogicalRule:
|
|
437
|
-
"""
|
|
438
|
-
Parse logical rule from string format
|
|
439
|
-
"""
|
|
440
|
-
# Split by implication arrow
|
|
441
|
-
if "⇒" not in rule_str:
|
|
442
|
-
raise ValueError(f"Rule must contain ⇒: {rule_str}")
|
|
443
|
-
|
|
444
|
-
body_str, head_str = rule_str.split("⇒", 1)
|
|
445
|
-
|
|
446
|
-
# Parse body atoms (split by ∧)
|
|
447
|
-
body_atom_strs = [atom.strip() for atom in body_str.split("∧")]
|
|
448
|
-
body_atoms = []
|
|
449
|
-
|
|
450
|
-
for atom_str in body_atom_strs:
|
|
451
|
-
atom = _parse_atom_from_string(atom_str.strip(), entities, relations)
|
|
452
|
-
body_atoms.append(atom)
|
|
453
|
-
|
|
454
|
-
# Parse head atom
|
|
455
|
-
head_atom = _parse_atom_from_string(head_str.strip(), entities, relations)
|
|
456
|
-
|
|
457
|
-
rule_id = f"parsed_{hash(rule_str) % 10000}"
|
|
458
|
-
|
|
459
|
-
return LogicalRule(
|
|
460
|
-
rule_id=rule_id,
|
|
461
|
-
body=body_atoms,
|
|
462
|
-
head=head_atom,
|
|
463
|
-
rule_type=RuleType.GENERAL,
|
|
464
|
-
confidence=0.5
|
|
465
|
-
)
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
def _parse_atom_from_string(atom_str: str, entities: Dict[str, Entity],
|
|
469
|
-
relations: Dict[str, Relation]) -> Atom:
|
|
470
|
-
"""Parse single atom from string: 'Pred(arg1,arg2)'"""
|
|
471
|
-
# Extract predicate and arguments
|
|
472
|
-
if "(" not in atom_str or ")" not in atom_str:
|
|
473
|
-
raise ValueError(f"Invalid atom format: {atom_str}")
|
|
474
|
-
|
|
475
|
-
pred_name = atom_str[:atom_str.index("(")].strip()
|
|
476
|
-
args_str = atom_str[atom_str.index("(")+1:atom_str.rindex(")")].strip()
|
|
477
|
-
|
|
478
|
-
# Get or create relation
|
|
479
|
-
if pred_name not in relations:
|
|
480
|
-
relations[pred_name] = Relation(name=pred_name)
|
|
481
|
-
predicate = relations[pred_name]
|
|
482
|
-
|
|
483
|
-
# Parse arguments
|
|
484
|
-
arg_names = [arg.strip() for arg in args_str.split(",")]
|
|
485
|
-
arguments = []
|
|
486
|
-
|
|
487
|
-
for arg_name in arg_names:
|
|
488
|
-
# Check if it's a variable (lowercase) or entity
|
|
489
|
-
if arg_name.islower() or arg_name.startswith('?'):
|
|
490
|
-
arguments.append(Variable(arg_name))
|
|
491
|
-
else:
|
|
492
|
-
# Entity - create if doesn't exist
|
|
493
|
-
if arg_name not in entities:
|
|
494
|
-
entities[arg_name] = Entity(name=arg_name)
|
|
495
|
-
arguments.append(entities[arg_name])
|
|
496
|
-
|
|
1
|
+
"""
|
|
2
|
+
Logical Rules and Ground Rules for NPLL
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List, Set, Dict, Tuple, Optional, Iterator, Any
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from enum import Enum
|
|
8
|
+
import itertools
|
|
9
|
+
import logging
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
|
|
12
|
+
from .knowledge_graph import Entity, Relation, Triple, KnowledgeGraph
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Variable:
|
|
18
|
+
"""Variable in logical rules"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, name: str):
|
|
21
|
+
self.name = name
|
|
22
|
+
|
|
23
|
+
def __str__(self) -> str:
|
|
24
|
+
return self.name
|
|
25
|
+
|
|
26
|
+
def __repr__(self) -> str:
|
|
27
|
+
return f"Var({self.name})"
|
|
28
|
+
|
|
29
|
+
def __eq__(self, other) -> bool:
|
|
30
|
+
return isinstance(other, Variable) and self.name == other.name
|
|
31
|
+
|
|
32
|
+
def __hash__(self) -> int:
|
|
33
|
+
return hash(self.name)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class Atom:
|
|
38
|
+
"""
|
|
39
|
+
Atomic formula in first-order logic: Pred(arg1, arg2)
|
|
40
|
+
"""
|
|
41
|
+
predicate: Relation
|
|
42
|
+
arguments: Tuple[Any, ...]
|
|
43
|
+
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
# Validate arguments
|
|
46
|
+
for arg in self.arguments:
|
|
47
|
+
if not isinstance(arg, (Variable, Entity)):
|
|
48
|
+
raise ValueError(f"Atom arguments must be Variable or Entity, got {type(arg)}")
|
|
49
|
+
|
|
50
|
+
def __str__(self) -> str:
|
|
51
|
+
args_str = ", ".join(str(arg) for arg in self.arguments)
|
|
52
|
+
return f"{self.predicate.name}({args_str})"
|
|
53
|
+
|
|
54
|
+
def is_ground(self) -> bool:
|
|
55
|
+
"""Check if atom is ground (no variables)"""
|
|
56
|
+
return all(isinstance(arg, Entity) for arg in self.arguments)
|
|
57
|
+
|
|
58
|
+
def get_variables(self) -> Set[Variable]:
|
|
59
|
+
"""Get all variables in this atom"""
|
|
60
|
+
return {arg for arg in self.arguments if isinstance(arg, Variable)}
|
|
61
|
+
|
|
62
|
+
def ground_with_substitution(self, substitution: Dict[Variable, Entity]) -> 'Atom':
|
|
63
|
+
"""Ground atom by substituting variables with entities"""
|
|
64
|
+
new_args = []
|
|
65
|
+
for arg in self.arguments:
|
|
66
|
+
if isinstance(arg, Variable) and arg in substitution:
|
|
67
|
+
new_args.append(substitution[arg])
|
|
68
|
+
else:
|
|
69
|
+
new_args.append(arg)
|
|
70
|
+
|
|
71
|
+
return Atom(predicate=self.predicate, arguments=tuple(new_args))
|
|
72
|
+
|
|
73
|
+
def to_triple(self) -> Optional[Triple]:
|
|
74
|
+
"""Convert ground atom to triple (if binary and ground)"""
|
|
75
|
+
if len(self.arguments) == 2 and self.is_ground():
|
|
76
|
+
head, tail = self.arguments
|
|
77
|
+
return Triple(head=head, relation=self.predicate, tail=tail)
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RuleType(Enum):
|
|
82
|
+
"""Types of logical rules"""
|
|
83
|
+
HORN_CLAUSE = "horn" # Standard Horn clause
|
|
84
|
+
EQUALITY = "equality" # Equality rules
|
|
85
|
+
TRANSITIVITY = "transitivity" # Transitive rules
|
|
86
|
+
SYMMETRY = "symmetry" # Symmetric rules
|
|
87
|
+
GENERAL = "general" # General first-order rules
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class LogicalRule:
|
|
92
|
+
"""
|
|
93
|
+
First-order logical rule: Body ⇒ Head
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
rule_id: str
|
|
97
|
+
body: List[Atom] # Body atoms (premise)
|
|
98
|
+
head: Atom # Head atom (conclusion)
|
|
99
|
+
rule_type: RuleType = RuleType.GENERAL
|
|
100
|
+
confidence: float = 0.5 # Initial confidence score
|
|
101
|
+
support: int = 0 # Number of supporting instances
|
|
102
|
+
|
|
103
|
+
def __post_init__(self):
|
|
104
|
+
"""Validate rule structure"""
|
|
105
|
+
if not self.body:
|
|
106
|
+
raise ValueError("Rule body cannot be empty")
|
|
107
|
+
|
|
108
|
+
if not isinstance(self.head, Atom):
|
|
109
|
+
raise ValueError("Rule head must be an Atom")
|
|
110
|
+
|
|
111
|
+
# Check variable consistency
|
|
112
|
+
body_vars = set()
|
|
113
|
+
for atom in self.body:
|
|
114
|
+
body_vars.update(atom.get_variables())
|
|
115
|
+
|
|
116
|
+
head_vars = self.head.get_variables()
|
|
117
|
+
|
|
118
|
+
# All head variables should appear in body
|
|
119
|
+
if not head_vars.issubset(body_vars):
|
|
120
|
+
logger.warning(f"Rule {self.rule_id}: Head variables not in body")
|
|
121
|
+
|
|
122
|
+
def get_all_variables(self) -> Set[Variable]:
|
|
123
|
+
"""Get all variables in the rule"""
|
|
124
|
+
variables = set()
|
|
125
|
+
for atom in self.body:
|
|
126
|
+
variables.update(atom.get_variables())
|
|
127
|
+
variables.update(self.head.get_variables())
|
|
128
|
+
return variables
|
|
129
|
+
|
|
130
|
+
def get_predicates(self) -> Set[Relation]:
|
|
131
|
+
"""Get all predicates used in the rule"""
|
|
132
|
+
predicates = {atom.predicate for atom in self.body}
|
|
133
|
+
predicates.add(self.head.predicate)
|
|
134
|
+
return predicates
|
|
135
|
+
|
|
136
|
+
def generate_ground_rules(self, kg: KnowledgeGraph,
|
|
137
|
+
max_groundings: int = 1000) -> List['GroundRule']:
|
|
138
|
+
"""
|
|
139
|
+
Generate ground rules by substituting variables with entities
|
|
140
|
+
"""
|
|
141
|
+
ground_rules = []
|
|
142
|
+
entities = list(kg.entities)
|
|
143
|
+
|
|
144
|
+
if not entities:
|
|
145
|
+
return ground_rules
|
|
146
|
+
|
|
147
|
+
# Get all variables that need to be substituted
|
|
148
|
+
variables = list(self.get_all_variables())
|
|
149
|
+
|
|
150
|
+
if not variables:
|
|
151
|
+
# Already ground rule
|
|
152
|
+
ground_body = [atom.to_triple() for atom in self.body if atom.is_ground()]
|
|
153
|
+
ground_head_triple = self.head.to_triple()
|
|
154
|
+
|
|
155
|
+
if ground_head_triple and all(t is not None for t in ground_body):
|
|
156
|
+
ground_rules.append(GroundRule(
|
|
157
|
+
rule_id=self.rule_id,
|
|
158
|
+
body_facts=[t for t in ground_body if t is not None],
|
|
159
|
+
head_fact=ground_head_triple,
|
|
160
|
+
parent_rule=self
|
|
161
|
+
))
|
|
162
|
+
return ground_rules
|
|
163
|
+
|
|
164
|
+
# Generate all possible variable substitutions
|
|
165
|
+
# Limit combinations to prevent explosion
|
|
166
|
+
max_entities_per_var = min(len(entities), max_groundings // len(variables) + 1)
|
|
167
|
+
|
|
168
|
+
substitution_count = 0
|
|
169
|
+
for substitution_values in itertools.product(entities[:max_entities_per_var],
|
|
170
|
+
repeat=len(variables)):
|
|
171
|
+
if substitution_count >= max_groundings:
|
|
172
|
+
break
|
|
173
|
+
|
|
174
|
+
substitution = dict(zip(variables, substitution_values))
|
|
175
|
+
|
|
176
|
+
# Generate ground atoms
|
|
177
|
+
try:
|
|
178
|
+
ground_body_atoms = [atom.ground_with_substitution(substitution)
|
|
179
|
+
for atom in self.body]
|
|
180
|
+
ground_head_atom = self.head.ground_with_substitution(substitution)
|
|
181
|
+
|
|
182
|
+
# Convert to triples
|
|
183
|
+
ground_body_triples = [atom.to_triple() for atom in ground_body_atoms
|
|
184
|
+
if atom.is_ground()]
|
|
185
|
+
ground_head_triple = ground_head_atom.to_triple()
|
|
186
|
+
|
|
187
|
+
# Check if all conversions successful
|
|
188
|
+
if (ground_head_triple and
|
|
189
|
+
len(ground_body_triples) == len(ground_body_atoms) and
|
|
190
|
+
all(t is not None for t in ground_body_triples)):
|
|
191
|
+
|
|
192
|
+
ground_rule = GroundRule(
|
|
193
|
+
rule_id=f"{self.rule_id}_ground_{substitution_count}",
|
|
194
|
+
body_facts=ground_body_triples,
|
|
195
|
+
head_fact=ground_head_triple,
|
|
196
|
+
parent_rule=self,
|
|
197
|
+
substitution=substitution.copy()
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
ground_rules.append(ground_rule)
|
|
201
|
+
substitution_count += 1
|
|
202
|
+
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.debug(f"Failed to ground rule {self.rule_id}: {e}")
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
logger.info(f"Generated {len(ground_rules)} ground rules for {self.rule_id}")
|
|
208
|
+
return ground_rules
|
|
209
|
+
|
|
210
|
+
def to_cnf(self) -> str:
|
|
211
|
+
"""
|
|
212
|
+
Convert rule to Conjunctive Normal Form (CNF)
|
|
213
|
+
"""
|
|
214
|
+
# Body atoms become negated disjuncts
|
|
215
|
+
cnf_parts = [f"¬{atom}" for atom in self.body]
|
|
216
|
+
# Head atom becomes positive disjunct
|
|
217
|
+
cnf_parts.append(str(self.head))
|
|
218
|
+
|
|
219
|
+
return " ∨ ".join(cnf_parts)
|
|
220
|
+
|
|
221
|
+
def __str__(self) -> str:
|
|
222
|
+
body_str = " ∧ ".join(str(atom) for atom in self.body)
|
|
223
|
+
return f"{body_str} ⇒ {self.head}"
|
|
224
|
+
|
|
225
|
+
def __repr__(self) -> str:
|
|
226
|
+
return f"LogicalRule(id={self.rule_id}, {self})"
|
|
227
|
+
|
|
228
|
+
# --- Serialization helpers ---
|
|
229
|
+
def serialize(self) -> Dict[str, Any]:
|
|
230
|
+
def _atom_to_dict(atom: Atom) -> Dict[str, Any]:
|
|
231
|
+
args = []
|
|
232
|
+
for a in atom.arguments:
|
|
233
|
+
if isinstance(a, Variable):
|
|
234
|
+
args.append({'type': 'var', 'name': a.name})
|
|
235
|
+
elif isinstance(a, Entity):
|
|
236
|
+
args.append({'type': 'ent', 'name': a.name})
|
|
237
|
+
else:
|
|
238
|
+
args.append({'type': 'raw', 'value': str(a)})
|
|
239
|
+
return {'predicate': atom.predicate.name, 'arguments': args}
|
|
240
|
+
|
|
241
|
+
return {
|
|
242
|
+
'rule_id': self.rule_id,
|
|
243
|
+
'rule_type': self.rule_type.value,
|
|
244
|
+
'confidence': self.confidence,
|
|
245
|
+
'support': self.support,
|
|
246
|
+
'body': [_atom_to_dict(a) for a in self.body],
|
|
247
|
+
'head': _atom_to_dict(self.head),
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
@staticmethod
|
|
251
|
+
def deserialize(data: Dict[str, Any]) -> 'LogicalRule':
|
|
252
|
+
rule_id = data['rule_id']
|
|
253
|
+
rule_type = RuleType(data.get('rule_type', RuleType.GENERAL.value))
|
|
254
|
+
confidence = data.get('confidence', 0.5)
|
|
255
|
+
support = data.get('support', 0)
|
|
256
|
+
|
|
257
|
+
def _atom_from_dict(d: Dict[str, Any]) -> Atom:
|
|
258
|
+
pred = Relation(d['predicate'])
|
|
259
|
+
args = []
|
|
260
|
+
for a in d['arguments']:
|
|
261
|
+
if a.get('type') == 'var':
|
|
262
|
+
args.append(Variable(a['name']))
|
|
263
|
+
elif a.get('type') == 'ent':
|
|
264
|
+
args.append(Entity(a['name']))
|
|
265
|
+
else:
|
|
266
|
+
# Fallback: treat as entity by name
|
|
267
|
+
args.append(Entity(str(a.get('value', ''))))
|
|
268
|
+
return Atom(predicate=pred, arguments=tuple(args))
|
|
269
|
+
|
|
270
|
+
body_atoms = [_atom_from_dict(x) for x in data['body']]
|
|
271
|
+
head_atom = _atom_from_dict(data['head'])
|
|
272
|
+
return LogicalRule(
|
|
273
|
+
rule_id=rule_id,
|
|
274
|
+
body=body_atoms,
|
|
275
|
+
head=head_atom,
|
|
276
|
+
rule_type=rule_type,
|
|
277
|
+
confidence=confidence,
|
|
278
|
+
support=support,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@dataclass
|
|
283
|
+
class GroundRule:
|
|
284
|
+
"""
|
|
285
|
+
Ground instance of a logical rule with all variables substituted
|
|
286
|
+
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
rule_id: str
|
|
290
|
+
body_facts: List[Triple] # Ground body facts
|
|
291
|
+
head_fact: Triple # Ground head fact
|
|
292
|
+
parent_rule: LogicalRule # Original rule this was grounded from
|
|
293
|
+
substitution: Optional[Dict[Variable, Entity]] = None
|
|
294
|
+
|
|
295
|
+
def __post_init__(self):
|
|
296
|
+
"""Validate ground rule"""
|
|
297
|
+
if not self.body_facts:
|
|
298
|
+
raise ValueError("Ground rule body cannot be empty")
|
|
299
|
+
|
|
300
|
+
if not isinstance(self.head_fact, Triple):
|
|
301
|
+
raise ValueError("Ground rule head must be a Triple")
|
|
302
|
+
|
|
303
|
+
def evaluate_truth_value(self, kg: KnowledgeGraph) -> bool:
|
|
304
|
+
"""
|
|
305
|
+
Evaluate truth value of ground rule given knowledge graph
|
|
306
|
+
"""
|
|
307
|
+
# Check if all body facts are true in KG
|
|
308
|
+
body_satisfied = all(
|
|
309
|
+
kg.contains_fact(fact.head.name, fact.relation.name, fact.tail.name)
|
|
310
|
+
for fact in self.body_facts
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# If body is false, rule is vacuously true
|
|
314
|
+
if not body_satisfied:
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
# If body is true, check if head is true
|
|
318
|
+
head_satisfied = kg.contains_fact(
|
|
319
|
+
self.head_fact.head.name,
|
|
320
|
+
self.head_fact.relation.name,
|
|
321
|
+
self.head_fact.tail.name
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return head_satisfied
|
|
325
|
+
|
|
326
|
+
def get_all_facts(self) -> List[Triple]:
|
|
327
|
+
"""Get all facts (body + head) in this ground rule"""
|
|
328
|
+
return self.body_facts + [self.head_fact]
|
|
329
|
+
|
|
330
|
+
def get_fact_truth_values(self, kg: KnowledgeGraph) -> Dict[Triple, bool]:
|
|
331
|
+
"""Get truth values for all facts in this ground rule"""
|
|
332
|
+
truth_values = {}
|
|
333
|
+
|
|
334
|
+
for fact in self.body_facts:
|
|
335
|
+
truth_values[fact] = kg.contains_fact(
|
|
336
|
+
fact.head.name, fact.relation.name, fact.tail.name
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
truth_values[self.head_fact] = kg.contains_fact(
|
|
340
|
+
self.head_fact.head.name,
|
|
341
|
+
self.head_fact.relation.name,
|
|
342
|
+
self.head_fact.tail.name
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return truth_values
|
|
346
|
+
|
|
347
|
+
def __str__(self) -> str:
|
|
348
|
+
body_str = " ∧ ".join(str(fact) for fact in self.body_facts)
|
|
349
|
+
return f"{body_str} ⇒ {self.head_fact}"
|
|
350
|
+
|
|
351
|
+
def __repr__(self) -> str:
|
|
352
|
+
return f"GroundRule(id={self.rule_id}, {self})"
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class RuleGenerator:
|
|
356
|
+
"""Generate logical rules from knowledge graph patterns"""
|
|
357
|
+
|
|
358
|
+
def __init__(self, kg: KnowledgeGraph):
|
|
359
|
+
self.kg = kg
|
|
360
|
+
|
|
361
|
+
def generate_simple_rules(self, min_support: int = 2,
|
|
362
|
+
max_rule_length: int = 3) -> List[LogicalRule]:
|
|
363
|
+
"""
|
|
364
|
+
Generate simple logical rules from knowledge graph patterns
|
|
365
|
+
"""
|
|
366
|
+
rules = []
|
|
367
|
+
|
|
368
|
+
# Generate transitivity rules: R1(x,y) ∧ R2(y,z) ⇒ R3(x,z)
|
|
369
|
+
relations = list(self.kg.relations)
|
|
370
|
+
|
|
371
|
+
for r1, r2, r3 in itertools.combinations_with_replacement(relations, 3):
|
|
372
|
+
if r1 == r2 == r3: # Skip trivial cases
|
|
373
|
+
continue
|
|
374
|
+
|
|
375
|
+
# Create variables
|
|
376
|
+
x, y, z = Variable('x'), Variable('y'), Variable('z')
|
|
377
|
+
|
|
378
|
+
# Create atoms
|
|
379
|
+
atom1 = Atom(predicate=r1, arguments=(x, y))
|
|
380
|
+
atom2 = Atom(predicate=r2, arguments=(y, z))
|
|
381
|
+
head_atom = Atom(predicate=r3, arguments=(x, z))
|
|
382
|
+
|
|
383
|
+
rule = LogicalRule(
|
|
384
|
+
rule_id=f"trans_{r1.name}_{r2.name}_{r3.name}",
|
|
385
|
+
body=[atom1, atom2],
|
|
386
|
+
head=head_atom,
|
|
387
|
+
rule_type=RuleType.TRANSITIVITY,
|
|
388
|
+
confidence=0.5 # Will be learned
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Check support by grounding rule
|
|
392
|
+
ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
|
|
393
|
+
|
|
394
|
+
# Count supporting instances
|
|
395
|
+
support_count = sum(1 for gr in ground_rules
|
|
396
|
+
if gr.evaluate_truth_value(self.kg))
|
|
397
|
+
|
|
398
|
+
if support_count >= min_support:
|
|
399
|
+
rule.support = support_count
|
|
400
|
+
rules.append(rule)
|
|
401
|
+
|
|
402
|
+
logger.info(f"Generated {len(rules)} rules with min support {min_support}")
|
|
403
|
+
return rules
|
|
404
|
+
|
|
405
|
+
def generate_symmetry_rules(self, min_support: int = 2) -> List[LogicalRule]:
|
|
406
|
+
"""Generate symmetry rules: R(x,y) ⇒ R(y,x)"""
|
|
407
|
+
rules = []
|
|
408
|
+
|
|
409
|
+
for relation in self.kg.relations:
|
|
410
|
+
x, y = Variable('x'), Variable('y')
|
|
411
|
+
|
|
412
|
+
body_atom = Atom(predicate=relation, arguments=(x, y))
|
|
413
|
+
head_atom = Atom(predicate=relation, arguments=(y, x))
|
|
414
|
+
|
|
415
|
+
rule = LogicalRule(
|
|
416
|
+
rule_id=f"sym_{relation.name}",
|
|
417
|
+
body=[body_atom],
|
|
418
|
+
head=head_atom,
|
|
419
|
+
rule_type=RuleType.SYMMETRY,
|
|
420
|
+
confidence=0.5
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Check support
|
|
424
|
+
ground_rules = rule.generate_ground_rules(self.kg, max_groundings=100)
|
|
425
|
+
support_count = sum(1 for gr in ground_rules
|
|
426
|
+
if gr.evaluate_truth_value(self.kg))
|
|
427
|
+
|
|
428
|
+
if support_count >= min_support:
|
|
429
|
+
rule.support = support_count
|
|
430
|
+
rules.append(rule)
|
|
431
|
+
|
|
432
|
+
return rules
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def parse_rule_from_string(rule_str: str, entities: Dict[str, Entity],
|
|
436
|
+
relations: Dict[str, Relation]) -> LogicalRule:
|
|
437
|
+
"""
|
|
438
|
+
Parse logical rule from string format
|
|
439
|
+
"""
|
|
440
|
+
# Split by implication arrow
|
|
441
|
+
if "⇒" not in rule_str:
|
|
442
|
+
raise ValueError(f"Rule must contain ⇒: {rule_str}")
|
|
443
|
+
|
|
444
|
+
body_str, head_str = rule_str.split("⇒", 1)
|
|
445
|
+
|
|
446
|
+
# Parse body atoms (split by ∧)
|
|
447
|
+
body_atom_strs = [atom.strip() for atom in body_str.split("∧")]
|
|
448
|
+
body_atoms = []
|
|
449
|
+
|
|
450
|
+
for atom_str in body_atom_strs:
|
|
451
|
+
atom = _parse_atom_from_string(atom_str.strip(), entities, relations)
|
|
452
|
+
body_atoms.append(atom)
|
|
453
|
+
|
|
454
|
+
# Parse head atom
|
|
455
|
+
head_atom = _parse_atom_from_string(head_str.strip(), entities, relations)
|
|
456
|
+
|
|
457
|
+
rule_id = f"parsed_{hash(rule_str) % 10000}"
|
|
458
|
+
|
|
459
|
+
return LogicalRule(
|
|
460
|
+
rule_id=rule_id,
|
|
461
|
+
body=body_atoms,
|
|
462
|
+
head=head_atom,
|
|
463
|
+
rule_type=RuleType.GENERAL,
|
|
464
|
+
confidence=0.5
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _parse_atom_from_string(atom_str: str, entities: Dict[str, Entity],
|
|
469
|
+
relations: Dict[str, Relation]) -> Atom:
|
|
470
|
+
"""Parse single atom from string: 'Pred(arg1,arg2)'"""
|
|
471
|
+
# Extract predicate and arguments
|
|
472
|
+
if "(" not in atom_str or ")" not in atom_str:
|
|
473
|
+
raise ValueError(f"Invalid atom format: {atom_str}")
|
|
474
|
+
|
|
475
|
+
pred_name = atom_str[:atom_str.index("(")].strip()
|
|
476
|
+
args_str = atom_str[atom_str.index("(")+1:atom_str.rindex(")")].strip()
|
|
477
|
+
|
|
478
|
+
# Get or create relation
|
|
479
|
+
if pred_name not in relations:
|
|
480
|
+
relations[pred_name] = Relation(name=pred_name)
|
|
481
|
+
predicate = relations[pred_name]
|
|
482
|
+
|
|
483
|
+
# Parse arguments
|
|
484
|
+
arg_names = [arg.strip() for arg in args_str.split(",")]
|
|
485
|
+
arguments = []
|
|
486
|
+
|
|
487
|
+
for arg_name in arg_names:
|
|
488
|
+
# Check if it's a variable (lowercase) or entity
|
|
489
|
+
if arg_name.islower() or arg_name.startswith('?'):
|
|
490
|
+
arguments.append(Variable(arg_name))
|
|
491
|
+
else:
|
|
492
|
+
# Entity - create if doesn't exist
|
|
493
|
+
if arg_name not in entities:
|
|
494
|
+
entities[arg_name] = Entity(name=arg_name)
|
|
495
|
+
arguments.append(entities[arg_name])
|
|
496
|
+
|
|
497
497
|
return Atom(predicate=predicate, arguments=tuple(arguments))
|