bloqade-circuit 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (80) hide show
  1. bloqade/analysis/address/impls.py +14 -0
  2. bloqade/analysis/fidelity/analysis.py +27 -2
  3. bloqade/noise/fidelity.py +3 -3
  4. bloqade/noise/native/_dialect.py +1 -1
  5. bloqade/noise/native/_wrappers.py +35 -6
  6. bloqade/noise/native/stmts.py +1 -1
  7. bloqade/pyqrack/device.py +109 -21
  8. bloqade/pyqrack/qasm2/core.py +4 -1
  9. bloqade/pyqrack/squin/qubit.py +16 -9
  10. bloqade/pyqrack/squin/wire.py +22 -4
  11. bloqade/pyqrack/task.py +13 -5
  12. bloqade/qasm2/__init__.py +1 -0
  13. bloqade/qasm2/_qasm_loading.py +151 -0
  14. bloqade/qasm2/dialects/core/__init__.py +9 -1
  15. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  16. bloqade/qasm2/dialects/noise.py +33 -1
  17. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  18. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  19. bloqade/qasm2/emit/impls/__init__.py +1 -0
  20. bloqade/qasm2/emit/impls/noise_native.py +89 -0
  21. bloqade/qasm2/emit/main.py +21 -0
  22. bloqade/qasm2/emit/target.py +20 -5
  23. bloqade/qasm2/groups.py +2 -0
  24. bloqade/qasm2/parse/__init__.py +7 -4
  25. bloqade/qasm2/parse/lowering.py +20 -130
  26. bloqade/qasm2/parse/qasm2.lark +1 -1
  27. bloqade/qasm2/passes/__init__.py +1 -0
  28. bloqade/qasm2/passes/fold.py +6 -0
  29. bloqade/qasm2/passes/noise.py +50 -2
  30. bloqade/qasm2/passes/parallel.py +9 -0
  31. bloqade/qasm2/passes/unroll_if.py +25 -0
  32. bloqade/qasm2/rewrite/__init__.py +1 -0
  33. bloqade/qasm2/rewrite/desugar.py +3 -2
  34. bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
  35. bloqade/qasm2/rewrite/native_gates.py +67 -4
  36. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  37. bloqade/squin/analysis/nsites/__init__.py +1 -0
  38. bloqade/squin/analysis/nsites/impls.py +25 -1
  39. bloqade/squin/noise/__init__.py +7 -26
  40. bloqade/squin/noise/_wrapper.py +25 -0
  41. bloqade/squin/op/__init__.py +33 -159
  42. bloqade/squin/op/_wrapper.py +101 -0
  43. bloqade/squin/op/stdlib.py +62 -0
  44. bloqade/squin/passes/__init__.py +1 -0
  45. bloqade/squin/passes/stim.py +68 -0
  46. bloqade/squin/rewrite/__init__.py +11 -0
  47. bloqade/squin/rewrite/qubit_to_stim.py +84 -0
  48. bloqade/squin/rewrite/squin_measure.py +98 -0
  49. bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
  50. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  51. bloqade/squin/rewrite/wire_to_stim.py +73 -0
  52. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  53. bloqade/squin/wire.py +1 -13
  54. bloqade/stim/__init__.py +39 -5
  55. bloqade/stim/_wrappers.py +14 -12
  56. bloqade/stim/dialects/__init__.py +1 -5
  57. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
  58. bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
  59. bloqade/stim/dialects/collapse/__init__.py +13 -2
  60. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
  61. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  62. bloqade/stim/dialects/gate/__init__.py +16 -1
  63. bloqade/stim/dialects/gate/emit.py +1 -1
  64. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  65. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  66. bloqade/stim/dialects/noise/emit.py +1 -1
  67. bloqade/stim/emit/__init__.py +1 -1
  68. bloqade/stim/groups.py +4 -2
  69. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
  70. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +80 -64
  71. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  72. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  73. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  74. /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
  75. /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
  76. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  77. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  78. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  79. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
  80. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,62 @@
1
+ from kirin import ir
2
+ from kirin.prelude import structural_no_opt
3
+
4
+ from . import types
5
+ from ._dialect import dialect
6
+ from ._wrapper import h, x, y, z, rot, phase, control
7
+
8
+
9
+ @ir.dialect_group(structural_no_opt.add(dialect))
10
+ def op(self):
11
+ def run_pass(method):
12
+ pass
13
+
14
+ return run_pass
15
+
16
+
17
+ @op
18
+ def rx(theta: float) -> types.Op:
19
+ """Rotation X gate."""
20
+ return rot(x(), theta)
21
+
22
+
23
+ @op
24
+ def ry(theta: float) -> types.Op:
25
+ """Rotation Y gate."""
26
+ return rot(y(), theta)
27
+
28
+
29
+ @op
30
+ def rz(theta: float) -> types.Op:
31
+ """Rotation Z gate."""
32
+ return rot(z(), theta)
33
+
34
+
35
+ @op
36
+ def cx() -> types.Op:
37
+ """Controlled X gate."""
38
+ return control(x(), n_controls=1)
39
+
40
+
41
+ @op
42
+ def cy() -> types.Op:
43
+ """Controlled Y gate."""
44
+ return control(y(), n_controls=1)
45
+
46
+
47
+ @op
48
+ def cz() -> types.Op:
49
+ """Control Z gate."""
50
+ return control(z(), n_controls=1)
51
+
52
+
53
+ @op
54
+ def ch() -> types.Op:
55
+ """Control H gate."""
56
+ return control(h(), n_controls=1)
57
+
58
+
59
+ @op
60
+ def cphase(theta: float) -> types.Op:
61
+ """Control Phase gate."""
62
+ return control(phase(theta), n_controls=1)
@@ -0,0 +1 @@
1
+ from .stim import SquinToStim as SquinToStim
@@ -0,0 +1,68 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.passes import Fold
4
+ from kirin.rewrite import (
5
+ Walk,
6
+ Chain,
7
+ Fixpoint,
8
+ DeadCodeElimination,
9
+ CommonSubexpressionElimination,
10
+ )
11
+ from kirin.ir.method import Method
12
+ from kirin.passes.abc import Pass
13
+ from kirin.rewrite.abc import RewriteResult
14
+
15
+ from bloqade.squin.rewrite import (
16
+ SquinWireToStim,
17
+ SquinQubitToStim,
18
+ WrapSquinAnalysis,
19
+ SquinMeasureToStim,
20
+ SquinWireIdentityElimination,
21
+ )
22
+ from bloqade.analysis.address import AddressAnalysis
23
+ from bloqade.squin.analysis.nsites import (
24
+ NSitesAnalysis,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class SquinToStim(Pass):
30
+
31
+ def unsafe_run(self, mt: Method) -> RewriteResult:
32
+ fold_pass = Fold(mt.dialects)
33
+ # propagate constants
34
+ rewrite_result = fold_pass(mt)
35
+
36
+ # Get necessary analysis results to plug into hints
37
+ address_analysis = AddressAnalysis(mt.dialects)
38
+ address_frame, _ = address_analysis.run_analysis(mt)
39
+ site_analysis = NSitesAnalysis(mt.dialects)
40
+ sites_frame, _ = site_analysis.run_analysis(mt)
41
+
42
+ # Wrap Rewrite + SquinToStim can happen w/ standard walk
43
+ rewrite_result = (
44
+ Walk(
45
+ Chain(
46
+ WrapSquinAnalysis(
47
+ address_analysis=address_frame.entries,
48
+ op_site_analysis=sites_frame.entries,
49
+ ),
50
+ SquinQubitToStim(),
51
+ SquinWireToStim(),
52
+ SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
53
+ SquinWireIdentityElimination(),
54
+ )
55
+ )
56
+ .rewrite(mt.code)
57
+ .join(rewrite_result)
58
+ )
59
+
60
+ rewrite_result = (
61
+ Fixpoint(
62
+ Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination()))
63
+ )
64
+ .rewrite(mt.code)
65
+ .join(rewrite_result)
66
+ )
67
+
68
+ return rewrite_result
@@ -0,0 +1,11 @@
1
+ from .wire_to_stim import SquinWireToStim as SquinWireToStim
2
+ from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
3
+ from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
4
+ from .wrap_analysis import (
5
+ SitesAttribute as SitesAttribute,
6
+ AddressAttribute as AddressAttribute,
7
+ WrapSquinAnalysis as WrapSquinAnalysis,
8
+ )
9
+ from .wire_identity_elimination import (
10
+ SquinWireIdentityElimination as SquinWireIdentityElimination,
11
+ )
@@ -0,0 +1,84 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade import stim
5
+ from bloqade.squin import op, qubit
6
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
7
+ from bloqade.squin.rewrite.stim_rewrite_util import (
8
+ SQUIN_STIM_GATE_MAPPING,
9
+ rewrite_Control,
10
+ insert_qubit_idx_from_address,
11
+ )
12
+
13
+
14
+ class SquinQubitToStim(RewriteRule):
15
+
16
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
17
+
18
+ match node:
19
+ case qubit.Apply() | qubit.Broadcast():
20
+ return self.rewrite_Apply_and_Broadcast(node)
21
+ case qubit.Reset():
22
+ return self.rewrite_Reset(node)
23
+ case _:
24
+ return RewriteResult()
25
+
26
+ def rewrite_Apply_and_Broadcast(
27
+ self, stmt: qubit.Apply | qubit.Broadcast
28
+ ) -> RewriteResult:
29
+ """
30
+ Rewrite Apply and Broadcast nodes to their stim equivalent statements.
31
+ """
32
+
33
+ # this is an SSAValue, need it to be the actual operator
34
+ applied_op = stmt.operator.owner
35
+ assert isinstance(applied_op, op.stmts.Operator)
36
+
37
+ if isinstance(applied_op, op.stmts.Control):
38
+ return rewrite_Control(stmt)
39
+
40
+ # need to handle Control through separate means
41
+ # but we can handle X, Y, Z, H, and S here just fine
42
+ stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
43
+ if stim_1q_op is None:
44
+ return RewriteResult()
45
+
46
+ address_attr = stmt.qubits.hints.get("address")
47
+ if address_attr is None:
48
+ return RewriteResult()
49
+
50
+ assert isinstance(address_attr, AddressAttribute)
51
+ qubit_idx_ssas = insert_qubit_idx_from_address(
52
+ address=address_attr, stmt_to_insert_before=stmt
53
+ )
54
+
55
+ if qubit_idx_ssas is None:
56
+ return RewriteResult()
57
+
58
+ stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
59
+ stmt.replace_by(stim_1q_stmt)
60
+
61
+ return RewriteResult(has_done_something=True)
62
+
63
+ def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult:
64
+ qubit_ilist_ssa = reset_stmt.qubits
65
+ # qubits are in an ilist which makes up an AddressTuple
66
+ address_attr = qubit_ilist_ssa.hints.get("address")
67
+ if address_attr is None:
68
+ return RewriteResult()
69
+
70
+ assert isinstance(address_attr, AddressAttribute)
71
+ qubit_idx_ssas = insert_qubit_idx_from_address(
72
+ address=address_attr, stmt_to_insert_before=reset_stmt
73
+ )
74
+
75
+ if qubit_idx_ssas is None:
76
+ return RewriteResult()
77
+
78
+ stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
79
+ reset_stmt.replace_by(stim_rz_stmt)
80
+
81
+ return RewriteResult(has_done_something=True)
82
+
83
+
84
+ # put rewrites for measure statements in separate rule, then just have to dispatch
@@ -0,0 +1,98 @@
1
+ # create rewrite rule name SquinMeasureToStim using kirin
2
+ from kirin import ir
3
+ from kirin.dialects import py
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+
6
+ from bloqade import stim
7
+ from bloqade.squin import wire, qubit
8
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
+ from bloqade.squin.rewrite.stim_rewrite_util import (
10
+ is_measure_result_used,
11
+ insert_qubit_idx_from_address,
12
+ )
13
+
14
+
15
+ class SquinMeasureToStim(RewriteRule):
16
+ """
17
+ Rewrite squin measure-related statements to stim statements.
18
+ """
19
+
20
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
+
22
+ match node:
23
+ case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
24
+ return self.rewrite_Measure(node)
25
+ case qubit.MeasureAndReset() | wire.MeasureAndReset():
26
+ return self.rewrite_MeasureAndReset(node)
27
+ case _:
28
+ return RewriteResult()
29
+
30
+ def rewrite_Measure(
31
+ self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
32
+ ) -> RewriteResult:
33
+ if is_measure_result_used(measure_stmt):
34
+ return RewriteResult()
35
+
36
+ qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
37
+ if qubit_idx_ssas is None:
38
+ return RewriteResult()
39
+
40
+ prob_noise_stmt = py.constant.Constant(0.0)
41
+ stim_measure_stmt = stim.collapse.MZ(
42
+ p=prob_noise_stmt.result,
43
+ targets=qubit_idx_ssas,
44
+ )
45
+ prob_noise_stmt.insert_before(measure_stmt)
46
+ measure_stmt.replace_by(stim_measure_stmt)
47
+
48
+ return RewriteResult(has_done_something=True)
49
+
50
+ def rewrite_MeasureAndReset(
51
+ self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
52
+ ) -> RewriteResult:
53
+ if not is_measure_result_used(meas_and_reset_stmt):
54
+ return RewriteResult()
55
+
56
+ qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt)
57
+
58
+ if qubit_idx_ssas is None:
59
+ return RewriteResult()
60
+
61
+ error_p_stmt = py.Constant(0.0)
62
+ stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result)
63
+ stim_rz_stmt = stim.collapse.RZ(
64
+ targets=qubit_idx_ssas,
65
+ )
66
+
67
+ error_p_stmt.insert_before(meas_and_reset_stmt)
68
+ stim_mz_stmt.insert_before(meas_and_reset_stmt)
69
+ meas_and_reset_stmt.replace_by(stim_rz_stmt)
70
+
71
+ return RewriteResult(has_done_something=True)
72
+
73
+ def get_qubit_idx_ssas(
74
+ self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
75
+ ) -> tuple[ir.SSAValue, ...] | None:
76
+ """
77
+ Extract the address attribute and insert qubit indices for the given measure statement.
78
+ """
79
+ match measure_stmt:
80
+ case qubit.MeasureQubit():
81
+ address_attr = measure_stmt.qubit.hints.get("address")
82
+ case qubit.MeasureQubitList():
83
+ address_attr = measure_stmt.qubits.hints.get("address")
84
+ case wire.Measure():
85
+ address_attr = measure_stmt.wire.hints.get("address")
86
+ case _:
87
+ return None
88
+
89
+ if address_attr is None:
90
+ return None
91
+
92
+ assert isinstance(address_attr, AddressAttribute)
93
+
94
+ qubit_idx_ssas = insert_qubit_idx_from_address(
95
+ address=address_attr, stmt_to_insert_before=measure_stmt
96
+ )
97
+
98
+ return qubit_idx_ssas
@@ -0,0 +1,158 @@
1
+ from kirin import ir
2
+ from kirin.dialects import py
3
+ from kirin.rewrite.abc import RewriteResult
4
+
5
+ from bloqade.squin import op, wire, qubit
6
+ from bloqade.stim.dialects import gate
7
+ from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
8
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
+
10
+ SQUIN_STIM_GATE_MAPPING = {
11
+ op.stmts.X: gate.X,
12
+ op.stmts.Y: gate.Y,
13
+ op.stmts.Z: gate.Z,
14
+ op.stmts.H: gate.H,
15
+ op.stmts.S: gate.S,
16
+ op.stmts.Identity: gate.Identity,
17
+ }
18
+
19
+
20
+ def insert_qubit_idx_from_address(
21
+ address: AddressAttribute, stmt_to_insert_before: ir.Statement
22
+ ) -> tuple[ir.SSAValue, ...] | None:
23
+ """
24
+ Extract qubit indices from an AddressAttribute and insert them into the SSA form.
25
+ """
26
+ address_data = address.address
27
+ qubit_idx_ssas = []
28
+
29
+ if isinstance(address_data, AddressTuple):
30
+ for address_qubit in address_data.data:
31
+ if not isinstance(address_qubit, AddressQubit):
32
+ return
33
+ qubit_idx = address_qubit.data
34
+ qubit_idx_stmt = py.Constant(qubit_idx)
35
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
36
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
37
+ elif isinstance(address_data, AddressWire):
38
+ address_qubit = address_data.origin_qubit
39
+ qubit_idx = address_qubit.data
40
+ qubit_idx_stmt = py.Constant(qubit_idx)
41
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
42
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
43
+ else:
44
+ return
45
+
46
+ return tuple(qubit_idx_ssas)
47
+
48
+
49
+ def insert_qubit_idx_from_wire_ssa(
50
+ wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
51
+ ) -> tuple[ir.SSAValue, ...] | None:
52
+ """
53
+ Extract qubit indices from wire SSA values and insert them into the SSA form.
54
+ """
55
+ qubit_idx_ssas = []
56
+ for wire_ssa in wire_ssas:
57
+ address_attribute = wire_ssa.hints.get("address")
58
+ if address_attribute is None:
59
+ return
60
+ assert isinstance(address_attribute, AddressAttribute)
61
+ wire_address = address_attribute.address
62
+ assert isinstance(wire_address, AddressWire)
63
+ qubit_idx = wire_address.origin_qubit.data
64
+ qubit_idx_stmt = py.Constant(qubit_idx)
65
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
66
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
67
+
68
+ return tuple(qubit_idx_ssas)
69
+
70
+
71
+ def insert_qubit_idx_after_apply(
72
+ stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
73
+ ) -> tuple[ir.SSAValue, ...] | None:
74
+ """
75
+ Extract qubit indices from Apply or Broadcast statements.
76
+ """
77
+ if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
78
+ qubits = stmt.qubits
79
+ address_attribute = qubits.hints.get("address")
80
+ if address_attribute is None:
81
+ return
82
+ assert isinstance(address_attribute, AddressAttribute)
83
+ return insert_qubit_idx_from_address(
84
+ address=address_attribute, stmt_to_insert_before=stmt
85
+ )
86
+ elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
87
+ wire_ssas = stmt.inputs
88
+ return insert_qubit_idx_from_wire_ssa(
89
+ wire_ssas=wire_ssas, stmt_to_insert_before=stmt
90
+ )
91
+
92
+
93
+ def rewrite_Control(
94
+ stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
95
+ ) -> RewriteResult:
96
+ """
97
+ Handle control gates for Apply and Broadcast statements.
98
+ """
99
+ ctrl_op = stmt_with_ctrl.operator.owner
100
+ assert isinstance(ctrl_op, op.stmts.Control)
101
+
102
+ ctrl_op_target_gate = ctrl_op.op.owner
103
+ assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
104
+
105
+ qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
106
+ if qubit_idx_ssas is None:
107
+ return RewriteResult()
108
+
109
+ # Separate control and target qubits
110
+ target_qubits = []
111
+ ctrl_qubits = []
112
+ for i in range(len(qubit_idx_ssas)):
113
+ if (i % 2) == 0:
114
+ ctrl_qubits.append(qubit_idx_ssas[i])
115
+ else:
116
+ target_qubits.append(qubit_idx_ssas[i])
117
+
118
+ target_qubits = tuple(target_qubits)
119
+ ctrl_qubits = tuple(ctrl_qubits)
120
+
121
+ supported_gate_mapping = {
122
+ op.stmts.X: gate.CX,
123
+ op.stmts.Y: gate.CY,
124
+ op.stmts.Z: gate.CZ,
125
+ }
126
+
127
+ stim_gate = supported_gate_mapping.get(type(ctrl_op_target_gate))
128
+ if stim_gate is None:
129
+ return RewriteResult()
130
+
131
+ stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
132
+
133
+ if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
134
+ # have to "reroute" the input of these statements to directly plug in
135
+ # to subsequent statements, remove dependency on the current statement
136
+ for input_wire, output_wire in zip(
137
+ stmt_with_ctrl.inputs, stmt_with_ctrl.results
138
+ ):
139
+ output_wire.replace_by(input_wire)
140
+
141
+ stmt_with_ctrl.replace_by(stim_stmt)
142
+
143
+ return RewriteResult(has_done_something=True)
144
+
145
+
146
+ def is_measure_result_used(
147
+ stmt: (
148
+ qubit.MeasureAndReset
149
+ | qubit.MeasureQubit
150
+ | qubit.MeasureQubitList
151
+ | wire.MeasureAndReset
152
+ | wire.Measure
153
+ ),
154
+ ) -> bool:
155
+ """
156
+ Check if the result of a measure statement is used in the program.
157
+ """
158
+ return bool(stmt.result.uses)
@@ -0,0 +1,24 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade.squin import wire
5
+
6
+
7
+ class SquinWireIdentityElimination(RewriteRule):
8
+
9
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
10
+ """
11
+ Handle the case where an unwrap feeds a wire directly into a wrap,
12
+ equivalent to nothing happening/identity operation
13
+
14
+ w = unwrap(qubit)
15
+ wrap(qubit, w)
16
+ """
17
+ if isinstance(node, wire.Wrap):
18
+ wire_origin_stmt = node.wire.owner
19
+ if isinstance(wire_origin_stmt, wire.Unwrap):
20
+ node.delete() # get rid of wrap
21
+ wire_origin_stmt.delete() # get rid of the unwrap
22
+ return RewriteResult(has_done_something=True)
23
+
24
+ return RewriteResult()
@@ -0,0 +1,73 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade import stim
5
+ from bloqade.squin import op, wire
6
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
7
+ from bloqade.squin.rewrite.stim_rewrite_util import (
8
+ SQUIN_STIM_GATE_MAPPING,
9
+ rewrite_Control,
10
+ insert_qubit_idx_from_address,
11
+ insert_qubit_idx_from_wire_ssa,
12
+ )
13
+
14
+
15
+ class SquinWireToStim(RewriteRule):
16
+
17
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
+ match node:
19
+ case wire.Apply() | wire.Broadcast():
20
+ return self.rewrite_Apply_and_Broadcast(node)
21
+ case wire.Reset():
22
+ return self.rewrite_Reset(node)
23
+ case _:
24
+ return RewriteResult()
25
+
26
+ def rewrite_Apply_and_Broadcast(
27
+ self, stmt: wire.Apply | wire.Broadcast
28
+ ) -> RewriteResult:
29
+
30
+ # this is an SSAValue, need it to be the actual operator
31
+ applied_op = stmt.operator.owner
32
+ assert isinstance(applied_op, op.stmts.Operator)
33
+
34
+ if isinstance(applied_op, op.stmts.Control):
35
+ return rewrite_Control(stmt)
36
+
37
+ stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
38
+ if stim_1q_op is None:
39
+ return RewriteResult()
40
+
41
+ qubit_idx_ssas = insert_qubit_idx_from_wire_ssa(
42
+ wire_ssas=stmt.inputs, stmt_to_insert_before=stmt
43
+ )
44
+ if qubit_idx_ssas is None:
45
+ return RewriteResult()
46
+
47
+ stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
48
+
49
+ # Get the wires from the inputs of Apply or Broadcast,
50
+ # then put those as the result of the current stmt
51
+ # before replacing it entirely
52
+ for input_wire, output_wire in zip(stmt.inputs, stmt.results):
53
+ output_wire.replace_by(input_wire)
54
+
55
+ stmt.replace_by(stim_1q_stmt)
56
+
57
+ return RewriteResult(has_done_something=True)
58
+
59
+ def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult:
60
+ address_attr = reset_stmt.wire.hints.get("address")
61
+ if address_attr is None:
62
+ return RewriteResult()
63
+ assert isinstance(address_attr, AddressAttribute)
64
+ qubit_idx_ssas = insert_qubit_idx_from_address(
65
+ address=address_attr, stmt_to_insert_before=reset_stmt
66
+ )
67
+ if qubit_idx_ssas is None:
68
+ return RewriteResult()
69
+
70
+ stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
71
+ reset_stmt.replace_by(stim_rz_stmt)
72
+
73
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+ from kirin.print.printer import Printer
6
+
7
+ from bloqade.squin import op, wire
8
+ from bloqade.analysis.address import Address
9
+ from bloqade.squin.analysis.nsites import Sites
10
+
11
+
12
+ @wire.dialect.register
13
+ @dataclass
14
+ class AddressAttribute(ir.Attribute):
15
+
16
+ name = "Address"
17
+ address: Address
18
+
19
+ def __hash__(self) -> int:
20
+ return hash(self.address)
21
+
22
+ def print_impl(self, printer: Printer) -> None:
23
+ # Can return to implementing this later
24
+ printer.print(self.address)
25
+
26
+
27
+ @op.dialect.register
28
+ @dataclass
29
+ class SitesAttribute(ir.Attribute):
30
+
31
+ name = "Sites"
32
+ sites: Sites
33
+
34
+ def __hash__(self) -> int:
35
+ return hash(self.sites)
36
+
37
+ def print_impl(self, printer: Printer) -> None:
38
+ # Can return to implementing this later
39
+ printer.print(self.sites)
40
+
41
+
42
+ @dataclass
43
+ class WrapSquinAnalysis(RewriteRule):
44
+
45
+ address_analysis: dict[ir.SSAValue, Address]
46
+ op_site_analysis: dict[ir.SSAValue, Sites]
47
+
48
+ def wrap(self, value: ir.SSAValue) -> bool:
49
+ address_analysis_result = self.address_analysis[value]
50
+ op_site_analysis_result = self.op_site_analysis[value]
51
+
52
+ if value.hints.get("address") and value.hints.get("sites"):
53
+ return False
54
+ else:
55
+ value.hints["address"] = AddressAttribute(address_analysis_result)
56
+ value.hints["sites"] = SitesAttribute(op_site_analysis_result)
57
+
58
+ return True
59
+
60
+ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
61
+ has_done_something = False
62
+ for arg in node.args:
63
+ if self.wrap(arg):
64
+ has_done_something = True
65
+ return RewriteResult(has_done_something=has_done_something)
66
+
67
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
68
+ has_done_something = False
69
+ for result in node.results:
70
+ if self.wrap(result):
71
+ has_done_something = True
72
+ return RewriteResult(has_done_something=has_done_something)
bloqade/squin/wire.py CHANGED
@@ -6,7 +6,7 @@ circuits. Thus we do not define wrapping functions for the statements in this
6
6
  dialect.
7
7
  """
8
8
 
9
- from kirin import ir, types, interp, lowering
9
+ from kirin import ir, types, lowering
10
10
  from kirin.decl import info, statement
11
11
  from kirin.lowering import wraps
12
12
 
@@ -112,18 +112,6 @@ class Reset(ir.Statement):
112
112
  wire: ir.SSAValue = info.argument(WireType)
113
113
 
114
114
 
115
- # Issue where constant propagation can't handle
116
- # multiple return values from Apply properly
117
- @dialect.register(key="constprop")
118
- class ConstPropWire(interp.MethodTable):
119
-
120
- @interp.impl(Apply)
121
- @interp.impl(Broadcast)
122
- def apply(self, interp, frame, stmt: Apply):
123
-
124
- return frame.get_values(stmt.inputs)
125
-
126
-
127
115
  @wraps(Unwrap)
128
116
  def unwrap(qubit: Qubit) -> Wire: ...
129
117