bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bloqade/analysis/address/__init__.py +8 -4
- bloqade/analysis/address/analysis.py +123 -33
- bloqade/analysis/address/impls.py +293 -90
- bloqade/analysis/address/lattice.py +209 -24
- bloqade/analysis/fidelity/analysis.py +11 -23
- bloqade/analysis/measure_id/__init__.py +4 -1
- bloqade/analysis/measure_id/analysis.py +29 -20
- bloqade/analysis/measure_id/impls.py +72 -31
- bloqade/annotate/__init__.py +6 -0
- bloqade/annotate/_dialect.py +3 -0
- bloqade/annotate/_interface.py +22 -0
- bloqade/annotate/stmts.py +29 -0
- bloqade/annotate/types.py +13 -0
- bloqade/cirq_utils/__init__.py +4 -2
- bloqade/cirq_utils/emit/__init__.py +3 -0
- bloqade/cirq_utils/emit/base.py +246 -0
- bloqade/cirq_utils/emit/gate.py +104 -0
- bloqade/cirq_utils/emit/noise.py +90 -0
- bloqade/cirq_utils/emit/qubit.py +35 -0
- bloqade/cirq_utils/lowering.py +660 -0
- bloqade/cirq_utils/noise/__init__.py +0 -2
- bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
- bloqade/cirq_utils/noise/model.py +151 -191
- bloqade/cirq_utils/noise/transform.py +2 -2
- bloqade/cirq_utils/parallelize.py +9 -6
- bloqade/gemini/__init__.py +1 -0
- bloqade/gemini/analysis/__init__.py +3 -0
- bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
- bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
- bloqade/gemini/analysis/logical_validation/impls.py +101 -0
- bloqade/gemini/groups.py +67 -0
- bloqade/native/__init__.py +23 -0
- bloqade/native/_prelude.py +45 -0
- bloqade/native/dialects/__init__.py +0 -0
- bloqade/native/dialects/gate/__init__.py +2 -0
- bloqade/native/dialects/gate/_dialect.py +3 -0
- bloqade/native/dialects/gate/_interface.py +32 -0
- bloqade/native/dialects/gate/stmts.py +31 -0
- bloqade/native/stdlib/__init__.py +0 -0
- bloqade/native/stdlib/broadcast.py +246 -0
- bloqade/native/stdlib/simple.py +220 -0
- bloqade/native/upstream/__init__.py +4 -0
- bloqade/native/upstream/squin2native.py +79 -0
- bloqade/pyqrack/__init__.py +2 -2
- bloqade/pyqrack/base.py +7 -1
- bloqade/pyqrack/device.py +190 -4
- bloqade/pyqrack/native.py +49 -0
- bloqade/pyqrack/reg.py +6 -6
- bloqade/pyqrack/squin/gate/__init__.py +1 -0
- bloqade/pyqrack/squin/gate/gate.py +136 -0
- bloqade/pyqrack/squin/noise/native.py +120 -54
- bloqade/pyqrack/squin/qubit.py +39 -36
- bloqade/pyqrack/target.py +5 -4
- bloqade/pyqrack/task.py +114 -7
- bloqade/qasm2/_qasm_loading.py +3 -3
- bloqade/qasm2/dialects/core/address.py +21 -12
- bloqade/qasm2/dialects/expr/_emit.py +19 -8
- bloqade/qasm2/dialects/expr/stmts.py +7 -7
- bloqade/qasm2/dialects/noise/fidelity.py +4 -8
- bloqade/qasm2/dialects/noise/model.py +2 -1
- bloqade/qasm2/emit/base.py +16 -11
- bloqade/qasm2/emit/gate.py +11 -8
- bloqade/qasm2/emit/main.py +103 -3
- bloqade/qasm2/emit/target.py +9 -5
- bloqade/qasm2/groups.py +3 -2
- bloqade/qasm2/parse/lowering.py +0 -1
- bloqade/qasm2/passes/fold.py +14 -73
- bloqade/qasm2/passes/glob.py +2 -2
- bloqade/qasm2/passes/noise.py +1 -1
- bloqade/qasm2/passes/parallel.py +7 -5
- bloqade/qasm2/rewrite/__init__.py +0 -1
- bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
- bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
- bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
- bloqade/qasm2/rewrite/register.py +2 -2
- bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
- bloqade/qbraid/lowering.py +1 -0
- bloqade/qbraid/schema.py +2 -2
- bloqade/qubit/__init__.py +12 -0
- bloqade/qubit/_dialect.py +3 -0
- bloqade/qubit/_interface.py +49 -0
- bloqade/qubit/_prelude.py +45 -0
- bloqade/qubit/analysis/__init__.py +1 -0
- bloqade/qubit/analysis/address_impl.py +40 -0
- bloqade/qubit/stdlib/__init__.py +2 -0
- bloqade/qubit/stdlib/_new.py +34 -0
- bloqade/qubit/stdlib/broadcast.py +62 -0
- bloqade/qubit/stdlib/simple.py +59 -0
- bloqade/qubit/stmts.py +60 -0
- bloqade/rewrite/passes/__init__.py +6 -0
- bloqade/rewrite/passes/aggressive_unroll.py +103 -0
- bloqade/rewrite/passes/callgraph.py +116 -0
- bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
- bloqade/rewrite/rules/split_ifs.py +18 -1
- bloqade/squin/__init__.py +47 -14
- bloqade/squin/analysis/__init__.py +0 -1
- bloqade/squin/analysis/schedule.py +10 -11
- bloqade/squin/gate/__init__.py +2 -0
- bloqade/squin/gate/_dialect.py +3 -0
- bloqade/squin/gate/_interface.py +98 -0
- bloqade/squin/gate/stmts.py +125 -0
- bloqade/squin/groups.py +5 -22
- bloqade/squin/noise/__init__.py +1 -10
- bloqade/squin/noise/_dialect.py +1 -1
- bloqade/squin/noise/_interface.py +45 -0
- bloqade/squin/noise/stmts.py +66 -28
- bloqade/squin/rewrite/U3_to_clifford.py +70 -51
- bloqade/squin/rewrite/__init__.py +0 -2
- bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
- bloqade/squin/rewrite/wrap_analysis.py +4 -35
- bloqade/squin/stdlib/__init__.py +0 -0
- bloqade/squin/stdlib/broadcast/__init__.py +34 -0
- bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
- bloqade/squin/stdlib/broadcast/gate.py +260 -0
- bloqade/squin/stdlib/broadcast/noise.py +144 -0
- bloqade/squin/stdlib/simple/__init__.py +33 -0
- bloqade/squin/stdlib/simple/gate.py +242 -0
- bloqade/squin/stdlib/simple/noise.py +126 -0
- bloqade/stim/__init__.py +1 -0
- bloqade/stim/_wrappers.py +6 -0
- bloqade/stim/dialects/auxiliary/emit.py +19 -18
- bloqade/stim/dialects/collapse/emit_str.py +7 -8
- bloqade/stim/dialects/gate/emit.py +9 -10
- bloqade/stim/dialects/noise/emit.py +17 -13
- bloqade/stim/dialects/noise/stmts.py +5 -3
- bloqade/stim/emit/__init__.py +1 -0
- bloqade/stim/emit/impls.py +16 -0
- bloqade/stim/emit/stim_str.py +48 -31
- bloqade/stim/groups.py +12 -2
- bloqade/stim/parse/lowering.py +14 -17
- bloqade/stim/passes/__init__.py +3 -1
- bloqade/stim/passes/flatten.py +26 -0
- bloqade/stim/passes/simplify_ifs.py +16 -2
- bloqade/stim/passes/squin_to_stim.py +18 -60
- bloqade/stim/rewrite/__init__.py +3 -4
- bloqade/stim/rewrite/get_record_util.py +24 -0
- bloqade/stim/rewrite/ifs_to_stim.py +29 -31
- bloqade/stim/rewrite/qubit_to_stim.py +90 -41
- bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
- bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
- bloqade/stim/rewrite/squin_measure.py +11 -79
- bloqade/stim/rewrite/squin_noise.py +134 -108
- bloqade/stim/rewrite/util.py +5 -192
- bloqade/test_utils.py +1 -1
- bloqade/types.py +10 -0
- bloqade/validation/__init__.py +2 -0
- bloqade/validation/analysis/__init__.py +5 -0
- bloqade/validation/analysis/analysis.py +41 -0
- bloqade/validation/analysis/lattice.py +58 -0
- bloqade/validation/kernel_validation.py +77 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
- bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
- bloqade/pyqrack/squin/op.py +0 -166
- bloqade/pyqrack/squin/runtime.py +0 -535
- bloqade/pyqrack/squin/wire.py +0 -51
- bloqade/rewrite/rules/flatten_ilist.py +0 -51
- bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
- bloqade/squin/_typeinfer.py +0 -20
- bloqade/squin/analysis/address_impl.py +0 -71
- bloqade/squin/analysis/nsites/__init__.py +0 -9
- bloqade/squin/analysis/nsites/analysis.py +0 -50
- bloqade/squin/analysis/nsites/impls.py +0 -92
- bloqade/squin/analysis/nsites/lattice.py +0 -49
- bloqade/squin/cirq/__init__.py +0 -265
- bloqade/squin/cirq/emit/emit_circuit.py +0 -109
- bloqade/squin/cirq/emit/noise.py +0 -49
- bloqade/squin/cirq/emit/op.py +0 -125
- bloqade/squin/cirq/emit/qubit.py +0 -60
- bloqade/squin/cirq/emit/runtime.py +0 -242
- bloqade/squin/cirq/lowering.py +0 -440
- bloqade/squin/lowering.py +0 -54
- bloqade/squin/noise/_wrapper.py +0 -40
- bloqade/squin/noise/rewrite.py +0 -111
- bloqade/squin/op/__init__.py +0 -41
- bloqade/squin/op/_dialect.py +0 -3
- bloqade/squin/op/_wrapper.py +0 -121
- bloqade/squin/op/number.py +0 -5
- bloqade/squin/op/rewrite.py +0 -46
- bloqade/squin/op/stdlib.py +0 -62
- bloqade/squin/op/stmts.py +0 -276
- bloqade/squin/op/traits.py +0 -43
- bloqade/squin/op/types.py +0 -26
- bloqade/squin/qubit.py +0 -184
- bloqade/squin/rewrite/canonicalize.py +0 -60
- bloqade/squin/rewrite/desugar.py +0 -124
- bloqade/squin/types.py +0 -8
- bloqade/squin/wire.py +0 -201
- bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
- bloqade/stim/rewrite/wire_to_stim.py +0 -57
- bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
- {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
bloqade/stim/rewrite/util.py
CHANGED
|
@@ -1,35 +1,8 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from kirin import ir, interp
|
|
4
|
-
from kirin.analysis import const
|
|
1
|
+
from kirin import ir
|
|
5
2
|
from kirin.dialects import py
|
|
6
|
-
from kirin.rewrite.abc import RewriteResult
|
|
7
3
|
|
|
8
|
-
from bloqade.squin import op, wire, noise as squin_noise, qubit
|
|
9
4
|
from bloqade.squin.rewrite import AddressAttribute
|
|
10
|
-
from bloqade.
|
|
11
|
-
from bloqade.analysis.address import AddressReg, AddressWire, AddressQubit, AddressTuple
|
|
12
|
-
|
|
13
|
-
SQUIN_STIM_OP_MAPPING = {
|
|
14
|
-
op.stmts.X: gate.X,
|
|
15
|
-
op.stmts.Y: gate.Y,
|
|
16
|
-
op.stmts.Z: gate.Z,
|
|
17
|
-
op.stmts.H: gate.H,
|
|
18
|
-
op.stmts.S: gate.S,
|
|
19
|
-
op.stmts.SqrtX: gate.SqrtX,
|
|
20
|
-
op.stmts.SqrtY: gate.SqrtY,
|
|
21
|
-
op.stmts.Identity: gate.Identity,
|
|
22
|
-
op.stmts.Reset: collapse.RZ,
|
|
23
|
-
squin_noise.stmts.QubitLoss: stim_noise.QubitLoss,
|
|
24
|
-
}
|
|
25
|
-
|
|
26
|
-
# Squin allows creation of control gates where the gate can be any operator,
|
|
27
|
-
# but Stim only supports CX, CY, and CZ as control gates.
|
|
28
|
-
SQUIN_STIM_CONTROL_GATE_MAPPING = {
|
|
29
|
-
op.stmts.X: gate.CX,
|
|
30
|
-
op.stmts.Y: gate.CY,
|
|
31
|
-
op.stmts.Z: gate.CZ,
|
|
32
|
-
}
|
|
5
|
+
from bloqade.analysis.address import AddressReg, AddressQubit
|
|
33
6
|
|
|
34
7
|
|
|
35
8
|
def create_and_insert_qubit_idx_stmt(
|
|
@@ -46,177 +19,17 @@ def insert_qubit_idx_from_address(
|
|
|
46
19
|
"""
|
|
47
20
|
Extract qubit indices from an AddressAttribute and insert them into the SSA form.
|
|
48
21
|
"""
|
|
49
|
-
address_data = address.address
|
|
50
22
|
qubit_idx_ssas = []
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
for address_qubit in address_data.data:
|
|
54
|
-
if not isinstance(address_qubit, AddressQubit):
|
|
55
|
-
return
|
|
23
|
+
if isinstance(address_data := address.address, AddressReg):
|
|
24
|
+
for qubit_idx in address_data.qubits:
|
|
56
25
|
create_and_insert_qubit_idx_stmt(
|
|
57
|
-
|
|
58
|
-
)
|
|
59
|
-
elif isinstance(address_data, AddressReg):
|
|
60
|
-
for qubit_idx in address_data.data:
|
|
61
|
-
create_and_insert_qubit_idx_stmt(
|
|
62
|
-
qubit_idx, stmt_to_insert_before, qubit_idx_ssas
|
|
26
|
+
qubit_idx.data, stmt_to_insert_before, qubit_idx_ssas
|
|
63
27
|
)
|
|
64
28
|
elif isinstance(address_data, AddressQubit):
|
|
65
29
|
create_and_insert_qubit_idx_stmt(
|
|
66
30
|
address_data.data, stmt_to_insert_before, qubit_idx_ssas
|
|
67
31
|
)
|
|
68
|
-
elif isinstance(address_data, AddressWire):
|
|
69
|
-
address_qubit = address_data.origin_qubit
|
|
70
|
-
create_and_insert_qubit_idx_stmt(
|
|
71
|
-
address_qubit.data, stmt_to_insert_before, qubit_idx_ssas
|
|
72
|
-
)
|
|
73
32
|
else:
|
|
74
33
|
return
|
|
75
34
|
|
|
76
35
|
return tuple(qubit_idx_ssas)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def insert_qubit_idx_from_wire_ssa(
|
|
80
|
-
wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
|
|
81
|
-
) -> tuple[ir.SSAValue, ...] | None:
|
|
82
|
-
"""
|
|
83
|
-
Extract qubit indices from wire SSA values and insert them into the SSA form.
|
|
84
|
-
"""
|
|
85
|
-
qubit_idx_ssas = []
|
|
86
|
-
for wire_ssa in wire_ssas:
|
|
87
|
-
address_attribute = wire_ssa.hints.get("address")
|
|
88
|
-
if address_attribute is None:
|
|
89
|
-
return
|
|
90
|
-
assert isinstance(address_attribute, AddressAttribute)
|
|
91
|
-
wire_address = address_attribute.address
|
|
92
|
-
assert isinstance(wire_address, AddressWire)
|
|
93
|
-
qubit_idx = wire_address.origin_qubit.data
|
|
94
|
-
qubit_idx_stmt = py.Constant(qubit_idx)
|
|
95
|
-
qubit_idx_ssas.append(qubit_idx_stmt.result)
|
|
96
|
-
qubit_idx_stmt.insert_before(stmt_to_insert_before)
|
|
97
|
-
|
|
98
|
-
return tuple(qubit_idx_ssas)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def insert_qubit_idx_after_apply(
|
|
102
|
-
stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
|
|
103
|
-
) -> tuple[ir.SSAValue, ...] | None:
|
|
104
|
-
"""
|
|
105
|
-
Extract qubit indices from Apply or Broadcast statements.
|
|
106
|
-
"""
|
|
107
|
-
if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
|
|
108
|
-
qubits = stmt.qubits
|
|
109
|
-
address_attribute = qubits.hints.get("address")
|
|
110
|
-
if address_attribute is None:
|
|
111
|
-
return
|
|
112
|
-
assert isinstance(address_attribute, AddressAttribute)
|
|
113
|
-
return insert_qubit_idx_from_address(
|
|
114
|
-
address=address_attribute, stmt_to_insert_before=stmt
|
|
115
|
-
)
|
|
116
|
-
elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
117
|
-
wire_ssas = stmt.inputs
|
|
118
|
-
return insert_qubit_idx_from_wire_ssa(
|
|
119
|
-
wire_ssas=wire_ssas, stmt_to_insert_before=stmt
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def rewrite_Control(
|
|
124
|
-
stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
|
|
125
|
-
) -> RewriteResult:
|
|
126
|
-
"""
|
|
127
|
-
Handle control gates for Apply and Broadcast statements.
|
|
128
|
-
"""
|
|
129
|
-
ctrl_op = stmt_with_ctrl.operator.owner
|
|
130
|
-
assert isinstance(ctrl_op, op.stmts.Control)
|
|
131
|
-
|
|
132
|
-
ctrl_op_target_gate = ctrl_op.op.owner
|
|
133
|
-
assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
|
|
134
|
-
|
|
135
|
-
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
|
|
136
|
-
if qubit_idx_ssas is None:
|
|
137
|
-
return RewriteResult()
|
|
138
|
-
|
|
139
|
-
# Separate control and target qubits
|
|
140
|
-
target_qubits = []
|
|
141
|
-
ctrl_qubits = []
|
|
142
|
-
for i in range(len(qubit_idx_ssas)):
|
|
143
|
-
if (i % 2) == 0:
|
|
144
|
-
ctrl_qubits.append(qubit_idx_ssas[i])
|
|
145
|
-
else:
|
|
146
|
-
target_qubits.append(qubit_idx_ssas[i])
|
|
147
|
-
|
|
148
|
-
target_qubits = tuple(target_qubits)
|
|
149
|
-
ctrl_qubits = tuple(ctrl_qubits)
|
|
150
|
-
|
|
151
|
-
stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate))
|
|
152
|
-
if stim_gate is None:
|
|
153
|
-
return RewriteResult()
|
|
154
|
-
|
|
155
|
-
stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
|
|
156
|
-
|
|
157
|
-
if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
|
|
158
|
-
create_wire_passthrough(stmt_with_ctrl)
|
|
159
|
-
|
|
160
|
-
stmt_with_ctrl.replace_by(stim_stmt)
|
|
161
|
-
|
|
162
|
-
return RewriteResult(has_done_something=True)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def rewrite_QubitLoss(
|
|
166
|
-
stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
|
|
167
|
-
) -> RewriteResult:
|
|
168
|
-
"""
|
|
169
|
-
Rewrite QubitLoss statements to Stim's TrivialError.
|
|
170
|
-
"""
|
|
171
|
-
|
|
172
|
-
squin_loss_op = stmt.operator.owner
|
|
173
|
-
assert isinstance(squin_loss_op, squin_noise.stmts.QubitLoss)
|
|
174
|
-
|
|
175
|
-
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
|
|
176
|
-
if qubit_idx_ssas is None:
|
|
177
|
-
return RewriteResult()
|
|
178
|
-
|
|
179
|
-
stim_loss_stmt = stim_noise.QubitLoss(
|
|
180
|
-
targets=qubit_idx_ssas,
|
|
181
|
-
probs=(squin_loss_op.p,),
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
if isinstance(stmt, (wire.Apply, wire.Broadcast)):
|
|
185
|
-
create_wire_passthrough(stmt)
|
|
186
|
-
|
|
187
|
-
stmt.replace_by(stim_loss_stmt)
|
|
188
|
-
|
|
189
|
-
return RewriteResult(has_done_something=True)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def create_wire_passthrough(stmt: wire.Apply | wire.Broadcast) -> None:
|
|
193
|
-
|
|
194
|
-
for input_wire, output_wire in zip(stmt.inputs, stmt.results):
|
|
195
|
-
# have to "reroute" the input of these statements to directly plug in
|
|
196
|
-
# to subsequent statements, remove dependency on the current statement
|
|
197
|
-
output_wire.replace_by(input_wire)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def is_measure_result_used(
|
|
201
|
-
stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
|
|
202
|
-
) -> bool:
|
|
203
|
-
"""
|
|
204
|
-
Check if the result of a measure statement is used in the program.
|
|
205
|
-
"""
|
|
206
|
-
return bool(stmt.result.uses)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
T = TypeVar("T")
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
def get_const_value(typ: type[T], value: ir.SSAValue) -> T:
|
|
213
|
-
if isinstance(hint := value.hints.get("const"), const.Value):
|
|
214
|
-
data = hint.data
|
|
215
|
-
if isinstance(data, typ):
|
|
216
|
-
return hint.data
|
|
217
|
-
raise interp.InterpreterError(
|
|
218
|
-
f"Expected constant value <type = {typ}>, got {data}"
|
|
219
|
-
)
|
|
220
|
-
raise interp.InterpreterError(
|
|
221
|
-
f"Expected constant value <type = {typ}>, got {value}"
|
|
222
|
-
)
|
bloqade/test_utils.py
CHANGED
|
@@ -25,7 +25,7 @@ def print_diff(node: pprint.Printable, expected_node: pprint.Printable):
|
|
|
25
25
|
|
|
26
26
|
def assert_nodes(node: ir.IRNode, expected_node: ir.IRNode):
|
|
27
27
|
try:
|
|
28
|
-
assert node.
|
|
28
|
+
assert node.is_structurally_equal(expected_node)
|
|
29
29
|
except AssertionError as e:
|
|
30
30
|
print_diff(node, expected_node)
|
|
31
31
|
raise e
|
bloqade/types.py
CHANGED
|
@@ -22,3 +22,13 @@ class Qubit(ABC):
|
|
|
22
22
|
|
|
23
23
|
QubitType = types.PyClass(Qubit)
|
|
24
24
|
"""Kirin type for a qubit."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MeasurementResult:
|
|
28
|
+
"""Runtime representation of the result of a measurement on a qubit."""
|
|
29
|
+
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
MeasurementResultType = types.PyClass(MeasurementResult)
|
|
34
|
+
"""Kirin type for a measurement result."""
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from dataclasses import field, dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir
|
|
5
|
+
from kirin.analysis import ForwardExtra, ForwardFrame
|
|
6
|
+
|
|
7
|
+
from .lattice import ErrorType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ValidationFrame(ForwardFrame[ErrorType]):
|
|
12
|
+
# NOTE: cannot be set[Error] since that's not hashable
|
|
13
|
+
errors: list[ir.ValidationError] = field(default_factory=list)
|
|
14
|
+
"""List of all ecnountered errors.
|
|
15
|
+
|
|
16
|
+
Append a `kirin.ir.ValidationError` to this list in the method implementation
|
|
17
|
+
in order for it to get picked up by the `KernelValidation` run.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC):
|
|
23
|
+
"""Analysis pass that indicates errors in the IR according to the respective method tables.
|
|
24
|
+
|
|
25
|
+
If you need to implement validation for a dialect shared by many groups (for example, if you need to ascertain if statements have a specific form)
|
|
26
|
+
you'll need to inherit from this class.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
lattice = ErrorType
|
|
30
|
+
|
|
31
|
+
def eval_fallback(self, frame: ValidationFrame, node: ir.Statement):
|
|
32
|
+
# NOTE: default to no errors
|
|
33
|
+
return tuple(self.lattice.top() for _ in node.results)
|
|
34
|
+
|
|
35
|
+
def initialize_frame(
|
|
36
|
+
self, node: ir.Statement, *, has_parent_access: bool = False
|
|
37
|
+
) -> ValidationFrame:
|
|
38
|
+
return ValidationFrame(node, has_parent_access=has_parent_access)
|
|
39
|
+
|
|
40
|
+
def method_self(self, method: ir.Method) -> ErrorType:
|
|
41
|
+
return self.lattice.top()
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import final
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin.lattice import (
|
|
5
|
+
SingletonMeta,
|
|
6
|
+
BoundedLattice,
|
|
7
|
+
IsSubsetEqMixin,
|
|
8
|
+
SimpleJoinMixin,
|
|
9
|
+
SimpleMeetMixin,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ErrorType(
|
|
15
|
+
SimpleJoinMixin["ErrorType"],
|
|
16
|
+
SimpleMeetMixin["ErrorType"],
|
|
17
|
+
IsSubsetEqMixin["ErrorType"],
|
|
18
|
+
BoundedLattice["ErrorType"],
|
|
19
|
+
):
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def bottom(cls) -> "ErrorType":
|
|
23
|
+
return InvalidErrorType()
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def top(cls) -> "ErrorType":
|
|
27
|
+
return NoError()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@final
|
|
31
|
+
@dataclass
|
|
32
|
+
class InvalidErrorType(ErrorType, metaclass=SingletonMeta):
|
|
33
|
+
"""Bottom to represent when we encounter an error running the analysis.
|
|
34
|
+
|
|
35
|
+
When this is encountered, it means there might be an error, but we were unable to tell.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@final
|
|
42
|
+
@dataclass
|
|
43
|
+
class Error(ErrorType):
|
|
44
|
+
"""Indicates an error in the IR."""
|
|
45
|
+
|
|
46
|
+
message: str = ""
|
|
47
|
+
"""Optional error message to show in the IR.
|
|
48
|
+
|
|
49
|
+
NOTE: this is just to show a message when printing the IR. Actual errors
|
|
50
|
+
are collected by appending ir.ValidationError to the frame in the method
|
|
51
|
+
implementation.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@final
|
|
56
|
+
@dataclass
|
|
57
|
+
class NoError(ErrorType, metaclass=SingletonMeta):
|
|
58
|
+
pass
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from kirin import ir, exception
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
|
|
7
|
+
from .analysis import ValidationAnalysis
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ValidationErrorGroup(BaseException):
|
|
11
|
+
def __init__(self, *args: object, errors=[]) -> None:
|
|
12
|
+
super().__init__(*args)
|
|
13
|
+
self.errors = errors
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# TODO: this overrides kirin's exception handler and should be upstreamed
|
|
17
|
+
def exception_handler(exc_type, exc_value, exc_tb):
|
|
18
|
+
if issubclass(exc_type, ValidationErrorGroup):
|
|
19
|
+
console = Console(force_terminal=True)
|
|
20
|
+
for i, err in enumerate(exc_value.errors):
|
|
21
|
+
with console.capture() as capture:
|
|
22
|
+
console.print(f"==== Error {i} ====")
|
|
23
|
+
console.print(f"[bold red]{type(err).__name__}:[/bold red]", end="")
|
|
24
|
+
print(capture.get(), *err.args, file=sys.stderr)
|
|
25
|
+
if err.source:
|
|
26
|
+
print("Source Traceback:", file=sys.stderr)
|
|
27
|
+
print(err.hint(), file=sys.stderr, end="")
|
|
28
|
+
console.print("=" * 40)
|
|
29
|
+
console.print(
|
|
30
|
+
"[bold red]Kernel validation failed:[/bold red] There were multiple errors encountered during validation, see above"
|
|
31
|
+
)
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
return exception.exception_handler(exc_type, exc_value, exc_tb)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
sys.excepthook = exception_handler
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class KernelValidation:
|
|
42
|
+
"""Validate a kernel according to a `ValidationAnalysis`.
|
|
43
|
+
|
|
44
|
+
This is a simple wrapper around the analysis that runs the analysis, checks
|
|
45
|
+
the `ValidationFrame` for errors and throws them if there are any.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
validation_analysis_cls: type[ValidationAnalysis]
|
|
49
|
+
"""The analysis that you want to run in order to validate the kernel."""
|
|
50
|
+
|
|
51
|
+
def run(self, mt: ir.Method, no_raise: bool = True) -> None:
|
|
52
|
+
"""Run the kernel validation analysis and raise any errors found.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
mt (ir.Method): The method to validate
|
|
56
|
+
no_raise (bool): Whether or not to raise errors when running the analysis.
|
|
57
|
+
This is only to make sure the analysis works. Errors found during
|
|
58
|
+
the analysis will be raised regardless of this setting. Defaults to `True`.
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
validation_analysis = self.validation_analysis_cls(mt.dialects)
|
|
63
|
+
|
|
64
|
+
if no_raise:
|
|
65
|
+
validation_frame, _ = validation_analysis.run_no_raise(mt)
|
|
66
|
+
else:
|
|
67
|
+
validation_frame, _ = validation_analysis.run(mt)
|
|
68
|
+
|
|
69
|
+
errors = validation_frame.errors
|
|
70
|
+
|
|
71
|
+
if len(errors) == 0:
|
|
72
|
+
# Valid program
|
|
73
|
+
return
|
|
74
|
+
elif len(errors) == 1:
|
|
75
|
+
raise errors[0]
|
|
76
|
+
else:
|
|
77
|
+
raise ValidationErrorGroup(errors=errors)
|
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bloqade-circuit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.9.1
|
|
4
4
|
Summary: The software development toolkit for neutral atom arrays.
|
|
5
5
|
Author-email: Roger-luo <rluo@quera.com>, kaihsin <khwu@quera.com>, weinbe58 <pweinberg@quera.com>, johnzl-777 <jlong@quera.com>
|
|
6
6
|
License-File: LICENSE
|
|
7
7
|
Requires-Python: >=3.10
|
|
8
|
-
Requires-Dist: kirin-toolchain~=0.
|
|
8
|
+
Requires-Dist: kirin-toolchain~=0.21.0
|
|
9
9
|
Requires-Dist: numpy>=1.22.0
|
|
10
10
|
Requires-Dist: pandas>=2.2.3
|
|
11
11
|
Requires-Dist: pydantic<2.11.0,>=1.3.0
|
|
12
|
-
Requires-Dist: pyqrack-cpu
|
|
13
|
-
Requires-Dist: pyqrack<1.41,>=1.38.2; sys_platform == 'darwin'
|
|
12
|
+
Requires-Dist: pyqrack-cpu~=1.69.1
|
|
14
13
|
Requires-Dist: rich>=13.9.4
|
|
15
14
|
Requires-Dist: scipy>=1.13.1
|
|
16
15
|
Provides-Extra: cirq
|
|
@@ -18,9 +17,9 @@ Requires-Dist: cirq-core>=1.4.1; extra == 'cirq'
|
|
|
18
17
|
Requires-Dist: cirq-core[contrib]>=1.4.1; extra == 'cirq'
|
|
19
18
|
Requires-Dist: qpsolvers[clarabel]>=4.7.0; extra == 'cirq'
|
|
20
19
|
Provides-Extra: pyqrack-cuda
|
|
21
|
-
Requires-Dist: pyqrack-cuda
|
|
20
|
+
Requires-Dist: pyqrack-cuda~=1.69.1; extra == 'pyqrack-cuda'
|
|
22
21
|
Provides-Extra: pyqrack-opencl
|
|
23
|
-
Requires-Dist: pyqrack
|
|
22
|
+
Requires-Dist: pyqrack~=1.69.1; extra == 'pyqrack-opencl'
|
|
24
23
|
Provides-Extra: qasm2
|
|
25
24
|
Requires-Dist: lark>=1.2.2; extra == 'qasm2'
|
|
26
25
|
Provides-Extra: qbraid
|