bloqade-circuit 0.2.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bloqade-circuit might be problematic. Click here for more details.
- bloqade/analysis/address/impls.py +3 -2
- bloqade/pyqrack/device.py +1 -3
- bloqade/pyqrack/noise/native.py +8 -8
- bloqade/pyqrack/qasm2/core.py +4 -1
- bloqade/pyqrack/squin/op.py +7 -0
- bloqade/pyqrack/squin/qubit.py +5 -27
- bloqade/pyqrack/squin/runtime.py +18 -0
- bloqade/pyqrack/squin/wire.py +4 -22
- bloqade/pyqrack/task.py +13 -5
- bloqade/qasm2/__init__.py +1 -0
- bloqade/qasm2/_qasm_loading.py +151 -0
- bloqade/qasm2/dialects/core/__init__.py +9 -1
- bloqade/qasm2/dialects/expr/__init__.py +18 -1
- bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
- bloqade/qasm2/dialects/noise/_dialect.py +3 -0
- bloqade/{noise → qasm2/dialects/noise}/fidelity.py +4 -4
- bloqade/qasm2/dialects/noise/model.py +278 -0
- bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +1 -1
- bloqade/qasm2/dialects/uop/__init__.py +39 -3
- bloqade/qasm2/dialects/uop/schedule.py +1 -1
- bloqade/qasm2/emit/impls/__init__.py +1 -0
- bloqade/qasm2/emit/impls/noise.py +89 -0
- bloqade/qasm2/emit/main.py +23 -4
- bloqade/qasm2/emit/target.py +19 -4
- bloqade/qasm2/noise.py +67 -0
- bloqade/qasm2/parse/__init__.py +7 -4
- bloqade/qasm2/parse/lowering.py +20 -130
- bloqade/qasm2/parse/qasm2.lark +1 -1
- bloqade/qasm2/passes/__init__.py +1 -0
- bloqade/qasm2/passes/fold.py +6 -0
- bloqade/qasm2/passes/glob.py +12 -8
- bloqade/qasm2/passes/noise.py +27 -16
- bloqade/qasm2/passes/parallel.py +9 -0
- bloqade/qasm2/passes/unroll_if.py +25 -0
- bloqade/qasm2/rewrite/__init__.py +3 -0
- bloqade/qasm2/rewrite/desugar.py +3 -2
- bloqade/qasm2/rewrite/native_gates.py +67 -4
- bloqade/qasm2/rewrite/noise/__init__.py +0 -0
- bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +32 -62
- bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
- bloqade/qasm2/rewrite/split_ifs.py +66 -0
- bloqade/qbraid/lowering.py +8 -8
- bloqade/squin/__init__.py +7 -1
- bloqade/squin/analysis/nsites/__init__.py +1 -0
- bloqade/squin/analysis/nsites/impls.py +16 -1
- bloqade/squin/groups.py +4 -4
- bloqade/squin/lowering.py +27 -0
- bloqade/squin/noise/__init__.py +7 -26
- bloqade/squin/noise/_wrapper.py +25 -0
- bloqade/squin/op/__init__.py +34 -159
- bloqade/squin/op/_wrapper.py +105 -0
- bloqade/squin/op/stdlib.py +62 -0
- bloqade/squin/op/stmts.py +10 -0
- bloqade/squin/passes/__init__.py +1 -0
- bloqade/squin/passes/stim.py +68 -0
- bloqade/squin/qubit.py +32 -37
- bloqade/squin/rewrite/__init__.py +11 -0
- bloqade/squin/rewrite/desugar.py +65 -0
- bloqade/squin/rewrite/qubit_to_stim.py +61 -0
- bloqade/squin/rewrite/squin_measure.py +73 -0
- bloqade/squin/rewrite/stim_rewrite_util.py +153 -0
- bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
- bloqade/squin/rewrite/wire_to_stim.py +52 -0
- bloqade/squin/rewrite/wrap_analysis.py +72 -0
- bloqade/squin/wire.py +5 -22
- bloqade/stim/__init__.py +40 -5
- bloqade/stim/_wrappers.py +18 -12
- bloqade/stim/dialects/__init__.py +1 -5
- bloqade/stim/dialects/{aux → auxiliary}/__init__.py +13 -1
- bloqade/stim/dialects/{aux → auxiliary}/emit.py +18 -3
- bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +1 -0
- bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +8 -0
- bloqade/stim/dialects/collapse/__init__.py +13 -2
- bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +4 -2
- bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
- bloqade/stim/dialects/gate/__init__.py +16 -1
- bloqade/stim/dialects/gate/emit.py +10 -3
- bloqade/stim/dialects/gate/stmts/base.py +1 -1
- bloqade/stim/dialects/gate/stmts/pp.py +1 -1
- bloqade/stim/dialects/noise/emit.py +33 -2
- bloqade/stim/dialects/noise/stmts.py +29 -0
- bloqade/stim/emit/__init__.py +1 -1
- bloqade/stim/groups.py +4 -2
- bloqade/stim/parse/__init__.py +1 -0
- bloqade/stim/parse/lowering.py +686 -0
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/METADATA +5 -3
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/RECORD +95 -77
- bloqade/noise/__init__.py +0 -2
- bloqade/noise/native/_dialect.py +0 -3
- bloqade/noise/native/_wrappers.py +0 -34
- bloqade/noise/native/model.py +0 -346
- bloqade/qasm2/dialects/noise.py +0 -16
- bloqade/squin/rewrite/measure_desugar.py +0 -33
- /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
- /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
- /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.dialects import py
|
|
3
|
+
from kirin.rewrite.abc import RewriteResult
|
|
4
|
+
|
|
5
|
+
from bloqade.squin import op, wire, qubit
|
|
6
|
+
from bloqade.stim.dialects import gate, collapse
|
|
7
|
+
from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
|
|
8
|
+
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
|
|
9
|
+
|
|
10
|
+
SQUIN_STIM_GATE_MAPPING = {
|
|
11
|
+
op.stmts.X: gate.X,
|
|
12
|
+
op.stmts.Y: gate.Y,
|
|
13
|
+
op.stmts.Z: gate.Z,
|
|
14
|
+
op.stmts.H: gate.H,
|
|
15
|
+
op.stmts.S: gate.S,
|
|
16
|
+
op.stmts.Identity: gate.Identity,
|
|
17
|
+
op.stmts.Reset: collapse.RZ,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def insert_qubit_idx_from_address(
|
|
22
|
+
address: AddressAttribute, stmt_to_insert_before: ir.Statement
|
|
23
|
+
) -> tuple[ir.SSAValue, ...] | None:
|
|
24
|
+
"""
|
|
25
|
+
Extract qubit indices from an AddressAttribute and insert them into the SSA form.
|
|
26
|
+
"""
|
|
27
|
+
address_data = address.address
|
|
28
|
+
qubit_idx_ssas = []
|
|
29
|
+
|
|
30
|
+
if isinstance(address_data, AddressTuple):
|
|
31
|
+
for address_qubit in address_data.data:
|
|
32
|
+
if not isinstance(address_qubit, AddressQubit):
|
|
33
|
+
return
|
|
34
|
+
qubit_idx = address_qubit.data
|
|
35
|
+
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
36
|
+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
37
|
+
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
38
|
+
elif isinstance(address_data, AddressWire):
|
|
39
|
+
address_qubit = address_data.origin_qubit
|
|
40
|
+
qubit_idx = address_qubit.data
|
|
41
|
+
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
42
|
+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
43
|
+
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
44
|
+
else:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
return tuple(qubit_idx_ssas)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def insert_qubit_idx_from_wire_ssa(
|
|
51
|
+
wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
|
|
52
|
+
) -> tuple[ir.SSAValue, ...] | None:
|
|
53
|
+
"""
|
|
54
|
+
Extract qubit indices from wire SSA values and insert them into the SSA form.
|
|
55
|
+
"""
|
|
56
|
+
qubit_idx_ssas = []
|
|
57
|
+
for wire_ssa in wire_ssas:
|
|
58
|
+
address_attribute = wire_ssa.hints.get("address")
|
|
59
|
+
if address_attribute is None:
|
|
60
|
+
return
|
|
61
|
+
assert isinstance(address_attribute, AddressAttribute)
|
|
62
|
+
wire_address = address_attribute.address
|
|
63
|
+
assert isinstance(wire_address, AddressWire)
|
|
64
|
+
qubit_idx = wire_address.origin_qubit.data
|
|
65
|
+
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
66
|
+
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
67
|
+
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
68
|
+
|
|
69
|
+
return tuple(qubit_idx_ssas)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def insert_qubit_idx_after_apply(
|
|
73
|
+
stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
|
|
74
|
+
) -> tuple[ir.SSAValue, ...] | None:
|
|
75
|
+
"""
|
|
76
|
+
Extract qubit indices from Apply or Broadcast statements.
|
|
77
|
+
"""
|
|
78
|
+
if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
|
|
79
|
+
qubits = stmt.qubits
|
|
80
|
+
address_attribute = qubits.hints.get("address")
|
|
81
|
+
if address_attribute is None:
|
|
82
|
+
return
|
|
83
|
+
assert isinstance(address_attribute, AddressAttribute)
|
|
84
|
+
return insert_qubit_idx_from_address(
|
|
85
|
+
address=address_attribute, stmt_to_insert_before=stmt
|
|
86
|
+
)
|
|
87
|
+
elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
88
|
+
wire_ssas = stmt.inputs
|
|
89
|
+
return insert_qubit_idx_from_wire_ssa(
|
|
90
|
+
wire_ssas=wire_ssas, stmt_to_insert_before=stmt
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def rewrite_Control(
|
|
95
|
+
stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
|
|
96
|
+
) -> RewriteResult:
|
|
97
|
+
"""
|
|
98
|
+
Handle control gates for Apply and Broadcast statements.
|
|
99
|
+
"""
|
|
100
|
+
ctrl_op = stmt_with_ctrl.operator.owner
|
|
101
|
+
assert isinstance(ctrl_op, op.stmts.Control)
|
|
102
|
+
|
|
103
|
+
ctrl_op_target_gate = ctrl_op.op.owner
|
|
104
|
+
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
|
|
105
|
+
|
|
106
|
+
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
|
|
107
|
+
if qubit_idx_ssas is None:
|
|
108
|
+
return RewriteResult()
|
|
109
|
+
|
|
110
|
+
# Separate control and target qubits
|
|
111
|
+
target_qubits = []
|
|
112
|
+
ctrl_qubits = []
|
|
113
|
+
for i in range(len(qubit_idx_ssas)):
|
|
114
|
+
if (i % 2) == 0:
|
|
115
|
+
ctrl_qubits.append(qubit_idx_ssas[i])
|
|
116
|
+
else:
|
|
117
|
+
target_qubits.append(qubit_idx_ssas[i])
|
|
118
|
+
|
|
119
|
+
target_qubits = tuple(target_qubits)
|
|
120
|
+
ctrl_qubits = tuple(ctrl_qubits)
|
|
121
|
+
|
|
122
|
+
supported_gate_mapping = {
|
|
123
|
+
op.stmts.X: gate.CX,
|
|
124
|
+
op.stmts.Y: gate.CY,
|
|
125
|
+
op.stmts.Z: gate.CZ,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
stim_gate = supported_gate_mapping.get(type(ctrl_op_target_gate))
|
|
129
|
+
if stim_gate is None:
|
|
130
|
+
return RewriteResult()
|
|
131
|
+
|
|
132
|
+
stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
|
|
133
|
+
|
|
134
|
+
if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
|
|
135
|
+
# have to "reroute" the input of these statements to directly plug in
|
|
136
|
+
# to subsequent statements, remove dependency on the current statement
|
|
137
|
+
for input_wire, output_wire in zip(
|
|
138
|
+
stmt_with_ctrl.inputs, stmt_with_ctrl.results
|
|
139
|
+
):
|
|
140
|
+
output_wire.replace_by(input_wire)
|
|
141
|
+
|
|
142
|
+
stmt_with_ctrl.replace_by(stim_stmt)
|
|
143
|
+
|
|
144
|
+
return RewriteResult(has_done_something=True)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def is_measure_result_used(
|
|
148
|
+
stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
|
|
149
|
+
) -> bool:
|
|
150
|
+
"""
|
|
151
|
+
Check if the result of a measure statement is used in the program.
|
|
152
|
+
"""
|
|
153
|
+
return bool(stmt.result.uses)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
|
+
|
|
4
|
+
from bloqade.squin import wire
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SquinWireIdentityElimination(RewriteRule):
|
|
8
|
+
|
|
9
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
10
|
+
"""
|
|
11
|
+
Handle the case where an unwrap feeds a wire directly into a wrap,
|
|
12
|
+
equivalent to nothing happening/identity operation
|
|
13
|
+
|
|
14
|
+
w = unwrap(qubit)
|
|
15
|
+
wrap(qubit, w)
|
|
16
|
+
"""
|
|
17
|
+
if isinstance(node, wire.Wrap):
|
|
18
|
+
wire_origin_stmt = node.wire.owner
|
|
19
|
+
if isinstance(wire_origin_stmt, wire.Unwrap):
|
|
20
|
+
node.delete() # get rid of wrap
|
|
21
|
+
wire_origin_stmt.delete() # get rid of the unwrap
|
|
22
|
+
return RewriteResult(has_done_something=True)
|
|
23
|
+
|
|
24
|
+
return RewriteResult()
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from kirin import ir
|
|
2
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
3
|
+
|
|
4
|
+
from bloqade.squin import op, wire
|
|
5
|
+
from bloqade.squin.rewrite.stim_rewrite_util import (
|
|
6
|
+
SQUIN_STIM_GATE_MAPPING,
|
|
7
|
+
rewrite_Control,
|
|
8
|
+
insert_qubit_idx_from_wire_ssa,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SquinWireToStim(RewriteRule):
|
|
13
|
+
|
|
14
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
15
|
+
match node:
|
|
16
|
+
case wire.Apply() | wire.Broadcast():
|
|
17
|
+
return self.rewrite_Apply_and_Broadcast(node)
|
|
18
|
+
case _:
|
|
19
|
+
return RewriteResult()
|
|
20
|
+
|
|
21
|
+
def rewrite_Apply_and_Broadcast(
|
|
22
|
+
self, stmt: wire.Apply | wire.Broadcast
|
|
23
|
+
) -> RewriteResult:
|
|
24
|
+
|
|
25
|
+
# this is an SSAValue, need it to be the actual operator
|
|
26
|
+
applied_op = stmt.operator.owner
|
|
27
|
+
assert isinstance(applied_op, op.stmts.Operator)
|
|
28
|
+
|
|
29
|
+
if isinstance(applied_op, op.stmts.Control):
|
|
30
|
+
return rewrite_Control(stmt)
|
|
31
|
+
|
|
32
|
+
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
|
|
33
|
+
if stim_1q_op is None:
|
|
34
|
+
return RewriteResult()
|
|
35
|
+
|
|
36
|
+
qubit_idx_ssas = insert_qubit_idx_from_wire_ssa(
|
|
37
|
+
wire_ssas=stmt.inputs, stmt_to_insert_before=stmt
|
|
38
|
+
)
|
|
39
|
+
if qubit_idx_ssas is None:
|
|
40
|
+
return RewriteResult()
|
|
41
|
+
|
|
42
|
+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
|
|
43
|
+
|
|
44
|
+
# Get the wires from the inputs of Apply or Broadcast,
|
|
45
|
+
# then put those as the result of the current stmt
|
|
46
|
+
# before replacing it entirely
|
|
47
|
+
for input_wire, output_wire in zip(stmt.inputs, stmt.results):
|
|
48
|
+
output_wire.replace_by(input_wire)
|
|
49
|
+
|
|
50
|
+
stmt.replace_by(stim_1q_stmt)
|
|
51
|
+
|
|
52
|
+
return RewriteResult(has_done_something=True)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kirin import ir
|
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
|
5
|
+
from kirin.print.printer import Printer
|
|
6
|
+
|
|
7
|
+
from bloqade.squin import op, wire
|
|
8
|
+
from bloqade.analysis.address import Address
|
|
9
|
+
from bloqade.squin.analysis.nsites import Sites
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@wire.dialect.register
|
|
13
|
+
@dataclass
|
|
14
|
+
class AddressAttribute(ir.Attribute):
|
|
15
|
+
|
|
16
|
+
name = "Address"
|
|
17
|
+
address: Address
|
|
18
|
+
|
|
19
|
+
def __hash__(self) -> int:
|
|
20
|
+
return hash(self.address)
|
|
21
|
+
|
|
22
|
+
def print_impl(self, printer: Printer) -> None:
|
|
23
|
+
# Can return to implementing this later
|
|
24
|
+
printer.print(self.address)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@op.dialect.register
|
|
28
|
+
@dataclass
|
|
29
|
+
class SitesAttribute(ir.Attribute):
|
|
30
|
+
|
|
31
|
+
name = "Sites"
|
|
32
|
+
sites: Sites
|
|
33
|
+
|
|
34
|
+
def __hash__(self) -> int:
|
|
35
|
+
return hash(self.sites)
|
|
36
|
+
|
|
37
|
+
def print_impl(self, printer: Printer) -> None:
|
|
38
|
+
# Can return to implementing this later
|
|
39
|
+
printer.print(self.sites)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class WrapSquinAnalysis(RewriteRule):
|
|
44
|
+
|
|
45
|
+
address_analysis: dict[ir.SSAValue, Address]
|
|
46
|
+
op_site_analysis: dict[ir.SSAValue, Sites]
|
|
47
|
+
|
|
48
|
+
def wrap(self, value: ir.SSAValue) -> bool:
|
|
49
|
+
address_analysis_result = self.address_analysis[value]
|
|
50
|
+
op_site_analysis_result = self.op_site_analysis[value]
|
|
51
|
+
|
|
52
|
+
if value.hints.get("address") and value.hints.get("sites"):
|
|
53
|
+
return False
|
|
54
|
+
else:
|
|
55
|
+
value.hints["address"] = AddressAttribute(address_analysis_result)
|
|
56
|
+
value.hints["sites"] = SitesAttribute(op_site_analysis_result)
|
|
57
|
+
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
|
|
61
|
+
has_done_something = False
|
|
62
|
+
for arg in node.args:
|
|
63
|
+
if self.wrap(arg):
|
|
64
|
+
has_done_something = True
|
|
65
|
+
return RewriteResult(has_done_something=has_done_something)
|
|
66
|
+
|
|
67
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
|
68
|
+
has_done_something = False
|
|
69
|
+
for result in node.results:
|
|
70
|
+
if self.wrap(result):
|
|
71
|
+
has_done_something = True
|
|
72
|
+
return RewriteResult(has_done_something=has_done_something)
|
bloqade/squin/wire.py
CHANGED
|
@@ -6,7 +6,7 @@ circuits. Thus we do not define wrapping functions for the statements in this
|
|
|
6
6
|
dialect.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from kirin import ir, types,
|
|
9
|
+
from kirin import ir, types, lowering
|
|
10
10
|
from kirin.decl import info, statement
|
|
11
11
|
from kirin.lowering import wraps
|
|
12
12
|
|
|
@@ -95,35 +95,18 @@ class Broadcast(ir.Statement):
|
|
|
95
95
|
class Measure(ir.Statement):
|
|
96
96
|
traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
|
|
97
97
|
wire: ir.SSAValue = info.argument(WireType)
|
|
98
|
+
qubit: ir.SSAValue = info.argument(QubitType)
|
|
98
99
|
result: ir.ResultValue = info.result(types.Int)
|
|
99
100
|
|
|
100
101
|
|
|
101
102
|
@statement(dialect=dialect)
|
|
102
|
-
class
|
|
103
|
-
traits = frozenset({lowering.FromPythonCall()
|
|
104
|
-
|
|
103
|
+
class NonDestructiveMeasure(ir.Statement):
|
|
104
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
105
|
+
input_wire: ir.SSAValue = info.argument(WireType)
|
|
105
106
|
result: ir.ResultValue = info.result(types.Int)
|
|
106
107
|
out_wire: ir.ResultValue = info.result(WireType)
|
|
107
108
|
|
|
108
109
|
|
|
109
|
-
@statement(dialect=dialect)
|
|
110
|
-
class Reset(ir.Statement):
|
|
111
|
-
traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
|
|
112
|
-
wire: ir.SSAValue = info.argument(WireType)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
# Issue where constant propagation can't handle
|
|
116
|
-
# multiple return values from Apply properly
|
|
117
|
-
@dialect.register(key="constprop")
|
|
118
|
-
class ConstPropWire(interp.MethodTable):
|
|
119
|
-
|
|
120
|
-
@interp.impl(Apply)
|
|
121
|
-
@interp.impl(Broadcast)
|
|
122
|
-
def apply(self, interp, frame, stmt: Apply):
|
|
123
|
-
|
|
124
|
-
return frame.get_values(stmt.inputs)
|
|
125
|
-
|
|
126
|
-
|
|
127
110
|
@wraps(Unwrap)
|
|
128
111
|
def unwrap(qubit: Qubit) -> Wire: ...
|
|
129
112
|
|
bloqade/stim/__init__.py
CHANGED
|
@@ -1,6 +1,41 @@
|
|
|
1
|
+
from . import emit as emit, parse as parse, dialects as dialects
|
|
1
2
|
from .groups import main as main
|
|
2
|
-
from ._wrappers import
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
3
|
+
from ._wrappers import (
|
|
4
|
+
h as h,
|
|
5
|
+
s as s,
|
|
6
|
+
x as x,
|
|
7
|
+
y as y,
|
|
8
|
+
z as z,
|
|
9
|
+
cx as cx,
|
|
10
|
+
cy as cy,
|
|
11
|
+
cz as cz,
|
|
12
|
+
mx as mx,
|
|
13
|
+
my as my,
|
|
14
|
+
mz as mz,
|
|
15
|
+
rx as rx,
|
|
16
|
+
ry as ry,
|
|
17
|
+
rz as rz,
|
|
18
|
+
mpp as mpp,
|
|
19
|
+
mxx as mxx,
|
|
20
|
+
myy as myy,
|
|
21
|
+
mzz as mzz,
|
|
22
|
+
rec as rec,
|
|
23
|
+
spp as spp,
|
|
24
|
+
swap as swap,
|
|
25
|
+
tick as tick,
|
|
26
|
+
sqrt_x as sqrt_x,
|
|
27
|
+
sqrt_y as sqrt_y,
|
|
28
|
+
sqrt_z as sqrt_z,
|
|
29
|
+
x_error as x_error,
|
|
30
|
+
y_error as y_error,
|
|
31
|
+
z_error as z_error,
|
|
32
|
+
detector as detector,
|
|
33
|
+
identity as identity,
|
|
34
|
+
depolarize1 as depolarize1,
|
|
35
|
+
depolarize2 as depolarize2,
|
|
36
|
+
pauli_string as pauli_string,
|
|
37
|
+
qubit_coords as qubit_coords,
|
|
38
|
+
pauli_channel1 as pauli_channel1,
|
|
39
|
+
pauli_channel2 as pauli_channel2,
|
|
40
|
+
observable_include as observable_include,
|
|
41
|
+
)
|
bloqade/stim/_wrappers.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import Union
|
|
|
2
2
|
|
|
3
3
|
from kirin.lowering import wraps
|
|
4
4
|
|
|
5
|
-
from .dialects import
|
|
5
|
+
from .dialects import gate, noise, collapse, auxiliary
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
# dialect:: gate
|
|
@@ -69,32 +69,38 @@ def cz(
|
|
|
69
69
|
|
|
70
70
|
## pp
|
|
71
71
|
@wraps(gate.SPP)
|
|
72
|
-
def spp(targets: tuple[
|
|
72
|
+
def spp(targets: tuple[auxiliary.PauliString, ...], dagger=False) -> None: ...
|
|
73
73
|
|
|
74
74
|
|
|
75
75
|
# dialect:: aux
|
|
76
|
-
@wraps(
|
|
77
|
-
def rec(id: int) ->
|
|
76
|
+
@wraps(auxiliary.GetRecord)
|
|
77
|
+
def rec(id: int) -> auxiliary.RecordResult: ...
|
|
78
78
|
|
|
79
79
|
|
|
80
|
-
@wraps(
|
|
80
|
+
@wraps(auxiliary.Detector)
|
|
81
81
|
def detector(
|
|
82
|
-
coord: tuple[Union[int, float], ...], targets: tuple[
|
|
82
|
+
coord: tuple[Union[int, float], ...], targets: tuple[auxiliary.RecordResult, ...]
|
|
83
83
|
) -> None: ...
|
|
84
84
|
|
|
85
85
|
|
|
86
|
-
@wraps(
|
|
87
|
-
def observable_include(
|
|
86
|
+
@wraps(auxiliary.ObservableInclude)
|
|
87
|
+
def observable_include(
|
|
88
|
+
idx: int, targets: tuple[auxiliary.RecordResult, ...]
|
|
89
|
+
) -> None: ...
|
|
88
90
|
|
|
89
91
|
|
|
90
|
-
@wraps(
|
|
92
|
+
@wraps(auxiliary.Tick)
|
|
91
93
|
def tick() -> None: ...
|
|
92
94
|
|
|
93
95
|
|
|
94
|
-
@wraps(
|
|
96
|
+
@wraps(auxiliary.NewPauliString)
|
|
95
97
|
def pauli_string(
|
|
96
98
|
string: tuple[str, ...], flipped: tuple[bool, ...], targets: tuple[int, ...]
|
|
97
|
-
) ->
|
|
99
|
+
) -> auxiliary.PauliString: ...
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@wraps(auxiliary.QubitCoordinates)
|
|
103
|
+
def qubit_coords(coord: tuple[Union[int, float], ...], target: int) -> None: ...
|
|
98
104
|
|
|
99
105
|
|
|
100
106
|
# dialect:: collapse
|
|
@@ -123,7 +129,7 @@ def mxx(p: float, targets: tuple[int, ...]) -> None: ...
|
|
|
123
129
|
|
|
124
130
|
|
|
125
131
|
@wraps(collapse.PPMeasurement)
|
|
126
|
-
def mpp(p: float, targets: tuple[
|
|
132
|
+
def mpp(p: float, targets: tuple[auxiliary.PauliString, ...]) -> None: ...
|
|
127
133
|
|
|
128
134
|
|
|
129
135
|
@wraps(collapse.RZ)
|
|
@@ -1,5 +1 @@
|
|
|
1
|
-
from . import
|
|
2
|
-
from .aux.stmts import * # noqa F403
|
|
3
|
-
from .gate.stmts import * # noqa F403
|
|
4
|
-
from .noise.stmts import * # noqa F403
|
|
5
|
-
from .collapse.stmts import * # noqa F403
|
|
1
|
+
from . import gate as gate, noise as noise, collapse as collapse, auxiliary as auxiliary
|
|
@@ -1,6 +1,18 @@
|
|
|
1
1
|
from . import lowering as lowering
|
|
2
2
|
from .emit import EmitStimAuxMethods as EmitStimAuxMethods
|
|
3
|
-
from .stmts import
|
|
3
|
+
from .stmts import (
|
|
4
|
+
Neg as Neg,
|
|
5
|
+
Tick as Tick,
|
|
6
|
+
ConstInt as ConstInt,
|
|
7
|
+
ConstStr as ConstStr,
|
|
8
|
+
Detector as Detector,
|
|
9
|
+
ConstBool as ConstBool,
|
|
10
|
+
GetRecord as GetRecord,
|
|
11
|
+
ConstFloat as ConstFloat,
|
|
12
|
+
NewPauliString as NewPauliString,
|
|
13
|
+
QubitCoordinates as QubitCoordinates,
|
|
14
|
+
ObservableInclude as ObservableInclude,
|
|
15
|
+
)
|
|
4
16
|
from .types import (
|
|
5
17
|
RecordType as RecordType,
|
|
6
18
|
PauliString as PauliString,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from kirin.emit import EmitStrFrame
|
|
2
2
|
from kirin.interp import MethodTable, impl
|
|
3
3
|
|
|
4
|
-
from bloqade.stim.emit.
|
|
4
|
+
from bloqade.stim.emit.stim_str import EmitStimMain
|
|
5
5
|
|
|
6
6
|
from . import stmts
|
|
7
7
|
from ._dialect import dialect
|
|
@@ -69,8 +69,10 @@ class EmitStimAuxMethods(MethodTable):
|
|
|
69
69
|
|
|
70
70
|
coord_str: str = ", ".join(coords)
|
|
71
71
|
target_str: str = " ".join(targets)
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
if len(coords):
|
|
73
|
+
emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
|
|
74
|
+
else:
|
|
75
|
+
emit.writeln(frame, f"DETECTOR {target_str}")
|
|
74
76
|
return ()
|
|
75
77
|
|
|
76
78
|
@impl(stmts.ObservableInclude)
|
|
@@ -100,3 +102,16 @@ class EmitStimAuxMethods(MethodTable):
|
|
|
100
102
|
)
|
|
101
103
|
|
|
102
104
|
return (out,)
|
|
105
|
+
|
|
106
|
+
@impl(stmts.QubitCoordinates)
|
|
107
|
+
def qubit_coordinates(
|
|
108
|
+
self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.QubitCoordinates
|
|
109
|
+
):
|
|
110
|
+
|
|
111
|
+
coords: tuple[str, ...] = frame.get_values(stmt.coord)
|
|
112
|
+
target: str = frame.get(stmt.target)
|
|
113
|
+
|
|
114
|
+
coord_str: str = ", ".join(coords)
|
|
115
|
+
emit.writeln(frame, f"QUBIT_COORDS({coord_str}) {target}")
|
|
116
|
+
|
|
117
|
+
return ()
|
|
@@ -45,3 +45,11 @@ class NewPauliString(ir.Statement):
|
|
|
45
45
|
flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool)
|
|
46
46
|
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
|
|
47
47
|
result: ir.ResultValue = info.result(type=PauliStringType)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@statement(dialect=dialect)
|
|
51
|
+
class QubitCoordinates(ir.Statement):
|
|
52
|
+
name = "qubit_coordinates"
|
|
53
|
+
traits = frozenset({lowering.FromPythonCall()})
|
|
54
|
+
coord: tuple[ir.SSAValue, ...] = info.argument(PyNum)
|
|
55
|
+
target: ir.SSAValue = info.argument(types.Int)
|
|
@@ -1,3 +1,14 @@
|
|
|
1
|
-
from .
|
|
2
|
-
|
|
1
|
+
from .stmts import (
|
|
2
|
+
MX as MX,
|
|
3
|
+
MY as MY,
|
|
4
|
+
MZ as MZ,
|
|
5
|
+
RX as RX,
|
|
6
|
+
RY as RY,
|
|
7
|
+
RZ as RZ,
|
|
8
|
+
MXX as MXX,
|
|
9
|
+
MYY as MYY,
|
|
10
|
+
MZZ as MZZ,
|
|
11
|
+
PPMeasurement as PPMeasurement,
|
|
12
|
+
)
|
|
3
13
|
from ._dialect import dialect as dialect
|
|
14
|
+
from .emit_str import EmitStimCollapseMethods as EmitStimCollapseMethods
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from kirin.emit import EmitStrFrame
|
|
2
2
|
from kirin.interp import MethodTable, impl
|
|
3
3
|
|
|
4
|
-
from bloqade.stim.emit.
|
|
4
|
+
from bloqade.stim.emit.stim_str import EmitStimMain
|
|
5
5
|
|
|
6
6
|
from . import stmts
|
|
7
7
|
from ._dialect import dialect
|
|
@@ -60,7 +60,9 @@ class EmitStimCollapseMethods(MethodTable):
|
|
|
60
60
|
self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PPMeasurement
|
|
61
61
|
):
|
|
62
62
|
probability: str = frame.get(stmt.p)
|
|
63
|
-
targets: tuple[str, ...] =
|
|
63
|
+
targets: tuple[str, ...] = tuple(
|
|
64
|
+
targ.upper() for targ in frame.get_values(stmt.targets)
|
|
65
|
+
)
|
|
64
66
|
|
|
65
67
|
out = f"MPP({probability}) " + " ".join(targets)
|
|
66
68
|
emit.writeln(frame, out)
|
|
@@ -1,3 +1,18 @@
|
|
|
1
1
|
from .emit import EmitStimGateMethods as EmitStimGateMethods
|
|
2
|
-
from .stmts import
|
|
2
|
+
from .stmts import (
|
|
3
|
+
CX as CX,
|
|
4
|
+
CY as CY,
|
|
5
|
+
CZ as CZ,
|
|
6
|
+
SPP as SPP,
|
|
7
|
+
H as H,
|
|
8
|
+
S as S,
|
|
9
|
+
X as X,
|
|
10
|
+
Y as Y,
|
|
11
|
+
Z as Z,
|
|
12
|
+
Swap as Swap,
|
|
13
|
+
SqrtX as SqrtX,
|
|
14
|
+
SqrtY as SqrtY,
|
|
15
|
+
SqrtZ as SqrtZ,
|
|
16
|
+
Identity as Identity,
|
|
17
|
+
)
|
|
3
18
|
from ._dialect import dialect as dialect
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from kirin.emit import EmitStrFrame
|
|
2
2
|
from kirin.interp import MethodTable, impl
|
|
3
3
|
|
|
4
|
-
from bloqade.stim.emit.
|
|
4
|
+
from bloqade.stim.emit.stim_str import EmitStimMain
|
|
5
5
|
|
|
6
6
|
from . import stmts
|
|
7
7
|
from ._dialect import dialect
|
|
@@ -12,6 +12,7 @@ from .stmts.base import SingleQubitGate, ControlledTwoQubitGate
|
|
|
12
12
|
class EmitStimGateMethods(MethodTable):
|
|
13
13
|
|
|
14
14
|
gate_1q_map: dict[str, tuple[str, str]] = {
|
|
15
|
+
stmts.Identity.name: ("I", "I"),
|
|
15
16
|
stmts.X.name: ("X", "X"),
|
|
16
17
|
stmts.Y.name: ("Y", "Y"),
|
|
17
18
|
stmts.Z.name: ("Z", "Z"),
|
|
@@ -22,6 +23,7 @@ class EmitStimGateMethods(MethodTable):
|
|
|
22
23
|
stmts.SqrtZ.name: ("SQRT_Z", "SQRT_Z_DAG"),
|
|
23
24
|
}
|
|
24
25
|
|
|
26
|
+
@impl(stmts.Identity)
|
|
25
27
|
@impl(stmts.X)
|
|
26
28
|
@impl(stmts.Y)
|
|
27
29
|
@impl(stmts.Z)
|
|
@@ -80,8 +82,13 @@ class EmitStimGateMethods(MethodTable):
|
|
|
80
82
|
@impl(stmts.SPP)
|
|
81
83
|
def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP):
|
|
82
84
|
|
|
83
|
-
targets: tuple[str, ...] =
|
|
84
|
-
|
|
85
|
+
targets: tuple[str, ...] = tuple(
|
|
86
|
+
targ.upper() for targ in frame.get_values(stmt.targets)
|
|
87
|
+
)
|
|
88
|
+
if stmt.dagger:
|
|
89
|
+
res = "SPP_DAG " + " ".join(targets)
|
|
90
|
+
else:
|
|
91
|
+
res = "SPP " + " ".join(targets)
|
|
85
92
|
emit.writeln(frame, res)
|
|
86
93
|
|
|
87
94
|
return ()
|