bloqade-circuit 0.6.8__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (33) hide show
  1. bloqade/analysis/measure_id/analysis.py +10 -11
  2. bloqade/analysis/measure_id/impls.py +15 -2
  3. bloqade/cirq_utils/noise/__init__.py +0 -2
  4. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  5. bloqade/cirq_utils/noise/model.py +141 -188
  6. bloqade/cirq_utils/noise/transform.py +2 -2
  7. bloqade/pyqrack/squin/qubit.py +4 -2
  8. bloqade/pyqrack/squin/runtime.py +14 -6
  9. bloqade/qasm2/emit/target.py +5 -1
  10. bloqade/squin/cirq/emit/op.py +37 -5
  11. bloqade/squin/cirq/emit/qubit.py +4 -4
  12. bloqade/squin/cirq/emit/runtime.py +0 -15
  13. bloqade/squin/cirq/lowering.py +3 -9
  14. bloqade/squin/gate.py +7 -0
  15. bloqade/squin/lowering.py +26 -0
  16. bloqade/squin/noise/__init__.py +0 -1
  17. bloqade/squin/noise/_wrapper.py +2 -6
  18. bloqade/squin/noise/rewrite.py +0 -11
  19. bloqade/squin/noise/stmts.py +2 -14
  20. bloqade/squin/op/_wrapper.py +4 -4
  21. bloqade/squin/op/stmts.py +33 -9
  22. bloqade/squin/op/types.py +104 -2
  23. bloqade/squin/qubit.py +27 -40
  24. bloqade/squin/rewrite/desugar.py +44 -66
  25. bloqade/stim/passes/squin_to_stim.py +21 -4
  26. bloqade/stim/rewrite/ifs_to_stim.py +6 -1
  27. bloqade/stim/rewrite/qubit_to_stim.py +1 -1
  28. bloqade/stim/rewrite/squin_noise.py +9 -7
  29. bloqade/stim/rewrite/util.py +15 -3
  30. {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/METADATA +2 -2
  31. {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/RECORD +33 -33
  32. {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/WHEEL +0 -0
  33. {bloqade_circuit-0.6.8.dist-info → bloqade_circuit-0.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,7 @@ import math
2
2
 
3
3
  import cirq
4
4
  import numpy as np
5
+ from kirin.emit import EmitError
5
6
  from kirin.interp import MethodTable, impl
6
7
 
7
8
  from ... import op
@@ -9,11 +10,11 @@ from .runtime import (
9
10
  SnRuntime,
10
11
  SpRuntime,
11
12
  U3Runtime,
12
- RotRuntime,
13
13
  KronRuntime,
14
14
  MultRuntime,
15
15
  ScaleRuntime,
16
16
  AdjointRuntime,
17
+ BasicOpRuntime,
17
18
  ControlRuntime,
18
19
  UnitaryRuntime,
19
20
  HermitianRuntime,
@@ -117,7 +118,7 @@ class EmitCirqOpMethods(MethodTable):
117
118
 
118
119
  @impl(op.stmts.Reset)
119
120
  def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
120
- return (HermitianRuntime(cirq.ResetChannel()),)
121
+ return (BasicOpRuntime(cirq.ResetChannel()),)
121
122
 
122
123
  @impl(op.stmts.PauliString)
123
124
  def pauli_string(
@@ -127,11 +128,42 @@ class EmitCirqOpMethods(MethodTable):
127
128
 
128
129
  @impl(op.stmts.Rot)
129
130
  def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
130
- axis_op: HermitianRuntime = frame.get(stmt.axis)
131
+ axis: OperatorRuntimeABC = frame.get(stmt.axis)
132
+
133
+ if not isinstance(axis, HermitianRuntime):
134
+ raise EmitError(
135
+ f"Circuit emission only supported for Pauli operators! Got axis {axis}"
136
+ )
137
+
131
138
  angle = frame.get(stmt.angle)
132
139
 
133
- axis_name = str(axis_op.gate).lower()
134
- return (RotRuntime(axis=axis_name, angle=angle),)
140
+ match axis.gate:
141
+ case cirq.X:
142
+ gate = cirq.Rx(rads=angle)
143
+ case cirq.Y:
144
+ gate = cirq.Ry(rads=angle)
145
+ case cirq.Z:
146
+ gate = cirq.Rz(rads=angle)
147
+ case _:
148
+ raise EmitError(
149
+ f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}"
150
+ )
151
+
152
+ return (HermitianRuntime(gate=gate),)
153
+
154
+ @impl(op.stmts.ResetToOne)
155
+ def reset_to_one(
156
+ self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ResetToOne
157
+ ):
158
+ # NOTE: just apply a reset to 0 and flip in sequence (we re-use the multiplication runtime since it does exactly that)
159
+ gate1 = cirq.ResetChannel()
160
+ gate2 = cirq.X
161
+
162
+ rt1 = BasicOpRuntime(gate1)
163
+ rt2 = HermitianRuntime(gate2)
164
+
165
+ # NOTE: mind the order: rhs is applied first
166
+ return (MultRuntime(rt2, rt1),)
135
167
 
136
168
  @impl(op.stmts.SqrtX)
137
169
  def sqrt_x(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtX):
@@ -25,7 +25,7 @@ class EmitCirqQubitMethods(MethodTable):
25
25
  @impl(qubit.Apply)
26
26
  def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
27
27
  op: OperatorRuntimeABC = frame.get(stmt.operator)
28
- qbits = frame.get(stmt.qubits)
28
+ qbits = [frame.get(qbit) for qbit in stmt.qubits]
29
29
  operations = op.apply(qbits)
30
30
  for operation in operations:
31
31
  frame.circuit.append(operation)
@@ -34,11 +34,11 @@ class EmitCirqQubitMethods(MethodTable):
34
34
  @impl(qubit.Broadcast)
35
35
  def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast):
36
36
  op = frame.get(stmt.operator)
37
- qbits = frame.get(stmt.qubits)
37
+ qbit_lists = [frame.get(qbit) for qbit in stmt.qubits]
38
38
 
39
39
  cirq_ops = []
40
- for qbit in qbits:
41
- cirq_ops.extend(op.apply([qbit]))
40
+ for qbits in zip(*qbit_lists):
41
+ cirq_ops.extend(op.apply(qbits))
42
42
 
43
43
  frame.circuit.append(cirq.Moment(cirq_ops))
44
44
  return ()
@@ -240,18 +240,3 @@ class PauliStringRuntime(OperatorRuntimeABC):
240
240
  qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
241
241
  }
242
242
  return [cirq.PauliString(pauli_mapping)]
243
-
244
-
245
- @dataclass
246
- class RotRuntime(OperatorRuntimeABC):
247
- axis: str
248
- angle: float
249
-
250
- def num_qubits(self) -> int:
251
- return 1
252
-
253
- def unsafe_apply(
254
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
255
- ) -> list[cirq.Operation]:
256
- rot = getattr(cirq, "R" + self.axis.lower())(rads=self.angle)
257
- return [rot(*qubits)]
@@ -44,14 +44,7 @@ class Squin(lowering.LoweringABC[CirqNode]):
44
44
  self, state: lowering.State[CirqNode], qids: list[cirq.Qid]
45
45
  ):
46
46
  qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
47
- qbits_stmt = ilist.New(values=qbits_getitem, elem_type=qubit.QubitType)
48
- qbits_result = state.current_frame.get(qbits_stmt.name)
49
-
50
- if qbits_result is not None:
51
- return qbits_result
52
-
53
- state.current_frame.push(qbits_stmt)
54
- return qbits_stmt.result
47
+ return tuple(qbits_getitem)
55
48
 
56
49
  def run(
57
50
  self,
@@ -159,7 +152,8 @@ class Squin(lowering.LoweringABC[CirqNode]):
159
152
  stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
160
153
  else:
161
154
  qbits = self.lower_qubit_getindices(state, node.qubits)
162
- stmt = state.current_frame.push(qubit.MeasureQubitList(qbits))
155
+ qbits_list = state.current_frame.push(ilist.New(values=qbits))
156
+ stmt = state.current_frame.push(qubit.MeasureQubitList(qbits_list.result))
163
157
 
164
158
  key = node.gate.key
165
159
  if isinstance(key, cirq.MeasurementKey):
bloqade/squin/gate.py CHANGED
@@ -137,6 +137,13 @@ def reset(qubit: Qubit) -> None:
137
137
  _qubit.apply(op, qubit)
138
138
 
139
139
 
140
+ @kernel
141
+ def reset_to_one(qubit: Qubit) -> None:
142
+ """Reset qubit to 1."""
143
+ op = _op.reset_to_one()
144
+ _qubit.apply(op, qubit)
145
+
146
+
140
147
  @kernel
141
148
  def cx(control: Qubit, target: Qubit) -> None:
142
149
  """Controlled x gate applied to control and target"""
bloqade/squin/lowering.py CHANGED
@@ -52,3 +52,29 @@ class ApplyAnyCallLowering(lowering.FromPythonCall["qubit.ApplyAny"]):
52
52
  return op, qubits.elts
53
53
 
54
54
  return op, [qubits]
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class BroadcastCallLowering(lowering.FromPythonCall["qubit.Broadcast"]):
59
+ """
60
+ Custom lowering for broadcast vararg call.
61
+
62
+ NOTE: we can re-use this to lower Apply too once we remove the deprecated syntax
63
+ """
64
+
65
+ def lower(
66
+ self, stmt: type["qubit.Broadcast"], state: lowering.State, node: ast.Call
67
+ ):
68
+ if len(node.args) < 2:
69
+ raise lowering.BuildError(
70
+ "Broadcast requires at least one operator and one qubit list argument"
71
+ )
72
+
73
+ op, *qubit_lists = node.args
74
+
75
+ op_lowered = state.lower(op).expect_one()
76
+ qubits_lists_lowered = [
77
+ state.lower(qubit_list).expect_one() for qubit_list in qubit_lists
78
+ ]
79
+
80
+ return state.current_frame.push(stmt(op_lowered, tuple(qubits_lists_lowered)))
@@ -1,7 +1,6 @@
1
1
  from . import stmts as stmts
2
2
  from ._dialect import dialect as dialect
3
3
  from ._wrapper import (
4
- pp_error as pp_error,
5
4
  depolarize as depolarize,
6
5
  qubit_loss as qubit_loss,
7
6
  depolarize2 as depolarize2,
@@ -3,17 +3,13 @@ from typing import Literal
3
3
  from kirin.dialects import ilist
4
4
  from kirin.lowering import wraps
5
5
 
6
- from bloqade.squin.op.types import Op
6
+ from bloqade.squin.op.types import Op, MultiQubitPauliOp
7
7
 
8
8
  from . import stmts
9
9
 
10
10
 
11
11
  @wraps(stmts.PauliError)
12
- def pauli_error(basis: Op, p: float) -> Op: ...
13
-
14
-
15
- @wraps(stmts.PPError)
16
- def pp_error(op: Op, p: float) -> Op: ...
12
+ def pauli_error(basis: MultiQubitPauliOp, p: float) -> Op: ...
17
13
 
18
14
 
19
15
  @wraps(stmts.Depolarize)
@@ -7,7 +7,6 @@ from kirin.dialects import py, ilist
7
7
  from kirin.rewrite.abc import RewriteRule, RewriteResult
8
8
 
9
9
  from .stmts import (
10
- PPError,
11
10
  QubitLoss,
12
11
  Depolarize,
13
12
  PauliError,
@@ -86,16 +85,6 @@ class _RewriteNoiseStmts(RewriteRule):
86
85
  (operator_list := ilist.New(values=operators)).insert_before(node)
87
86
  return operator_list.result
88
87
 
89
- def rewrite_p_p_error(self, node: PPError) -> RewriteResult:
90
- (operators := ilist.New(values=(node.op,))).insert_before(node)
91
- (ps := ilist.New(values=(node.p,))).insert_before(node)
92
- stochastic_channel = StochasticUnitaryChannel(
93
- operators=operators.result, probabilities=ps.result
94
- )
95
-
96
- node.replace_by(stochastic_channel)
97
- return RewriteResult(has_done_something=True)
98
-
99
88
  def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
100
89
  paulis = (X(), Y(), Z())
101
90
  operators: list[ir.SSAValue] = []
@@ -2,10 +2,8 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
  from kirin.dialects import ilist
4
4
 
5
- from bloqade.squin.op.types import OpType
6
-
7
5
  from ._dialect import dialect
8
- from ..op.types import NumOperators
6
+ from ..op.types import OpType, NumOperators, MultiQubitPauliOpType
9
7
 
10
8
 
11
9
  @statement
@@ -16,17 +14,7 @@ class NoiseChannel(ir.Statement):
16
14
 
17
15
  @statement(dialect=dialect)
18
16
  class PauliError(NoiseChannel):
19
- basis: ir.SSAValue = info.argument(OpType)
20
- p: ir.SSAValue = info.argument(types.Float)
21
-
22
-
23
- @statement(dialect=dialect)
24
- class PPError(NoiseChannel):
25
- """
26
- Pauli Product Error
27
- """
28
-
29
- op: ir.SSAValue = info.argument(OpType)
17
+ basis: ir.SSAValue = info.argument(MultiQubitPauliOpType)
30
18
  p: ir.SSAValue = info.argument(types.Float)
31
19
 
32
20
 
@@ -62,15 +62,15 @@ def phase(theta: float) -> types.Op: ...
62
62
 
63
63
 
64
64
  @wraps(stmts.X)
65
- def x() -> types.Op: ...
65
+ def x() -> types.PauliOp: ...
66
66
 
67
67
 
68
68
  @wraps(stmts.Y)
69
- def y() -> types.Op: ...
69
+ def y() -> types.PauliOp: ...
70
70
 
71
71
 
72
72
  @wraps(stmts.Z)
73
- def z() -> types.Op: ...
73
+ def z() -> types.PauliOp: ...
74
74
 
75
75
 
76
76
  @wraps(stmts.SqrtX)
@@ -118,4 +118,4 @@ def u(theta: float, phi: float, lam: float) -> types.Op: ...
118
118
 
119
119
 
120
120
  @wraps(stmts.PauliString)
121
- def pauli_string(*, string: str) -> types.Op: ...
121
+ def pauli_string(*, string: str) -> types.PauliStringOp: ...
bloqade/squin/op/stmts.py CHANGED
@@ -1,7 +1,19 @@
1
1
  from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
- from .types import OpType
4
+ from .types import (
5
+ OpType,
6
+ ROpType,
7
+ XOpType,
8
+ YOpType,
9
+ ZOpType,
10
+ KronType,
11
+ MultType,
12
+ PauliOpType,
13
+ ControlOpType,
14
+ PauliStringType,
15
+ ControlledOpType,
16
+ )
5
17
  from .number import NumberType
6
18
  from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
7
19
  from ._dialect import dialect
@@ -22,22 +34,28 @@ class CompositeOp(Operator):
22
34
  pass
23
35
 
24
36
 
37
+ LhsType = types.TypeVar("Lhs", bound=OpType)
38
+ RhsType = types.TypeVar("Rhs", bound=OpType)
39
+
40
+
25
41
  @statement
26
42
  class BinaryOp(CompositeOp):
27
- lhs: ir.SSAValue = info.argument(OpType)
28
- rhs: ir.SSAValue = info.argument(OpType)
43
+ lhs: ir.SSAValue = info.argument(LhsType)
44
+ rhs: ir.SSAValue = info.argument(RhsType)
29
45
 
30
46
 
31
47
  @statement(dialect=dialect)
32
48
  class Kron(BinaryOp):
33
49
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
34
50
  is_unitary: bool = info.attribute(default=False)
51
+ result: ir.ResultValue = info.result(KronType[LhsType, RhsType])
35
52
 
36
53
 
37
54
  @statement(dialect=dialect)
38
55
  class Mult(BinaryOp):
39
56
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
40
57
  is_unitary: bool = info.attribute(default=False)
58
+ result: ir.ResultValue = info.result(MultType[LhsType, RhsType])
41
59
 
42
60
 
43
61
  @statement(dialect=dialect)
@@ -59,15 +77,20 @@ class Scale(CompositeOp):
59
77
  class Control(CompositeOp):
60
78
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
61
79
  is_unitary: bool = info.attribute(default=False)
62
- op: ir.SSAValue = info.argument(OpType)
80
+ op: ir.SSAValue = info.argument(ControlledOpType)
63
81
  n_controls: int = info.attribute()
82
+ result: ir.ResultValue = info.result(ControlOpType[ControlledOpType])
83
+
84
+
85
+ RotationAxisType = types.TypeVar("RotationAxis", bound=OpType)
64
86
 
65
87
 
66
88
  @statement(dialect=dialect)
67
89
  class Rot(CompositeOp):
68
90
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()})
69
- axis: ir.SSAValue = info.argument(OpType)
91
+ axis: ir.SSAValue = info.argument(RotationAxisType)
70
92
  angle: ir.SSAValue = info.argument(types.Float)
93
+ result: ir.ResultValue = info.result(ROpType[RotationAxisType])
71
94
 
72
95
 
73
96
  @statement(dialect=dialect)
@@ -166,13 +189,14 @@ class CliffordOp(ConstantUnitary):
166
189
 
167
190
  @statement
168
191
  class PauliOp(CliffordOp):
169
- pass
192
+ result: ir.ResultValue = info.result(type=PauliOpType)
170
193
 
171
194
 
172
195
  @statement(dialect=dialect)
173
196
  class PauliString(ConstantUnitary):
174
197
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
175
198
  string: str = info.attribute()
199
+ result: ir.ResultValue = info.result(type=PauliStringType)
176
200
 
177
201
  def verify(self) -> None:
178
202
  if not set("XYZ").issuperset(self.string):
@@ -183,17 +207,17 @@ class PauliString(ConstantUnitary):
183
207
 
184
208
  @statement(dialect=dialect)
185
209
  class X(PauliOp):
186
- pass
210
+ result: ir.ResultValue = info.result(XOpType)
187
211
 
188
212
 
189
213
  @statement(dialect=dialect)
190
214
  class Y(PauliOp):
191
- pass
215
+ result: ir.ResultValue = info.result(YOpType)
192
216
 
193
217
 
194
218
  @statement(dialect=dialect)
195
219
  class Z(PauliOp):
196
- pass
220
+ result: ir.ResultValue = info.result(ZOpType)
197
221
 
198
222
 
199
223
  @statement(dialect=dialect)
bloqade/squin/op/types.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import overload
1
+ from typing import Generic, TypeVar, overload
2
2
 
3
3
  from kirin import types
4
4
 
@@ -23,4 +23,106 @@ class Op:
23
23
 
24
24
  OpType = types.PyClass(Op)
25
25
 
26
- NumOperators = types.TypeVar("NumOperators")
26
+
27
+ class CompositeOp(Op):
28
+ pass
29
+
30
+
31
+ CompositeOpType = types.PyClass(CompositeOp)
32
+
33
+ LhsType = TypeVar("LhsType", bound=Op)
34
+ RhsType = TypeVar("RhsType", bound=Op)
35
+
36
+
37
+ class BinaryOp(Op, Generic[LhsType, RhsType]):
38
+ lhs: LhsType
39
+ rhs: RhsType
40
+
41
+
42
+ BinaryOpType = types.Generic(BinaryOp, OpType, OpType)
43
+
44
+
45
+ class Mult(BinaryOp[LhsType, RhsType]):
46
+ pass
47
+
48
+
49
+ MultType = types.Generic(Mult, OpType, OpType)
50
+
51
+
52
+ class Kron(BinaryOp[LhsType, RhsType]):
53
+ pass
54
+
55
+
56
+ KronType = types.Generic(Kron, OpType, OpType)
57
+
58
+
59
+ class MultiQubitPauliOp(Op):
60
+ pass
61
+
62
+
63
+ MultiQubitPauliOpType = types.PyClass(MultiQubitPauliOp)
64
+
65
+
66
+ class PauliStringOp(MultiQubitPauliOp):
67
+ pass
68
+
69
+
70
+ PauliStringType = types.PyClass(PauliStringOp)
71
+
72
+
73
+ class PauliOp(MultiQubitPauliOp):
74
+ pass
75
+
76
+
77
+ PauliOpType = types.PyClass(PauliOp)
78
+
79
+
80
+ class XOp(PauliOp):
81
+ pass
82
+
83
+
84
+ XOpType = types.PyClass(XOp)
85
+
86
+
87
+ class YOp(PauliOp):
88
+ pass
89
+
90
+
91
+ YOpType = types.PyClass(YOp)
92
+
93
+
94
+ class ZOp(PauliOp):
95
+ pass
96
+
97
+
98
+ ZOpType = types.PyClass(ZOp)
99
+
100
+
101
+ ControlledOp = TypeVar("ControlledOp", bound=Op)
102
+
103
+
104
+ class ControlOp(CompositeOp, Generic[ControlledOp]):
105
+ op: ControlledOp
106
+
107
+
108
+ ControlledOpType = types.TypeVar("ControlledOp", bound=OpType)
109
+ ControlOpType = types.Generic(ControlOp, ControlledOpType)
110
+ CXOpType = ControlOpType[XOpType]
111
+ CYOpType = ControlOpType[YOpType]
112
+ CZOpType = ControlOpType[ZOpType]
113
+
114
+ RotationAxis = TypeVar("RotationAxis", bound=Op)
115
+
116
+
117
+ class ROp(CompositeOp, Generic[RotationAxis]):
118
+ axis: RotationAxis
119
+ angle: float
120
+
121
+
122
+ ROpType = types.Generic(ROp, OpType)
123
+ RxOpType = ROpType[XOpType]
124
+ RyOpType = ROpType[YOpType]
125
+ RzOpType = ROpType[ZOpType]
126
+
127
+
128
+ NumOperators = types.TypeVar("NumOperators", bound=types.Int)
bloqade/squin/qubit.py CHANGED
@@ -7,7 +7,7 @@ Depends on:
7
7
  - `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits.
8
8
  """
9
9
 
10
- from typing import Any, overload
10
+ from typing import Any, TypeVar, overload
11
11
 
12
12
  from kirin import ir, types, lowering
13
13
  from kirin.decl import info, statement
@@ -18,7 +18,7 @@ from bloqade.types import Qubit, QubitType
18
18
  from bloqade.squin.op.types import Op, OpType
19
19
 
20
20
  from .types import MeasurementResult, MeasurementResultType
21
- from .lowering import ApplyAnyCallLowering
21
+ from .lowering import ApplyAnyCallLowering, BroadcastCallLowering
22
22
 
23
23
  dialect = ir.Dialect("squin.qubit")
24
24
 
@@ -34,7 +34,7 @@ class New(ir.Statement):
34
34
  class Apply(ir.Statement):
35
35
  traits = frozenset({lowering.FromPythonCall()})
36
36
  operator: ir.SSAValue = info.argument(OpType)
37
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
37
+ qubits: tuple[ir.SSAValue, ...] = info.argument(QubitType)
38
38
 
39
39
 
40
40
  @statement(dialect=dialect)
@@ -47,9 +47,9 @@ class ApplyAny(ir.Statement):
47
47
 
48
48
  @statement(dialect=dialect)
49
49
  class Broadcast(ir.Statement):
50
- traits = frozenset({lowering.FromPythonCall()})
50
+ traits = frozenset({BroadcastCallLowering()})
51
51
  operator: ir.SSAValue = info.argument(OpType)
52
- qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
52
+ qubits: tuple[ir.SSAValue, ...] = info.argument(ilist.IListType[QubitType])
53
53
 
54
54
 
55
55
  @statement(dialect=dialect)
@@ -93,26 +93,10 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
93
93
  ...
94
94
 
95
95
 
96
- @overload
97
- def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
98
- """Apply an operator to a list of qubits.
99
-
100
- Note, that when considering atom loss, lost qubits will be skipped.
101
-
102
- Args:
103
- operator: The operator to apply.
104
- qubits: The list of qubits to apply the operator to. The size of the list
105
- must be inferable and match the number of qubits expected by the operator.
106
-
107
- Returns:
108
- None
109
- """
110
- ...
111
-
112
-
113
- @overload
96
+ @wraps(ApplyAny)
114
97
  def apply(operator: Op, *qubits: Qubit) -> None:
115
- """Apply and operator to any number of qubits.
98
+ """Apply an operator to qubits. The number of qubit arguments must match the
99
+ size of the operator.
116
100
 
117
101
  Note, that when considering atom loss, lost qubits will be skipped.
118
102
 
@@ -127,10 +111,6 @@ def apply(operator: Op, *qubits: Qubit) -> None:
127
111
  ...
128
112
 
129
113
 
130
- @wraps(ApplyAny)
131
- def apply(operator: Op, *qubits) -> None: ...
132
-
133
-
134
114
  @overload
135
115
  def measure(input: Qubit) -> MeasurementResult: ...
136
116
  @overload
@@ -154,23 +134,30 @@ def measure(input: Any) -> Any:
154
134
  ...
155
135
 
156
136
 
137
+ OpSize = TypeVar("OpSize")
138
+
139
+
157
140
  @wraps(Broadcast)
158
- def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
159
- """Broadcast and apply an operator to a list of qubits. For example, an operator
160
- that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.
141
+ def broadcast(operator: Op, *qubits: ilist.IList[Qubit, OpSize] | list[Qubit]) -> None:
142
+ """Broadcast and apply an operator to lists of qubits. The number of qubit lists must
143
+ match the size of the operator and the lists must be of same length. The operator is
144
+ then applied to the list elements similar to what python's map function does.
161
145
 
162
- For controlled operators, the list of qubits is interpreted as sets of (controls, targets).
163
- For example
146
+ ## Usage examples
164
147
 
165
- ```
166
- apply(CX, [q0, q1, q2, q3])
167
- ```
148
+ ```python
149
+ from bloqade import squin
168
150
 
169
- is equivalent to
151
+ @squin.kernel
152
+ def ghz():
153
+ controls = squin.qubit.new(4)
154
+ targets = squin.qubit.new(4)
170
155
 
171
- ```
172
- apply(CX, [q0, q1])
173
- apply(CX, [q2, q3])
156
+ h = squin.op.h()
157
+ squin.qubit.broadcast(h, controls)
158
+
159
+ cx = squin.op.cx()
160
+ squin.qubit.broadcast(cx, controls, targets)
174
161
  ```
175
162
 
176
163
  Args: