bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,125 +0,0 @@
1
- import math
2
-
3
- import cirq
4
- import numpy as np
5
- from kirin.interp import MethodTable, impl
6
-
7
- from ... import op
8
- from .runtime import (
9
- SnRuntime,
10
- SpRuntime,
11
- U3Runtime,
12
- KronRuntime,
13
- MultRuntime,
14
- ScaleRuntime,
15
- AdjointRuntime,
16
- ControlRuntime,
17
- UnitaryRuntime,
18
- HermitianRuntime,
19
- ProjectorRuntime,
20
- OperatorRuntimeABC,
21
- PauliStringRuntime,
22
- )
23
- from .emit_circuit import EmitCirq, EmitCirqFrame
24
-
25
-
26
- @op.dialect.register(key="emit.cirq")
27
- class EmitCirqOpMethods(MethodTable):
28
-
29
- @impl(op.stmts.X)
30
- @impl(op.stmts.Y)
31
- @impl(op.stmts.Z)
32
- @impl(op.stmts.H)
33
- def hermitian(
34
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
35
- ):
36
- cirq_op = getattr(cirq, stmt.name.upper())
37
- return (HermitianRuntime(cirq_op),)
38
-
39
- @impl(op.stmts.S)
40
- @impl(op.stmts.T)
41
- def unitary(
42
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
43
- ):
44
- cirq_op = getattr(cirq, stmt.name.upper())
45
- return (UnitaryRuntime(cirq_op),)
46
-
47
- @impl(op.stmts.P0)
48
- @impl(op.stmts.P1)
49
- def projector(
50
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P0 | op.stmts.P1
51
- ):
52
- return (ProjectorRuntime(isinstance(stmt, op.stmts.P1)),)
53
-
54
- @impl(op.stmts.Sn)
55
- def sn(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sn):
56
- return (SnRuntime(),)
57
-
58
- @impl(op.stmts.Sp)
59
- def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):
60
- return (SpRuntime(),)
61
-
62
- @impl(op.stmts.Identity)
63
- def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
64
- op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites))
65
- return (op,)
66
-
67
- @impl(op.stmts.Control)
68
- def control(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Control):
69
- op: OperatorRuntimeABC = frame.get(stmt.op)
70
- return (ControlRuntime(op, stmt.n_controls),)
71
-
72
- @impl(op.stmts.Kron)
73
- def kron(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Kron):
74
- lhs = frame.get(stmt.lhs)
75
- rhs = frame.get(stmt.rhs)
76
- op = KronRuntime(lhs, rhs)
77
- return (op,)
78
-
79
- @impl(op.stmts.Mult)
80
- def mult(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Mult):
81
- lhs = frame.get(stmt.lhs)
82
- rhs = frame.get(stmt.rhs)
83
- op = MultRuntime(lhs, rhs)
84
- return (op,)
85
-
86
- @impl(op.stmts.Adjoint)
87
- def adjoint(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Adjoint):
88
- op_ = frame.get(stmt.op)
89
- return (AdjointRuntime(op_),)
90
-
91
- @impl(op.stmts.Scale)
92
- def scale(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Scale):
93
- op_ = frame.get(stmt.op)
94
- factor = frame.get(stmt.factor)
95
- return (ScaleRuntime(operator=op_, factor=factor),)
96
-
97
- @impl(op.stmts.U3)
98
- def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.U3):
99
- theta = frame.get(stmt.theta)
100
- phi = frame.get(stmt.phi)
101
- lam = frame.get(stmt.lam)
102
- return (U3Runtime(theta=theta, phi=phi, lam=lam),)
103
-
104
- @impl(op.stmts.PhaseOp)
105
- def phaseop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PhaseOp):
106
- theta = frame.get(stmt.theta)
107
- op_ = HermitianRuntime(cirq.IdentityGate(num_qubits=1))
108
- return (ScaleRuntime(operator=op_, factor=np.exp(1j * theta)),)
109
-
110
- @impl(op.stmts.ShiftOp)
111
- def shiftop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ShiftOp):
112
- theta = frame.get(stmt.theta)
113
-
114
- # NOTE: ShiftOp(theta) == U3(pi, theta, 0)
115
- return (U3Runtime(math.pi, theta, 0),)
116
-
117
- @impl(op.stmts.Reset)
118
- def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
119
- return (HermitianRuntime(cirq.ResetChannel()),)
120
-
121
- @impl(op.stmts.PauliString)
122
- def pauli_string(
123
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
124
- ):
125
- return (PauliStringRuntime(stmt.string),)
@@ -1,60 +0,0 @@
1
- import cirq
2
- from kirin.interp import MethodTable, impl
3
-
4
- from ... import qubit
5
- from .op import OperatorRuntimeABC
6
- from .emit_circuit import EmitCirq, EmitCirqFrame
7
-
8
-
9
- @qubit.dialect.register(key="emit.cirq")
10
- class EmitCirqQubitMethods(MethodTable):
11
- @impl(qubit.New)
12
- def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
13
- n_qubits = frame.get(stmt.n_qubits)
14
-
15
- if frame.qubits is not None:
16
- cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)]
17
- else:
18
- cirq_qubits = [
19
- cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
20
- ]
21
-
22
- frame.qubit_index += n_qubits
23
- return (cirq_qubits,)
24
-
25
- @impl(qubit.Apply)
26
- def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
27
- op: OperatorRuntimeABC = frame.get(stmt.operator)
28
- qbits = frame.get(stmt.qubits)
29
- operations = op.apply(qbits)
30
- for operation in operations:
31
- frame.circuit.append(operation)
32
- return ()
33
-
34
- @impl(qubit.Broadcast)
35
- def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast):
36
- op = frame.get(stmt.operator)
37
- qbits = frame.get(stmt.qubits)
38
-
39
- cirq_ops = []
40
- for qbit in qbits:
41
- cirq_ops.extend(op.apply([qbit]))
42
-
43
- frame.circuit.append(cirq.Moment(cirq_ops))
44
- return ()
45
-
46
- @impl(qubit.MeasureQubit)
47
- def measure_qubit(
48
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit
49
- ):
50
- qbit = frame.get(stmt.qubit)
51
- frame.circuit.append(cirq.measure(qbit))
52
- return ()
53
-
54
- @impl(qubit.MeasureQubitList)
55
- def measure_qubit_list(
56
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList
57
- ):
58
- qbits = frame.get(stmt.qubits)
59
- frame.circuit.append(cirq.measure(qbits))
60
- return ()
@@ -1,242 +0,0 @@
1
- import math
2
- from typing import Sequence
3
- from numbers import Number
4
- from dataclasses import dataclass
5
-
6
- import cirq
7
-
8
-
9
- @dataclass
10
- class OperatorRuntimeABC:
11
- def num_qubits(self) -> int: ...
12
-
13
- def check_qubits(self, qubits: Sequence[cirq.Qid]):
14
- assert self.num_qubits() == len(qubits)
15
-
16
- def apply(
17
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
18
- ) -> list[cirq.Operation]:
19
- self.check_qubits(qubits)
20
- return self.unsafe_apply(qubits, adjoint=adjoint)
21
-
22
- def unsafe_apply(
23
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
24
- ) -> list[cirq.Operation]:
25
- raise NotImplementedError(
26
- f"Apply method needs to be implemented in {self.__class__.__name__}"
27
- )
28
-
29
-
30
- @dataclass
31
- class UnsafeOperatorRuntimeABC(OperatorRuntimeABC):
32
- def check_qubits(self, qubits: Sequence[cirq.Qid]):
33
- # NOTE: let's let cirq check this one
34
- pass
35
-
36
-
37
- @dataclass
38
- class BasicOpRuntime(UnsafeOperatorRuntimeABC):
39
- gate: cirq.Gate
40
-
41
- def num_qubits(self) -> int:
42
- return self.gate.num_qubits()
43
-
44
- def unsafe_apply(
45
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
46
- ) -> list[cirq.Operation]:
47
- return [self.gate(*qubits)]
48
-
49
-
50
- @dataclass
51
- class UnitaryRuntime(BasicOpRuntime):
52
- def unsafe_apply(
53
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
54
- ) -> list[cirq.Operation]:
55
- exponent = (-1) ** adjoint
56
- return [self.gate(*qubits) ** exponent]
57
-
58
-
59
- @dataclass
60
- class HermitianRuntime(BasicOpRuntime):
61
- def unsafe_apply(
62
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
63
- ) -> list[cirq.Operation]:
64
- return [self.gate(*qubits)]
65
-
66
-
67
- @dataclass
68
- class ProjectorRuntime(UnsafeOperatorRuntimeABC):
69
- target_state: bool
70
-
71
- def num_qubits(self) -> int:
72
- return 1
73
-
74
- def unsafe_apply(
75
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
76
- ) -> list[cirq.Operation]:
77
- # NOTE: this doesn't scale well, but works
78
- sign = (-1) ** self.target_state
79
- p = (1 + sign * cirq.Z(*qubits)) / 2
80
- return [p]
81
-
82
-
83
- @dataclass
84
- class SpRuntime(UnsafeOperatorRuntimeABC):
85
- def num_qubits(self) -> int:
86
- return 1
87
-
88
- def unsafe_apply(
89
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
90
- ) -> list[cirq.Operation]:
91
- if adjoint:
92
- return SnRuntime().unsafe_apply(qubits, adjoint=False)
93
-
94
- return [(cirq.X(*qubits) - 1j * cirq.Y(*qubits)) / 2] # type: ignore -- we're not dealing with cirq's type issues
95
-
96
-
97
- @dataclass
98
- class SnRuntime(UnsafeOperatorRuntimeABC):
99
- def num_qubits(self) -> int:
100
- return 1
101
-
102
- def unsafe_apply(
103
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
104
- ) -> list[cirq.Operation]:
105
- if adjoint:
106
- return SpRuntime().unsafe_apply(qubits, adjoint=False)
107
-
108
- return [(cirq.X(*qubits) + 1j * cirq.Y(*qubits)) / 2] # type: ignore -- we're not dealing with cirq's type issues
109
-
110
-
111
- @dataclass
112
- class MultRuntime(OperatorRuntimeABC):
113
- lhs: OperatorRuntimeABC
114
- rhs: OperatorRuntimeABC
115
-
116
- def num_qubits(self) -> int:
117
- n = self.lhs.num_qubits()
118
- assert n == self.rhs.num_qubits()
119
- return n
120
-
121
- def unsafe_apply(
122
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
123
- ) -> list[cirq.Operation]:
124
- rhs = self.rhs.unsafe_apply(qubits, adjoint=adjoint)
125
- lhs = self.lhs.unsafe_apply(qubits, adjoint=adjoint)
126
-
127
- if adjoint:
128
- return lhs + rhs
129
- else:
130
- return rhs + lhs
131
-
132
-
133
- @dataclass
134
- class KronRuntime(OperatorRuntimeABC):
135
- lhs: OperatorRuntimeABC
136
- rhs: OperatorRuntimeABC
137
-
138
- def num_qubits(self) -> int:
139
- return self.lhs.num_qubits() + self.rhs.num_qubits()
140
-
141
- def unsafe_apply(
142
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
143
- ) -> list[cirq.Operation]:
144
- n = self.lhs.num_qubits()
145
- cirq_ops = self.lhs.unsafe_apply(qubits[:n], adjoint=adjoint)
146
- cirq_ops.extend(self.rhs.unsafe_apply(qubits[n:], adjoint=adjoint))
147
- return cirq_ops
148
-
149
-
150
- @dataclass
151
- class ControlRuntime(OperatorRuntimeABC):
152
- operator: OperatorRuntimeABC
153
- n_controls: int
154
-
155
- def num_qubits(self) -> int:
156
- return self.n_controls + self.operator.num_qubits()
157
-
158
- def unsafe_apply(
159
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
160
- ) -> list[cirq.Operation]:
161
- m = len(qubits) - self.n_controls
162
- cirq_ops = self.operator.unsafe_apply(qubits[m:], adjoint=adjoint)
163
- controlled_ops = [cirq_op.controlled_by(*qubits[:m]) for cirq_op in cirq_ops]
164
- return controlled_ops
165
-
166
-
167
- @dataclass
168
- class AdjointRuntime(OperatorRuntimeABC):
169
- operator: OperatorRuntimeABC
170
-
171
- def num_qubits(self) -> int:
172
- return self.operator.num_qubits()
173
-
174
- def unsafe_apply(
175
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
176
- ) -> list[cirq.Operation]:
177
- # NOTE: to account for e.g. adjoint(adjoint(op))
178
- passed_on_adjoint = not adjoint
179
- return self.operator.unsafe_apply(qubits, adjoint=passed_on_adjoint)
180
-
181
-
182
- @dataclass
183
- class U3Runtime(UnsafeOperatorRuntimeABC):
184
- theta: float
185
- phi: float
186
- lam: float
187
-
188
- def num_qubits(self) -> int:
189
- return 1
190
-
191
- def angles(self, adjoint: bool) -> tuple[float, float, float]:
192
- if adjoint:
193
- # NOTE: adjoint(U(theta, phi, lam)) == U(-theta, -lam, -phi)
194
- return -self.theta, -self.lam, -self.phi
195
- else:
196
- return self.theta, self.phi, self.lam
197
-
198
- def unsafe_apply(
199
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
200
- ) -> list[cirq.Operation]:
201
- theta, phi, lam = self.angles(adjoint=adjoint)
202
-
203
- ops = [
204
- cirq.Rz(rads=lam)(*qubits),
205
- cirq.Rx(rads=math.pi / 2)(*qubits),
206
- cirq.Rz(rads=theta)(*qubits),
207
- cirq.Rx(rads=-math.pi / 2)(*qubits),
208
- cirq.Rz(rads=phi)(*qubits),
209
- ]
210
-
211
- return ops
212
-
213
-
214
- @dataclass
215
- class ScaleRuntime(OperatorRuntimeABC):
216
- factor: Number
217
- operator: OperatorRuntimeABC
218
-
219
- def num_qubits(self) -> int:
220
- return self.operator.num_qubits()
221
-
222
- def unsafe_apply(
223
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
224
- ) -> list[cirq.Operation]:
225
- cirq_ops = self.operator.unsafe_apply(qubits=qubits, adjoint=adjoint)
226
- return [self.factor * cirq_ops[0]] + cirq_ops[1:] # type: ignore
227
-
228
-
229
- @dataclass
230
- class PauliStringRuntime(OperatorRuntimeABC):
231
- string: str
232
-
233
- def num_qubits(self) -> int:
234
- return len(self.string)
235
-
236
- def unsafe_apply(
237
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
238
- ) -> list[cirq.Operation]:
239
- pauli_mapping = {
240
- qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
241
- }
242
- return [cirq.PauliString(pauli_mapping)]