bloqade-circuit 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (79) hide show
  1. bloqade/analysis/address/impls.py +14 -0
  2. bloqade/noise/fidelity.py +3 -3
  3. bloqade/noise/native/_dialect.py +1 -1
  4. bloqade/noise/native/_wrappers.py +35 -6
  5. bloqade/noise/native/stmts.py +1 -1
  6. bloqade/pyqrack/device.py +1 -3
  7. bloqade/pyqrack/qasm2/core.py +4 -1
  8. bloqade/pyqrack/squin/qubit.py +16 -9
  9. bloqade/pyqrack/squin/wire.py +22 -4
  10. bloqade/pyqrack/task.py +13 -5
  11. bloqade/qasm2/__init__.py +1 -0
  12. bloqade/qasm2/_qasm_loading.py +151 -0
  13. bloqade/qasm2/dialects/core/__init__.py +9 -1
  14. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  15. bloqade/qasm2/dialects/noise.py +33 -1
  16. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  17. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  18. bloqade/qasm2/emit/impls/__init__.py +1 -0
  19. bloqade/qasm2/emit/impls/noise_native.py +89 -0
  20. bloqade/qasm2/emit/main.py +21 -0
  21. bloqade/qasm2/emit/target.py +20 -5
  22. bloqade/qasm2/groups.py +2 -0
  23. bloqade/qasm2/parse/__init__.py +7 -4
  24. bloqade/qasm2/parse/lowering.py +20 -130
  25. bloqade/qasm2/parse/qasm2.lark +1 -1
  26. bloqade/qasm2/passes/__init__.py +1 -0
  27. bloqade/qasm2/passes/fold.py +6 -0
  28. bloqade/qasm2/passes/noise.py +22 -2
  29. bloqade/qasm2/passes/parallel.py +9 -0
  30. bloqade/qasm2/passes/unroll_if.py +25 -0
  31. bloqade/qasm2/rewrite/__init__.py +1 -0
  32. bloqade/qasm2/rewrite/desugar.py +3 -2
  33. bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
  34. bloqade/qasm2/rewrite/native_gates.py +67 -4
  35. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  36. bloqade/squin/analysis/nsites/__init__.py +1 -0
  37. bloqade/squin/analysis/nsites/impls.py +25 -1
  38. bloqade/squin/noise/__init__.py +7 -26
  39. bloqade/squin/noise/_wrapper.py +25 -0
  40. bloqade/squin/op/__init__.py +33 -159
  41. bloqade/squin/op/_wrapper.py +101 -0
  42. bloqade/squin/op/stdlib.py +62 -0
  43. bloqade/squin/passes/__init__.py +1 -0
  44. bloqade/squin/passes/stim.py +68 -0
  45. bloqade/squin/rewrite/__init__.py +11 -0
  46. bloqade/squin/rewrite/qubit_to_stim.py +84 -0
  47. bloqade/squin/rewrite/squin_measure.py +98 -0
  48. bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
  49. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  50. bloqade/squin/rewrite/wire_to_stim.py +73 -0
  51. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  52. bloqade/squin/wire.py +1 -13
  53. bloqade/stim/__init__.py +39 -5
  54. bloqade/stim/_wrappers.py +14 -12
  55. bloqade/stim/dialects/__init__.py +1 -5
  56. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
  57. bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
  58. bloqade/stim/dialects/collapse/__init__.py +13 -2
  59. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
  60. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  61. bloqade/stim/dialects/gate/__init__.py +16 -1
  62. bloqade/stim/dialects/gate/emit.py +1 -1
  63. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  64. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  65. bloqade/stim/dialects/noise/emit.py +1 -1
  66. bloqade/stim/emit/__init__.py +1 -1
  67. bloqade/stim/groups.py +4 -2
  68. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
  69. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +79 -63
  70. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  71. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  72. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  73. /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
  74. /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
  75. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  76. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  77. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  78. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
  79. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,98 @@
1
+ # create rewrite rule name SquinMeasureToStim using kirin
2
+ from kirin import ir
3
+ from kirin.dialects import py
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+
6
+ from bloqade import stim
7
+ from bloqade.squin import wire, qubit
8
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
+ from bloqade.squin.rewrite.stim_rewrite_util import (
10
+ is_measure_result_used,
11
+ insert_qubit_idx_from_address,
12
+ )
13
+
14
+
15
+ class SquinMeasureToStim(RewriteRule):
16
+ """
17
+ Rewrite squin measure-related statements to stim statements.
18
+ """
19
+
20
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
+
22
+ match node:
23
+ case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
24
+ return self.rewrite_Measure(node)
25
+ case qubit.MeasureAndReset() | wire.MeasureAndReset():
26
+ return self.rewrite_MeasureAndReset(node)
27
+ case _:
28
+ return RewriteResult()
29
+
30
+ def rewrite_Measure(
31
+ self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
32
+ ) -> RewriteResult:
33
+ if is_measure_result_used(measure_stmt):
34
+ return RewriteResult()
35
+
36
+ qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt)
37
+ if qubit_idx_ssas is None:
38
+ return RewriteResult()
39
+
40
+ prob_noise_stmt = py.constant.Constant(0.0)
41
+ stim_measure_stmt = stim.collapse.MZ(
42
+ p=prob_noise_stmt.result,
43
+ targets=qubit_idx_ssas,
44
+ )
45
+ prob_noise_stmt.insert_before(measure_stmt)
46
+ measure_stmt.replace_by(stim_measure_stmt)
47
+
48
+ return RewriteResult(has_done_something=True)
49
+
50
+ def rewrite_MeasureAndReset(
51
+ self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
52
+ ) -> RewriteResult:
53
+ if not is_measure_result_used(meas_and_reset_stmt):
54
+ return RewriteResult()
55
+
56
+ qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt)
57
+
58
+ if qubit_idx_ssas is None:
59
+ return RewriteResult()
60
+
61
+ error_p_stmt = py.Constant(0.0)
62
+ stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result)
63
+ stim_rz_stmt = stim.collapse.RZ(
64
+ targets=qubit_idx_ssas,
65
+ )
66
+
67
+ error_p_stmt.insert_before(meas_and_reset_stmt)
68
+ stim_mz_stmt.insert_before(meas_and_reset_stmt)
69
+ meas_and_reset_stmt.replace_by(stim_rz_stmt)
70
+
71
+ return RewriteResult(has_done_something=True)
72
+
73
+ def get_qubit_idx_ssas(
74
+ self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
75
+ ) -> tuple[ir.SSAValue, ...] | None:
76
+ """
77
+ Extract the address attribute and insert qubit indices for the given measure statement.
78
+ """
79
+ match measure_stmt:
80
+ case qubit.MeasureQubit():
81
+ address_attr = measure_stmt.qubit.hints.get("address")
82
+ case qubit.MeasureQubitList():
83
+ address_attr = measure_stmt.qubits.hints.get("address")
84
+ case wire.Measure():
85
+ address_attr = measure_stmt.wire.hints.get("address")
86
+ case _:
87
+ return None
88
+
89
+ if address_attr is None:
90
+ return None
91
+
92
+ assert isinstance(address_attr, AddressAttribute)
93
+
94
+ qubit_idx_ssas = insert_qubit_idx_from_address(
95
+ address=address_attr, stmt_to_insert_before=measure_stmt
96
+ )
97
+
98
+ return qubit_idx_ssas
@@ -0,0 +1,158 @@
1
+ from kirin import ir
2
+ from kirin.dialects import py
3
+ from kirin.rewrite.abc import RewriteResult
4
+
5
+ from bloqade.squin import op, wire, qubit
6
+ from bloqade.stim.dialects import gate
7
+ from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
8
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
+
10
+ SQUIN_STIM_GATE_MAPPING = {
11
+ op.stmts.X: gate.X,
12
+ op.stmts.Y: gate.Y,
13
+ op.stmts.Z: gate.Z,
14
+ op.stmts.H: gate.H,
15
+ op.stmts.S: gate.S,
16
+ op.stmts.Identity: gate.Identity,
17
+ }
18
+
19
+
20
+ def insert_qubit_idx_from_address(
21
+ address: AddressAttribute, stmt_to_insert_before: ir.Statement
22
+ ) -> tuple[ir.SSAValue, ...] | None:
23
+ """
24
+ Extract qubit indices from an AddressAttribute and insert them into the SSA form.
25
+ """
26
+ address_data = address.address
27
+ qubit_idx_ssas = []
28
+
29
+ if isinstance(address_data, AddressTuple):
30
+ for address_qubit in address_data.data:
31
+ if not isinstance(address_qubit, AddressQubit):
32
+ return
33
+ qubit_idx = address_qubit.data
34
+ qubit_idx_stmt = py.Constant(qubit_idx)
35
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
36
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
37
+ elif isinstance(address_data, AddressWire):
38
+ address_qubit = address_data.origin_qubit
39
+ qubit_idx = address_qubit.data
40
+ qubit_idx_stmt = py.Constant(qubit_idx)
41
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
42
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
43
+ else:
44
+ return
45
+
46
+ return tuple(qubit_idx_ssas)
47
+
48
+
49
+ def insert_qubit_idx_from_wire_ssa(
50
+ wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
51
+ ) -> tuple[ir.SSAValue, ...] | None:
52
+ """
53
+ Extract qubit indices from wire SSA values and insert them into the SSA form.
54
+ """
55
+ qubit_idx_ssas = []
56
+ for wire_ssa in wire_ssas:
57
+ address_attribute = wire_ssa.hints.get("address")
58
+ if address_attribute is None:
59
+ return
60
+ assert isinstance(address_attribute, AddressAttribute)
61
+ wire_address = address_attribute.address
62
+ assert isinstance(wire_address, AddressWire)
63
+ qubit_idx = wire_address.origin_qubit.data
64
+ qubit_idx_stmt = py.Constant(qubit_idx)
65
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
66
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
67
+
68
+ return tuple(qubit_idx_ssas)
69
+
70
+
71
+ def insert_qubit_idx_after_apply(
72
+ stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
73
+ ) -> tuple[ir.SSAValue, ...] | None:
74
+ """
75
+ Extract qubit indices from Apply or Broadcast statements.
76
+ """
77
+ if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
78
+ qubits = stmt.qubits
79
+ address_attribute = qubits.hints.get("address")
80
+ if address_attribute is None:
81
+ return
82
+ assert isinstance(address_attribute, AddressAttribute)
83
+ return insert_qubit_idx_from_address(
84
+ address=address_attribute, stmt_to_insert_before=stmt
85
+ )
86
+ elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
87
+ wire_ssas = stmt.inputs
88
+ return insert_qubit_idx_from_wire_ssa(
89
+ wire_ssas=wire_ssas, stmt_to_insert_before=stmt
90
+ )
91
+
92
+
93
+ def rewrite_Control(
94
+ stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
95
+ ) -> RewriteResult:
96
+ """
97
+ Handle control gates for Apply and Broadcast statements.
98
+ """
99
+ ctrl_op = stmt_with_ctrl.operator.owner
100
+ assert isinstance(ctrl_op, op.stmts.Control)
101
+
102
+ ctrl_op_target_gate = ctrl_op.op.owner
103
+ assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
104
+
105
+ qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
106
+ if qubit_idx_ssas is None:
107
+ return RewriteResult()
108
+
109
+ # Separate control and target qubits
110
+ target_qubits = []
111
+ ctrl_qubits = []
112
+ for i in range(len(qubit_idx_ssas)):
113
+ if (i % 2) == 0:
114
+ ctrl_qubits.append(qubit_idx_ssas[i])
115
+ else:
116
+ target_qubits.append(qubit_idx_ssas[i])
117
+
118
+ target_qubits = tuple(target_qubits)
119
+ ctrl_qubits = tuple(ctrl_qubits)
120
+
121
+ supported_gate_mapping = {
122
+ op.stmts.X: gate.CX,
123
+ op.stmts.Y: gate.CY,
124
+ op.stmts.Z: gate.CZ,
125
+ }
126
+
127
+ stim_gate = supported_gate_mapping.get(type(ctrl_op_target_gate))
128
+ if stim_gate is None:
129
+ return RewriteResult()
130
+
131
+ stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
132
+
133
+ if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
134
+ # have to "reroute" the input of these statements to directly plug in
135
+ # to subsequent statements, remove dependency on the current statement
136
+ for input_wire, output_wire in zip(
137
+ stmt_with_ctrl.inputs, stmt_with_ctrl.results
138
+ ):
139
+ output_wire.replace_by(input_wire)
140
+
141
+ stmt_with_ctrl.replace_by(stim_stmt)
142
+
143
+ return RewriteResult(has_done_something=True)
144
+
145
+
146
+ def is_measure_result_used(
147
+ stmt: (
148
+ qubit.MeasureAndReset
149
+ | qubit.MeasureQubit
150
+ | qubit.MeasureQubitList
151
+ | wire.MeasureAndReset
152
+ | wire.Measure
153
+ ),
154
+ ) -> bool:
155
+ """
156
+ Check if the result of a measure statement is used in the program.
157
+ """
158
+ return bool(stmt.result.uses)
@@ -0,0 +1,24 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade.squin import wire
5
+
6
+
7
+ class SquinWireIdentityElimination(RewriteRule):
8
+
9
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
10
+ """
11
+ Handle the case where an unwrap feeds a wire directly into a wrap,
12
+ equivalent to nothing happening/identity operation
13
+
14
+ w = unwrap(qubit)
15
+ wrap(qubit, w)
16
+ """
17
+ if isinstance(node, wire.Wrap):
18
+ wire_origin_stmt = node.wire.owner
19
+ if isinstance(wire_origin_stmt, wire.Unwrap):
20
+ node.delete() # get rid of wrap
21
+ wire_origin_stmt.delete() # get rid of the unwrap
22
+ return RewriteResult(has_done_something=True)
23
+
24
+ return RewriteResult()
@@ -0,0 +1,73 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade import stim
5
+ from bloqade.squin import op, wire
6
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
7
+ from bloqade.squin.rewrite.stim_rewrite_util import (
8
+ SQUIN_STIM_GATE_MAPPING,
9
+ rewrite_Control,
10
+ insert_qubit_idx_from_address,
11
+ insert_qubit_idx_from_wire_ssa,
12
+ )
13
+
14
+
15
+ class SquinWireToStim(RewriteRule):
16
+
17
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
+ match node:
19
+ case wire.Apply() | wire.Broadcast():
20
+ return self.rewrite_Apply_and_Broadcast(node)
21
+ case wire.Reset():
22
+ return self.rewrite_Reset(node)
23
+ case _:
24
+ return RewriteResult()
25
+
26
+ def rewrite_Apply_and_Broadcast(
27
+ self, stmt: wire.Apply | wire.Broadcast
28
+ ) -> RewriteResult:
29
+
30
+ # this is an SSAValue, need it to be the actual operator
31
+ applied_op = stmt.operator.owner
32
+ assert isinstance(applied_op, op.stmts.Operator)
33
+
34
+ if isinstance(applied_op, op.stmts.Control):
35
+ return rewrite_Control(stmt)
36
+
37
+ stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
38
+ if stim_1q_op is None:
39
+ return RewriteResult()
40
+
41
+ qubit_idx_ssas = insert_qubit_idx_from_wire_ssa(
42
+ wire_ssas=stmt.inputs, stmt_to_insert_before=stmt
43
+ )
44
+ if qubit_idx_ssas is None:
45
+ return RewriteResult()
46
+
47
+ stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
48
+
49
+ # Get the wires from the inputs of Apply or Broadcast,
50
+ # then put those as the result of the current stmt
51
+ # before replacing it entirely
52
+ for input_wire, output_wire in zip(stmt.inputs, stmt.results):
53
+ output_wire.replace_by(input_wire)
54
+
55
+ stmt.replace_by(stim_1q_stmt)
56
+
57
+ return RewriteResult(has_done_something=True)
58
+
59
+ def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult:
60
+ address_attr = reset_stmt.wire.hints.get("address")
61
+ if address_attr is None:
62
+ return RewriteResult()
63
+ assert isinstance(address_attr, AddressAttribute)
64
+ qubit_idx_ssas = insert_qubit_idx_from_address(
65
+ address=address_attr, stmt_to_insert_before=reset_stmt
66
+ )
67
+ if qubit_idx_ssas is None:
68
+ return RewriteResult()
69
+
70
+ stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
71
+ reset_stmt.replace_by(stim_rz_stmt)
72
+
73
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+ from kirin.print.printer import Printer
6
+
7
+ from bloqade.squin import op, wire
8
+ from bloqade.analysis.address import Address
9
+ from bloqade.squin.analysis.nsites import Sites
10
+
11
+
12
+ @wire.dialect.register
13
+ @dataclass
14
+ class AddressAttribute(ir.Attribute):
15
+
16
+ name = "Address"
17
+ address: Address
18
+
19
+ def __hash__(self) -> int:
20
+ return hash(self.address)
21
+
22
+ def print_impl(self, printer: Printer) -> None:
23
+ # Can return to implementing this later
24
+ printer.print(self.address)
25
+
26
+
27
+ @op.dialect.register
28
+ @dataclass
29
+ class SitesAttribute(ir.Attribute):
30
+
31
+ name = "Sites"
32
+ sites: Sites
33
+
34
+ def __hash__(self) -> int:
35
+ return hash(self.sites)
36
+
37
+ def print_impl(self, printer: Printer) -> None:
38
+ # Can return to implementing this later
39
+ printer.print(self.sites)
40
+
41
+
42
+ @dataclass
43
+ class WrapSquinAnalysis(RewriteRule):
44
+
45
+ address_analysis: dict[ir.SSAValue, Address]
46
+ op_site_analysis: dict[ir.SSAValue, Sites]
47
+
48
+ def wrap(self, value: ir.SSAValue) -> bool:
49
+ address_analysis_result = self.address_analysis[value]
50
+ op_site_analysis_result = self.op_site_analysis[value]
51
+
52
+ if value.hints.get("address") and value.hints.get("sites"):
53
+ return False
54
+ else:
55
+ value.hints["address"] = AddressAttribute(address_analysis_result)
56
+ value.hints["sites"] = SitesAttribute(op_site_analysis_result)
57
+
58
+ return True
59
+
60
+ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
61
+ has_done_something = False
62
+ for arg in node.args:
63
+ if self.wrap(arg):
64
+ has_done_something = True
65
+ return RewriteResult(has_done_something=has_done_something)
66
+
67
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
68
+ has_done_something = False
69
+ for result in node.results:
70
+ if self.wrap(result):
71
+ has_done_something = True
72
+ return RewriteResult(has_done_something=has_done_something)
bloqade/squin/wire.py CHANGED
@@ -6,7 +6,7 @@ circuits. Thus we do not define wrapping functions for the statements in this
6
6
  dialect.
7
7
  """
8
8
 
9
- from kirin import ir, types, interp, lowering
9
+ from kirin import ir, types, lowering
10
10
  from kirin.decl import info, statement
11
11
  from kirin.lowering import wraps
12
12
 
@@ -112,18 +112,6 @@ class Reset(ir.Statement):
112
112
  wire: ir.SSAValue = info.argument(WireType)
113
113
 
114
114
 
115
- # Issue where constant propagation can't handle
116
- # multiple return values from Apply properly
117
- @dialect.register(key="constprop")
118
- class ConstPropWire(interp.MethodTable):
119
-
120
- @interp.impl(Apply)
121
- @interp.impl(Broadcast)
122
- def apply(self, interp, frame, stmt: Apply):
123
-
124
- return frame.get_values(stmt.inputs)
125
-
126
-
127
115
  @wraps(Unwrap)
128
116
  def unwrap(qubit: Qubit) -> Wire: ...
129
117
 
bloqade/stim/__init__.py CHANGED
@@ -1,6 +1,40 @@
1
+ from . import emit as emit, dialects as dialects
1
2
  from .groups import main as main
2
- from ._wrappers import * # noqa: F403
3
- from .dialects.aux import * # noqa F403
4
- from .dialects.gate import * # noqa F403
5
- from .dialects.noise import * # noqa F403
6
- from .dialects.collapse import * # noqa F403
3
+ from ._wrappers import (
4
+ h as h,
5
+ s as s,
6
+ x as x,
7
+ y as y,
8
+ z as z,
9
+ cx as cx,
10
+ cy as cy,
11
+ cz as cz,
12
+ mx as mx,
13
+ my as my,
14
+ mz as mz,
15
+ rx as rx,
16
+ ry as ry,
17
+ rz as rz,
18
+ mpp as mpp,
19
+ mxx as mxx,
20
+ myy as myy,
21
+ mzz as mzz,
22
+ rec as rec,
23
+ spp as spp,
24
+ swap as swap,
25
+ tick as tick,
26
+ sqrt_x as sqrt_x,
27
+ sqrt_y as sqrt_y,
28
+ sqrt_z as sqrt_z,
29
+ x_error as x_error,
30
+ y_error as y_error,
31
+ z_error as z_error,
32
+ detector as detector,
33
+ identity as identity,
34
+ depolarize1 as depolarize1,
35
+ depolarize2 as depolarize2,
36
+ pauli_string as pauli_string,
37
+ pauli_channel1 as pauli_channel1,
38
+ pauli_channel2 as pauli_channel2,
39
+ observable_include as observable_include,
40
+ )
bloqade/stim/_wrappers.py CHANGED
@@ -2,7 +2,7 @@ from typing import Union
2
2
 
3
3
  from kirin.lowering import wraps
4
4
 
5
- from .dialects import aux, gate, noise, collapse
5
+ from .dialects import gate, noise, collapse, auxiliary
6
6
 
7
7
 
8
8
  # dialect:: gate
@@ -69,32 +69,34 @@ def cz(
69
69
 
70
70
  ## pp
71
71
  @wraps(gate.SPP)
72
- def spp(targets: tuple[aux.PauliString, ...], dagger=False) -> None: ...
72
+ def spp(targets: tuple[auxiliary.PauliString, ...], dagger=False) -> None: ...
73
73
 
74
74
 
75
75
  # dialect:: aux
76
- @wraps(aux.GetRecord)
77
- def rec(id: int) -> aux.RecordResult: ...
76
+ @wraps(auxiliary.GetRecord)
77
+ def rec(id: int) -> auxiliary.RecordResult: ...
78
78
 
79
79
 
80
- @wraps(aux.Detector)
80
+ @wraps(auxiliary.Detector)
81
81
  def detector(
82
- coord: tuple[Union[int, float], ...], targets: tuple[aux.RecordResult, ...]
82
+ coord: tuple[Union[int, float], ...], targets: tuple[auxiliary.RecordResult, ...]
83
83
  ) -> None: ...
84
84
 
85
85
 
86
- @wraps(aux.ObservableInclude)
87
- def observable_include(idx: int, targets: tuple[aux.RecordResult, ...]) -> None: ...
86
+ @wraps(auxiliary.ObservableInclude)
87
+ def observable_include(
88
+ idx: int, targets: tuple[auxiliary.RecordResult, ...]
89
+ ) -> None: ...
88
90
 
89
91
 
90
- @wraps(aux.Tick)
92
+ @wraps(auxiliary.Tick)
91
93
  def tick() -> None: ...
92
94
 
93
95
 
94
- @wraps(aux.NewPauliString)
96
+ @wraps(auxiliary.NewPauliString)
95
97
  def pauli_string(
96
98
  string: tuple[str, ...], flipped: tuple[bool, ...], targets: tuple[int, ...]
97
- ) -> aux.PauliString: ...
99
+ ) -> auxiliary.PauliString: ...
98
100
 
99
101
 
100
102
  # dialect:: collapse
@@ -123,7 +125,7 @@ def mxx(p: float, targets: tuple[int, ...]) -> None: ...
123
125
 
124
126
 
125
127
  @wraps(collapse.PPMeasurement)
126
- def mpp(p: float, targets: tuple[aux.PauliString, ...]) -> None: ...
128
+ def mpp(p: float, targets: tuple[auxiliary.PauliString, ...]) -> None: ...
127
129
 
128
130
 
129
131
  @wraps(collapse.RZ)
@@ -1,5 +1 @@
1
- from . import aux as aux, gate as gate, noise as noise, collapse as collapse
2
- from .aux.stmts import * # noqa F403
3
- from .gate.stmts import * # noqa F403
4
- from .noise.stmts import * # noqa F403
5
- from .collapse.stmts import * # noqa F403
1
+ from . import gate as gate, noise as noise, collapse as collapse, auxiliary as auxiliary
@@ -1,6 +1,17 @@
1
1
  from . import lowering as lowering
2
2
  from .emit import EmitStimAuxMethods as EmitStimAuxMethods
3
- from .stmts import * # noqa F403
3
+ from .stmts import (
4
+ Neg as Neg,
5
+ Tick as Tick,
6
+ ConstInt as ConstInt,
7
+ ConstStr as ConstStr,
8
+ Detector as Detector,
9
+ ConstBool as ConstBool,
10
+ GetRecord as GetRecord,
11
+ ConstFloat as ConstFloat,
12
+ NewPauliString as NewPauliString,
13
+ ObservableInclude as ObservableInclude,
14
+ )
4
15
  from .types import (
5
16
  RecordType as RecordType,
6
17
  PauliString as PauliString,
@@ -1,7 +1,7 @@
1
1
  from kirin.emit import EmitStrFrame
2
2
  from kirin.interp import MethodTable, impl
3
3
 
4
- from bloqade.stim.emit.stim import EmitStimMain
4
+ from bloqade.stim.emit.stim_str import EmitStimMain
5
5
 
6
6
  from . import stmts
7
7
  from ._dialect import dialect
@@ -1,3 +1,14 @@
1
- from .emit import EmitStimCollapseMethods as EmitStimCollapseMethods
2
- from .stmts import * # noqa F403
1
+ from .stmts import (
2
+ MX as MX,
3
+ MY as MY,
4
+ MZ as MZ,
5
+ RX as RX,
6
+ RY as RY,
7
+ RZ as RZ,
8
+ MXX as MXX,
9
+ MYY as MYY,
10
+ MZZ as MZZ,
11
+ PPMeasurement as PPMeasurement,
12
+ )
3
13
  from ._dialect import dialect as dialect
14
+ from .emit_str import EmitStimCollapseMethods as EmitStimCollapseMethods
@@ -1,7 +1,7 @@
1
1
  from kirin.emit import EmitStrFrame
2
2
  from kirin.interp import MethodTable, impl
3
3
 
4
- from bloqade.stim.emit.stim import EmitStimMain
4
+ from bloqade.stim.emit.stim_str import EmitStimMain
5
5
 
6
6
  from . import stmts
7
7
  from ._dialect import dialect
@@ -2,7 +2,7 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
4
  from .._dialect import dialect
5
- from ...aux.types import PauliStringType
5
+ from ...auxiliary.types import PauliStringType
6
6
 
7
7
 
8
8
  @statement(dialect=dialect)
@@ -1,3 +1,18 @@
1
1
  from .emit import EmitStimGateMethods as EmitStimGateMethods
2
- from .stmts import * # noqa F403
2
+ from .stmts import (
3
+ CX as CX,
4
+ CY as CY,
5
+ CZ as CZ,
6
+ SPP as SPP,
7
+ H as H,
8
+ S as S,
9
+ X as X,
10
+ Y as Y,
11
+ Z as Z,
12
+ Swap as Swap,
13
+ SqrtX as SqrtX,
14
+ SqrtY as SqrtY,
15
+ SqrtZ as SqrtZ,
16
+ Identity as Identity,
17
+ )
3
18
  from ._dialect import dialect as dialect
@@ -1,7 +1,7 @@
1
1
  from kirin.emit import EmitStrFrame
2
2
  from kirin.interp import MethodTable, impl
3
3
 
4
- from bloqade.stim.emit.stim import EmitStimMain
4
+ from bloqade.stim.emit.stim_str import EmitStimMain
5
5
 
6
6
  from . import stmts
7
7
  from ._dialect import dialect
@@ -1,7 +1,7 @@
1
1
  from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
- from bloqade.stim.dialects.aux import RecordType
4
+ from bloqade.stim.dialects.auxiliary import RecordType
5
5
 
6
6
 
7
7
  @statement
@@ -2,7 +2,7 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
4
  from .._dialect import dialect
5
- from ...aux.types import PauliStringType
5
+ from ...auxiliary.types import PauliStringType
6
6
 
7
7
 
8
8
  # Generalized Pauli-product gates