kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,522 @@
1
+ import typing
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass
4
+ from collections.abc import Hashable
5
+
6
+ from beartype.door import TupleVariableTypeHint # type: ignore
7
+ from beartype.door import TypeHint, ClassTypeHint, LiteralTypeHint, TypeVarTypeHint
8
+ from typing_extensions import Never
9
+
10
+ from kirin.print import Printer
11
+ from kirin.lattice import (
12
+ UnionMeta,
13
+ SingletonMeta,
14
+ BoundedLattice,
15
+ IsSubsetEqMixin,
16
+ SimpleMeetMixin,
17
+ )
18
+
19
+ from .abc import Attribute, LatticeAttributeMeta
20
+ from ._types import _TypeAttribute
21
+
22
+
23
+ class TypeAttributeMeta(LatticeAttributeMeta):
24
+ """Metaclass for type attributes."""
25
+
26
+ pass
27
+
28
+
29
+ class SingletonTypeMeta(TypeAttributeMeta, SingletonMeta):
30
+ """Metaclass for singleton type attributes.
31
+
32
+ Singleton type attributes are attributes that have only one instance.
33
+
34
+ Examples:
35
+ - `AnyType`
36
+ - `BottomType`
37
+ """
38
+
39
+ pass
40
+
41
+
42
+ class UnionTypeMeta(TypeAttributeMeta, UnionMeta):
43
+ pass
44
+
45
+
46
+ @dataclass
47
+ class TypeAttribute(
48
+ _TypeAttribute,
49
+ SimpleMeetMixin["TypeAttribute"],
50
+ IsSubsetEqMixin["TypeAttribute"],
51
+ BoundedLattice["TypeAttribute"],
52
+ metaclass=TypeAttributeMeta,
53
+ ):
54
+
55
+ @classmethod
56
+ def top(cls) -> "TypeAttribute":
57
+ return AnyType()
58
+
59
+ @classmethod
60
+ def bottom(cls) -> "TypeAttribute":
61
+ return BottomType()
62
+
63
+ def join(self, other: "TypeAttribute") -> "TypeAttribute":
64
+ if self.is_subseteq(other):
65
+ return other
66
+ elif other.is_subseteq(self):
67
+ return self
68
+ elif isinstance(other, TypeAttribute):
69
+ return Union(self, other)
70
+ return AnyType() # don't know how to join
71
+
72
+ def print_impl(self, printer: Printer) -> None:
73
+ printer.print_name(self, prefix="!")
74
+
75
+ def __or__(self, other: "TypeAttribute"):
76
+ return self.join(other)
77
+
78
+ def __eq__(self, value: object) -> bool:
79
+ return isinstance(value, TypeAttribute) and self.is_equal(value)
80
+
81
+ @abstractmethod
82
+ def __hash__(self) -> int: ...
83
+
84
+
85
+ @typing.final
86
+ @dataclass(eq=False)
87
+ class AnyType(TypeAttribute, metaclass=SingletonTypeMeta):
88
+ name = "Any"
89
+
90
+ def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
91
+ return self.is_subseteq(other.bound)
92
+
93
+ def __hash__(self) -> int:
94
+ return id(self)
95
+
96
+
97
+ @typing.final
98
+ @dataclass(eq=False)
99
+ class BottomType(TypeAttribute, metaclass=SingletonTypeMeta):
100
+ name = "Bottom"
101
+
102
+ def is_subseteq(self, other: TypeAttribute) -> bool:
103
+ if isinstance(other, TypeVar):
104
+ return self.is_subseteq(other.bound)
105
+ return True
106
+
107
+ def __hash__(self) -> int:
108
+ return id(self)
109
+
110
+
111
+ class PyClassMeta(TypeAttributeMeta):
112
+
113
+ def __init__(self, *args, **kwargs):
114
+ super(PyClassMeta, self).__init__(*args, **kwargs)
115
+ self._cache = {}
116
+
117
+ def __call__(self, typ):
118
+ if typ is typing.Any:
119
+ return AnyType()
120
+ elif typ is typing.NoReturn or typ is Never:
121
+ return BottomType()
122
+ elif typ is typing.Tuple:
123
+ typ = tuple
124
+ elif typ is typing.List:
125
+ typ = list
126
+ elif isinstance(typ, TypeVar):
127
+ return hint2type(typ)
128
+ elif isinstance(typ, type) and typ in self._cache:
129
+ return self._cache[typ]
130
+
131
+ instance = super(PyClassMeta, self).__call__(typ)
132
+ self._cache[typ] = instance
133
+ return instance
134
+
135
+
136
+ PyClassType = typing.TypeVar("PyClassType")
137
+
138
+
139
+ @typing.final
140
+ @dataclass(eq=False)
141
+ class PyClass(TypeAttribute, typing.Generic[PyClassType], metaclass=PyClassMeta):
142
+ name = "PyClass"
143
+ typ: type[PyClassType]
144
+
145
+ def __init__(self, typ: type[PyClassType]) -> None:
146
+ self.typ = typ
147
+
148
+ def is_subseteq_PyClass(self, other: "PyClass") -> bool:
149
+ return issubclass(self.typ, other.typ)
150
+
151
+ def is_subseteq_Union(self, other: "Union") -> bool:
152
+ return any(self.is_subseteq(t) for t in other.types)
153
+
154
+ def is_subseteq_Generic(self, other: "Generic") -> bool:
155
+ # NOTE: subclass without generics is just generic with all any parameters
156
+ Any = AnyType()
157
+ return (
158
+ self.is_subseteq(other.body)
159
+ and all(Any.is_subseteq(bound) for bound in other.vars)
160
+ and (other.vararg is None or Any.is_subseteq(other.vararg.typ))
161
+ )
162
+
163
+ def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
164
+ return self.is_subseteq(other.bound)
165
+
166
+ def __hash__(self) -> int:
167
+ return hash((PyClass, self.typ))
168
+
169
+ def __repr__(self) -> str:
170
+ return self.typ.__name__
171
+
172
+ def print_impl(self, printer: Printer) -> None:
173
+ printer.plain_print("!py.", self.typ.__name__)
174
+
175
+
176
+ class LiteralMeta(TypeAttributeMeta):
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ super(LiteralMeta, self).__init__(*args, **kwargs)
180
+ self._cache = {}
181
+
182
+ def __call__(self, data):
183
+ if isinstance(data, Attribute):
184
+ return data
185
+ elif not isinstance(data, Hashable):
186
+ return PyClass(type(data))
187
+ elif data in self._cache:
188
+ return self._cache[data]
189
+
190
+ instance = super(LiteralMeta, self).__call__(data)
191
+ self._cache[data] = instance
192
+ return instance
193
+
194
+
195
+ LiteralType = typing.TypeVar("LiteralType")
196
+
197
+
198
+ @typing.final
199
+ @dataclass(eq=False)
200
+ class Literal(TypeAttribute, typing.Generic[LiteralType], metaclass=LiteralMeta):
201
+ name = "Literal"
202
+ data: LiteralType
203
+
204
+ def is_equal(self, other: TypeAttribute) -> bool:
205
+ return self is other
206
+
207
+ def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
208
+ return self.is_subseteq(other.bound)
209
+
210
+ def is_subseteq_Union(self, other: "Union") -> bool:
211
+ return any(self.is_subseteq(t) for t in other.types)
212
+
213
+ def is_subseteq_Literal(self, other: "Literal") -> bool:
214
+ return self.data == other.data
215
+
216
+ def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
217
+ return PyClass(type(self.data)).is_subseteq(other)
218
+
219
+ def __hash__(self) -> int:
220
+ return hash((Literal, self.data))
221
+
222
+ def print_impl(self, printer: Printer) -> None:
223
+ printer.plain_print(repr(self.data))
224
+
225
+
226
+ @typing.final
227
+ @dataclass(eq=False)
228
+ class Union(TypeAttribute, metaclass=UnionTypeMeta):
229
+ name = "Union"
230
+ types: frozenset[TypeAttribute]
231
+
232
+ def __init__(
233
+ self,
234
+ typ_or_set: TypeAttribute | typing.Iterable[TypeAttribute],
235
+ *typs: TypeAttribute,
236
+ ):
237
+ if isinstance(typ_or_set, TypeAttribute):
238
+ params: typing.Iterable[TypeAttribute] = (typ_or_set, *typs)
239
+ else:
240
+ params = typ_or_set
241
+ assert not typs, "Cannot pass multiple arguments when passing a set"
242
+
243
+ types: frozenset[TypeAttribute] = frozenset()
244
+ for typ in params:
245
+ if isinstance(typ, Union):
246
+ types = types.union(typ.types)
247
+ else:
248
+ types = types.union({typ})
249
+ self.types = types
250
+
251
+ def is_equal(self, other: TypeAttribute) -> bool:
252
+ return isinstance(other, Union) and self.types == other.types
253
+
254
+ def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
255
+ return all(t.is_subseteq(other) for t in self.types)
256
+
257
+ def join(self, other: TypeAttribute) -> TypeAttribute:
258
+ if self.is_subseteq(other):
259
+ return other
260
+ elif other.is_subseteq(self):
261
+ return self
262
+ elif isinstance(other, Union):
263
+ return Union(self.types | other.types)
264
+ elif isinstance(other, TypeAttribute):
265
+ return Union(self.types | {other})
266
+ return BottomType()
267
+
268
+ def meet(self, other: TypeAttribute) -> TypeAttribute:
269
+ if self.is_subseteq(other):
270
+ return self
271
+ elif other.is_subseteq(self):
272
+ return other
273
+ elif isinstance(other, Union):
274
+ return Union(self.types & other.types)
275
+ elif isinstance(other, TypeAttribute):
276
+ return Union(self.types & {other})
277
+ return BottomType()
278
+
279
+ def __hash__(self) -> int:
280
+ return hash((Union, self.types))
281
+
282
+ def print_impl(self, printer: Printer) -> None:
283
+ printer.print_name(self, prefix="!")
284
+ printer.print_seq(self.types, delim=", ", prefix="[", suffix="]")
285
+
286
+
287
+ @typing.final
288
+ @dataclass(eq=False)
289
+ class TypeVar(TypeAttribute):
290
+ name = "TypeVar"
291
+ varname: str
292
+ bound: TypeAttribute
293
+
294
+ def __init__(self, name: str, bound: TypeAttribute | None = None):
295
+ self.varname = name
296
+ self.bound = bound or AnyType()
297
+
298
+ def is_equal(self, other: TypeAttribute) -> bool:
299
+ return (
300
+ isinstance(other, TypeVar)
301
+ and self.varname == other.varname
302
+ and self.bound.is_equal(other.bound)
303
+ )
304
+
305
+ def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
306
+ return self.bound.is_subseteq(other.bound)
307
+
308
+ def is_subseteq_Union(self, other: Union) -> bool:
309
+ return any(self.is_subseteq(t) for t in other.types)
310
+
311
+ def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
312
+ return self.bound.is_subseteq(other)
313
+
314
+ def __hash__(self) -> int:
315
+ return hash((TypeVar, self.varname, self.bound))
316
+
317
+ def print_impl(self, printer: Printer) -> None:
318
+ printer.plain_print(f"~{self.varname}")
319
+ if self.bound is not self.bound.top():
320
+ printer.plain_print(" : ")
321
+ printer.print(self.bound)
322
+
323
+
324
+ @typing.final
325
+ @dataclass(eq=False)
326
+ class Vararg(Attribute):
327
+ name = "Vararg"
328
+ typ: TypeAttribute
329
+
330
+ def __hash__(self) -> int:
331
+ return hash((Vararg, self.typ))
332
+
333
+ def print_impl(self, printer: Printer) -> None:
334
+ printer.plain_print("*")
335
+ printer.print(self.typ)
336
+
337
+
338
+ TypeVarValue: typing.TypeAlias = TypeAttribute | Vararg | list
339
+ TypeOrVararg: typing.TypeAlias = TypeAttribute | Vararg
340
+
341
+
342
+ @typing.final
343
+ @dataclass(eq=False)
344
+ class Generic(TypeAttribute, typing.Generic[PyClassType]):
345
+ name = "Generic"
346
+ body: PyClass[PyClassType]
347
+ vars: tuple[TypeAttribute, ...]
348
+ vararg: Vararg | None = None
349
+
350
+ def __init__(
351
+ self,
352
+ body: type[PyClassType] | PyClass[PyClassType],
353
+ *vars: TypeAttribute | list | Vararg,
354
+ ):
355
+ if isinstance(body, PyClass):
356
+ self.body = body
357
+ else:
358
+ self.body = PyClass(body)
359
+ self.vars, self.vararg = _split_type_args(vars)
360
+
361
+ def is_subseteq_Literal(self, other: Literal) -> bool:
362
+ return False
363
+
364
+ def is_subseteq_PyClass(self, other: PyClass) -> bool:
365
+ return self.body.is_subseteq(other)
366
+
367
+ def is_subseteq_Union(self, other: Union) -> bool:
368
+ return any(self.is_subseteq(t) for t in other.types)
369
+
370
+ def is_subseteq_TypeVar(self, other: TypeVar) -> bool:
371
+ return self.body.is_subseteq(other.bound)
372
+
373
+ def is_subseteq_Generic(self, other: "Generic") -> bool:
374
+ if other.vararg is None:
375
+ return (
376
+ self.body.is_subseteq(other.body)
377
+ and len(self.vars) == len(other.vars)
378
+ and all(v.is_subseteq(o) for v, o in zip(self.vars, other.vars))
379
+ )
380
+ else:
381
+ return (
382
+ self.body.is_subseteq(other.body)
383
+ and len(self.vars) >= len(other.vars)
384
+ and all(v.is_subseteq(o) for v, o in zip(self.vars, other.vars))
385
+ and all(
386
+ v.is_subseteq(other.vararg.typ)
387
+ for v in self.vars[len(other.vars) :]
388
+ )
389
+ and (
390
+ self.vararg is None or self.vararg.typ.is_subseteq(other.vararg.typ)
391
+ )
392
+ )
393
+
394
+ def __hash__(self) -> int:
395
+ return hash((Generic, self.body, self.vars, self.vararg))
396
+
397
+ def __repr__(self) -> str:
398
+ if self.vararg is None:
399
+ return f"{self.body}[{', '.join(map(repr, self.vars))}]"
400
+ else:
401
+ return f"{self.body}[{', '.join(map(repr, self.vars))}, {self.vararg}, ...]"
402
+
403
+ def print_impl(self, printer: Printer) -> None:
404
+ printer.print(self.body)
405
+ printer.plain_print("[")
406
+ if self.vars:
407
+ printer.print_seq(self.vars)
408
+ if self.vararg is not None:
409
+ if self.vars:
410
+ printer.plain_print(", ")
411
+ printer.print(self.vararg.typ)
412
+ printer.plain_print(", ...")
413
+ printer.plain_print("]")
414
+
415
+ def __getitem__(self, typ: TypeVarValue | tuple[TypeVarValue, ...]) -> "Generic":
416
+ return self.where(typ)
417
+
418
+ def where(self, typ: TypeVarValue | tuple[TypeVarValue, ...]) -> "Generic":
419
+ if isinstance(typ, tuple):
420
+ typs = typ
421
+ else:
422
+ typs = (typ,)
423
+
424
+ args, vararg = _split_type_args(typs)
425
+ if self.vararg is None and vararg is None:
426
+ assert len(args) <= len(
427
+ self.vars
428
+ ), "Number of type arguments does not match"
429
+ if all(v.is_subseteq(bound) for v, bound in zip(args, self.vars)):
430
+ return Generic(self.body, *args, *self.vars[len(args) :])
431
+ else:
432
+ raise TypeError("Type arguments do not match")
433
+ elif self.vararg is not None and vararg is None:
434
+ assert len(args) >= len(
435
+ self.vars
436
+ ), "Number of type arguments does not match"
437
+ if all(v.is_subseteq(bound) for v, bound in zip(args, self.vars)) and all(
438
+ v.is_subseteq(self.vararg.typ) for v in args[len(self.vars) :]
439
+ ):
440
+ return Generic(self.body, *args)
441
+ elif self.vararg is not None and vararg is not None:
442
+ if len(args) < len(self.vars):
443
+ if (
444
+ all(v.is_subseteq(bound) for v, bound in zip(args, self.vars))
445
+ and all(
446
+ vararg.typ.is_subseteq(bound)
447
+ for bound in self.vars[len(args) :]
448
+ )
449
+ and vararg.typ.is_subseteq(self.vararg.typ)
450
+ ):
451
+ return Generic(self.body, *args, vararg)
452
+ else:
453
+ if (
454
+ all(v.is_subseteq(bound) for v, bound in zip(args, self.vars))
455
+ and all(v.is_subseteq(vararg.typ) for v in args[len(self.vars) :])
456
+ and vararg.typ.is_subseteq(self.vararg.typ)
457
+ ):
458
+ return Generic(self.body, *args, vararg)
459
+ raise TypeError("Type arguments do not match")
460
+
461
+
462
+ def _typeparams_list2tuple(args: tuple[TypeVarValue, ...]) -> tuple[TypeOrVararg, ...]:
463
+ "provides the syntax sugar [A, B, C] type Generic(tuple, A, B, C)"
464
+ return tuple(Generic(tuple, *arg) if isinstance(arg, list) else arg for arg in args)
465
+
466
+
467
+ def _split_type_args(
468
+ args: tuple[TypeVarValue, ...]
469
+ ) -> tuple[tuple[TypeAttribute, ...], Vararg | None]:
470
+ args = _typeparams_list2tuple(args)
471
+ if args is None or len(args) == 0:
472
+ return (), None
473
+
474
+ if isinstance(args[-1], Vararg):
475
+ xs = args[:-1]
476
+ if is_tuple_of(xs, TypeAttribute):
477
+ return xs, args[-1]
478
+ else:
479
+ raise TypeError("Multiple varargs are not allowed")
480
+ elif is_tuple_of(args, TypeAttribute):
481
+ return args, None
482
+ raise TypeError("Vararg must be the last argument")
483
+
484
+
485
+ T = typing.TypeVar("T")
486
+
487
+
488
+ def is_tuple_of(xs: tuple, typ: type[T]) -> typing.TypeGuard[tuple[T, ...]]:
489
+ return all(isinstance(x, typ) for x in xs)
490
+
491
+
492
+ def hint2type(hint) -> TypeAttribute:
493
+ if isinstance(hint, TypeAttribute):
494
+ return hint
495
+ elif hint is None:
496
+ return PyClass(type(None))
497
+
498
+ bear_hint = TypeHint(hint)
499
+ if isinstance(bear_hint, LiteralTypeHint):
500
+ return Literal(typing.get_args(hint)[0])
501
+ elif isinstance(bear_hint, TypeVarTypeHint):
502
+ return TypeVar(
503
+ hint.__name__,
504
+ hint2type(hint.__bound__) if hint.__bound__ else None,
505
+ )
506
+ elif isinstance(bear_hint, ClassTypeHint):
507
+ return PyClass(hint)
508
+ elif isinstance(bear_hint, TupleVariableTypeHint):
509
+ if len(bear_hint.args) != 1:
510
+ raise TypeError("Tuple hint must have exactly one argument")
511
+ return Generic(tuple, Vararg(hint2type(bear_hint.args[0])))
512
+
513
+ origin: type | None = typing.get_origin(hint)
514
+ if origin is None: # non-generic
515
+ return PyClass(hint)
516
+
517
+ body = PyClass(origin)
518
+ args = typing.get_args(hint)
519
+ params = []
520
+ for arg in args:
521
+ params.append(hint2type(arg))
522
+ return Generic(body, *params)
kirin/ir/dialect.py ADDED
@@ -0,0 +1,125 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, TypeVar
4
+ from dataclasses import field, dataclass
5
+
6
+ from typing_extensions import dataclass_transform
7
+
8
+ from kirin.ir.nodes import Statement
9
+ from kirin.ir.attrs.abc import Attribute
10
+
11
+ T = TypeVar("T")
12
+
13
+ if TYPE_CHECKING:
14
+ from kirin.interp.table import MethodTable
15
+ from kirin.lowering.dialect import FromPythonAST
16
+
17
+
18
+ # TODO: add an option to generate default lowering at dialect construction
19
+ @dataclass
20
+ class Dialect:
21
+ """Dialect is a collection of statements, attributes, interpreters, lowerings, and codegen.
22
+
23
+ Example:
24
+ ```python
25
+ from kirin import ir
26
+
27
+ my_dialect = ir.Dialect(name="my_dialect")
28
+
29
+ ```
30
+ """
31
+
32
+ name: str
33
+ """The name of the dialect."""
34
+ stmts: list[type[Statement]] = field(default_factory=list, init=True)
35
+ """A list of statements in the dialect."""
36
+ attrs: list[type[Attribute]] = field(default_factory=list, init=True)
37
+ """A list of attributes in the dialect."""
38
+ interps: dict[str, MethodTable] = field(default_factory=dict, init=True)
39
+ """A dictionary of registered method table in the dialect."""
40
+ lowering: dict[str, FromPythonAST] = field(default_factory=dict, init=True)
41
+ """A dictionary of registered python lowering implmentations in the dialect."""
42
+
43
+ def __post_init__(self) -> None:
44
+ from kirin.lowering.dialect import NoSpecialLowering
45
+
46
+ self.lowering["default"] = NoSpecialLowering()
47
+
48
+ def __repr__(self) -> str:
49
+ return f"Dialect(name={self.name}, ...)"
50
+
51
+ def __hash__(self) -> int:
52
+ return hash(self.name)
53
+
54
+ @dataclass_transform()
55
+ def register(self, node: type | None = None, key: str | None = None):
56
+ """register is a decorator to register a node to the dialect.
57
+
58
+ Args:
59
+ node (type | None): The node to register. Defaults to None.
60
+ key (str | None): The key to register the node to. Defaults to None.
61
+
62
+ Raises:
63
+ ValueError: If the node is not a subclass of Statement, Attribute, DialectInterpreter, FromPythonAST, or DialectEmit.
64
+
65
+ Example:
66
+ * Register a method table for concrete interpreter (by default key="main") to the dialect:
67
+ ```python
68
+ from kirin import ir
69
+
70
+ my_dialect = ir.Dialect(name="my_dialect")
71
+
72
+ @my_dialect.register
73
+ class MyMethodTable(ir.MethodTable):
74
+ ...
75
+ ```
76
+
77
+ * Register a method table for the interpreter specified by `key` to the dialect:
78
+ ```python
79
+ from kirin import ir
80
+
81
+ my_dialect = ir.Dialect(name="my_dialect")
82
+
83
+ @my_dialect.register(key="my_interp")
84
+ class MyMethodTable(ir.MethodTable):
85
+ ...
86
+ ```
87
+
88
+
89
+ """
90
+ from kirin.interp.table import MethodTable
91
+ from kirin.lowering.dialect import FromPythonAST
92
+
93
+ if key is None:
94
+ key = "main"
95
+
96
+ def wrapper(node: type[T]) -> type[T]:
97
+ if issubclass(node, Statement):
98
+ self.stmts.append(node)
99
+ elif issubclass(node, Attribute):
100
+ assert (
101
+ Attribute in node.__mro__
102
+ ), f"{node} is not a subclass of Attribute"
103
+ setattr(node, "dialect", self)
104
+ assert hasattr(node, "name"), f"{node} does not have a name attribute"
105
+ self.attrs.append(node)
106
+ elif issubclass(node, MethodTable):
107
+ if key in self.interps:
108
+ raise ValueError(
109
+ f"Cannot register {node} to Dialect, key {key} exists"
110
+ )
111
+ self.interps[key] = node()
112
+ elif issubclass(node, FromPythonAST):
113
+ if key in self.lowering:
114
+ raise ValueError(
115
+ f"Cannot register {node} to Dialect, key {key} exists"
116
+ )
117
+ self.lowering[key] = node()
118
+ else:
119
+ raise ValueError(f"Cannot register {node} to Dialect")
120
+ return node
121
+
122
+ if node is None:
123
+ return wrapper
124
+
125
+ return wrapper(node)