bloqade-circuit 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (80) hide show
  1. bloqade/analysis/address/impls.py +14 -0
  2. bloqade/analysis/fidelity/analysis.py +27 -2
  3. bloqade/noise/fidelity.py +3 -3
  4. bloqade/noise/native/_dialect.py +1 -1
  5. bloqade/noise/native/_wrappers.py +35 -6
  6. bloqade/noise/native/stmts.py +1 -1
  7. bloqade/pyqrack/device.py +109 -21
  8. bloqade/pyqrack/qasm2/core.py +4 -1
  9. bloqade/pyqrack/squin/qubit.py +16 -9
  10. bloqade/pyqrack/squin/wire.py +22 -4
  11. bloqade/pyqrack/task.py +13 -5
  12. bloqade/qasm2/__init__.py +1 -0
  13. bloqade/qasm2/_qasm_loading.py +151 -0
  14. bloqade/qasm2/dialects/core/__init__.py +9 -1
  15. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  16. bloqade/qasm2/dialects/noise.py +33 -1
  17. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  18. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  19. bloqade/qasm2/emit/impls/__init__.py +1 -0
  20. bloqade/qasm2/emit/impls/noise_native.py +89 -0
  21. bloqade/qasm2/emit/main.py +21 -0
  22. bloqade/qasm2/emit/target.py +20 -5
  23. bloqade/qasm2/groups.py +2 -0
  24. bloqade/qasm2/parse/__init__.py +7 -4
  25. bloqade/qasm2/parse/lowering.py +20 -130
  26. bloqade/qasm2/parse/qasm2.lark +1 -1
  27. bloqade/qasm2/passes/__init__.py +1 -0
  28. bloqade/qasm2/passes/fold.py +6 -0
  29. bloqade/qasm2/passes/noise.py +50 -2
  30. bloqade/qasm2/passes/parallel.py +9 -0
  31. bloqade/qasm2/passes/unroll_if.py +25 -0
  32. bloqade/qasm2/rewrite/__init__.py +1 -0
  33. bloqade/qasm2/rewrite/desugar.py +3 -2
  34. bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
  35. bloqade/qasm2/rewrite/native_gates.py +67 -4
  36. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  37. bloqade/squin/analysis/nsites/__init__.py +1 -0
  38. bloqade/squin/analysis/nsites/impls.py +25 -1
  39. bloqade/squin/noise/__init__.py +7 -26
  40. bloqade/squin/noise/_wrapper.py +25 -0
  41. bloqade/squin/op/__init__.py +33 -159
  42. bloqade/squin/op/_wrapper.py +101 -0
  43. bloqade/squin/op/stdlib.py +62 -0
  44. bloqade/squin/passes/__init__.py +1 -0
  45. bloqade/squin/passes/stim.py +68 -0
  46. bloqade/squin/rewrite/__init__.py +11 -0
  47. bloqade/squin/rewrite/qubit_to_stim.py +84 -0
  48. bloqade/squin/rewrite/squin_measure.py +98 -0
  49. bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
  50. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  51. bloqade/squin/rewrite/wire_to_stim.py +73 -0
  52. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  53. bloqade/squin/wire.py +1 -13
  54. bloqade/stim/__init__.py +39 -5
  55. bloqade/stim/_wrappers.py +14 -12
  56. bloqade/stim/dialects/__init__.py +1 -5
  57. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
  58. bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
  59. bloqade/stim/dialects/collapse/__init__.py +13 -2
  60. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
  61. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  62. bloqade/stim/dialects/gate/__init__.py +16 -1
  63. bloqade/stim/dialects/gate/emit.py +1 -1
  64. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  65. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  66. bloqade/stim/dialects/noise/emit.py +1 -1
  67. bloqade/stim/emit/__init__.py +1 -1
  68. bloqade/stim/groups.py +4 -2
  69. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
  70. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +80 -64
  71. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  72. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  73. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  74. /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
  75. /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
  76. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  77. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  78. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  79. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
  80. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,9 @@
1
- from kirin import ir, types, lowering
1
+ from kirin import ir, types, interp, lowering
2
2
  from kirin.decl import info, statement
3
3
 
4
+ from bloqade.qasm2.parse import ast
4
5
  from bloqade.qasm2.types import QubitType
6
+ from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame
5
7
 
6
8
  dialect = ir.Dialect("qasm2.noise")
7
9
 
@@ -14,3 +16,33 @@ class Pauli1(ir.Statement):
14
16
  py: ir.SSAValue = info.argument(types.Float)
15
17
  pz: ir.SSAValue = info.argument(types.Float)
16
18
  qarg: ir.SSAValue = info.argument(QubitType)
19
+
20
+
21
+ @dialect.register(key="emit.qasm2.gate")
22
+ class NoiseEmit(interp.MethodTable):
23
+
24
+ @interp.impl(Pauli1)
25
+ def emit_pauli(
26
+ self,
27
+ emit: EmitQASM2Gate,
28
+ frame: EmitQASM2Frame,
29
+ stmt: Pauli1,
30
+ ):
31
+
32
+ px: ast.Number = frame.get(stmt.px)
33
+ py: ast.Number = frame.get(stmt.py)
34
+ pz: ast.Number = frame.get(stmt.pz)
35
+ qarg: ast.Bit | ast.Name = frame.get(stmt.qarg)
36
+
37
+ qarg_str = (
38
+ f"{qarg.name.id}[{qarg.addr}]"
39
+ if isinstance(qarg, ast.Bit)
40
+ else f"{qarg.id}"
41
+ )
42
+
43
+ frame.body.append(
44
+ ast.Comment(
45
+ text=f"noise.Pauli1({px.value}, {py.value}, {pz.value}) {qarg_str}]"
46
+ )
47
+ )
48
+ return ()
@@ -1,4 +1,40 @@
1
- from . import _emit as _emit, stmts as stmts
2
- from .stmts import * # noqa: F403
1
+ from . import _emit as _emit, stmts as stmts, schedule as schedule
2
+ from .stmts import (
3
+ CH as CH,
4
+ CU as CU,
5
+ CX as CX,
6
+ CY as CY,
7
+ CZ as CZ,
8
+ RX as RX,
9
+ RY as RY,
10
+ RZ as RZ,
11
+ SX as SX,
12
+ U1 as U1,
13
+ U2 as U2,
14
+ CCX as CCX,
15
+ CRX as CRX,
16
+ CRY as CRY,
17
+ CRZ as CRZ,
18
+ CSX as CSX,
19
+ CU1 as CU1,
20
+ CU3 as CU3,
21
+ RXX as RXX,
22
+ RZZ as RZZ,
23
+ H as H,
24
+ S as S,
25
+ T as T,
26
+ X as X,
27
+ Y as Y,
28
+ Z as Z,
29
+ Id as Id,
30
+ Sdag as Sdag,
31
+ Swap as Swap,
32
+ Tdag as Tdag,
33
+ CSwap as CSwap,
34
+ SXdag as SXdag,
35
+ UGate as UGate,
36
+ Barrier as Barrier,
37
+ SingleQubitGate as SingleQubitGate,
38
+ TwoQubitCtrlGate as TwoQubitCtrlGate,
39
+ )
3
40
  from ._dialect import dialect as dialect
4
- from .schedule import * # noqa: F403
@@ -8,7 +8,7 @@ from ._dialect import dialect
8
8
 
9
9
 
10
10
  @dialect.register(key="qasm2.schedule.dag")
11
- class UOp(interp.MethodTable):
11
+ class UOpSchedule(interp.MethodTable):
12
12
 
13
13
  @interp.impl(stmts.Id)
14
14
  @interp.impl(stmts.SXdag)
@@ -0,0 +1 @@
1
+ from . import noise_native as noise_native
@@ -0,0 +1,89 @@
1
+ from typing import Any
2
+
3
+ from kirin import interp
4
+ from kirin.dialects import ilist
5
+
6
+ from bloqade.noise import native
7
+ from bloqade.qasm2.parse import ast
8
+ from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame
9
+
10
+
11
+ @native.dialect.register(key="emit.qasm2.gate")
12
+ class NativeNoise(interp.MethodTable):
13
+
14
+ def _convert(self, node: ast.Bit | ast.Name) -> str:
15
+ if isinstance(node, ast.Bit):
16
+ return f"{node.name.id}[{node.addr}]"
17
+ else:
18
+ return f"{node.id}"
19
+
20
+ @interp.impl(native.CZPauliChannel)
21
+ def emit_czp(
22
+ self,
23
+ emit: EmitQASM2Gate,
24
+ frame: EmitQASM2Frame,
25
+ stmt: native.CZPauliChannel,
26
+ ):
27
+ paired: bool = stmt.paired
28
+ px_ctrl: float = stmt.px_ctrl
29
+ py_ctrl: float = stmt.py_ctrl
30
+ pz_ctrl: float = stmt.pz_ctrl
31
+ px_qarg: float = stmt.pz_qarg
32
+ py_qarg: float = stmt.py_qarg
33
+ pz_qarg: float = stmt.pz_qarg
34
+ ctrls: ilist.IList[ast.Bit, Any] = frame.get(stmt.ctrls)
35
+ qargs: ilist.IList[ast.Bit, Any] = frame.get(stmt.qargs)
36
+ frame.body.append(
37
+ ast.Comment(
38
+ text=f"native.CZPauliChannel(paired={paired}, p_ctrl=[x:{px_ctrl}, y:{py_ctrl}, z:{pz_ctrl}], p_qarg[x:{px_qarg}, y:{py_qarg}, z:{pz_qarg}])"
39
+ )
40
+ )
41
+ frame.body.append(
42
+ ast.Comment(
43
+ text=f" -: ctrls: {', '.join([self._convert(q) for q in ctrls])}"
44
+ )
45
+ )
46
+ frame.body.append(
47
+ ast.Comment(
48
+ text=f" -: qargs: {', '.join([self._convert(q) for q in qargs])}"
49
+ )
50
+ )
51
+ return ()
52
+
53
+ @interp.impl(native.AtomLossChannel)
54
+ def emit_loss(
55
+ self,
56
+ emit: EmitQASM2Gate,
57
+ frame: EmitQASM2Frame,
58
+ stmt: native.AtomLossChannel,
59
+ ):
60
+ prob: float = stmt.prob
61
+ qargs: ilist.IList[ast.Bit, Any] = frame.get(stmt.qargs)
62
+ frame.body.append(ast.Comment(text=f"native.Atomloss(p={prob})"))
63
+ frame.body.append(
64
+ ast.Comment(
65
+ text=f" -: qargs: {', '.join([self._convert(q) for q in qargs])}"
66
+ )
67
+ )
68
+ return ()
69
+
70
+ @interp.impl(native.PauliChannel)
71
+ def emit_pauli(
72
+ self,
73
+ emit: EmitQASM2Gate,
74
+ frame: EmitQASM2Frame,
75
+ stmt: native.PauliChannel,
76
+ ):
77
+ px: float = stmt.px
78
+ py: float = stmt.py
79
+ pz: float = stmt.pz
80
+ qargs: ilist.IList[ast.Bit, Any] = frame.get(stmt.qargs)
81
+ frame.body.append(
82
+ ast.Comment(text=f"native.PauliChannel(px={px}, py={py}, pz={pz})")
83
+ )
84
+ frame.body.append(
85
+ ast.Comment(
86
+ text=f" -: qargs: {', '.join([self._convert(q) for q in qargs])}"
87
+ )
88
+ )
89
+ return ()
@@ -5,8 +5,11 @@ from kirin.dialects import cf, scf, func
5
5
  from kirin.ir.dialect import Dialect as Dialect
6
6
 
7
7
  from bloqade.qasm2.parse import ast
8
+ from bloqade.qasm2.dialects.uop import SingleQubitGate, TwoQubitCtrlGate
9
+ from bloqade.qasm2.dialects.expr import GateFunction
8
10
 
9
11
  from .base import EmitQASM2Base, EmitQASM2Frame
12
+ from ..dialects.core.stmts import Reset, Measure
10
13
 
11
14
 
12
15
  @dataclass
@@ -94,6 +97,24 @@ class Scf(interp.MethodTable):
94
97
 
95
98
  cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
96
99
 
100
+ # NOTE: we need exactly one of those in the then body in order to emit valid QASM2
101
+ AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
102
+
103
+ then_stmts = stmt.then_body.blocks[0].stmts
104
+ uop_stmts = 0
105
+ for s in then_stmts:
106
+ if isinstance(s, AllowedThenType):
107
+ uop_stmts += 1
108
+ continue
109
+
110
+ if isinstance(s, func.Invoke):
111
+ uop_stmts += isinstance(s.callee.code, GateFunction)
112
+
113
+ if uop_stmts != 1:
114
+ raise interp.InterpreterError(
115
+ "Cannot lower if-statement: QASM2 only allows exactly one quantum operation in the body."
116
+ )
117
+
97
118
  with emit.new_frame(stmt) as then_frame:
98
119
  then_frame.entries.update(frame.entries)
99
120
  emit.emit_block(then_frame, stmt.then_body.blocks[0])
@@ -11,6 +11,7 @@ from bloqade.qasm2.passes.glob import GlobalToParallel
11
11
  from bloqade.qasm2.passes.py2qasm import Py2QASM
12
12
  from bloqade.qasm2.passes.parallel import ParallelToUOp
13
13
 
14
+ from . import impls as impls # register the tables
14
15
  from .gate import EmitQASM2Gate
15
16
  from .main import EmitQASM2Main
16
17
 
@@ -27,6 +28,8 @@ class QASM2:
27
28
  allow_parallel: bool = False,
28
29
  allow_global: bool = False,
29
30
  custom_gate: bool = True,
31
+ unroll_ifs: bool = True,
32
+ allow_noise: bool = True,
30
33
  ) -> None:
31
34
  """Initialize the QASM2 target.
32
35
 
@@ -43,13 +46,18 @@ class QASM2:
43
46
  qelib1 (bool):
44
47
  Include the `include "qelib1.inc"` line in the resulting QASM2 AST that's
45
48
  submitted to qBraid. Defaults to `True`.
49
+
46
50
  custom_gate (bool):
47
51
  Include the custom gate definitions in the resulting QASM2 AST. Defaults to `True`. If `False`, all the qasm2.gate will be inlined.
48
52
 
53
+ unroll_ifs (bool):
54
+ Unrolls if statements with multiple qasm2 statements in the body in order to produce valid qasm2 output, which only allows a single
55
+ operation in an if body. Defaults to `True`.
56
+
49
57
 
50
58
 
51
59
  """
52
- from bloqade import qasm2
60
+ from bloqade import noise, qasm2
53
61
 
54
62
  self.main_target = qasm2.main
55
63
  self.gate_target = qasm2.gate
@@ -58,6 +66,7 @@ class QASM2:
58
66
  self.custom_gate = custom_gate
59
67
  self.allow_parallel = allow_parallel
60
68
  self.allow_global = allow_global
69
+ self.unroll_ifs = unroll_ifs
61
70
 
62
71
  if allow_parallel:
63
72
  self.main_target = self.main_target.add(qasm2.dialects.parallel)
@@ -67,7 +76,11 @@ class QASM2:
67
76
  self.main_target = self.main_target.add(qasm2.dialects.glob)
68
77
  self.gate_target = self.gate_target.add(qasm2.dialects.glob)
69
78
 
70
- if allow_global or allow_parallel:
79
+ if allow_noise:
80
+ self.main_target = self.main_target.add(noise.native)
81
+ self.gate_target = self.gate_target.add(noise.native)
82
+
83
+ if allow_global or allow_parallel or allow_noise:
71
84
  self.main_target = self.main_target.add(ilist)
72
85
  self.gate_target = self.gate_target.add(ilist)
73
86
 
@@ -87,9 +100,11 @@ class QASM2:
87
100
 
88
101
  # make a cloned instance of kernel
89
102
  entry = entry.similar()
90
- QASM2Fold(entry.dialects, inline_gate_subroutine=not self.custom_gate).fixpoint(
91
- entry
92
- )
103
+ QASM2Fold(
104
+ entry.dialects,
105
+ inline_gate_subroutine=not self.custom_gate,
106
+ unroll_ifs=self.unroll_ifs,
107
+ ).fixpoint(entry)
93
108
 
94
109
  if not self.allow_global:
95
110
  # rewrite global to parallel
bloqade/qasm2/groups.py CHANGED
@@ -2,6 +2,7 @@ from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
3
  from kirin.dialects import scf, func, ilist, lowering
4
4
 
5
+ from bloqade.noise import native
5
6
  from bloqade.qasm2.dialects import (
6
7
  uop,
7
8
  core,
@@ -90,6 +91,7 @@ def main(self):
90
91
  noise,
91
92
  parallel,
92
93
  core,
94
+ native,
93
95
  ]
94
96
  )
95
97
  )
@@ -19,19 +19,22 @@ def loadfile(file: str | pathlib.Path):
19
19
  return loads(f.read())
20
20
 
21
21
 
22
- def pprint(node: ast.Node, *, console: Console | None = None):
22
+ def pprint(node: ast.Node, *, console: Console | None = None, no_color: bool = False):
23
23
  if console:
24
- return Printer(console).visit(node)
24
+ printer = Printer(console)
25
25
  else:
26
- Printer().visit(node)
26
+ printer = Printer()
27
+ printer.console.no_color = no_color
28
+ printer.visit(node)
27
29
 
28
30
 
29
- def spprint(node: ast.Node, *, console: Console | None = None):
31
+ def spprint(node: ast.Node, *, console: Console | None = None, no_color: bool = False):
30
32
  if console:
31
33
  printer = Printer(console)
32
34
  else:
33
35
  printer = Printer()
34
36
 
37
+ printer.console.no_color = no_color
35
38
  with printer.string_io() as stream:
36
39
  printer.visit(node)
37
40
  return stream.getvalue()
@@ -1,5 +1,3 @@
1
- import os
2
- import pathlib
3
1
  from typing import Any
4
2
  from dataclasses import field, dataclass
5
3
 
@@ -19,131 +17,39 @@ class QASM2(lowering.LoweringABC[ast.Node]):
19
17
  hint_show_lineno: bool = field(default=True, kw_only=True)
20
18
  stacktrace: bool = field(default=True, kw_only=True)
21
19
 
22
- def loads(
20
+ def run(
23
21
  self,
24
- source: str,
25
- kernel_name: str,
22
+ stmt: ast.Node,
26
23
  *,
27
- returns: str | None = None,
24
+ source: str | None = None,
28
25
  globals: dict[str, Any] | None = None,
29
26
  file: str | None = None,
30
27
  lineno_offset: int = 0,
31
28
  col_offset: int = 0,
32
29
  compactify: bool = True,
33
- ) -> ir.Method:
34
- from ..parse import loads
35
-
36
- # TODO: add source info
37
- stmt = loads(source)
30
+ ) -> ir.Region:
38
31
 
39
- state = lowering.State(
40
- self,
32
+ frame = self.get_frame(
33
+ stmt,
34
+ source=source,
35
+ globals=globals,
41
36
  file=file,
42
37
  lineno_offset=lineno_offset,
43
38
  col_offset=col_offset,
44
39
  )
45
- with state.frame(
46
- [stmt],
47
- globals=globals,
48
- finalize_next=False,
49
- ) as frame:
50
- try:
51
- self.visit(state, stmt)
52
- # append return statement with the return values
53
- if returns is not None:
54
- return_value = frame.get(returns)
55
- if return_value is None:
56
- raise lowering.BuildError(f"Cannot find return value {returns}")
57
- else:
58
- return_value = func.ConstantNone()
59
- frame.push(return_value)
60
-
61
- return_node = frame.push(func.Return(value_or_stmt=return_value))
62
-
63
- except lowering.BuildError as e:
64
- hint = state.error_hint(
65
- e,
66
- max_lines=self.max_lines,
67
- indent=self.hint_indent,
68
- show_lineno=self.hint_show_lineno,
69
- )
70
- if self.stacktrace:
71
- raise Exception(
72
- f"{e.args[0]}\n\n{hint}",
73
- *e.args[1:],
74
- ) from e
75
- else:
76
- e.args = (hint,)
77
- raise e
78
-
79
- region = frame.curr_region
80
-
81
- if compactify:
82
- from kirin.rewrite import Walk, CFGCompactify
83
-
84
- Walk(CFGCompactify()).rewrite(region)
85
-
86
- code = func.Function(
87
- sym_name=kernel_name,
88
- signature=func.Signature((), return_node.value.type),
89
- body=region,
90
- )
91
40
 
92
- return ir.Method(
93
- mod=None,
94
- py_func=None,
95
- sym_name=kernel_name,
96
- arg_names=[],
97
- dialects=self.dialects,
98
- code=code,
99
- )
41
+ return frame.curr_region
100
42
 
101
- def loadfile(
102
- self,
103
- file: str | pathlib.Path,
104
- *,
105
- kernel_name: str | None = None,
106
- returns: str | None = None,
107
- globals: dict[str, Any] | None = None,
108
- lineno_offset: int = 0,
109
- col_offset: int = 0,
110
- compactify: bool = True,
111
- ) -> ir.Method:
112
- if isinstance(file, str):
113
- file = pathlib.Path(*os.path.split(file))
114
-
115
- if not file.is_file() or not file.name.endswith(".qasm"):
116
- raise ValueError("File must be a .qasm file")
117
-
118
- kernel_name = (
119
- file.name.replace(".qasm", "") if kernel_name is None else kernel_name
120
- )
121
-
122
- with file.open("r") as f:
123
- source = f.read()
124
-
125
- return self.loads(
126
- source,
127
- kernel_name,
128
- returns=returns,
129
- globals=globals,
130
- file=str(file),
131
- lineno_offset=lineno_offset,
132
- col_offset=col_offset,
133
- compactify=compactify,
134
- )
135
-
136
- def run(
43
+ def get_frame(
137
44
  self,
138
45
  stmt: ast.Node,
139
- *,
140
46
  source: str | None = None,
141
47
  globals: dict[str, Any] | None = None,
142
48
  file: str | None = None,
143
49
  lineno_offset: int = 0,
144
50
  col_offset: int = 0,
145
51
  compactify: bool = True,
146
- ) -> ir.Region:
52
+ ) -> lowering.Frame:
147
53
  # TODO: add source info
148
54
  state = lowering.State(
149
55
  self,
@@ -154,32 +60,16 @@ class QASM2(lowering.LoweringABC[ast.Node]):
154
60
  with state.frame(
155
61
  [stmt],
156
62
  globals=globals,
63
+ finalize_next=False,
157
64
  ) as frame:
158
- try:
159
- self.visit(state, stmt)
160
- except lowering.BuildError as e:
161
- hint = state.error_hint(
162
- e,
163
- max_lines=self.max_lines,
164
- indent=self.hint_indent,
165
- show_lineno=self.hint_show_lineno,
166
- )
167
- if self.stacktrace:
168
- raise Exception(
169
- f"{e.args[0]}\n\n{hint}",
170
- *e.args[1:],
171
- ) from e
172
- else:
173
- e.args = (hint,)
174
- raise e
175
-
176
- region = frame.curr_region
177
-
178
- if compactify:
179
- from kirin.rewrite import Walk, CFGCompactify
180
-
181
- Walk(CFGCompactify()).rewrite(region)
182
- return region
65
+ self.visit(state, stmt)
66
+
67
+ if compactify:
68
+ from kirin.rewrite import Walk, CFGCompactify
69
+
70
+ Walk(CFGCompactify()).rewrite(frame.curr_region)
71
+
72
+ return frame
183
73
 
184
74
  def visit(self, state: lowering.State[ast.Node], node: ast.Node) -> lowering.Result:
185
75
  name = node.__class__.__name__
@@ -9,7 +9,7 @@ version: INT "." INT
9
9
  // stmts
10
10
  include: "include" STRING ";"
11
11
  ifstmt: "if" "(" expr "==" expr ")" ifbody
12
- ifbody: qop | "{" qop* "}" // allow multiple qops
12
+ ifbody: qop
13
13
  opaque: "opaque" IDENTIFIER ["(" [params] ")"] qubits ";"
14
14
  barrier: "barrier" qubits ";"
15
15
  qreg: "qreg" IDENTIFIER "[" INT "]" ";"
@@ -3,3 +3,4 @@ from .noise import NoisePass as NoisePass
3
3
  from .py2qasm import Py2QASM as Py2QASM
4
4
  from .qasm2py import QASM2Py as QASM2Py
5
5
  from .parallel import UOpToParallel as UOpToParallel
6
+ from .unroll_if import UnrollIfs as UnrollIfs
@@ -23,6 +23,8 @@ from kirin.rewrite.abc import RewriteResult
23
23
 
24
24
  from bloqade.qasm2.dialects import expr
25
25
 
26
+ from .unroll_if import UnrollIfs
27
+
26
28
 
27
29
  @dataclass
28
30
  class QASM2Fold(Pass):
@@ -30,6 +32,7 @@ class QASM2Fold(Pass):
30
32
 
31
33
  constprop: const.Propagate = field(init=False)
32
34
  inline_gate_subroutine: bool = True
35
+ unroll_ifs: bool = True
33
36
 
34
37
  def __post_init__(self):
35
38
  self.constprop = const.Propagate(self.dialects)
@@ -61,6 +64,9 @@ class QASM2Fold(Pass):
61
64
  .join(result)
62
65
  )
63
66
 
67
+ if self.unroll_ifs:
68
+ UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69
+
64
70
  # run typeinfer again after unroll etc. because we now insert
65
71
  # a lot of new nodes, which might have more precise types
66
72
  self.typeinfer.unsafe_run(mt)
@@ -21,6 +21,34 @@ class NoisePass(Pass):
21
21
  NOTE: This pass is not guaranteed to be supported long-term in bloqade. We will be
22
22
  moving towards a more general approach to noise modeling in the future.
23
23
 
24
+ ## Usage examples
25
+
26
+ ```
27
+ from bloqade import qasm2
28
+ from bloqade.noise import native
29
+ from bloqade.qasm2.passes.noise import NoisePass
30
+
31
+ noise_main = qasm2.extended.add(native.dialect)
32
+
33
+ @noise_main
34
+ def main():
35
+ q = qasm2.qreg(2)
36
+ qasm2.h(q[0])
37
+ qasm2.cx(q[0], q[1])
38
+ return q
39
+
40
+ # simple IR without any nosie
41
+ main.print()
42
+
43
+ noise_pass = NoisePass(noise_main)
44
+
45
+ # rewrite stuff in-place
46
+ noise_pass.unsafe_run(main)
47
+
48
+ # now, we do have noise channels in the IR
49
+ main.print()
50
+ ```
51
+
24
52
  """
25
53
 
26
54
  noise_model: native.MoveNoiseModelABC = field(
@@ -34,13 +62,32 @@ class NoisePass(Pass):
34
62
  def __post_init__(self):
35
63
  self.address_analysis = address.AddressAnalysis(self.dialects)
36
64
 
65
+ def get_qubit_values(self, mt: ir.Method):
66
+ frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
67
+ qubit_ssa_values = {}
68
+ # Traverse statements in block order to fine the first SSA value for each qubit
69
+ for block in mt.callable_region.blocks:
70
+ for stmt in block.stmts:
71
+ if len(stmt.results) != 1:
72
+ continue
73
+
74
+ addr = frame.entries.get(result := stmt.results[0])
75
+ if (
76
+ isinstance(addr, address.AddressQubit)
77
+ and (index := addr.data) not in qubit_ssa_values
78
+ ):
79
+ qubit_ssa_values[index] = result
80
+
81
+ return qubit_ssa_values, frame.entries
82
+
37
83
  def unsafe_run(self, mt: ir.Method):
38
84
  result = LiftQubits(self.dialects).unsafe_run(mt)
39
- frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
85
+ qubit_ssa_value, address_analysis = self.get_qubit_values(mt)
40
86
  result = (
41
87
  Walk(
42
88
  NoiseRewriteRule(
43
- address_analysis=frame.entries,
89
+ qubit_ssa_value=qubit_ssa_value,
90
+ address_analysis=address_analysis,
44
91
  noise_model=self.noise_model,
45
92
  gate_noise_params=self.gate_noise_params,
46
93
  ),
@@ -49,5 +96,6 @@ class NoisePass(Pass):
49
96
  .rewrite(mt.code)
50
97
  .join(result)
51
98
  )
99
+
52
100
  result = Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
53
101
  return result
@@ -27,6 +27,7 @@ from bloqade.qasm2.rewrite import (
27
27
  RaiseRegisterRule,
28
28
  UOpToParallelRule,
29
29
  SimpleOptimalMergePolicy,
30
+ RydbergGateSetRewriteRule,
30
31
  )
31
32
  from bloqade.squin.analysis import schedule
32
33
 
@@ -135,6 +136,7 @@ class UOpToParallel(Pass):
135
136
  """
136
137
 
137
138
  merge_policy_type: Type[MergePolicyABC] = SimpleOptimalMergePolicy
139
+ rewrite_to_native_first: bool = False
138
140
  constprop: const.Propagate = field(init=False)
139
141
 
140
142
  def __post_init__(self):
@@ -147,6 +149,13 @@ class UOpToParallel(Pass):
147
149
  if not result.has_done_something:
148
150
  return result
149
151
 
152
+ if self.rewrite_to_native_first:
153
+ result = (
154
+ Fixpoint(Walk(RydbergGateSetRewriteRule(self.dialects)))
155
+ .rewrite(mt.code)
156
+ .join(result)
157
+ )
158
+
150
159
  frame, _ = self.constprop.run_analysis(mt)
151
160
  result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
152
161