bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,61 @@
1
+ from dataclasses import field, dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import (
6
+ Walk,
7
+ Chain,
8
+ Fixpoint,
9
+ ConstantFold,
10
+ DeadCodeElimination,
11
+ CommonSubexpressionElimination,
12
+ )
13
+ from kirin.rewrite.result import RewriteResult
14
+
15
+ from bloqade.noise import native
16
+ from bloqade.analysis import address
17
+ from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
18
+
19
+
20
+ @dataclass
21
+ class NoisePass(Pass):
22
+ """Apply a noise model to a quantum circuit.
23
+
24
+ NOTE: This pass is not guaranteed to be supported long-term in bloqade. We will be
25
+ moving towards a more general approach to noise modeling in the future.
26
+
27
+ """
28
+
29
+ noise_model: native.MoveNoiseModelABC = field(
30
+ default_factory=native.TwoRowZoneModel
31
+ )
32
+ gate_noise_params: native.GateNoiseParams = field(
33
+ default_factory=native.GateNoiseParams
34
+ )
35
+ address_analysis: address.AddressAnalysis = field(init=False)
36
+
37
+ def __post_init__(self):
38
+ self.address_analysis = address.AddressAnalysis(self.dialects)
39
+
40
+ def unsafe_run(self, mt: ir.Method):
41
+ result = RewriteResult()
42
+
43
+ frame, res = self.address_analysis.run_analysis(mt, no_raise=False)
44
+ result = (
45
+ Walk(
46
+ NoiseRewriteRule(
47
+ address_analysis=frame.entries,
48
+ noise_model=self.noise_model,
49
+ gate_noise_params=self.gate_noise_params,
50
+ )
51
+ )
52
+ .rewrite(mt.code)
53
+ .join(result)
54
+ )
55
+ rule = Chain(
56
+ ConstantFold(),
57
+ DeadCodeElimination(),
58
+ CommonSubexpressionElimination(),
59
+ )
60
+ result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
61
+ return result
@@ -0,0 +1,176 @@
1
+ """
2
+ Passes for converting parallel gates into multiple single gates as well as
3
+ converting multiple single gates to parallel gates.
4
+ """
5
+
6
+ from typing import Type
7
+ from dataclasses import field, dataclass
8
+
9
+ from kirin import ir
10
+ from kirin.passes import Pass
11
+ from kirin.rewrite import (
12
+ Walk,
13
+ Chain,
14
+ Fixpoint,
15
+ WrapConst,
16
+ ConstantFold,
17
+ DeadCodeElimination,
18
+ CommonSubexpressionElimination,
19
+ result,
20
+ )
21
+ from kirin.analysis import const
22
+
23
+ from bloqade.analysis import address
24
+ from bloqade.qasm2.rewrite import (
25
+ MergePolicyABC,
26
+ ParallelToUOpRule,
27
+ RaiseRegisterRule,
28
+ UOpToParallelRule,
29
+ SimpleOptimalMergePolicy,
30
+ )
31
+ from bloqade.squin.analysis import schedule
32
+
33
+
34
+ @dataclass
35
+ class ParallelToUOp(Pass):
36
+ """Pass to convert parallel gates into single gates.
37
+
38
+ This pass rewrites any parallel gates from the `qasm2.parallel` dialect into multiple
39
+ single gates in the `qasm2.uop` dialect, bringing the program closer to
40
+ conforming to standard QASM2 syntax.
41
+
42
+ ## Usage Examples
43
+ ```
44
+ # Define kernel
45
+ @qasm2.extended
46
+ def main():
47
+ q = qasm2.qreg(4)
48
+
49
+ qasm2.parallel.cz(ctrls=[q[0], q[2]], qargs=[q[1], q[3]])
50
+
51
+ # Run rewrite
52
+ ParallelToUOp(main.dialects)(main)
53
+ ```
54
+
55
+ The `qasm2.parallel.cz` statement has been rewritten to individual gates:
56
+
57
+ ```
58
+ qasm2.uop.cz(ctrl=q[0], qarg=q[1])
59
+ qasm2.uop.cz(ctrl=q[2], qarg=q[3])
60
+ ```
61
+
62
+ """
63
+
64
+ def generate_rule(self, mt: ir.Method) -> ParallelToUOpRule:
65
+ frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
66
+
67
+ id_map = {}
68
+
69
+ # GOAL: Get the ssa value for the first reference of each qubit.
70
+ for ssa, addr in frame.entries.items():
71
+ if not isinstance(addr, address.AddressQubit):
72
+ # skip any stmts that are not qubits
73
+ continue
74
+
75
+ # get qubit id from analysis result
76
+ qubit_id = addr.data
77
+
78
+ # check if id has already been found
79
+ # if so, skip this ssa value
80
+ if qubit_id in id_map:
81
+ continue
82
+
83
+ id_map[qubit_id] = ssa
84
+
85
+ return ParallelToUOpRule(id_map=id_map, address_analysis=frame.entries)
86
+
87
+ def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
88
+ result = Walk(self.generate_rule(mt)).rewrite(mt.code)
89
+ rule = Chain(
90
+ ConstantFold(),
91
+ DeadCodeElimination(),
92
+ CommonSubexpressionElimination(),
93
+ )
94
+ return Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
95
+
96
+
97
+ @dataclass
98
+ class UOpToParallel(Pass):
99
+ """Pass to convert single gates into parallel gates.
100
+
101
+ This pass looks for single gates from the `qasm2.uop` dialect that can be combined
102
+ into parallel gates from the `qasm2.parallel` dialect and performs a rewrite to do so.
103
+
104
+ ## Usage Examples
105
+ ```
106
+ # Define kernel
107
+ @qasm2.main
108
+ def test():
109
+ q = qasm2.qreg(4)
110
+
111
+ theta = 0.1
112
+ phi = 0.2
113
+ lam = 0.3
114
+
115
+ qasm2.u(q[1], theta, phi, lam)
116
+ qasm2.u(q[3], theta, phi, lam)
117
+ qasm2.cx(q[1], q[3])
118
+ qasm2.u(q[2], theta, phi, lam)
119
+ qasm2.u(q[0], theta, phi, lam)
120
+ qasm2.cx(q[0], q[2])
121
+
122
+ # Run rewrite
123
+ UOpToParallel(main.dialects)(main)
124
+ ```
125
+
126
+ The individual `qasm2.u` statements have now been combined
127
+ into a single `qasm2.parallel.u` statement.
128
+
129
+ ```
130
+ qasm2.parallel.u(qargs = [q[0], q[1], q[2], q[3]], theta, phi, lam)
131
+ qasm2.uop.CX(q[1], q[3])
132
+ qasm2.uop.CX(q[0], q[2])
133
+ ```
134
+
135
+ """
136
+
137
+ merge_policy_type: Type[MergePolicyABC] = SimpleOptimalMergePolicy
138
+ constprop: const.Propagate = field(init=False)
139
+
140
+ def __post_init__(self):
141
+ self.constprop = const.Propagate(self.dialects)
142
+
143
+ def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
144
+ result = Walk(RaiseRegisterRule()).rewrite(mt.code)
145
+
146
+ # do not run the parallelization because registers are not at the top
147
+ if not result.has_done_something:
148
+ return result
149
+
150
+ frame, _ = self.constprop.run_analysis(mt)
151
+ result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
152
+
153
+ frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
154
+ dags = schedule.DagScheduleAnalysis(
155
+ mt.dialects, address_analysis=frame.entries
156
+ ).get_dags(mt)
157
+
158
+ result = (
159
+ Walk(
160
+ UOpToParallelRule(
161
+ {
162
+ block: self.merge_policy_type.from_analysis(dag, frame.entries)
163
+ for block, dag in dags.items()
164
+ }
165
+ )
166
+ )
167
+ .rewrite(mt.code)
168
+ .join(result)
169
+ )
170
+
171
+ rule = Chain(
172
+ ConstantFold(),
173
+ DeadCodeElimination(),
174
+ CommonSubexpressionElimination(),
175
+ )
176
+ return Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
@@ -0,0 +1,63 @@
1
+ """Rewrite py dialects into qasm dialects."""
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import Walk, Fixpoint
6
+ from kirin.dialects import py, math
7
+ from kirin.rewrite.abc import RewriteRule
8
+ from kirin.rewrite.result import RewriteResult
9
+
10
+ from bloqade.qasm2.dialects import core, expr
11
+
12
+
13
+ class _Py2QASM(RewriteRule):
14
+ """Rewrite py dialects into qasm dialects."""
15
+
16
+ UNARY_OPS = {
17
+ py.USub: expr.Neg,
18
+ math.sin: expr.Sin,
19
+ math.cos: expr.Cos,
20
+ math.tan: expr.Tan,
21
+ math.exp: expr.Exp,
22
+ math.sqrt: expr.Sqrt,
23
+ }
24
+
25
+ BINARY_OPS = {
26
+ py.Add: expr.Add,
27
+ py.Sub: expr.Sub,
28
+ py.Mult: expr.Mul,
29
+ py.Div: expr.Div,
30
+ py.Pow: expr.Pow,
31
+ }
32
+
33
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
34
+ if isinstance(node, py.Constant):
35
+ value = node.value.unwrap()
36
+ if isinstance(value, int):
37
+ node.replace_by(expr.ConstInt(value=value))
38
+ return RewriteResult(has_done_something=True)
39
+ elif isinstance(value, float):
40
+ node.replace_by(expr.ConstFloat(value=value))
41
+ return RewriteResult(has_done_something=True)
42
+ elif isinstance(node, py.BinOp):
43
+ if (pystmt := self.BINARY_OPS.get(type(node))) is not None:
44
+ node.replace_by(pystmt(node.lhs, node.rhs))
45
+ return RewriteResult(has_done_something=True)
46
+ elif isinstance(node, py.UnaryOp):
47
+ if (pystmt := self.UNARY_OPS.get(type(node))) is not None:
48
+ node.replace_by(pystmt(node.value))
49
+ return RewriteResult(has_done_something=True)
50
+ elif isinstance(node, py.cmp.Eq):
51
+ node.replace_by(core.CRegEq(node.lhs, node.rhs))
52
+ return RewriteResult(has_done_something=True)
53
+ elif isinstance(node, py.assign.Alias):
54
+ node.result.replace_by(node.value)
55
+ node.delete()
56
+ return RewriteResult(has_done_something=True)
57
+ return RewriteResult()
58
+
59
+
60
+ class Py2QASM(Pass):
61
+
62
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
63
+ return Fixpoint(Walk(_Py2QASM())).rewrite(mt.code)
@@ -0,0 +1,61 @@
1
+ """Rewrite qasm dialects into py dialects."""
2
+
3
+ import math as pymath
4
+
5
+ from kirin import ir
6
+ from kirin.passes import Pass
7
+ from kirin.rewrite import Walk, Fixpoint
8
+ from kirin.dialects import py, math
9
+ from kirin.rewrite.abc import RewriteRule
10
+ from kirin.rewrite.result import RewriteResult
11
+
12
+ from bloqade.qasm2.dialects import core, expr
13
+
14
+
15
+ class _QASM2Py(RewriteRule):
16
+ """Rewrite qasm dialects into py dialects."""
17
+
18
+ UNARY_OPS = {
19
+ expr.Neg: py.USub,
20
+ expr.Sin: math.stmts.sin,
21
+ expr.Cos: math.stmts.cos,
22
+ expr.Tan: math.stmts.tan,
23
+ expr.Exp: math.stmts.exp,
24
+ expr.Sqrt: math.stmts.sqrt,
25
+ }
26
+
27
+ BINARY_OPS = {
28
+ expr.Add: py.Add,
29
+ expr.Sub: py.Sub,
30
+ expr.Mul: py.Mult,
31
+ expr.Div: py.Div,
32
+ expr.Pow: py.Pow,
33
+ }
34
+
35
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
36
+ if isinstance(node, (expr.ConstInt, expr.ConstFloat)):
37
+ node.replace_by(py.Constant(value=node.value))
38
+ return RewriteResult(has_done_something=True)
39
+ elif isinstance(node, expr.Neg):
40
+ node.replace_by(self.UNARY_OPS[type(node)](value=node.value))
41
+ return RewriteResult(has_done_something=True)
42
+ elif isinstance(node, (expr.Sin, expr.Cos, expr.Tan, expr.Exp, expr.Sqrt)):
43
+ node.replace_by(self.UNARY_OPS[type(node)](x=node.value))
44
+ return RewriteResult(has_done_something=True)
45
+ elif isinstance(node, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Pow)):
46
+ node.replace_by(self.BINARY_OPS[type(node)](lhs=node.lhs, rhs=node.rhs))
47
+ return RewriteResult(has_done_something=True)
48
+ elif isinstance(node, core.CRegEq):
49
+ node.replace_by(py.cmp.Eq(node.lhs, node.rhs))
50
+ return RewriteResult(has_done_something=True)
51
+ elif isinstance(node, expr.ConstPI):
52
+ node.replace_by(py.Constant(value=pymath.pi))
53
+ return RewriteResult(has_done_something=True)
54
+ else:
55
+ return RewriteResult()
56
+
57
+
58
+ class QASM2Py(Pass):
59
+
60
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
61
+ return Fixpoint(Walk(_QASM2Py())).rewrite(mt.code)
@@ -0,0 +1,12 @@
1
+ from .glob import (
2
+ GlobalToUOpRule as GlobalToUOpRule,
3
+ GlobalToParallelRule as GlobalToParallelRule,
4
+ )
5
+ from .register import RaiseRegisterRule as RaiseRegisterRule
6
+ from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
7
+ from .uop_to_parallel import (
8
+ MergePolicyABC as MergePolicyABC,
9
+ UOpToParallelRule as UOpToParallelRule,
10
+ SimpleGreedyMergePolicy as SimpleGreedyMergePolicy,
11
+ SimpleOptimalMergePolicy as SimpleOptimalMergePolicy,
12
+ )
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import abc, walk, result
6
+ from kirin.dialects import py
7
+
8
+ from bloqade.qasm2.dialects import core
9
+
10
+
11
+ class IndexingDesugarRule(abc.RewriteRule):
12
+ def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
13
+ if isinstance(node, py.indexing.GetItem):
14
+ if node.obj.type.is_subseteq(core.QRegType):
15
+ node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
16
+ return result.RewriteResult(has_done_something=True)
17
+ elif node.obj.type.is_subseteq(core.CRegType):
18
+ node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
19
+ return result.RewriteResult(has_done_something=True)
20
+
21
+ return result.RewriteResult()
22
+
23
+
24
+ @dataclass
25
+ class IndexingDesugarPass(Pass):
26
+ def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
27
+
28
+ return walk.Walk(IndexingDesugarRule()).rewrite(mt.code)
@@ -0,0 +1,103 @@
1
+ from typing import Dict, List
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.rewrite import abc, result
6
+ from kirin.dialects import py, ilist
7
+
8
+ from bloqade import qasm2
9
+ from bloqade.analysis import address
10
+ from bloqade.qasm2.dialects import glob
11
+
12
+
13
+ @dataclass
14
+ class GlobalRewriteBase:
15
+ address_analysis: Dict[ir.SSAValue, address.Address]
16
+
17
+ def get_qubit_ssa(self, node: glob.UGate):
18
+ new_stmts: List[ir.Statement] = []
19
+ qubit_ssa: List[ir.SSAValue] = []
20
+ # can't rewrite if the registers are coming from a block argument
21
+ if not isinstance(node.registers, ir.ResultValue):
22
+ return new_stmts, None
23
+
24
+ if not isinstance(node.registers.owner, ilist.New):
25
+ return new_stmts, None
26
+
27
+ register_ssa_values = node.registers.owner.values
28
+
29
+ for register_ssa in register_ssa_values:
30
+ addr = self.address_analysis.get(register_ssa, address.Address.top())
31
+ if not isinstance(addr, address.AddressReg):
32
+ new_stmts.clear()
33
+ return new_stmts, None
34
+
35
+ for qubit in range(len(addr.data)):
36
+ new_stmts.append(idx_stmt := py.constant.Constant(value=qubit))
37
+ new_stmts.append(
38
+ qubit_stmt := qasm2.core.QRegGet(
39
+ reg=register_ssa, idx=idx_stmt.result
40
+ )
41
+ )
42
+ qubit_ssa.append(qubit_stmt.result)
43
+
44
+ return new_stmts, qubit_ssa
45
+
46
+
47
+ @dataclass
48
+ class GlobalToParallelRule(abc.RewriteRule, GlobalRewriteBase):
49
+
50
+ def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
51
+ if type(node) in glob.dialect.stmts:
52
+ return getattr(self, f"rewrite_{node.name}")(node)
53
+
54
+ return result.RewriteResult()
55
+
56
+ def rewrite_ugate(self, node: glob.UGate):
57
+
58
+ new_stmts, qubit_ssa = self.get_qubit_ssa(node)
59
+
60
+ if qubit_ssa is None:
61
+ return result.RewriteResult()
62
+
63
+ new_stmts.append(qargs := ilist.New(values=qubit_ssa))
64
+ new_stmts.append(
65
+ qasm2.dialects.parallel.UGate(
66
+ qargs=qargs.result, theta=node.theta, phi=node.phi, lam=node.lam
67
+ )
68
+ )
69
+
70
+ for stmt in new_stmts:
71
+ stmt.insert_before(node)
72
+
73
+ node.delete()
74
+
75
+ return result.RewriteResult(has_done_something=True)
76
+
77
+
78
+ @dataclass
79
+ class GlobalToUOpRule(abc.RewriteRule, GlobalRewriteBase):
80
+
81
+ def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
82
+ if type(node) in glob.dialect.stmts:
83
+ return getattr(self, f"rewrite_{node.name}")(node)
84
+
85
+ return result.RewriteResult()
86
+
87
+ def rewrite_ugate(self, node: glob.UGate):
88
+
89
+ new_stmts, qubit_ssa = self.get_qubit_ssa(node)
90
+
91
+ if qubit_ssa is None:
92
+ return result.RewriteResult()
93
+
94
+ for qarg in qubit_ssa:
95
+ new_stmts.append(
96
+ qasm2.uop.UGate(qarg=qarg, theta=node.theta, phi=node.phi, lam=node.lam)
97
+ )
98
+
99
+ for stmt in new_stmts:
100
+ stmt.insert_before(node)
101
+
102
+ node.delete()
103
+ return result.RewriteResult(has_done_something=True)