kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
from typing import TypeVar, final
|
2
|
+
|
3
|
+
from kirin import ir, types, interp
|
4
|
+
from kirin.decl import fields
|
5
|
+
from kirin.analysis import const
|
6
|
+
from kirin.interp.impl import Signature
|
7
|
+
from kirin.analysis.forward import Forward, ForwardFrame
|
8
|
+
|
9
|
+
from .solve import TypeResolution
|
10
|
+
|
11
|
+
|
12
|
+
@final
|
13
|
+
class TypeInference(Forward[types.TypeAttribute]):
|
14
|
+
"""Type inference analysis for kirin.
|
15
|
+
|
16
|
+
This analysis uses the forward dataflow analysis framework to infer the types of
|
17
|
+
the IR. The analysis uses the type information within the IR to determine the
|
18
|
+
method dispatch.
|
19
|
+
|
20
|
+
The analysis will fallback to a type resolution algorithm if the type information
|
21
|
+
is not available in the IR but the type information is available in the abstract
|
22
|
+
values.
|
23
|
+
"""
|
24
|
+
|
25
|
+
keys = ["typeinfer"]
|
26
|
+
lattice = types.TypeAttribute
|
27
|
+
|
28
|
+
def run_analysis(
|
29
|
+
self, method: ir.Method, args: tuple[types.TypeAttribute, ...] | None = None
|
30
|
+
) -> tuple[ForwardFrame[types.TypeAttribute], types.TypeAttribute]:
|
31
|
+
if args is None:
|
32
|
+
args = method.arg_types
|
33
|
+
return super().run_analysis(method, args)
|
34
|
+
|
35
|
+
# NOTE: unlike concrete interpreter, instead of using type information
|
36
|
+
# within the IR. Type inference will use the interpreted
|
37
|
+
# value (which is a type) to determine the method dispatch.
|
38
|
+
def build_signature(
|
39
|
+
self, frame: ForwardFrame[types.TypeAttribute], stmt: ir.Statement
|
40
|
+
) -> Signature:
|
41
|
+
_args = ()
|
42
|
+
for x in frame.get_values(stmt.args):
|
43
|
+
# TODO: remove this after we have multiple dispatch...
|
44
|
+
if isinstance(x, types.Generic):
|
45
|
+
_args += (x.body,)
|
46
|
+
else:
|
47
|
+
_args += (x,)
|
48
|
+
return Signature(stmt.__class__, _args)
|
49
|
+
|
50
|
+
def eval_stmt_fallback(
|
51
|
+
self, frame: ForwardFrame[types.TypeAttribute], stmt: ir.Statement
|
52
|
+
) -> tuple[types.TypeAttribute, ...] | interp.SpecialValue[types.TypeAttribute]:
|
53
|
+
resolve = TypeResolution()
|
54
|
+
fs = fields(stmt)
|
55
|
+
for f, value in zip(fs.args.values(), frame.get_values(stmt.args)):
|
56
|
+
resolve.solve(f.type, value)
|
57
|
+
|
58
|
+
for arg, f in zip(stmt.args, fs.args.values()):
|
59
|
+
frame.set(arg, frame.get(arg).meet(resolve.substitute(f.type)))
|
60
|
+
return tuple(resolve.substitute(result.type) for result in stmt.results)
|
61
|
+
|
62
|
+
def run_method(
|
63
|
+
self, method: ir.Method, args: tuple[types.TypeAttribute, ...]
|
64
|
+
) -> tuple[ForwardFrame[types.TypeAttribute], types.TypeAttribute]:
|
65
|
+
return self.run_callable(method.code, (method.self_type,) + args)
|
66
|
+
|
67
|
+
T = TypeVar("T")
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def maybe_const(cls, value: ir.SSAValue, type_: type[T]) -> T | None:
|
71
|
+
"""Get a constant value of a given type.
|
72
|
+
|
73
|
+
If the value is not a constant or the constant is not of the given type, return
|
74
|
+
`None`.
|
75
|
+
"""
|
76
|
+
hint = value.hints.get("const")
|
77
|
+
if isinstance(hint, const.Value) and isinstance(hint.data, type_):
|
78
|
+
return hint.data
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def expect_const(cls, value: ir.SSAValue, type_: type[T]):
|
82
|
+
"""Expect a constant value of a given type.
|
83
|
+
|
84
|
+
If the value is not a constant or the constant is not of the given type, raise
|
85
|
+
an `InterpreterError`.
|
86
|
+
"""
|
87
|
+
hint = cls.maybe_const(value, type_)
|
88
|
+
if hint is None:
|
89
|
+
raise interp.InterpreterError(f"expected {type_}, got {hint}")
|
90
|
+
return hint
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""Type resolution for type inference.
|
2
|
+
|
3
|
+
This module contains the type resolution algorithm for type inference.
|
4
|
+
A simple algorithm is used to resolve the types of the IR by comparing
|
5
|
+
the input types with the output types.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from dataclasses import field, dataclass
|
9
|
+
|
10
|
+
from kirin import types
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class TypeResolutionResult:
|
15
|
+
"""Base class for type resolution results."""
|
16
|
+
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class ResolutionOk(TypeResolutionResult):
|
22
|
+
"""Type resolution result for successful resolution."""
|
23
|
+
|
24
|
+
def __bool__(self):
|
25
|
+
return True
|
26
|
+
|
27
|
+
|
28
|
+
Ok = ResolutionOk()
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class ResolutionError(TypeResolutionResult):
|
33
|
+
"""Type resolution result for failed resolution."""
|
34
|
+
|
35
|
+
expr: types.TypeAttribute
|
36
|
+
value: types.TypeAttribute
|
37
|
+
|
38
|
+
def __bool__(self):
|
39
|
+
return False
|
40
|
+
|
41
|
+
def __str__(self):
|
42
|
+
return f"expected {self.expr}, got {self.value}"
|
43
|
+
|
44
|
+
|
45
|
+
@dataclass
|
46
|
+
class TypeResolution:
|
47
|
+
"""Type resolution algorithm for type inference."""
|
48
|
+
|
49
|
+
vars: dict[types.TypeVar, types.TypeAttribute] = field(default_factory=dict)
|
50
|
+
|
51
|
+
def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute:
|
52
|
+
"""Substitute type variables in the type with their values.
|
53
|
+
|
54
|
+
This method substitutes type variables in the given type with their
|
55
|
+
values. If the type is a generic type, the method recursively
|
56
|
+
substitutes the type variables in the type arguments.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
typ: The type to substitute.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
The type with the type variables substituted.
|
63
|
+
"""
|
64
|
+
if isinstance(typ, types.TypeVar):
|
65
|
+
return self.vars.get(typ, typ)
|
66
|
+
elif isinstance(typ, types.Generic):
|
67
|
+
return types.Generic(
|
68
|
+
typ.body, *tuple(self.substitute(var) for var in typ.vars)
|
69
|
+
)
|
70
|
+
elif isinstance(typ, types.Union):
|
71
|
+
return types.Union(self.substitute(t) for t in typ.types)
|
72
|
+
return typ
|
73
|
+
|
74
|
+
def solve(
|
75
|
+
self, annot: types.TypeAttribute, value: types.TypeAttribute
|
76
|
+
) -> TypeResolutionResult:
|
77
|
+
"""Solve the type resolution problem.
|
78
|
+
|
79
|
+
This method compares the expected type `annot` with the actual
|
80
|
+
type `value` and returns a result indicating whether the types
|
81
|
+
match or not.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
annot: The expected type.
|
85
|
+
value: The actual type.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
A `TypeResolutionResult` object indicating the result of the
|
89
|
+
resolution.
|
90
|
+
"""
|
91
|
+
if isinstance(annot, types.TypeVar):
|
92
|
+
return self.solve_TypeVar(annot, value)
|
93
|
+
elif isinstance(annot, types.Generic):
|
94
|
+
return self.solve_Generic(annot, value)
|
95
|
+
elif isinstance(annot, types.Union):
|
96
|
+
return self.solve_Union(annot, value)
|
97
|
+
|
98
|
+
if annot.is_subseteq(value):
|
99
|
+
return Ok
|
100
|
+
else:
|
101
|
+
return ResolutionError(annot, value)
|
102
|
+
|
103
|
+
def solve_TypeVar(self, annot: types.TypeVar, value: types.TypeAttribute):
|
104
|
+
if annot in self.vars:
|
105
|
+
if value.is_subseteq(self.vars[annot]):
|
106
|
+
self.vars[annot] = value
|
107
|
+
elif self.vars[annot].is_subseteq(value):
|
108
|
+
pass
|
109
|
+
else:
|
110
|
+
return ResolutionError(annot, value)
|
111
|
+
else:
|
112
|
+
self.vars[annot] = value
|
113
|
+
return Ok
|
114
|
+
|
115
|
+
def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute):
|
116
|
+
if not isinstance(value, types.Generic):
|
117
|
+
return ResolutionError(annot, value)
|
118
|
+
|
119
|
+
if not value.body.is_subseteq(annot.body):
|
120
|
+
return ResolutionError(annot.body, value.body)
|
121
|
+
|
122
|
+
for var, val in zip(annot.vars, value.vars):
|
123
|
+
result = self.solve(var, val)
|
124
|
+
if not result:
|
125
|
+
return result
|
126
|
+
|
127
|
+
if not annot.vararg:
|
128
|
+
return Ok
|
129
|
+
|
130
|
+
for val in value.vars[len(annot.vars) :]:
|
131
|
+
result = self.solve(annot.vararg.typ, val)
|
132
|
+
if not result:
|
133
|
+
return result
|
134
|
+
return Ok
|
135
|
+
|
136
|
+
def solve_Union(self, annot: types.Union, value: types.TypeAttribute):
|
137
|
+
for typ in annot.types:
|
138
|
+
result = self.solve(typ, value)
|
139
|
+
if result:
|
140
|
+
return Ok
|
141
|
+
return ResolutionError(annot, value)
|
kirin/decl/__init__.py
ADDED
@@ -0,0 +1,108 @@
|
|
1
|
+
from typing import TypeVar, Callable
|
2
|
+
|
3
|
+
from typing_extensions import Unpack, dataclass_transform
|
4
|
+
|
5
|
+
from kirin.ir import Statement
|
6
|
+
from kirin.decl import info
|
7
|
+
from kirin.decl.base import StatementOptions
|
8
|
+
from kirin.decl.verify import Verify
|
9
|
+
from kirin.decl.emit.init import EmitInit
|
10
|
+
from kirin.decl.emit.name import EmitName
|
11
|
+
from kirin.decl.emit.repr import EmitRepr
|
12
|
+
from kirin.decl.emit.traits import EmitTraits
|
13
|
+
from kirin.decl.emit.verify import EmitVerify
|
14
|
+
from kirin.decl.scan_fields import ScanFields
|
15
|
+
from kirin.decl.emit.dialect import EmitDialect
|
16
|
+
from kirin.decl.emit.property import EmitProperty
|
17
|
+
from kirin.decl.emit.typecheck import EmitTypeCheck
|
18
|
+
|
19
|
+
|
20
|
+
class StatementDecl(
|
21
|
+
ScanFields,
|
22
|
+
Verify,
|
23
|
+
EmitInit,
|
24
|
+
EmitProperty,
|
25
|
+
EmitDialect,
|
26
|
+
EmitName,
|
27
|
+
EmitRepr,
|
28
|
+
EmitTraits,
|
29
|
+
EmitVerify,
|
30
|
+
EmitTypeCheck,
|
31
|
+
):
|
32
|
+
pass
|
33
|
+
|
34
|
+
|
35
|
+
StmtType = TypeVar("StmtType", bound=Statement)
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass_transform(
|
39
|
+
field_specifiers=(
|
40
|
+
info.attribute,
|
41
|
+
info.argument,
|
42
|
+
info.region,
|
43
|
+
info.result,
|
44
|
+
info.block,
|
45
|
+
)
|
46
|
+
)
|
47
|
+
def statement(
|
48
|
+
cls=None,
|
49
|
+
**kwargs: Unpack[StatementOptions],
|
50
|
+
) -> Callable[[type[StmtType]], type[StmtType]]:
|
51
|
+
"""Declare a new statement class.
|
52
|
+
|
53
|
+
This decorator is used to declare a new statement class. It is used to
|
54
|
+
generate the necessary boilerplate code for the class. The class should
|
55
|
+
inherit from `kirin.ir.Statement`.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
init(bool): Whether to generate an `__init__` method.
|
59
|
+
repr(bool): Whether to generate a `__repr__` method.
|
60
|
+
kw_only(bool): Whether to use keyword-only arguments in the `__init__`
|
61
|
+
method.
|
62
|
+
dialect(Optional[Dialect]): The dialect of the statement.
|
63
|
+
property(bool): Whether to generate property methods for attributes.
|
64
|
+
|
65
|
+
Example:
|
66
|
+
The following is an example of how to use the `statement` decorator.
|
67
|
+
|
68
|
+
```python
|
69
|
+
@statement
|
70
|
+
class MyStatement(ir.Statement):
|
71
|
+
name = "some_name"
|
72
|
+
traits = frozenset({TraitA(), TraitB()})
|
73
|
+
some_input: ir.SSAValue = info.argument()
|
74
|
+
some_output: ir.ResultValue = info.result()
|
75
|
+
body: ir.Region = info.region()
|
76
|
+
successor: ir.Block = info.block()
|
77
|
+
```
|
78
|
+
|
79
|
+
If the `name` field is not specified, a lowercase name field will be auto generated.
|
80
|
+
|
81
|
+
In addition, one can optionally register the statement to a dialect
|
82
|
+
by providing the `dialect` argument to the decorator.
|
83
|
+
|
84
|
+
The following example register the statement to a dialect `my_dialect_object`, and `name = "myfoo"` field is autogenerated
|
85
|
+
|
86
|
+
```python
|
87
|
+
@statement(dialect=my_dialect_object)
|
88
|
+
class MyFoo(ir.Statement):
|
89
|
+
traits = frozenset({ir.FromPythonCall()})
|
90
|
+
value: str = info.attribute()
|
91
|
+
```
|
92
|
+
"""
|
93
|
+
|
94
|
+
def wrap(cls):
|
95
|
+
decl = StatementDecl(cls, **kwargs)
|
96
|
+
decl.scan_fields()
|
97
|
+
decl.verify()
|
98
|
+
decl.emit()
|
99
|
+
decl.register()
|
100
|
+
return cls
|
101
|
+
|
102
|
+
if cls is None:
|
103
|
+
return wrap
|
104
|
+
return wrap(cls)
|
105
|
+
|
106
|
+
|
107
|
+
def fields(cls: type[Statement] | Statement) -> info.StatementFields:
|
108
|
+
return getattr(cls, ScanFields._FIELDS)
|
kirin/decl/base.py
ADDED
@@ -0,0 +1,65 @@
|
|
1
|
+
import sys
|
2
|
+
import inspect
|
3
|
+
from typing import Any, TypedDict
|
4
|
+
|
5
|
+
from typing_extensions import Unpack, Optional
|
6
|
+
|
7
|
+
from kirin.ir import Dialect
|
8
|
+
from kirin.decl.info import StatementFields
|
9
|
+
|
10
|
+
|
11
|
+
class StatementOptions(TypedDict, total=False):
|
12
|
+
init: bool
|
13
|
+
repr: bool
|
14
|
+
kw_only: bool
|
15
|
+
dialect: Optional[Dialect]
|
16
|
+
property: bool
|
17
|
+
|
18
|
+
|
19
|
+
class BaseModifier:
|
20
|
+
_PARAMS = "__kirin_stmt_params"
|
21
|
+
|
22
|
+
def __init__(self, cls: type, **kwargs: Unpack[StatementOptions]) -> None:
|
23
|
+
self.cls = cls
|
24
|
+
self.cls_module = sys.modules.get(cls.__module__)
|
25
|
+
|
26
|
+
if "dialect" in kwargs:
|
27
|
+
self.dialect = kwargs["dialect"]
|
28
|
+
else:
|
29
|
+
self.dialect = None
|
30
|
+
self.params = kwargs
|
31
|
+
setattr(cls, self._PARAMS, self.params)
|
32
|
+
|
33
|
+
if cls.__module__ in sys.modules:
|
34
|
+
self.globals = sys.modules[cls.__module__].__dict__
|
35
|
+
else:
|
36
|
+
# Theoretically this can happen if someone writes
|
37
|
+
# a custom string to cls.__module__. In which case
|
38
|
+
# such dataclass won't be fully introspectable
|
39
|
+
# (w.r.t. typing.get_type_hints) but will still function
|
40
|
+
# correctly.
|
41
|
+
self.globals: dict[str, Any] = {}
|
42
|
+
|
43
|
+
# analysis state, used by scan_field, etc.
|
44
|
+
self.fields = StatementFields()
|
45
|
+
self.has_statement_bases = False
|
46
|
+
self.kw_only = self.params.get("kw_only", False)
|
47
|
+
self.KW_ONLY_seen = False
|
48
|
+
|
49
|
+
def register(self) -> None:
|
50
|
+
if self.dialect is None:
|
51
|
+
return
|
52
|
+
self.dialect.register(self.cls)
|
53
|
+
|
54
|
+
def emit(self):
|
55
|
+
self._self_name = "__kirin_stmt_self" if "self" in self.fields else "self"
|
56
|
+
self._class_name = "__kirin_stmt_cls" if "cls" in self.fields else "cls"
|
57
|
+
self._run_passes("emit_")
|
58
|
+
|
59
|
+
def verify(self):
|
60
|
+
self._run_passes("verify_")
|
61
|
+
|
62
|
+
def _run_passes(self, prefix: str):
|
63
|
+
for name, method in inspect.getmembers(self, inspect.ismethod):
|
64
|
+
if name.startswith(prefix):
|
65
|
+
method()
|
File without changes
|
@@ -0,0 +1,29 @@
|
|
1
|
+
"""This module provides a function to create a function dynamically.
|
2
|
+
|
3
|
+
Copied from `dataclasses._create_fn` in Python 3.10.13.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from dataclasses import MISSING
|
7
|
+
|
8
|
+
|
9
|
+
def create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING):
|
10
|
+
# Note that we may mutate locals. Callers beware!
|
11
|
+
# The only callers are internal to this module, so no
|
12
|
+
# worries about external callers.
|
13
|
+
if locals is None:
|
14
|
+
locals = {}
|
15
|
+
return_annotation = ""
|
16
|
+
if return_type is not MISSING:
|
17
|
+
locals["_return_type"] = return_type
|
18
|
+
return_annotation = "->_return_type"
|
19
|
+
args = ",".join(args)
|
20
|
+
body = "\n".join(f" {b}" for b in body)
|
21
|
+
|
22
|
+
# Compute the text of the entire function.
|
23
|
+
txt = f" def {name}({args}){return_annotation}:\n{body}"
|
24
|
+
|
25
|
+
local_vars = ", ".join(locals.keys())
|
26
|
+
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
|
27
|
+
ns = {}
|
28
|
+
exec(txt, globals, ns)
|
29
|
+
return ns["__create_fn__"](**locals)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""Copied from dataclasses in Python 3.10.13.
|
2
|
+
"""
|
3
|
+
|
4
|
+
from types import FunctionType
|
5
|
+
|
6
|
+
|
7
|
+
def set_qualname(cls: type, value):
|
8
|
+
# Ensure that the functions returned from _create_fn uses the proper
|
9
|
+
# __qualname__ (the class they belong to).
|
10
|
+
if isinstance(value, FunctionType):
|
11
|
+
value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
|
12
|
+
return value
|
13
|
+
|
14
|
+
|
15
|
+
def set_new_attribute(cls: type, name: str, value):
|
16
|
+
# Never overwrites an existing attribute. Returns True if the
|
17
|
+
# attribute already exists.
|
18
|
+
if name in cls.__dict__:
|
19
|
+
return True
|
20
|
+
set_qualname(cls, value)
|
21
|
+
setattr(cls, name, value)
|
22
|
+
return False
|