bloqade-circuit 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (79) hide show
  1. bloqade/analysis/address/impls.py +14 -0
  2. bloqade/noise/fidelity.py +3 -3
  3. bloqade/noise/native/_dialect.py +1 -1
  4. bloqade/noise/native/_wrappers.py +35 -6
  5. bloqade/noise/native/stmts.py +1 -1
  6. bloqade/pyqrack/device.py +1 -3
  7. bloqade/pyqrack/qasm2/core.py +4 -1
  8. bloqade/pyqrack/squin/qubit.py +16 -9
  9. bloqade/pyqrack/squin/wire.py +22 -4
  10. bloqade/pyqrack/task.py +13 -5
  11. bloqade/qasm2/__init__.py +1 -0
  12. bloqade/qasm2/_qasm_loading.py +151 -0
  13. bloqade/qasm2/dialects/core/__init__.py +9 -1
  14. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  15. bloqade/qasm2/dialects/noise.py +33 -1
  16. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  17. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  18. bloqade/qasm2/emit/impls/__init__.py +1 -0
  19. bloqade/qasm2/emit/impls/noise_native.py +89 -0
  20. bloqade/qasm2/emit/main.py +21 -0
  21. bloqade/qasm2/emit/target.py +20 -5
  22. bloqade/qasm2/groups.py +2 -0
  23. bloqade/qasm2/parse/__init__.py +7 -4
  24. bloqade/qasm2/parse/lowering.py +20 -130
  25. bloqade/qasm2/parse/qasm2.lark +1 -1
  26. bloqade/qasm2/passes/__init__.py +1 -0
  27. bloqade/qasm2/passes/fold.py +6 -0
  28. bloqade/qasm2/passes/noise.py +22 -2
  29. bloqade/qasm2/passes/parallel.py +9 -0
  30. bloqade/qasm2/passes/unroll_if.py +25 -0
  31. bloqade/qasm2/rewrite/__init__.py +1 -0
  32. bloqade/qasm2/rewrite/desugar.py +3 -2
  33. bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
  34. bloqade/qasm2/rewrite/native_gates.py +67 -4
  35. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  36. bloqade/squin/analysis/nsites/__init__.py +1 -0
  37. bloqade/squin/analysis/nsites/impls.py +25 -1
  38. bloqade/squin/noise/__init__.py +7 -26
  39. bloqade/squin/noise/_wrapper.py +25 -0
  40. bloqade/squin/op/__init__.py +33 -159
  41. bloqade/squin/op/_wrapper.py +101 -0
  42. bloqade/squin/op/stdlib.py +62 -0
  43. bloqade/squin/passes/__init__.py +1 -0
  44. bloqade/squin/passes/stim.py +68 -0
  45. bloqade/squin/rewrite/__init__.py +11 -0
  46. bloqade/squin/rewrite/qubit_to_stim.py +84 -0
  47. bloqade/squin/rewrite/squin_measure.py +98 -0
  48. bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
  49. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  50. bloqade/squin/rewrite/wire_to_stim.py +73 -0
  51. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  52. bloqade/squin/wire.py +1 -13
  53. bloqade/stim/__init__.py +39 -5
  54. bloqade/stim/_wrappers.py +14 -12
  55. bloqade/stim/dialects/__init__.py +1 -5
  56. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
  57. bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
  58. bloqade/stim/dialects/collapse/__init__.py +13 -2
  59. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
  60. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  61. bloqade/stim/dialects/gate/__init__.py +16 -1
  62. bloqade/stim/dialects/gate/emit.py +1 -1
  63. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  64. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  65. bloqade/stim/dialects/noise/emit.py +1 -1
  66. bloqade/stim/emit/__init__.py +1 -1
  67. bloqade/stim/groups.py +4 -2
  68. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
  69. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +79 -63
  70. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  71. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  72. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  73. /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
  74. /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
  75. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  76. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  77. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  78. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
  79. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,8 +5,11 @@ from kirin.dialects import cf, scf, func
5
5
  from kirin.ir.dialect import Dialect as Dialect
6
6
 
7
7
  from bloqade.qasm2.parse import ast
8
+ from bloqade.qasm2.dialects.uop import SingleQubitGate, TwoQubitCtrlGate
9
+ from bloqade.qasm2.dialects.expr import GateFunction
8
10
 
9
11
  from .base import EmitQASM2Base, EmitQASM2Frame
12
+ from ..dialects.core.stmts import Reset, Measure
10
13
 
11
14
 
12
15
  @dataclass
@@ -94,6 +97,24 @@ class Scf(interp.MethodTable):
94
97
 
95
98
  cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
96
99
 
100
+ # NOTE: we need exactly one of those in the then body in order to emit valid QASM2
101
+ AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
102
+
103
+ then_stmts = stmt.then_body.blocks[0].stmts
104
+ uop_stmts = 0
105
+ for s in then_stmts:
106
+ if isinstance(s, AllowedThenType):
107
+ uop_stmts += 1
108
+ continue
109
+
110
+ if isinstance(s, func.Invoke):
111
+ uop_stmts += isinstance(s.callee.code, GateFunction)
112
+
113
+ if uop_stmts != 1:
114
+ raise interp.InterpreterError(
115
+ "Cannot lower if-statement: QASM2 only allows exactly one quantum operation in the body."
116
+ )
117
+
97
118
  with emit.new_frame(stmt) as then_frame:
98
119
  then_frame.entries.update(frame.entries)
99
120
  emit.emit_block(then_frame, stmt.then_body.blocks[0])
@@ -11,6 +11,7 @@ from bloqade.qasm2.passes.glob import GlobalToParallel
11
11
  from bloqade.qasm2.passes.py2qasm import Py2QASM
12
12
  from bloqade.qasm2.passes.parallel import ParallelToUOp
13
13
 
14
+ from . import impls as impls # register the tables
14
15
  from .gate import EmitQASM2Gate
15
16
  from .main import EmitQASM2Main
16
17
 
@@ -27,6 +28,8 @@ class QASM2:
27
28
  allow_parallel: bool = False,
28
29
  allow_global: bool = False,
29
30
  custom_gate: bool = True,
31
+ unroll_ifs: bool = True,
32
+ allow_noise: bool = True,
30
33
  ) -> None:
31
34
  """Initialize the QASM2 target.
32
35
 
@@ -43,13 +46,18 @@ class QASM2:
43
46
  qelib1 (bool):
44
47
  Include the `include "qelib1.inc"` line in the resulting QASM2 AST that's
45
48
  submitted to qBraid. Defaults to `True`.
49
+
46
50
  custom_gate (bool):
47
51
  Include the custom gate definitions in the resulting QASM2 AST. Defaults to `True`. If `False`, all the qasm2.gate will be inlined.
48
52
 
53
+ unroll_ifs (bool):
54
+ Unrolls if statements with multiple qasm2 statements in the body in order to produce valid qasm2 output, which only allows a single
55
+ operation in an if body. Defaults to `True`.
56
+
49
57
 
50
58
 
51
59
  """
52
- from bloqade import qasm2
60
+ from bloqade import noise, qasm2
53
61
 
54
62
  self.main_target = qasm2.main
55
63
  self.gate_target = qasm2.gate
@@ -58,6 +66,7 @@ class QASM2:
58
66
  self.custom_gate = custom_gate
59
67
  self.allow_parallel = allow_parallel
60
68
  self.allow_global = allow_global
69
+ self.unroll_ifs = unroll_ifs
61
70
 
62
71
  if allow_parallel:
63
72
  self.main_target = self.main_target.add(qasm2.dialects.parallel)
@@ -67,7 +76,11 @@ class QASM2:
67
76
  self.main_target = self.main_target.add(qasm2.dialects.glob)
68
77
  self.gate_target = self.gate_target.add(qasm2.dialects.glob)
69
78
 
70
- if allow_global or allow_parallel:
79
+ if allow_noise:
80
+ self.main_target = self.main_target.add(noise.native)
81
+ self.gate_target = self.gate_target.add(noise.native)
82
+
83
+ if allow_global or allow_parallel or allow_noise:
71
84
  self.main_target = self.main_target.add(ilist)
72
85
  self.gate_target = self.gate_target.add(ilist)
73
86
 
@@ -87,9 +100,11 @@ class QASM2:
87
100
 
88
101
  # make a cloned instance of kernel
89
102
  entry = entry.similar()
90
- QASM2Fold(entry.dialects, inline_gate_subroutine=not self.custom_gate).fixpoint(
91
- entry
92
- )
103
+ QASM2Fold(
104
+ entry.dialects,
105
+ inline_gate_subroutine=not self.custom_gate,
106
+ unroll_ifs=self.unroll_ifs,
107
+ ).fixpoint(entry)
93
108
 
94
109
  if not self.allow_global:
95
110
  # rewrite global to parallel
bloqade/qasm2/groups.py CHANGED
@@ -2,6 +2,7 @@ from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
3
  from kirin.dialects import scf, func, ilist, lowering
4
4
 
5
+ from bloqade.noise import native
5
6
  from bloqade.qasm2.dialects import (
6
7
  uop,
7
8
  core,
@@ -90,6 +91,7 @@ def main(self):
90
91
  noise,
91
92
  parallel,
92
93
  core,
94
+ native,
93
95
  ]
94
96
  )
95
97
  )
@@ -19,19 +19,22 @@ def loadfile(file: str | pathlib.Path):
19
19
  return loads(f.read())
20
20
 
21
21
 
22
- def pprint(node: ast.Node, *, console: Console | None = None):
22
+ def pprint(node: ast.Node, *, console: Console | None = None, no_color: bool = False):
23
23
  if console:
24
- return Printer(console).visit(node)
24
+ printer = Printer(console)
25
25
  else:
26
- Printer().visit(node)
26
+ printer = Printer()
27
+ printer.console.no_color = no_color
28
+ printer.visit(node)
27
29
 
28
30
 
29
- def spprint(node: ast.Node, *, console: Console | None = None):
31
+ def spprint(node: ast.Node, *, console: Console | None = None, no_color: bool = False):
30
32
  if console:
31
33
  printer = Printer(console)
32
34
  else:
33
35
  printer = Printer()
34
36
 
37
+ printer.console.no_color = no_color
35
38
  with printer.string_io() as stream:
36
39
  printer.visit(node)
37
40
  return stream.getvalue()
@@ -1,5 +1,3 @@
1
- import os
2
- import pathlib
3
1
  from typing import Any
4
2
  from dataclasses import field, dataclass
5
3
 
@@ -19,131 +17,39 @@ class QASM2(lowering.LoweringABC[ast.Node]):
19
17
  hint_show_lineno: bool = field(default=True, kw_only=True)
20
18
  stacktrace: bool = field(default=True, kw_only=True)
21
19
 
22
- def loads(
20
+ def run(
23
21
  self,
24
- source: str,
25
- kernel_name: str,
22
+ stmt: ast.Node,
26
23
  *,
27
- returns: str | None = None,
24
+ source: str | None = None,
28
25
  globals: dict[str, Any] | None = None,
29
26
  file: str | None = None,
30
27
  lineno_offset: int = 0,
31
28
  col_offset: int = 0,
32
29
  compactify: bool = True,
33
- ) -> ir.Method:
34
- from ..parse import loads
35
-
36
- # TODO: add source info
37
- stmt = loads(source)
30
+ ) -> ir.Region:
38
31
 
39
- state = lowering.State(
40
- self,
32
+ frame = self.get_frame(
33
+ stmt,
34
+ source=source,
35
+ globals=globals,
41
36
  file=file,
42
37
  lineno_offset=lineno_offset,
43
38
  col_offset=col_offset,
44
39
  )
45
- with state.frame(
46
- [stmt],
47
- globals=globals,
48
- finalize_next=False,
49
- ) as frame:
50
- try:
51
- self.visit(state, stmt)
52
- # append return statement with the return values
53
- if returns is not None:
54
- return_value = frame.get(returns)
55
- if return_value is None:
56
- raise lowering.BuildError(f"Cannot find return value {returns}")
57
- else:
58
- return_value = func.ConstantNone()
59
- frame.push(return_value)
60
-
61
- return_node = frame.push(func.Return(value_or_stmt=return_value))
62
-
63
- except lowering.BuildError as e:
64
- hint = state.error_hint(
65
- e,
66
- max_lines=self.max_lines,
67
- indent=self.hint_indent,
68
- show_lineno=self.hint_show_lineno,
69
- )
70
- if self.stacktrace:
71
- raise Exception(
72
- f"{e.args[0]}\n\n{hint}",
73
- *e.args[1:],
74
- ) from e
75
- else:
76
- e.args = (hint,)
77
- raise e
78
-
79
- region = frame.curr_region
80
-
81
- if compactify:
82
- from kirin.rewrite import Walk, CFGCompactify
83
-
84
- Walk(CFGCompactify()).rewrite(region)
85
-
86
- code = func.Function(
87
- sym_name=kernel_name,
88
- signature=func.Signature((), return_node.value.type),
89
- body=region,
90
- )
91
40
 
92
- return ir.Method(
93
- mod=None,
94
- py_func=None,
95
- sym_name=kernel_name,
96
- arg_names=[],
97
- dialects=self.dialects,
98
- code=code,
99
- )
41
+ return frame.curr_region
100
42
 
101
- def loadfile(
102
- self,
103
- file: str | pathlib.Path,
104
- *,
105
- kernel_name: str | None = None,
106
- returns: str | None = None,
107
- globals: dict[str, Any] | None = None,
108
- lineno_offset: int = 0,
109
- col_offset: int = 0,
110
- compactify: bool = True,
111
- ) -> ir.Method:
112
- if isinstance(file, str):
113
- file = pathlib.Path(*os.path.split(file))
114
-
115
- if not file.is_file() or not file.name.endswith(".qasm"):
116
- raise ValueError("File must be a .qasm file")
117
-
118
- kernel_name = (
119
- file.name.replace(".qasm", "") if kernel_name is None else kernel_name
120
- )
121
-
122
- with file.open("r") as f:
123
- source = f.read()
124
-
125
- return self.loads(
126
- source,
127
- kernel_name,
128
- returns=returns,
129
- globals=globals,
130
- file=str(file),
131
- lineno_offset=lineno_offset,
132
- col_offset=col_offset,
133
- compactify=compactify,
134
- )
135
-
136
- def run(
43
+ def get_frame(
137
44
  self,
138
45
  stmt: ast.Node,
139
- *,
140
46
  source: str | None = None,
141
47
  globals: dict[str, Any] | None = None,
142
48
  file: str | None = None,
143
49
  lineno_offset: int = 0,
144
50
  col_offset: int = 0,
145
51
  compactify: bool = True,
146
- ) -> ir.Region:
52
+ ) -> lowering.Frame:
147
53
  # TODO: add source info
148
54
  state = lowering.State(
149
55
  self,
@@ -154,32 +60,16 @@ class QASM2(lowering.LoweringABC[ast.Node]):
154
60
  with state.frame(
155
61
  [stmt],
156
62
  globals=globals,
63
+ finalize_next=False,
157
64
  ) as frame:
158
- try:
159
- self.visit(state, stmt)
160
- except lowering.BuildError as e:
161
- hint = state.error_hint(
162
- e,
163
- max_lines=self.max_lines,
164
- indent=self.hint_indent,
165
- show_lineno=self.hint_show_lineno,
166
- )
167
- if self.stacktrace:
168
- raise Exception(
169
- f"{e.args[0]}\n\n{hint}",
170
- *e.args[1:],
171
- ) from e
172
- else:
173
- e.args = (hint,)
174
- raise e
175
-
176
- region = frame.curr_region
177
-
178
- if compactify:
179
- from kirin.rewrite import Walk, CFGCompactify
180
-
181
- Walk(CFGCompactify()).rewrite(region)
182
- return region
65
+ self.visit(state, stmt)
66
+
67
+ if compactify:
68
+ from kirin.rewrite import Walk, CFGCompactify
69
+
70
+ Walk(CFGCompactify()).rewrite(frame.curr_region)
71
+
72
+ return frame
183
73
 
184
74
  def visit(self, state: lowering.State[ast.Node], node: ast.Node) -> lowering.Result:
185
75
  name = node.__class__.__name__
@@ -9,7 +9,7 @@ version: INT "." INT
9
9
  // stmts
10
10
  include: "include" STRING ";"
11
11
  ifstmt: "if" "(" expr "==" expr ")" ifbody
12
- ifbody: qop | "{" qop* "}" // allow multiple qops
12
+ ifbody: qop
13
13
  opaque: "opaque" IDENTIFIER ["(" [params] ")"] qubits ";"
14
14
  barrier: "barrier" qubits ";"
15
15
  qreg: "qreg" IDENTIFIER "[" INT "]" ";"
@@ -3,3 +3,4 @@ from .noise import NoisePass as NoisePass
3
3
  from .py2qasm import Py2QASM as Py2QASM
4
4
  from .qasm2py import QASM2Py as QASM2Py
5
5
  from .parallel import UOpToParallel as UOpToParallel
6
+ from .unroll_if import UnrollIfs as UnrollIfs
@@ -23,6 +23,8 @@ from kirin.rewrite.abc import RewriteResult
23
23
 
24
24
  from bloqade.qasm2.dialects import expr
25
25
 
26
+ from .unroll_if import UnrollIfs
27
+
26
28
 
27
29
  @dataclass
28
30
  class QASM2Fold(Pass):
@@ -30,6 +32,7 @@ class QASM2Fold(Pass):
30
32
 
31
33
  constprop: const.Propagate = field(init=False)
32
34
  inline_gate_subroutine: bool = True
35
+ unroll_ifs: bool = True
33
36
 
34
37
  def __post_init__(self):
35
38
  self.constprop = const.Propagate(self.dialects)
@@ -61,6 +64,9 @@ class QASM2Fold(Pass):
61
64
  .join(result)
62
65
  )
63
66
 
67
+ if self.unroll_ifs:
68
+ UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69
+
64
70
  # run typeinfer again after unroll etc. because we now insert
65
71
  # a lot of new nodes, which might have more precise types
66
72
  self.typeinfer.unsafe_run(mt)
@@ -62,13 +62,32 @@ class NoisePass(Pass):
62
62
  def __post_init__(self):
63
63
  self.address_analysis = address.AddressAnalysis(self.dialects)
64
64
 
65
+ def get_qubit_values(self, mt: ir.Method):
66
+ frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
67
+ qubit_ssa_values = {}
68
+ # Traverse statements in block order to fine the first SSA value for each qubit
69
+ for block in mt.callable_region.blocks:
70
+ for stmt in block.stmts:
71
+ if len(stmt.results) != 1:
72
+ continue
73
+
74
+ addr = frame.entries.get(result := stmt.results[0])
75
+ if (
76
+ isinstance(addr, address.AddressQubit)
77
+ and (index := addr.data) not in qubit_ssa_values
78
+ ):
79
+ qubit_ssa_values[index] = result
80
+
81
+ return qubit_ssa_values, frame.entries
82
+
65
83
  def unsafe_run(self, mt: ir.Method):
66
84
  result = LiftQubits(self.dialects).unsafe_run(mt)
67
- frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
85
+ qubit_ssa_value, address_analysis = self.get_qubit_values(mt)
68
86
  result = (
69
87
  Walk(
70
88
  NoiseRewriteRule(
71
- address_analysis=frame.entries,
89
+ qubit_ssa_value=qubit_ssa_value,
90
+ address_analysis=address_analysis,
72
91
  noise_model=self.noise_model,
73
92
  gate_noise_params=self.gate_noise_params,
74
93
  ),
@@ -77,5 +96,6 @@ class NoisePass(Pass):
77
96
  .rewrite(mt.code)
78
97
  .join(result)
79
98
  )
99
+
80
100
  result = Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
81
101
  return result
@@ -27,6 +27,7 @@ from bloqade.qasm2.rewrite import (
27
27
  RaiseRegisterRule,
28
28
  UOpToParallelRule,
29
29
  SimpleOptimalMergePolicy,
30
+ RydbergGateSetRewriteRule,
30
31
  )
31
32
  from bloqade.squin.analysis import schedule
32
33
 
@@ -135,6 +136,7 @@ class UOpToParallel(Pass):
135
136
  """
136
137
 
137
138
  merge_policy_type: Type[MergePolicyABC] = SimpleOptimalMergePolicy
139
+ rewrite_to_native_first: bool = False
138
140
  constprop: const.Propagate = field(init=False)
139
141
 
140
142
  def __post_init__(self):
@@ -147,6 +149,13 @@ class UOpToParallel(Pass):
147
149
  if not result.has_done_something:
148
150
  return result
149
151
 
152
+ if self.rewrite_to_native_first:
153
+ result = (
154
+ Fixpoint(Walk(RydbergGateSetRewriteRule(self.dialects)))
155
+ .rewrite(mt.code)
156
+ .join(result)
157
+ )
158
+
150
159
  frame, _ = self.constprop.run_analysis(mt)
151
160
  result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
152
161
 
@@ -0,0 +1,25 @@
1
+ from kirin import ir
2
+ from kirin.passes import Pass
3
+ from kirin.rewrite import (
4
+ Walk,
5
+ Chain,
6
+ Fixpoint,
7
+ ConstantFold,
8
+ CommonSubexpressionElimination,
9
+ )
10
+
11
+ from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12
+
13
+
14
+ class UnrollIfs(Pass):
15
+ """This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16
+
17
+ def unsafe_run(self, mt: ir.Method):
18
+ result = Walk(LiftThenBody()).rewrite(mt.code)
19
+ result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20
+ result = (
21
+ Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
22
+ .rewrite(mt.code)
23
+ .join(result)
24
+ )
25
+ return result
@@ -3,6 +3,7 @@ from .glob import (
3
3
  GlobalToParallelRule as GlobalToParallelRule,
4
4
  )
5
5
  from .register import RaiseRegisterRule as RaiseRegisterRule
6
+ from .native_gates import RydbergGateSetRewriteRule as RydbergGateSetRewriteRule
6
7
  from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
7
8
  from .uop_to_parallel import (
8
9
  MergePolicyABC as MergePolicyABC,
@@ -5,16 +5,17 @@ from kirin.passes import Pass
5
5
  from kirin.rewrite import abc, walk
6
6
  from kirin.dialects import py
7
7
 
8
+ from bloqade.qasm2 import types
8
9
  from bloqade.qasm2.dialects import core
9
10
 
10
11
 
11
12
  class IndexingDesugarRule(abc.RewriteRule):
12
13
  def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
13
14
  if isinstance(node, py.indexing.GetItem):
14
- if node.obj.type.is_subseteq(core.QRegType):
15
+ if node.obj.type.is_subseteq(types.QRegType):
15
16
  node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
16
17
  return abc.RewriteResult(has_done_something=True)
17
- elif node.obj.type.is_subseteq(core.CRegType):
18
+ elif node.obj.type.is_subseteq(types.CRegType):
18
19
  node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
19
20
  return abc.RewriteResult(has_done_something=True)
20
21
 
@@ -18,6 +18,7 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
18
18
  """
19
19
 
20
20
  address_analysis: Dict[ir.SSAValue, address.Address]
21
+ qubit_ssa_value: Dict[int, ir.SSAValue]
21
22
  gate_noise_params: native.GateNoiseParams = field(
22
23
  default_factory=native.GateNoiseParams
23
24
  )
@@ -25,15 +26,6 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
25
26
  default_factory=native.TwoRowZoneModel
26
27
  )
27
28
 
28
- def __post_init__(self):
29
- self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
30
- for ssa_value, addr in self.address_analysis.items():
31
- if (
32
- isinstance(addr, address.AddressQubit)
33
- and ssa_value not in self.qubit_ssa_value
34
- ):
35
- self.qubit_ssa_value[addr.data] = ssa_value
36
-
37
29
  def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
38
30
  if isinstance(node, uop.SingleQubitGate):
39
31
  return self.rewrite_single_qubit_gate(node)
@@ -173,6 +173,53 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
173
173
  cirq.XPowGate(exponent=0.5).on(self.cached_qubits[0]), node
174
174
  )
175
175
 
176
+ def rewrite_ccx(self, node: uop.CCX) -> abc.RewriteResult:
177
+ # from https://algassert.com/quirk#circuit=%7B%22cols%22:%5B%5B%22QFT3%22%5D,%5B%22inputA3%22,1,1,%22+=A3%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22%E2%80%A2%22,%22X%22%5D,%5B1,1,1,%22%E2%80%A6%22,%22%E2%80%A6%22,%22%E2%80%A6%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E%C2%BC%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22Z%5E%C2%BC%22,%22Z%5E%C2%BC%22%5D,%5B1,1,1,1,%22H%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22X%5E-%C2%BC%22,%22X%5E%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22H%22%5D%5D%7D
178
+
179
+ # x^(1/4)
180
+ lam1, theta1, phi1 = map(
181
+ self.const_float,
182
+ map(around, (1.5707963267948966, 0.7853981633974483, -1.5707963267948966)),
183
+ )
184
+ lam1.insert_before(node)
185
+ theta1.insert_before(node)
186
+ phi1.insert_before(node)
187
+
188
+ lam1 = lam1.result
189
+ theta1 = theta1.result
190
+ phi1 = phi1.result
191
+
192
+ # x^(-1/4)
193
+ lam2, theta2, phi2 = map(
194
+ self.const_float,
195
+ map(around, (4.71238898038469, 0.7853981633974483, 1.5707963267948966)),
196
+ )
197
+ lam2.insert_before(node)
198
+ theta2.insert_before(node)
199
+ phi2.insert_before(node)
200
+ lam2 = lam2.result
201
+ theta2 = theta2.result
202
+ phi2 = phi2.result
203
+
204
+ uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
205
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
206
+ uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
207
+ uop.UGate(node.qarg, theta1, phi1, lam1).insert_before(node)
208
+ uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
209
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
210
+ uop.T(node.ctrl1).insert_before(node)
211
+ uop.T(node.ctrl2).insert_before(node)
212
+ uop.H(node.ctrl1).insert_before(node)
213
+ uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
214
+ uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
215
+ uop.UGate(node.ctrl1, theta2, phi2, lam2).insert_before(node)
216
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
217
+ uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
218
+ uop.H(node.ctrl1).insert_before(node)
219
+ node.delete() # delete the original CCX gate
220
+
221
+ return abc.RewriteResult(has_done_something=True)
222
+
176
223
  def rewrite_sxdg(self, node: uop.SXdag) -> abc.RewriteResult:
177
224
  return self._rewrite_1q_gates(
178
225
  cirq.XPowGate(exponent=-0.5).on(self.cached_qubits[0]), node
@@ -394,9 +441,12 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
394
441
  new_gate_stmts = self._generate_1q_gate_stmts(cirq_gate, node.qarg)
395
442
  return self._rewrite_gate_stmts(new_gate_stmts, node)
396
443
 
397
- def _generate_2q_ctrl_gate_stmts(
444
+ def _generate_multi_ctrl_gate_stmts(
398
445
  self, cirq_gate: cirq.Operation, qubits_ssa: List[ir.SSAValue]
399
446
  ) -> list[ir.Statement]:
447
+ qubit_to_ssa_map = {
448
+ q: ssa for q, ssa in zip(self.cached_qubits[: len(qubits_ssa)], qubits_ssa)
449
+ }
400
450
  target_gates = self.gateset.decompose_to_target_gateset(cirq_gate, 0)
401
451
  new_stmts = []
402
452
  for new_gate in target_gates:
@@ -412,7 +462,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
412
462
  new_stmts.append(phi2_stmt)
413
463
  new_stmts.append(
414
464
  uop.UGate(
415
- qarg=qubits_ssa[new_gate.qubits[0].x],
465
+ qarg=qubit_to_ssa_map[new_gate.qubits[0]],
416
466
  theta=phi0_stmt.result,
417
467
  phi=phi1_stmt.result,
418
468
  lam=phi2_stmt.result,
@@ -420,18 +470,31 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
420
470
  )
421
471
  else:
422
472
  # 2q
423
- new_stmts.append(uop.CZ(ctrl=qubits_ssa[0], qarg=qubits_ssa[1]))
473
+ new_stmts.append(
474
+ uop.CZ(
475
+ ctrl=qubit_to_ssa_map[new_gate.qubits[0]],
476
+ qarg=qubit_to_ssa_map[new_gate.qubits[1]],
477
+ )
478
+ )
424
479
 
425
480
  return new_stmts
426
481
 
427
482
  def _rewrite_2q_ctrl_gates(
428
483
  self, cirq_gate: cirq.Operation, node: uop.TwoQubitCtrlGate
429
484
  ) -> abc.RewriteResult:
430
- new_gate_stmts = self._generate_2q_ctrl_gate_stmts(
485
+ new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
431
486
  cirq_gate, [node.ctrl, node.qarg]
432
487
  )
433
488
  return self._rewrite_gate_stmts(new_gate_stmts, node)
434
489
 
490
+ def _rewrite_3q_ctrl_gates(
491
+ self, cirq_gate: cirq.Operation, node: uop.CCX
492
+ ) -> abc.RewriteResult:
493
+ new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
494
+ cirq_gate, [node.ctrl1, node.ctrl2, node.qarg]
495
+ )
496
+ return self._rewrite_gate_stmts(new_gate_stmts, node)
497
+
435
498
  def _rewrite_gate_stmts(
436
499
  self, new_gate_stmts: list[ir.Statement], node: ir.Statement
437
500
  ):