kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/ir/attrs/types.py
ADDED
@@ -0,0 +1,522 @@
|
|
1
|
+
import typing
|
2
|
+
from abc import abstractmethod
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from collections.abc import Hashable
|
5
|
+
|
6
|
+
from beartype.door import TupleVariableTypeHint # type: ignore
|
7
|
+
from beartype.door import TypeHint, ClassTypeHint, LiteralTypeHint, TypeVarTypeHint
|
8
|
+
from typing_extensions import Never
|
9
|
+
|
10
|
+
from kirin.print import Printer
|
11
|
+
from kirin.lattice import (
|
12
|
+
UnionMeta,
|
13
|
+
SingletonMeta,
|
14
|
+
BoundedLattice,
|
15
|
+
IsSubsetEqMixin,
|
16
|
+
SimpleMeetMixin,
|
17
|
+
)
|
18
|
+
|
19
|
+
from .abc import Attribute, LatticeAttributeMeta
|
20
|
+
from ._types import _TypeAttribute
|
21
|
+
|
22
|
+
|
23
|
+
class TypeAttributeMeta(LatticeAttributeMeta):
|
24
|
+
"""Metaclass for type attributes."""
|
25
|
+
|
26
|
+
pass
|
27
|
+
|
28
|
+
|
29
|
+
class SingletonTypeMeta(TypeAttributeMeta, SingletonMeta):
|
30
|
+
"""Metaclass for singleton type attributes.
|
31
|
+
|
32
|
+
Singleton type attributes are attributes that have only one instance.
|
33
|
+
|
34
|
+
Examples:
|
35
|
+
- `AnyType`
|
36
|
+
- `BottomType`
|
37
|
+
"""
|
38
|
+
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class UnionTypeMeta(TypeAttributeMeta, UnionMeta):
|
43
|
+
pass
|
44
|
+
|
45
|
+
|
46
|
+
@dataclass
|
47
|
+
class TypeAttribute(
|
48
|
+
_TypeAttribute,
|
49
|
+
SimpleMeetMixin["TypeAttribute"],
|
50
|
+
IsSubsetEqMixin["TypeAttribute"],
|
51
|
+
BoundedLattice["TypeAttribute"],
|
52
|
+
metaclass=TypeAttributeMeta,
|
53
|
+
):
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def top(cls) -> "TypeAttribute":
|
57
|
+
return AnyType()
|
58
|
+
|
59
|
+
@classmethod
|
60
|
+
def bottom(cls) -> "TypeAttribute":
|
61
|
+
return BottomType()
|
62
|
+
|
63
|
+
def join(self, other: "TypeAttribute") -> "TypeAttribute":
|
64
|
+
if self.is_subseteq(other):
|
65
|
+
return other
|
66
|
+
elif other.is_subseteq(self):
|
67
|
+
return self
|
68
|
+
elif isinstance(other, TypeAttribute):
|
69
|
+
return Union(self, other)
|
70
|
+
return AnyType() # don't know how to join
|
71
|
+
|
72
|
+
def print_impl(self, printer: Printer) -> None:
|
73
|
+
printer.print_name(self, prefix="!")
|
74
|
+
|
75
|
+
def __or__(self, other: "TypeAttribute"):
|
76
|
+
return self.join(other)
|
77
|
+
|
78
|
+
def __eq__(self, value: object) -> bool:
|
79
|
+
return isinstance(value, TypeAttribute) and self.is_equal(value)
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def __hash__(self) -> int: ...
|
83
|
+
|
84
|
+
|
85
|
+
@typing.final
|
86
|
+
@dataclass(eq=False)
|
87
|
+
class AnyType(TypeAttribute, metaclass=SingletonTypeMeta):
|
88
|
+
name = "Any"
|
89
|
+
|
90
|
+
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
|
91
|
+
return self.is_subseteq(other.bound)
|
92
|
+
|
93
|
+
def __hash__(self) -> int:
|
94
|
+
return id(self)
|
95
|
+
|
96
|
+
|
97
|
+
@typing.final
|
98
|
+
@dataclass(eq=False)
|
99
|
+
class BottomType(TypeAttribute, metaclass=SingletonTypeMeta):
|
100
|
+
name = "Bottom"
|
101
|
+
|
102
|
+
def is_subseteq(self, other: TypeAttribute) -> bool:
|
103
|
+
if isinstance(other, TypeVar):
|
104
|
+
return self.is_subseteq(other.bound)
|
105
|
+
return True
|
106
|
+
|
107
|
+
def __hash__(self) -> int:
|
108
|
+
return id(self)
|
109
|
+
|
110
|
+
|
111
|
+
class PyClassMeta(TypeAttributeMeta):
|
112
|
+
|
113
|
+
def __init__(self, *args, **kwargs):
|
114
|
+
super(PyClassMeta, self).__init__(*args, **kwargs)
|
115
|
+
self._cache = {}
|
116
|
+
|
117
|
+
def __call__(self, typ):
|
118
|
+
if typ is typing.Any:
|
119
|
+
return AnyType()
|
120
|
+
elif typ is typing.NoReturn or typ is Never:
|
121
|
+
return BottomType()
|
122
|
+
elif typ is typing.Tuple:
|
123
|
+
typ = tuple
|
124
|
+
elif typ is typing.List:
|
125
|
+
typ = list
|
126
|
+
elif isinstance(typ, TypeVar):
|
127
|
+
return hint2type(typ)
|
128
|
+
elif isinstance(typ, type) and typ in self._cache:
|
129
|
+
return self._cache[typ]
|
130
|
+
|
131
|
+
instance = super(PyClassMeta, self).__call__(typ)
|
132
|
+
self._cache[typ] = instance
|
133
|
+
return instance
|
134
|
+
|
135
|
+
|
136
|
+
PyClassType = typing.TypeVar("PyClassType")
|
137
|
+
|
138
|
+
|
139
|
+
@typing.final
|
140
|
+
@dataclass(eq=False)
|
141
|
+
class PyClass(TypeAttribute, typing.Generic[PyClassType], metaclass=PyClassMeta):
|
142
|
+
name = "PyClass"
|
143
|
+
typ: type[PyClassType]
|
144
|
+
|
145
|
+
def __init__(self, typ: type[PyClassType]) -> None:
|
146
|
+
self.typ = typ
|
147
|
+
|
148
|
+
def is_subseteq_PyClass(self, other: "PyClass") -> bool:
|
149
|
+
return issubclass(self.typ, other.typ)
|
150
|
+
|
151
|
+
def is_subseteq_Union(self, other: "Union") -> bool:
|
152
|
+
return any(self.is_subseteq(t) for t in other.types)
|
153
|
+
|
154
|
+
def is_subseteq_Generic(self, other: "Generic") -> bool:
|
155
|
+
# NOTE: subclass without generics is just generic with all any parameters
|
156
|
+
Any = AnyType()
|
157
|
+
return (
|
158
|
+
self.is_subseteq(other.body)
|
159
|
+
and all(Any.is_subseteq(bound) for bound in other.vars)
|
160
|
+
and (other.vararg is None or Any.is_subseteq(other.vararg.typ))
|
161
|
+
)
|
162
|
+
|
163
|
+
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
|
164
|
+
return self.is_subseteq(other.bound)
|
165
|
+
|
166
|
+
def __hash__(self) -> int:
|
167
|
+
return hash((PyClass, self.typ))
|
168
|
+
|
169
|
+
def __repr__(self) -> str:
|
170
|
+
return self.typ.__name__
|
171
|
+
|
172
|
+
def print_impl(self, printer: Printer) -> None:
|
173
|
+
printer.plain_print("!py.", self.typ.__name__)
|
174
|
+
|
175
|
+
|
176
|
+
class LiteralMeta(TypeAttributeMeta):
|
177
|
+
|
178
|
+
def __init__(self, *args, **kwargs):
|
179
|
+
super(LiteralMeta, self).__init__(*args, **kwargs)
|
180
|
+
self._cache = {}
|
181
|
+
|
182
|
+
def __call__(self, data):
|
183
|
+
if isinstance(data, Attribute):
|
184
|
+
return data
|
185
|
+
elif not isinstance(data, Hashable):
|
186
|
+
return PyClass(type(data))
|
187
|
+
elif data in self._cache:
|
188
|
+
return self._cache[data]
|
189
|
+
|
190
|
+
instance = super(LiteralMeta, self).__call__(data)
|
191
|
+
self._cache[data] = instance
|
192
|
+
return instance
|
193
|
+
|
194
|
+
|
195
|
+
LiteralType = typing.TypeVar("LiteralType")
|
196
|
+
|
197
|
+
|
198
|
+
@typing.final
|
199
|
+
@dataclass(eq=False)
|
200
|
+
class Literal(TypeAttribute, typing.Generic[LiteralType], metaclass=LiteralMeta):
|
201
|
+
name = "Literal"
|
202
|
+
data: LiteralType
|
203
|
+
|
204
|
+
def is_equal(self, other: TypeAttribute) -> bool:
|
205
|
+
return self is other
|
206
|
+
|
207
|
+
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
|
208
|
+
return self.is_subseteq(other.bound)
|
209
|
+
|
210
|
+
def is_subseteq_Union(self, other: "Union") -> bool:
|
211
|
+
return any(self.is_subseteq(t) for t in other.types)
|
212
|
+
|
213
|
+
def is_subseteq_Literal(self, other: "Literal") -> bool:
|
214
|
+
return self.data == other.data
|
215
|
+
|
216
|
+
def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
|
217
|
+
return PyClass(type(self.data)).is_subseteq(other)
|
218
|
+
|
219
|
+
def __hash__(self) -> int:
|
220
|
+
return hash((Literal, self.data))
|
221
|
+
|
222
|
+
def print_impl(self, printer: Printer) -> None:
|
223
|
+
printer.plain_print(repr(self.data))
|
224
|
+
|
225
|
+
|
226
|
+
@typing.final
|
227
|
+
@dataclass(eq=False)
|
228
|
+
class Union(TypeAttribute, metaclass=UnionTypeMeta):
|
229
|
+
name = "Union"
|
230
|
+
types: frozenset[TypeAttribute]
|
231
|
+
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
typ_or_set: TypeAttribute | typing.Iterable[TypeAttribute],
|
235
|
+
*typs: TypeAttribute,
|
236
|
+
):
|
237
|
+
if isinstance(typ_or_set, TypeAttribute):
|
238
|
+
params: typing.Iterable[TypeAttribute] = (typ_or_set, *typs)
|
239
|
+
else:
|
240
|
+
params = typ_or_set
|
241
|
+
assert not typs, "Cannot pass multiple arguments when passing a set"
|
242
|
+
|
243
|
+
types: frozenset[TypeAttribute] = frozenset()
|
244
|
+
for typ in params:
|
245
|
+
if isinstance(typ, Union):
|
246
|
+
types = types.union(typ.types)
|
247
|
+
else:
|
248
|
+
types = types.union({typ})
|
249
|
+
self.types = types
|
250
|
+
|
251
|
+
def is_equal(self, other: TypeAttribute) -> bool:
|
252
|
+
return isinstance(other, Union) and self.types == other.types
|
253
|
+
|
254
|
+
def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
|
255
|
+
return all(t.is_subseteq(other) for t in self.types)
|
256
|
+
|
257
|
+
def join(self, other: TypeAttribute) -> TypeAttribute:
|
258
|
+
if self.is_subseteq(other):
|
259
|
+
return other
|
260
|
+
elif other.is_subseteq(self):
|
261
|
+
return self
|
262
|
+
elif isinstance(other, Union):
|
263
|
+
return Union(self.types | other.types)
|
264
|
+
elif isinstance(other, TypeAttribute):
|
265
|
+
return Union(self.types | {other})
|
266
|
+
return BottomType()
|
267
|
+
|
268
|
+
def meet(self, other: TypeAttribute) -> TypeAttribute:
|
269
|
+
if self.is_subseteq(other):
|
270
|
+
return self
|
271
|
+
elif other.is_subseteq(self):
|
272
|
+
return other
|
273
|
+
elif isinstance(other, Union):
|
274
|
+
return Union(self.types & other.types)
|
275
|
+
elif isinstance(other, TypeAttribute):
|
276
|
+
return Union(self.types & {other})
|
277
|
+
return BottomType()
|
278
|
+
|
279
|
+
def __hash__(self) -> int:
|
280
|
+
return hash((Union, self.types))
|
281
|
+
|
282
|
+
def print_impl(self, printer: Printer) -> None:
|
283
|
+
printer.print_name(self, prefix="!")
|
284
|
+
printer.print_seq(self.types, delim=", ", prefix="[", suffix="]")
|
285
|
+
|
286
|
+
|
287
|
+
@typing.final
|
288
|
+
@dataclass(eq=False)
|
289
|
+
class TypeVar(TypeAttribute):
|
290
|
+
name = "TypeVar"
|
291
|
+
varname: str
|
292
|
+
bound: TypeAttribute
|
293
|
+
|
294
|
+
def __init__(self, name: str, bound: TypeAttribute | None = None):
|
295
|
+
self.varname = name
|
296
|
+
self.bound = bound or AnyType()
|
297
|
+
|
298
|
+
def is_equal(self, other: TypeAttribute) -> bool:
|
299
|
+
return (
|
300
|
+
isinstance(other, TypeVar)
|
301
|
+
and self.varname == other.varname
|
302
|
+
and self.bound.is_equal(other.bound)
|
303
|
+
)
|
304
|
+
|
305
|
+
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
|
306
|
+
return self.bound.is_subseteq(other.bound)
|
307
|
+
|
308
|
+
def is_subseteq_Union(self, other: Union) -> bool:
|
309
|
+
return any(self.is_subseteq(t) for t in other.types)
|
310
|
+
|
311
|
+
def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
|
312
|
+
return self.bound.is_subseteq(other)
|
313
|
+
|
314
|
+
def __hash__(self) -> int:
|
315
|
+
return hash((TypeVar, self.varname, self.bound))
|
316
|
+
|
317
|
+
def print_impl(self, printer: Printer) -> None:
|
318
|
+
printer.plain_print(f"~{self.varname}")
|
319
|
+
if self.bound is not self.bound.top():
|
320
|
+
printer.plain_print(" : ")
|
321
|
+
printer.print(self.bound)
|
322
|
+
|
323
|
+
|
324
|
+
@typing.final
|
325
|
+
@dataclass(eq=False)
|
326
|
+
class Vararg(Attribute):
|
327
|
+
name = "Vararg"
|
328
|
+
typ: TypeAttribute
|
329
|
+
|
330
|
+
def __hash__(self) -> int:
|
331
|
+
return hash((Vararg, self.typ))
|
332
|
+
|
333
|
+
def print_impl(self, printer: Printer) -> None:
|
334
|
+
printer.plain_print("*")
|
335
|
+
printer.print(self.typ)
|
336
|
+
|
337
|
+
|
338
|
+
TypeVarValue: typing.TypeAlias = TypeAttribute | Vararg | list
|
339
|
+
TypeOrVararg: typing.TypeAlias = TypeAttribute | Vararg
|
340
|
+
|
341
|
+
|
342
|
+
@typing.final
|
343
|
+
@dataclass(eq=False)
|
344
|
+
class Generic(TypeAttribute, typing.Generic[PyClassType]):
|
345
|
+
name = "Generic"
|
346
|
+
body: PyClass[PyClassType]
|
347
|
+
vars: tuple[TypeAttribute, ...]
|
348
|
+
vararg: Vararg | None = None
|
349
|
+
|
350
|
+
def __init__(
|
351
|
+
self,
|
352
|
+
body: type[PyClassType] | PyClass[PyClassType],
|
353
|
+
*vars: TypeAttribute | list | Vararg,
|
354
|
+
):
|
355
|
+
if isinstance(body, PyClass):
|
356
|
+
self.body = body
|
357
|
+
else:
|
358
|
+
self.body = PyClass(body)
|
359
|
+
self.vars, self.vararg = _split_type_args(vars)
|
360
|
+
|
361
|
+
def is_subseteq_Literal(self, other: Literal) -> bool:
|
362
|
+
return False
|
363
|
+
|
364
|
+
def is_subseteq_PyClass(self, other: PyClass) -> bool:
|
365
|
+
return self.body.is_subseteq(other)
|
366
|
+
|
367
|
+
def is_subseteq_Union(self, other: Union) -> bool:
|
368
|
+
return any(self.is_subseteq(t) for t in other.types)
|
369
|
+
|
370
|
+
def is_subseteq_TypeVar(self, other: TypeVar) -> bool:
|
371
|
+
return self.body.is_subseteq(other.bound)
|
372
|
+
|
373
|
+
def is_subseteq_Generic(self, other: "Generic") -> bool:
|
374
|
+
if other.vararg is None:
|
375
|
+
return (
|
376
|
+
self.body.is_subseteq(other.body)
|
377
|
+
and len(self.vars) == len(other.vars)
|
378
|
+
and all(v.is_subseteq(o) for v, o in zip(self.vars, other.vars))
|
379
|
+
)
|
380
|
+
else:
|
381
|
+
return (
|
382
|
+
self.body.is_subseteq(other.body)
|
383
|
+
and len(self.vars) >= len(other.vars)
|
384
|
+
and all(v.is_subseteq(o) for v, o in zip(self.vars, other.vars))
|
385
|
+
and all(
|
386
|
+
v.is_subseteq(other.vararg.typ)
|
387
|
+
for v in self.vars[len(other.vars) :]
|
388
|
+
)
|
389
|
+
and (
|
390
|
+
self.vararg is None or self.vararg.typ.is_subseteq(other.vararg.typ)
|
391
|
+
)
|
392
|
+
)
|
393
|
+
|
394
|
+
def __hash__(self) -> int:
|
395
|
+
return hash((Generic, self.body, self.vars, self.vararg))
|
396
|
+
|
397
|
+
def __repr__(self) -> str:
|
398
|
+
if self.vararg is None:
|
399
|
+
return f"{self.body}[{', '.join(map(repr, self.vars))}]"
|
400
|
+
else:
|
401
|
+
return f"{self.body}[{', '.join(map(repr, self.vars))}, {self.vararg}, ...]"
|
402
|
+
|
403
|
+
def print_impl(self, printer: Printer) -> None:
|
404
|
+
printer.print(self.body)
|
405
|
+
printer.plain_print("[")
|
406
|
+
if self.vars:
|
407
|
+
printer.print_seq(self.vars)
|
408
|
+
if self.vararg is not None:
|
409
|
+
if self.vars:
|
410
|
+
printer.plain_print(", ")
|
411
|
+
printer.print(self.vararg.typ)
|
412
|
+
printer.plain_print(", ...")
|
413
|
+
printer.plain_print("]")
|
414
|
+
|
415
|
+
def __getitem__(self, typ: TypeVarValue | tuple[TypeVarValue, ...]) -> "Generic":
|
416
|
+
return self.where(typ)
|
417
|
+
|
418
|
+
def where(self, typ: TypeVarValue | tuple[TypeVarValue, ...]) -> "Generic":
|
419
|
+
if isinstance(typ, tuple):
|
420
|
+
typs = typ
|
421
|
+
else:
|
422
|
+
typs = (typ,)
|
423
|
+
|
424
|
+
args, vararg = _split_type_args(typs)
|
425
|
+
if self.vararg is None and vararg is None:
|
426
|
+
assert len(args) <= len(
|
427
|
+
self.vars
|
428
|
+
), "Number of type arguments does not match"
|
429
|
+
if all(v.is_subseteq(bound) for v, bound in zip(args, self.vars)):
|
430
|
+
return Generic(self.body, *args, *self.vars[len(args) :])
|
431
|
+
else:
|
432
|
+
raise TypeError("Type arguments do not match")
|
433
|
+
elif self.vararg is not None and vararg is None:
|
434
|
+
assert len(args) >= len(
|
435
|
+
self.vars
|
436
|
+
), "Number of type arguments does not match"
|
437
|
+
if all(v.is_subseteq(bound) for v, bound in zip(args, self.vars)) and all(
|
438
|
+
v.is_subseteq(self.vararg.typ) for v in args[len(self.vars) :]
|
439
|
+
):
|
440
|
+
return Generic(self.body, *args)
|
441
|
+
elif self.vararg is not None and vararg is not None:
|
442
|
+
if len(args) < len(self.vars):
|
443
|
+
if (
|
444
|
+
all(v.is_subseteq(bound) for v, bound in zip(args, self.vars))
|
445
|
+
and all(
|
446
|
+
vararg.typ.is_subseteq(bound)
|
447
|
+
for bound in self.vars[len(args) :]
|
448
|
+
)
|
449
|
+
and vararg.typ.is_subseteq(self.vararg.typ)
|
450
|
+
):
|
451
|
+
return Generic(self.body, *args, vararg)
|
452
|
+
else:
|
453
|
+
if (
|
454
|
+
all(v.is_subseteq(bound) for v, bound in zip(args, self.vars))
|
455
|
+
and all(v.is_subseteq(vararg.typ) for v in args[len(self.vars) :])
|
456
|
+
and vararg.typ.is_subseteq(self.vararg.typ)
|
457
|
+
):
|
458
|
+
return Generic(self.body, *args, vararg)
|
459
|
+
raise TypeError("Type arguments do not match")
|
460
|
+
|
461
|
+
|
462
|
+
def _typeparams_list2tuple(args: tuple[TypeVarValue, ...]) -> tuple[TypeOrVararg, ...]:
|
463
|
+
"provides the syntax sugar [A, B, C] type Generic(tuple, A, B, C)"
|
464
|
+
return tuple(Generic(tuple, *arg) if isinstance(arg, list) else arg for arg in args)
|
465
|
+
|
466
|
+
|
467
|
+
def _split_type_args(
|
468
|
+
args: tuple[TypeVarValue, ...]
|
469
|
+
) -> tuple[tuple[TypeAttribute, ...], Vararg | None]:
|
470
|
+
args = _typeparams_list2tuple(args)
|
471
|
+
if args is None or len(args) == 0:
|
472
|
+
return (), None
|
473
|
+
|
474
|
+
if isinstance(args[-1], Vararg):
|
475
|
+
xs = args[:-1]
|
476
|
+
if is_tuple_of(xs, TypeAttribute):
|
477
|
+
return xs, args[-1]
|
478
|
+
else:
|
479
|
+
raise TypeError("Multiple varargs are not allowed")
|
480
|
+
elif is_tuple_of(args, TypeAttribute):
|
481
|
+
return args, None
|
482
|
+
raise TypeError("Vararg must be the last argument")
|
483
|
+
|
484
|
+
|
485
|
+
T = typing.TypeVar("T")
|
486
|
+
|
487
|
+
|
488
|
+
def is_tuple_of(xs: tuple, typ: type[T]) -> typing.TypeGuard[tuple[T, ...]]:
|
489
|
+
return all(isinstance(x, typ) for x in xs)
|
490
|
+
|
491
|
+
|
492
|
+
def hint2type(hint) -> TypeAttribute:
|
493
|
+
if isinstance(hint, TypeAttribute):
|
494
|
+
return hint
|
495
|
+
elif hint is None:
|
496
|
+
return PyClass(type(None))
|
497
|
+
|
498
|
+
bear_hint = TypeHint(hint)
|
499
|
+
if isinstance(bear_hint, LiteralTypeHint):
|
500
|
+
return Literal(typing.get_args(hint)[0])
|
501
|
+
elif isinstance(bear_hint, TypeVarTypeHint):
|
502
|
+
return TypeVar(
|
503
|
+
hint.__name__,
|
504
|
+
hint2type(hint.__bound__) if hint.__bound__ else None,
|
505
|
+
)
|
506
|
+
elif isinstance(bear_hint, ClassTypeHint):
|
507
|
+
return PyClass(hint)
|
508
|
+
elif isinstance(bear_hint, TupleVariableTypeHint):
|
509
|
+
if len(bear_hint.args) != 1:
|
510
|
+
raise TypeError("Tuple hint must have exactly one argument")
|
511
|
+
return Generic(tuple, Vararg(hint2type(bear_hint.args[0])))
|
512
|
+
|
513
|
+
origin: type | None = typing.get_origin(hint)
|
514
|
+
if origin is None: # non-generic
|
515
|
+
return PyClass(hint)
|
516
|
+
|
517
|
+
body = PyClass(origin)
|
518
|
+
args = typing.get_args(hint)
|
519
|
+
params = []
|
520
|
+
for arg in args:
|
521
|
+
params.append(hint2type(arg))
|
522
|
+
return Generic(body, *params)
|
kirin/ir/dialect.py
ADDED
@@ -0,0 +1,125 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, TypeVar
|
4
|
+
from dataclasses import field, dataclass
|
5
|
+
|
6
|
+
from typing_extensions import dataclass_transform
|
7
|
+
|
8
|
+
from kirin.ir.nodes import Statement
|
9
|
+
from kirin.ir.attrs.abc import Attribute
|
10
|
+
|
11
|
+
T = TypeVar("T")
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from kirin.interp.table import MethodTable
|
15
|
+
from kirin.lowering.dialect import FromPythonAST
|
16
|
+
|
17
|
+
|
18
|
+
# TODO: add an option to generate default lowering at dialect construction
|
19
|
+
@dataclass
|
20
|
+
class Dialect:
|
21
|
+
"""Dialect is a collection of statements, attributes, interpreters, lowerings, and codegen.
|
22
|
+
|
23
|
+
Example:
|
24
|
+
```python
|
25
|
+
from kirin import ir
|
26
|
+
|
27
|
+
my_dialect = ir.Dialect(name="my_dialect")
|
28
|
+
|
29
|
+
```
|
30
|
+
"""
|
31
|
+
|
32
|
+
name: str
|
33
|
+
"""The name of the dialect."""
|
34
|
+
stmts: list[type[Statement]] = field(default_factory=list, init=True)
|
35
|
+
"""A list of statements in the dialect."""
|
36
|
+
attrs: list[type[Attribute]] = field(default_factory=list, init=True)
|
37
|
+
"""A list of attributes in the dialect."""
|
38
|
+
interps: dict[str, MethodTable] = field(default_factory=dict, init=True)
|
39
|
+
"""A dictionary of registered method table in the dialect."""
|
40
|
+
lowering: dict[str, FromPythonAST] = field(default_factory=dict, init=True)
|
41
|
+
"""A dictionary of registered python lowering implmentations in the dialect."""
|
42
|
+
|
43
|
+
def __post_init__(self) -> None:
|
44
|
+
from kirin.lowering.dialect import NoSpecialLowering
|
45
|
+
|
46
|
+
self.lowering["default"] = NoSpecialLowering()
|
47
|
+
|
48
|
+
def __repr__(self) -> str:
|
49
|
+
return f"Dialect(name={self.name}, ...)"
|
50
|
+
|
51
|
+
def __hash__(self) -> int:
|
52
|
+
return hash(self.name)
|
53
|
+
|
54
|
+
@dataclass_transform()
|
55
|
+
def register(self, node: type | None = None, key: str | None = None):
|
56
|
+
"""register is a decorator to register a node to the dialect.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
node (type | None): The node to register. Defaults to None.
|
60
|
+
key (str | None): The key to register the node to. Defaults to None.
|
61
|
+
|
62
|
+
Raises:
|
63
|
+
ValueError: If the node is not a subclass of Statement, Attribute, DialectInterpreter, FromPythonAST, or DialectEmit.
|
64
|
+
|
65
|
+
Example:
|
66
|
+
* Register a method table for concrete interpreter (by default key="main") to the dialect:
|
67
|
+
```python
|
68
|
+
from kirin import ir
|
69
|
+
|
70
|
+
my_dialect = ir.Dialect(name="my_dialect")
|
71
|
+
|
72
|
+
@my_dialect.register
|
73
|
+
class MyMethodTable(ir.MethodTable):
|
74
|
+
...
|
75
|
+
```
|
76
|
+
|
77
|
+
* Register a method table for the interpreter specified by `key` to the dialect:
|
78
|
+
```python
|
79
|
+
from kirin import ir
|
80
|
+
|
81
|
+
my_dialect = ir.Dialect(name="my_dialect")
|
82
|
+
|
83
|
+
@my_dialect.register(key="my_interp")
|
84
|
+
class MyMethodTable(ir.MethodTable):
|
85
|
+
...
|
86
|
+
```
|
87
|
+
|
88
|
+
|
89
|
+
"""
|
90
|
+
from kirin.interp.table import MethodTable
|
91
|
+
from kirin.lowering.dialect import FromPythonAST
|
92
|
+
|
93
|
+
if key is None:
|
94
|
+
key = "main"
|
95
|
+
|
96
|
+
def wrapper(node: type[T]) -> type[T]:
|
97
|
+
if issubclass(node, Statement):
|
98
|
+
self.stmts.append(node)
|
99
|
+
elif issubclass(node, Attribute):
|
100
|
+
assert (
|
101
|
+
Attribute in node.__mro__
|
102
|
+
), f"{node} is not a subclass of Attribute"
|
103
|
+
setattr(node, "dialect", self)
|
104
|
+
assert hasattr(node, "name"), f"{node} does not have a name attribute"
|
105
|
+
self.attrs.append(node)
|
106
|
+
elif issubclass(node, MethodTable):
|
107
|
+
if key in self.interps:
|
108
|
+
raise ValueError(
|
109
|
+
f"Cannot register {node} to Dialect, key {key} exists"
|
110
|
+
)
|
111
|
+
self.interps[key] = node()
|
112
|
+
elif issubclass(node, FromPythonAST):
|
113
|
+
if key in self.lowering:
|
114
|
+
raise ValueError(
|
115
|
+
f"Cannot register {node} to Dialect, key {key} exists"
|
116
|
+
)
|
117
|
+
self.lowering[key] = node()
|
118
|
+
else:
|
119
|
+
raise ValueError(f"Cannot register {node} to Dialect")
|
120
|
+
return node
|
121
|
+
|
122
|
+
if node is None:
|
123
|
+
return wrapper
|
124
|
+
|
125
|
+
return wrapper(node)
|