bloqade-circuit 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (32) hide show
  1. bloqade/squin/analysis/nsites/impls.py +7 -0
  2. bloqade/squin/cirq/__init__.py +112 -1
  3. bloqade/squin/cirq/emit/emit_circuit.py +109 -0
  4. bloqade/squin/cirq/emit/op.py +125 -0
  5. bloqade/squin/cirq/emit/qubit.py +60 -0
  6. bloqade/squin/cirq/emit/runtime.py +234 -0
  7. bloqade/squin/cirq/lowering.py +73 -4
  8. bloqade/squin/op/__init__.py +3 -0
  9. bloqade/squin/op/_wrapper.py +12 -0
  10. bloqade/squin/op/stmts.py +19 -1
  11. bloqade/squin/rewrite/U3_to_clifford.py +149 -0
  12. bloqade/squin/rewrite/__init__.py +3 -7
  13. bloqade/squin/rewrite/remove_dangling_qubits.py +19 -0
  14. bloqade/squin/rewrite/wrap_analysis.py +34 -19
  15. bloqade/stim/__init__.py +1 -1
  16. bloqade/stim/dialects/auxiliary/stmts/const.py +1 -1
  17. bloqade/stim/dialects/gate/stmts/__init__.py +6 -0
  18. bloqade/stim/passes/__init__.py +1 -0
  19. bloqade/stim/passes/squin_to_stim.py +86 -0
  20. bloqade/stim/rewrite/__init__.py +7 -0
  21. bloqade/stim/rewrite/py_constant_to_stim.py +42 -0
  22. bloqade/{squin → stim}/rewrite/qubit_to_stim.py +18 -3
  23. bloqade/{squin → stim}/rewrite/squin_measure.py +2 -2
  24. bloqade/{squin/rewrite/stim_rewrite_util.py → stim/rewrite/util.py} +36 -17
  25. bloqade/{squin → stim}/rewrite/wire_to_stim.py +1 -1
  26. {bloqade_circuit-0.4.2.dist-info → bloqade_circuit-0.4.4.dist-info}/METADATA +1 -1
  27. {bloqade_circuit-0.4.2.dist-info → bloqade_circuit-0.4.4.dist-info}/RECORD +30 -22
  28. bloqade/squin/passes/__init__.py +0 -1
  29. bloqade/squin/passes/stim.py +0 -68
  30. /bloqade/{squin → stim}/rewrite/wire_identity_elimination.py +0 -0
  31. {bloqade_circuit-0.4.2.dist-info → bloqade_circuit-0.4.4.dist-info}/WHEEL +0 -0
  32. {bloqade_circuit-0.4.2.dist-info → bloqade_circuit-0.4.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,6 @@
1
1
  from kirin import interp
2
+ from kirin.dialects import scf
3
+ from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer
2
4
 
3
5
  from bloqade.squin import op, wire
4
6
 
@@ -78,3 +80,8 @@ class SquinOp(interp.MethodTable):
78
80
  def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
79
81
  op_sites = frame.get(stmt.op)
80
82
  return (op_sites,)
83
+
84
+
85
+ @scf.dialect.register(key="op.nsites")
86
+ class ScfSquinOp(ScfTypeInfer):
87
+ pass
@@ -1,12 +1,17 @@
1
- from typing import Any
1
+ from typing import Any, Sequence
2
2
 
3
3
  import cirq
4
4
  from kirin import ir, types
5
+ from kirin.emit import EmitError
5
6
  from kirin.dialects import func
6
7
 
7
8
  from . import lowering as lowering
8
9
  from .. import kernel
10
+
11
+ # NOTE: just to register methods
12
+ from .emit import op as op, qubit as qubit
9
13
  from .lowering import Squin
14
+ from .emit.emit_circuit import EmitCirq
10
15
 
11
16
 
12
17
  def load_circuit(
@@ -87,3 +92,109 @@ def load_circuit(
87
92
  dialects=dialects,
88
93
  code=code,
89
94
  )
95
+
96
+
97
+ def emit_circuit(
98
+ mt: ir.Method,
99
+ qubits: Sequence[cirq.Qid] | None = None,
100
+ ) -> cirq.Circuit:
101
+ """Converts a squin.kernel method to a cirq.Circuit object.
102
+
103
+ Args:
104
+ mt (ir.Method): The kernel method from which to construct the circuit.
105
+
106
+ Keyword Args:
107
+ qubits (Sequence[cirq.Qid] | None):
108
+ A list of qubits to use as the qubits in the circuit. Defaults to None.
109
+ If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
110
+ statement in the order they appear inside the kernel.
111
+ **Note**: If a list of qubits is provided, make sure that there is a sufficient
112
+ number of qubits for the resulting circuit.
113
+
114
+ ## Examples:
115
+
116
+ Here's a very basic example:
117
+
118
+ ```python
119
+ from bloqade import squin
120
+
121
+ @squin.kernel
122
+ def main():
123
+ q = squin.qubit.new(2)
124
+ h = squin.op.h()
125
+ squin.qubit.apply(h, q[0])
126
+ cx = squin.op.cx()
127
+ squin.qubit.apply(cx, q)
128
+
129
+ circuit = squin.cirq.emit_circuit(main)
130
+
131
+ print(circuit)
132
+ ```
133
+
134
+ You can also compose multiple kernels. Those are emitted as subcircuits within the "main" circuit.
135
+ Subkernels can accept arguments and return a value.
136
+
137
+ ```python
138
+ from bloqade import squin
139
+ from kirin.dialects import ilist
140
+ from typing import Literal
141
+ import cirq
142
+
143
+ @squin.kernel
144
+ def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
145
+ h = squin.op.h()
146
+ squin.qubit.apply(h, q[0])
147
+ cx = squin.op.cx()
148
+ squin.qubit.apply(cx, q)
149
+ return cx
150
+
151
+ @squin.kernel
152
+ def main():
153
+ q = squin.qubit.new(2)
154
+ cx = entangle(q)
155
+ q2 = squin.qubit.new(3)
156
+ squin.qubit.apply(cx, [q[1], q2[2]])
157
+
158
+
159
+ # custom list of qubits on grid
160
+ qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
161
+
162
+ circuit = squin.cirq.emit_circuit(main, qubits=qubits)
163
+ print(circuit)
164
+
165
+ ```
166
+
167
+ We also passed in a custom list of qubits above. This allows you to provide a custom geometry
168
+ and manipulate the qubits in other circuits directly written in cirq as well.
169
+ """
170
+
171
+ if isinstance(mt.code, func.Function) and not mt.code.signature.output.is_subseteq(
172
+ types.NoneType
173
+ ):
174
+ raise EmitError(
175
+ "The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
176
+ )
177
+
178
+ emitter = EmitCirq(qubits=qubits)
179
+ return emitter.run(mt, args=())
180
+
181
+
182
+ def dump_circuit(mt: ir.Method, qubits: Sequence[cirq.Qid] | None = None, **kwargs):
183
+ """Converts a squin.kernel method to a cirq.Circuit object and dumps it as JSON.
184
+
185
+ This just runs `emit_circuit` and calls the `cirq.to_json` function to emit a JSON.
186
+
187
+ Args:
188
+ mt (ir.Method): The kernel method from which to construct the circuit.
189
+
190
+ Keyword Args:
191
+ qubits (Sequence[cirq.Qid] | None):
192
+ A list of qubits to use as the qubits in the circuit. Defaults to None.
193
+ If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
194
+ statement in the order they appear inside the kernel.
195
+ **Note**: If a list of qubits is provided, make sure that there is a sufficient
196
+ number of qubits for the resulting circuit.
197
+
198
+ """
199
+ circuit = emit_circuit(mt, qubits=qubits)
200
+ return cirq.to_json(circuit, **kwargs)
@@ -0,0 +1,109 @@
1
+ from typing import Sequence
2
+ from dataclasses import field, dataclass
3
+
4
+ import cirq
5
+ from kirin import ir
6
+ from kirin.emit import EmitABC, EmitError, EmitFrame
7
+ from kirin.interp import MethodTable, impl
8
+ from kirin.dialects import func
9
+ from typing_extensions import Self
10
+
11
+ from ... import kernel
12
+
13
+
14
+ @dataclass
15
+ class EmitCirqFrame(EmitFrame):
16
+ qubit_index: int = 0
17
+ qubits: Sequence[cirq.Qid] | None = None
18
+ circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
19
+
20
+
21
+ def _default_kernel():
22
+ return kernel
23
+
24
+
25
+ @dataclass
26
+ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
27
+ keys = ["emit.cirq", "main"]
28
+ dialects: ir.DialectGroup = field(default_factory=_default_kernel)
29
+ void = cirq.Circuit()
30
+ qubits: Sequence[cirq.Qid] | None = None
31
+ _cached_circuit_operations: dict[int, cirq.CircuitOperation] = field(
32
+ init=False, default_factory=dict
33
+ )
34
+
35
+ def initialize(self) -> Self:
36
+ return super().initialize()
37
+
38
+ def initialize_frame(
39
+ self, code: ir.Statement, *, has_parent_access: bool = False
40
+ ) -> EmitCirqFrame:
41
+ return EmitCirqFrame(
42
+ code, has_parent_access=has_parent_access, qubits=self.qubits
43
+ )
44
+
45
+ def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
46
+ return self.run_callable(method.code, args)
47
+
48
+ def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
49
+ for stmt in block.stmts:
50
+ result = self.eval_stmt(frame, stmt)
51
+ if isinstance(result, tuple):
52
+ frame.set_values(stmt.results, result)
53
+
54
+ return frame.circuit
55
+
56
+
57
+ @func.dialect.register(key="emit.cirq")
58
+ class FuncEmit(MethodTable):
59
+
60
+ @impl(func.Function)
61
+ def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
62
+ emit.run_ssacfg_region(frame, stmt.body, ())
63
+ return (frame.circuit,)
64
+
65
+ @impl(func.Invoke)
66
+ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
67
+ stmt_hash = hash((stmt.callee, stmt.inputs))
68
+ if (
69
+ cached_circuit_op := emit._cached_circuit_operations.get(stmt_hash)
70
+ ) is not None:
71
+ # NOTE: cache hit
72
+ frame.circuit.append(cached_circuit_op)
73
+ return ()
74
+
75
+ ret = stmt.result
76
+
77
+ with emit.new_frame(stmt.callee.code, has_parent_access=True) as sub_frame:
78
+ sub_frame.qubit_index = frame.qubit_index
79
+ sub_frame.qubits = frame.qubits
80
+
81
+ region = stmt.callee.callable_region
82
+ if len(region.blocks) > 1:
83
+ raise EmitError(
84
+ "Subroutine with more than a single block encountered. This is not supported!"
85
+ )
86
+
87
+ # NOTE: get the arguments, "self" is just an empty circuit
88
+ method_self = emit.void
89
+ args = [frame.get(arg_) for arg_ in stmt.inputs]
90
+ emit.run_ssacfg_region(
91
+ sub_frame, stmt.callee.callable_region, args=(method_self, *args)
92
+ )
93
+ sub_circuit = sub_frame.circuit
94
+
95
+ # NOTE: check to see if the call terminates with a return value and fetch the value;
96
+ # we don't support multiple return statements via control flow so we just pick the first one
97
+ block = region.blocks[0]
98
+ return_stmt = next(
99
+ (stmt for stmt in block.stmts if isinstance(stmt, func.Return)), None
100
+ )
101
+ if return_stmt is not None:
102
+ frame.entries[ret] = sub_frame.get(return_stmt.value)
103
+
104
+ circuit_op = cirq.CircuitOperation(
105
+ sub_circuit.freeze(), use_repetition_ids=False
106
+ )
107
+ emit._cached_circuit_operations[stmt_hash] = circuit_op
108
+ frame.circuit.append(circuit_op)
109
+ return ()
@@ -0,0 +1,125 @@
1
+ import math
2
+
3
+ import cirq
4
+ import numpy as np
5
+ from kirin.interp import MethodTable, impl
6
+
7
+ from ... import op
8
+ from .runtime import (
9
+ SnRuntime,
10
+ SpRuntime,
11
+ U3Runtime,
12
+ KronRuntime,
13
+ MultRuntime,
14
+ ScaleRuntime,
15
+ AdjointRuntime,
16
+ ControlRuntime,
17
+ UnitaryRuntime,
18
+ HermitianRuntime,
19
+ ProjectorRuntime,
20
+ OperatorRuntimeABC,
21
+ PauliStringRuntime,
22
+ )
23
+ from .emit_circuit import EmitCirq, EmitCirqFrame
24
+
25
+
26
+ @op.dialect.register(key="emit.cirq")
27
+ class EmitCirqOpMethods(MethodTable):
28
+
29
+ @impl(op.stmts.X)
30
+ @impl(op.stmts.Y)
31
+ @impl(op.stmts.Z)
32
+ @impl(op.stmts.H)
33
+ def hermitian(
34
+ self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
35
+ ):
36
+ cirq_op = getattr(cirq, stmt.name.upper())
37
+ return (HermitianRuntime(cirq_op),)
38
+
39
+ @impl(op.stmts.S)
40
+ @impl(op.stmts.T)
41
+ def unitary(
42
+ self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
43
+ ):
44
+ cirq_op = getattr(cirq, stmt.name.upper())
45
+ return (UnitaryRuntime(cirq_op),)
46
+
47
+ @impl(op.stmts.P0)
48
+ @impl(op.stmts.P1)
49
+ def projector(
50
+ self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P0 | op.stmts.P1
51
+ ):
52
+ return (ProjectorRuntime(isinstance(stmt, op.stmts.P1)),)
53
+
54
+ @impl(op.stmts.Sn)
55
+ def sn(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sn):
56
+ return (SnRuntime(),)
57
+
58
+ @impl(op.stmts.Sp)
59
+ def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):
60
+ return (SpRuntime(),)
61
+
62
+ @impl(op.stmts.Identity)
63
+ def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
64
+ op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites))
65
+ return (op,)
66
+
67
+ @impl(op.stmts.Control)
68
+ def control(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Control):
69
+ op: OperatorRuntimeABC = frame.get(stmt.op)
70
+ return (ControlRuntime(op, stmt.n_controls),)
71
+
72
+ @impl(op.stmts.Kron)
73
+ def kron(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Kron):
74
+ lhs = frame.get(stmt.lhs)
75
+ rhs = frame.get(stmt.rhs)
76
+ op = KronRuntime(lhs, rhs)
77
+ return (op,)
78
+
79
+ @impl(op.stmts.Mult)
80
+ def mult(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Mult):
81
+ lhs = frame.get(stmt.lhs)
82
+ rhs = frame.get(stmt.rhs)
83
+ op = MultRuntime(lhs, rhs)
84
+ return (op,)
85
+
86
+ @impl(op.stmts.Adjoint)
87
+ def adjoint(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Adjoint):
88
+ op_ = frame.get(stmt.op)
89
+ return (AdjointRuntime(op_),)
90
+
91
+ @impl(op.stmts.Scale)
92
+ def scale(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Scale):
93
+ op_ = frame.get(stmt.op)
94
+ factor = frame.get(stmt.factor)
95
+ return (ScaleRuntime(operator=op_, factor=factor),)
96
+
97
+ @impl(op.stmts.U3)
98
+ def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.U3):
99
+ theta = frame.get(stmt.theta)
100
+ phi = frame.get(stmt.phi)
101
+ lam = frame.get(stmt.lam)
102
+ return (U3Runtime(theta=theta, phi=phi, lam=lam),)
103
+
104
+ @impl(op.stmts.PhaseOp)
105
+ def phaseop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PhaseOp):
106
+ theta = frame.get(stmt.theta)
107
+ op_ = HermitianRuntime(cirq.IdentityGate(num_qubits=1))
108
+ return (ScaleRuntime(operator=op_, factor=np.exp(1j * theta)),)
109
+
110
+ @impl(op.stmts.ShiftOp)
111
+ def shiftop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ShiftOp):
112
+ theta = frame.get(stmt.theta)
113
+
114
+ # NOTE: ShiftOp(theta) == U3(pi, theta, 0)
115
+ return (U3Runtime(math.pi, theta, 0),)
116
+
117
+ @impl(op.stmts.Reset)
118
+ def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
119
+ return (HermitianRuntime(cirq.ResetChannel()),)
120
+
121
+ @impl(op.stmts.PauliString)
122
+ def pauli_string(
123
+ self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
124
+ ):
125
+ return (PauliStringRuntime(stmt.string),)
@@ -0,0 +1,60 @@
1
+ import cirq
2
+ from kirin.interp import MethodTable, impl
3
+
4
+ from ... import qubit
5
+ from .op import OperatorRuntimeABC
6
+ from .emit_circuit import EmitCirq, EmitCirqFrame
7
+
8
+
9
+ @qubit.dialect.register(key="emit.cirq")
10
+ class EmitCirqQubitMethods(MethodTable):
11
+ @impl(qubit.New)
12
+ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
13
+ n_qubits = frame.get(stmt.n_qubits)
14
+
15
+ if frame.qubits is not None:
16
+ cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)]
17
+ else:
18
+ cirq_qubits = [
19
+ cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
20
+ ]
21
+
22
+ frame.qubit_index += n_qubits
23
+ return (cirq_qubits,)
24
+
25
+ @impl(qubit.Apply)
26
+ def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
27
+ op: OperatorRuntimeABC = frame.get(stmt.operator)
28
+ qbits = frame.get(stmt.qubits)
29
+ operations = op.apply(qbits)
30
+ for operation in operations:
31
+ frame.circuit.append(operation)
32
+ return ()
33
+
34
+ @impl(qubit.Broadcast)
35
+ def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast):
36
+ op = frame.get(stmt.operator)
37
+ qbits = frame.get(stmt.qubits)
38
+
39
+ cirq_ops = []
40
+ for qbit in qbits:
41
+ cirq_ops.extend(op.apply([qbit]))
42
+
43
+ frame.circuit.append(cirq.Moment(cirq_ops))
44
+ return ()
45
+
46
+ @impl(qubit.MeasureQubit)
47
+ def measure_qubit(
48
+ self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit
49
+ ):
50
+ qbit = frame.get(stmt.qubit)
51
+ frame.circuit.append(cirq.measure(qbit))
52
+ return ()
53
+
54
+ @impl(qubit.MeasureQubitList)
55
+ def measure_qubit_list(
56
+ self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList
57
+ ):
58
+ qbits = frame.get(stmt.qubits)
59
+ frame.circuit.append(cirq.measure(qbits))
60
+ return ()
@@ -0,0 +1,234 @@
1
+ import math
2
+ from typing import Sequence
3
+ from numbers import Number
4
+ from dataclasses import dataclass
5
+
6
+ import cirq
7
+
8
+
9
+ @dataclass
10
+ class OperatorRuntimeABC:
11
+ def num_qubits(self) -> int: ...
12
+
13
+ def check_qubits(self, qubits: Sequence[cirq.Qid]):
14
+ assert self.num_qubits() == len(qubits)
15
+
16
+ def apply(
17
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
18
+ ) -> list[cirq.Operation]:
19
+ self.check_qubits(qubits)
20
+ return self.unsafe_apply(qubits, adjoint=adjoint)
21
+
22
+ def unsafe_apply(
23
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
24
+ ) -> list[cirq.Operation]: ...
25
+
26
+
27
+ @dataclass
28
+ class UnsafeOperatorRuntimeABC(OperatorRuntimeABC):
29
+ def check_qubits(self, qubits: Sequence[cirq.Qid]):
30
+ # NOTE: let's let cirq check this one
31
+ pass
32
+
33
+
34
+ @dataclass
35
+ class BasicOpRuntime(UnsafeOperatorRuntimeABC):
36
+ gate: cirq.Gate
37
+
38
+ def num_qubits(self) -> int:
39
+ return self.gate.num_qubits()
40
+
41
+
42
+ @dataclass
43
+ class UnitaryRuntime(BasicOpRuntime):
44
+ def unsafe_apply(
45
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
46
+ ) -> list[cirq.Operation]:
47
+ exponent = (-1) ** adjoint
48
+ return [self.gate(*qubits) ** exponent]
49
+
50
+
51
+ @dataclass
52
+ class HermitianRuntime(BasicOpRuntime):
53
+ def unsafe_apply(
54
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
55
+ ) -> list[cirq.Operation]:
56
+ return [self.gate(*qubits)]
57
+
58
+
59
+ @dataclass
60
+ class ProjectorRuntime(UnsafeOperatorRuntimeABC):
61
+ target_state: bool
62
+
63
+ def num_qubits(self) -> int:
64
+ return 1
65
+
66
+ def unsafe_apply(
67
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
68
+ ) -> list[cirq.Operation]:
69
+ # NOTE: this doesn't scale well, but works
70
+ sign = (-1) ** self.target_state
71
+ p = (1 + sign * cirq.Z(*qubits)) / 2
72
+ return [p]
73
+
74
+
75
+ @dataclass
76
+ class SpRuntime(UnsafeOperatorRuntimeABC):
77
+ def num_qubits(self) -> int:
78
+ return 1
79
+
80
+ def unsafe_apply(
81
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
82
+ ) -> list[cirq.Operation]:
83
+ if adjoint:
84
+ return SnRuntime().unsafe_apply(qubits, adjoint=False)
85
+
86
+ return [(cirq.X(*qubits) - 1j * cirq.Y(*qubits)) / 2] # type: ignore -- we're not dealing with cirq's type issues
87
+
88
+
89
+ @dataclass
90
+ class SnRuntime(UnsafeOperatorRuntimeABC):
91
+ def num_qubits(self) -> int:
92
+ return 1
93
+
94
+ def unsafe_apply(
95
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
96
+ ) -> list[cirq.Operation]:
97
+ if adjoint:
98
+ return SpRuntime().unsafe_apply(qubits, adjoint=False)
99
+
100
+ return [(cirq.X(*qubits) + 1j * cirq.Y(*qubits)) / 2] # type: ignore -- we're not dealing with cirq's type issues
101
+
102
+
103
+ @dataclass
104
+ class MultRuntime(OperatorRuntimeABC):
105
+ lhs: OperatorRuntimeABC
106
+ rhs: OperatorRuntimeABC
107
+
108
+ def num_qubits(self) -> int:
109
+ n = self.lhs.num_qubits()
110
+ assert n == self.rhs.num_qubits()
111
+ return n
112
+
113
+ def unsafe_apply(
114
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
115
+ ) -> list[cirq.Operation]:
116
+ rhs = self.rhs.unsafe_apply(qubits, adjoint=adjoint)
117
+ lhs = self.lhs.unsafe_apply(qubits, adjoint=adjoint)
118
+
119
+ if adjoint:
120
+ return lhs + rhs
121
+ else:
122
+ return rhs + lhs
123
+
124
+
125
+ @dataclass
126
+ class KronRuntime(OperatorRuntimeABC):
127
+ lhs: OperatorRuntimeABC
128
+ rhs: OperatorRuntimeABC
129
+
130
+ def num_qubits(self) -> int:
131
+ return self.lhs.num_qubits() + self.rhs.num_qubits()
132
+
133
+ def unsafe_apply(
134
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
135
+ ) -> list[cirq.Operation]:
136
+ n = self.lhs.num_qubits()
137
+ cirq_ops = self.lhs.unsafe_apply(qubits[:n], adjoint=adjoint)
138
+ cirq_ops.extend(self.rhs.unsafe_apply(qubits[n:], adjoint=adjoint))
139
+ return cirq_ops
140
+
141
+
142
+ @dataclass
143
+ class ControlRuntime(OperatorRuntimeABC):
144
+ operator: OperatorRuntimeABC
145
+ n_controls: int
146
+
147
+ def num_qubits(self) -> int:
148
+ return self.n_controls + self.operator.num_qubits()
149
+
150
+ def unsafe_apply(
151
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
152
+ ) -> list[cirq.Operation]:
153
+ m = len(qubits) - self.n_controls
154
+ cirq_ops = self.operator.unsafe_apply(qubits[m:], adjoint=adjoint)
155
+ controlled_ops = [cirq_op.controlled_by(*qubits[:m]) for cirq_op in cirq_ops]
156
+ return controlled_ops
157
+
158
+
159
+ @dataclass
160
+ class AdjointRuntime(OperatorRuntimeABC):
161
+ operator: OperatorRuntimeABC
162
+
163
+ def num_qubits(self) -> int:
164
+ return self.operator.num_qubits()
165
+
166
+ def unsafe_apply(
167
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
168
+ ) -> list[cirq.Operation]:
169
+ # NOTE: to account for e.g. adjoint(adjoint(op))
170
+ passed_on_adjoint = not adjoint
171
+ return self.operator.unsafe_apply(qubits, adjoint=passed_on_adjoint)
172
+
173
+
174
+ @dataclass
175
+ class U3Runtime(UnsafeOperatorRuntimeABC):
176
+ theta: float
177
+ phi: float
178
+ lam: float
179
+
180
+ def num_qubits(self) -> int:
181
+ return 1
182
+
183
+ def angles(self, adjoint: bool) -> tuple[float, float, float]:
184
+ if adjoint:
185
+ # NOTE: adjoint(U(theta, phi, lam)) == U(-theta, -lam, -phi)
186
+ return -self.theta, -self.lam, -self.phi
187
+ else:
188
+ return self.theta, self.phi, self.lam
189
+
190
+ def unsafe_apply(
191
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
192
+ ) -> list[cirq.Operation]:
193
+ theta, phi, lam = self.angles(adjoint=adjoint)
194
+
195
+ ops = [
196
+ cirq.Rz(rads=lam)(*qubits),
197
+ cirq.Rx(rads=math.pi / 2)(*qubits),
198
+ cirq.Rz(rads=theta)(*qubits),
199
+ cirq.Rx(rads=-math.pi / 2)(*qubits),
200
+ cirq.Rz(rads=phi)(*qubits),
201
+ ]
202
+
203
+ return ops
204
+
205
+
206
+ @dataclass
207
+ class ScaleRuntime(OperatorRuntimeABC):
208
+ factor: Number
209
+ operator: OperatorRuntimeABC
210
+
211
+ def num_qubits(self) -> int:
212
+ return self.operator.num_qubits()
213
+
214
+ def unsafe_apply(
215
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
216
+ ) -> list[cirq.Operation]:
217
+ cirq_ops = self.operator.unsafe_apply(qubits=qubits, adjoint=adjoint)
218
+ return [self.factor * cirq_ops[0]] + cirq_ops[1:] # type: ignore
219
+
220
+
221
+ @dataclass
222
+ class PauliStringRuntime(OperatorRuntimeABC):
223
+ string: str
224
+
225
+ def num_qubits(self) -> int:
226
+ return len(self.string)
227
+
228
+ def unsafe_apply(
229
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
230
+ ) -> list[cirq.Operation]:
231
+ pauli_mapping = {
232
+ qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
233
+ }
234
+ return [cirq.PauliString(pauli_mapping)]