bloqade-circuit 0.7.13__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +96 -51
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,129 +0,0 @@
1
- from typing import Sequence
2
- from dataclasses import field, dataclass
3
-
4
- import cirq
5
- from kirin import ir, interp
6
- from kirin.emit import EmitABC, EmitError, EmitFrame
7
- from kirin.interp import MethodTable, impl
8
- from kirin.dialects import func
9
- from typing_extensions import Self
10
-
11
- from ... import kernel
12
-
13
-
14
- @dataclass
15
- class EmitCirqFrame(EmitFrame):
16
- qubit_index: int = 0
17
- qubits: Sequence[cirq.Qid] | None = None
18
- circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
19
-
20
-
21
- def _default_kernel():
22
- return kernel
23
-
24
-
25
- @dataclass
26
- class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
27
- keys = ["emit.cirq", "main"]
28
- dialects: ir.DialectGroup = field(default_factory=_default_kernel)
29
- void = cirq.Circuit()
30
- qubits: Sequence[cirq.Qid] | None = None
31
- _cached_circuit_operations: dict[int, cirq.CircuitOperation] = field(
32
- init=False, default_factory=dict
33
- )
34
-
35
- def initialize(self) -> Self:
36
- return super().initialize()
37
-
38
- def initialize_frame(
39
- self, code: ir.Statement, *, has_parent_access: bool = False
40
- ) -> EmitCirqFrame:
41
- return EmitCirqFrame(
42
- code, has_parent_access=has_parent_access, qubits=self.qubits
43
- )
44
-
45
- def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
46
- return self.run_callable(method.code, args)
47
-
48
- def run_callable_region(
49
- self,
50
- frame: EmitCirqFrame,
51
- code: ir.Statement,
52
- region: ir.Region,
53
- args: tuple,
54
- ):
55
- if len(region.blocks) > 0:
56
- block_args = list(region.blocks[0].args)
57
- # NOTE: skip self arg
58
- frame.set_values(block_args[1:], args)
59
-
60
- results = self.eval_stmt(frame, code)
61
- if isinstance(results, tuple):
62
- if len(results) == 0:
63
- return self.void
64
- elif len(results) == 1:
65
- return results[0]
66
- raise interp.InterpreterError(f"Unexpected results {results}")
67
-
68
- def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
69
- for stmt in block.stmts:
70
- result = self.eval_stmt(frame, stmt)
71
- if isinstance(result, tuple):
72
- frame.set_values(stmt.results, result)
73
-
74
- return frame.circuit
75
-
76
-
77
- @func.dialect.register(key="emit.cirq")
78
- class FuncEmit(MethodTable):
79
-
80
- @impl(func.Function)
81
- def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
82
- emit.run_ssacfg_region(frame, stmt.body, ())
83
- return (frame.circuit,)
84
-
85
- @impl(func.Invoke)
86
- def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
87
- stmt_hash = hash((stmt.callee, stmt.inputs))
88
- if (
89
- cached_circuit_op := emit._cached_circuit_operations.get(stmt_hash)
90
- ) is not None:
91
- # NOTE: cache hit
92
- frame.circuit.append(cached_circuit_op)
93
- return ()
94
-
95
- ret = stmt.result
96
-
97
- with emit.new_frame(stmt.callee.code, has_parent_access=True) as sub_frame:
98
- sub_frame.qubit_index = frame.qubit_index
99
- sub_frame.qubits = frame.qubits
100
-
101
- region = stmt.callee.callable_region
102
- if len(region.blocks) > 1:
103
- raise EmitError(
104
- "Subroutine with more than a single block encountered. This is not supported!"
105
- )
106
-
107
- # NOTE: get the arguments, "self" is just an empty circuit
108
- method_self = emit.void
109
- args = [frame.get(arg_) for arg_ in stmt.inputs]
110
- emit.run_ssacfg_region(
111
- sub_frame, stmt.callee.callable_region, args=(method_self, *args)
112
- )
113
- sub_circuit = sub_frame.circuit
114
-
115
- # NOTE: check to see if the call terminates with a return value and fetch the value;
116
- # we don't support multiple return statements via control flow so we just pick the first one
117
- block = region.blocks[0]
118
- return_stmt = next(
119
- (stmt for stmt in block.stmts if isinstance(stmt, func.Return)), None
120
- )
121
- if return_stmt is not None:
122
- frame.entries[ret] = sub_frame.get(return_stmt.value)
123
-
124
- circuit_op = cirq.CircuitOperation(
125
- sub_circuit.freeze(), use_repetition_ids=False
126
- )
127
- emit._cached_circuit_operations[stmt_hash] = circuit_op
128
- frame.circuit.append(circuit_op)
129
- return ()
@@ -1,49 +0,0 @@
1
- import cirq
2
- from kirin.emit import EmitError
3
- from kirin.interp import MethodTable, impl
4
-
5
- from ... import noise
6
- from .runtime import (
7
- KronRuntime,
8
- BasicOpRuntime,
9
- OperatorRuntimeABC,
10
- PauliStringRuntime,
11
- )
12
- from .emit_circuit import EmitCirq, EmitCirqFrame
13
-
14
-
15
- @noise.dialect.register(key="emit.cirq")
16
- class EmitCirqNoiseMethods(MethodTable):
17
-
18
- @impl(noise.stmts.StochasticUnitaryChannel)
19
- def stochastic_unitary_channel(
20
- self,
21
- emit: EmitCirq,
22
- frame: EmitCirqFrame,
23
- stmt: noise.stmts.StochasticUnitaryChannel,
24
- ):
25
- ops = frame.get(stmt.operators)
26
- ps = frame.get(stmt.probabilities)
27
-
28
- error_probabilities = {self._op_to_key(op_): p for op_, p in zip(ops, ps)}
29
- cirq_op = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
30
- return (BasicOpRuntime(cirq_op),)
31
-
32
- @staticmethod
33
- def _op_to_key(operator: OperatorRuntimeABC) -> str:
34
- match operator:
35
- case KronRuntime():
36
- key_lhs = EmitCirqNoiseMethods._op_to_key(operator.lhs)
37
- key_rhs = EmitCirqNoiseMethods._op_to_key(operator.rhs)
38
- return key_lhs + key_rhs
39
-
40
- case BasicOpRuntime():
41
- return str(operator.gate)
42
-
43
- case PauliStringRuntime():
44
- return operator.string
45
-
46
- case _:
47
- raise EmitError(
48
- f"Unexpected operator runtime in StochasticUnitaryChannel of type {type(operator).__name__} encountered!"
49
- )
@@ -1,176 +0,0 @@
1
- import math
2
-
3
- import cirq
4
- import numpy as np
5
- from kirin.emit import EmitError
6
- from kirin.interp import MethodTable, impl
7
-
8
- from ... import op
9
- from .runtime import (
10
- SnRuntime,
11
- SpRuntime,
12
- U3Runtime,
13
- KronRuntime,
14
- MultRuntime,
15
- ScaleRuntime,
16
- AdjointRuntime,
17
- BasicOpRuntime,
18
- ControlRuntime,
19
- UnitaryRuntime,
20
- HermitianRuntime,
21
- ProjectorRuntime,
22
- OperatorRuntimeABC,
23
- PauliStringRuntime,
24
- )
25
- from .emit_circuit import EmitCirq, EmitCirqFrame
26
-
27
-
28
- @op.dialect.register(key="emit.cirq")
29
- class EmitCirqOpMethods(MethodTable):
30
-
31
- @impl(op.stmts.X)
32
- @impl(op.stmts.Y)
33
- @impl(op.stmts.Z)
34
- @impl(op.stmts.H)
35
- def hermitian(
36
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
37
- ):
38
- cirq_op = getattr(cirq, stmt.name.upper())
39
- return (HermitianRuntime(cirq_op),)
40
-
41
- @impl(op.stmts.S)
42
- @impl(op.stmts.T)
43
- def unitary(
44
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
45
- ):
46
- cirq_op = getattr(cirq, stmt.name.upper())
47
- return (UnitaryRuntime(cirq_op),)
48
-
49
- @impl(op.stmts.P0)
50
- @impl(op.stmts.P1)
51
- def projector(
52
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P0 | op.stmts.P1
53
- ):
54
- return (ProjectorRuntime(isinstance(stmt, op.stmts.P1)),)
55
-
56
- @impl(op.stmts.Sn)
57
- def sn(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sn):
58
- return (SnRuntime(),)
59
-
60
- @impl(op.stmts.Sp)
61
- def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):
62
- return (SpRuntime(),)
63
-
64
- @impl(op.stmts.Identity)
65
- def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
66
- op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites))
67
- return (op,)
68
-
69
- @impl(op.stmts.Control)
70
- def control(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Control):
71
- op: OperatorRuntimeABC = frame.get(stmt.op)
72
- return (ControlRuntime(op, stmt.n_controls),)
73
-
74
- @impl(op.stmts.Kron)
75
- def kron(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Kron):
76
- lhs = frame.get(stmt.lhs)
77
- rhs = frame.get(stmt.rhs)
78
- op = KronRuntime(lhs, rhs)
79
- return (op,)
80
-
81
- @impl(op.stmts.Mult)
82
- def mult(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Mult):
83
- lhs = frame.get(stmt.lhs)
84
- rhs = frame.get(stmt.rhs)
85
- op = MultRuntime(lhs, rhs)
86
- return (op,)
87
-
88
- @impl(op.stmts.Adjoint)
89
- def adjoint(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Adjoint):
90
- op_ = frame.get(stmt.op)
91
- return (AdjointRuntime(op_),)
92
-
93
- @impl(op.stmts.Scale)
94
- def scale(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Scale):
95
- op_ = frame.get(stmt.op)
96
- factor = frame.get(stmt.factor)
97
- return (ScaleRuntime(operator=op_, factor=factor),)
98
-
99
- @impl(op.stmts.U3)
100
- def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.U3):
101
- theta = frame.get(stmt.theta)
102
- phi = frame.get(stmt.phi)
103
- lam = frame.get(stmt.lam)
104
- return (U3Runtime(theta=theta, phi=phi, lam=lam),)
105
-
106
- @impl(op.stmts.PhaseOp)
107
- def phaseop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PhaseOp):
108
- theta = frame.get(stmt.theta)
109
- op_ = HermitianRuntime(cirq.IdentityGate(num_qubits=1))
110
- return (ScaleRuntime(operator=op_, factor=np.exp(1j * theta)),)
111
-
112
- @impl(op.stmts.ShiftOp)
113
- def shiftop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ShiftOp):
114
- theta = frame.get(stmt.theta)
115
-
116
- # NOTE: ShiftOp(theta) == U3(pi, theta, 0)
117
- return (U3Runtime(math.pi, theta, 0),)
118
-
119
- @impl(op.stmts.Reset)
120
- def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
121
- return (BasicOpRuntime(cirq.ResetChannel()),)
122
-
123
- @impl(op.stmts.PauliString)
124
- def pauli_string(
125
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
126
- ):
127
- return (PauliStringRuntime(stmt.string),)
128
-
129
- @impl(op.stmts.Rot)
130
- def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
131
- axis: OperatorRuntimeABC = frame.get(stmt.axis)
132
-
133
- if not isinstance(axis, HermitianRuntime):
134
- raise EmitError(
135
- f"Circuit emission only supported for Pauli operators! Got axis {axis}"
136
- )
137
-
138
- angle = frame.get(stmt.angle)
139
-
140
- match axis.gate:
141
- case cirq.X:
142
- gate = cirq.Rx(rads=angle)
143
- case cirq.Y:
144
- gate = cirq.Ry(rads=angle)
145
- case cirq.Z:
146
- gate = cirq.Rz(rads=angle)
147
- case _:
148
- raise EmitError(
149
- f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}"
150
- )
151
-
152
- return (HermitianRuntime(gate=gate),)
153
-
154
- @impl(op.stmts.ResetToOne)
155
- def reset_to_one(
156
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ResetToOne
157
- ):
158
- # NOTE: just apply a reset to 0 and flip in sequence (we re-use the multiplication runtime since it does exactly that)
159
- gate1 = cirq.ResetChannel()
160
- gate2 = cirq.X
161
-
162
- rt1 = BasicOpRuntime(gate1)
163
- rt2 = HermitianRuntime(gate2)
164
-
165
- # NOTE: mind the order: rhs is applied first
166
- return (MultRuntime(rt2, rt1),)
167
-
168
- @impl(op.stmts.SqrtX)
169
- def sqrt_x(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtX):
170
- cirq_op = cirq.XPowGate(exponent=0.5)
171
- return (UnitaryRuntime(cirq_op),)
172
-
173
- @impl(op.stmts.SqrtY)
174
- def sqrt_y(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtY):
175
- cirq_op = cirq.YPowGate(exponent=0.5)
176
- return (UnitaryRuntime(cirq_op),)
@@ -1,58 +0,0 @@
1
- import cirq
2
- from kirin.interp import MethodTable, impl
3
-
4
- from ... import qubit
5
- from .op import OperatorRuntimeABC
6
- from .emit_circuit import EmitCirq, EmitCirqFrame
7
-
8
-
9
- @qubit.dialect.register(key="emit.cirq")
10
- class EmitCirqQubitMethods(MethodTable):
11
- @impl(qubit.New)
12
- def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
13
- n_qubits = frame.get(stmt.n_qubits)
14
-
15
- if frame.qubits is not None:
16
- cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)]
17
- else:
18
- cirq_qubits = [
19
- cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
20
- ]
21
-
22
- frame.qubit_index += n_qubits
23
- return (cirq_qubits,)
24
-
25
- @impl(qubit.Apply)
26
- def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
27
- op: OperatorRuntimeABC = frame.get(stmt.operator)
28
- qbits = [frame.get(qbit) for qbit in stmt.qubits]
29
- operations = op.apply(qbits)
30
- for operation in operations:
31
- frame.circuit.append(operation)
32
- return ()
33
-
34
- @impl(qubit.Broadcast)
35
- def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast):
36
- op = frame.get(stmt.operator)
37
- qbit_lists = [frame.get(qbit) for qbit in stmt.qubits]
38
-
39
- for qbits in zip(*qbit_lists):
40
- frame.circuit.append(op.apply(qbits))
41
-
42
- return ()
43
-
44
- @impl(qubit.MeasureQubit)
45
- def measure_qubit(
46
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit
47
- ):
48
- qbit = frame.get(stmt.qubit)
49
- frame.circuit.append(cirq.measure(qbit))
50
- return (emit.void,)
51
-
52
- @impl(qubit.MeasureQubitList)
53
- def measure_qubit_list(
54
- self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList
55
- ):
56
- qbits = frame.get(stmt.qubits)
57
- frame.circuit.append(cirq.measure(qbits))
58
- return (emit.void,)
@@ -1,242 +0,0 @@
1
- import math
2
- from typing import Sequence
3
- from numbers import Number
4
- from dataclasses import dataclass
5
-
6
- import cirq
7
-
8
-
9
- @dataclass
10
- class OperatorRuntimeABC:
11
- def num_qubits(self) -> int: ...
12
-
13
- def check_qubits(self, qubits: Sequence[cirq.Qid]):
14
- assert self.num_qubits() == len(qubits)
15
-
16
- def apply(
17
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
18
- ) -> list[cirq.Operation]:
19
- self.check_qubits(qubits)
20
- return self.unsafe_apply(qubits, adjoint=adjoint)
21
-
22
- def unsafe_apply(
23
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
24
- ) -> list[cirq.Operation]:
25
- raise NotImplementedError(
26
- f"Apply method needs to be implemented in {self.__class__.__name__}"
27
- )
28
-
29
-
30
- @dataclass
31
- class UnsafeOperatorRuntimeABC(OperatorRuntimeABC):
32
- def check_qubits(self, qubits: Sequence[cirq.Qid]):
33
- # NOTE: let's let cirq check this one
34
- pass
35
-
36
-
37
- @dataclass
38
- class BasicOpRuntime(UnsafeOperatorRuntimeABC):
39
- gate: cirq.Gate
40
-
41
- def num_qubits(self) -> int:
42
- return self.gate.num_qubits()
43
-
44
- def unsafe_apply(
45
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
46
- ) -> list[cirq.Operation]:
47
- return [self.gate(*qubits)]
48
-
49
-
50
- @dataclass
51
- class UnitaryRuntime(BasicOpRuntime):
52
- def unsafe_apply(
53
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
54
- ) -> list[cirq.Operation]:
55
- exponent = (-1) ** adjoint
56
- return [self.gate(*qubits) ** exponent]
57
-
58
-
59
- @dataclass
60
- class HermitianRuntime(BasicOpRuntime):
61
- def unsafe_apply(
62
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
63
- ) -> list[cirq.Operation]:
64
- return [self.gate(*qubits)]
65
-
66
-
67
- @dataclass
68
- class ProjectorRuntime(UnsafeOperatorRuntimeABC):
69
- target_state: bool
70
-
71
- def num_qubits(self) -> int:
72
- return 1
73
-
74
- def unsafe_apply(
75
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
76
- ) -> list[cirq.Operation]:
77
- # NOTE: this doesn't scale well, but works
78
- sign = (-1) ** self.target_state
79
- p = (1 + sign * cirq.Z(*qubits)) / 2
80
- return [p]
81
-
82
-
83
- @dataclass
84
- class SpRuntime(UnsafeOperatorRuntimeABC):
85
- def num_qubits(self) -> int:
86
- return 1
87
-
88
- def unsafe_apply(
89
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
90
- ) -> list[cirq.Operation]:
91
- if adjoint:
92
- return SnRuntime().unsafe_apply(qubits, adjoint=False)
93
-
94
- return [(cirq.X(*qubits) - 1j * cirq.Y(*qubits)) / 2] # type: ignore -- we're not dealing with cirq's type issues
95
-
96
-
97
- @dataclass
98
- class SnRuntime(UnsafeOperatorRuntimeABC):
99
- def num_qubits(self) -> int:
100
- return 1
101
-
102
- def unsafe_apply(
103
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
104
- ) -> list[cirq.Operation]:
105
- if adjoint:
106
- return SpRuntime().unsafe_apply(qubits, adjoint=False)
107
-
108
- return [(cirq.X(*qubits) + 1j * cirq.Y(*qubits)) / 2] # type: ignore -- we're not dealing with cirq's type issues
109
-
110
-
111
- @dataclass
112
- class MultRuntime(OperatorRuntimeABC):
113
- lhs: OperatorRuntimeABC
114
- rhs: OperatorRuntimeABC
115
-
116
- def num_qubits(self) -> int:
117
- n = self.lhs.num_qubits()
118
- assert n == self.rhs.num_qubits()
119
- return n
120
-
121
- def unsafe_apply(
122
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
123
- ) -> list[cirq.Operation]:
124
- rhs = self.rhs.unsafe_apply(qubits, adjoint=adjoint)
125
- lhs = self.lhs.unsafe_apply(qubits, adjoint=adjoint)
126
-
127
- if adjoint:
128
- return lhs + rhs
129
- else:
130
- return rhs + lhs
131
-
132
-
133
- @dataclass
134
- class KronRuntime(OperatorRuntimeABC):
135
- lhs: OperatorRuntimeABC
136
- rhs: OperatorRuntimeABC
137
-
138
- def num_qubits(self) -> int:
139
- return self.lhs.num_qubits() + self.rhs.num_qubits()
140
-
141
- def unsafe_apply(
142
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
143
- ) -> list[cirq.Operation]:
144
- n = self.lhs.num_qubits()
145
- cirq_ops = self.lhs.unsafe_apply(qubits[:n], adjoint=adjoint)
146
- cirq_ops.extend(self.rhs.unsafe_apply(qubits[n:], adjoint=adjoint))
147
- return cirq_ops
148
-
149
-
150
- @dataclass
151
- class ControlRuntime(OperatorRuntimeABC):
152
- operator: OperatorRuntimeABC
153
- n_controls: int
154
-
155
- def num_qubits(self) -> int:
156
- return self.n_controls + self.operator.num_qubits()
157
-
158
- def unsafe_apply(
159
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
160
- ) -> list[cirq.Operation]:
161
- m = len(qubits) - self.n_controls
162
- cirq_ops = self.operator.unsafe_apply(qubits[m:], adjoint=adjoint)
163
- controlled_ops = [cirq_op.controlled_by(*qubits[:m]) for cirq_op in cirq_ops]
164
- return controlled_ops
165
-
166
-
167
- @dataclass
168
- class AdjointRuntime(OperatorRuntimeABC):
169
- operator: OperatorRuntimeABC
170
-
171
- def num_qubits(self) -> int:
172
- return self.operator.num_qubits()
173
-
174
- def unsafe_apply(
175
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
176
- ) -> list[cirq.Operation]:
177
- # NOTE: to account for e.g. adjoint(adjoint(op))
178
- passed_on_adjoint = not adjoint
179
- return self.operator.unsafe_apply(qubits, adjoint=passed_on_adjoint)
180
-
181
-
182
- @dataclass
183
- class U3Runtime(UnsafeOperatorRuntimeABC):
184
- theta: float
185
- phi: float
186
- lam: float
187
-
188
- def num_qubits(self) -> int:
189
- return 1
190
-
191
- def angles(self, adjoint: bool) -> tuple[float, float, float]:
192
- if adjoint:
193
- # NOTE: adjoint(U(theta, phi, lam)) == U(-theta, -lam, -phi)
194
- return -self.theta, -self.lam, -self.phi
195
- else:
196
- return self.theta, self.phi, self.lam
197
-
198
- def unsafe_apply(
199
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
200
- ) -> list[cirq.Operation]:
201
- theta, phi, lam = self.angles(adjoint=adjoint)
202
-
203
- ops = [
204
- cirq.Rz(rads=lam)(*qubits),
205
- cirq.Rx(rads=math.pi / 2)(*qubits),
206
- cirq.Rz(rads=theta)(*qubits),
207
- cirq.Rx(rads=-math.pi / 2)(*qubits),
208
- cirq.Rz(rads=phi)(*qubits),
209
- ]
210
-
211
- return ops
212
-
213
-
214
- @dataclass
215
- class ScaleRuntime(OperatorRuntimeABC):
216
- factor: Number
217
- operator: OperatorRuntimeABC
218
-
219
- def num_qubits(self) -> int:
220
- return self.operator.num_qubits()
221
-
222
- def unsafe_apply(
223
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
224
- ) -> list[cirq.Operation]:
225
- cirq_ops = self.operator.unsafe_apply(qubits=qubits, adjoint=adjoint)
226
- return [self.factor * cirq_ops[0]] + cirq_ops[1:] # type: ignore
227
-
228
-
229
- @dataclass
230
- class PauliStringRuntime(OperatorRuntimeABC):
231
- string: str
232
-
233
- def num_qubits(self) -> int:
234
- return len(self.string)
235
-
236
- def unsafe_apply(
237
- self, qubits: Sequence[cirq.Qid], adjoint: bool = False
238
- ) -> list[cirq.Operation]:
239
- pauli_mapping = {
240
- qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
241
- }
242
- return [cirq.PauliString(pauli_mapping)]