bloqade-circuit 0.7.12__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,18 @@
1
1
  from typing import Sequence, final
2
2
  from dataclasses import dataclass
3
3
 
4
+ from kirin import ir, types
4
5
  from kirin.lattice import (
5
6
  SingletonMeta,
6
7
  BoundedLattice,
7
8
  SimpleJoinMixin,
8
9
  SimpleMeetMixin,
9
10
  )
11
+ from kirin.analysis import const
12
+ from kirin.dialects import ilist
13
+ from kirin.ir.attrs.abc import LatticeAttributeMeta
14
+
15
+ from bloqade.types import QubitType
10
16
 
11
17
 
12
18
  @dataclass
@@ -18,54 +24,94 @@ class Address(
18
24
 
19
25
  @classmethod
20
26
  def bottom(cls) -> "Address":
21
- return NotQubit()
27
+ return Bottom()
22
28
 
23
29
  @classmethod
24
30
  def top(cls) -> "Address":
25
- return AnyAddress()
31
+ return Unknown()
32
+
33
+ @classmethod
34
+ def from_type(cls, typ: types.TypeAttribute):
35
+ if typ.is_subseteq(ilist.IListType[QubitType]):
36
+ return UnknownReg()
37
+ elif typ.is_subseteq(QubitType):
38
+ return UnknownQubit()
39
+ else:
40
+ return Unknown()
26
41
 
27
42
 
28
43
  @final
29
- @dataclass
30
- class NotQubit(Address, metaclass=SingletonMeta):
44
+ class Bottom(Address, metaclass=SingletonMeta):
45
+ """Error during interpretation"""
31
46
 
32
47
  def is_subseteq(self, other: Address) -> bool:
33
48
  return True
34
49
 
35
50
 
36
51
  @final
37
- @dataclass
38
- class AnyAddress(Address, metaclass=SingletonMeta):
52
+ class Unknown(Address, metaclass=SingletonMeta):
53
+ """Can't determine if it is an address or constant."""
39
54
 
40
55
  def is_subseteq(self, other: Address) -> bool:
41
- return isinstance(other, AnyAddress)
56
+ return isinstance(other, Unknown)
42
57
 
43
58
 
44
59
  @final
45
60
  @dataclass
46
- class AddressTuple(Address):
47
- data: tuple[Address, ...]
61
+ class ConstResult(Address):
62
+ """Stores a constant prop result in the lattice"""
63
+
64
+ result: const.Result
48
65
 
49
66
  def is_subseteq(self, other: Address) -> bool:
50
- if isinstance(other, AddressTuple):
51
- return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
52
- return False
67
+ return isinstance(other, ConstResult) and self.result.is_subseteq(other.result)
68
+
69
+
70
+ class QubitLike(Address):
71
+ def join(self, other: Address):
72
+ if isinstance(other, QubitLike):
73
+ return super().join(other)
74
+ return self.bottom()
75
+
76
+ def meet(self, other: Address):
77
+ if isinstance(other, QubitLike):
78
+ return super().meet(other)
79
+ return self.bottom()
53
80
 
54
81
 
55
82
  @final
56
- @dataclass
57
- class AddressReg(Address):
58
- data: Sequence[int]
83
+ class UnknownQubit(QubitLike, metaclass=SingletonMeta):
84
+ """A lattice element representing a single qubit with an unknown address."""
59
85
 
60
86
  def is_subseteq(self, other: Address) -> bool:
61
- if isinstance(other, AddressReg):
62
- return self.data == other.data
63
- return False
87
+ return isinstance(other, QubitLike)
88
+
89
+
90
+ class RegisterLike(Address):
91
+ def join(self, other: Address):
92
+ if isinstance(other, RegisterLike):
93
+ return super().join(other)
94
+ return self.bottom()
95
+
96
+ def meet(self, other: Address):
97
+ if isinstance(other, RegisterLike):
98
+ return super().meet(other)
99
+ return self.bottom()
100
+
101
+
102
+ @final
103
+ class UnknownReg(RegisterLike, metaclass=SingletonMeta):
104
+ """A lattice element representing a container of qubits with unknown indices."""
105
+
106
+ def is_subseteq(self, other: Address) -> bool:
107
+ return isinstance(other, RegisterLike)
64
108
 
65
109
 
66
110
  @final
67
111
  @dataclass
68
- class AddressQubit(Address):
112
+ class AddressQubit(QubitLike):
113
+ """A lattice element representing a single qubit with a known address."""
114
+
69
115
  data: int
70
116
 
71
117
  def is_subseteq(self, other: Address) -> bool:
@@ -76,10 +122,149 @@ class AddressQubit(Address):
76
122
 
77
123
  @final
78
124
  @dataclass
79
- class AddressWire(Address):
80
- origin_qubit: AddressQubit
125
+ class AddressReg(RegisterLike):
126
+ """A lattice element representing a container of qubits with known indices."""
127
+
128
+ data: Sequence[int]
81
129
 
82
130
  def is_subseteq(self, other: Address) -> bool:
83
- if isinstance(other, AddressWire):
84
- return self.origin_qubit == other.origin_qubit
85
- return False
131
+ return isinstance(other, AddressReg) and self.data == other.data
132
+
133
+ @property
134
+ def qubits(self) -> tuple[AddressQubit, ...]:
135
+ return tuple(AddressQubit(i) for i in self.data)
136
+
137
+
138
+ @final
139
+ @dataclass
140
+ class PartialLambda(Address):
141
+ """Represents a partially known lambda function"""
142
+
143
+ argnames: list[str]
144
+ code: ir.Statement
145
+ captured: tuple[Address, ...]
146
+
147
+ def join(self, other: Address) -> Address:
148
+ if other is other.bottom():
149
+ return self
150
+
151
+ if not isinstance(other, PartialLambda):
152
+ return self.top().join(other) # widen self
153
+
154
+ if self.code is not other.code:
155
+ return self.top() # lambda stmt is pure
156
+
157
+ if len(self.captured) != len(other.captured):
158
+ return self.bottom() # err
159
+
160
+ return PartialLambda(
161
+ self.argnames,
162
+ self.code,
163
+ tuple(x.join(y) for x, y in zip(self.captured, other.captured)),
164
+ )
165
+
166
+ def meet(self, other: Address) -> Address:
167
+ if not isinstance(other, PartialLambda):
168
+ return self.top().meet(other)
169
+
170
+ if self.code is not other.code:
171
+ return self.bottom()
172
+
173
+ if len(self.captured) != len(other.captured):
174
+ return self.top()
175
+
176
+ return PartialLambda(
177
+ self.argnames,
178
+ self.code,
179
+ tuple(x.meet(y) for x, y in zip(self.captured, other.captured)),
180
+ )
181
+
182
+ def is_subseteq(self, other: Address) -> bool:
183
+ return (
184
+ isinstance(other, PartialLambda)
185
+ and self.code is other.code
186
+ and self.argnames == other.argnames
187
+ and len(self.captured) == len(other.captured)
188
+ and all(
189
+ self_ele.is_subseteq(other_ele)
190
+ for self_ele, other_ele in zip(self.captured, other.captured)
191
+ )
192
+ )
193
+
194
+
195
+ @dataclass
196
+ class StaticContainer(Address):
197
+ """A lattice element representing the results of any static container, e. g. ilist or tuple."""
198
+
199
+ data: tuple[Address, ...]
200
+
201
+ @classmethod
202
+ def new(cls, data: tuple[Address, ...]):
203
+ return cls(data)
204
+
205
+ def join(self, other: "Address") -> "Address":
206
+ if isinstance(other, type(self)) and len(self.data) == len(other.data):
207
+ return self.new(tuple(x.join(y) for x, y in zip(self.data, other.data)))
208
+ return self.top()
209
+
210
+ def meet(self, other: "Address") -> "Address":
211
+ if isinstance(other, type(self)) and len(self.data) == len(other.data):
212
+ return self.new(tuple(x.meet(y) for x, y in zip(self.data, other.data)))
213
+ return self.bottom()
214
+
215
+ def is_subseteq(self, other: "Address") -> bool:
216
+ return (
217
+ isinstance(other, type(self))
218
+ and len(self.data) == len(other.data)
219
+ and all(x.is_subseteq(y) for x, y in zip(self.data, other.data))
220
+ )
221
+
222
+
223
+ class PartialIListMeta(LatticeAttributeMeta):
224
+ """This metaclass assures that PartialILists of ConstResults or AddressQubits are canonicalized
225
+ to a single ConstResult or AddressReg respectively.
226
+
227
+ because AddressReg is a specialization of PartialIList, being a container of pure qubit
228
+ addresses. For Operations that act in generic containers (e.g., ilist.ForEach),
229
+ AddressReg is treated as PartialIList but for other types of analysis it is often
230
+ useful to distinguish between a generic IList and a pure qubit address list.
231
+
232
+ Inside the method tables the `GetValuesMixin` implements a method that effectively
233
+ undoes this canonicalization.
234
+
235
+ """
236
+
237
+ def __call__(cls, data: tuple[Address, ...]):
238
+ # TODO: when constant prop has PartialIList, make sure to canonicalize here.
239
+ if types.is_tuple_of(data, ConstResult) and types.is_tuple_of(
240
+ all_constants := tuple(ele.result for ele in data), const.Value
241
+ ):
242
+ # all constants, create constant list
243
+ return ConstResult(
244
+ const.Value(ilist.IList([ele.data for ele in all_constants]))
245
+ )
246
+ elif types.is_tuple_of(data, AddressQubit):
247
+ # all qubits create qubit register
248
+ return AddressReg(tuple(ele.data for ele in data))
249
+ else:
250
+ return super().__call__(data)
251
+
252
+
253
+ @final
254
+ class PartialIList(StaticContainer, metaclass=PartialIListMeta):
255
+ """A lattice element representing a partially known ilist."""
256
+
257
+
258
+ class PartialTupleMeta(LatticeAttributeMeta):
259
+ """This metaclass assures that PartialTuples of ConstResults are canonicalized to a single ConstResult."""
260
+
261
+ def __call__(cls, data: tuple[Address, ...]):
262
+ if not types.is_tuple_of(data, ConstResult):
263
+ return super().__call__(data)
264
+
265
+ return ConstResult(const.PartialTuple(tuple(ele.result for ele in data)))
266
+
267
+
268
+ @final
269
+ class PartialTuple(StaticContainer, metaclass=PartialTupleMeta):
270
+ """A lattice element representing a partially known tuple."""
@@ -7,7 +7,7 @@ from kirin.analysis import Forward
7
7
  from kirin.interp.value import Successor
8
8
  from kirin.analysis.forward import ForwardFrame
9
9
 
10
- from ..address import AddressAnalysis
10
+ from ..address import Address, AddressAnalysis
11
11
 
12
12
 
13
13
  class FidelityAnalysis(Forward):
@@ -57,7 +57,7 @@ class FidelityAnalysis(Forward):
57
57
 
58
58
  _current_atom_survival_probability: list[float] = field(init=False)
59
59
 
60
- addr_frame: ForwardFrame = field(init=False)
60
+ addr_frame: ForwardFrame[Address] = field(init=False)
61
61
 
62
62
  def initialize(self):
63
63
  super().initialize()
@@ -2,7 +2,7 @@ from kirin import types as kirin_types, interp
2
2
  from kirin.analysis import const
3
3
  from kirin.dialects import py, scf, func, ilist
4
4
 
5
- from bloqade.squin import wire, qubit
5
+ from bloqade import qubit
6
6
 
7
7
  from .lattice import (
8
8
  AnyMeasureId,
@@ -21,22 +21,12 @@ from .analysis import MeasureIDFrame, MeasurementIDAnalysis
21
21
  @qubit.dialect.register(key="measure_id")
22
22
  class SquinQubit(interp.MethodTable):
23
23
 
24
- @interp.impl(qubit.MeasureQubit)
25
- def measure_qubit(
26
- self,
27
- interp: MeasurementIDAnalysis,
28
- frame: interp.Frame,
29
- stmt: qubit.MeasureQubit,
30
- ):
31
- interp.measure_count += 1
32
- return (MeasureIdBool(interp.measure_count),)
33
-
34
- @interp.impl(qubit.MeasureQubitList)
24
+ @interp.impl(qubit.stmts.Measure)
35
25
  def measure_qubit_list(
36
26
  self,
37
27
  interp: MeasurementIDAnalysis,
38
28
  frame: interp.Frame,
39
- stmt: qubit.MeasureQubitList,
29
+ stmt: qubit.stmts.Measure,
40
30
  ):
41
31
 
42
32
  # try to get the length of the list
@@ -56,20 +46,6 @@ class SquinQubit(interp.MethodTable):
56
46
  return (MeasureIdTuple(data=tuple(measure_id_bools)),)
57
47
 
58
48
 
59
- @wire.dialect.register(key="measure_id")
60
- class SquinWire(interp.MethodTable):
61
-
62
- @interp.impl(wire.Measure)
63
- def measure_qubit(
64
- self,
65
- interp: MeasurementIDAnalysis,
66
- frame: interp.Frame,
67
- stmt: wire.Measure,
68
- ):
69
- interp.measure_count += 1
70
- return (MeasureIdBool(interp.measure_count),)
71
-
72
-
73
49
  @ilist.dialect.register(key="measure_id")
74
50
  class IList(interp.MethodTable):
75
51
  @interp.impl(ilist.New)
@@ -1,4 +1,6 @@
1
- from . import noise as noise
1
+ from . import emit as emit, noise as noise, lowering as lowering
2
+ from .emit import emit_circuit as emit_circuit
3
+ from .lowering import load_circuit as load_circuit
2
4
  from .parallelize import (
3
5
  transpile as transpile,
4
6
  parallelize as parallelize,
@@ -0,0 +1,3 @@
1
+ # NOTE: just to register methods
2
+ from . import gate as gate, noise as noise, qubit as qubit
3
+ from .base import emit_circuit as emit_circuit
@@ -0,0 +1,243 @@
1
+ from typing import Sequence
2
+ from warnings import warn
3
+ from dataclasses import field, dataclass
4
+
5
+ import cirq
6
+ from kirin import ir, types, interp
7
+ from kirin.emit import EmitABC, EmitError, EmitFrame
8
+ from kirin.interp import MethodTable, impl
9
+ from kirin.dialects import py, func
10
+ from typing_extensions import Self
11
+
12
+ from bloqade.squin import kernel
13
+ from bloqade.rewrite.passes import AggressiveUnroll
14
+
15
+
16
+ def emit_circuit(
17
+ mt: ir.Method,
18
+ qubits: Sequence[cirq.Qid] | None = None,
19
+ circuit_qubits: Sequence[cirq.Qid] | None = None,
20
+ args: tuple = (),
21
+ ignore_returns: bool = False,
22
+ ) -> cirq.Circuit:
23
+ """Converts a squin.kernel method to a cirq.Circuit object.
24
+
25
+ Args:
26
+ mt (ir.Method): The kernel method from which to construct the circuit.
27
+
28
+ Keyword Args:
29
+ circuit_qubits (Sequence[cirq.Qid] | None):
30
+ A list of qubits to use as the qubits in the circuit. Defaults to None.
31
+ If this is None, then `cirq.LineQubit`s are inserted for every `squin.qalloc`
32
+ statement in the order they appear inside the kernel.
33
+ **Note**: If a list of qubits is provided, make sure that there is a sufficient
34
+ number of qubits for the resulting circuit.
35
+ args (tuple):
36
+ The arguments of the kernel function from which to emit a circuit.
37
+ ignore_returns (bool):
38
+ If `False`, emitting a circuit from a kernel that returns a value will error.
39
+ Set it to `True` in order to ignore the return value(s). Defaults to `False`.
40
+
41
+ ## Examples:
42
+
43
+ Here's a very basic example:
44
+
45
+ ```python
46
+ from bloqade import squin
47
+ from bloqade.cirq_utils import emit_circuit
48
+
49
+ @squin.kernel
50
+ def main():
51
+ q = squin.qalloc(2)
52
+ squin.h(q[0])
53
+ squin.cx(q[0], q[1])
54
+
55
+ circuit = emit_circuit(main)
56
+
57
+ print(circuit)
58
+ ```
59
+
60
+ You can also compose multiple kernels. Those are emitted as subcircuits within the "main" circuit.
61
+ Subkernels can accept arguments and return a value.
62
+
63
+ ```python
64
+ from bloqade import squin
65
+ from bloqade.cirq_utils import emit_circuit
66
+ from kirin.dialects import ilist
67
+ from typing import Literal
68
+ import cirq
69
+
70
+ @squin.kernel
71
+ def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
72
+ squin.h(q[0])
73
+ squin.cx(q[0], q[1])
74
+
75
+ @squin.kernel
76
+ def main():
77
+ q = squin.qalloc(2)
78
+ q2 = squin.qalloc(3)
79
+ squin.cx(q[1], q2[2])
80
+
81
+
82
+ # custom list of qubits on grid
83
+ qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
84
+
85
+ circuit = emit_circuit(main, circuit_qubits=qubits)
86
+ print(circuit)
87
+
88
+ ```
89
+
90
+ We also passed in a custom list of qubits above. This allows you to provide a custom geometry
91
+ and manipulate the qubits in other circuits directly written in cirq as well.
92
+ """
93
+
94
+ if circuit_qubits is None and qubits is not None:
95
+ circuit_qubits = qubits
96
+ warn(
97
+ "The keyword argument `qubits` is deprecated. Use `circuit_qubits` instead."
98
+ )
99
+
100
+ if (
101
+ not ignore_returns
102
+ and isinstance(mt.code, func.Function)
103
+ and not mt.code.signature.output.is_subseteq(types.NoneType)
104
+ ):
105
+ raise EmitError(
106
+ "The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
107
+ " Set `ignore_returns = True` in order to simply ignore the return values and emit a circuit."
108
+ )
109
+
110
+ if len(args) != len(mt.args):
111
+ raise ValueError(
112
+ f"The method from which you're trying to emit a circuit takes {len(mt.args)} as input, but you passed in {len(args)} via the `args` keyword!"
113
+ )
114
+
115
+ emitter = EmitCirq(qubits=circuit_qubits)
116
+
117
+ symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface)
118
+ if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None:
119
+ raise EmitError("The method is not a symbol, cannot emit circuit!")
120
+
121
+ sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap()
122
+
123
+ if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None:
124
+ raise EmitError(
125
+ f"The method {sym_name} does not have a signature, cannot emit circuit!"
126
+ )
127
+
128
+ signature = signature_trait.get_signature(mt.code)
129
+ new_signature = func.Signature(inputs=(), output=signature.output)
130
+
131
+ callable_region = mt.callable_region.clone()
132
+ entry_block = callable_region.blocks[0]
133
+ args_ssa = list(entry_block.args)
134
+ first_stmt = entry_block.first_stmt
135
+
136
+ assert first_stmt is not None, "Method has no statements!"
137
+ if len(args_ssa) - 1 != len(args):
138
+ raise EmitError(
139
+ f"The method {sym_name} takes {len(args_ssa) - 1} arguments, but you passed in {len(args)} via the `args` keyword!"
140
+ )
141
+
142
+ for arg, arg_ssa in zip(args, args_ssa[1:], strict=True):
143
+ (value := py.Constant(arg)).insert_before(first_stmt)
144
+ arg_ssa.replace_by(value.result)
145
+ entry_block.args.delete(arg_ssa)
146
+
147
+ new_func = func.Function(
148
+ sym_name=sym_name, body=callable_region, signature=new_signature
149
+ )
150
+ mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
151
+
152
+ AggressiveUnroll(mt_.dialects).fixpoint(mt_)
153
+ return emitter.run(mt_, args=())
154
+
155
+
156
+ @dataclass
157
+ class EmitCirqFrame(EmitFrame):
158
+ qubit_index: int = 0
159
+ qubits: Sequence[cirq.Qid] | None = None
160
+ circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
161
+
162
+
163
+ def _default_kernel():
164
+ return kernel
165
+
166
+
167
+ @dataclass
168
+ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
169
+ keys = ["emit.cirq", "main"]
170
+ dialects: ir.DialectGroup = field(default_factory=_default_kernel)
171
+ void = cirq.Circuit()
172
+ qubits: Sequence[cirq.Qid] | None = None
173
+
174
+ def initialize(self) -> Self:
175
+ return super().initialize()
176
+
177
+ def initialize_frame(
178
+ self, code: ir.Statement, *, has_parent_access: bool = False
179
+ ) -> EmitCirqFrame:
180
+ return EmitCirqFrame(
181
+ code, has_parent_access=has_parent_access, qubits=self.qubits
182
+ )
183
+
184
+ def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
185
+ return self.run_callable(method.code, args)
186
+
187
+ def run_callable_region(
188
+ self,
189
+ frame: EmitCirqFrame,
190
+ code: ir.Statement,
191
+ region: ir.Region,
192
+ args: tuple,
193
+ ):
194
+ if len(region.blocks) > 0:
195
+ block_args = list(region.blocks[0].args)
196
+ # NOTE: skip self arg
197
+ frame.set_values(block_args[1:], args)
198
+
199
+ results = self.eval_stmt(frame, code)
200
+ if isinstance(results, tuple):
201
+ if len(results) == 0:
202
+ return self.void
203
+ elif len(results) == 1:
204
+ return results[0]
205
+ raise interp.InterpreterError(f"Unexpected results {results}")
206
+
207
+ def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
208
+ for stmt in block.stmts:
209
+ result = self.eval_stmt(frame, stmt)
210
+ if isinstance(result, tuple):
211
+ frame.set_values(stmt.results, result)
212
+
213
+ return frame.circuit
214
+
215
+
216
+ @func.dialect.register(key="emit.cirq")
217
+ class __FuncEmit(MethodTable):
218
+
219
+ @impl(func.Function)
220
+ def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
221
+ emit.run_ssacfg_region(frame, stmt.body, ())
222
+ return (frame.circuit,)
223
+
224
+ @impl(func.Invoke)
225
+ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
226
+ raise EmitError(
227
+ "Function invokes should need to be inlined! "
228
+ "If you called the emit_circuit method, that should have happened, please report this issue."
229
+ )
230
+
231
+ @impl(func.Return)
232
+ def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
233
+ # NOTE: should only be hit if ignore_returns == True
234
+ return ()
235
+
236
+
237
+ @py.indexing.dialect.register(key="emit.cirq")
238
+ class __Concrete(interp.MethodTable):
239
+
240
+ @interp.impl(py.indexing.GetItem)
241
+ def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem):
242
+ # NOTE: no support for indexing into single statements in cirq
243
+ return ()