bloqade-circuit 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. bloqade/analysis/address/impls.py +3 -16
  2. bloqade/pyqrack/__init__.py +1 -1
  3. bloqade/pyqrack/noise/native.py +8 -8
  4. bloqade/pyqrack/squin/noise/__init__.py +1 -0
  5. bloqade/pyqrack/squin/noise/native.py +72 -0
  6. bloqade/pyqrack/squin/op.py +7 -0
  7. bloqade/pyqrack/squin/qubit.py +0 -29
  8. bloqade/pyqrack/squin/runtime.py +18 -0
  9. bloqade/pyqrack/squin/wire.py +0 -36
  10. bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
  11. bloqade/qasm2/dialects/noise/_dialect.py +3 -0
  12. bloqade/{noise → qasm2/dialects/noise}/fidelity.py +2 -2
  13. bloqade/qasm2/dialects/noise/model.py +278 -0
  14. bloqade/qasm2/emit/impls/__init__.py +1 -1
  15. bloqade/qasm2/emit/impls/{noise_native.py → noise.py} +11 -11
  16. bloqade/qasm2/emit/main.py +2 -4
  17. bloqade/qasm2/emit/target.py +3 -3
  18. bloqade/qasm2/groups.py +0 -2
  19. bloqade/{noise/native/_wrappers.py → qasm2/noise.py} +9 -5
  20. bloqade/qasm2/passes/glob.py +12 -8
  21. bloqade/qasm2/passes/noise.py +5 -14
  22. bloqade/qasm2/rewrite/__init__.py +2 -0
  23. bloqade/qasm2/rewrite/noise/__init__.py +0 -0
  24. bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +31 -53
  25. bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
  26. bloqade/qbraid/lowering.py +8 -8
  27. bloqade/squin/__init__.py +16 -1
  28. bloqade/squin/analysis/nsites/impls.py +0 -9
  29. bloqade/squin/cirq/__init__.py +89 -0
  30. bloqade/squin/cirq/lowering.py +303 -0
  31. bloqade/squin/groups.py +7 -7
  32. bloqade/squin/lowering.py +27 -0
  33. bloqade/squin/noise/__init__.py +3 -1
  34. bloqade/squin/noise/_wrapper.py +7 -3
  35. bloqade/squin/noise/rewrite.py +111 -0
  36. bloqade/squin/noise/stmts.py +21 -16
  37. bloqade/squin/op/__init__.py +1 -0
  38. bloqade/squin/op/_wrapper.py +4 -0
  39. bloqade/squin/op/stmts.py +10 -11
  40. bloqade/squin/op/types.py +2 -0
  41. bloqade/squin/qubit.py +32 -37
  42. bloqade/squin/rewrite/desugar.py +65 -0
  43. bloqade/squin/rewrite/qubit_to_stim.py +0 -23
  44. bloqade/squin/rewrite/squin_measure.py +2 -27
  45. bloqade/squin/rewrite/stim_rewrite_util.py +3 -8
  46. bloqade/squin/rewrite/wire_to_stim.py +0 -21
  47. bloqade/squin/wire.py +4 -9
  48. bloqade/stim/__init__.py +2 -1
  49. bloqade/stim/_wrappers.py +4 -0
  50. bloqade/stim/dialects/auxiliary/__init__.py +1 -0
  51. bloqade/stim/dialects/auxiliary/emit.py +17 -2
  52. bloqade/stim/dialects/auxiliary/stmts/__init__.py +1 -0
  53. bloqade/stim/dialects/auxiliary/stmts/annotate.py +8 -0
  54. bloqade/stim/dialects/collapse/emit_str.py +3 -1
  55. bloqade/stim/dialects/gate/emit.py +9 -2
  56. bloqade/stim/dialects/noise/emit.py +32 -1
  57. bloqade/stim/dialects/noise/stmts.py +29 -0
  58. bloqade/stim/parse/__init__.py +1 -0
  59. bloqade/stim/parse/lowering.py +686 -0
  60. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/METADATA +3 -1
  61. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/RECORD +64 -57
  62. bloqade/noise/__init__.py +0 -2
  63. bloqade/noise/native/_dialect.py +0 -3
  64. bloqade/noise/native/model.py +0 -346
  65. bloqade/qasm2/dialects/noise.py +0 -48
  66. bloqade/squin/rewrite/measure_desugar.py +0 -33
  67. /bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +0 -0
  68. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/WHEEL +0 -0
  69. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.1.dist-info}/licenses/LICENSE +0 -0
bloqade/squin/op/stmts.py CHANGED
@@ -9,7 +9,7 @@ from ._dialect import dialect
9
9
 
10
10
  @statement
11
11
  class Operator(ir.Statement):
12
- pass
12
+ result: ir.ResultValue = info.result(OpType)
13
13
 
14
14
 
15
15
  @statement
@@ -26,7 +26,6 @@ class CompositeOp(Operator):
26
26
  class BinaryOp(CompositeOp):
27
27
  lhs: ir.SSAValue = info.argument(OpType)
28
28
  rhs: ir.SSAValue = info.argument(OpType)
29
- result: ir.ResultValue = info.result(OpType)
30
29
 
31
30
 
32
31
  @statement(dialect=dialect)
@@ -46,7 +45,6 @@ class Adjoint(CompositeOp):
46
45
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
47
46
  is_unitary: bool = info.attribute(default=False)
48
47
  op: ir.SSAValue = info.argument(OpType)
49
- result: ir.ResultValue = info.result(OpType)
50
48
 
51
49
 
52
50
  @statement(dialect=dialect)
@@ -55,7 +53,6 @@ class Scale(CompositeOp):
55
53
  is_unitary: bool = info.attribute(default=False)
56
54
  op: ir.SSAValue = info.argument(OpType)
57
55
  factor: ir.SSAValue = info.argument(NumberType)
58
- result: ir.ResultValue = info.result(OpType)
59
56
 
60
57
 
61
58
  @statement(dialect=dialect)
@@ -64,7 +61,6 @@ class Control(CompositeOp):
64
61
  is_unitary: bool = info.attribute(default=False)
65
62
  op: ir.SSAValue = info.argument(OpType)
66
63
  n_controls: int = info.attribute()
67
- result: ir.ResultValue = info.result(OpType)
68
64
 
69
65
 
70
66
  @statement(dialect=dialect)
@@ -72,14 +68,12 @@ class Rot(CompositeOp):
72
68
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()})
73
69
  axis: ir.SSAValue = info.argument(OpType)
74
70
  angle: ir.SSAValue = info.argument(types.Float)
75
- result: ir.ResultValue = info.result(OpType)
76
71
 
77
72
 
78
73
  @statement(dialect=dialect)
79
74
  class Identity(CompositeOp):
80
75
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
81
76
  sites: int = info.attribute()
82
- result: ir.ResultValue = info.result(OpType)
83
77
 
84
78
 
85
79
  @statement
@@ -87,7 +81,6 @@ class ConstantOp(PrimitiveOp):
87
81
  traits = frozenset(
88
82
  {ir.Pure(), lowering.FromPythonCall(), ir.ConstantLike(), FixedSites(1)}
89
83
  )
90
- result: ir.ResultValue = info.result(OpType)
91
84
 
92
85
 
93
86
  @statement
@@ -109,7 +102,6 @@ class U3(PrimitiveOp):
109
102
  theta: ir.SSAValue = info.argument(types.Float)
110
103
  phi: ir.SSAValue = info.argument(types.Float)
111
104
  lam: ir.SSAValue = info.argument(types.Float)
112
- result: ir.ResultValue = info.result(OpType)
113
105
 
114
106
 
115
107
  @statement(dialect=dialect)
@@ -124,7 +116,6 @@ class PhaseOp(PrimitiveOp):
124
116
 
125
117
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
126
118
  theta: ir.SSAValue = info.argument(types.Float)
127
- result: ir.ResultValue = info.result(OpType)
128
119
 
129
120
 
130
121
  @statement(dialect=dialect)
@@ -139,7 +130,15 @@ class ShiftOp(PrimitiveOp):
139
130
 
140
131
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
141
132
  theta: ir.SSAValue = info.argument(types.Float)
142
- result: ir.ResultValue = info.result(OpType)
133
+
134
+
135
+ @statement(dialect=dialect)
136
+ class Reset(PrimitiveOp):
137
+ """
138
+ Reset operator for qubits or wires.
139
+ """
140
+
141
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
143
142
 
144
143
 
145
144
  @statement
bloqade/squin/op/types.py CHANGED
@@ -22,3 +22,5 @@ class Op:
22
22
 
23
23
 
24
24
  OpType = types.PyClass(Op)
25
+
26
+ NumOperators = types.TypeVar("NumOperators")
bloqade/squin/qubit.py CHANGED
@@ -17,6 +17,8 @@ from kirin.lowering import wraps
17
17
  from bloqade.types import Qubit, QubitType
18
18
  from bloqade.squin.op.types import Op, OpType
19
19
 
20
+ from .lowering import ApplyAnyCallLowering
21
+
20
22
  dialect = ir.Dialect("squin.qubit")
21
23
 
22
24
 
@@ -34,6 +36,14 @@ class Apply(ir.Statement):
34
36
  qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
35
37
 
36
38
 
39
+ @statement(dialect=dialect)
40
+ class ApplyAny(ir.Statement):
41
+ # NOTE: custom lowering to deal with vararg calls
42
+ traits = frozenset({ApplyAnyCallLowering()})
43
+ operator: ir.SSAValue = info.argument(OpType)
44
+ qubits: tuple[ir.SSAValue, ...] = info.argument()
45
+
46
+
37
47
  @statement(dialect=dialect)
38
48
  class Broadcast(ir.Statement):
39
49
  traits = frozenset({lowering.FromPythonCall()})
@@ -68,19 +78,6 @@ class MeasureQubitList(ir.Statement):
68
78
  result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
69
79
 
70
80
 
71
- @statement(dialect=dialect)
72
- class MeasureAndReset(ir.Statement):
73
- traits = frozenset({lowering.FromPythonCall()})
74
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
75
- result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
76
-
77
-
78
- @statement(dialect=dialect)
79
- class Reset(ir.Statement):
80
- traits = frozenset({lowering.FromPythonCall()})
81
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
82
-
83
-
84
81
  # NOTE: no dependent types in Python, so we have to mark it Any...
85
82
  @wraps(New)
86
83
  def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
@@ -95,7 +92,7 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
95
92
  ...
96
93
 
97
94
 
98
- @wraps(Apply)
95
+ @overload
99
96
  def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
100
97
  """Apply an operator to a list of qubits.
101
98
 
@@ -112,6 +109,27 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
112
109
  ...
113
110
 
114
111
 
112
+ @overload
113
+ def apply(operator: Op, *qubits: Qubit) -> None:
114
+ """Apply and operator to any number of qubits.
115
+
116
+ Note, that when considering atom loss, lost qubits will be skipped.
117
+
118
+ Args:
119
+ operator: The operator to apply.
120
+ *qubits: The qubits to apply the operator to. The number of qubits must
121
+ match the size of the operator.
122
+
123
+ Returns:
124
+ None
125
+ """
126
+ ...
127
+
128
+
129
+ @wraps(ApplyAny)
130
+ def apply(operator: Op, *qubits) -> None: ...
131
+
132
+
115
133
  @overload
116
134
  def measure(input: Qubit) -> bool: ...
117
135
  @overload
@@ -161,26 +179,3 @@ def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> No
161
179
  None
162
180
  """
163
181
  ...
164
-
165
-
166
- @wraps(MeasureAndReset)
167
- def measure_and_reset(qubits: ilist.IList[Qubit, Any]) -> ilist.IList[bool, Any]:
168
- """Measure the qubits in the list and reset them."
169
-
170
- Args:
171
- qubits: The list of qubits to measure and reset.
172
-
173
- Returns:
174
- list[bool]: The result of the measurement.
175
- """
176
- ...
177
-
178
-
179
- @wraps(Reset)
180
- def reset(qubits: ilist.IList[Qubit, Any]) -> None:
181
- """Reset the qubits in the list."
182
-
183
- Args:
184
- qubits: The list of qubits to reset.
185
- """
186
- ...
@@ -0,0 +1,65 @@
1
+ from kirin import ir, types
2
+ from kirin.dialects import ilist
3
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
4
+
5
+ from bloqade.squin.qubit import (
6
+ Apply,
7
+ ApplyAny,
8
+ QubitType,
9
+ MeasureAny,
10
+ MeasureQubit,
11
+ MeasureQubitList,
12
+ )
13
+
14
+
15
+ class MeasureDesugarRule(RewriteRule):
16
+ """
17
+ Desugar measure operations in the circuit.
18
+ """
19
+
20
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
+
22
+ if not isinstance(node, MeasureAny):
23
+ return RewriteResult()
24
+
25
+ if node.input.type.is_subseteq(QubitType):
26
+ node.replace_by(
27
+ MeasureQubit(
28
+ qubit=node.input,
29
+ )
30
+ )
31
+ return RewriteResult(has_done_something=True)
32
+ elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
33
+ node.replace_by(
34
+ MeasureQubitList(
35
+ qubits=node.input,
36
+ )
37
+ )
38
+ return RewriteResult(has_done_something=True)
39
+
40
+ return RewriteResult()
41
+
42
+
43
+ class ApplyDesugarRule(RewriteRule):
44
+ """
45
+ Desugar apply operators in the kernel.
46
+ """
47
+
48
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
49
+
50
+ if not isinstance(node, ApplyAny):
51
+ return RewriteResult()
52
+
53
+ op = node.operator
54
+ qubits = node.qubits
55
+
56
+ if len(qubits) == 1 and qubits[0].type.is_subseteq(ilist.IListType):
57
+ # NOTE: already calling with just a single argument that is already an ilist
58
+ qubits_ilist = qubits[0]
59
+ else:
60
+ (qubits_ilist_stmt := ilist.New(values=qubits)).insert_before(node)
61
+ qubits_ilist = qubits_ilist_stmt.result
62
+
63
+ stmt = Apply(operator=op, qubits=qubits_ilist)
64
+ node.replace_by(stmt)
65
+ return RewriteResult(has_done_something=True)
@@ -1,7 +1,6 @@
1
1
  from kirin import ir
2
2
  from kirin.rewrite.abc import RewriteRule, RewriteResult
3
3
 
4
- from bloqade import stim
5
4
  from bloqade.squin import op, qubit
6
5
  from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
7
6
  from bloqade.squin.rewrite.stim_rewrite_util import (
@@ -18,8 +17,6 @@ class SquinQubitToStim(RewriteRule):
18
17
  match node:
19
18
  case qubit.Apply() | qubit.Broadcast():
20
19
  return self.rewrite_Apply_and_Broadcast(node)
21
- case qubit.Reset():
22
- return self.rewrite_Reset(node)
23
20
  case _:
24
21
  return RewriteResult()
25
22
 
@@ -60,25 +57,5 @@ class SquinQubitToStim(RewriteRule):
60
57
 
61
58
  return RewriteResult(has_done_something=True)
62
59
 
63
- def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult:
64
- qubit_ilist_ssa = reset_stmt.qubits
65
- # qubits are in an ilist which makes up an AddressTuple
66
- address_attr = qubit_ilist_ssa.hints.get("address")
67
- if address_attr is None:
68
- return RewriteResult()
69
-
70
- assert isinstance(address_attr, AddressAttribute)
71
- qubit_idx_ssas = insert_qubit_idx_from_address(
72
- address=address_attr, stmt_to_insert_before=reset_stmt
73
- )
74
-
75
- if qubit_idx_ssas is None:
76
- return RewriteResult()
77
-
78
- stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
79
- reset_stmt.replace_by(stim_rz_stmt)
80
-
81
- return RewriteResult(has_done_something=True)
82
-
83
60
 
84
61
  # put rewrites for measure statements in separate rule, then just have to dispatch
@@ -3,8 +3,8 @@ from kirin import ir
3
3
  from kirin.dialects import py
4
4
  from kirin.rewrite.abc import RewriteRule, RewriteResult
5
5
 
6
- from bloqade import stim
7
6
  from bloqade.squin import wire, qubit
7
+ from bloqade.stim.dialects import collapse
8
8
  from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
9
  from bloqade.squin.rewrite.stim_rewrite_util import (
10
10
  is_measure_result_used,
@@ -22,8 +22,6 @@ class SquinMeasureToStim(RewriteRule):
22
22
  match node:
23
23
  case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure():
24
24
  return self.rewrite_Measure(node)
25
- case qubit.MeasureAndReset() | wire.MeasureAndReset():
26
- return self.rewrite_MeasureAndReset(node)
27
25
  case _:
28
26
  return RewriteResult()
29
27
 
@@ -38,7 +36,7 @@ class SquinMeasureToStim(RewriteRule):
38
36
  return RewriteResult()
39
37
 
40
38
  prob_noise_stmt = py.constant.Constant(0.0)
41
- stim_measure_stmt = stim.collapse.MZ(
39
+ stim_measure_stmt = collapse.MZ(
42
40
  p=prob_noise_stmt.result,
43
41
  targets=qubit_idx_ssas,
44
42
  )
@@ -47,29 +45,6 @@ class SquinMeasureToStim(RewriteRule):
47
45
 
48
46
  return RewriteResult(has_done_something=True)
49
47
 
50
- def rewrite_MeasureAndReset(
51
- self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset
52
- ) -> RewriteResult:
53
- if not is_measure_result_used(meas_and_reset_stmt):
54
- return RewriteResult()
55
-
56
- qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt)
57
-
58
- if qubit_idx_ssas is None:
59
- return RewriteResult()
60
-
61
- error_p_stmt = py.Constant(0.0)
62
- stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result)
63
- stim_rz_stmt = stim.collapse.RZ(
64
- targets=qubit_idx_ssas,
65
- )
66
-
67
- error_p_stmt.insert_before(meas_and_reset_stmt)
68
- stim_mz_stmt.insert_before(meas_and_reset_stmt)
69
- meas_and_reset_stmt.replace_by(stim_rz_stmt)
70
-
71
- return RewriteResult(has_done_something=True)
72
-
73
48
  def get_qubit_idx_ssas(
74
49
  self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure
75
50
  ) -> tuple[ir.SSAValue, ...] | None:
@@ -3,7 +3,7 @@ from kirin.dialects import py
3
3
  from kirin.rewrite.abc import RewriteResult
4
4
 
5
5
  from bloqade.squin import op, wire, qubit
6
- from bloqade.stim.dialects import gate
6
+ from bloqade.stim.dialects import gate, collapse
7
7
  from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
8
8
  from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9
9
 
@@ -14,6 +14,7 @@ SQUIN_STIM_GATE_MAPPING = {
14
14
  op.stmts.H: gate.H,
15
15
  op.stmts.S: gate.S,
16
16
  op.stmts.Identity: gate.Identity,
17
+ op.stmts.Reset: collapse.RZ,
17
18
  }
18
19
 
19
20
 
@@ -144,13 +145,7 @@ def rewrite_Control(
144
145
 
145
146
 
146
147
  def is_measure_result_used(
147
- stmt: (
148
- qubit.MeasureAndReset
149
- | qubit.MeasureQubit
150
- | qubit.MeasureQubitList
151
- | wire.MeasureAndReset
152
- | wire.Measure
153
- ),
148
+ stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
154
149
  ) -> bool:
155
150
  """
156
151
  Check if the result of a measure statement is used in the program.
@@ -1,13 +1,10 @@
1
1
  from kirin import ir
2
2
  from kirin.rewrite.abc import RewriteRule, RewriteResult
3
3
 
4
- from bloqade import stim
5
4
  from bloqade.squin import op, wire
6
- from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
7
5
  from bloqade.squin.rewrite.stim_rewrite_util import (
8
6
  SQUIN_STIM_GATE_MAPPING,
9
7
  rewrite_Control,
10
- insert_qubit_idx_from_address,
11
8
  insert_qubit_idx_from_wire_ssa,
12
9
  )
13
10
 
@@ -18,8 +15,6 @@ class SquinWireToStim(RewriteRule):
18
15
  match node:
19
16
  case wire.Apply() | wire.Broadcast():
20
17
  return self.rewrite_Apply_and_Broadcast(node)
21
- case wire.Reset():
22
- return self.rewrite_Reset(node)
23
18
  case _:
24
19
  return RewriteResult()
25
20
 
@@ -55,19 +50,3 @@ class SquinWireToStim(RewriteRule):
55
50
  stmt.replace_by(stim_1q_stmt)
56
51
 
57
52
  return RewriteResult(has_done_something=True)
58
-
59
- def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult:
60
- address_attr = reset_stmt.wire.hints.get("address")
61
- if address_attr is None:
62
- return RewriteResult()
63
- assert isinstance(address_attr, AddressAttribute)
64
- qubit_idx_ssas = insert_qubit_idx_from_address(
65
- address=address_attr, stmt_to_insert_before=reset_stmt
66
- )
67
- if qubit_idx_ssas is None:
68
- return RewriteResult()
69
-
70
- stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas)
71
- reset_stmt.replace_by(stim_rz_stmt)
72
-
73
- return RewriteResult(has_done_something=True)
bloqade/squin/wire.py CHANGED
@@ -95,23 +95,18 @@ class Broadcast(ir.Statement):
95
95
  class Measure(ir.Statement):
96
96
  traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
97
97
  wire: ir.SSAValue = info.argument(WireType)
98
+ qubit: ir.SSAValue = info.argument(QubitType)
98
99
  result: ir.ResultValue = info.result(types.Int)
99
100
 
100
101
 
101
102
  @statement(dialect=dialect)
102
- class MeasureAndReset(ir.Statement):
103
- traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
104
- wire: ir.SSAValue = info.argument(WireType)
103
+ class NonDestructiveMeasure(ir.Statement):
104
+ traits = frozenset({lowering.FromPythonCall()})
105
+ input_wire: ir.SSAValue = info.argument(WireType)
105
106
  result: ir.ResultValue = info.result(types.Int)
106
107
  out_wire: ir.ResultValue = info.result(WireType)
107
108
 
108
109
 
109
- @statement(dialect=dialect)
110
- class Reset(ir.Statement):
111
- traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
112
- wire: ir.SSAValue = info.argument(WireType)
113
-
114
-
115
110
  @wraps(Unwrap)
116
111
  def unwrap(qubit: Qubit) -> Wire: ...
117
112
 
bloqade/stim/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from . import emit as emit, dialects as dialects
1
+ from . import emit as emit, parse as parse, dialects as dialects
2
2
  from .groups import main as main
3
3
  from ._wrappers import (
4
4
  h as h,
@@ -34,6 +34,7 @@ from ._wrappers import (
34
34
  depolarize1 as depolarize1,
35
35
  depolarize2 as depolarize2,
36
36
  pauli_string as pauli_string,
37
+ qubit_coords as qubit_coords,
37
38
  pauli_channel1 as pauli_channel1,
38
39
  pauli_channel2 as pauli_channel2,
39
40
  observable_include as observable_include,
bloqade/stim/_wrappers.py CHANGED
@@ -99,6 +99,10 @@ def pauli_string(
99
99
  ) -> auxiliary.PauliString: ...
100
100
 
101
101
 
102
+ @wraps(auxiliary.QubitCoordinates)
103
+ def qubit_coords(coord: tuple[Union[int, float], ...], target: int) -> None: ...
104
+
105
+
102
106
  # dialect:: collapse
103
107
  @wraps(collapse.MZ)
104
108
  def mz(p: float, targets: tuple[int, ...]) -> None: ...
@@ -10,6 +10,7 @@ from .stmts import (
10
10
  GetRecord as GetRecord,
11
11
  ConstFloat as ConstFloat,
12
12
  NewPauliString as NewPauliString,
13
+ QubitCoordinates as QubitCoordinates,
13
14
  ObservableInclude as ObservableInclude,
14
15
  )
15
16
  from .types import (
@@ -69,8 +69,10 @@ class EmitStimAuxMethods(MethodTable):
69
69
 
70
70
  coord_str: str = ", ".join(coords)
71
71
  target_str: str = " ".join(targets)
72
- emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
73
-
72
+ if len(coords):
73
+ emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
74
+ else:
75
+ emit.writeln(frame, f"DETECTOR {target_str}")
74
76
  return ()
75
77
 
76
78
  @impl(stmts.ObservableInclude)
@@ -100,3 +102,16 @@ class EmitStimAuxMethods(MethodTable):
100
102
  )
101
103
 
102
104
  return (out,)
105
+
106
+ @impl(stmts.QubitCoordinates)
107
+ def qubit_coordinates(
108
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.QubitCoordinates
109
+ ):
110
+
111
+ coords: tuple[str, ...] = frame.get_values(stmt.coord)
112
+ target: str = frame.get(stmt.target)
113
+
114
+ coord_str: str = ", ".join(coords)
115
+ emit.writeln(frame, f"QUBIT_COORDS({coord_str}) {target}")
116
+
117
+ return ()
@@ -10,5 +10,6 @@ from .annotate import (
10
10
  Detector as Detector,
11
11
  GetRecord as GetRecord,
12
12
  NewPauliString as NewPauliString,
13
+ QubitCoordinates as QubitCoordinates,
13
14
  ObservableInclude as ObservableInclude,
14
15
  )
@@ -45,3 +45,11 @@ class NewPauliString(ir.Statement):
45
45
  flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool)
46
46
  targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
47
47
  result: ir.ResultValue = info.result(type=PauliStringType)
48
+
49
+
50
+ @statement(dialect=dialect)
51
+ class QubitCoordinates(ir.Statement):
52
+ name = "qubit_coordinates"
53
+ traits = frozenset({lowering.FromPythonCall()})
54
+ coord: tuple[ir.SSAValue, ...] = info.argument(PyNum)
55
+ target: ir.SSAValue = info.argument(types.Int)
@@ -60,7 +60,9 @@ class EmitStimCollapseMethods(MethodTable):
60
60
  self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PPMeasurement
61
61
  ):
62
62
  probability: str = frame.get(stmt.p)
63
- targets: tuple[str, ...] = frame.get_values(stmt.targets)
63
+ targets: tuple[str, ...] = tuple(
64
+ targ.upper() for targ in frame.get_values(stmt.targets)
65
+ )
64
66
 
65
67
  out = f"MPP({probability}) " + " ".join(targets)
66
68
  emit.writeln(frame, out)
@@ -12,6 +12,7 @@ from .stmts.base import SingleQubitGate, ControlledTwoQubitGate
12
12
  class EmitStimGateMethods(MethodTable):
13
13
 
14
14
  gate_1q_map: dict[str, tuple[str, str]] = {
15
+ stmts.Identity.name: ("I", "I"),
15
16
  stmts.X.name: ("X", "X"),
16
17
  stmts.Y.name: ("Y", "Y"),
17
18
  stmts.Z.name: ("Z", "Z"),
@@ -22,6 +23,7 @@ class EmitStimGateMethods(MethodTable):
22
23
  stmts.SqrtZ.name: ("SQRT_Z", "SQRT_Z_DAG"),
23
24
  }
24
25
 
26
+ @impl(stmts.Identity)
25
27
  @impl(stmts.X)
26
28
  @impl(stmts.Y)
27
29
  @impl(stmts.Z)
@@ -80,8 +82,13 @@ class EmitStimGateMethods(MethodTable):
80
82
  @impl(stmts.SPP)
81
83
  def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP):
82
84
 
83
- targets: tuple[str, ...] = frame.get_values(stmt.targets)
84
- res = "SPP " + " ".join(targets)
85
+ targets: tuple[str, ...] = tuple(
86
+ targ.upper() for targ in frame.get_values(stmt.targets)
87
+ )
88
+ if stmt.dagger:
89
+ res = "SPP_DAG " + " ".join(targets)
90
+ else:
91
+ res = "SPP " + " ".join(targets)
85
92
  emit.writeln(frame, res)
86
93
 
87
94
  return ()
@@ -44,7 +44,7 @@ class EmitStimNoiseMethods(MethodTable):
44
44
  px: str = frame.get(stmt.px)
45
45
  py: str = frame.get(stmt.py)
46
46
  pz: str = frame.get(stmt.pz)
47
- res = f"PAULI_CHANNEL_1({px},{py},{pz}) " + " ".join(targets)
47
+ res = f"PAULI_CHANNEL_1({px}, {py}, {pz}) " + " ".join(targets)
48
48
  emit.writeln(frame, res)
49
49
 
50
50
  return ()
@@ -64,3 +64,34 @@ class EmitStimNoiseMethods(MethodTable):
64
64
  emit.writeln(frame, res)
65
65
 
66
66
  return ()
67
+
68
+ @impl(stmts.TrivialError)
69
+ def non_stim_error(
70
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.TrivialError
71
+ ):
72
+
73
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
74
+ prob: tuple[str, ...] = frame.get_values(stmt.probs)
75
+ prob_str: str = ", ".join(prob)
76
+
77
+ res = f"I_ERROR[{stmt.name}]({prob_str}) " + " ".join(targets)
78
+ emit.writeln(frame, res)
79
+
80
+ return ()
81
+
82
+ @impl(stmts.TrivialCorrelatedError)
83
+ def non_stim_corr_error(
84
+ self,
85
+ emit: EmitStimMain,
86
+ frame: EmitStrFrame,
87
+ stmt: stmts.TrivialCorrelatedError,
88
+ ):
89
+
90
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
91
+ prob: tuple[str, ...] = frame.get_values(stmt.probs)
92
+ prob_str: str = ", ".join(prob)
93
+
94
+ res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets)
95
+ emit.writeln(frame, res)
96
+
97
+ return ()
@@ -75,3 +75,32 @@ class ZError(ir.Statement):
75
75
  traits = frozenset({lowering.FromPythonCall()})
76
76
  p: ir.SSAValue = info.argument(types.Float)
77
77
  targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
78
+
79
+
80
+ @statement
81
+ class NonStimError(ir.Statement):
82
+ name = "NonStimError"
83
+ traits = frozenset({lowering.FromPythonCall()})
84
+ probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
85
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
86
+
87
+
88
+ @statement
89
+ class NonStimCorrelatedError(ir.Statement):
90
+ name = "NonStimCorrelatedError"
91
+ traits = frozenset({lowering.FromPythonCall()})
92
+ nonce: int = (
93
+ info.attribute()
94
+ ) # Must be a unique value, otherwise stim might merge two correlated errors with equal probabilities
95
+ probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
96
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
97
+
98
+
99
+ @statement(dialect=dialect)
100
+ class TrivialCorrelatedError(NonStimCorrelatedError):
101
+ name = "TRIV_CORR_ERROR"
102
+
103
+
104
+ @statement(dialect=dialect)
105
+ class TrivialError(NonStimError):
106
+ name = "TRIV_ERROR"
@@ -0,0 +1 @@
1
+ from .lowering import loads as loads, loadfile as loadfile