bloqade-circuit 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (59) hide show
  1. bloqade/analysis/address/impls.py +3 -16
  2. bloqade/pyqrack/noise/native.py +8 -8
  3. bloqade/pyqrack/squin/op.py +7 -0
  4. bloqade/pyqrack/squin/qubit.py +0 -29
  5. bloqade/pyqrack/squin/runtime.py +18 -0
  6. bloqade/pyqrack/squin/wire.py +0 -36
  7. bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
  8. bloqade/qasm2/dialects/noise/_dialect.py +3 -0
  9. bloqade/{noise → qasm2/dialects/noise}/fidelity.py +2 -2
  10. bloqade/qasm2/dialects/noise/model.py +278 -0
  11. bloqade/qasm2/emit/impls/__init__.py +1 -1
  12. bloqade/qasm2/emit/impls/{noise_native.py → noise.py} +11 -11
  13. bloqade/qasm2/emit/main.py +2 -4
  14. bloqade/qasm2/emit/target.py +3 -3
  15. bloqade/qasm2/groups.py +0 -2
  16. bloqade/{noise/native/_wrappers.py → qasm2/noise.py} +9 -5
  17. bloqade/qasm2/passes/glob.py +12 -8
  18. bloqade/qasm2/passes/noise.py +5 -14
  19. bloqade/qasm2/rewrite/__init__.py +2 -0
  20. bloqade/qasm2/rewrite/noise/__init__.py +0 -0
  21. bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +31 -53
  22. bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
  23. bloqade/qbraid/lowering.py +8 -8
  24. bloqade/squin/__init__.py +7 -1
  25. bloqade/squin/analysis/nsites/impls.py +0 -9
  26. bloqade/squin/groups.py +4 -4
  27. bloqade/squin/lowering.py +27 -0
  28. bloqade/squin/op/__init__.py +1 -0
  29. bloqade/squin/op/_wrapper.py +4 -0
  30. bloqade/squin/op/stmts.py +10 -0
  31. bloqade/squin/qubit.py +32 -37
  32. bloqade/squin/rewrite/desugar.py +65 -0
  33. bloqade/squin/rewrite/qubit_to_stim.py +0 -23
  34. bloqade/squin/rewrite/squin_measure.py +2 -27
  35. bloqade/squin/rewrite/stim_rewrite_util.py +3 -8
  36. bloqade/squin/rewrite/wire_to_stim.py +0 -21
  37. bloqade/squin/wire.py +4 -9
  38. bloqade/stim/__init__.py +2 -1
  39. bloqade/stim/_wrappers.py +4 -0
  40. bloqade/stim/dialects/auxiliary/__init__.py +1 -0
  41. bloqade/stim/dialects/auxiliary/emit.py +17 -2
  42. bloqade/stim/dialects/auxiliary/stmts/__init__.py +1 -0
  43. bloqade/stim/dialects/auxiliary/stmts/annotate.py +8 -0
  44. bloqade/stim/dialects/collapse/emit_str.py +3 -1
  45. bloqade/stim/dialects/gate/emit.py +9 -2
  46. bloqade/stim/dialects/noise/emit.py +32 -1
  47. bloqade/stim/dialects/noise/stmts.py +29 -0
  48. bloqade/stim/parse/__init__.py +1 -0
  49. bloqade/stim/parse/lowering.py +686 -0
  50. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.0.dist-info}/METADATA +3 -1
  51. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.0.dist-info}/RECORD +54 -52
  52. bloqade/noise/__init__.py +0 -2
  53. bloqade/noise/native/_dialect.py +0 -3
  54. bloqade/noise/native/model.py +0 -346
  55. bloqade/qasm2/dialects/noise.py +0 -48
  56. bloqade/squin/rewrite/measure_desugar.py +0 -33
  57. /bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +0 -0
  58. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.0.dist-info}/WHEEL +0 -0
  59. {bloqade_circuit-0.3.0.dist-info → bloqade_circuit-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,11 +3,15 @@ from typing import Any
3
3
  from kirin.dialects import ilist
4
4
  from kirin.lowering import wraps
5
5
 
6
- from bloqade.noise import native
7
- from bloqade.qasm2.types import Qubit
6
+ from .types import Qubit
7
+ from .dialects import noise
8
+ from .dialects.noise import (
9
+ TwoRowZoneModel as TwoRowZoneModel,
10
+ MoveNoiseModelABC as MoveNoiseModelABC,
11
+ )
8
12
 
9
13
 
10
- @wraps(native.AtomLossChannel)
14
+ @wraps(noise.AtomLossChannel)
11
15
  def atom_loss_channel(qargs: ilist.IList[Qubit, Any] | list, *, prob: float) -> None:
12
16
  """Apply an atom loss channel to a list of qubits.
13
17
 
@@ -18,7 +22,7 @@ def atom_loss_channel(qargs: ilist.IList[Qubit, Any] | list, *, prob: float) ->
18
22
  ...
19
23
 
20
24
 
21
- @wraps(native.PauliChannel)
25
+ @wraps(noise.PauliChannel)
22
26
  def pauli_channel(
23
27
  qargs: ilist.IList[Qubit, Any] | list, *, px: float, py: float, pz: float
24
28
  ) -> None:
@@ -32,7 +36,7 @@ def pauli_channel(
32
36
  """
33
37
 
34
38
 
35
- @wraps(native.CZPauliChannel)
39
+ @wraps(noise.CZPauliChannel)
36
40
  def cz_pauli_channel(
37
41
  ctrls: ilist.IList[Qubit, Any] | list,
38
42
  qargs: ilist.IList[Qubit, Any] | list,
@@ -58,13 +58,15 @@ class GlobalToUOP(Pass):
58
58
  rewriter = walk.Walk(self.generate_rule(mt))
59
59
  result = rewriter.rewrite(mt.code)
60
60
 
61
- result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code)
62
- result = Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination())).rewrite(
63
- mt.code
61
+ result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code).join(result)
62
+ result = (
63
+ Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination()))
64
+ .rewrite(mt.code)
65
+ .join(result)
64
66
  )
65
67
 
66
68
  # do fold again to get proper hint for inserted const
67
- result = Fold(mt.dialects)(mt)
69
+ result = Fold(mt.dialects)(mt).join(result)
68
70
  return result
69
71
 
70
72
 
@@ -110,10 +112,12 @@ class GlobalToParallel(Pass):
110
112
  rewriter = walk.Walk(self.generate_rule(mt))
111
113
  result = rewriter.rewrite(mt.code)
112
114
 
113
- result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code)
114
- result = Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination())).rewrite(
115
- mt.code
115
+ result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code).join(result)
116
+ result = (
117
+ Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination()))
118
+ .rewrite(mt.code)
119
+ .join(result)
116
120
  )
117
121
  # do fold again to get proper hint
118
- result = Fold(mt.dialects)(mt)
122
+ result = Fold(mt.dialects)(mt).join(result)
119
123
  return result
@@ -8,10 +8,10 @@ from kirin.rewrite import (
8
8
  DeadCodeElimination,
9
9
  )
10
10
 
11
- from bloqade.noise import native
11
+ from bloqade.qasm2 import noise
12
12
  from bloqade.analysis import address
13
+ from bloqade.qasm2.rewrite import NoiseRewriteRule
13
14
  from bloqade.qasm2.passes.lift_qubits import LiftQubits
14
- from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
15
15
 
16
16
 
17
17
  @dataclass
@@ -25,12 +25,9 @@ class NoisePass(Pass):
25
25
 
26
26
  ```
27
27
  from bloqade import qasm2
28
- from bloqade.noise import native
29
- from bloqade.qasm2.passes.noise import NoisePass
28
+ from bloqade.qasm2.passes import NoisePass
30
29
 
31
- noise_main = qasm2.extended.add(native.dialect)
32
-
33
- @noise_main
30
+ @qasm2.extended
34
31
  def main():
35
32
  q = qasm2.qreg(2)
36
33
  qasm2.h(q[0])
@@ -51,12 +48,7 @@ class NoisePass(Pass):
51
48
 
52
49
  """
53
50
 
54
- noise_model: native.MoveNoiseModelABC = field(
55
- default_factory=native.TwoRowZoneModel
56
- )
57
- gate_noise_params: native.GateNoiseParams = field(
58
- default_factory=native.GateNoiseParams
59
- )
51
+ noise_model: noise.MoveNoiseModelABC = field(default_factory=noise.TwoRowZoneModel)
60
52
  address_analysis: address.AddressAnalysis = field(init=False)
61
53
 
62
54
  def __post_init__(self):
@@ -89,7 +81,6 @@ class NoisePass(Pass):
89
81
  qubit_ssa_value=qubit_ssa_value,
90
82
  address_analysis=address_analysis,
91
83
  noise_model=self.noise_model,
92
- gate_noise_params=self.gate_noise_params,
93
84
  ),
94
85
  reverse=True,
95
86
  )
@@ -11,3 +11,5 @@ from .uop_to_parallel import (
11
11
  SimpleGreedyMergePolicy as SimpleGreedyMergePolicy,
12
12
  SimpleOptimalMergePolicy as SimpleOptimalMergePolicy,
13
13
  )
14
+ from .noise.remove_noise import RemoveNoisePass as RemoveNoisePass
15
+ from .noise.heuristic_noise import NoiseRewriteRule as NoiseRewriteRule
File without changes
@@ -5,9 +5,8 @@ from kirin import ir
5
5
  from kirin.rewrite import abc as rewrite_abc
6
6
  from kirin.dialects import ilist
7
7
 
8
- from bloqade.noise import native
9
8
  from bloqade.analysis import address
10
- from bloqade.qasm2.dialects import uop, glob, parallel
9
+ from bloqade.qasm2.dialects import uop, glob, noise, parallel
11
10
 
12
11
 
13
12
  @dataclass
@@ -19,12 +18,7 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
19
18
 
20
19
  address_analysis: Dict[ir.SSAValue, address.Address]
21
20
  qubit_ssa_value: Dict[int, ir.SSAValue]
22
- gate_noise_params: native.GateNoiseParams = field(
23
- default_factory=native.GateNoiseParams
24
- )
25
- noise_model: native.MoveNoiseModelABC = field(
26
- default_factory=native.TwoRowZoneModel
27
- )
21
+ noise_model: noise.MoveNoiseModelABC = field(default_factory=noise.TwoRowZoneModel)
28
22
 
29
23
  def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
30
24
  if isinstance(node, uop.SingleQubitGate):
@@ -46,22 +40,18 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
46
40
  qargs: ir.SSAValue,
47
41
  probs: Tuple[float, float, float, float],
48
42
  ):
49
- native.PauliChannel(qargs, px=probs[0], py=probs[1], pz=probs[2]).insert_before(
43
+ noise.PauliChannel(qargs, px=probs[0], py=probs[1], pz=probs[2]).insert_before(
50
44
  node
51
45
  )
52
- native.AtomLossChannel(qargs, prob=probs[3]).insert_before(node)
46
+ noise.AtomLossChannel(qargs, prob=probs[3]).insert_before(node)
53
47
 
54
48
  return rewrite_abc.RewriteResult(has_done_something=True)
55
49
 
56
50
  def rewrite_single_qubit_gate(self, node: uop.SingleQubitGate):
57
- probs = (
58
- self.gate_noise_params.local_px,
59
- self.gate_noise_params.local_py,
60
- self.gate_noise_params.local_pz,
61
- self.gate_noise_params.local_loss_prob,
62
- )
63
51
  (qargs := ilist.New(values=(node.qarg,))).insert_before(node)
64
- return self.insert_single_qubit_noise(node, qargs.result, probs)
52
+ return self.insert_single_qubit_noise(
53
+ node, qargs.result, self.noise_model.local_errors
54
+ )
65
55
 
66
56
  def rewrite_global_single_qubit_gate(self, node: glob.UGate):
67
57
  addrs = self.address_analysis[node.registers]
@@ -77,14 +67,10 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
77
67
  for qid in addr.data:
78
68
  qargs.append(self.qubit_ssa_value[qid])
79
69
 
80
- probs = (
81
- self.gate_noise_params.global_px,
82
- self.gate_noise_params.global_py,
83
- self.gate_noise_params.global_pz,
84
- self.gate_noise_params.global_loss_prob,
85
- )
86
70
  (qargs := ilist.New(values=tuple(qargs))).insert_before(node)
87
- return self.insert_single_qubit_noise(node, qargs.result, probs)
71
+ return self.insert_single_qubit_noise(
72
+ node, qargs.result, self.noise_model.global_errors
73
+ )
88
74
 
89
75
  def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
90
76
  addrs = self.address_analysis[node.qargs]
@@ -94,15 +80,11 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
94
80
  if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
95
81
  return rewrite_abc.RewriteResult()
96
82
 
97
- probs = (
98
- self.gate_noise_params.local_px,
99
- self.gate_noise_params.local_py,
100
- self.gate_noise_params.local_pz,
101
- self.gate_noise_params.local_loss_prob,
102
- )
103
83
  assert isinstance(node.qargs, ir.ResultValue)
104
84
  assert isinstance(node.qargs.stmt, ilist.New)
105
- return self.insert_single_qubit_noise(node, node.qargs, probs)
85
+ return self.insert_single_qubit_noise(
86
+ node, node.qargs, self.noise_model.local_errors
87
+ )
106
88
 
107
89
  def move_noise_stmts(
108
90
  self,
@@ -118,9 +100,9 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
118
100
  nodes.append(
119
101
  qargs := ilist.New(tuple(self.qubit_ssa_value[q] for q in qubits))
120
102
  )
121
- nodes.append(native.AtomLossChannel(qargs.result, prob=probs[3]))
103
+ nodes.append(noise.AtomLossChannel(qargs.result, prob=probs[3]))
122
104
  nodes.append(
123
- native.PauliChannel(qargs.result, px=probs[0], py=probs[1], pz=probs[2])
105
+ noise.PauliChannel(qargs.result, px=probs[0], py=probs[1], pz=probs[2])
124
106
  )
125
107
 
126
108
  return nodes
@@ -131,34 +113,30 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
131
113
  qargs: ir.SSAValue,
132
114
  ) -> list[ir.Statement]:
133
115
  return [
134
- native.CZPauliChannel(
116
+ noise.CZPauliChannel(
135
117
  ctrls,
136
118
  qargs,
137
- px_ctrl=self.gate_noise_params.cz_paired_gate_px,
138
- py_ctrl=self.gate_noise_params.cz_paired_gate_py,
139
- pz_ctrl=self.gate_noise_params.cz_paired_gate_pz,
140
- px_qarg=self.gate_noise_params.cz_paired_gate_px,
141
- py_qarg=self.gate_noise_params.cz_paired_gate_py,
142
- pz_qarg=self.gate_noise_params.cz_paired_gate_pz,
119
+ px_ctrl=self.noise_model.cz_paired_gate_px,
120
+ py_ctrl=self.noise_model.cz_paired_gate_py,
121
+ pz_ctrl=self.noise_model.cz_paired_gate_pz,
122
+ px_qarg=self.noise_model.cz_paired_gate_px,
123
+ py_qarg=self.noise_model.cz_paired_gate_py,
124
+ pz_qarg=self.noise_model.cz_paired_gate_pz,
143
125
  paired=True,
144
126
  ),
145
- native.CZPauliChannel(
127
+ noise.CZPauliChannel(
146
128
  ctrls,
147
129
  qargs,
148
- px_ctrl=self.gate_noise_params.cz_unpaired_gate_px,
149
- py_ctrl=self.gate_noise_params.cz_unpaired_gate_py,
150
- pz_ctrl=self.gate_noise_params.cz_unpaired_gate_pz,
151
- px_qarg=self.gate_noise_params.cz_unpaired_gate_px,
152
- py_qarg=self.gate_noise_params.cz_unpaired_gate_py,
153
- pz_qarg=self.gate_noise_params.cz_unpaired_gate_pz,
130
+ px_ctrl=self.noise_model.cz_unpaired_gate_px,
131
+ py_ctrl=self.noise_model.cz_unpaired_gate_py,
132
+ pz_ctrl=self.noise_model.cz_unpaired_gate_pz,
133
+ px_qarg=self.noise_model.cz_unpaired_gate_px,
134
+ py_qarg=self.noise_model.cz_unpaired_gate_py,
135
+ pz_qarg=self.noise_model.cz_unpaired_gate_pz,
154
136
  paired=False,
155
137
  ),
156
- native.AtomLossChannel(
157
- ctrls, prob=self.gate_noise_params.cz_gate_loss_prob
158
- ),
159
- native.AtomLossChannel(
160
- qargs, prob=self.gate_noise_params.cz_gate_loss_prob
161
- ),
138
+ noise.AtomLossChannel(ctrls, prob=self.noise_model.cz_gate_loss_prob),
139
+ noise.AtomLossChannel(qargs, prob=self.noise_model.cz_gate_loss_prob),
162
140
  ]
163
141
 
164
142
  def rewrite_cz_gate(self, node: uop.CZ):
@@ -4,8 +4,8 @@ from kirin import ir
4
4
  from kirin.rewrite import abc, dce, walk, fixpoint
5
5
  from kirin.passes.abc import Pass
6
6
 
7
- from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
8
- from ._dialect import dialect
7
+ from ...dialects.noise.stmts import PauliChannel, CZPauliChannel, AtomLossChannel
8
+ from ...dialects.noise._dialect import dialect
9
9
 
10
10
 
11
11
  class RemoveNoiseRewrite(abc.RewriteRule):
@@ -4,13 +4,13 @@ from dataclasses import field, dataclass
4
4
  from kirin import ir, types, passes
5
5
  from kirin.dialects import func, ilist
6
6
 
7
- from bloqade import noise, qasm2
7
+ from bloqade import qasm2
8
8
  from bloqade.qbraid import schema
9
- from bloqade.qasm2.dialects import glob, parallel
9
+ from bloqade.qasm2.dialects import glob, noise, parallel
10
10
 
11
11
 
12
12
  @ir.dialect_group(
13
- [func, qasm2.core, qasm2.uop, parallel, glob, qasm2.expr, noise.native, ilist]
13
+ [func, qasm2.core, qasm2.uop, parallel, glob, qasm2.expr, noise, ilist]
14
14
  )
15
15
  def qbraid_noise(
16
16
  self,
@@ -192,7 +192,7 @@ class Lowering:
192
192
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qubits))
193
193
  )
194
194
  self.block_list.append(
195
- noise.native.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
195
+ noise.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
196
196
  )
197
197
 
198
198
  for (p_ctrl, p_qarg), qubits in paired_layers.items():
@@ -204,7 +204,7 @@ class Lowering:
204
204
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qargs))
205
205
  )
206
206
  self.block_list.append(
207
- noise.native.CZPauliChannel(
207
+ noise.CZPauliChannel(
208
208
  paired=True,
209
209
  px_ctrl=p_ctrl[0],
210
210
  py_ctrl=p_ctrl[1],
@@ -226,7 +226,7 @@ class Lowering:
226
226
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qargs))
227
227
  )
228
228
  self.block_list.append(
229
- noise.native.CZPauliChannel(
229
+ noise.CZPauliChannel(
230
230
  paired=False,
231
231
  px_ctrl=p_ctrl[0],
232
232
  py_ctrl=p_ctrl[1],
@@ -285,7 +285,7 @@ class Lowering:
285
285
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qubits))
286
286
  )
287
287
  self.block_list.append(
288
- noise.native.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
288
+ noise.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
289
289
  )
290
290
 
291
291
  def lower_measurement(self, operation: schema.Measurement):
@@ -303,7 +303,7 @@ class Lowering:
303
303
  for survival_prob, qubits in layers.items():
304
304
  self.block_list.append(qargs := ilist.New(values=qubits))
305
305
  self.block_list.append(
306
- noise.native.AtomLossChannel(prob=survival_prob, qargs=qargs.result)
306
+ noise.AtomLossChannel(prob=survival_prob, qargs=qargs.result)
307
307
  )
308
308
 
309
309
  def lower_number(self, value: float | int) -> ir.SSAValue:
bloqade/squin/__init__.py CHANGED
@@ -1,2 +1,8 @@
1
- from . import op as op, wire as wire, noise as noise, qubit as qubit
1
+ from . import (
2
+ op as op,
3
+ wire as wire,
4
+ noise as noise,
5
+ qubit as qubit,
6
+ lowering as lowering,
7
+ )
2
8
  from .groups import wired as wired, kernel as kernel
@@ -23,15 +23,6 @@ class SquinWire(interp.MethodTable):
23
23
 
24
24
  return tuple(frame.get(input) for input in stmt.inputs)
25
25
 
26
- @interp.impl(wire.MeasureAndReset)
27
- def measure_and_reset(
28
- self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset
29
- ):
30
-
31
- # MeasureAndReset produces both a new wire
32
- # and an integer which don't have any sites at all
33
- return (NoSites(), NoSites())
34
-
35
26
 
36
27
  @op.dialect.register(key="op.nsites")
37
28
  class SquinOp(interp.MethodTable):
bloqade/squin/groups.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
+ from kirin.rewrite import Walk, Chain
3
4
  from kirin.dialects import ilist
4
- from kirin.rewrite.walk import Walk
5
5
 
6
6
  from . import op, wire, qubit
7
7
  from .op.rewrite import PyMultToSquinMult
8
- from .rewrite.measure_desugar import MeasureDesugarRule
8
+ from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
9
9
 
10
10
 
11
11
  @ir.dialect_group(structural_no_opt.union([op, qubit]))
@@ -13,7 +13,7 @@ def kernel(self):
13
13
  fold_pass = passes.Fold(self)
14
14
  typeinfer_pass = passes.TypeInfer(self)
15
15
  ilist_desugar_pass = ilist.IListDesugar(self)
16
- measure_desugar_pass = Walk(MeasureDesugarRule())
16
+ desugar_pass = Walk(Chain(MeasureDesugarRule(), ApplyDesugarRule()))
17
17
  py_mult_to_mult_pass = PyMultToSquinMult(self)
18
18
 
19
19
  def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
@@ -25,7 +25,7 @@ def kernel(self):
25
25
 
26
26
  if typeinfer:
27
27
  typeinfer_pass(method)
28
- measure_desugar_pass.rewrite(method.code)
28
+ desugar_pass.rewrite(method.code)
29
29
 
30
30
  ilist_desugar_pass(method)
31
31
 
@@ -0,0 +1,27 @@
1
+ import ast
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import lowering
5
+
6
+ from . import qubit
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class ApplyAnyCallLowering(lowering.FromPythonCall["qubit.ApplyAny"]):
11
+ """
12
+ Custom lowering for ApplyAny that collects vararg qubits into a single tuple argument
13
+ """
14
+
15
+ def lower(
16
+ self, stmt: type["qubit.ApplyAny"], state: lowering.State, node: ast.Call
17
+ ):
18
+ if len(node.args) < 2:
19
+ raise lowering.BuildError(
20
+ "Apply requires at least one operator and one qubit as arguments!"
21
+ )
22
+ op, *qubits = node.args
23
+ op_ssa = state.lower(op).expect_one()
24
+ qubits_lowered = [state.lower(qbit).expect_one() for qbit in qubits]
25
+
26
+ s = stmt(op_ssa, tuple(qubits_lowered))
27
+ return state.current_frame.push(s)
@@ -25,6 +25,7 @@ from ._wrapper import (
25
25
  kron as kron,
26
26
  mult as mult,
27
27
  phase as phase,
28
+ reset as reset,
28
29
  scale as scale,
29
30
  shift as shift,
30
31
  spin_n as spin_n,
@@ -37,6 +37,10 @@ def control(op: types.Op, *, n_controls: int) -> types.Op:
37
37
  ...
38
38
 
39
39
 
40
+ @wraps(stmts.Reset)
41
+ def reset() -> types.Op: ...
42
+
43
+
40
44
  @wraps(stmts.Identity)
41
45
  def identity(*, sites: int) -> types.Op: ...
42
46
 
bloqade/squin/op/stmts.py CHANGED
@@ -142,6 +142,16 @@ class ShiftOp(PrimitiveOp):
142
142
  result: ir.ResultValue = info.result(OpType)
143
143
 
144
144
 
145
+ @statement(dialect=dialect)
146
+ class Reset(PrimitiveOp):
147
+ """
148
+ Reset operator for qubits or wires.
149
+ """
150
+
151
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
152
+ result: ir.ResultValue = info.result(OpType)
153
+
154
+
145
155
  @statement
146
156
  class PauliOp(ConstantUnitary):
147
157
  pass
bloqade/squin/qubit.py CHANGED
@@ -17,6 +17,8 @@ from kirin.lowering import wraps
17
17
  from bloqade.types import Qubit, QubitType
18
18
  from bloqade.squin.op.types import Op, OpType
19
19
 
20
+ from .lowering import ApplyAnyCallLowering
21
+
20
22
  dialect = ir.Dialect("squin.qubit")
21
23
 
22
24
 
@@ -34,6 +36,14 @@ class Apply(ir.Statement):
34
36
  qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
35
37
 
36
38
 
39
+ @statement(dialect=dialect)
40
+ class ApplyAny(ir.Statement):
41
+ # NOTE: custom lowering to deal with vararg calls
42
+ traits = frozenset({ApplyAnyCallLowering()})
43
+ operator: ir.SSAValue = info.argument(OpType)
44
+ qubits: tuple[ir.SSAValue, ...] = info.argument()
45
+
46
+
37
47
  @statement(dialect=dialect)
38
48
  class Broadcast(ir.Statement):
39
49
  traits = frozenset({lowering.FromPythonCall()})
@@ -68,19 +78,6 @@ class MeasureQubitList(ir.Statement):
68
78
  result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
69
79
 
70
80
 
71
- @statement(dialect=dialect)
72
- class MeasureAndReset(ir.Statement):
73
- traits = frozenset({lowering.FromPythonCall()})
74
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
75
- result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
76
-
77
-
78
- @statement(dialect=dialect)
79
- class Reset(ir.Statement):
80
- traits = frozenset({lowering.FromPythonCall()})
81
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
82
-
83
-
84
81
  # NOTE: no dependent types in Python, so we have to mark it Any...
85
82
  @wraps(New)
86
83
  def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
@@ -95,7 +92,7 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
95
92
  ...
96
93
 
97
94
 
98
- @wraps(Apply)
95
+ @overload
99
96
  def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
100
97
  """Apply an operator to a list of qubits.
101
98
 
@@ -112,6 +109,27 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
112
109
  ...
113
110
 
114
111
 
112
+ @overload
113
+ def apply(operator: Op, *qubits: Qubit) -> None:
114
+ """Apply and operator to any number of qubits.
115
+
116
+ Note, that when considering atom loss, lost qubits will be skipped.
117
+
118
+ Args:
119
+ operator: The operator to apply.
120
+ *qubits: The qubits to apply the operator to. The number of qubits must
121
+ match the size of the operator.
122
+
123
+ Returns:
124
+ None
125
+ """
126
+ ...
127
+
128
+
129
+ @wraps(ApplyAny)
130
+ def apply(operator: Op, *qubits) -> None: ...
131
+
132
+
115
133
  @overload
116
134
  def measure(input: Qubit) -> bool: ...
117
135
  @overload
@@ -161,26 +179,3 @@ def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> No
161
179
  None
162
180
  """
163
181
  ...
164
-
165
-
166
- @wraps(MeasureAndReset)
167
- def measure_and_reset(qubits: ilist.IList[Qubit, Any]) -> ilist.IList[bool, Any]:
168
- """Measure the qubits in the list and reset them."
169
-
170
- Args:
171
- qubits: The list of qubits to measure and reset.
172
-
173
- Returns:
174
- list[bool]: The result of the measurement.
175
- """
176
- ...
177
-
178
-
179
- @wraps(Reset)
180
- def reset(qubits: ilist.IList[Qubit, Any]) -> None:
181
- """Reset the qubits in the list."
182
-
183
- Args:
184
- qubits: The list of qubits to reset.
185
- """
186
- ...
@@ -0,0 +1,65 @@
1
+ from kirin import ir, types
2
+ from kirin.dialects import ilist
3
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
4
+
5
+ from bloqade.squin.qubit import (
6
+ Apply,
7
+ ApplyAny,
8
+ QubitType,
9
+ MeasureAny,
10
+ MeasureQubit,
11
+ MeasureQubitList,
12
+ )
13
+
14
+
15
+ class MeasureDesugarRule(RewriteRule):
16
+ """
17
+ Desugar measure operations in the circuit.
18
+ """
19
+
20
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
+
22
+ if not isinstance(node, MeasureAny):
23
+ return RewriteResult()
24
+
25
+ if node.input.type.is_subseteq(QubitType):
26
+ node.replace_by(
27
+ MeasureQubit(
28
+ qubit=node.input,
29
+ )
30
+ )
31
+ return RewriteResult(has_done_something=True)
32
+ elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
33
+ node.replace_by(
34
+ MeasureQubitList(
35
+ qubits=node.input,
36
+ )
37
+ )
38
+ return RewriteResult(has_done_something=True)
39
+
40
+ return RewriteResult()
41
+
42
+
43
+ class ApplyDesugarRule(RewriteRule):
44
+ """
45
+ Desugar apply operators in the kernel.
46
+ """
47
+
48
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
49
+
50
+ if not isinstance(node, ApplyAny):
51
+ return RewriteResult()
52
+
53
+ op = node.operator
54
+ qubits = node.qubits
55
+
56
+ if len(qubits) == 1 and qubits[0].type.is_subseteq(ilist.IListType):
57
+ # NOTE: already calling with just a single argument that is already an ilist
58
+ qubits_ilist = qubits[0]
59
+ else:
60
+ (qubits_ilist_stmt := ilist.New(values=qubits)).insert_before(node)
61
+ qubits_ilist = qubits_ilist_stmt.result
62
+
63
+ stmt = Apply(operator=op, qubits=qubits_ilist)
64
+ node.replace_by(stmt)
65
+ return RewriteResult(has_done_something=True)