kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,90 @@
1
+ from typing import TypeVar, final
2
+
3
+ from kirin import ir, types, interp
4
+ from kirin.decl import fields
5
+ from kirin.analysis import const
6
+ from kirin.interp.impl import Signature
7
+ from kirin.analysis.forward import Forward, ForwardFrame
8
+
9
+ from .solve import TypeResolution
10
+
11
+
12
+ @final
13
+ class TypeInference(Forward[types.TypeAttribute]):
14
+ """Type inference analysis for kirin.
15
+
16
+ This analysis uses the forward dataflow analysis framework to infer the types of
17
+ the IR. The analysis uses the type information within the IR to determine the
18
+ method dispatch.
19
+
20
+ The analysis will fallback to a type resolution algorithm if the type information
21
+ is not available in the IR but the type information is available in the abstract
22
+ values.
23
+ """
24
+
25
+ keys = ["typeinfer"]
26
+ lattice = types.TypeAttribute
27
+
28
+ def run_analysis(
29
+ self, method: ir.Method, args: tuple[types.TypeAttribute, ...] | None = None
30
+ ) -> tuple[ForwardFrame[types.TypeAttribute], types.TypeAttribute]:
31
+ if args is None:
32
+ args = method.arg_types
33
+ return super().run_analysis(method, args)
34
+
35
+ # NOTE: unlike concrete interpreter, instead of using type information
36
+ # within the IR. Type inference will use the interpreted
37
+ # value (which is a type) to determine the method dispatch.
38
+ def build_signature(
39
+ self, frame: ForwardFrame[types.TypeAttribute], stmt: ir.Statement
40
+ ) -> Signature:
41
+ _args = ()
42
+ for x in frame.get_values(stmt.args):
43
+ # TODO: remove this after we have multiple dispatch...
44
+ if isinstance(x, types.Generic):
45
+ _args += (x.body,)
46
+ else:
47
+ _args += (x,)
48
+ return Signature(stmt.__class__, _args)
49
+
50
+ def eval_stmt_fallback(
51
+ self, frame: ForwardFrame[types.TypeAttribute], stmt: ir.Statement
52
+ ) -> tuple[types.TypeAttribute, ...] | interp.SpecialValue[types.TypeAttribute]:
53
+ resolve = TypeResolution()
54
+ fs = fields(stmt)
55
+ for f, value in zip(fs.args.values(), frame.get_values(stmt.args)):
56
+ resolve.solve(f.type, value)
57
+
58
+ for arg, f in zip(stmt.args, fs.args.values()):
59
+ frame.set(arg, frame.get(arg).meet(resolve.substitute(f.type)))
60
+ return tuple(resolve.substitute(result.type) for result in stmt.results)
61
+
62
+ def run_method(
63
+ self, method: ir.Method, args: tuple[types.TypeAttribute, ...]
64
+ ) -> tuple[ForwardFrame[types.TypeAttribute], types.TypeAttribute]:
65
+ return self.run_callable(method.code, (method.self_type,) + args)
66
+
67
+ T = TypeVar("T")
68
+
69
+ @classmethod
70
+ def maybe_const(cls, value: ir.SSAValue, type_: type[T]) -> T | None:
71
+ """Get a constant value of a given type.
72
+
73
+ If the value is not a constant or the constant is not of the given type, return
74
+ `None`.
75
+ """
76
+ hint = value.hints.get("const")
77
+ if isinstance(hint, const.Value) and isinstance(hint.data, type_):
78
+ return hint.data
79
+
80
+ @classmethod
81
+ def expect_const(cls, value: ir.SSAValue, type_: type[T]):
82
+ """Expect a constant value of a given type.
83
+
84
+ If the value is not a constant or the constant is not of the given type, raise
85
+ an `InterpreterError`.
86
+ """
87
+ hint = cls.maybe_const(value, type_)
88
+ if hint is None:
89
+ raise interp.InterpreterError(f"expected {type_}, got {hint}")
90
+ return hint
@@ -0,0 +1,141 @@
1
+ """Type resolution for type inference.
2
+
3
+ This module contains the type resolution algorithm for type inference.
4
+ A simple algorithm is used to resolve the types of the IR by comparing
5
+ the input types with the output types.
6
+ """
7
+
8
+ from dataclasses import field, dataclass
9
+
10
+ from kirin import types
11
+
12
+
13
+ @dataclass
14
+ class TypeResolutionResult:
15
+ """Base class for type resolution results."""
16
+
17
+ pass
18
+
19
+
20
+ @dataclass
21
+ class ResolutionOk(TypeResolutionResult):
22
+ """Type resolution result for successful resolution."""
23
+
24
+ def __bool__(self):
25
+ return True
26
+
27
+
28
+ Ok = ResolutionOk()
29
+
30
+
31
+ @dataclass
32
+ class ResolutionError(TypeResolutionResult):
33
+ """Type resolution result for failed resolution."""
34
+
35
+ expr: types.TypeAttribute
36
+ value: types.TypeAttribute
37
+
38
+ def __bool__(self):
39
+ return False
40
+
41
+ def __str__(self):
42
+ return f"expected {self.expr}, got {self.value}"
43
+
44
+
45
+ @dataclass
46
+ class TypeResolution:
47
+ """Type resolution algorithm for type inference."""
48
+
49
+ vars: dict[types.TypeVar, types.TypeAttribute] = field(default_factory=dict)
50
+
51
+ def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute:
52
+ """Substitute type variables in the type with their values.
53
+
54
+ This method substitutes type variables in the given type with their
55
+ values. If the type is a generic type, the method recursively
56
+ substitutes the type variables in the type arguments.
57
+
58
+ Args:
59
+ typ: The type to substitute.
60
+
61
+ Returns:
62
+ The type with the type variables substituted.
63
+ """
64
+ if isinstance(typ, types.TypeVar):
65
+ return self.vars.get(typ, typ)
66
+ elif isinstance(typ, types.Generic):
67
+ return types.Generic(
68
+ typ.body, *tuple(self.substitute(var) for var in typ.vars)
69
+ )
70
+ elif isinstance(typ, types.Union):
71
+ return types.Union(self.substitute(t) for t in typ.types)
72
+ return typ
73
+
74
+ def solve(
75
+ self, annot: types.TypeAttribute, value: types.TypeAttribute
76
+ ) -> TypeResolutionResult:
77
+ """Solve the type resolution problem.
78
+
79
+ This method compares the expected type `annot` with the actual
80
+ type `value` and returns a result indicating whether the types
81
+ match or not.
82
+
83
+ Args:
84
+ annot: The expected type.
85
+ value: The actual type.
86
+
87
+ Returns:
88
+ A `TypeResolutionResult` object indicating the result of the
89
+ resolution.
90
+ """
91
+ if isinstance(annot, types.TypeVar):
92
+ return self.solve_TypeVar(annot, value)
93
+ elif isinstance(annot, types.Generic):
94
+ return self.solve_Generic(annot, value)
95
+ elif isinstance(annot, types.Union):
96
+ return self.solve_Union(annot, value)
97
+
98
+ if annot.is_subseteq(value):
99
+ return Ok
100
+ else:
101
+ return ResolutionError(annot, value)
102
+
103
+ def solve_TypeVar(self, annot: types.TypeVar, value: types.TypeAttribute):
104
+ if annot in self.vars:
105
+ if value.is_subseteq(self.vars[annot]):
106
+ self.vars[annot] = value
107
+ elif self.vars[annot].is_subseteq(value):
108
+ pass
109
+ else:
110
+ return ResolutionError(annot, value)
111
+ else:
112
+ self.vars[annot] = value
113
+ return Ok
114
+
115
+ def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute):
116
+ if not isinstance(value, types.Generic):
117
+ return ResolutionError(annot, value)
118
+
119
+ if not value.body.is_subseteq(annot.body):
120
+ return ResolutionError(annot.body, value.body)
121
+
122
+ for var, val in zip(annot.vars, value.vars):
123
+ result = self.solve(var, val)
124
+ if not result:
125
+ return result
126
+
127
+ if not annot.vararg:
128
+ return Ok
129
+
130
+ for val in value.vars[len(annot.vars) :]:
131
+ result = self.solve(annot.vararg.typ, val)
132
+ if not result:
133
+ return result
134
+ return Ok
135
+
136
+ def solve_Union(self, annot: types.Union, value: types.TypeAttribute):
137
+ for typ in annot.types:
138
+ result = self.solve(typ, value)
139
+ if result:
140
+ return Ok
141
+ return ResolutionError(annot, value)
kirin/decl/__init__.py ADDED
@@ -0,0 +1,108 @@
1
+ from typing import TypeVar, Callable
2
+
3
+ from typing_extensions import Unpack, dataclass_transform
4
+
5
+ from kirin.ir import Statement
6
+ from kirin.decl import info
7
+ from kirin.decl.base import StatementOptions
8
+ from kirin.decl.verify import Verify
9
+ from kirin.decl.emit.init import EmitInit
10
+ from kirin.decl.emit.name import EmitName
11
+ from kirin.decl.emit.repr import EmitRepr
12
+ from kirin.decl.emit.traits import EmitTraits
13
+ from kirin.decl.emit.verify import EmitVerify
14
+ from kirin.decl.scan_fields import ScanFields
15
+ from kirin.decl.emit.dialect import EmitDialect
16
+ from kirin.decl.emit.property import EmitProperty
17
+ from kirin.decl.emit.typecheck import EmitTypeCheck
18
+
19
+
20
+ class StatementDecl(
21
+ ScanFields,
22
+ Verify,
23
+ EmitInit,
24
+ EmitProperty,
25
+ EmitDialect,
26
+ EmitName,
27
+ EmitRepr,
28
+ EmitTraits,
29
+ EmitVerify,
30
+ EmitTypeCheck,
31
+ ):
32
+ pass
33
+
34
+
35
+ StmtType = TypeVar("StmtType", bound=Statement)
36
+
37
+
38
+ @dataclass_transform(
39
+ field_specifiers=(
40
+ info.attribute,
41
+ info.argument,
42
+ info.region,
43
+ info.result,
44
+ info.block,
45
+ )
46
+ )
47
+ def statement(
48
+ cls=None,
49
+ **kwargs: Unpack[StatementOptions],
50
+ ) -> Callable[[type[StmtType]], type[StmtType]]:
51
+ """Declare a new statement class.
52
+
53
+ This decorator is used to declare a new statement class. It is used to
54
+ generate the necessary boilerplate code for the class. The class should
55
+ inherit from `kirin.ir.Statement`.
56
+
57
+ Args:
58
+ init(bool): Whether to generate an `__init__` method.
59
+ repr(bool): Whether to generate a `__repr__` method.
60
+ kw_only(bool): Whether to use keyword-only arguments in the `__init__`
61
+ method.
62
+ dialect(Optional[Dialect]): The dialect of the statement.
63
+ property(bool): Whether to generate property methods for attributes.
64
+
65
+ Example:
66
+ The following is an example of how to use the `statement` decorator.
67
+
68
+ ```python
69
+ @statement
70
+ class MyStatement(ir.Statement):
71
+ name = "some_name"
72
+ traits = frozenset({TraitA(), TraitB()})
73
+ some_input: ir.SSAValue = info.argument()
74
+ some_output: ir.ResultValue = info.result()
75
+ body: ir.Region = info.region()
76
+ successor: ir.Block = info.block()
77
+ ```
78
+
79
+ If the `name` field is not specified, a lowercase name field will be auto generated.
80
+
81
+ In addition, one can optionally register the statement to a dialect
82
+ by providing the `dialect` argument to the decorator.
83
+
84
+ The following example register the statement to a dialect `my_dialect_object`, and `name = "myfoo"` field is autogenerated
85
+
86
+ ```python
87
+ @statement(dialect=my_dialect_object)
88
+ class MyFoo(ir.Statement):
89
+ traits = frozenset({ir.FromPythonCall()})
90
+ value: str = info.attribute()
91
+ ```
92
+ """
93
+
94
+ def wrap(cls):
95
+ decl = StatementDecl(cls, **kwargs)
96
+ decl.scan_fields()
97
+ decl.verify()
98
+ decl.emit()
99
+ decl.register()
100
+ return cls
101
+
102
+ if cls is None:
103
+ return wrap
104
+ return wrap(cls)
105
+
106
+
107
+ def fields(cls: type[Statement] | Statement) -> info.StatementFields:
108
+ return getattr(cls, ScanFields._FIELDS)
kirin/decl/base.py ADDED
@@ -0,0 +1,65 @@
1
+ import sys
2
+ import inspect
3
+ from typing import Any, TypedDict
4
+
5
+ from typing_extensions import Unpack, Optional
6
+
7
+ from kirin.ir import Dialect
8
+ from kirin.decl.info import StatementFields
9
+
10
+
11
+ class StatementOptions(TypedDict, total=False):
12
+ init: bool
13
+ repr: bool
14
+ kw_only: bool
15
+ dialect: Optional[Dialect]
16
+ property: bool
17
+
18
+
19
+ class BaseModifier:
20
+ _PARAMS = "__kirin_stmt_params"
21
+
22
+ def __init__(self, cls: type, **kwargs: Unpack[StatementOptions]) -> None:
23
+ self.cls = cls
24
+ self.cls_module = sys.modules.get(cls.__module__)
25
+
26
+ if "dialect" in kwargs:
27
+ self.dialect = kwargs["dialect"]
28
+ else:
29
+ self.dialect = None
30
+ self.params = kwargs
31
+ setattr(cls, self._PARAMS, self.params)
32
+
33
+ if cls.__module__ in sys.modules:
34
+ self.globals = sys.modules[cls.__module__].__dict__
35
+ else:
36
+ # Theoretically this can happen if someone writes
37
+ # a custom string to cls.__module__. In which case
38
+ # such dataclass won't be fully introspectable
39
+ # (w.r.t. typing.get_type_hints) but will still function
40
+ # correctly.
41
+ self.globals: dict[str, Any] = {}
42
+
43
+ # analysis state, used by scan_field, etc.
44
+ self.fields = StatementFields()
45
+ self.has_statement_bases = False
46
+ self.kw_only = self.params.get("kw_only", False)
47
+ self.KW_ONLY_seen = False
48
+
49
+ def register(self) -> None:
50
+ if self.dialect is None:
51
+ return
52
+ self.dialect.register(self.cls)
53
+
54
+ def emit(self):
55
+ self._self_name = "__kirin_stmt_self" if "self" in self.fields else "self"
56
+ self._class_name = "__kirin_stmt_cls" if "cls" in self.fields else "cls"
57
+ self._run_passes("emit_")
58
+
59
+ def verify(self):
60
+ self._run_passes("verify_")
61
+
62
+ def _run_passes(self, prefix: str):
63
+ for name, method in inspect.getmembers(self, inspect.ismethod):
64
+ if name.startswith(prefix):
65
+ method()
@@ -0,0 +1,2 @@
1
+ def camel2snake(name: str) -> str:
2
+ return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_")
File without changes
@@ -0,0 +1,29 @@
1
+ """This module provides a function to create a function dynamically.
2
+
3
+ Copied from `dataclasses._create_fn` in Python 3.10.13.
4
+ """
5
+
6
+ from dataclasses import MISSING
7
+
8
+
9
+ def create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING):
10
+ # Note that we may mutate locals. Callers beware!
11
+ # The only callers are internal to this module, so no
12
+ # worries about external callers.
13
+ if locals is None:
14
+ locals = {}
15
+ return_annotation = ""
16
+ if return_type is not MISSING:
17
+ locals["_return_type"] = return_type
18
+ return_annotation = "->_return_type"
19
+ args = ",".join(args)
20
+ body = "\n".join(f" {b}" for b in body)
21
+
22
+ # Compute the text of the entire function.
23
+ txt = f" def {name}({args}){return_annotation}:\n{body}"
24
+
25
+ local_vars = ", ".join(locals.keys())
26
+ txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
27
+ ns = {}
28
+ exec(txt, globals, ns)
29
+ return ns["__create_fn__"](**locals)
@@ -0,0 +1,22 @@
1
+ """Copied from dataclasses in Python 3.10.13.
2
+ """
3
+
4
+ from types import FunctionType
5
+
6
+
7
+ def set_qualname(cls: type, value):
8
+ # Ensure that the functions returned from _create_fn uses the proper
9
+ # __qualname__ (the class they belong to).
10
+ if isinstance(value, FunctionType):
11
+ value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
12
+ return value
13
+
14
+
15
+ def set_new_attribute(cls: type, name: str, value):
16
+ # Never overwrites an existing attribute. Returns True if the
17
+ # attribute already exists.
18
+ if name in cls.__dict__:
19
+ return True
20
+ set_qualname(cls, value)
21
+ setattr(cls, name, value)
22
+ return False
@@ -0,0 +1,8 @@
1
+ from kirin.decl.base import BaseModifier
2
+
3
+
4
+ class EmitDialect(BaseModifier):
5
+
6
+ def emit_dialect(self):
7
+ setattr(self.cls, "dialect", self.dialect)
8
+ return