bloqade-circuit 0.2.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (101) hide show
  1. bloqade/analysis/address/impls.py +3 -2
  2. bloqade/pyqrack/device.py +1 -3
  3. bloqade/pyqrack/noise/native.py +8 -8
  4. bloqade/pyqrack/qasm2/core.py +4 -1
  5. bloqade/pyqrack/squin/op.py +7 -0
  6. bloqade/pyqrack/squin/qubit.py +5 -27
  7. bloqade/pyqrack/squin/runtime.py +18 -0
  8. bloqade/pyqrack/squin/wire.py +4 -22
  9. bloqade/pyqrack/task.py +13 -5
  10. bloqade/qasm2/__init__.py +1 -0
  11. bloqade/qasm2/_qasm_loading.py +151 -0
  12. bloqade/qasm2/dialects/core/__init__.py +9 -1
  13. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  14. bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
  15. bloqade/qasm2/dialects/noise/_dialect.py +3 -0
  16. bloqade/{noise → qasm2/dialects/noise}/fidelity.py +4 -4
  17. bloqade/qasm2/dialects/noise/model.py +278 -0
  18. bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +1 -1
  19. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  20. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  21. bloqade/qasm2/emit/impls/__init__.py +1 -0
  22. bloqade/qasm2/emit/impls/noise.py +89 -0
  23. bloqade/qasm2/emit/main.py +23 -4
  24. bloqade/qasm2/emit/target.py +19 -4
  25. bloqade/qasm2/noise.py +67 -0
  26. bloqade/qasm2/parse/__init__.py +7 -4
  27. bloqade/qasm2/parse/lowering.py +20 -130
  28. bloqade/qasm2/parse/qasm2.lark +1 -1
  29. bloqade/qasm2/passes/__init__.py +1 -0
  30. bloqade/qasm2/passes/fold.py +6 -0
  31. bloqade/qasm2/passes/glob.py +12 -8
  32. bloqade/qasm2/passes/noise.py +27 -16
  33. bloqade/qasm2/passes/parallel.py +9 -0
  34. bloqade/qasm2/passes/unroll_if.py +25 -0
  35. bloqade/qasm2/rewrite/__init__.py +3 -0
  36. bloqade/qasm2/rewrite/desugar.py +3 -2
  37. bloqade/qasm2/rewrite/native_gates.py +67 -4
  38. bloqade/qasm2/rewrite/noise/__init__.py +0 -0
  39. bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +32 -62
  40. bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
  41. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  42. bloqade/qbraid/lowering.py +8 -8
  43. bloqade/squin/__init__.py +7 -1
  44. bloqade/squin/analysis/nsites/__init__.py +1 -0
  45. bloqade/squin/analysis/nsites/impls.py +16 -1
  46. bloqade/squin/groups.py +4 -4
  47. bloqade/squin/lowering.py +27 -0
  48. bloqade/squin/noise/__init__.py +7 -26
  49. bloqade/squin/noise/_wrapper.py +25 -0
  50. bloqade/squin/op/__init__.py +34 -159
  51. bloqade/squin/op/_wrapper.py +105 -0
  52. bloqade/squin/op/stdlib.py +62 -0
  53. bloqade/squin/op/stmts.py +10 -0
  54. bloqade/squin/passes/__init__.py +1 -0
  55. bloqade/squin/passes/stim.py +68 -0
  56. bloqade/squin/qubit.py +32 -37
  57. bloqade/squin/rewrite/__init__.py +11 -0
  58. bloqade/squin/rewrite/desugar.py +65 -0
  59. bloqade/squin/rewrite/qubit_to_stim.py +61 -0
  60. bloqade/squin/rewrite/squin_measure.py +73 -0
  61. bloqade/squin/rewrite/stim_rewrite_util.py +153 -0
  62. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  63. bloqade/squin/rewrite/wire_to_stim.py +52 -0
  64. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  65. bloqade/squin/wire.py +5 -22
  66. bloqade/stim/__init__.py +40 -5
  67. bloqade/stim/_wrappers.py +18 -12
  68. bloqade/stim/dialects/__init__.py +1 -5
  69. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +13 -1
  70. bloqade/stim/dialects/{aux → auxiliary}/emit.py +18 -3
  71. bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +1 -0
  72. bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +8 -0
  73. bloqade/stim/dialects/collapse/__init__.py +13 -2
  74. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +4 -2
  75. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  76. bloqade/stim/dialects/gate/__init__.py +16 -1
  77. bloqade/stim/dialects/gate/emit.py +10 -3
  78. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  79. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  80. bloqade/stim/dialects/noise/emit.py +33 -2
  81. bloqade/stim/dialects/noise/stmts.py +29 -0
  82. bloqade/stim/emit/__init__.py +1 -1
  83. bloqade/stim/groups.py +4 -2
  84. bloqade/stim/parse/__init__.py +1 -0
  85. bloqade/stim/parse/lowering.py +686 -0
  86. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/METADATA +5 -3
  87. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/RECORD +95 -77
  88. bloqade/noise/__init__.py +0 -2
  89. bloqade/noise/native/_dialect.py +0 -3
  90. bloqade/noise/native/_wrappers.py +0 -34
  91. bloqade/noise/native/model.py +0 -346
  92. bloqade/qasm2/dialects/noise.py +0 -16
  93. bloqade/squin/rewrite/measure_desugar.py +0 -33
  94. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  95. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  96. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  97. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  98. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  99. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  100. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/WHEEL +0 -0
  101. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,9 +5,8 @@ from kirin import ir
5
5
  from kirin.rewrite import abc as rewrite_abc
6
6
  from kirin.dialects import ilist
7
7
 
8
- from bloqade.noise import native
9
8
  from bloqade.analysis import address
10
- from bloqade.qasm2.dialects import uop, glob, parallel
9
+ from bloqade.qasm2.dialects import uop, glob, noise, parallel
11
10
 
12
11
 
13
12
  @dataclass
@@ -18,21 +17,8 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
18
17
  """
19
18
 
20
19
  address_analysis: Dict[ir.SSAValue, address.Address]
21
- gate_noise_params: native.GateNoiseParams = field(
22
- default_factory=native.GateNoiseParams
23
- )
24
- noise_model: native.MoveNoiseModelABC = field(
25
- default_factory=native.TwoRowZoneModel
26
- )
27
-
28
- def __post_init__(self):
29
- self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
30
- for ssa_value, addr in self.address_analysis.items():
31
- if (
32
- isinstance(addr, address.AddressQubit)
33
- and ssa_value not in self.qubit_ssa_value
34
- ):
35
- self.qubit_ssa_value[addr.data] = ssa_value
20
+ qubit_ssa_value: Dict[int, ir.SSAValue]
21
+ noise_model: noise.MoveNoiseModelABC = field(default_factory=noise.TwoRowZoneModel)
36
22
 
37
23
  def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
38
24
  if isinstance(node, uop.SingleQubitGate):
@@ -54,22 +40,18 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
54
40
  qargs: ir.SSAValue,
55
41
  probs: Tuple[float, float, float, float],
56
42
  ):
57
- native.PauliChannel(qargs, px=probs[0], py=probs[1], pz=probs[2]).insert_before(
43
+ noise.PauliChannel(qargs, px=probs[0], py=probs[1], pz=probs[2]).insert_before(
58
44
  node
59
45
  )
60
- native.AtomLossChannel(qargs, prob=probs[3]).insert_before(node)
46
+ noise.AtomLossChannel(qargs, prob=probs[3]).insert_before(node)
61
47
 
62
48
  return rewrite_abc.RewriteResult(has_done_something=True)
63
49
 
64
50
  def rewrite_single_qubit_gate(self, node: uop.SingleQubitGate):
65
- probs = (
66
- self.gate_noise_params.local_px,
67
- self.gate_noise_params.local_py,
68
- self.gate_noise_params.local_pz,
69
- self.gate_noise_params.local_loss_prob,
70
- )
71
51
  (qargs := ilist.New(values=(node.qarg,))).insert_before(node)
72
- return self.insert_single_qubit_noise(node, qargs.result, probs)
52
+ return self.insert_single_qubit_noise(
53
+ node, qargs.result, self.noise_model.local_errors
54
+ )
73
55
 
74
56
  def rewrite_global_single_qubit_gate(self, node: glob.UGate):
75
57
  addrs = self.address_analysis[node.registers]
@@ -85,14 +67,10 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
85
67
  for qid in addr.data:
86
68
  qargs.append(self.qubit_ssa_value[qid])
87
69
 
88
- probs = (
89
- self.gate_noise_params.global_px,
90
- self.gate_noise_params.global_py,
91
- self.gate_noise_params.global_pz,
92
- self.gate_noise_params.global_loss_prob,
93
- )
94
70
  (qargs := ilist.New(values=tuple(qargs))).insert_before(node)
95
- return self.insert_single_qubit_noise(node, qargs.result, probs)
71
+ return self.insert_single_qubit_noise(
72
+ node, qargs.result, self.noise_model.global_errors
73
+ )
96
74
 
97
75
  def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
98
76
  addrs = self.address_analysis[node.qargs]
@@ -102,15 +80,11 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
102
80
  if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
103
81
  return rewrite_abc.RewriteResult()
104
82
 
105
- probs = (
106
- self.gate_noise_params.local_px,
107
- self.gate_noise_params.local_py,
108
- self.gate_noise_params.local_pz,
109
- self.gate_noise_params.local_loss_prob,
110
- )
111
83
  assert isinstance(node.qargs, ir.ResultValue)
112
84
  assert isinstance(node.qargs.stmt, ilist.New)
113
- return self.insert_single_qubit_noise(node, node.qargs, probs)
85
+ return self.insert_single_qubit_noise(
86
+ node, node.qargs, self.noise_model.local_errors
87
+ )
114
88
 
115
89
  def move_noise_stmts(
116
90
  self,
@@ -126,9 +100,9 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
126
100
  nodes.append(
127
101
  qargs := ilist.New(tuple(self.qubit_ssa_value[q] for q in qubits))
128
102
  )
129
- nodes.append(native.AtomLossChannel(qargs.result, prob=probs[3]))
103
+ nodes.append(noise.AtomLossChannel(qargs.result, prob=probs[3]))
130
104
  nodes.append(
131
- native.PauliChannel(qargs.result, px=probs[0], py=probs[1], pz=probs[2])
105
+ noise.PauliChannel(qargs.result, px=probs[0], py=probs[1], pz=probs[2])
132
106
  )
133
107
 
134
108
  return nodes
@@ -139,34 +113,30 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
139
113
  qargs: ir.SSAValue,
140
114
  ) -> list[ir.Statement]:
141
115
  return [
142
- native.CZPauliChannel(
116
+ noise.CZPauliChannel(
143
117
  ctrls,
144
118
  qargs,
145
- px_ctrl=self.gate_noise_params.cz_paired_gate_px,
146
- py_ctrl=self.gate_noise_params.cz_paired_gate_py,
147
- pz_ctrl=self.gate_noise_params.cz_paired_gate_pz,
148
- px_qarg=self.gate_noise_params.cz_paired_gate_px,
149
- py_qarg=self.gate_noise_params.cz_paired_gate_py,
150
- pz_qarg=self.gate_noise_params.cz_paired_gate_pz,
119
+ px_ctrl=self.noise_model.cz_paired_gate_px,
120
+ py_ctrl=self.noise_model.cz_paired_gate_py,
121
+ pz_ctrl=self.noise_model.cz_paired_gate_pz,
122
+ px_qarg=self.noise_model.cz_paired_gate_px,
123
+ py_qarg=self.noise_model.cz_paired_gate_py,
124
+ pz_qarg=self.noise_model.cz_paired_gate_pz,
151
125
  paired=True,
152
126
  ),
153
- native.CZPauliChannel(
127
+ noise.CZPauliChannel(
154
128
  ctrls,
155
129
  qargs,
156
- px_ctrl=self.gate_noise_params.cz_unpaired_gate_px,
157
- py_ctrl=self.gate_noise_params.cz_unpaired_gate_py,
158
- pz_ctrl=self.gate_noise_params.cz_unpaired_gate_pz,
159
- px_qarg=self.gate_noise_params.cz_unpaired_gate_px,
160
- py_qarg=self.gate_noise_params.cz_unpaired_gate_py,
161
- pz_qarg=self.gate_noise_params.cz_unpaired_gate_pz,
130
+ px_ctrl=self.noise_model.cz_unpaired_gate_px,
131
+ py_ctrl=self.noise_model.cz_unpaired_gate_py,
132
+ pz_ctrl=self.noise_model.cz_unpaired_gate_pz,
133
+ px_qarg=self.noise_model.cz_unpaired_gate_px,
134
+ py_qarg=self.noise_model.cz_unpaired_gate_py,
135
+ pz_qarg=self.noise_model.cz_unpaired_gate_pz,
162
136
  paired=False,
163
137
  ),
164
- native.AtomLossChannel(
165
- ctrls, prob=self.gate_noise_params.cz_gate_loss_prob
166
- ),
167
- native.AtomLossChannel(
168
- qargs, prob=self.gate_noise_params.cz_gate_loss_prob
169
- ),
138
+ noise.AtomLossChannel(ctrls, prob=self.noise_model.cz_gate_loss_prob),
139
+ noise.AtomLossChannel(qargs, prob=self.noise_model.cz_gate_loss_prob),
170
140
  ]
171
141
 
172
142
  def rewrite_cz_gate(self, node: uop.CZ):
@@ -4,8 +4,8 @@ from kirin import ir
4
4
  from kirin.rewrite import abc, dce, walk, fixpoint
5
5
  from kirin.passes.abc import Pass
6
6
 
7
- from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
8
- from ._dialect import dialect
7
+ from ...dialects.noise.stmts import PauliChannel, CZPauliChannel, AtomLossChannel
8
+ from ...dialects.noise._dialect import dialect
9
9
 
10
10
 
11
11
  class RemoveNoiseRewrite(abc.RewriteRule):
@@ -0,0 +1,66 @@
1
+ from kirin import ir
2
+ from kirin.dialects import scf, func
3
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
4
+
5
+ from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6
+ from ..dialects.core.stmts import Reset, Measure
7
+
8
+ # TODO: unify with PR #248
9
+ AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
10
+
11
+ DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
12
+
13
+
14
+ class LiftThenBody(RewriteRule):
15
+ """Lifts anything that's not a UOP or a yield/return out of the then body"""
16
+
17
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
+ if not isinstance(node, scf.IfElse):
19
+ return RewriteResult()
20
+
21
+ then_stmts = node.then_body.stmts()
22
+
23
+ lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
24
+
25
+ if len(lift_stmts) == 0:
26
+ return RewriteResult()
27
+
28
+ for stmt in lift_stmts:
29
+ stmt.detach()
30
+ stmt.insert_before(node)
31
+
32
+ return RewriteResult(has_done_something=True)
33
+
34
+
35
+ class SplitIfStmts(RewriteRule):
36
+ """Splits the then body of an if-else statement into multiple if statements"""
37
+
38
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
39
+ if not isinstance(node, scf.IfElse):
40
+ return RewriteResult()
41
+
42
+ *stmts, yield_or_return = node.then_body.stmts()
43
+
44
+ if len(stmts) == 1:
45
+ return RewriteResult()
46
+
47
+ is_yield = isinstance(yield_or_return, scf.Yield)
48
+
49
+ for stmt in stmts:
50
+ stmt.detach()
51
+
52
+ yield_or_return = scf.Yield() if is_yield else func.Return()
53
+
54
+ then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,))
55
+ then_body = ir.Region(then_block)
56
+ else_body = node.else_body.clone()
57
+ else_body.detach()
58
+ new_if = scf.IfElse(
59
+ cond=node.cond, then_body=then_body, else_body=else_body
60
+ )
61
+
62
+ new_if.insert_before(node)
63
+
64
+ node.delete()
65
+
66
+ return RewriteResult(has_done_something=True)
@@ -4,13 +4,13 @@ from dataclasses import field, dataclass
4
4
  from kirin import ir, types, passes
5
5
  from kirin.dialects import func, ilist
6
6
 
7
- from bloqade import noise, qasm2
7
+ from bloqade import qasm2
8
8
  from bloqade.qbraid import schema
9
- from bloqade.qasm2.dialects import glob, parallel
9
+ from bloqade.qasm2.dialects import glob, noise, parallel
10
10
 
11
11
 
12
12
  @ir.dialect_group(
13
- [func, qasm2.core, qasm2.uop, parallel, glob, qasm2.expr, noise.native, ilist]
13
+ [func, qasm2.core, qasm2.uop, parallel, glob, qasm2.expr, noise, ilist]
14
14
  )
15
15
  def qbraid_noise(
16
16
  self,
@@ -192,7 +192,7 @@ class Lowering:
192
192
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qubits))
193
193
  )
194
194
  self.block_list.append(
195
- noise.native.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
195
+ noise.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
196
196
  )
197
197
 
198
198
  for (p_ctrl, p_qarg), qubits in paired_layers.items():
@@ -204,7 +204,7 @@ class Lowering:
204
204
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qargs))
205
205
  )
206
206
  self.block_list.append(
207
- noise.native.CZPauliChannel(
207
+ noise.CZPauliChannel(
208
208
  paired=True,
209
209
  px_ctrl=p_ctrl[0],
210
210
  py_ctrl=p_ctrl[1],
@@ -226,7 +226,7 @@ class Lowering:
226
226
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qargs))
227
227
  )
228
228
  self.block_list.append(
229
- noise.native.CZPauliChannel(
229
+ noise.CZPauliChannel(
230
230
  paired=False,
231
231
  px_ctrl=p_ctrl[0],
232
232
  py_ctrl=p_ctrl[1],
@@ -285,7 +285,7 @@ class Lowering:
285
285
  qargs := ilist.New(values=tuple(self.qubit_id_map[q] for q in qubits))
286
286
  )
287
287
  self.block_list.append(
288
- noise.native.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
288
+ noise.PauliChannel(px=px, py=py, pz=pz, qargs=qargs.result)
289
289
  )
290
290
 
291
291
  def lower_measurement(self, operation: schema.Measurement):
@@ -303,7 +303,7 @@ class Lowering:
303
303
  for survival_prob, qubits in layers.items():
304
304
  self.block_list.append(qargs := ilist.New(values=qubits))
305
305
  self.block_list.append(
306
- noise.native.AtomLossChannel(prob=survival_prob, qargs=qargs.result)
306
+ noise.AtomLossChannel(prob=survival_prob, qargs=qargs.result)
307
307
  )
308
308
 
309
309
  def lower_number(self, value: float | int) -> ir.SSAValue:
bloqade/squin/__init__.py CHANGED
@@ -1,2 +1,8 @@
1
- from . import op as op, wire as wire, noise as noise, qubit as qubit
1
+ from . import (
2
+ op as op,
3
+ wire as wire,
4
+ noise as noise,
5
+ qubit as qubit,
6
+ lowering as lowering,
7
+ )
2
8
  from .groups import wired as wired, kernel as kernel
@@ -1,6 +1,7 @@
1
1
  # Need this for impl registration to work properly!
2
2
  from . import impls as impls
3
3
  from .lattice import (
4
+ Sites as Sites,
4
5
  NoSites as NoSites,
5
6
  AnySites as AnySites,
6
7
  NumberSites as NumberSites,
@@ -1,6 +1,6 @@
1
1
  from kirin import interp
2
2
 
3
- from bloqade.squin import op
3
+ from bloqade.squin import op, wire
4
4
 
5
5
  from .lattice import (
6
6
  NoSites,
@@ -9,6 +9,21 @@ from .lattice import (
9
9
  from .analysis import NSitesAnalysis
10
10
 
11
11
 
12
+ @wire.dialect.register(key="op.nsites")
13
+ class SquinWire(interp.MethodTable):
14
+
15
+ @interp.impl(wire.Apply)
16
+ @interp.impl(wire.Broadcast)
17
+ def apply(
18
+ self,
19
+ interp: NSitesAnalysis,
20
+ frame: interp.Frame,
21
+ stmt: wire.Apply | wire.Broadcast,
22
+ ):
23
+
24
+ return tuple(frame.get(input) for input in stmt.inputs)
25
+
26
+
12
27
  @op.dialect.register(key="op.nsites")
13
28
  class SquinOp(interp.MethodTable):
14
29
 
bloqade/squin/groups.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
+ from kirin.rewrite import Walk, Chain
3
4
  from kirin.dialects import ilist
4
- from kirin.rewrite.walk import Walk
5
5
 
6
6
  from . import op, wire, qubit
7
7
  from .op.rewrite import PyMultToSquinMult
8
- from .rewrite.measure_desugar import MeasureDesugarRule
8
+ from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
9
9
 
10
10
 
11
11
  @ir.dialect_group(structural_no_opt.union([op, qubit]))
@@ -13,7 +13,7 @@ def kernel(self):
13
13
  fold_pass = passes.Fold(self)
14
14
  typeinfer_pass = passes.TypeInfer(self)
15
15
  ilist_desugar_pass = ilist.IListDesugar(self)
16
- measure_desugar_pass = Walk(MeasureDesugarRule())
16
+ desugar_pass = Walk(Chain(MeasureDesugarRule(), ApplyDesugarRule()))
17
17
  py_mult_to_mult_pass = PyMultToSquinMult(self)
18
18
 
19
19
  def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
@@ -25,7 +25,7 @@ def kernel(self):
25
25
 
26
26
  if typeinfer:
27
27
  typeinfer_pass(method)
28
- measure_desugar_pass.rewrite(method.code)
28
+ desugar_pass.rewrite(method.code)
29
29
 
30
30
  ilist_desugar_pass(method)
31
31
 
@@ -0,0 +1,27 @@
1
+ import ast
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import lowering
5
+
6
+ from . import qubit
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class ApplyAnyCallLowering(lowering.FromPythonCall["qubit.ApplyAny"]):
11
+ """
12
+ Custom lowering for ApplyAny that collects vararg qubits into a single tuple argument
13
+ """
14
+
15
+ def lower(
16
+ self, stmt: type["qubit.ApplyAny"], state: lowering.State, node: ast.Call
17
+ ):
18
+ if len(node.args) < 2:
19
+ raise lowering.BuildError(
20
+ "Apply requires at least one operator and one qubit as arguments!"
21
+ )
22
+ op, *qubits = node.args
23
+ op_ssa = state.lower(op).expect_one()
24
+ qubits_lowered = [state.lower(qbit).expect_one() for qbit in qubits]
25
+
26
+ s = stmt(op_ssa, tuple(qubits_lowered))
27
+ return state.current_frame.push(s)
@@ -1,27 +1,8 @@
1
- # Put all the proper wrappers here
2
-
3
- from kirin.lowering import wraps as _wraps
4
-
5
- from bloqade.squin.op.types import Op
6
-
7
1
  from . import stmts as stmts
8
-
9
-
10
- @_wraps(stmts.PauliError)
11
- def pauli_error(basis: Op, p: float) -> Op: ...
12
-
13
-
14
- @_wraps(stmts.PPError)
15
- def pp_error(op: Op, p: float) -> Op: ...
16
-
17
-
18
- @_wraps(stmts.Depolarize)
19
- def depolarize(n_qubits: int, p: float) -> Op: ...
20
-
21
-
22
- @_wraps(stmts.PauliChannel)
23
- def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
24
-
25
-
26
- @_wraps(stmts.QubitLoss)
27
- def qubit_loss(p: float) -> Op: ...
2
+ from ._dialect import dialect as dialect
3
+ from ._wrapper import (
4
+ pp_error as pp_error,
5
+ depolarize as depolarize,
6
+ qubit_loss as qubit_loss,
7
+ pauli_channel as pauli_channel,
8
+ )
@@ -0,0 +1,25 @@
1
+ from kirin.lowering import wraps
2
+
3
+ from bloqade.squin.op.types import Op
4
+
5
+ from . import stmts
6
+
7
+
8
+ @wraps(stmts.PauliError)
9
+ def pauli_error(basis: Op, p: float) -> Op: ...
10
+
11
+
12
+ @wraps(stmts.PPError)
13
+ def pp_error(op: Op, p: float) -> Op: ...
14
+
15
+
16
+ @wraps(stmts.Depolarize)
17
+ def depolarize(n_qubits: int, p: float) -> Op: ...
18
+
19
+
20
+ @wraps(stmts.PauliChannel)
21
+ def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
22
+
23
+
24
+ @wraps(stmts.QubitLoss)
25
+ def qubit_loss(p: float) -> Op: ...
@@ -1,162 +1,37 @@
1
- from kirin import ir as _ir
2
- from kirin.prelude import structural_no_opt as _structural_no_opt
3
- from kirin.lowering import wraps as _wraps
4
-
5
1
  from . import stmts as stmts, types as types, rewrite as rewrite
2
+ from .stdlib import (
3
+ ch as ch,
4
+ cx as cx,
5
+ cy as cy,
6
+ cz as cz,
7
+ rx as rx,
8
+ ry as ry,
9
+ rz as rz,
10
+ cphase as cphase,
11
+ )
6
12
  from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
7
13
  from ._dialect import dialect as dialect
8
-
9
-
10
- @_wraps(stmts.Kron)
11
- def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12
-
13
-
14
- @_wraps(stmts.Mult)
15
- def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
16
-
17
-
18
- @_wraps(stmts.Scale)
19
- def scale(op: types.Op, factor: complex) -> types.Op: ...
20
-
21
-
22
- @_wraps(stmts.Adjoint)
23
- def adjoint(op: types.Op) -> types.Op: ...
24
-
25
-
26
- @_wraps(stmts.Control)
27
- def control(op: types.Op, *, n_controls: int) -> types.Op:
28
- """
29
- Create a controlled operator.
30
-
31
- Note, that when considering atom loss, the operator will not be applied if
32
- any of the controls has been lost.
33
-
34
- Args:
35
- operator: The operator to apply under the control.
36
- n_controls: The number qubits to be used as control.
37
-
38
- Returns:
39
- Operator
40
- """
41
- ...
42
-
43
-
44
- @_wraps(stmts.Identity)
45
- def identity(*, sites: int) -> types.Op: ...
46
-
47
-
48
- @_wraps(stmts.Rot)
49
- def rot(axis: types.Op, angle: float) -> types.Op: ...
50
-
51
-
52
- @_wraps(stmts.ShiftOp)
53
- def shift(theta: float) -> types.Op: ...
54
-
55
-
56
- @_wraps(stmts.PhaseOp)
57
- def phase(theta: float) -> types.Op: ...
58
-
59
-
60
- @_wraps(stmts.X)
61
- def x() -> types.Op: ...
62
-
63
-
64
- @_wraps(stmts.Y)
65
- def y() -> types.Op: ...
66
-
67
-
68
- @_wraps(stmts.Z)
69
- def z() -> types.Op: ...
70
-
71
-
72
- @_wraps(stmts.H)
73
- def h() -> types.Op: ...
74
-
75
-
76
- @_wraps(stmts.S)
77
- def s() -> types.Op: ...
78
-
79
-
80
- @_wraps(stmts.T)
81
- def t() -> types.Op: ...
82
-
83
-
84
- @_wraps(stmts.P0)
85
- def p0() -> types.Op: ...
86
-
87
-
88
- @_wraps(stmts.P1)
89
- def p1() -> types.Op: ...
90
-
91
-
92
- @_wraps(stmts.Sn)
93
- def spin_n() -> types.Op: ...
94
-
95
-
96
- @_wraps(stmts.Sp)
97
- def spin_p() -> types.Op: ...
98
-
99
-
100
- @_wraps(stmts.U3)
101
- def u(theta: float, phi: float, lam: float) -> types.Op: ...
102
-
103
-
104
- @_wraps(stmts.PauliString)
105
- def pauli_string(*, string: str) -> types.Op: ...
106
-
107
-
108
- # stdlibs
109
- @_ir.dialect_group(_structural_no_opt.add(dialect))
110
- def op(self):
111
- def run_pass(method):
112
- pass
113
-
114
- return run_pass
115
-
116
-
117
- @op
118
- def rx(theta: float) -> types.Op:
119
- """Rotation X gate."""
120
- return rot(x(), theta)
121
-
122
-
123
- @op
124
- def ry(theta: float) -> types.Op:
125
- """Rotation Y gate."""
126
- return rot(y(), theta)
127
-
128
-
129
- @op
130
- def rz(theta: float) -> types.Op:
131
- """Rotation Z gate."""
132
- return rot(z(), theta)
133
-
134
-
135
- @op
136
- def cx() -> types.Op:
137
- """Controlled X gate."""
138
- return control(x(), n_controls=1)
139
-
140
-
141
- @op
142
- def cy() -> types.Op:
143
- """Controlled Y gate."""
144
- return control(y(), n_controls=1)
145
-
146
-
147
- @op
148
- def cz() -> types.Op:
149
- """Control Z gate."""
150
- return control(z(), n_controls=1)
151
-
152
-
153
- @op
154
- def ch() -> types.Op:
155
- """Control H gate."""
156
- return control(h(), n_controls=1)
157
-
158
-
159
- @op
160
- def cphase(theta: float) -> types.Op:
161
- """Control Phase gate."""
162
- return control(phase(theta), n_controls=1)
14
+ from ._wrapper import (
15
+ h as h,
16
+ s as s,
17
+ t as t,
18
+ u as u,
19
+ x as x,
20
+ y as y,
21
+ z as z,
22
+ p0 as p0,
23
+ p1 as p1,
24
+ rot as rot,
25
+ kron as kron,
26
+ mult as mult,
27
+ phase as phase,
28
+ reset as reset,
29
+ scale as scale,
30
+ shift as shift,
31
+ spin_n as spin_n,
32
+ spin_p as spin_p,
33
+ adjoint as adjoint,
34
+ control as control,
35
+ identity as identity,
36
+ pauli_string as pauli_string,
37
+ )