bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,83 @@
1
+ from typing import Dict, List, Optional
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.rewrite import abc, result
6
+
7
+ from bloqade.analysis import address
8
+ from bloqade.qasm2.dialects import uop, parallel
9
+
10
+
11
+ @dataclass
12
+ class ParallelToUOpRule(abc.RewriteRule):
13
+ id_map: Dict[int, ir.SSAValue]
14
+ address_analysis: Dict[ir.SSAValue, address.Address]
15
+
16
+ def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
17
+ if type(node) in parallel.dialect.stmts:
18
+ return getattr(self, f"rewrite_{node.name}")(node)
19
+
20
+ return result.RewriteResult()
21
+
22
+ def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
23
+ addr = self.address_analysis.get(ilist_ref)
24
+ if not isinstance(addr, address.AddressTuple):
25
+ return None
26
+
27
+ ids = []
28
+ for ele in addr.data:
29
+ if not isinstance(ele, address.AddressQubit):
30
+ return None
31
+
32
+ ids.append(ele.data)
33
+
34
+ return [self.id_map[ele] for ele in ids]
35
+
36
+ def rewrite_cz(self, node: ir.Statement):
37
+ assert isinstance(node, parallel.CZ)
38
+
39
+ ctrls = self.get_qubit_ssa(node.ctrls)
40
+ qargs = self.get_qubit_ssa(node.qargs)
41
+
42
+ if ctrls is None or qargs is None:
43
+ return result.RewriteResult()
44
+
45
+ for ctrl, qarg in zip(ctrls, qargs):
46
+ new_node = uop.CZ(ctrl, qarg)
47
+ new_node.insert_before(node)
48
+
49
+ node.delete()
50
+
51
+ return result.RewriteResult(has_done_something=True)
52
+
53
+ def rewrite_u(self, node: ir.Statement):
54
+ assert isinstance(node, parallel.UGate)
55
+
56
+ qargs = self.get_qubit_ssa(node.qargs)
57
+
58
+ if qargs is None:
59
+ return result.RewriteResult()
60
+
61
+ for qarg in qargs:
62
+ new_node = uop.UGate(qarg, theta=node.theta, phi=node.phi, lam=node.lam)
63
+ new_node.insert_after(node)
64
+
65
+ node.delete()
66
+
67
+ return result.RewriteResult(has_done_something=True)
68
+
69
+ def rewrite_rz(self, node: ir.Statement):
70
+ assert isinstance(node, parallel.RZ)
71
+
72
+ qargs = self.get_qubit_ssa(node.qargs)
73
+
74
+ if qargs is None:
75
+ return result.RewriteResult()
76
+
77
+ for qarg in qargs:
78
+ new_node = uop.RZ(qarg, theta=node.theta)
79
+ new_node.insert_after(node)
80
+
81
+ node.delete()
82
+
83
+ return result.RewriteResult(has_done_something=True)
@@ -0,0 +1,45 @@
1
+ from kirin import ir
2
+ from kirin.dialects import py
3
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
4
+
5
+ from bloqade.qasm2.dialects import core
6
+
7
+
8
+ class RaiseRegisterRule(RewriteRule):
9
+ """This rule puts all registers at the top of the block.
10
+
11
+ This is required for the UOpToParallel rules to work correctly
12
+ to handle cases where a register is defined in between two statements
13
+ that can be parallelized.
14
+
15
+ """
16
+
17
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
+ if not isinstance(node, core.QRegNew):
19
+ return RewriteResult()
20
+
21
+ if node.parent_block is None or node.parent_block.first_stmt is None:
22
+ return RewriteResult()
23
+
24
+ first_stmt = node.parent_block.first_stmt
25
+
26
+ n_qubits_ref = node.n_qubits
27
+
28
+ n_qubits = n_qubits_ref.owner
29
+ if isinstance(n_qubits, py.Constant):
30
+ # case where the n_qubits comes from a constant
31
+ new_n_qubits = n_qubits.from_stmt(n_qubits)
32
+ new_n_qubits.insert_before(first_stmt)
33
+ new_n_qubits_ref = new_n_qubits.result
34
+
35
+ elif isinstance(n_qubits, ir.BlockArgument):
36
+ # case where the n_qubits comes from a block argument
37
+ new_n_qubits_ref = n_qubits
38
+ else:
39
+ return RewriteResult()
40
+
41
+ new_qreg_stmt = core.QRegNew(n_qubits=new_n_qubits_ref)
42
+ new_qreg_stmt.insert_before(first_stmt)
43
+ node.result.replace_by(new_qreg_stmt.result)
44
+ node.delete()
45
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,395 @@
1
+ import abc
2
+ from typing import Dict, List, Tuple, Iterable
3
+ from dataclasses import field, dataclass
4
+
5
+ from kirin import ir
6
+ from kirin.rewrite import abc as rewrite_abc
7
+ from kirin.dialects import py, ilist
8
+ from kirin.analysis.const import lattice
9
+
10
+ from bloqade.analysis import address
11
+ from bloqade.qasm2.dialects import uop, core, parallel
12
+ from bloqade.squin.analysis.schedule import StmtDag
13
+
14
+
15
+ class MergePolicyABC(abc.ABC):
16
+ @abc.abstractmethod
17
+ def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
18
+ pass
19
+
20
+ @classmethod
21
+ @abc.abstractmethod
22
+ def can_merge(cls, stmt1: ir.Statement, stmt2: ir.Statement) -> bool:
23
+ pass
24
+
25
+ @classmethod
26
+ @abc.abstractmethod
27
+ def merge_gates(
28
+ cls, gate_stmts: Iterable[ir.Statement]
29
+ ) -> List[List[ir.Statement]]:
30
+ pass
31
+
32
+ @classmethod
33
+ @abc.abstractmethod
34
+ def from_analysis(
35
+ cls, dag: StmtDag, address_analysis: Dict[ir.SSAValue, address.Address]
36
+ ) -> "MergePolicyABC":
37
+ pass
38
+
39
+
40
+ @dataclass
41
+ class SimpleMergePolicy(MergePolicyABC):
42
+ """General merge policy for merging gates based on their type and arguments.
43
+
44
+ Base class to implement a merge policy for CZ, U and RZ gates, To completed the policy implement the
45
+ `merge_gates` class method. This will take an iterable of statements and return a list
46
+ of groups of statements that can be merged together. There are two mix-in classes
47
+ that can be used to implement the `merge_gates` method. The `GreedyMixin` will merge
48
+ gates together greedily, while the `OptimalMixIn` will merge gates together optimally.
49
+
50
+ """
51
+
52
+ address_analysis: Dict[ir.SSAValue, address.Address]
53
+ """Mapping from SSA values to their address analysis results. Needed for rewrites"""
54
+ merge_groups: List[List[ir.Statement]]
55
+ """List of groups of statements that can be merged together"""
56
+ group_numbers: Dict[ir.Statement, int]
57
+ """Mapping from statements to their group number"""
58
+ group_has_merged: Dict[int, bool] = field(default_factory=dict)
59
+ """Mapping from group number to whether the group has been merged"""
60
+
61
+ @staticmethod
62
+ def same_id_checker(ssa1: ir.SSAValue, ssa2: ir.SSAValue):
63
+ if ssa1 is ssa2:
64
+ return True
65
+ elif (hint1 := ssa1.hints.get("const")) and (hint2 := ssa2.hints.get("const")):
66
+ assert isinstance(hint1, lattice.Result) and isinstance(
67
+ hint2, lattice.Result
68
+ )
69
+ return hint1.is_equal(hint2)
70
+ else:
71
+ return False
72
+
73
+ @classmethod
74
+ def check_equiv_args(
75
+ cls,
76
+ args1: Iterable[ir.SSAValue],
77
+ args2: Iterable[ir.SSAValue],
78
+ ):
79
+ try:
80
+ return all(
81
+ cls.same_id_checker(ssa1, ssa2)
82
+ for ssa1, ssa2 in zip(args1, args2, strict=True)
83
+ )
84
+ except ValueError:
85
+ return False
86
+
87
+ @classmethod
88
+ def can_merge(cls, stmt1: ir.Statement, stmt2: ir.Statement) -> bool:
89
+ match stmt1, stmt2:
90
+ case (
91
+ (uop.UGate(), uop.UGate())
92
+ | (uop.RZ(), uop.RZ())
93
+ | (parallel.UGate(), parallel.UGate())
94
+ | (parallel.UGate(), uop.UGate())
95
+ | (uop.UGate(), parallel.UGate())
96
+ | (uop.UGate(), parallel.UGate())
97
+ | (uop.UGate(), parallel.UGate())
98
+ | (parallel.RZ(), parallel.RZ())
99
+ | (uop.RZ(), parallel.RZ())
100
+ | (parallel.RZ(), uop.RZ())
101
+ ):
102
+ return cls.check_equiv_args(stmt1.args[1:], stmt2.args[1:])
103
+ case (
104
+ (parallel.CZ(), parallel.CZ())
105
+ | (parallel.CZ(), uop.CZ())
106
+ | (uop.CZ(), parallel.CZ())
107
+ | (uop.CZ(), uop.CZ())
108
+ | (uop.Barrier(), uop.Barrier())
109
+ ):
110
+ return True
111
+
112
+ case _:
113
+ return False
114
+
115
+ @classmethod
116
+ def from_analysis(
117
+ cls,
118
+ dag: StmtDag,
119
+ address_analysis: Dict[ir.SSAValue, address.Address],
120
+ ):
121
+
122
+ merge_groups = []
123
+ group_numbers = {}
124
+
125
+ for group in dag.topological_groups():
126
+ gate_groups = cls.merge_gates(map(dag.stmts.__getitem__, group))
127
+ gate_groups_iter = (group for group in gate_groups if len(group) > 1)
128
+
129
+ for gate_group in gate_groups_iter:
130
+ group_number = len(merge_groups)
131
+ merge_groups.append(gate_group)
132
+ for stmt in gate_group:
133
+ group_numbers[stmt] = group_number
134
+
135
+ for group in merge_groups:
136
+ group.sort(key=lambda stmt: dag.stmt_index[stmt])
137
+
138
+ return cls(
139
+ address_analysis=address_analysis,
140
+ merge_groups=merge_groups,
141
+ group_numbers=group_numbers,
142
+ )
143
+
144
+ def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
145
+
146
+ if node not in self.group_numbers:
147
+ return rewrite_abc.RewriteResult()
148
+
149
+ group_number = self.group_numbers[node]
150
+ group = self.merge_groups[group_number]
151
+ if node is group[0]:
152
+ result = getattr(self, f"rewrite_group_{node.name}")(node, group)
153
+
154
+ self.group_has_merged[group_number] = result.has_done_something
155
+ return result
156
+
157
+ if self.group_has_merged[group_number]:
158
+ node.delete()
159
+
160
+ return rewrite_abc.RewriteResult(
161
+ has_done_something=self.group_has_merged[group_number]
162
+ )
163
+
164
+ def move_and_collect_qubit_list(
165
+ self, qargs: List[ir.SSAValue], node: ir.Statement
166
+ ) -> Tuple[ir.SSAValue, ...] | None:
167
+
168
+ qubits: List[ir.SSAValue] = []
169
+ # collect references to qubits
170
+ for qarg in qargs:
171
+ addr = self.address_analysis[qarg]
172
+
173
+ if isinstance(addr, address.AddressQubit):
174
+ qubits.append(qarg)
175
+
176
+ elif isinstance(addr, address.AddressTuple):
177
+ assert isinstance(qarg, ir.ResultValue)
178
+ assert isinstance(qarg.stmt, ilist.New)
179
+ qubits.extend(qarg.stmt.values)
180
+ else:
181
+ # give up if we cannot determine the address
182
+ return None
183
+
184
+ new_qubits = []
185
+
186
+ # the registers must be moved to the top of the block
187
+ # before this pass can be applied
188
+ for qubit_ref in qubits:
189
+ qubit = qubit_ref.owner
190
+ match qubit:
191
+ case ir.BlockArgument(): # do not need to move the qubit
192
+ new_qubits.append(qubit)
193
+ case core.QRegGet(reg=reg, idx=ir.BlockArgument() as idx):
194
+ new_qubit = core.QRegGet(reg=reg, idx=idx)
195
+ new_qubit.insert_before(node)
196
+ new_qubits.append(new_qubit.result)
197
+ case core.QRegGet(
198
+ reg=reg, idx=ir.ResultValue(stmt=py.Constant() as idx)
199
+ ):
200
+ (new_idx := idx.from_stmt(idx)).insert_before(node)
201
+ (
202
+ new_qubit := core.QRegGet(reg=reg, idx=new_idx.result)
203
+ ).insert_before(node)
204
+ new_qubits.append(new_qubit.result)
205
+ case _:
206
+ return None
207
+
208
+ return tuple(new_qubits)
209
+
210
+ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):
211
+ ctrls = []
212
+ qargs = []
213
+
214
+ for stmt in group:
215
+ if isinstance(stmt, uop.CZ):
216
+ ctrls.append(stmt.ctrl)
217
+ qargs.append(stmt.qarg)
218
+ elif isinstance(stmt, parallel.CZ):
219
+ ctrls.append(stmt.ctrls)
220
+ qargs.append(stmt.qargs)
221
+ else:
222
+ return rewrite_abc.RewriteResult(has_done_something=False)
223
+
224
+ ctrls_values = self.move_and_collect_qubit_list(ctrls, node)
225
+ qargs_values = self.move_and_collect_qubit_list(qargs, node)
226
+
227
+ if ctrls_values is None or qargs_values is None:
228
+ # give up if we cannot determine the address or cannot move the qubits
229
+ return rewrite_abc.RewriteResult(has_done_something=False)
230
+
231
+ new_ctrls = ilist.New(values=ctrls_values)
232
+ new_qargs = ilist.New(values=qargs_values)
233
+ new_gate = parallel.CZ(ctrls=new_ctrls.result, qargs=new_qargs.result)
234
+
235
+ new_ctrls.insert_before(node)
236
+ new_qargs.insert_before(node)
237
+ new_gate.insert_before(node)
238
+
239
+ node.delete()
240
+
241
+ return rewrite_abc.RewriteResult(has_done_something=True)
242
+
243
+ def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]):
244
+ return self.rewrite_group_u(node, group)
245
+
246
+ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
247
+ qargs = []
248
+
249
+ for stmt in group:
250
+ if isinstance(stmt, uop.UGate):
251
+ qargs.append(stmt.qarg)
252
+ elif isinstance(stmt, parallel.UGate):
253
+ qargs.append(stmt.qargs)
254
+ else:
255
+ return rewrite_abc.RewriteResult(has_done_something=False)
256
+
257
+ assert isinstance(node, (uop.UGate, parallel.UGate))
258
+ qargs_values = self.move_and_collect_qubit_list(qargs, node)
259
+
260
+ if qargs_values is None:
261
+ return rewrite_abc.RewriteResult(has_done_something=False)
262
+
263
+ new_qargs = ilist.New(values=qargs_values)
264
+ new_gate = parallel.UGate(
265
+ qargs=new_qargs.result,
266
+ theta=node.theta,
267
+ phi=node.phi,
268
+ lam=node.lam,
269
+ )
270
+ new_qargs.insert_before(node)
271
+ new_gate.insert_before(node)
272
+ node.delete()
273
+
274
+ return rewrite_abc.RewriteResult(has_done_something=True)
275
+
276
+ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
277
+ qargs = []
278
+
279
+ for stmt in group:
280
+ if isinstance(stmt, uop.RZ):
281
+ qargs.append(stmt.qarg)
282
+ elif isinstance(stmt, parallel.RZ):
283
+ qargs.append(stmt.qargs)
284
+ else:
285
+ return rewrite_abc.RewriteResult(has_done_something=False)
286
+
287
+ assert isinstance(node, (uop.RZ, parallel.RZ))
288
+
289
+ qargs_values = self.move_and_collect_qubit_list(qargs, node)
290
+
291
+ if qargs_values is None:
292
+ return rewrite_abc.RewriteResult(has_done_something=False)
293
+
294
+ new_qargs = ilist.New(values=qargs_values)
295
+ new_gate = parallel.RZ(
296
+ qargs=new_qargs.result,
297
+ theta=node.theta,
298
+ )
299
+ new_qargs.insert_before(node)
300
+ new_gate.insert_before(node)
301
+ node.delete()
302
+
303
+ return rewrite_abc.RewriteResult(has_done_something=True)
304
+
305
+ def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
306
+ qargs = []
307
+ for stmt in group:
308
+ qargs.extend(stmt.qargs)
309
+
310
+ qargs_values = self.move_and_collect_qubit_list(qargs, node)
311
+
312
+ if qargs_values is None:
313
+ return rewrite_abc.RewriteResult(has_done_something=False)
314
+
315
+ new_node = uop.Barrier(qargs=qargs_values)
316
+ new_node.insert_before(node)
317
+ node.delete()
318
+
319
+ return rewrite_abc.RewriteResult(has_done_something=True)
320
+
321
+
322
+ class GreedyMixin(MergePolicyABC):
323
+ """Merge policy that greedily merges gates together.
324
+
325
+ The `merge_gates` method will merge policy will try greedily merge gates together.
326
+ This policy has a worst case complexity of O(n) where n is the
327
+ number of gates in the input iterable.
328
+ """
329
+
330
+ @classmethod
331
+ def merge_gates(
332
+ cls, gate_stmts: Iterable[ir.Statement]
333
+ ) -> List[List[ir.Statement]]:
334
+
335
+ iterable = iter(gate_stmts)
336
+ groups = [[next(iterable)]]
337
+
338
+ for stmt in gate_stmts:
339
+ if cls.can_merge(groups[-1][-1], stmt):
340
+ groups[-1].append(stmt)
341
+ else:
342
+ groups.append([stmt])
343
+
344
+ return groups
345
+
346
+
347
+ class OptimalMixIn(MergePolicyABC):
348
+ """Merge policy that merges gates together optimally.
349
+
350
+ The `merge_gates` method will merge policy will try to merge every gate into every
351
+ group of gates, terminating when it finds a group that can be merged with the current
352
+ gate. This policy has a worst case complexity of O(n^2) where n is the number of gates
353
+ in the input iterable.
354
+
355
+ """
356
+
357
+ @classmethod
358
+ def merge_gates(
359
+ cls, gate_stmts: Iterable[ir.Statement]
360
+ ) -> List[List[ir.Statement]]:
361
+
362
+ groups = []
363
+ for stmt in gate_stmts:
364
+ found = False
365
+ for group in groups:
366
+ if cls.can_merge(group[-1], stmt):
367
+ group.append(stmt)
368
+ found = True
369
+ break
370
+
371
+ if not found:
372
+ groups.append([stmt])
373
+
374
+ return groups
375
+
376
+
377
+ @dataclass
378
+ class SimpleGreedyMergePolicy(GreedyMixin, SimpleMergePolicy):
379
+ pass
380
+
381
+
382
+ @dataclass
383
+ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy):
384
+ pass
385
+
386
+
387
+ @dataclass
388
+ class UOpToParallelRule(rewrite_abc.RewriteRule):
389
+ merge_rewriters: Dict[ir.Block | None, MergePolicyABC]
390
+
391
+ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
392
+ merge_rewriter = self.merge_rewriters.get(
393
+ node.parent_block, lambda _: rewrite_abc.RewriteResult()
394
+ )
395
+ return merge_rewriter(node)
bloqade/qasm2/types.py ADDED
@@ -0,0 +1,39 @@
1
+ from kirin import types
2
+
3
+ from bloqade.types import Qubit as Qubit, QubitType as QubitType
4
+
5
+
6
+ class Bit:
7
+ """Runtime representation of a bit.
8
+
9
+ Note:
10
+ This is the base class of more specific bit types, such as
11
+ a reference to a piece of classical register in some quantum register
12
+ dialects.
13
+ """
14
+
15
+ pass
16
+
17
+
18
+ class QReg:
19
+ """Runtime representation of a quantum register."""
20
+
21
+ def __getitem__(self, index) -> Qubit:
22
+ raise NotImplementedError("cannot call __getitem__ outside of a kernel")
23
+
24
+
25
+ class CReg:
26
+ """Runtime representation of a classical register."""
27
+
28
+ def __getitem__(self, index) -> Bit:
29
+ raise NotImplementedError("cannot call __getitem__ outside of a kernel")
30
+
31
+
32
+ BitType = types.PyClass(Bit)
33
+ """Kirin type for a classical bit."""
34
+
35
+ QRegType = types.PyClass(QReg)
36
+ """Kirin type for a quantum register."""
37
+
38
+ CRegType = types.PyClass(CReg)
39
+ """Kirin type for a classical register."""
@@ -0,0 +1,2 @@
1
+ from .target import qBraid as qBraid
2
+ from .lowering import Lowering as Lowering