bloqade-circuit 0.2.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (101) hide show
  1. bloqade/analysis/address/impls.py +3 -2
  2. bloqade/pyqrack/device.py +1 -3
  3. bloqade/pyqrack/noise/native.py +8 -8
  4. bloqade/pyqrack/qasm2/core.py +4 -1
  5. bloqade/pyqrack/squin/op.py +7 -0
  6. bloqade/pyqrack/squin/qubit.py +5 -27
  7. bloqade/pyqrack/squin/runtime.py +18 -0
  8. bloqade/pyqrack/squin/wire.py +4 -22
  9. bloqade/pyqrack/task.py +13 -5
  10. bloqade/qasm2/__init__.py +1 -0
  11. bloqade/qasm2/_qasm_loading.py +151 -0
  12. bloqade/qasm2/dialects/core/__init__.py +9 -1
  13. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  14. bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
  15. bloqade/qasm2/dialects/noise/_dialect.py +3 -0
  16. bloqade/{noise → qasm2/dialects/noise}/fidelity.py +4 -4
  17. bloqade/qasm2/dialects/noise/model.py +278 -0
  18. bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +1 -1
  19. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  20. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  21. bloqade/qasm2/emit/impls/__init__.py +1 -0
  22. bloqade/qasm2/emit/impls/noise.py +89 -0
  23. bloqade/qasm2/emit/main.py +23 -4
  24. bloqade/qasm2/emit/target.py +19 -4
  25. bloqade/qasm2/noise.py +67 -0
  26. bloqade/qasm2/parse/__init__.py +7 -4
  27. bloqade/qasm2/parse/lowering.py +20 -130
  28. bloqade/qasm2/parse/qasm2.lark +1 -1
  29. bloqade/qasm2/passes/__init__.py +1 -0
  30. bloqade/qasm2/passes/fold.py +6 -0
  31. bloqade/qasm2/passes/glob.py +12 -8
  32. bloqade/qasm2/passes/noise.py +27 -16
  33. bloqade/qasm2/passes/parallel.py +9 -0
  34. bloqade/qasm2/passes/unroll_if.py +25 -0
  35. bloqade/qasm2/rewrite/__init__.py +3 -0
  36. bloqade/qasm2/rewrite/desugar.py +3 -2
  37. bloqade/qasm2/rewrite/native_gates.py +67 -4
  38. bloqade/qasm2/rewrite/noise/__init__.py +0 -0
  39. bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +32 -62
  40. bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
  41. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  42. bloqade/qbraid/lowering.py +8 -8
  43. bloqade/squin/__init__.py +7 -1
  44. bloqade/squin/analysis/nsites/__init__.py +1 -0
  45. bloqade/squin/analysis/nsites/impls.py +16 -1
  46. bloqade/squin/groups.py +4 -4
  47. bloqade/squin/lowering.py +27 -0
  48. bloqade/squin/noise/__init__.py +7 -26
  49. bloqade/squin/noise/_wrapper.py +25 -0
  50. bloqade/squin/op/__init__.py +34 -159
  51. bloqade/squin/op/_wrapper.py +105 -0
  52. bloqade/squin/op/stdlib.py +62 -0
  53. bloqade/squin/op/stmts.py +10 -0
  54. bloqade/squin/passes/__init__.py +1 -0
  55. bloqade/squin/passes/stim.py +68 -0
  56. bloqade/squin/qubit.py +32 -37
  57. bloqade/squin/rewrite/__init__.py +11 -0
  58. bloqade/squin/rewrite/desugar.py +65 -0
  59. bloqade/squin/rewrite/qubit_to_stim.py +61 -0
  60. bloqade/squin/rewrite/squin_measure.py +73 -0
  61. bloqade/squin/rewrite/stim_rewrite_util.py +153 -0
  62. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  63. bloqade/squin/rewrite/wire_to_stim.py +52 -0
  64. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  65. bloqade/squin/wire.py +5 -22
  66. bloqade/stim/__init__.py +40 -5
  67. bloqade/stim/_wrappers.py +18 -12
  68. bloqade/stim/dialects/__init__.py +1 -5
  69. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +13 -1
  70. bloqade/stim/dialects/{aux → auxiliary}/emit.py +18 -3
  71. bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +1 -0
  72. bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +8 -0
  73. bloqade/stim/dialects/collapse/__init__.py +13 -2
  74. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +4 -2
  75. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  76. bloqade/stim/dialects/gate/__init__.py +16 -1
  77. bloqade/stim/dialects/gate/emit.py +10 -3
  78. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  79. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  80. bloqade/stim/dialects/noise/emit.py +33 -2
  81. bloqade/stim/dialects/noise/stmts.py +29 -0
  82. bloqade/stim/emit/__init__.py +1 -1
  83. bloqade/stim/groups.py +4 -2
  84. bloqade/stim/parse/__init__.py +1 -0
  85. bloqade/stim/parse/lowering.py +686 -0
  86. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/METADATA +5 -3
  87. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/RECORD +95 -77
  88. bloqade/noise/__init__.py +0 -2
  89. bloqade/noise/native/_dialect.py +0 -3
  90. bloqade/noise/native/_wrappers.py +0 -34
  91. bloqade/noise/native/model.py +0 -346
  92. bloqade/qasm2/dialects/noise.py +0 -16
  93. bloqade/squin/rewrite/measure_desugar.py +0 -33
  94. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  95. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  96. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  97. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  98. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  99. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  100. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/WHEEL +0 -0
  101. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,153 @@
1
+ from kirin import ir
2
+ from kirin.dialects import py
3
+ from kirin.rewrite.abc import RewriteResult
4
+
5
+ from bloqade.squin import op, wire, qubit
6
+ from bloqade.stim.dialects import gate, collapse
7
+ from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
8
+ from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
+
10
+ SQUIN_STIM_GATE_MAPPING = {
11
+ op.stmts.X: gate.X,
12
+ op.stmts.Y: gate.Y,
13
+ op.stmts.Z: gate.Z,
14
+ op.stmts.H: gate.H,
15
+ op.stmts.S: gate.S,
16
+ op.stmts.Identity: gate.Identity,
17
+ op.stmts.Reset: collapse.RZ,
18
+ }
19
+
20
+
21
+ def insert_qubit_idx_from_address(
22
+ address: AddressAttribute, stmt_to_insert_before: ir.Statement
23
+ ) -> tuple[ir.SSAValue, ...] | None:
24
+ """
25
+ Extract qubit indices from an AddressAttribute and insert them into the SSA form.
26
+ """
27
+ address_data = address.address
28
+ qubit_idx_ssas = []
29
+
30
+ if isinstance(address_data, AddressTuple):
31
+ for address_qubit in address_data.data:
32
+ if not isinstance(address_qubit, AddressQubit):
33
+ return
34
+ qubit_idx = address_qubit.data
35
+ qubit_idx_stmt = py.Constant(qubit_idx)
36
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
37
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
38
+ elif isinstance(address_data, AddressWire):
39
+ address_qubit = address_data.origin_qubit
40
+ qubit_idx = address_qubit.data
41
+ qubit_idx_stmt = py.Constant(qubit_idx)
42
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
43
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
44
+ else:
45
+ return
46
+
47
+ return tuple(qubit_idx_ssas)
48
+
49
+
50
+ def insert_qubit_idx_from_wire_ssa(
51
+ wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement
52
+ ) -> tuple[ir.SSAValue, ...] | None:
53
+ """
54
+ Extract qubit indices from wire SSA values and insert them into the SSA form.
55
+ """
56
+ qubit_idx_ssas = []
57
+ for wire_ssa in wire_ssas:
58
+ address_attribute = wire_ssa.hints.get("address")
59
+ if address_attribute is None:
60
+ return
61
+ assert isinstance(address_attribute, AddressAttribute)
62
+ wire_address = address_attribute.address
63
+ assert isinstance(wire_address, AddressWire)
64
+ qubit_idx = wire_address.origin_qubit.data
65
+ qubit_idx_stmt = py.Constant(qubit_idx)
66
+ qubit_idx_ssas.append(qubit_idx_stmt.result)
67
+ qubit_idx_stmt.insert_before(stmt_to_insert_before)
68
+
69
+ return tuple(qubit_idx_ssas)
70
+
71
+
72
+ def insert_qubit_idx_after_apply(
73
+ stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast,
74
+ ) -> tuple[ir.SSAValue, ...] | None:
75
+ """
76
+ Extract qubit indices from Apply or Broadcast statements.
77
+ """
78
+ if isinstance(stmt, (qubit.Apply, qubit.Broadcast)):
79
+ qubits = stmt.qubits
80
+ address_attribute = qubits.hints.get("address")
81
+ if address_attribute is None:
82
+ return
83
+ assert isinstance(address_attribute, AddressAttribute)
84
+ return insert_qubit_idx_from_address(
85
+ address=address_attribute, stmt_to_insert_before=stmt
86
+ )
87
+ elif isinstance(stmt, (wire.Apply, wire.Broadcast)):
88
+ wire_ssas = stmt.inputs
89
+ return insert_qubit_idx_from_wire_ssa(
90
+ wire_ssas=wire_ssas, stmt_to_insert_before=stmt
91
+ )
92
+
93
+
94
+ def rewrite_Control(
95
+ stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast,
96
+ ) -> RewriteResult:
97
+ """
98
+ Handle control gates for Apply and Broadcast statements.
99
+ """
100
+ ctrl_op = stmt_with_ctrl.operator.owner
101
+ assert isinstance(ctrl_op, op.stmts.Control)
102
+
103
+ ctrl_op_target_gate = ctrl_op.op.owner
104
+ assert isinstance(ctrl_op_target_gate, op.stmts.Operator)
105
+
106
+ qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl)
107
+ if qubit_idx_ssas is None:
108
+ return RewriteResult()
109
+
110
+ # Separate control and target qubits
111
+ target_qubits = []
112
+ ctrl_qubits = []
113
+ for i in range(len(qubit_idx_ssas)):
114
+ if (i % 2) == 0:
115
+ ctrl_qubits.append(qubit_idx_ssas[i])
116
+ else:
117
+ target_qubits.append(qubit_idx_ssas[i])
118
+
119
+ target_qubits = tuple(target_qubits)
120
+ ctrl_qubits = tuple(ctrl_qubits)
121
+
122
+ supported_gate_mapping = {
123
+ op.stmts.X: gate.CX,
124
+ op.stmts.Y: gate.CY,
125
+ op.stmts.Z: gate.CZ,
126
+ }
127
+
128
+ stim_gate = supported_gate_mapping.get(type(ctrl_op_target_gate))
129
+ if stim_gate is None:
130
+ return RewriteResult()
131
+
132
+ stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
133
+
134
+ if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
135
+ # have to "reroute" the input of these statements to directly plug in
136
+ # to subsequent statements, remove dependency on the current statement
137
+ for input_wire, output_wire in zip(
138
+ stmt_with_ctrl.inputs, stmt_with_ctrl.results
139
+ ):
140
+ output_wire.replace_by(input_wire)
141
+
142
+ stmt_with_ctrl.replace_by(stim_stmt)
143
+
144
+ return RewriteResult(has_done_something=True)
145
+
146
+
147
+ def is_measure_result_used(
148
+ stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
149
+ ) -> bool:
150
+ """
151
+ Check if the result of a measure statement is used in the program.
152
+ """
153
+ return bool(stmt.result.uses)
@@ -0,0 +1,24 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade.squin import wire
5
+
6
+
7
+ class SquinWireIdentityElimination(RewriteRule):
8
+
9
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
10
+ """
11
+ Handle the case where an unwrap feeds a wire directly into a wrap,
12
+ equivalent to nothing happening/identity operation
13
+
14
+ w = unwrap(qubit)
15
+ wrap(qubit, w)
16
+ """
17
+ if isinstance(node, wire.Wrap):
18
+ wire_origin_stmt = node.wire.owner
19
+ if isinstance(wire_origin_stmt, wire.Unwrap):
20
+ node.delete() # get rid of wrap
21
+ wire_origin_stmt.delete() # get rid of the unwrap
22
+ return RewriteResult(has_done_something=True)
23
+
24
+ return RewriteResult()
@@ -0,0 +1,52 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
3
+
4
+ from bloqade.squin import op, wire
5
+ from bloqade.squin.rewrite.stim_rewrite_util import (
6
+ SQUIN_STIM_GATE_MAPPING,
7
+ rewrite_Control,
8
+ insert_qubit_idx_from_wire_ssa,
9
+ )
10
+
11
+
12
+ class SquinWireToStim(RewriteRule):
13
+
14
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
15
+ match node:
16
+ case wire.Apply() | wire.Broadcast():
17
+ return self.rewrite_Apply_and_Broadcast(node)
18
+ case _:
19
+ return RewriteResult()
20
+
21
+ def rewrite_Apply_and_Broadcast(
22
+ self, stmt: wire.Apply | wire.Broadcast
23
+ ) -> RewriteResult:
24
+
25
+ # this is an SSAValue, need it to be the actual operator
26
+ applied_op = stmt.operator.owner
27
+ assert isinstance(applied_op, op.stmts.Operator)
28
+
29
+ if isinstance(applied_op, op.stmts.Control):
30
+ return rewrite_Control(stmt)
31
+
32
+ stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
33
+ if stim_1q_op is None:
34
+ return RewriteResult()
35
+
36
+ qubit_idx_ssas = insert_qubit_idx_from_wire_ssa(
37
+ wire_ssas=stmt.inputs, stmt_to_insert_before=stmt
38
+ )
39
+ if qubit_idx_ssas is None:
40
+ return RewriteResult()
41
+
42
+ stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
43
+
44
+ # Get the wires from the inputs of Apply or Broadcast,
45
+ # then put those as the result of the current stmt
46
+ # before replacing it entirely
47
+ for input_wire, output_wire in zip(stmt.inputs, stmt.results):
48
+ output_wire.replace_by(input_wire)
49
+
50
+ stmt.replace_by(stim_1q_stmt)
51
+
52
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+ from kirin.print.printer import Printer
6
+
7
+ from bloqade.squin import op, wire
8
+ from bloqade.analysis.address import Address
9
+ from bloqade.squin.analysis.nsites import Sites
10
+
11
+
12
+ @wire.dialect.register
13
+ @dataclass
14
+ class AddressAttribute(ir.Attribute):
15
+
16
+ name = "Address"
17
+ address: Address
18
+
19
+ def __hash__(self) -> int:
20
+ return hash(self.address)
21
+
22
+ def print_impl(self, printer: Printer) -> None:
23
+ # Can return to implementing this later
24
+ printer.print(self.address)
25
+
26
+
27
+ @op.dialect.register
28
+ @dataclass
29
+ class SitesAttribute(ir.Attribute):
30
+
31
+ name = "Sites"
32
+ sites: Sites
33
+
34
+ def __hash__(self) -> int:
35
+ return hash(self.sites)
36
+
37
+ def print_impl(self, printer: Printer) -> None:
38
+ # Can return to implementing this later
39
+ printer.print(self.sites)
40
+
41
+
42
+ @dataclass
43
+ class WrapSquinAnalysis(RewriteRule):
44
+
45
+ address_analysis: dict[ir.SSAValue, Address]
46
+ op_site_analysis: dict[ir.SSAValue, Sites]
47
+
48
+ def wrap(self, value: ir.SSAValue) -> bool:
49
+ address_analysis_result = self.address_analysis[value]
50
+ op_site_analysis_result = self.op_site_analysis[value]
51
+
52
+ if value.hints.get("address") and value.hints.get("sites"):
53
+ return False
54
+ else:
55
+ value.hints["address"] = AddressAttribute(address_analysis_result)
56
+ value.hints["sites"] = SitesAttribute(op_site_analysis_result)
57
+
58
+ return True
59
+
60
+ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
61
+ has_done_something = False
62
+ for arg in node.args:
63
+ if self.wrap(arg):
64
+ has_done_something = True
65
+ return RewriteResult(has_done_something=has_done_something)
66
+
67
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
68
+ has_done_something = False
69
+ for result in node.results:
70
+ if self.wrap(result):
71
+ has_done_something = True
72
+ return RewriteResult(has_done_something=has_done_something)
bloqade/squin/wire.py CHANGED
@@ -6,7 +6,7 @@ circuits. Thus we do not define wrapping functions for the statements in this
6
6
  dialect.
7
7
  """
8
8
 
9
- from kirin import ir, types, interp, lowering
9
+ from kirin import ir, types, lowering
10
10
  from kirin.decl import info, statement
11
11
  from kirin.lowering import wraps
12
12
 
@@ -95,35 +95,18 @@ class Broadcast(ir.Statement):
95
95
  class Measure(ir.Statement):
96
96
  traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
97
97
  wire: ir.SSAValue = info.argument(WireType)
98
+ qubit: ir.SSAValue = info.argument(QubitType)
98
99
  result: ir.ResultValue = info.result(types.Int)
99
100
 
100
101
 
101
102
  @statement(dialect=dialect)
102
- class MeasureAndReset(ir.Statement):
103
- traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
104
- wire: ir.SSAValue = info.argument(WireType)
103
+ class NonDestructiveMeasure(ir.Statement):
104
+ traits = frozenset({lowering.FromPythonCall()})
105
+ input_wire: ir.SSAValue = info.argument(WireType)
105
106
  result: ir.ResultValue = info.result(types.Int)
106
107
  out_wire: ir.ResultValue = info.result(WireType)
107
108
 
108
109
 
109
- @statement(dialect=dialect)
110
- class Reset(ir.Statement):
111
- traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
112
- wire: ir.SSAValue = info.argument(WireType)
113
-
114
-
115
- # Issue where constant propagation can't handle
116
- # multiple return values from Apply properly
117
- @dialect.register(key="constprop")
118
- class ConstPropWire(interp.MethodTable):
119
-
120
- @interp.impl(Apply)
121
- @interp.impl(Broadcast)
122
- def apply(self, interp, frame, stmt: Apply):
123
-
124
- return frame.get_values(stmt.inputs)
125
-
126
-
127
110
  @wraps(Unwrap)
128
111
  def unwrap(qubit: Qubit) -> Wire: ...
129
112
 
bloqade/stim/__init__.py CHANGED
@@ -1,6 +1,41 @@
1
+ from . import emit as emit, parse as parse, dialects as dialects
1
2
  from .groups import main as main
2
- from ._wrappers import * # noqa: F403
3
- from .dialects.aux import * # noqa F403
4
- from .dialects.gate import * # noqa F403
5
- from .dialects.noise import * # noqa F403
6
- from .dialects.collapse import * # noqa F403
3
+ from ._wrappers import (
4
+ h as h,
5
+ s as s,
6
+ x as x,
7
+ y as y,
8
+ z as z,
9
+ cx as cx,
10
+ cy as cy,
11
+ cz as cz,
12
+ mx as mx,
13
+ my as my,
14
+ mz as mz,
15
+ rx as rx,
16
+ ry as ry,
17
+ rz as rz,
18
+ mpp as mpp,
19
+ mxx as mxx,
20
+ myy as myy,
21
+ mzz as mzz,
22
+ rec as rec,
23
+ spp as spp,
24
+ swap as swap,
25
+ tick as tick,
26
+ sqrt_x as sqrt_x,
27
+ sqrt_y as sqrt_y,
28
+ sqrt_z as sqrt_z,
29
+ x_error as x_error,
30
+ y_error as y_error,
31
+ z_error as z_error,
32
+ detector as detector,
33
+ identity as identity,
34
+ depolarize1 as depolarize1,
35
+ depolarize2 as depolarize2,
36
+ pauli_string as pauli_string,
37
+ qubit_coords as qubit_coords,
38
+ pauli_channel1 as pauli_channel1,
39
+ pauli_channel2 as pauli_channel2,
40
+ observable_include as observable_include,
41
+ )
bloqade/stim/_wrappers.py CHANGED
@@ -2,7 +2,7 @@ from typing import Union
2
2
 
3
3
  from kirin.lowering import wraps
4
4
 
5
- from .dialects import aux, gate, noise, collapse
5
+ from .dialects import gate, noise, collapse, auxiliary
6
6
 
7
7
 
8
8
  # dialect:: gate
@@ -69,32 +69,38 @@ def cz(
69
69
 
70
70
  ## pp
71
71
  @wraps(gate.SPP)
72
- def spp(targets: tuple[aux.PauliString, ...], dagger=False) -> None: ...
72
+ def spp(targets: tuple[auxiliary.PauliString, ...], dagger=False) -> None: ...
73
73
 
74
74
 
75
75
  # dialect:: aux
76
- @wraps(aux.GetRecord)
77
- def rec(id: int) -> aux.RecordResult: ...
76
+ @wraps(auxiliary.GetRecord)
77
+ def rec(id: int) -> auxiliary.RecordResult: ...
78
78
 
79
79
 
80
- @wraps(aux.Detector)
80
+ @wraps(auxiliary.Detector)
81
81
  def detector(
82
- coord: tuple[Union[int, float], ...], targets: tuple[aux.RecordResult, ...]
82
+ coord: tuple[Union[int, float], ...], targets: tuple[auxiliary.RecordResult, ...]
83
83
  ) -> None: ...
84
84
 
85
85
 
86
- @wraps(aux.ObservableInclude)
87
- def observable_include(idx: int, targets: tuple[aux.RecordResult, ...]) -> None: ...
86
+ @wraps(auxiliary.ObservableInclude)
87
+ def observable_include(
88
+ idx: int, targets: tuple[auxiliary.RecordResult, ...]
89
+ ) -> None: ...
88
90
 
89
91
 
90
- @wraps(aux.Tick)
92
+ @wraps(auxiliary.Tick)
91
93
  def tick() -> None: ...
92
94
 
93
95
 
94
- @wraps(aux.NewPauliString)
96
+ @wraps(auxiliary.NewPauliString)
95
97
  def pauli_string(
96
98
  string: tuple[str, ...], flipped: tuple[bool, ...], targets: tuple[int, ...]
97
- ) -> aux.PauliString: ...
99
+ ) -> auxiliary.PauliString: ...
100
+
101
+
102
+ @wraps(auxiliary.QubitCoordinates)
103
+ def qubit_coords(coord: tuple[Union[int, float], ...], target: int) -> None: ...
98
104
 
99
105
 
100
106
  # dialect:: collapse
@@ -123,7 +129,7 @@ def mxx(p: float, targets: tuple[int, ...]) -> None: ...
123
129
 
124
130
 
125
131
  @wraps(collapse.PPMeasurement)
126
- def mpp(p: float, targets: tuple[aux.PauliString, ...]) -> None: ...
132
+ def mpp(p: float, targets: tuple[auxiliary.PauliString, ...]) -> None: ...
127
133
 
128
134
 
129
135
  @wraps(collapse.RZ)
@@ -1,5 +1 @@
1
- from . import aux as aux, gate as gate, noise as noise, collapse as collapse
2
- from .aux.stmts import * # noqa F403
3
- from .gate.stmts import * # noqa F403
4
- from .noise.stmts import * # noqa F403
5
- from .collapse.stmts import * # noqa F403
1
+ from . import gate as gate, noise as noise, collapse as collapse, auxiliary as auxiliary
@@ -1,6 +1,18 @@
1
1
  from . import lowering as lowering
2
2
  from .emit import EmitStimAuxMethods as EmitStimAuxMethods
3
- from .stmts import * # noqa F403
3
+ from .stmts import (
4
+ Neg as Neg,
5
+ Tick as Tick,
6
+ ConstInt as ConstInt,
7
+ ConstStr as ConstStr,
8
+ Detector as Detector,
9
+ ConstBool as ConstBool,
10
+ GetRecord as GetRecord,
11
+ ConstFloat as ConstFloat,
12
+ NewPauliString as NewPauliString,
13
+ QubitCoordinates as QubitCoordinates,
14
+ ObservableInclude as ObservableInclude,
15
+ )
4
16
  from .types import (
5
17
  RecordType as RecordType,
6
18
  PauliString as PauliString,
@@ -1,7 +1,7 @@
1
1
  from kirin.emit import EmitStrFrame
2
2
  from kirin.interp import MethodTable, impl
3
3
 
4
- from bloqade.stim.emit.stim import EmitStimMain
4
+ from bloqade.stim.emit.stim_str import EmitStimMain
5
5
 
6
6
  from . import stmts
7
7
  from ._dialect import dialect
@@ -69,8 +69,10 @@ class EmitStimAuxMethods(MethodTable):
69
69
 
70
70
  coord_str: str = ", ".join(coords)
71
71
  target_str: str = " ".join(targets)
72
- emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
73
-
72
+ if len(coords):
73
+ emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
74
+ else:
75
+ emit.writeln(frame, f"DETECTOR {target_str}")
74
76
  return ()
75
77
 
76
78
  @impl(stmts.ObservableInclude)
@@ -100,3 +102,16 @@ class EmitStimAuxMethods(MethodTable):
100
102
  )
101
103
 
102
104
  return (out,)
105
+
106
+ @impl(stmts.QubitCoordinates)
107
+ def qubit_coordinates(
108
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.QubitCoordinates
109
+ ):
110
+
111
+ coords: tuple[str, ...] = frame.get_values(stmt.coord)
112
+ target: str = frame.get(stmt.target)
113
+
114
+ coord_str: str = ", ".join(coords)
115
+ emit.writeln(frame, f"QUBIT_COORDS({coord_str}) {target}")
116
+
117
+ return ()
@@ -10,5 +10,6 @@ from .annotate import (
10
10
  Detector as Detector,
11
11
  GetRecord as GetRecord,
12
12
  NewPauliString as NewPauliString,
13
+ QubitCoordinates as QubitCoordinates,
13
14
  ObservableInclude as ObservableInclude,
14
15
  )
@@ -45,3 +45,11 @@ class NewPauliString(ir.Statement):
45
45
  flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool)
46
46
  targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
47
47
  result: ir.ResultValue = info.result(type=PauliStringType)
48
+
49
+
50
+ @statement(dialect=dialect)
51
+ class QubitCoordinates(ir.Statement):
52
+ name = "qubit_coordinates"
53
+ traits = frozenset({lowering.FromPythonCall()})
54
+ coord: tuple[ir.SSAValue, ...] = info.argument(PyNum)
55
+ target: ir.SSAValue = info.argument(types.Int)
@@ -1,3 +1,14 @@
1
- from .emit import EmitStimCollapseMethods as EmitStimCollapseMethods
2
- from .stmts import * # noqa F403
1
+ from .stmts import (
2
+ MX as MX,
3
+ MY as MY,
4
+ MZ as MZ,
5
+ RX as RX,
6
+ RY as RY,
7
+ RZ as RZ,
8
+ MXX as MXX,
9
+ MYY as MYY,
10
+ MZZ as MZZ,
11
+ PPMeasurement as PPMeasurement,
12
+ )
3
13
  from ._dialect import dialect as dialect
14
+ from .emit_str import EmitStimCollapseMethods as EmitStimCollapseMethods
@@ -1,7 +1,7 @@
1
1
  from kirin.emit import EmitStrFrame
2
2
  from kirin.interp import MethodTable, impl
3
3
 
4
- from bloqade.stim.emit.stim import EmitStimMain
4
+ from bloqade.stim.emit.stim_str import EmitStimMain
5
5
 
6
6
  from . import stmts
7
7
  from ._dialect import dialect
@@ -60,7 +60,9 @@ class EmitStimCollapseMethods(MethodTable):
60
60
  self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PPMeasurement
61
61
  ):
62
62
  probability: str = frame.get(stmt.p)
63
- targets: tuple[str, ...] = frame.get_values(stmt.targets)
63
+ targets: tuple[str, ...] = tuple(
64
+ targ.upper() for targ in frame.get_values(stmt.targets)
65
+ )
64
66
 
65
67
  out = f"MPP({probability}) " + " ".join(targets)
66
68
  emit.writeln(frame, out)
@@ -2,7 +2,7 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
4
  from .._dialect import dialect
5
- from ...aux.types import PauliStringType
5
+ from ...auxiliary.types import PauliStringType
6
6
 
7
7
 
8
8
  @statement(dialect=dialect)
@@ -1,3 +1,18 @@
1
1
  from .emit import EmitStimGateMethods as EmitStimGateMethods
2
- from .stmts import * # noqa F403
2
+ from .stmts import (
3
+ CX as CX,
4
+ CY as CY,
5
+ CZ as CZ,
6
+ SPP as SPP,
7
+ H as H,
8
+ S as S,
9
+ X as X,
10
+ Y as Y,
11
+ Z as Z,
12
+ Swap as Swap,
13
+ SqrtX as SqrtX,
14
+ SqrtY as SqrtY,
15
+ SqrtZ as SqrtZ,
16
+ Identity as Identity,
17
+ )
3
18
  from ._dialect import dialect as dialect
@@ -1,7 +1,7 @@
1
1
  from kirin.emit import EmitStrFrame
2
2
  from kirin.interp import MethodTable, impl
3
3
 
4
- from bloqade.stim.emit.stim import EmitStimMain
4
+ from bloqade.stim.emit.stim_str import EmitStimMain
5
5
 
6
6
  from . import stmts
7
7
  from ._dialect import dialect
@@ -12,6 +12,7 @@ from .stmts.base import SingleQubitGate, ControlledTwoQubitGate
12
12
  class EmitStimGateMethods(MethodTable):
13
13
 
14
14
  gate_1q_map: dict[str, tuple[str, str]] = {
15
+ stmts.Identity.name: ("I", "I"),
15
16
  stmts.X.name: ("X", "X"),
16
17
  stmts.Y.name: ("Y", "Y"),
17
18
  stmts.Z.name: ("Z", "Z"),
@@ -22,6 +23,7 @@ class EmitStimGateMethods(MethodTable):
22
23
  stmts.SqrtZ.name: ("SQRT_Z", "SQRT_Z_DAG"),
23
24
  }
24
25
 
26
+ @impl(stmts.Identity)
25
27
  @impl(stmts.X)
26
28
  @impl(stmts.Y)
27
29
  @impl(stmts.Z)
@@ -80,8 +82,13 @@ class EmitStimGateMethods(MethodTable):
80
82
  @impl(stmts.SPP)
81
83
  def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP):
82
84
 
83
- targets: tuple[str, ...] = frame.get_values(stmt.targets)
84
- res = "SPP " + " ".join(targets)
85
+ targets: tuple[str, ...] = tuple(
86
+ targ.upper() for targ in frame.get_values(stmt.targets)
87
+ )
88
+ if stmt.dagger:
89
+ res = "SPP_DAG " + " ".join(targets)
90
+ else:
91
+ res = "SPP " + " ".join(targets)
85
92
  emit.writeln(frame, res)
86
93
 
87
94
  return ()
@@ -1,7 +1,7 @@
1
1
  from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
- from bloqade.stim.dialects.aux import RecordType
4
+ from bloqade.stim.dialects.auxiliary import RecordType
5
5
 
6
6
 
7
7
  @statement
@@ -2,7 +2,7 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
4
  from .._dialect import dialect
5
- from ...aux.types import PauliStringType
5
+ from ...auxiliary.types import PauliStringType
6
6
 
7
7
 
8
8
  # Generalized Pauli-product gates