bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,293 @@
1
+ from dataclasses import field, dataclass
2
+
3
+ from kirin import print
4
+
5
+ from .ast import (
6
+ OPENQASM,
7
+ Pi,
8
+ Bit,
9
+ Cmp,
10
+ Call,
11
+ CReg,
12
+ Gate,
13
+ Name,
14
+ QReg,
15
+ BinOp,
16
+ Kirin,
17
+ Reset,
18
+ UGate,
19
+ CXGate,
20
+ IfStmt,
21
+ Number,
22
+ Opaque,
23
+ Barrier,
24
+ Comment,
25
+ Include,
26
+ Measure,
27
+ UnaryOp,
28
+ GlobUGate,
29
+ ParaCZGate,
30
+ ParaRZGate,
31
+ ParaU3Gate,
32
+ Instruction,
33
+ MainProgram,
34
+ NoisePAULI1,
35
+ ParallelQArgs,
36
+ )
37
+ from .visitor import Visitor
38
+
39
+
40
+ @dataclass
41
+ class ColorScheme:
42
+ comment: str = "bright_black"
43
+ keyword: str = "red"
44
+ symbol: str = "cyan"
45
+ string: str = "yellow"
46
+ number: str = "green"
47
+ irrational: str = "magenta"
48
+
49
+
50
+ @dataclass
51
+ class PrintState:
52
+ indent: int = 0
53
+ result_width: int = 0
54
+ rich_style: str | None = None
55
+ rich_highlight: bool | None = False
56
+ indent_marks: list[int] = field(default_factory=list)
57
+ messages: list[str] = field(default_factory=list)
58
+
59
+
60
+ class Printer(print.Printer, Visitor[None]):
61
+
62
+ def visit_MainProgram(self, node: MainProgram) -> None:
63
+ self.print_indent()
64
+ self.visit(node.header)
65
+ self.print_newline()
66
+ for stmt in node.statements:
67
+ self.visit(stmt)
68
+ self.print_newline()
69
+
70
+ def visit_OPENQASM(self, node: OPENQASM) -> None:
71
+ self.plain_print(
72
+ f"OPENQASM {node.version.major}.{node.version.minor}",
73
+ style="comment",
74
+ )
75
+ self.plain_print(";")
76
+
77
+ def visit_Kirin(self, node: Kirin) -> None:
78
+ self.plain_print(
79
+ "KIRIN " + "{" + ",".join(sorted(node.dialects)) + "}", style="comment"
80
+ )
81
+ self.plain_print(";")
82
+
83
+ def visit_Include(self, node: Include) -> None:
84
+ self.plain_print("include", style="keyword")
85
+ self.plain_print(" ")
86
+ self.plain_print('"', node.filename, '"', style="string")
87
+ self.plain_print(";")
88
+
89
+ def visit_Barrier(self, node: Barrier) -> None:
90
+ self.print_indent()
91
+ self.plain_print("barrier", style="keyword")
92
+ self.plain_print(" ")
93
+ self.print_seq(node.qargs, emit=self.visit)
94
+ self.plain_print(";")
95
+
96
+ def visit_Instruction(self, node: Instruction) -> None:
97
+ self.visit_Name(node.name)
98
+ self.plain_print(" ")
99
+ if node.params:
100
+ self.print_seq(
101
+ node.params, delim=", ", prefix="(", suffix=") ", emit=self.visit
102
+ )
103
+ self.print_seq(node.qargs, emit=self.visit)
104
+ self.plain_print(";")
105
+
106
+ def visit_Comment(self, node: Comment) -> None:
107
+ self.plain_print("// ", node.text, style="comment")
108
+
109
+ def visit_CReg(self, node: CReg) -> None:
110
+ self.plain_print("creg", style="keyword")
111
+ self.plain_print(f" {node.name}[{node.size}]")
112
+ self.plain_print(";")
113
+
114
+ def visit_QReg(self, node: QReg) -> None:
115
+ self.plain_print("qreg", style="keyword")
116
+ self.plain_print(f" {node.name}[{node.size}]")
117
+ self.plain_print(";")
118
+
119
+ def visit_CXGate(self, node: CXGate) -> None:
120
+ self.plain_print("CX", style="keyword")
121
+ self.plain_print(" ")
122
+ self.visit(node.ctrl)
123
+ self.plain_print(", ")
124
+ self.visit(node.qarg)
125
+ self.plain_print(";")
126
+
127
+ def visit_UGate(self, node: UGate) -> None:
128
+ self.plain_print("U", style="keyword")
129
+ self.plain_print("(")
130
+ self.visit(node.theta)
131
+ self.plain_print(", ")
132
+ self.visit(node.phi)
133
+ self.plain_print(", ")
134
+ self.visit(node.lam)
135
+ self.plain_print(") ")
136
+ self.visit(node.qarg)
137
+ self.plain_print(";")
138
+
139
+ def visit_Measure(self, node: Measure) -> None:
140
+ self.plain_print("measure", style="keyword")
141
+ self.plain_print(" ")
142
+ self.visit(node.qarg)
143
+ self.plain_print(" -> ")
144
+ self.visit(node.carg)
145
+ self.plain_print(";")
146
+
147
+ def visit_Reset(self, node: Reset) -> None:
148
+ self.plain_print("reset ")
149
+ self.visit(node.qarg)
150
+ self.plain_print(";")
151
+
152
+ def visit_Opaque(self, node: Opaque) -> None:
153
+ self.plain_print("opaque ", style="keyword")
154
+ if node.cparams:
155
+ self.print_seq(
156
+ node.cparams, delim=", ", prefix="(", suffix=")", emit=self.visit
157
+ )
158
+
159
+ if node.qparams:
160
+ self.plain_print(" ")
161
+ self.print_seq(node.qparams, delim=", ", emit=self.visit)
162
+ self.plain_print(";")
163
+
164
+ def visit_Gate(self, node: Gate) -> None:
165
+ self.plain_print("gate ", style="keyword")
166
+ self.plain_print(node.name, style="symbol")
167
+ if node.cparams:
168
+ self.print_seq(
169
+ node.cparams, delim=", ", prefix="(", suffix=")", emit=self.plain_print
170
+ )
171
+
172
+ if node.qparams:
173
+ self.plain_print(" ")
174
+ self.print_seq(node.qparams, delim=", ", emit=self.plain_print)
175
+
176
+ self.plain_print(" {")
177
+ with self.indent():
178
+ self.print_newline()
179
+ for idx, stmt in enumerate(node.body):
180
+ self.visit(stmt)
181
+ if idx < len(node.body) - 1:
182
+ self.print_newline()
183
+ self.print_newline()
184
+ self.plain_print("}")
185
+
186
+ def visit_IfStmt(self, node: IfStmt) -> None:
187
+ self.plain_print("if", style="keyword")
188
+ self.visit(node.cond)
189
+ if len(node.body) == 1: # inline if
190
+ self.visit(node.body[0])
191
+ else:
192
+ self.plain_print("{")
193
+ with self.indent():
194
+ self.print_newline()
195
+ for idx, stmt in enumerate(node.body):
196
+ self.visit(stmt)
197
+ if idx < len(node.body) - 1:
198
+ self.print_newline()
199
+ self.print_newline()
200
+ self.plain_print("}")
201
+
202
+ def visit_Cmp(self, node: Cmp) -> None:
203
+ self.plain_print(" (")
204
+ self.visit(node.lhs)
205
+ self.plain_print(" == ", style="keyword")
206
+ self.visit(node.rhs)
207
+ self.plain_print(") ")
208
+
209
+ def visit_Call(self, node: Call) -> None:
210
+ self.plain_print(node.name)
211
+ self.print_seq(node.args, delim=", ", prefix="(", suffix=")", emit=self.visit)
212
+
213
+ def visit_BinOp(self, node: BinOp) -> None:
214
+ self.plain_print("(")
215
+ self.visit(node.lhs)
216
+ self.plain_print(f" {node.op} ", style="keyword")
217
+ self.visit(node.rhs)
218
+ self.plain_print(")")
219
+
220
+ def visit_UnaryOp(self, node: UnaryOp) -> None:
221
+ self.plain_print(f"{node.op}", style="keyword")
222
+ self.visit(node.operand)
223
+
224
+ def visit_Bit(self, node: Bit) -> None:
225
+ self.visit_Name(node.name)
226
+ if node.addr is not None:
227
+ self.plain_print("[")
228
+ self.plain_print(node.addr, style="number")
229
+ self.plain_print("]")
230
+
231
+ def visit_Number(self, node: Number) -> None:
232
+ self.plain_print(node.value)
233
+
234
+ def visit_Pi(self, node: Pi) -> None:
235
+ self.plain_print("pi", style="number")
236
+
237
+ def visit_Name(self, node: Name) -> None:
238
+ return self.plain_print(node.id, style="symbol")
239
+
240
+ def visit_ParallelQArgs(self, node: ParallelQArgs) -> None:
241
+ self.plain_print("{")
242
+ with self.indent():
243
+ for idx, qargs in enumerate(node.qargs):
244
+ self.print_newline()
245
+ self.print_seq(qargs, emit=self.visit)
246
+ self.plain_print(";")
247
+ self.print_newline()
248
+ self.plain_print("}")
249
+
250
+ def visit_ParaU3Gate(self, node: ParaU3Gate) -> None:
251
+ self.plain_print("parallel.U", style="keyword")
252
+ self.plain_print("(")
253
+ self.visit(node.theta)
254
+ self.plain_print(", ")
255
+ self.visit(node.phi)
256
+ self.plain_print(", ")
257
+ self.visit(node.lam)
258
+ self.plain_print(") ")
259
+ self.visit_ParallelQArgs(node.qargs)
260
+
261
+ def visit_ParaCZGate(self, node: ParaCZGate) -> None:
262
+ self.plain_print("parallel.CZ ", style="keyword")
263
+ self.visit_ParallelQArgs(node.qargs)
264
+
265
+ def visit_ParaRZGate(self, node: ParaRZGate) -> None:
266
+ self.plain_print("parallel.RZ", style="keyword")
267
+ self.plain_print("(")
268
+ self.visit(node.theta)
269
+ self.plain_print(") ")
270
+ self.visit_ParallelQArgs(node.qargs)
271
+
272
+ def visit_GlobUGate(self, node: GlobUGate) -> None:
273
+ self.plain_print("glob.U", style="keyword")
274
+ self.plain_print("(")
275
+ self.visit(node.theta)
276
+ self.plain_print(", ")
277
+ self.visit(node.phi)
278
+ self.plain_print(", ")
279
+ self.visit(node.lam)
280
+ self.plain_print(") ")
281
+ self.print_seq(node.registers, prefix="{", suffix="}", emit=self.visit)
282
+
283
+ def visit_NoisePAULI1(self, node: NoisePAULI1) -> None:
284
+ self.plain_print("noise.PAULI1", style="keyword")
285
+ self.plain_print("(")
286
+ self.visit(node.px)
287
+ self.plain_print(", ")
288
+ self.visit(node.py)
289
+ self.plain_print(", ")
290
+ self.visit(node.pz)
291
+ self.plain_print(") ")
292
+ self.visit(node.qarg)
293
+ self.plain_print(";")
@@ -0,0 +1,75 @@
1
+ mainprogram: header ";" [statement*]
2
+ ?header: openqasm | kirin
3
+ openqasm: "OPENQASM" version
4
+ kirin: "KIRIN" "{" dialect ("," dialect)* "}"
5
+ dialect: IDENTIFIER ("." IDENTIFIER)*
6
+ version: INT "." INT
7
+ ?statement: regdecl | gate | opaque | qop | ifstmt | barrier | include
8
+ ?regdecl: qreg | creg
9
+ // stmts
10
+ include: "include" STRING ";"
11
+ ifstmt: "if" "(" expr "==" expr ")" ifbody
12
+ ifbody: qop | "{" qop* "}" // allow multiple qops
13
+ opaque: "opaque" IDENTIFIER ["(" [params] ")"] qubits ";"
14
+ barrier: "barrier" qubits ";"
15
+ qreg: "qreg" IDENTIFIER "[" INT "]" ";"
16
+ creg: "creg" IDENTIFIER "[" INT "]" ";"
17
+ gate: "gate" IDENTIFIER ["(" cparams? ")"] qparams "{" (uop | barrier)* "}"
18
+ cparams: IDENTIFIER ("," IDENTIFIER)*
19
+ qparams: IDENTIFIER ("," IDENTIFIER)*
20
+
21
+ // quantum ops
22
+ ?qop: uop | measure | extension | reset
23
+ reset: "reset" bit ";"
24
+ measure: "measure" bit "->" bit ";"
25
+
26
+ ?uop: inst | ugate | cx_gate
27
+ inst: IDENTIFIER ["(" [params] ")"] qubits ";"
28
+ ugate: "U" "(" expr "," expr "," expr ")" bit ";"
29
+ cx_gate: "CX" bit "," bit ";"
30
+ params: expr ("," expr)*
31
+
32
+ ?extension: parallel | glob | noise
33
+ ?glob: glob_u_gate
34
+ glob_u_gate: "glob" "." "U" "(" expr "," expr "," expr ")" global_body
35
+ global_body: "{" IDENTIFIER ("," IDENTIFIER)* "}"
36
+ ?noise: noise_pauli1
37
+ noise_pauli1: "noise" "." "PAULI1" "(" expr "," expr "," expr ")" bit ";"
38
+ ?parallel: para_u_gate | para_rz_gate | para_cz_gate
39
+ para_u_gate: "parallel" "." "U" "(" expr "," expr "," expr ")" parallel_body
40
+ para_rz_gate: "parallel" "." "RZ" "(" expr ")" parallel_body
41
+ para_cz_gate: "parallel" "." "CZ" parallel_body
42
+ parallel_body: "{" task_args* "}"
43
+ task_args: bit ("," bit)* ";"
44
+
45
+ ?bit: IDENTIFIER ("[" INT "]")?
46
+ qubits: bit ("," bit)*
47
+
48
+ ?expr: term ((PLUS|MINUS) term)*
49
+ ?term: factor ((TIMES|DIVIDE) factor)*
50
+ ?factor: (PLUS|MINUS) factor | power
51
+ ?power: molecule (POW factor)?
52
+ ?molecule: molecule "(" [arglist] ")" -> call
53
+ | atom
54
+ ?atom: "(" expr ")" | IDENTIFIER | NUMBER | PI
55
+ arglist: argument ("," argument)*
56
+ ?argument: expr
57
+
58
+ VERSION: "2.0"
59
+ PI: "pi"
60
+ PLUS: "+"
61
+ MINUS: "-"
62
+ TIMES: "*"
63
+ DIVIDE: "/"
64
+ POW: "^"
65
+ COMMENT: "//" /[^\n]*/ NEWLINE
66
+ NEWLINE: "\n"
67
+
68
+ %import common.INT
69
+ %import common.FLOAT
70
+ %import common.CNAME -> IDENTIFIER
71
+ %import common.NUMBER
72
+ %import common.ESCAPED_STRING -> STRING
73
+ %import common.WS
74
+ %ignore WS
75
+ %ignore COMMENT
@@ -0,0 +1,16 @@
1
+ from typing import Generic, TypeVar
2
+
3
+ from . import ast
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ class Visitor(Generic[T]):
9
+
10
+ def visit(self, node: ast.Node) -> T:
11
+ method_name = f"visit_{node.__class__.__name__}"
12
+ visitor = getattr(self, method_name, self.generic_visit)
13
+ return visitor(node)
14
+
15
+ def generic_visit(self, node: ast.Node) -> T:
16
+ raise NotImplementedError(f"No visit_{node.__class__.__name__} method")
@@ -0,0 +1,39 @@
1
+ from typing import Generic, TypeVar
2
+
3
+ from . import ast
4
+
5
+ T = TypeVar("T")
6
+
7
+ class Visitor(Generic[T]):
8
+ def visit(self, node: ast.Node) -> T: ...
9
+ def generic_visit(self, node: ast.Node) -> T: ...
10
+ def visit_MainProgram(self, node: ast.MainProgram) -> T: ...
11
+ def visit_OPENQASM(self, node: ast.OPENQASM) -> T: ...
12
+ def visit_Kirin(self, node: ast.Kirin) -> T: ...
13
+ def visit_QReg(self, node: ast.QReg) -> T: ...
14
+ def visit_CReg(self, node: ast.CReg) -> T: ...
15
+ def visit_Gate(self, node: ast.Gate) -> T: ...
16
+ def visit_Opaque(self, node: ast.Opaque) -> T: ...
17
+ def visit_IfStmt(self, node: ast.IfStmt) -> T: ...
18
+ def visit_Cmp(self, node: ast.Cmp) -> T: ...
19
+ def visit_Barrier(self, node: ast.Barrier) -> T: ...
20
+ def visit_Include(self, node: ast.Include) -> T: ...
21
+ def visit_Measure(self, node: ast.Measure) -> T: ...
22
+ def visit_Reset(self, node: ast.Reset) -> T: ...
23
+ def visit_Instruction(self, node: ast.Instruction) -> T: ...
24
+ def visit_UGate(self, node: ast.UGate) -> T: ...
25
+ def visit_CXGate(self, node: ast.CXGate) -> T: ...
26
+ def visit_Bit(self, node: ast.Bit) -> T: ...
27
+ def visit_BinOp(self, node: ast.BinOp) -> T: ...
28
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> T: ...
29
+ def visit_Call(self, node: ast.Call) -> T: ...
30
+ def visit_Number(self, node: ast.Number) -> T: ...
31
+ def visit_Pi(self, node: ast.Pi) -> T: ...
32
+ def visit_Name(self, node: ast.Name) -> T: ...
33
+ # extensions
34
+ def visit_ParaU3Gate(self, node: ast.ParaU3Gate) -> T: ...
35
+ def visit_ParaCZGate(self, node: ast.ParaCZGate) -> T: ...
36
+ def visit_ParaRZGate(self, node: ast.ParaRZGate) -> T: ...
37
+ def visit_ParallelQArgs(self, node: ast.ParallelQArgs) -> T: ...
38
+ def visit_GlobUGate(self, node: ast.GlobUGate) -> T: ...
39
+ def visit_NoisePAULI1(self, node: ast.NoisePAULI1) -> T: ...
@@ -0,0 +1,5 @@
1
+ from .fold import QASM2Fold as QASM2Fold
2
+ from .noise import NoisePass as NoisePass
3
+ from .py2qasm import Py2QASM as Py2QASM
4
+ from .qasm2py import QASM2Py as QASM2Py
5
+ from .parallel import UOpToParallel as UOpToParallel
@@ -0,0 +1,94 @@
1
+ from dataclasses import field, dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass, TypeInfer
5
+ from kirin.rewrite import (
6
+ Walk,
7
+ Chain,
8
+ Inline,
9
+ Fixpoint,
10
+ WrapConst,
11
+ Call2Invoke,
12
+ ConstantFold,
13
+ CFGCompactify,
14
+ InlineGetItem,
15
+ InlineGetField,
16
+ DeadCodeElimination,
17
+ CommonSubexpressionElimination,
18
+ )
19
+ from kirin.analysis import const
20
+ from kirin.dialects import scf, ilist
21
+ from kirin.ir.method import Method
22
+ from kirin.rewrite.abc import RewriteResult
23
+
24
+ from bloqade.qasm2.dialects import expr
25
+
26
+
27
+ @dataclass
28
+ class QASM2Fold(Pass):
29
+ """Fold pass for qasm2.extended"""
30
+
31
+ constprop: const.Propagate = field(init=False)
32
+ inline_gate_subroutine: bool = True
33
+
34
+ def __post_init__(self):
35
+ self.constprop = const.Propagate(self.dialects)
36
+ self.typeinfer = TypeInfer(self.dialects)
37
+
38
+ def unsafe_run(self, mt: Method) -> RewriteResult:
39
+ result = RewriteResult()
40
+ frame, _ = self.constprop.run_analysis(mt)
41
+ result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
42
+ rule = Chain(
43
+ ConstantFold(),
44
+ Call2Invoke(),
45
+ InlineGetField(),
46
+ InlineGetItem(),
47
+ DeadCodeElimination(),
48
+ CommonSubexpressionElimination(),
49
+ )
50
+ result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
51
+
52
+ result = (
53
+ Walk(
54
+ Chain(
55
+ scf.unroll.PickIfElse(),
56
+ scf.unroll.ForLoop(),
57
+ scf.trim.UnusedYield(),
58
+ )
59
+ )
60
+ .rewrite(mt.code)
61
+ .join(result)
62
+ )
63
+
64
+ # run typeinfer again after unroll etc. because we now insert
65
+ # a lot of new nodes, which might have more precise types
66
+ self.typeinfer.unsafe_run(mt)
67
+ result = (
68
+ Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
69
+ .rewrite(mt.code)
70
+ .join(result)
71
+ )
72
+
73
+ def inline_simple(node: ir.Statement):
74
+ if isinstance(node, expr.GateFunction):
75
+ return self.inline_gate_subroutine
76
+
77
+ if not isinstance(node.parent_stmt, (scf.For, scf.IfElse)):
78
+ return True # always inline calls outside of loops and if-else
79
+
80
+ # inside loops and if-else, only inline simple functions, i.e. functions with a single block
81
+ if (trait := node.get_trait(ir.CallableStmtInterface)) is None:
82
+ return False # not a callable, don't inline to be safe
83
+ region = trait.get_callable_region(node)
84
+ return len(region.blocks) == 1
85
+
86
+ result = (
87
+ Walk(
88
+ Inline(inline_simple),
89
+ )
90
+ .rewrite(mt.code)
91
+ .join(result)
92
+ )
93
+ result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
94
+ return result
@@ -0,0 +1,119 @@
1
+ """
2
+ Passes that deal with global gates. As of now, only one rewrite pass exists
3
+ which converts global gates to single qubit gates.
4
+ """
5
+
6
+ from kirin import ir
7
+ from kirin.rewrite import cse, dce, walk, result
8
+ from kirin.passes.abc import Pass
9
+ from kirin.passes.fold import Fold
10
+ from kirin.rewrite.fixpoint import Fixpoint
11
+
12
+ from bloqade.analysis import address
13
+ from bloqade.qasm2.rewrite import GlobalToUOpRule, GlobalToParallelRule
14
+
15
+
16
+ class GlobalToUOP(Pass):
17
+ """Pass to convert Global gates into single gates.
18
+
19
+ This pass rewrites the global unitary gate from the `qasm2.glob` dialect into multiple
20
+ single gates in the `qasm2.uop` dialect, bringing the program closer to
21
+ conforming to standard QASM2 syntax.
22
+
23
+
24
+ ## Usage Examples
25
+ ```
26
+ # Define kernel
27
+ @qasm2.extended
28
+ def main():
29
+ q1 = qasm2.qreg(1)
30
+ q2 = qasm2.qreg(2)
31
+
32
+ theta = 1.3
33
+ phi = 1.1
34
+ lam = 1.2
35
+
36
+ qasm2.glob.u(theta=theta, phi=phi, lam=lam, registers=[q1, q2])
37
+
38
+ GlobalToUOP(dialects=main.dialects)(main)
39
+
40
+ # Run rewrite
41
+ GlobalToUOP(main.dialects)(main)
42
+ ```
43
+
44
+ The `qasm2.glob.u` statement has been rewritten to individual gates:
45
+
46
+ ```
47
+ qasm2.uop.u(q1[0], theta, phi, lam)
48
+ qasm2.uop.u(q2[0], theta, phi, lam)
49
+ qasm2.uop.u(q2[1], theta, phi, lam)
50
+ ```
51
+ """
52
+
53
+ def generate_rule(self, mt: ir.Method) -> GlobalToUOpRule:
54
+ frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
55
+ return GlobalToUOpRule(frame.entries)
56
+
57
+ def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
58
+ rewriter = walk.Walk(self.generate_rule(mt))
59
+ result = rewriter.rewrite(mt.code)
60
+
61
+ result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code)
62
+ result = Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination())).rewrite(
63
+ mt.code
64
+ )
65
+
66
+ # do fold again to get proper hint for inserted const
67
+ result = Fold(mt.dialects)(mt)
68
+ return result
69
+
70
+
71
+ class GlobalToParallel(Pass):
72
+ """Pass to convert Global gates into parallel gates.
73
+
74
+ This pass rewrites the global unitary gate from the `qasm2.glob` dialect into multiple
75
+ parallel gates in the `qasm2.parallel` dialect.
76
+
77
+
78
+ ## Usage Examples
79
+ ```
80
+ # Define kernel
81
+ @qasm2.extended
82
+ def main():
83
+ q1 = qasm2.qreg(1)
84
+ q2 = qasm2.qreg(2)
85
+
86
+ theta = 1.3
87
+ phi = 1.1
88
+ lam = 1.2
89
+
90
+ qasm2.glob.u(theta=theta, phi=phi, lam=lam, registers=[q1, q2])
91
+
92
+ GlobalToParallel(dialects=main.dialects)(main)
93
+
94
+ # Run rewrite
95
+ GlobalToParallel(main.dialects)(main)
96
+ ```
97
+
98
+ The `qasm2.glob.u` statement has been rewritten to individual gates:
99
+
100
+ ```
101
+ qasm2.parallel.u(theta=theta, phi=phi, lam=lam, qargs=[q1[0], q2[0], q2[1]])
102
+ ```
103
+ """
104
+
105
+ def generate_rule(self, mt: ir.Method) -> GlobalToParallelRule:
106
+ frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
107
+ return GlobalToParallelRule(frame.entries)
108
+
109
+ def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
110
+ rewriter = walk.Walk(self.generate_rule(mt))
111
+ result = rewriter.rewrite(mt.code)
112
+
113
+ result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code)
114
+ result = Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination())).rewrite(
115
+ mt.code
116
+ )
117
+ # do fold again to get proper hint
118
+ result = Fold(mt.dialects)(mt)
119
+ return result