bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,5 @@
1
1
  import abc
2
+ from typing import Sequence
2
3
  from dataclasses import field, dataclass
3
4
 
4
5
 
@@ -161,7 +162,7 @@ class MoveNoiseModelABC(abc.ABC):
161
162
 
162
163
  @abc.abstractmethod
163
164
  def parallel_cz_errors(
164
- self, ctrls: list[int], qargs: list[int], rest: list[int]
165
+ self, ctrls: Sequence[int], qargs: Sequence[int], rest: Sequence[int]
165
166
  ) -> dict[tuple[float, float, float, float], list[int]]:
166
167
  """Takes a set of ctrls and qargs and returns a noise model for all qubits."""
167
168
  pass
@@ -2,8 +2,9 @@ from abc import ABC
2
2
  from typing import Generic, TypeVar, overload
3
3
  from dataclasses import field, dataclass
4
4
 
5
- from kirin import ir, idtable
6
- from kirin.emit import EmitABC, EmitError, EmitFrame
5
+ from kirin import ir, interp, idtable
6
+ from kirin.emit import EmitABC, EmitFrame
7
+ from kirin.worklist import WorkList
7
8
  from typing_extensions import Self
8
9
 
9
10
  from bloqade.qasm2.parse import ast
@@ -15,6 +16,9 @@ EmitNode = TypeVar("EmitNode", bound=ast.Node)
15
16
  @dataclass
16
17
  class EmitQASM2Frame(EmitFrame[ast.Node | None], Generic[StmtType]):
17
18
  body: list[StmtType] = field(default_factory=list)
19
+ worklist: WorkList[interp.Successor] = field(default_factory=WorkList)
20
+ block_ref: dict[ir.Block, ast.Node | None] = field(default_factory=dict)
21
+ _indent: int = 0
18
22
 
19
23
 
20
24
  @dataclass
@@ -37,18 +41,13 @@ class EmitQASM2Base(
37
41
  return self
38
42
 
39
43
  def initialize_frame(
40
- self, code: ir.Statement, *, has_parent_access: bool = False
44
+ self, node: ir.Statement, *, has_parent_access: bool = False
41
45
  ) -> EmitQASM2Frame[StmtType]:
42
- return EmitQASM2Frame(code, has_parent_access=has_parent_access)
43
-
44
- def run_method(
45
- self, method: ir.Method, args: tuple[ast.Node | None, ...]
46
- ) -> tuple[EmitQASM2Frame[StmtType], ast.Node | None]:
47
- return self.run_callable(method.code, (ast.Name(method.sym_name),) + args)
46
+ return EmitQASM2Frame(node, has_parent_access=has_parent_access)
48
47
 
49
48
  def emit_block(self, frame: EmitQASM2Frame, block: ir.Block) -> ast.Node | None:
50
49
  for stmt in block.stmts:
51
- result = self.eval_stmt(frame, stmt)
50
+ result = self.frame_eval(frame, stmt)
52
51
  if isinstance(result, tuple):
53
52
  frame.set_values(stmt.results, result)
54
53
  return None
@@ -70,5 +69,11 @@ class EmitQASM2Base(
70
69
  node: ast.Node | None,
71
70
  ) -> A | B:
72
71
  if not isinstance(node, typ):
73
- raise EmitError(f"expected {typ}, got {type(node)}")
72
+ raise TypeError(f"expected {typ}, got {type(node)}")
74
73
  return node
74
+
75
+ def reset(self):
76
+ pass
77
+
78
+ def eval_fallback(self, frame: EmitQASM2Frame, node: ir.Statement):
79
+ return tuple(None for _ in range(len(node.results)))
@@ -3,11 +3,12 @@ from dataclasses import field, dataclass
3
3
  from kirin import ir, types, interp
4
4
  from kirin.dialects import py, func, ilist
5
5
  from kirin.ir.dialect import Dialect as Dialect
6
+ from typing_extensions import Self
6
7
 
7
8
  from bloqade.types import QubitType
8
9
  from bloqade.qasm2.parse import ast
9
10
 
10
- from .base import EmitError, EmitQASM2Base, EmitQASM2Frame
11
+ from .base import EmitQASM2Base, EmitQASM2Frame
11
12
 
12
13
 
13
14
  def _default_dialect_group():
@@ -18,9 +19,13 @@ def _default_dialect_group():
18
19
 
19
20
  @dataclass
20
21
  class EmitQASM2Gate(EmitQASM2Base[ast.UOp | ast.Barrier, ast.Gate]):
21
- keys = ["emit.qasm2.gate"]
22
+ keys = ("emit.qasm2.gate",)
22
23
  dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
23
24
 
25
+ def initialize(self) -> Self:
26
+ super().initialize()
27
+ return self
28
+
24
29
 
25
30
  @ilist.dialect.register(key="emit.qasm2.gate")
26
31
  class Ilist(interp.MethodTable):
@@ -45,7 +50,7 @@ class Func(interp.MethodTable):
45
50
 
46
51
  @interp.impl(func.Call)
47
52
  def emit_call(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: func.Call):
48
- raise EmitError("cannot emit dynamic call")
53
+ raise RuntimeError("cannot emit dynamic call")
49
54
 
50
55
  @interp.impl(func.Invoke)
51
56
  def emit_invoke(
@@ -55,7 +60,7 @@ class Func(interp.MethodTable):
55
60
  if len(stmt.results) == 1 and stmt.results[0].type.is_subseteq(types.NoneType):
56
61
  ret = (None,)
57
62
  elif len(stmt.results) > 0:
58
- raise EmitError(
63
+ raise RuntimeError(
59
64
  "cannot emit invoke with results, this "
60
65
  "is not compatible QASM2 gate routine"
61
66
  " (consider pass qreg/creg by argument)"
@@ -67,10 +72,9 @@ class Func(interp.MethodTable):
67
72
  qparams.append(frame.get(arg))
68
73
  else:
69
74
  cparams.append(frame.get(arg))
70
-
71
75
  frame.body.append(
72
76
  ast.Instruction(
73
- name=ast.Name(stmt.callee.sym_name),
77
+ name=ast.Name(stmt.callee.__getattribute__("sym_name")),
74
78
  params=cparams,
75
79
  qargs=qparams,
76
80
  )
@@ -80,9 +84,8 @@ class Func(interp.MethodTable):
80
84
  @interp.impl(func.Lambda)
81
85
  @interp.impl(func.GetField)
82
86
  def emit_err(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
83
- raise EmitError(f"illegal statement {stmt.name} for QASM2 gate routine")
87
+ raise RuntimeError(f"illegal statement {stmt.name} for QASM2 gate routine")
84
88
 
85
89
  @interp.impl(func.Return)
86
- @interp.impl(func.ConstantNone)
87
90
  def ignore(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
88
91
  return ()
@@ -1,8 +1,10 @@
1
+ from typing import List, cast
1
2
  from dataclasses import dataclass
2
3
 
3
4
  from kirin import ir, interp
4
5
  from kirin.dialects import cf, scf, func
5
6
  from kirin.ir.dialect import Dialect as Dialect
7
+ from typing_extensions import Self
6
8
 
7
9
  from bloqade.qasm2.parse import ast
8
10
  from bloqade.qasm2.dialects.uop import SingleQubitGate, TwoQubitCtrlGate
@@ -14,26 +16,124 @@ from ..dialects.core.stmts import Reset, Measure
14
16
 
15
17
  @dataclass
16
18
  class EmitQASM2Main(EmitQASM2Base[ast.Statement, ast.MainProgram]):
17
- keys = ["emit.qasm2.main", "emit.qasm2.gate"]
19
+ keys = ("emit.qasm2.main", "emit.qasm2.gate")
18
20
  dialects: ir.DialectGroup
19
21
 
22
+ def initialize(self) -> Self:
23
+ super().initialize()
24
+ return self
25
+
26
+ def eval_fallback(self, frame: EmitQASM2Frame, node: ir.Statement):
27
+ return tuple(None for _ in range(len(node.results)))
28
+
20
29
 
21
30
  @func.dialect.register(key="emit.qasm2.main")
22
31
  class Func(interp.MethodTable):
32
+ @interp.impl(func.Invoke)
33
+ def invoke(self, emit: EmitQASM2Main, frame: EmitQASM2Frame, node: func.Invoke):
34
+ name = emit.callables.get(node.callee.code)
35
+ if name is None:
36
+ name = emit.callables.add(node.callee.code)
37
+ emit.callable_to_emit.append(node.callee.code)
38
+
39
+ if isinstance(node.callee.code, GateFunction):
40
+ c_params: list[ast.Expr] = []
41
+ q_args: list[ast.Bit | ast.Name] = []
42
+
43
+ for arg in node.args:
44
+ val = frame.get(arg)
45
+ if val is None:
46
+ raise interp.InterpreterError(f"missing mapping for arg {arg}")
47
+ if isinstance(val, (ast.Bit, ast.Name)):
48
+ q_args.append(val)
49
+ elif isinstance(val, ast.Expr):
50
+ c_params.append(val)
51
+
52
+ instr = ast.Instruction(
53
+ name=ast.Name(name) if isinstance(name, str) else name,
54
+ params=c_params,
55
+ qargs=q_args,
56
+ )
57
+ frame.body.append(instr)
58
+ return ()
59
+
60
+ callee_name_node = ast.Name(name) if isinstance(name, str) else name
61
+ args = tuple(frame.get_values(node.args))
62
+ _, call_expr = emit.call(node.callee.code, callee_name_node, *args)
63
+ if call_expr is not None:
64
+ frame.body.append(call_expr)
65
+ return ()
23
66
 
24
67
  @interp.impl(func.Function)
25
68
  def emit_func(
26
69
  self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: func.Function
27
70
  ):
28
71
  from bloqade.qasm2.dialects import glob, parallel
72
+ from bloqade.qasm2.emit.gate import EmitQASM2Gate
73
+
74
+ if isinstance(stmt, GateFunction):
75
+ return ()
76
+
77
+ func_name = emit.callables.get(stmt)
78
+ if func_name is None:
79
+ func_name = emit.callables.add(stmt)
80
+
81
+ for block in stmt.body.blocks:
82
+ frame.current_block = block
83
+ for s in block.stmts:
84
+ frame.current_stmt = s
85
+ stmt_results = emit.frame_eval(frame, s)
86
+ if isinstance(stmt_results, tuple):
87
+ if len(stmt_results) != 0:
88
+ frame.set_values(s._results, stmt_results)
89
+ continue
90
+
91
+ gate_defs: list[ast.Gate] = []
92
+
93
+ gate_emitter = EmitQASM2Gate(dialects=emit.dialects).initialize()
94
+ gate_emitter.callables = emit.callables
95
+
96
+ while emit.callable_to_emit:
97
+ callable_node = emit.callable_to_emit.pop()
98
+ if callable_node is None:
99
+ break
100
+
101
+ if isinstance(callable_node, GateFunction):
102
+ with gate_emitter.eval_context():
103
+ with gate_emitter.new_frame(
104
+ callable_node, has_parent_access=False
105
+ ) as gate_frame:
106
+ gate_result = gate_emitter.frame_eval(gate_frame, callable_node)
107
+ gate_obj = None
108
+ if isinstance(gate_result, tuple) and len(gate_result) > 0:
109
+ maybe = gate_result[0]
110
+ if isinstance(maybe, ast.Gate):
111
+ gate_obj = maybe
112
+
113
+ if gate_obj is None:
114
+ name = emit.callables.get(
115
+ callable_node
116
+ ) or emit.callables.add(callable_node)
117
+ prefix = getattr(emit.callables, "prefix", "") or ""
118
+ emit_name = (
119
+ name[len(prefix) :]
120
+ if prefix and name.startswith(prefix)
121
+ else name
122
+ )
123
+ gate_obj = ast.Gate(
124
+ name=emit_name, cparams=[], qparams=[], body=[]
125
+ )
126
+
127
+ gate_defs.append(gate_obj)
29
128
 
30
- emit.run_ssacfg_region(frame, stmt.body, ())
31
129
  if emit.dialects.data.intersection((parallel.dialect, glob.dialect)):
32
130
  header = ast.Kirin([dialect.name for dialect in emit.dialects])
33
131
  else:
34
132
  header = ast.OPENQASM(ast.Version(2, 0))
35
133
 
36
- emit.output = ast.MainProgram(header=header, statements=frame.body)
134
+ full_body = gate_defs + frame.body
135
+ stmt_list = cast(List[ast.Statement], full_body)
136
+ emit.output = ast.MainProgram(header=header, statements=stmt_list)
37
137
  return ()
38
138
 
39
139
 
@@ -115,8 +115,8 @@ class QASM2:
115
115
  ParallelToUOp(dialects=entry.dialects)(entry)
116
116
 
117
117
  Py2QASM(entry.dialects)(entry)
118
- target_main = EmitQASM2Main(self.main_target)
119
- target_main.run(entry, ())
118
+ target_main = EmitQASM2Main(self.main_target).initialize()
119
+ target_main.run(entry)
120
120
 
121
121
  main_program = target_main.output
122
122
  assert main_program is not None, f"failed to emit {entry.sym_name}"
@@ -127,9 +127,13 @@ class QASM2:
127
127
 
128
128
  if self.custom_gate:
129
129
  cg = CallGraph(entry)
130
- target_gate = EmitQASM2Gate(self.gate_target)
130
+ target_gate = EmitQASM2Gate(self.gate_target).initialize()
131
131
 
132
- for _, fn in cg.defs.items():
132
+ for _, fns in cg.defs.items():
133
+ if len(fns) != 1:
134
+ raise ValueError("Incorrect callgraph")
135
+
136
+ (fn,) = fns
133
137
  if fn is entry:
134
138
  continue
135
139
 
@@ -146,7 +150,7 @@ class QASM2:
146
150
 
147
151
  Py2QASM(fn.dialects)(fn)
148
152
 
149
- target_gate.run(fn, tuple(ast.Name(name) for name in fn.arg_names[1:]))
153
+ target_gate.run(fn)
150
154
  assert target_gate.output is not None, f"failed to emit {fn.sym_name}"
151
155
  extra.append(target_gate.output)
152
156
 
bloqade/qasm2/groups.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
- from kirin.dialects import scf, func, ilist, lowering
3
+ from kirin.dialects import scf, func, ilist, ssacfg, lowering
4
4
 
5
5
  from bloqade.qasm2.dialects import (
6
6
  uop,
@@ -15,7 +15,7 @@ from bloqade.qasm2.dialects import (
15
15
  from bloqade.qasm2.rewrite.desugar import IndexingDesugarPass
16
16
 
17
17
 
18
- @ir.dialect_group([uop, func, expr, lowering.func, lowering.call])
18
+ @ir.dialect_group([uop, func, expr, lowering.func, lowering.call, ssacfg])
19
19
  def gate(self):
20
20
  fold_pass = passes.Fold(self)
21
21
  typeinfer_pass = passes.TypeInfer(self)
@@ -58,6 +58,7 @@ def gate(self):
58
58
  func,
59
59
  lowering.func,
60
60
  lowering.call,
61
+ ssacfg,
61
62
  ]
62
63
  )
63
64
  def main(self):
@@ -450,7 +450,6 @@ class QASM2(lowering.LoweringABC[ast.Node]):
450
450
  func.Invoke(
451
451
  callee=value,
452
452
  inputs=tuple(params + qargs),
453
- kwargs=tuple(),
454
453
  )
455
454
  )
456
455
 
@@ -1,27 +1,12 @@
1
1
  from dataclasses import field, dataclass
2
2
 
3
3
  from kirin import ir
4
- from kirin.passes import Pass, TypeInfer
5
- from kirin.rewrite import (
6
- Walk,
7
- Chain,
8
- Inline,
9
- Fixpoint,
10
- WrapConst,
11
- Call2Invoke,
12
- ConstantFold,
13
- CFGCompactify,
14
- InlineGetItem,
15
- InlineGetField,
16
- DeadCodeElimination,
17
- CommonSubexpressionElimination,
18
- )
19
- from kirin.analysis import const
20
- from kirin.dialects import scf, ilist
4
+ from kirin.passes import Pass
21
5
  from kirin.ir.method import Method
22
6
  from kirin.rewrite.abc import RewriteResult
23
7
 
24
8
  from bloqade.qasm2.dialects import expr
9
+ from bloqade.rewrite.passes import AggressiveUnroll
25
10
 
26
11
  from .unroll_if import UnrollIfs
27
12
 
@@ -30,71 +15,27 @@ from .unroll_if import UnrollIfs
30
15
  class QASM2Fold(Pass):
31
16
  """Fold pass for qasm2.extended"""
32
17
 
33
- constprop: const.Propagate = field(init=False)
34
18
  inline_gate_subroutine: bool = True
35
19
  unroll_ifs: bool = True
20
+ aggressive_unroll: AggressiveUnroll = field(init=False)
36
21
 
37
22
  def __post_init__(self):
38
- self.constprop = const.Propagate(self.dialects)
39
- self.typeinfer = TypeInfer(self.dialects)
23
+ def inline_simple(node: ir.Statement):
24
+ if isinstance(node, expr.GateFunction):
25
+ return self.inline_gate_subroutine
40
26
 
41
- def unsafe_run(self, mt: Method) -> RewriteResult:
42
- result = RewriteResult()
43
- frame, _ = self.constprop.run_analysis(mt)
44
- result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
45
- rule = Chain(
46
- ConstantFold(),
47
- Call2Invoke(),
48
- InlineGetField(),
49
- InlineGetItem(),
50
- DeadCodeElimination(),
51
- CommonSubexpressionElimination(),
52
- )
53
- result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
27
+ return True
54
28
 
55
- result = (
56
- Walk(
57
- Chain(
58
- scf.unroll.PickIfElse(),
59
- scf.unroll.ForLoop(),
60
- scf.trim.UnusedYield(),
61
- )
62
- )
63
- .rewrite(mt.code)
64
- .join(result)
29
+ self.aggressive_unroll = AggressiveUnroll(
30
+ self.dialects, inline_simple, no_raise=self.no_raise
65
31
  )
66
32
 
67
- if self.unroll_ifs:
68
- UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69
-
70
- # run typeinfer again after unroll etc. because we now insert
71
- # a lot of new nodes, which might have more precise types
72
- self.typeinfer.unsafe_run(mt)
73
- result = (
74
- Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
75
- .rewrite(mt.code)
76
- .join(result)
77
- )
78
-
79
- def inline_simple(node: ir.Statement):
80
- if isinstance(node, expr.GateFunction):
81
- return self.inline_gate_subroutine
33
+ def unsafe_run(self, mt: Method) -> RewriteResult:
34
+ result = RewriteResult()
82
35
 
83
- if not isinstance(node.parent_stmt, (scf.For, scf.IfElse)):
84
- return True # always inline calls outside of loops and if-else
36
+ if self.unroll_ifs:
37
+ result = UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
85
38
 
86
- # inside loops and if-else, only inline simple functions, i.e. functions with a single block
87
- if (trait := node.get_trait(ir.CallableStmtInterface)) is None:
88
- return False # not a callable, don't inline to be safe
89
- region = trait.get_callable_region(node)
90
- return len(region.blocks) == 1
39
+ result = self.aggressive_unroll.unsafe_run(mt).join(result)
91
40
 
92
- result = (
93
- Walk(
94
- Inline(inline_simple),
95
- )
96
- .rewrite(mt.code)
97
- .join(result)
98
- )
99
- result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
100
41
  return result
@@ -51,7 +51,7 @@ class GlobalToUOP(Pass):
51
51
  """
52
52
 
53
53
  def generate_rule(self, mt: ir.Method) -> GlobalToUOpRule:
54
- frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
54
+ frame, _ = address.AddressAnalysis(mt.dialects).run(mt)
55
55
  return GlobalToUOpRule(frame.entries)
56
56
 
57
57
  def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
@@ -105,7 +105,7 @@ class GlobalToParallel(Pass):
105
105
  """
106
106
 
107
107
  def generate_rule(self, mt: ir.Method) -> GlobalToParallelRule:
108
- frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
108
+ frame, _ = address.AddressAnalysis(mt.dialects).run(mt)
109
109
  return GlobalToParallelRule(frame.entries)
110
110
 
111
111
  def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
@@ -55,7 +55,7 @@ class NoisePass(Pass):
55
55
  self.address_analysis = address.AddressAnalysis(self.dialects)
56
56
 
57
57
  def get_qubit_values(self, mt: ir.Method):
58
- frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
58
+ frame, _ = self.address_analysis.run(mt)
59
59
  qubit_ssa_values = {}
60
60
  # Traverse statements in block order to fine the first SSA value for each qubit
61
61
  for block in mt.callable_region.blocks:
@@ -28,7 +28,6 @@ from bloqade.qasm2.rewrite import (
28
28
  UOpToParallelRule,
29
29
  ParallelToGlobalRule,
30
30
  SimpleOptimalMergePolicy,
31
- RydbergGateSetRewriteRule,
32
31
  )
33
32
  from bloqade.squin.analysis import schedule
34
33
 
@@ -64,7 +63,7 @@ class ParallelToUOp(Pass):
64
63
  """
65
64
 
66
65
  def generate_rule(self, mt: ir.Method) -> ParallelToUOpRule:
67
- frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
66
+ frame, _ = address.AddressAnalysis(mt.dialects).run(mt)
68
67
 
69
68
  id_map = {}
70
69
 
@@ -151,16 +150,19 @@ class UOpToParallel(Pass):
151
150
  return result
152
151
 
153
152
  if self.rewrite_to_native_first:
153
+ # NOTE: this import also imports cirq, so we do it locally here
154
+ from bloqade.qasm2.rewrite.native_gates import RydbergGateSetRewriteRule
155
+
154
156
  result = (
155
157
  Fixpoint(Walk(RydbergGateSetRewriteRule(self.dialects)))
156
158
  .rewrite(mt.code)
157
159
  .join(result)
158
160
  )
159
161
 
160
- frame, _ = self.constprop.run_analysis(mt)
162
+ frame, _ = self.constprop.run(mt)
161
163
  result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
162
164
 
163
- frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
165
+ frame, _ = address.AddressAnalysis(mt.dialects).run(mt)
164
166
  dags = schedule.DagScheduleAnalysis(
165
167
  mt.dialects, address_analysis=frame.entries
166
168
  ).get_dags(mt)
@@ -191,7 +193,7 @@ class ParallelToGlobal(Pass):
191
193
 
192
194
  def generate_rule(self, mt: ir.Method) -> ParallelToGlobalRule:
193
195
  address_analysis = address.AddressAnalysis(mt.dialects)
194
- frame, _ = address_analysis.run_analysis(mt)
196
+ frame, _ = address_analysis.run(mt)
195
197
  return ParallelToGlobalRule(frame.entries)
196
198
 
197
199
  def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
@@ -3,7 +3,6 @@ from .glob import (
3
3
  GlobalToParallelRule as GlobalToParallelRule,
4
4
  )
5
5
  from .register import RaiseRegisterRule as RaiseRegisterRule
6
- from .native_gates import RydbergGateSetRewriteRule as RydbergGateSetRewriteRule
7
6
  from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
8
7
  from .uop_to_parallel import (
9
8
  MergePolicyABC as MergePolicyABC,
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Tuple, cast
1
+ from typing import Dict, List, Tuple
2
2
  from dataclasses import field, dataclass
3
3
 
4
4
  from kirin import ir
@@ -55,7 +55,7 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
55
55
 
56
56
  def rewrite_global_single_qubit_gate(self, node: glob.UGate):
57
57
  addrs = self.address_analysis[node.registers]
58
- if not isinstance(addrs, address.AddressTuple):
58
+ if not isinstance(addrs, address.PartialIList):
59
59
  return rewrite_abc.RewriteResult()
60
60
 
61
61
  qargs = []
@@ -74,10 +74,7 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
74
74
 
75
75
  def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
76
76
  addrs = self.address_analysis[node.qargs]
77
- if not isinstance(addrs, address.AddressTuple):
78
- return rewrite_abc.RewriteResult()
79
-
80
- if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
77
+ if not isinstance(addrs, address.AddressReg):
81
78
  return rewrite_abc.RewriteResult()
82
79
 
83
80
  assert isinstance(node.qargs, ir.ResultValue)
@@ -178,18 +175,11 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
178
175
  qargs = self.address_analysis[node.qargs]
179
176
 
180
177
  has_done_something = False
181
- if (
182
- isinstance(ctrls, address.AddressTuple)
183
- and all(isinstance(addr, address.AddressQubit) for addr in ctrls.data)
184
- and isinstance(qargs, address.AddressTuple)
185
- and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
178
+ if isinstance(ctrls, address.AddressReg) and isinstance(
179
+ qargs, address.AddressReg
186
180
  ):
187
- ctrl_qubits = list(
188
- map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
189
- )
190
- qarg_qubits = list(
191
- map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
192
- )
181
+ ctrl_qubits = tuple(ctrls.data)
182
+ qarg_qubits = tuple(qargs.data)
193
183
  rest = sorted(
194
184
  set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
195
185
  )
@@ -3,7 +3,6 @@ from dataclasses import dataclass
3
3
 
4
4
  from kirin import ir
5
5
  from kirin.rewrite import abc
6
- from kirin.analysis import const
7
6
  from kirin.dialects import ilist
8
7
 
9
8
  from bloqade.analysis import address
@@ -20,28 +19,24 @@ class ParallelToGlobalRule(abc.RewriteRule):
20
19
  return abc.RewriteResult()
21
20
 
22
21
  qargs = node.qargs
23
- qarg_addresses = self.address_analysis.get(qargs, None)
22
+ qargs_address = self.address_analysis.get(qargs, address.Unknown())
24
23
 
25
- if isinstance(qarg_addresses, address.AddressReg):
26
- # NOTE: we only have an AddressReg if it's an entire register, definitely rewrite that
27
- return self._rewrite_parallel_to_glob(node)
28
-
29
- if not isinstance(qarg_addresses, address.AddressTuple):
24
+ if not isinstance(qargs_address, address.AddressReg):
30
25
  return abc.RewriteResult()
31
26
 
32
- idxs, qreg = self._find_qreg(qargs.owner, set())
27
+ qregs = self._get_all_qreg(qargs.owner)
33
28
 
34
- if qreg is None:
35
- # NOTE: no unique register found
29
+ if len(qregs) != 1:
36
30
  return abc.RewriteResult()
37
31
 
38
- if not isinstance(hint := qreg.n_qubits.hints.get("const"), const.Value):
39
- # NOTE: non-constant number of qubits
32
+ qreg = next(iter(qregs))
33
+
34
+ qreg_address = self.address_analysis.get(qreg, address.Unknown())
35
+
36
+ if not isinstance(qreg_address, address.AddressReg):
40
37
  return abc.RewriteResult()
41
38
 
42
- n = hint.data
43
- if len(idxs) != n:
44
- # NOTE: not all qubits of the register are there
39
+ if set(qargs_address.data) != set(qreg_address.data):
45
40
  return abc.RewriteResult()
46
41
 
47
42
  return self._rewrite_parallel_to_glob(node)
@@ -53,6 +48,24 @@ class ParallelToGlobalRule(abc.RewriteRule):
53
48
  node.replace_by(global_u)
54
49
  return abc.RewriteResult(has_done_something=True)
55
50
 
51
+ @staticmethod
52
+ def _get_all_qreg(owner: ir.Statement | ir.Block):
53
+ stack = [owner]
54
+ qregs: set[ir.SSAValue] = set()
55
+ while stack:
56
+ current = stack.pop()
57
+
58
+ if isinstance(current, core.stmts.QRegGet):
59
+ stack.append(current.reg.owner)
60
+ elif isinstance(current, ilist.New):
61
+ for val in current.values:
62
+ stack.append(val.owner)
63
+
64
+ elif isinstance(current, core.QRegNew):
65
+ qregs.add(current.result)
66
+
67
+ return qregs
68
+
56
69
  @staticmethod
57
70
  def _find_qreg(
58
71
  qargs_owner: ir.Statement | ir.Block, idxs: set