syncraft 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of syncraft might be problematic. Click here for more details.
- syncraft/__init__.py +30 -9
- syncraft/algebra.py +143 -214
- syncraft/ast.py +62 -7
- syncraft/cache.py +113 -0
- syncraft/constraint.py +184 -134
- syncraft/dev.py +9 -0
- syncraft/finder.py +17 -12
- syncraft/generator.py +80 -78
- syncraft/lexer.py +131 -0
- syncraft/parser.py +75 -224
- syncraft/syntax.py +187 -100
- syncraft/utils.py +214 -0
- syncraft/walker.py +147 -0
- syncraft-0.2.7.dist-info/METADATA +56 -0
- syncraft-0.2.7.dist-info/RECORD +20 -0
- syncraft/diagnostic.py +0 -70
- syncraft-0.2.5.dist-info/METADATA +0 -113
- syncraft-0.2.5.dist-info/RECORD +0 -16
- {syncraft-0.2.5.dist-info → syncraft-0.2.7.dist-info}/WHEEL +0 -0
- {syncraft-0.2.5.dist-info → syncraft-0.2.7.dist-info}/licenses/LICENSE +0 -0
- {syncraft-0.2.5.dist-info → syncraft-0.2.7.dist-info}/top_level.txt +0 -0
syncraft/ast.py
CHANGED
|
@@ -13,9 +13,16 @@ from dataclasses import dataclass, replace, is_dataclass, fields
|
|
|
13
13
|
from enum import Enum
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
class SyncraftError(Exception):
|
|
17
|
+
def __init__(self, message: str, offending: Any, expect: Any = None, **kwargs: Any) -> None:
|
|
18
|
+
super().__init__(message)
|
|
19
|
+
self.offending = offending
|
|
20
|
+
self.expect = expect
|
|
21
|
+
self.data = kwargs
|
|
22
|
+
|
|
16
23
|
def shallow_dict(a: Any)->Dict[str, Any]:
|
|
17
24
|
if not is_dataclass(a):
|
|
18
|
-
raise
|
|
25
|
+
raise SyncraftError("Expected dataclass instance for collector inverse", offending=a, expect="dataclass")
|
|
19
26
|
return {f.name: getattr(a, f.name) for f in fields(a)}
|
|
20
27
|
|
|
21
28
|
|
|
@@ -130,7 +137,7 @@ class Bimap(Generic[A, B]):
|
|
|
130
137
|
return c, inv
|
|
131
138
|
return Bimap(bimap_then_run)
|
|
132
139
|
else:
|
|
133
|
-
raise
|
|
140
|
+
raise SyncraftError("Unsupported type for Bimap >>", offending=other, expect=(Bimap , Biarrow))
|
|
134
141
|
def __rrshift__(self, other: Bimap[C, A] | Biarrow[C, A]) -> Bimap[C, B]:
|
|
135
142
|
"""Right-composition so arrows or bimaps can be on the left of ``>>``."""
|
|
136
143
|
if isinstance(other, Biarrow):
|
|
@@ -152,7 +159,7 @@ class Bimap(Generic[A, B]):
|
|
|
152
159
|
return b2, inv
|
|
153
160
|
return Bimap(bimap_then_run)
|
|
154
161
|
else:
|
|
155
|
-
raise
|
|
162
|
+
raise SyncraftError("Unsupported type for Bimap <<", offending=other, expect=(Bimap , Biarrow))
|
|
156
163
|
|
|
157
164
|
|
|
158
165
|
@staticmethod
|
|
@@ -397,7 +404,7 @@ class Collect(Generic[A, E], AST):
|
|
|
397
404
|
|
|
398
405
|
def inv_one_positional(e: E) -> B:
|
|
399
406
|
if not is_dataclass(e):
|
|
400
|
-
raise
|
|
407
|
+
raise SyncraftError("Expected dataclass instance for collector inverse", offending=e, expect="dataclass")
|
|
401
408
|
named_dict = shallow_dict(e)
|
|
402
409
|
return named_dict[fields(e)[0].name]
|
|
403
410
|
|
|
@@ -417,7 +424,7 @@ class Collect(Generic[A, E], AST):
|
|
|
417
424
|
ret: E = self.collector(*unnamed, **named)
|
|
418
425
|
def invf(e: E) -> Tuple[Any, ...]:
|
|
419
426
|
if not is_dataclass(e):
|
|
420
|
-
raise
|
|
427
|
+
raise SyncraftError("Expected dataclass instance for collector inverse", offending=e, expect="dataclass")
|
|
421
428
|
named_dict = shallow_dict(e)
|
|
422
429
|
unnamed = []
|
|
423
430
|
for f in fields(e):
|
|
@@ -432,6 +439,21 @@ class Collect(Generic[A, E], AST):
|
|
|
432
439
|
return tuple(tmp)
|
|
433
440
|
return ret, lambda e: replace(self, value=inner_f(invf(e))) # type: ignore
|
|
434
441
|
return self.collector(b), lambda e: replace(self, value=inner_f(inv_one_positional(e))) # type: ignore
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
@dataclass(frozen=True)
|
|
446
|
+
class Custom(Generic[A, B], AST):
|
|
447
|
+
"""A custom AST node wrapping an arbitrary value.
|
|
448
|
+
|
|
449
|
+
Used when the parse result does not fit into other AST node types.
|
|
450
|
+
"""
|
|
451
|
+
meta: B
|
|
452
|
+
value: A
|
|
453
|
+
def bimap(self, r: Bimap[A, C]=Bimap.identity()) -> Tuple[C, Callable[[C], Custom[A, B]]]:
|
|
454
|
+
"""Defer to the provided mapping ``r``."""
|
|
455
|
+
v, inv = r(self.value)
|
|
456
|
+
return v, lambda c: replace(self, value=inv(c))
|
|
435
457
|
|
|
436
458
|
#########################################################################################################################
|
|
437
459
|
@dataclass(frozen=True)
|
|
@@ -455,13 +477,46 @@ class TokenProtocol(Protocol):
|
|
|
455
477
|
T = TypeVar('T', bound=TokenProtocol)
|
|
456
478
|
|
|
457
479
|
|
|
480
|
+
|
|
481
|
+
@dataclass(frozen=True)
|
|
482
|
+
class SyntaxSpec:
|
|
483
|
+
pass
|
|
484
|
+
@dataclass(frozen=True)
|
|
485
|
+
class ChoiceSpec(SyntaxSpec, Generic[A, B]):
|
|
486
|
+
left: A
|
|
487
|
+
right: B
|
|
488
|
+
|
|
489
|
+
@dataclass(frozen=True)
|
|
490
|
+
class LazySpec(SyntaxSpec, Generic[A]):
|
|
491
|
+
value: A
|
|
492
|
+
@dataclass(frozen=True)
|
|
493
|
+
class ThenSpec(SyntaxSpec, Generic[A, B]):
|
|
494
|
+
left: A
|
|
495
|
+
right: B
|
|
496
|
+
|
|
458
497
|
@dataclass(frozen=True)
|
|
459
|
-
class
|
|
498
|
+
class ManySpec(SyntaxSpec, Generic[A]):
|
|
499
|
+
value: A
|
|
500
|
+
at_least: int
|
|
501
|
+
at_most: Optional[int]
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
@dataclass(frozen=True)
|
|
505
|
+
class TokenSpec(SyntaxSpec):
|
|
460
506
|
token_type: Optional[Enum] = None
|
|
461
507
|
text: Optional[str] = None
|
|
462
508
|
case_sensitive: bool = False
|
|
463
509
|
regex: Optional[re.Pattern[str]] = None
|
|
464
|
-
|
|
510
|
+
|
|
511
|
+
@classmethod
|
|
512
|
+
def create(cls,
|
|
513
|
+
*,
|
|
514
|
+
token_type: Optional[Enum] = None,
|
|
515
|
+
text: Optional[str] = None,
|
|
516
|
+
case_sensitive: bool = False,
|
|
517
|
+
regex: Optional[re.Pattern[str]] = None) -> TokenSpec:
|
|
518
|
+
return cls(token_type=token_type, text=text, case_sensitive=case_sensitive, regex=regex)
|
|
519
|
+
|
|
465
520
|
def is_valid(self, token: TokenProtocol) -> bool:
|
|
466
521
|
type_match = self.token_type is None or token.token_type == self.token_type
|
|
467
522
|
value_match = self.text is None or (token.text.strip() == self.text.strip() if self.case_sensitive else
|
syncraft/cache.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Dict, TypeVar, Hashable, Generic, Callable, Any, Generator, overload, Literal
|
|
5
|
+
from weakref import WeakKeyDictionary
|
|
6
|
+
from syncraft.ast import SyncraftError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RecursionError(SyncraftError):
|
|
10
|
+
def __init__(self, message: str, offending: Any, expect: Any = None, **kwargs: Any) -> None:
|
|
11
|
+
super().__init__(message, offending, expect, **kwargs)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class InProgress:
|
|
16
|
+
_instance = None
|
|
17
|
+
def __new__(cls):
|
|
18
|
+
if cls._instance is None:
|
|
19
|
+
cls._instance = super(InProgress, cls).__new__(cls)
|
|
20
|
+
return cls._instance
|
|
21
|
+
def __str__(self)->str:
|
|
22
|
+
return self.__class__.__name__
|
|
23
|
+
def __repr__(self)->str:
|
|
24
|
+
return self.__str__()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
Args = TypeVar('Args', bound=Hashable)
|
|
30
|
+
Ret = TypeVar('Ret')
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class Cache(Generic[Args, Ret]):
|
|
34
|
+
cache: WeakKeyDictionary[Callable[..., Any], Dict[Args, Ret | InProgress]] = field(default_factory=WeakKeyDictionary)
|
|
35
|
+
|
|
36
|
+
def __contains__(self, f: Callable[..., Any]) -> bool:
|
|
37
|
+
return f in self.cache
|
|
38
|
+
|
|
39
|
+
def __repr__(self) -> str:
|
|
40
|
+
return f"Cache({({f.__name__: list(c.keys()) for f, c in self.cache.items()})})"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def __or__(self, other: Cache[Args, Any]) -> Cache[Args, Any]:
|
|
44
|
+
assert self.cache is other.cache, "There should be only one global cache"
|
|
45
|
+
if self.cache is other.cache:
|
|
46
|
+
return self
|
|
47
|
+
elif len(self.cache) == 0:
|
|
48
|
+
return other
|
|
49
|
+
elif len(other.cache) == 0:
|
|
50
|
+
return self
|
|
51
|
+
merged = Cache[Args, Ret]()
|
|
52
|
+
for f, c in self.cache.items():
|
|
53
|
+
merged.cache[f] = c.copy()
|
|
54
|
+
for f, c in other.cache.items():
|
|
55
|
+
merged.cache.setdefault(f, {}).update(c)
|
|
56
|
+
return merged
|
|
57
|
+
|
|
58
|
+
@overload
|
|
59
|
+
def _execute(self,
|
|
60
|
+
f: Callable[[Args, bool], Ret],
|
|
61
|
+
args: Args,
|
|
62
|
+
use_cache: bool,
|
|
63
|
+
is_gen: Literal[False]) -> Ret: ...
|
|
64
|
+
@overload
|
|
65
|
+
def _execute(self,
|
|
66
|
+
f: Callable[[Args, bool], Generator[Any, Any, Ret]],
|
|
67
|
+
args: Args,
|
|
68
|
+
use_cache: bool,
|
|
69
|
+
is_gen: Literal[True]) -> Generator[Any, Any, Ret]: ...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _execute(self,
|
|
73
|
+
f: Callable[[Args, bool], Any],
|
|
74
|
+
args: Args,
|
|
75
|
+
use_cache:bool,
|
|
76
|
+
is_gen: bool
|
|
77
|
+
) -> Ret | Generator[Any, Any, Ret]:
|
|
78
|
+
if f not in self.cache:
|
|
79
|
+
self.cache.setdefault(f, dict())
|
|
80
|
+
c: Dict[Args, Ret | InProgress] = self.cache[f]
|
|
81
|
+
if args in c:
|
|
82
|
+
v = c[args]
|
|
83
|
+
if isinstance(v, InProgress):
|
|
84
|
+
raise RecursionError("Left-recursion detected in parser", offending=f, state=args)
|
|
85
|
+
else:
|
|
86
|
+
return v
|
|
87
|
+
try:
|
|
88
|
+
c[args] = InProgress()
|
|
89
|
+
if is_gen:
|
|
90
|
+
result = yield from f(args, use_cache)
|
|
91
|
+
else:
|
|
92
|
+
result = f(args, use_cache)
|
|
93
|
+
c[args] = result
|
|
94
|
+
if not use_cache:
|
|
95
|
+
c.pop(args, None)
|
|
96
|
+
return result
|
|
97
|
+
except Exception as e:
|
|
98
|
+
c.pop(args, None)
|
|
99
|
+
raise e
|
|
100
|
+
|
|
101
|
+
def gen(self,
|
|
102
|
+
f: Callable[[Args, bool], Generator[Any, Any, Ret]],
|
|
103
|
+
args: Args,
|
|
104
|
+
use_cache:bool) -> Generator[Any, Any, Ret]:
|
|
105
|
+
return (yield from self._execute(f, args, use_cache, is_gen=True))
|
|
106
|
+
|
|
107
|
+
def call(self,
|
|
108
|
+
f: Callable[[Args, bool], Ret],
|
|
109
|
+
args: Args,
|
|
110
|
+
use_cache:bool) -> Ret:
|
|
111
|
+
return self._execute(f, args, use_cache, is_gen=False)
|
|
112
|
+
|
|
113
|
+
|
syncraft/constraint.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
Callable, Generic, Tuple, TypeVar, Optional, Any, Self,
|
|
4
|
+
Generator, List, Set, Union, Dict, Iterable,
|
|
5
|
+
)
|
|
3
6
|
from enum import Enum
|
|
4
|
-
from dataclasses import dataclass, field, replace
|
|
7
|
+
from dataclasses import dataclass, field, replace, is_dataclass, fields
|
|
5
8
|
import collections.abc
|
|
6
9
|
from collections import defaultdict
|
|
7
10
|
from itertools import product
|
|
8
11
|
from inspect import Signature
|
|
9
12
|
import inspect
|
|
10
|
-
|
|
13
|
+
from syncraft.ast import SyncraftError
|
|
11
14
|
K = TypeVar('K')
|
|
12
15
|
V = TypeVar('V')
|
|
13
16
|
class FrozenDict(collections.abc.Mapping, Generic[K, V]):
|
|
@@ -171,7 +174,9 @@ class Constraint:
|
|
|
171
174
|
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
172
175
|
kw_params.append(pname)
|
|
173
176
|
else:
|
|
174
|
-
raise
|
|
177
|
+
raise SyncraftError(f"Unsupported parameter kind: {param.kind}",
|
|
178
|
+
offending=param.kind,
|
|
179
|
+
expect=(inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY))
|
|
175
180
|
def run_f(bound: FrozenDict[str, Tuple[Any, ...]]) -> ConstraintResult:
|
|
176
181
|
# positional argument values
|
|
177
182
|
pos_values = [bound.get(pname, ()) for pname in pos_params]
|
|
@@ -254,134 +259,179 @@ def all_binding(a: FrozenDict[str, Tuple[Any, ...]], *names: str) -> Generator[F
|
|
|
254
259
|
|
|
255
260
|
|
|
256
261
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
262
|
+
####################################################################################################################################
|
|
261
263
|
@dataclass(frozen=True)
|
|
262
|
-
class
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
return
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
)
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
return
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
return
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
if
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
264
|
+
class Var:
|
|
265
|
+
name: str
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
Subst = Dict[str, Any]
|
|
269
|
+
Fact = Tuple[str, Tuple[Any, ...]]
|
|
270
|
+
Rule = Tuple[str, Tuple[Any, ...], List[Fact]]
|
|
271
|
+
|
|
272
|
+
def is_var(x): return isinstance(x, Var)
|
|
273
|
+
|
|
274
|
+
# ---------- Unification ----------
|
|
275
|
+
def unify(x, y, subst: Subst) -> Subst | None:
|
|
276
|
+
if x == y:
|
|
277
|
+
return subst
|
|
278
|
+
if is_var(x):
|
|
279
|
+
return unify_var(x, y, subst)
|
|
280
|
+
if is_var(y):
|
|
281
|
+
return unify_var(y, x, subst)
|
|
282
|
+
if isinstance(x, tuple) and isinstance(y, tuple) and len(x) == len(y):
|
|
283
|
+
for a, b in zip(x, y):
|
|
284
|
+
tmp = unify(a, b, subst)
|
|
285
|
+
if tmp is None:
|
|
286
|
+
return None
|
|
287
|
+
else:
|
|
288
|
+
subst = tmp
|
|
289
|
+
return subst
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
def unify_var(var: Var, val: Any, subst: Subst) -> Subst | None:
|
|
293
|
+
if var.name in subst:
|
|
294
|
+
return unify(subst[var.name], val, subst)
|
|
295
|
+
if occurs_check(var, val, subst):
|
|
296
|
+
return None
|
|
297
|
+
subst = subst.copy()
|
|
298
|
+
subst[var.name] = val
|
|
299
|
+
return subst
|
|
300
|
+
|
|
301
|
+
def occurs_check(var: Var, val: Any, subst: Subst) -> bool:
|
|
302
|
+
if var == val:
|
|
303
|
+
return True
|
|
304
|
+
if is_var(val) and val.name in subst:
|
|
305
|
+
return occurs_check(var, subst[val.name], subst)
|
|
306
|
+
if isinstance(val, tuple):
|
|
307
|
+
return any(occurs_check(var, v, subst) for v in val)
|
|
308
|
+
return False
|
|
309
|
+
|
|
310
|
+
# ---------- Substitution ----------
|
|
311
|
+
def apply_subst_fact(fact: Fact, subst: Subst) -> Fact:
|
|
312
|
+
pred, args = fact
|
|
313
|
+
return (pred, tuple(apply_subst_term(a, subst) for a in args))
|
|
314
|
+
|
|
315
|
+
def apply_subst_term(term, subst: Subst):
|
|
316
|
+
if is_var(term) and term.name in subst:
|
|
317
|
+
return apply_subst_term(subst[term.name], subst)
|
|
318
|
+
return term
|
|
319
|
+
|
|
320
|
+
# ---------- Engine ----------
|
|
321
|
+
class DatalogEngine:
|
|
322
|
+
def __init__(self):
|
|
323
|
+
self.facts: List[Fact] = []
|
|
324
|
+
self.rules: List[Rule] = []
|
|
325
|
+
|
|
326
|
+
def add_fact(self, fact: Fact):
|
|
327
|
+
self.facts.append(fact)
|
|
328
|
+
|
|
329
|
+
def add_rule(self, head: Fact, body: List[Fact]):
|
|
330
|
+
self.rules.append((head[0], head[1], body))
|
|
331
|
+
|
|
332
|
+
# ----- Forward chaining -----
|
|
333
|
+
def infer(self) -> List[Fact]:
|
|
334
|
+
changed = True
|
|
335
|
+
inferred = set(self.facts)
|
|
336
|
+
while changed:
|
|
337
|
+
changed = False
|
|
338
|
+
for (hpred, hargs, body) in self.rules:
|
|
339
|
+
for subst in self._prove_body(body, {}):
|
|
340
|
+
head = apply_subst_fact((hpred, hargs), subst)
|
|
341
|
+
if head not in inferred:
|
|
342
|
+
inferred.add(head)
|
|
343
|
+
changed = True
|
|
344
|
+
return list(inferred)
|
|
345
|
+
|
|
346
|
+
# ----- Backward chaining -----
|
|
347
|
+
def query(self, goal: Fact, subst: Subst | None = None) -> Generator[Subst, None, None]:
|
|
348
|
+
if subst is None:
|
|
349
|
+
subst = {}
|
|
350
|
+
pred, args = goal
|
|
351
|
+
|
|
352
|
+
# Match against facts
|
|
353
|
+
for (fpred, fargs) in self.facts:
|
|
354
|
+
if fpred != pred:
|
|
355
|
+
continue
|
|
356
|
+
s = unify(args, fargs, subst)
|
|
357
|
+
if s is not None:
|
|
358
|
+
yield s
|
|
359
|
+
|
|
360
|
+
# Match against rules
|
|
361
|
+
for (hpred, hargs, body) in self.rules:
|
|
362
|
+
if hpred != pred:
|
|
363
|
+
continue
|
|
364
|
+
s = unify(args, hargs, subst)
|
|
365
|
+
if s is None:
|
|
366
|
+
continue
|
|
367
|
+
yield from self._prove_body(body, s)
|
|
368
|
+
|
|
369
|
+
def _prove_body(self, goals: List[Fact], subst: Subst) -> Generator[Subst, None, None]:
|
|
370
|
+
if not goals:
|
|
371
|
+
yield subst
|
|
372
|
+
return
|
|
373
|
+
first, *rest = goals
|
|
374
|
+
for s in self.query(apply_subst_fact(first, subst), subst):
|
|
375
|
+
yield from self._prove_body(rest, s)
|
|
376
|
+
|
|
377
|
+
#####################################################################################################################################
|
|
378
|
+
def dataclass_to_facts(obj: Any, *, extended: bool = False, parent: Any = None) -> List[Fact]:
|
|
379
|
+
facts: List[Fact] = []
|
|
380
|
+
|
|
381
|
+
if not is_dataclass(obj):
|
|
382
|
+
raise TypeError(f"Expected dataclass instance, got {type(obj)}")
|
|
383
|
+
|
|
384
|
+
cls = type(obj)
|
|
385
|
+
pred = cls.__name__ # use class name as predicate
|
|
386
|
+
args = tuple(getattr(obj, f.name) for f in fields(obj))
|
|
387
|
+
facts.append((pred, args))
|
|
388
|
+
|
|
389
|
+
for f in fields(obj):
|
|
390
|
+
val = getattr(obj, f.name)
|
|
391
|
+
|
|
392
|
+
if is_dataclass(val):
|
|
393
|
+
# recurse into child dataclass
|
|
394
|
+
facts.extend(dataclass_to_facts(val, extended=extended, parent=obj))
|
|
395
|
+
|
|
396
|
+
if extended:
|
|
397
|
+
facts.append(("Contains", (obj, val)))
|
|
398
|
+
facts.append(("Field", (obj, f.name, val)))
|
|
399
|
+
|
|
400
|
+
elif isinstance(val, list):
|
|
401
|
+
for item in val:
|
|
402
|
+
if is_dataclass(item):
|
|
403
|
+
facts.extend(dataclass_to_facts(item, extended=extended, parent=obj))
|
|
404
|
+
if extended:
|
|
405
|
+
facts.append(("Contains", (obj, item)))
|
|
406
|
+
facts.append(("Field", (obj, f.name, item)))
|
|
407
|
+
else:
|
|
408
|
+
if extended:
|
|
409
|
+
facts.append(("Field", (obj, f.name, item)))
|
|
410
|
+
else:
|
|
411
|
+
if extended:
|
|
412
|
+
facts.append(("Field", (obj, f.name, val)))
|
|
413
|
+
|
|
414
|
+
return facts
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def test()->None:
|
|
419
|
+
X, Y, Z = Var("X"), Var("Y"), Var("Z")
|
|
420
|
+
|
|
421
|
+
db = DatalogEngine()
|
|
422
|
+
db.add_fact(("parent", ("alice", "bob")))
|
|
423
|
+
db.add_fact(("parent", ("bob", "carol")))
|
|
424
|
+
|
|
425
|
+
# Rules
|
|
426
|
+
db.add_rule(("ancestor", (X, Y)), [("parent", (X, Y))])
|
|
427
|
+
db.add_rule(("ancestor", (X, Y)), [("parent", (X, Z)), ("ancestor", (Z, Y))])
|
|
428
|
+
|
|
429
|
+
print("Forward infer:")
|
|
430
|
+
print(db.infer())
|
|
431
|
+
# [('parent', ('alice', 'bob')), ('parent', ('bob', 'carol')),
|
|
432
|
+
# ('ancestor', ('alice', 'bob')), ('ancestor', ('bob', 'carol')),
|
|
433
|
+
# ('ancestor', ('alice', 'carol'))]
|
|
434
|
+
|
|
435
|
+
print("Backward query:")
|
|
436
|
+
print(list(db.query(("ancestor", (X, "carol")))))
|
|
437
|
+
# [{'X': 'bob'}, {'X': 'alice'}]
|