unicode-fol-kit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,779 @@
1
+ """AST node definitions for a first-order logic (FOL) language.
2
+
3
+ Each node class represents a syntactic construct — terms (Variable, Constant,
4
+ Number, Function) or formulas (Atom, Not, And, Or, Xor, Implies, Iff,
5
+ Quantifier) — and provides serialisation to dict, Z3, Prover9, and TPTP.
6
+ FOLTransformer bridges the Lark parse tree to these AST nodes.
7
+ """
8
+
9
+ from typing import List, Union, Dict
10
+ from lark import Transformer
11
+ from dataclasses import dataclass, fields
12
+
13
+ import z3
14
+
15
+ _SORT = z3.DeclareSort("S")
16
+
17
+
18
+ # =========================
19
+ # Z3 Environment
20
+ # =========================
21
+
22
+ class Z3Env:
23
+ """Tracks declared Z3 symbols. Single sort for all terms."""
24
+
25
+ def __init__(self):
26
+ """Initialise empty symbol, function, and predicate tables."""
27
+ self.symbols: Dict[str, z3.ExprRef] = {}
28
+ self.funcs: Dict[str, z3.FuncDeclRef] = {}
29
+ self.preds: Dict[str, z3.FuncDeclRef] = {}
30
+
31
+ def get_symbol(self, name: str) -> z3.ExprRef:
32
+ """Get or create a Z3 constant (used for both variables and constants)."""
33
+ if name not in self.symbols:
34
+ self.symbols[name] = z3.Const(name, _SORT)
35
+ return self.symbols[name]
36
+
37
+ def get_func(self, name: str, arity: int) -> z3.FuncDeclRef:
38
+ """Get or create an uninterpreted Z3 function of the given arity mapping S^arity -> S."""
39
+ if name not in self.funcs:
40
+ self.funcs[name] = z3.Function(name, *([_SORT] * arity), _SORT)
41
+ return self.funcs[name]
42
+
43
+ def get_pred(self, name: str, arity: int) -> z3.FuncDeclRef:
44
+ """Get or create an uninterpreted Z3 predicate of the given arity mapping S^arity -> Bool."""
45
+ if name not in self.preds:
46
+ self.preds[name] = z3.Function(name, *([_SORT] * arity), z3.BoolSort())
47
+ return self.preds[name]
48
+
49
+
50
+ # =========================
51
+ # Base Node
52
+ # =========================
53
+
54
+ class Node:
55
+ """Base class for all AST nodes."""
56
+
57
+ def to_dict(self) -> dict:
58
+ """Serialise this node to a JSON-compatible dictionary."""
59
+ raise NotImplementedError
60
+
61
+ def to_z3(self, env: Z3Env = None) -> z3.ExprRef:
62
+ """Translate this node into a Z3 expression using the given environment."""
63
+ raise NotImplementedError
64
+
65
+ def to_prover9(self) -> str:
66
+ """Render this node as a Prover9-syntax string."""
67
+ raise NotImplementedError
68
+
69
+ def to_tptp(self) -> str:
70
+ """Render this node as a TPTP-syntax string."""
71
+ raise NotImplementedError
72
+
73
+ @staticmethod
74
+ def from_dict(d: dict) -> "Node":
75
+ """Deserialise a node from a dictionary produced by to_dict."""
76
+ t = d["_type"]
77
+ if t not in NODE_CLASSES:
78
+ raise ValueError(f"Unknown type: {t}")
79
+ return NODE_CLASSES[t].from_dict(d)
80
+
81
+ _TREE_LABELS = {
82
+ "And": "∧", "Or": "∨", "Xor": "⊕",
83
+ "Implies": "→", "Iff": "↔", "Not": "¬",
84
+ }
85
+
86
+ def _tree_parts(self):
87
+ """Return (label, children) for tree rendering.
88
+
89
+ Leaf terms render their value in the label and have no children.
90
+ Atom and Function render the symbol in the label and expose their
91
+ argument nodes. Quantifier shows its type and bound variable.
92
+ Everything else falls back to its dataclass fields, treating any
93
+ Node-valued field as a child.
94
+ """
95
+ cls = type(self).__name__
96
+ if cls in ("Variable", "Constant"):
97
+ return f"{cls}: {self.name}", []
98
+ if cls == "Number":
99
+ return f"Number: {self.value}", []
100
+ if cls == "Atom":
101
+ return f"Atom: {self.predicate}", list(self.args)
102
+ if cls == "Function":
103
+ return f"Function: {self.name}", list(self.args)
104
+ if cls == "Quantifier":
105
+ return f"{self.type} {self.variable.name}", [self.formula]
106
+
107
+ label = self._TREE_LABELS.get(cls, cls)
108
+ children = []
109
+ for f in fields(self):
110
+ value = getattr(self, f.name)
111
+ if isinstance(value, Node):
112
+ children.append(value)
113
+ elif isinstance(value, list):
114
+ children.extend(c for c in value if isinstance(c, Node))
115
+ return label, children
116
+
117
+ def tree_str(self) -> str:
118
+ """Render the AST as a multi-line ASCII tree using ├──/└── connectors."""
119
+ label, children = self._tree_parts()
120
+ lines = [label]
121
+ for i, child in enumerate(children):
122
+ last = i == len(children) - 1
123
+ branch = "└── " if last else "├── "
124
+ prefix = " " if last else "│ "
125
+ sub = child.tree_str().split("\n")
126
+ lines.append(branch + sub[0])
127
+ lines.extend(prefix + s for s in sub[1:])
128
+ return "\n".join(lines)
129
+
130
+
131
+ # =========================
132
+ # Term Nodes
133
+ # =========================
134
+
135
+ @dataclass
136
+ class Variable(Node):
137
+ """A logical variable, represented by a single lowercase letter in the grammar."""
138
+
139
+ name: str
140
+
141
+ def to_dict(self):
142
+ """Serialise to dict with type tag and variable name."""
143
+ return {"_type": "Variable", "name": self.name}
144
+
145
+ @staticmethod
146
+ def from_dict(d):
147
+ """Deserialise a Variable from a dict produced by to_dict."""
148
+ return Variable(d["name"])
149
+
150
+ def to_z3(self, env: Z3Env = None):
151
+ """Translate to a Z3 constant in the uninterpreted sort S."""
152
+ return (env or Z3Env()).get_symbol(self.name)
153
+
154
+ def to_prover9(self) -> str:
155
+ """Render the variable name as-is; Prover9 treats uppercase as variables."""
156
+ return self.name
157
+
158
+ def to_tptp(self) -> str:
159
+ """Render variable in TPTP syntax. TPTP requires variables to be uppercase; single lowercase letters are capitalized."""
160
+ return self.name.upper()
161
+
162
+
163
+ @dataclass
164
+ class Constant(Node):
165
+ """A ground constant, produced by a bare NAME or by the c_-prefixed CONSTANT terminal."""
166
+
167
+ name: str
168
+
169
+ def to_dict(self):
170
+ """Serialise to dict with type tag and constant name."""
171
+ return {"_type": "Constant", "name": self.name}
172
+
173
+ @staticmethod
174
+ def from_dict(d):
175
+ """Deserialise a Constant from a dict produced by to_dict."""
176
+ return Constant(d["name"])
177
+
178
+ def to_z3(self, env: Z3Env = None):
179
+ """Translate to a Z3 constant in the uninterpreted sort S."""
180
+ return (env or Z3Env()).get_symbol(self.name)
181
+
182
+ def to_prover9(self) -> str:
183
+ """Render the constant name as-is."""
184
+ return self.name
185
+
186
+ def to_tptp(self) -> str:
187
+ """Render constant in TPTP syntax. TPTP requires constants to start with a lowercase letter."""
188
+ return self.name.lower()
189
+
190
+
191
+ @dataclass
192
+ class Number(Node):
193
+ """A numeric literal node, produced by the NUMBER terminal in the grammar."""
194
+
195
+ value: Union[int, float]
196
+
197
+ def to_dict(self):
198
+ """Serialise to dict with type tag and numeric value."""
199
+ return {"_type": "Number", "value": self.value}
200
+
201
+ @staticmethod
202
+ def from_dict(d):
203
+ """Deserialise a Number from a dict produced by to_dict."""
204
+ return Number(d["value"])
205
+
206
+ def to_z3(self, env: Z3Env = None):
207
+ """Encode the number as a named constant in the uninterpreted sort S."""
208
+ return (env or Z3Env()).get_symbol(str(self.value))
209
+
210
+ def to_prover9(self) -> str:
211
+ """Render the numeric value as a plain string."""
212
+ return str(self.value)
213
+
214
+ def to_tptp(self) -> str:
215
+ """Render number in TPTP syntax as an integer or rational literal."""
216
+ return str(self.value)
217
+
218
+
219
+ @dataclass
220
+ class Function(Node):
221
+ """A function application node, covering both named functions and arithmetic operators."""
222
+
223
+ name: str
224
+ args: List[Node]
225
+
226
+ INFIX_OPS = {"+", "-", "*", "/"}
227
+
228
+ def to_dict(self):
229
+ """Serialise to dict with type tag, function name, and recursively serialised arguments."""
230
+ return {
231
+ "_type": "Function",
232
+ "name": self.name,
233
+ "args": [a.to_dict() for a in self.args]
234
+ }
235
+
236
+ @staticmethod
237
+ def from_dict(d):
238
+ """Deserialise a Function from a dict produced by to_dict."""
239
+ return Function(d["name"], [Node.from_dict(a) for a in d["args"]])
240
+
241
+ def to_z3(self, env: Z3Env = None):
242
+ """Translate to an uninterpreted Z3 function application in sort S."""
243
+ env = env or Z3Env()
244
+ z3_args = [a.to_z3(env) for a in self.args]
245
+ func = env.get_func(self.name, len(self.args))
246
+ return func(*z3_args)
247
+
248
+ def to_prover9(self) -> str:
249
+ """Render in Prover9 syntax, using infix notation for arithmetic operators."""
250
+ if self.name in self.INFIX_OPS and len(self.args) == 2:
251
+ left = self.args[0].to_prover9()
252
+ right = self.args[1].to_prover9()
253
+ return f"({left} {self.name} {right})"
254
+
255
+ args_str = ", ".join(a.to_prover9() for a in self.args)
256
+ return f"{self.name}({args_str})"
257
+
258
+ TPTP_ARITH_OPS = {
259
+ "+": "$sum",
260
+ "-": "$difference",
261
+ "*": "$product",
262
+ "/": "$quotient",
263
+ }
264
+
265
+ def to_tptp(self) -> str:
266
+ """Render function application in TPTP syntax.
267
+
268
+ Arithmetic operators (+, -, *, /) are mapped to their TPTP dollar-word
269
+ equivalents ($sum, $difference, $product, $quotient) and emitted in
270
+ prefix notation. All other functions are emitted as lowercase
271
+ identifiers with a parenthesised argument list.
272
+ """
273
+ args_str = ",".join(a.to_tptp() for a in self.args)
274
+ tptp_name = self.TPTP_ARITH_OPS.get(self.name, self.name.lower())
275
+ return f"{tptp_name}({args_str})"
276
+
277
+
278
+ # =========================
279
+ # Formula Nodes
280
+ # =========================
281
+
282
+ @dataclass
283
+ class Atom(Node):
284
+ """An atomic formula: either a named predicate application or an infix comparison."""
285
+
286
+ predicate: str
287
+ args: List[Node]
288
+
289
+ INFIX_PREDS_P9 = {
290
+ "=": "=", "<": "<", ">": ">",
291
+ "≤": "<=", "≥": ">=", "≠": "!=",
292
+ }
293
+
294
+ def to_dict(self):
295
+ """Serialise to dict with type tag, predicate name, and recursively serialised arguments."""
296
+ return {
297
+ "_type": "Atom",
298
+ "predicate": self.predicate,
299
+ "args": [a.to_dict() for a in self.args]
300
+ }
301
+
302
+ @staticmethod
303
+ def from_dict(d):
304
+ """Deserialise an Atom from a dict produced by to_dict."""
305
+ return Atom(d["predicate"], [Node.from_dict(a) for a in d["args"]])
306
+
307
+ def to_z3(self, env: Z3Env = None):
308
+ """Translate to a Z3 boolean expression.
309
+
310
+ Equality and disequality map to native Z3 operators; all other
311
+ predicates become uninterpreted Z3 functions returning Bool.
312
+ """
313
+ env = env or Z3Env()
314
+ z3_args = [a.to_z3(env) for a in self.args]
315
+
316
+ if self.predicate == "=" and len(self.args) == 2:
317
+ return z3_args[0] == z3_args[1]
318
+ if self.predicate == "≠" and len(self.args) == 2:
319
+ return z3_args[0] != z3_args[1]
320
+
321
+ pred = env.get_pred(self.predicate, len(self.args))
322
+ return pred(*z3_args)
323
+
324
+ def to_prover9(self) -> str:
325
+ """Render in Prover9 syntax, using infix notation for comparison predicates."""
326
+ if self.predicate in self.INFIX_PREDS_P9 and len(self.args) == 2:
327
+ left = self.args[0].to_prover9()
328
+ right = self.args[1].to_prover9()
329
+ op = self.INFIX_PREDS_P9[self.predicate]
330
+ return f"({left} {op} {right})"
331
+
332
+ args_str = ", ".join(a.to_prover9() for a in self.args)
333
+ return f"{self.predicate}({args_str})"
334
+
335
+ INFIX_PREDS_TPTP = {
336
+ "=": "=",
337
+ "≠": "!=",
338
+ "<": "$less",
339
+ ">": "$greater",
340
+ "≤": "$lesseq",
341
+ "≥": "$greatereq",
342
+ }
343
+
344
+ def to_tptp(self) -> str:
345
+ """Render an atom in TPTP syntax.
346
+
347
+ All infix predicates (=, !=, <, >, ≤, ≥) are kept as infix expressions,
348
+ mirroring the Prover9 approach. Arithmetic comparison predicates use
349
+ their TPTP dollar-word symbols. All other predicates are emitted as
350
+ lowercase identifiers with a parenthesised argument list.
351
+ """
352
+ if self.predicate in self.INFIX_PREDS_TPTP and len(self.args) == 2:
353
+ left = self.args[0].to_tptp()
354
+ right = self.args[1].to_tptp()
355
+ op = self.INFIX_PREDS_TPTP[self.predicate]
356
+ return f"({left} {op} {right})"
357
+
358
+ if not self.args:
359
+ return f"{self.predicate.lower()}"
360
+
361
+ args_str = ",".join(a.to_tptp() for a in self.args)
362
+ return f"{self.predicate.lower()}({args_str})"
363
+
364
+
365
+ @dataclass
366
+ class Not(Node):
367
+ """Logical negation of a formula."""
368
+
369
+ formula: Node
370
+
371
+ def to_dict(self):
372
+ """Serialise to dict with type tag and recursively serialised subformula."""
373
+ return {"_type": "Not", "formula": self.formula.to_dict()}
374
+
375
+ @staticmethod
376
+ def from_dict(d):
377
+ """Deserialise a Not from a dict produced by to_dict."""
378
+ return Not(Node.from_dict(d["formula"]))
379
+
380
+ def to_z3(self, env: Z3Env = None):
381
+ """Translate to a Z3 Not expression."""
382
+ return z3.Not(self.formula.to_z3(env or Z3Env()))
383
+
384
+ def to_prover9(self) -> str:
385
+ """Render negation in Prover9 syntax using the dash operator."""
386
+ return f"-({self.formula.to_prover9()})"
387
+
388
+ def to_tptp(self) -> str:
389
+ """Render negation in TPTP syntax using the tilde operator."""
390
+ return f"~({self.formula.to_tptp()})"
391
+
392
+
393
+ @dataclass
394
+ class And(Node):
395
+ """Conjunction of two formulas."""
396
+
397
+ left: Node
398
+ right: Node
399
+
400
+ def to_dict(self):
401
+ """Serialise to dict with type tag and recursively serialised operands."""
402
+ return {"_type": "And", "left": self.left.to_dict(), "right": self.right.to_dict()}
403
+
404
+ @staticmethod
405
+ def from_dict(d):
406
+ """Deserialise an And from a dict produced by to_dict."""
407
+ return And(Node.from_dict(d["left"]), Node.from_dict(d["right"]))
408
+
409
+ def to_z3(self, env: Z3Env = None):
410
+ """Translate to a Z3 And expression."""
411
+ env = env or Z3Env()
412
+ return z3.And(self.left.to_z3(env), self.right.to_z3(env))
413
+
414
+ def to_prover9(self) -> str:
415
+ """Render conjunction in Prover9 syntax using the ampersand operator."""
416
+ return f"({self.left.to_prover9()} & {self.right.to_prover9()})"
417
+
418
+ def to_tptp(self) -> str:
419
+ """Render conjunction in TPTP syntax using the ampersand operator."""
420
+ return f"({self.left.to_tptp()} & {self.right.to_tptp()})"
421
+
422
+
423
+ @dataclass
424
+ class Or(Node):
425
+ """Disjunction of two formulas."""
426
+
427
+ left: Node
428
+ right: Node
429
+
430
+ def to_dict(self):
431
+ """Serialise to dict with type tag and recursively serialised operands."""
432
+ return {"_type": "Or", "left": self.left.to_dict(), "right": self.right.to_dict()}
433
+
434
+ @staticmethod
435
+ def from_dict(d):
436
+ """Deserialise an Or from a dict produced by to_dict."""
437
+ return Or(Node.from_dict(d["left"]), Node.from_dict(d["right"]))
438
+
439
+ def to_z3(self, env: Z3Env = None):
440
+ """Translate to a Z3 Or expression."""
441
+ env = env or Z3Env()
442
+ return z3.Or(self.left.to_z3(env), self.right.to_z3(env))
443
+
444
+ def to_prover9(self) -> str:
445
+ """Render disjunction in Prover9 syntax using the pipe operator."""
446
+ return f"({self.left.to_prover9()} | {self.right.to_prover9()})"
447
+
448
+ def to_tptp(self) -> str:
449
+ """Render disjunction in TPTP syntax using the pipe operator."""
450
+ return f"({self.left.to_tptp()} | {self.right.to_tptp()})"
451
+
452
+
453
+ @dataclass
454
+ class Xor(Node):
455
+ """Exclusive disjunction of two formulas."""
456
+
457
+ left: Node
458
+ right: Node
459
+
460
+ def to_dict(self):
461
+ """Serialise to dict with type tag and recursively serialised operands."""
462
+ return {"_type": "Xor", "left": self.left.to_dict(), "right": self.right.to_dict()}
463
+
464
+ @staticmethod
465
+ def from_dict(d):
466
+ """Deserialise an Xor from a dict produced by to_dict."""
467
+ return Xor(Node.from_dict(d["left"]), Node.from_dict(d["right"]))
468
+
469
+ def to_z3(self, env: Z3Env = None):
470
+ """Translate to a Z3 Xor expression."""
471
+ env = env or Z3Env()
472
+ return z3.Xor(self.left.to_z3(env), self.right.to_z3(env))
473
+
474
+ def to_prover9(self) -> str:
475
+ """Render exclusive or in Prover9 syntax by expanding to (l | r) & -(l & r)."""
476
+ l = self.left.to_prover9()
477
+ r = self.right.to_prover9()
478
+ return f"(({l} | {r}) & -(({l}) & ({r})))"
479
+
480
+ def to_tptp(self) -> str:
481
+ """Render exclusive or in TPTP syntax using the XOR operator (~|)."""
482
+ return f"({self.left.to_tptp()} ~| {self.right.to_tptp()})"
483
+
484
+
485
+ @dataclass
486
+ class Implies(Node):
487
+ """Material implication from left to right."""
488
+
489
+ left: Node
490
+ right: Node
491
+
492
+ def to_dict(self):
493
+ """Serialise to dict with type tag and recursively serialised operands."""
494
+ return {"_type": "Implies", "left": self.left.to_dict(), "right": self.right.to_dict()}
495
+
496
+ @staticmethod
497
+ def from_dict(d):
498
+ """Deserialise an Implies from a dict produced by to_dict."""
499
+ return Implies(Node.from_dict(d["left"]), Node.from_dict(d["right"]))
500
+
501
+ def to_z3(self, env: Z3Env = None):
502
+ """Translate to a Z3 Implies expression."""
503
+ env = env or Z3Env()
504
+ return z3.Implies(self.left.to_z3(env), self.right.to_z3(env))
505
+
506
+ def to_prover9(self) -> str:
507
+ """Render implication in Prover9 syntax using the -> operator."""
508
+ return f"({self.left.to_prover9()} -> {self.right.to_prover9()})"
509
+
510
+ def to_tptp(self) -> str:
511
+ """Render implication in TPTP syntax using the => operator."""
512
+ return f"({self.left.to_tptp()} => {self.right.to_tptp()})"
513
+
514
+
515
+ @dataclass
516
+ class Iff(Node):
517
+ """Biconditional (if and only if) between two formulas."""
518
+
519
+ left: Node
520
+ right: Node
521
+
522
+ def to_dict(self):
523
+ """Serialise to dict with type tag and recursively serialised operands."""
524
+ return {"_type": "Iff", "left": self.left.to_dict(), "right": self.right.to_dict()}
525
+
526
+ @staticmethod
527
+ def from_dict(d):
528
+ """Deserialise an Iff from a dict produced by to_dict."""
529
+ return Iff(Node.from_dict(d["left"]), Node.from_dict(d["right"]))
530
+
531
+ def to_z3(self, env: Z3Env = None):
532
+ """Translate to Z3 equality of the two boolean subexpressions."""
533
+ env = env or Z3Env()
534
+ return self.left.to_z3(env) == self.right.to_z3(env)
535
+
536
+ def to_prover9(self) -> str:
537
+ """Render biconditional in Prover9 syntax using the <-> operator."""
538
+ return f"({self.left.to_prover9()} <-> {self.right.to_prover9()})"
539
+
540
+ def to_tptp(self) -> str:
541
+ """Render biconditional in TPTP syntax using the <=> operator."""
542
+ return f"({self.left.to_tptp()} <=> {self.right.to_tptp()})"
543
+
544
+
545
+ @dataclass
546
+ class Quantifier(Node):
547
+ """A universally or existentially quantified formula over a single variable."""
548
+
549
+ type: str
550
+ variable: Variable
551
+ formula: Node
552
+
553
+ def to_dict(self):
554
+ """Serialise to dict with type tag, quantifier type, variable, and recursively serialised body."""
555
+ return {
556
+ "_type": "Quantifier",
557
+ "type": self.type,
558
+ "variable": self.variable.to_dict(),
559
+ "formula": self.formula.to_dict()
560
+ }
561
+
562
+ @staticmethod
563
+ def from_dict(d):
564
+ """Deserialise a Quantifier from a dict produced by to_dict."""
565
+ return Quantifier(d["type"], Node.from_dict(d["variable"]), Node.from_dict(d["formula"]))
566
+
567
+ def to_z3(self, env: Z3Env = None):
568
+ """Translate to a Z3 ForAll or Exists expression over the bound variable."""
569
+ env = env or Z3Env()
570
+ z3_var = self.variable.to_z3(env)
571
+ body = self.formula.to_z3(env)
572
+
573
+ if self.type in ("forall", "∀"):
574
+ return z3.ForAll([z3_var], body)
575
+ elif self.type in ("exists", "∃"):
576
+ return z3.Exists([z3_var], body)
577
+ raise ValueError(f"Unknown quantifier: {self.type}")
578
+
579
+ def to_prover9(self) -> str:
580
+ """Render the quantified formula in Prover9 syntax using all/exists keywords."""
581
+ var = self.variable.to_prover9()
582
+ body = self.formula.to_prover9()
583
+
584
+ if self.type in ("forall", "∀"):
585
+ return f"(all {var} {body})"
586
+ elif self.type in ("exists", "∃"):
587
+ return f"(exists {var} {body})"
588
+ raise ValueError(f"Unknown quantifier: {self.type}")
589
+
590
+ def to_tptp(self) -> str:
591
+ """Render a quantified formula in TPTP syntax.
592
+
593
+ Universal quantification uses ! and existential uses ?,
594
+ with the bound variable listed in brackets: ![X]: body or ?[X]: body.
595
+ """
596
+ var = self.variable.to_tptp()
597
+ body = self.formula.to_tptp()
598
+
599
+ if self.type in ("forall", "∀"):
600
+ return f"(![{var}]: {body})"
601
+ elif self.type in ("exists", "∃"):
602
+ return f"(?[{var}]: {body})"
603
+ raise ValueError(f"Unknown quantifier: {self.type}")
604
+
605
+
606
+ # =========================
607
+ # Registry
608
+ # =========================
609
+
610
+ NODE_CLASSES = {
611
+ "Variable": Variable, "Constant": Constant, "Number": Number,
612
+ "Function": Function, "Atom": Atom, "Not": Not, "And": And,
613
+ "Or": Or, "Xor": Xor, "Implies": Implies, "Iff": Iff,
614
+ "Quantifier": Quantifier,
615
+ }
616
+
617
+ # =========================
618
+ # Transformer
619
+ # =========================
620
+ class FOLTransformer(Transformer):
621
+ """Transforms parsed tokens from Lark parser into AST nodes."""
622
+
623
+ @staticmethod
624
+ def _fold_binary(items, node_cls):
625
+ """Left-fold a variable-length item list into nested binary nodes."""
626
+ node = items[0]
627
+ for item in items[1:]:
628
+ node = node_cls(node, item)
629
+ return node
630
+
631
+ def atom0_(self, items):
632
+ """Transform bare predicate symbol into a zero-arity Atom node."""
633
+ pred = str(items[0])
634
+ return Atom(pred, [])
635
+
636
+ def VARIABLE(self, items):
637
+ """Transform variable token into Variable node."""
638
+ return Variable(str(items))
639
+
640
+ def NAME(self, items):
641
+ """Transform name token into Constant node."""
642
+ return Constant(str(items))
643
+
644
+ def const_(self, items):
645
+ """Transform c_-prefixed constant token into Constant node."""
646
+ return Constant(str(items[0]))
647
+
648
+ def number_(self, items):
649
+ """Transform numeric literal token into Number node."""
650
+ text = str(items[0])
651
+ value = float(text) if "." in text else int(text)
652
+ return Number(value)
653
+
654
+ def function_(self, items):
655
+ """Transform function application into Function node."""
656
+ head = items[0]
657
+ name = head.name if isinstance(head, Constant) else str(head)
658
+ args = items[1:]
659
+ if args and isinstance(args[0], list):
660
+ args = args[0]
661
+ return Function(name, args)
662
+
663
+ def add_(self, items):
664
+ """Transform addition into Function node."""
665
+ left, right = items
666
+ return Function("+", [left, right])
667
+
668
+ def sub_(self, items):
669
+ """Transform subtraction into Function node."""
670
+ left, right = items
671
+ return Function("-", [left, right])
672
+
673
+ def mul_(self, items):
674
+ """Transform multiplication into Function node."""
675
+ left, right = items
676
+ return Function("*", [left, right])
677
+
678
+ def div_(self, items):
679
+ """Transform division into Function node."""
680
+ left, right = items
681
+ return Function("/", [left, right])
682
+
683
+ def atom_term(self, items):
684
+ """Pass through atom term."""
685
+ return items[0]
686
+
687
+ def term(self, items):
688
+ """Pass through term."""
689
+ return items[0]
690
+
691
+ def sum(self, items):
692
+ """Pass through sum expression."""
693
+ return items[0]
694
+
695
+ def product(self, items):
696
+ """Pass through product expression."""
697
+ return items[0]
698
+
699
+ def termlist(self, items):
700
+ """Transform term list."""
701
+ return items
702
+
703
+ def infix_predicate(self, items):
704
+ """Pass through infix predicate."""
705
+ return items[0]
706
+
707
+ def atom(self, items):
708
+ """Pass through atom."""
709
+ return items[0]
710
+
711
+ def atom_(self, items):
712
+ """Transform predicate application into Atom node."""
713
+ pred = str(items[0])
714
+ if not isinstance(items[1], list):
715
+ args = [items[1]]
716
+ else:
717
+ args = items[1]
718
+ return Atom(pred, args)
719
+
720
+ def lt_(self, items):
721
+ """Transform less-than comparison into Atom node."""
722
+ left, right = items
723
+ return Atom("<", [left, right])
724
+
725
+ def gt_(self, items):
726
+ """Transform greater-than comparison into Atom node."""
727
+ left, right = items
728
+ return Atom(">", [left, right])
729
+
730
+ def eq_(self, items):
731
+ """Transform equality comparison into Atom node."""
732
+ left, right = items
733
+ return Atom("=", [left, right])
734
+
735
+ def le_(self, items):
736
+ """Transform less-than-or-equal comparison into Atom node."""
737
+ left, right = items
738
+ return Atom("≤", [left, right])
739
+
740
+ def ge_(self, items):
741
+ """Transform greater-than-or-equal comparison into Atom node."""
742
+ left, right = items
743
+ return Atom("≥", [left, right])
744
+
745
+ def ne_(self, items):
746
+ """Transform not-equal comparison into Atom node."""
747
+ left, right = items
748
+ return Atom("≠", [left, right])
749
+
750
+ def not_(self, items):
751
+ """Transform negation into Not node."""
752
+ return Not(items[0])
753
+
754
+ def and_(self, items):
755
+ """Transform conjunction into And node."""
756
+ return self._fold_binary(items, And)
757
+
758
+ def or_(self, items):
759
+ """Transform disjunction into Or node."""
760
+ return self._fold_binary(items, Or)
761
+
762
+ def xor_(self, items):
763
+ """Transform exclusive or into Xor node."""
764
+ return self._fold_binary(items, Xor)
765
+
766
+ def implies_(self, items):
767
+ """Transform implication into Implies node."""
768
+ return Implies(items[0], items[1])
769
+
770
+ def iff_(self, items):
771
+ """Transform biconditional into Iff node."""
772
+ return Iff(items[0], items[1])
773
+
774
+ def quantifier_(self, items):
775
+ """Transform quantifier expression into Quantifier node."""
776
+ quant = items[0]
777
+ var = items[1]
778
+ formula = items[2]
779
+ return Quantifier(str(quant), var, formula)