bloqade-circuit 0.2.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (101) hide show
  1. bloqade/analysis/address/impls.py +3 -2
  2. bloqade/pyqrack/device.py +1 -3
  3. bloqade/pyqrack/noise/native.py +8 -8
  4. bloqade/pyqrack/qasm2/core.py +4 -1
  5. bloqade/pyqrack/squin/op.py +7 -0
  6. bloqade/pyqrack/squin/qubit.py +5 -27
  7. bloqade/pyqrack/squin/runtime.py +18 -0
  8. bloqade/pyqrack/squin/wire.py +4 -22
  9. bloqade/pyqrack/task.py +13 -5
  10. bloqade/qasm2/__init__.py +1 -0
  11. bloqade/qasm2/_qasm_loading.py +151 -0
  12. bloqade/qasm2/dialects/core/__init__.py +9 -1
  13. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  14. bloqade/{noise/native → qasm2/dialects/noise}/__init__.py +1 -7
  15. bloqade/qasm2/dialects/noise/_dialect.py +3 -0
  16. bloqade/{noise → qasm2/dialects/noise}/fidelity.py +4 -4
  17. bloqade/qasm2/dialects/noise/model.py +278 -0
  18. bloqade/{noise/native → qasm2/dialects/noise}/stmts.py +1 -1
  19. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  20. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  21. bloqade/qasm2/emit/impls/__init__.py +1 -0
  22. bloqade/qasm2/emit/impls/noise.py +89 -0
  23. bloqade/qasm2/emit/main.py +23 -4
  24. bloqade/qasm2/emit/target.py +19 -4
  25. bloqade/qasm2/noise.py +67 -0
  26. bloqade/qasm2/parse/__init__.py +7 -4
  27. bloqade/qasm2/parse/lowering.py +20 -130
  28. bloqade/qasm2/parse/qasm2.lark +1 -1
  29. bloqade/qasm2/passes/__init__.py +1 -0
  30. bloqade/qasm2/passes/fold.py +6 -0
  31. bloqade/qasm2/passes/glob.py +12 -8
  32. bloqade/qasm2/passes/noise.py +27 -16
  33. bloqade/qasm2/passes/parallel.py +9 -0
  34. bloqade/qasm2/passes/unroll_if.py +25 -0
  35. bloqade/qasm2/rewrite/__init__.py +3 -0
  36. bloqade/qasm2/rewrite/desugar.py +3 -2
  37. bloqade/qasm2/rewrite/native_gates.py +67 -4
  38. bloqade/qasm2/rewrite/noise/__init__.py +0 -0
  39. bloqade/qasm2/rewrite/{heuristic_noise.py → noise/heuristic_noise.py} +32 -62
  40. bloqade/{noise/native/rewrite.py → qasm2/rewrite/noise/remove_noise.py} +2 -2
  41. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  42. bloqade/qbraid/lowering.py +8 -8
  43. bloqade/squin/__init__.py +7 -1
  44. bloqade/squin/analysis/nsites/__init__.py +1 -0
  45. bloqade/squin/analysis/nsites/impls.py +16 -1
  46. bloqade/squin/groups.py +4 -4
  47. bloqade/squin/lowering.py +27 -0
  48. bloqade/squin/noise/__init__.py +7 -26
  49. bloqade/squin/noise/_wrapper.py +25 -0
  50. bloqade/squin/op/__init__.py +34 -159
  51. bloqade/squin/op/_wrapper.py +105 -0
  52. bloqade/squin/op/stdlib.py +62 -0
  53. bloqade/squin/op/stmts.py +10 -0
  54. bloqade/squin/passes/__init__.py +1 -0
  55. bloqade/squin/passes/stim.py +68 -0
  56. bloqade/squin/qubit.py +32 -37
  57. bloqade/squin/rewrite/__init__.py +11 -0
  58. bloqade/squin/rewrite/desugar.py +65 -0
  59. bloqade/squin/rewrite/qubit_to_stim.py +61 -0
  60. bloqade/squin/rewrite/squin_measure.py +73 -0
  61. bloqade/squin/rewrite/stim_rewrite_util.py +153 -0
  62. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  63. bloqade/squin/rewrite/wire_to_stim.py +52 -0
  64. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  65. bloqade/squin/wire.py +5 -22
  66. bloqade/stim/__init__.py +40 -5
  67. bloqade/stim/_wrappers.py +18 -12
  68. bloqade/stim/dialects/__init__.py +1 -5
  69. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +13 -1
  70. bloqade/stim/dialects/{aux → auxiliary}/emit.py +18 -3
  71. bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +1 -0
  72. bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +8 -0
  73. bloqade/stim/dialects/collapse/__init__.py +13 -2
  74. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +4 -2
  75. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  76. bloqade/stim/dialects/gate/__init__.py +16 -1
  77. bloqade/stim/dialects/gate/emit.py +10 -3
  78. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  79. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  80. bloqade/stim/dialects/noise/emit.py +33 -2
  81. bloqade/stim/dialects/noise/stmts.py +29 -0
  82. bloqade/stim/emit/__init__.py +1 -1
  83. bloqade/stim/groups.py +4 -2
  84. bloqade/stim/parse/__init__.py +1 -0
  85. bloqade/stim/parse/lowering.py +686 -0
  86. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/METADATA +5 -3
  87. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/RECORD +95 -77
  88. bloqade/noise/__init__.py +0 -2
  89. bloqade/noise/native/_dialect.py +0 -3
  90. bloqade/noise/native/_wrappers.py +0 -34
  91. bloqade/noise/native/model.py +0 -346
  92. bloqade/qasm2/dialects/noise.py +0 -16
  93. bloqade/squin/rewrite/measure_desugar.py +0 -33
  94. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  95. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  96. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  97. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  98. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  99. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  100. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/WHEEL +0 -0
  101. {bloqade_circuit-0.2.3.dist-info → bloqade_circuit-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,3 @@
1
- import os
2
- import pathlib
3
1
  from typing import Any
4
2
  from dataclasses import field, dataclass
5
3
 
@@ -19,131 +17,39 @@ class QASM2(lowering.LoweringABC[ast.Node]):
19
17
  hint_show_lineno: bool = field(default=True, kw_only=True)
20
18
  stacktrace: bool = field(default=True, kw_only=True)
21
19
 
22
- def loads(
20
+ def run(
23
21
  self,
24
- source: str,
25
- kernel_name: str,
22
+ stmt: ast.Node,
26
23
  *,
27
- returns: str | None = None,
24
+ source: str | None = None,
28
25
  globals: dict[str, Any] | None = None,
29
26
  file: str | None = None,
30
27
  lineno_offset: int = 0,
31
28
  col_offset: int = 0,
32
29
  compactify: bool = True,
33
- ) -> ir.Method:
34
- from ..parse import loads
35
-
36
- # TODO: add source info
37
- stmt = loads(source)
30
+ ) -> ir.Region:
38
31
 
39
- state = lowering.State(
40
- self,
32
+ frame = self.get_frame(
33
+ stmt,
34
+ source=source,
35
+ globals=globals,
41
36
  file=file,
42
37
  lineno_offset=lineno_offset,
43
38
  col_offset=col_offset,
44
39
  )
45
- with state.frame(
46
- [stmt],
47
- globals=globals,
48
- finalize_next=False,
49
- ) as frame:
50
- try:
51
- self.visit(state, stmt)
52
- # append return statement with the return values
53
- if returns is not None:
54
- return_value = frame.get(returns)
55
- if return_value is None:
56
- raise lowering.BuildError(f"Cannot find return value {returns}")
57
- else:
58
- return_value = func.ConstantNone()
59
- frame.push(return_value)
60
-
61
- return_node = frame.push(func.Return(value_or_stmt=return_value))
62
-
63
- except lowering.BuildError as e:
64
- hint = state.error_hint(
65
- e,
66
- max_lines=self.max_lines,
67
- indent=self.hint_indent,
68
- show_lineno=self.hint_show_lineno,
69
- )
70
- if self.stacktrace:
71
- raise Exception(
72
- f"{e.args[0]}\n\n{hint}",
73
- *e.args[1:],
74
- ) from e
75
- else:
76
- e.args = (hint,)
77
- raise e
78
-
79
- region = frame.curr_region
80
-
81
- if compactify:
82
- from kirin.rewrite import Walk, CFGCompactify
83
-
84
- Walk(CFGCompactify()).rewrite(region)
85
-
86
- code = func.Function(
87
- sym_name=kernel_name,
88
- signature=func.Signature((), return_node.value.type),
89
- body=region,
90
- )
91
40
 
92
- return ir.Method(
93
- mod=None,
94
- py_func=None,
95
- sym_name=kernel_name,
96
- arg_names=[],
97
- dialects=self.dialects,
98
- code=code,
99
- )
41
+ return frame.curr_region
100
42
 
101
- def loadfile(
102
- self,
103
- file: str | pathlib.Path,
104
- *,
105
- kernel_name: str | None = None,
106
- returns: str | None = None,
107
- globals: dict[str, Any] | None = None,
108
- lineno_offset: int = 0,
109
- col_offset: int = 0,
110
- compactify: bool = True,
111
- ) -> ir.Method:
112
- if isinstance(file, str):
113
- file = pathlib.Path(*os.path.split(file))
114
-
115
- if not file.is_file() or not file.name.endswith(".qasm"):
116
- raise ValueError("File must be a .qasm file")
117
-
118
- kernel_name = (
119
- file.name.replace(".qasm", "") if kernel_name is None else kernel_name
120
- )
121
-
122
- with file.open("r") as f:
123
- source = f.read()
124
-
125
- return self.loads(
126
- source,
127
- kernel_name,
128
- returns=returns,
129
- globals=globals,
130
- file=str(file),
131
- lineno_offset=lineno_offset,
132
- col_offset=col_offset,
133
- compactify=compactify,
134
- )
135
-
136
- def run(
43
+ def get_frame(
137
44
  self,
138
45
  stmt: ast.Node,
139
- *,
140
46
  source: str | None = None,
141
47
  globals: dict[str, Any] | None = None,
142
48
  file: str | None = None,
143
49
  lineno_offset: int = 0,
144
50
  col_offset: int = 0,
145
51
  compactify: bool = True,
146
- ) -> ir.Region:
52
+ ) -> lowering.Frame:
147
53
  # TODO: add source info
148
54
  state = lowering.State(
149
55
  self,
@@ -154,32 +60,16 @@ class QASM2(lowering.LoweringABC[ast.Node]):
154
60
  with state.frame(
155
61
  [stmt],
156
62
  globals=globals,
63
+ finalize_next=False,
157
64
  ) as frame:
158
- try:
159
- self.visit(state, stmt)
160
- except lowering.BuildError as e:
161
- hint = state.error_hint(
162
- e,
163
- max_lines=self.max_lines,
164
- indent=self.hint_indent,
165
- show_lineno=self.hint_show_lineno,
166
- )
167
- if self.stacktrace:
168
- raise Exception(
169
- f"{e.args[0]}\n\n{hint}",
170
- *e.args[1:],
171
- ) from e
172
- else:
173
- e.args = (hint,)
174
- raise e
175
-
176
- region = frame.curr_region
177
-
178
- if compactify:
179
- from kirin.rewrite import Walk, CFGCompactify
180
-
181
- Walk(CFGCompactify()).rewrite(region)
182
- return region
65
+ self.visit(state, stmt)
66
+
67
+ if compactify:
68
+ from kirin.rewrite import Walk, CFGCompactify
69
+
70
+ Walk(CFGCompactify()).rewrite(frame.curr_region)
71
+
72
+ return frame
183
73
 
184
74
  def visit(self, state: lowering.State[ast.Node], node: ast.Node) -> lowering.Result:
185
75
  name = node.__class__.__name__
@@ -9,7 +9,7 @@ version: INT "." INT
9
9
  // stmts
10
10
  include: "include" STRING ";"
11
11
  ifstmt: "if" "(" expr "==" expr ")" ifbody
12
- ifbody: qop | "{" qop* "}" // allow multiple qops
12
+ ifbody: qop
13
13
  opaque: "opaque" IDENTIFIER ["(" [params] ")"] qubits ";"
14
14
  barrier: "barrier" qubits ";"
15
15
  qreg: "qreg" IDENTIFIER "[" INT "]" ";"
@@ -3,3 +3,4 @@ from .noise import NoisePass as NoisePass
3
3
  from .py2qasm import Py2QASM as Py2QASM
4
4
  from .qasm2py import QASM2Py as QASM2Py
5
5
  from .parallel import UOpToParallel as UOpToParallel
6
+ from .unroll_if import UnrollIfs as UnrollIfs
@@ -23,6 +23,8 @@ from kirin.rewrite.abc import RewriteResult
23
23
 
24
24
  from bloqade.qasm2.dialects import expr
25
25
 
26
+ from .unroll_if import UnrollIfs
27
+
26
28
 
27
29
  @dataclass
28
30
  class QASM2Fold(Pass):
@@ -30,6 +32,7 @@ class QASM2Fold(Pass):
30
32
 
31
33
  constprop: const.Propagate = field(init=False)
32
34
  inline_gate_subroutine: bool = True
35
+ unroll_ifs: bool = True
33
36
 
34
37
  def __post_init__(self):
35
38
  self.constprop = const.Propagate(self.dialects)
@@ -61,6 +64,9 @@ class QASM2Fold(Pass):
61
64
  .join(result)
62
65
  )
63
66
 
67
+ if self.unroll_ifs:
68
+ UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69
+
64
70
  # run typeinfer again after unroll etc. because we now insert
65
71
  # a lot of new nodes, which might have more precise types
66
72
  self.typeinfer.unsafe_run(mt)
@@ -58,13 +58,15 @@ class GlobalToUOP(Pass):
58
58
  rewriter = walk.Walk(self.generate_rule(mt))
59
59
  result = rewriter.rewrite(mt.code)
60
60
 
61
- result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code)
62
- result = Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination())).rewrite(
63
- mt.code
61
+ result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code).join(result)
62
+ result = (
63
+ Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination()))
64
+ .rewrite(mt.code)
65
+ .join(result)
64
66
  )
65
67
 
66
68
  # do fold again to get proper hint for inserted const
67
- result = Fold(mt.dialects)(mt)
69
+ result = Fold(mt.dialects)(mt).join(result)
68
70
  return result
69
71
 
70
72
 
@@ -110,10 +112,12 @@ class GlobalToParallel(Pass):
110
112
  rewriter = walk.Walk(self.generate_rule(mt))
111
113
  result = rewriter.rewrite(mt.code)
112
114
 
113
- result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code)
114
- result = Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination())).rewrite(
115
- mt.code
115
+ result = walk.Walk(dce.DeadCodeElimination()).rewrite(mt.code).join(result)
116
+ result = (
117
+ Fixpoint(walk.Walk(rule=cse.CommonSubexpressionElimination()))
118
+ .rewrite(mt.code)
119
+ .join(result)
116
120
  )
117
121
  # do fold again to get proper hint
118
- result = Fold(mt.dialects)(mt)
122
+ result = Fold(mt.dialects)(mt).join(result)
119
123
  return result
@@ -8,10 +8,10 @@ from kirin.rewrite import (
8
8
  DeadCodeElimination,
9
9
  )
10
10
 
11
- from bloqade.noise import native
11
+ from bloqade.qasm2 import noise
12
12
  from bloqade.analysis import address
13
+ from bloqade.qasm2.rewrite import NoiseRewriteRule
13
14
  from bloqade.qasm2.passes.lift_qubits import LiftQubits
14
- from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
15
15
 
16
16
 
17
17
  @dataclass
@@ -25,12 +25,9 @@ class NoisePass(Pass):
25
25
 
26
26
  ```
27
27
  from bloqade import qasm2
28
- from bloqade.noise import native
29
- from bloqade.qasm2.passes.noise import NoisePass
28
+ from bloqade.qasm2.passes import NoisePass
30
29
 
31
- noise_main = qasm2.extended.add(native.dialect)
32
-
33
- @noise_main
30
+ @qasm2.extended
34
31
  def main():
35
32
  q = qasm2.qreg(2)
36
33
  qasm2.h(q[0])
@@ -51,31 +48,45 @@ class NoisePass(Pass):
51
48
 
52
49
  """
53
50
 
54
- noise_model: native.MoveNoiseModelABC = field(
55
- default_factory=native.TwoRowZoneModel
56
- )
57
- gate_noise_params: native.GateNoiseParams = field(
58
- default_factory=native.GateNoiseParams
59
- )
51
+ noise_model: noise.MoveNoiseModelABC = field(default_factory=noise.TwoRowZoneModel)
60
52
  address_analysis: address.AddressAnalysis = field(init=False)
61
53
 
62
54
  def __post_init__(self):
63
55
  self.address_analysis = address.AddressAnalysis(self.dialects)
64
56
 
57
+ def get_qubit_values(self, mt: ir.Method):
58
+ frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
59
+ qubit_ssa_values = {}
60
+ # Traverse statements in block order to fine the first SSA value for each qubit
61
+ for block in mt.callable_region.blocks:
62
+ for stmt in block.stmts:
63
+ if len(stmt.results) != 1:
64
+ continue
65
+
66
+ addr = frame.entries.get(result := stmt.results[0])
67
+ if (
68
+ isinstance(addr, address.AddressQubit)
69
+ and (index := addr.data) not in qubit_ssa_values
70
+ ):
71
+ qubit_ssa_values[index] = result
72
+
73
+ return qubit_ssa_values, frame.entries
74
+
65
75
  def unsafe_run(self, mt: ir.Method):
66
76
  result = LiftQubits(self.dialects).unsafe_run(mt)
67
- frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
77
+ qubit_ssa_value, address_analysis = self.get_qubit_values(mt)
68
78
  result = (
69
79
  Walk(
70
80
  NoiseRewriteRule(
71
- address_analysis=frame.entries,
81
+ qubit_ssa_value=qubit_ssa_value,
82
+ address_analysis=address_analysis,
72
83
  noise_model=self.noise_model,
73
- gate_noise_params=self.gate_noise_params,
74
84
  ),
75
85
  reverse=True,
76
86
  )
77
87
  .rewrite(mt.code)
78
88
  .join(result)
79
89
  )
90
+
80
91
  result = Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
81
92
  return result
@@ -27,6 +27,7 @@ from bloqade.qasm2.rewrite import (
27
27
  RaiseRegisterRule,
28
28
  UOpToParallelRule,
29
29
  SimpleOptimalMergePolicy,
30
+ RydbergGateSetRewriteRule,
30
31
  )
31
32
  from bloqade.squin.analysis import schedule
32
33
 
@@ -135,6 +136,7 @@ class UOpToParallel(Pass):
135
136
  """
136
137
 
137
138
  merge_policy_type: Type[MergePolicyABC] = SimpleOptimalMergePolicy
139
+ rewrite_to_native_first: bool = False
138
140
  constprop: const.Propagate = field(init=False)
139
141
 
140
142
  def __post_init__(self):
@@ -147,6 +149,13 @@ class UOpToParallel(Pass):
147
149
  if not result.has_done_something:
148
150
  return result
149
151
 
152
+ if self.rewrite_to_native_first:
153
+ result = (
154
+ Fixpoint(Walk(RydbergGateSetRewriteRule(self.dialects)))
155
+ .rewrite(mt.code)
156
+ .join(result)
157
+ )
158
+
150
159
  frame, _ = self.constprop.run_analysis(mt)
151
160
  result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
152
161
 
@@ -0,0 +1,25 @@
1
+ from kirin import ir
2
+ from kirin.passes import Pass
3
+ from kirin.rewrite import (
4
+ Walk,
5
+ Chain,
6
+ Fixpoint,
7
+ ConstantFold,
8
+ CommonSubexpressionElimination,
9
+ )
10
+
11
+ from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12
+
13
+
14
+ class UnrollIfs(Pass):
15
+ """This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16
+
17
+ def unsafe_run(self, mt: ir.Method):
18
+ result = Walk(LiftThenBody()).rewrite(mt.code)
19
+ result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20
+ result = (
21
+ Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
22
+ .rewrite(mt.code)
23
+ .join(result)
24
+ )
25
+ return result
@@ -3,6 +3,7 @@ from .glob import (
3
3
  GlobalToParallelRule as GlobalToParallelRule,
4
4
  )
5
5
  from .register import RaiseRegisterRule as RaiseRegisterRule
6
+ from .native_gates import RydbergGateSetRewriteRule as RydbergGateSetRewriteRule
6
7
  from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
7
8
  from .uop_to_parallel import (
8
9
  MergePolicyABC as MergePolicyABC,
@@ -10,3 +11,5 @@ from .uop_to_parallel import (
10
11
  SimpleGreedyMergePolicy as SimpleGreedyMergePolicy,
11
12
  SimpleOptimalMergePolicy as SimpleOptimalMergePolicy,
12
13
  )
14
+ from .noise.remove_noise import RemoveNoisePass as RemoveNoisePass
15
+ from .noise.heuristic_noise import NoiseRewriteRule as NoiseRewriteRule
@@ -5,16 +5,17 @@ from kirin.passes import Pass
5
5
  from kirin.rewrite import abc, walk
6
6
  from kirin.dialects import py
7
7
 
8
+ from bloqade.qasm2 import types
8
9
  from bloqade.qasm2.dialects import core
9
10
 
10
11
 
11
12
  class IndexingDesugarRule(abc.RewriteRule):
12
13
  def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
13
14
  if isinstance(node, py.indexing.GetItem):
14
- if node.obj.type.is_subseteq(core.QRegType):
15
+ if node.obj.type.is_subseteq(types.QRegType):
15
16
  node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
16
17
  return abc.RewriteResult(has_done_something=True)
17
- elif node.obj.type.is_subseteq(core.CRegType):
18
+ elif node.obj.type.is_subseteq(types.CRegType):
18
19
  node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
19
20
  return abc.RewriteResult(has_done_something=True)
20
21
 
@@ -173,6 +173,53 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
173
173
  cirq.XPowGate(exponent=0.5).on(self.cached_qubits[0]), node
174
174
  )
175
175
 
176
+ def rewrite_ccx(self, node: uop.CCX) -> abc.RewriteResult:
177
+ # from https://algassert.com/quirk#circuit=%7B%22cols%22:%5B%5B%22QFT3%22%5D,%5B%22inputA3%22,1,1,%22+=A3%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22%E2%80%A2%22,%22X%22%5D,%5B1,1,1,%22%E2%80%A6%22,%22%E2%80%A6%22,%22%E2%80%A6%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E%C2%BC%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22Z%5E%C2%BC%22,%22Z%5E%C2%BC%22%5D,%5B1,1,1,1,%22H%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22X%5E-%C2%BC%22,%22X%5E%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22H%22%5D%5D%7D
178
+
179
+ # x^(1/4)
180
+ lam1, theta1, phi1 = map(
181
+ self.const_float,
182
+ map(around, (1.5707963267948966, 0.7853981633974483, -1.5707963267948966)),
183
+ )
184
+ lam1.insert_before(node)
185
+ theta1.insert_before(node)
186
+ phi1.insert_before(node)
187
+
188
+ lam1 = lam1.result
189
+ theta1 = theta1.result
190
+ phi1 = phi1.result
191
+
192
+ # x^(-1/4)
193
+ lam2, theta2, phi2 = map(
194
+ self.const_float,
195
+ map(around, (4.71238898038469, 0.7853981633974483, 1.5707963267948966)),
196
+ )
197
+ lam2.insert_before(node)
198
+ theta2.insert_before(node)
199
+ phi2.insert_before(node)
200
+ lam2 = lam2.result
201
+ theta2 = theta2.result
202
+ phi2 = phi2.result
203
+
204
+ uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
205
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
206
+ uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
207
+ uop.UGate(node.qarg, theta1, phi1, lam1).insert_before(node)
208
+ uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
209
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
210
+ uop.T(node.ctrl1).insert_before(node)
211
+ uop.T(node.ctrl2).insert_before(node)
212
+ uop.H(node.ctrl1).insert_before(node)
213
+ uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
214
+ uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
215
+ uop.UGate(node.ctrl1, theta2, phi2, lam2).insert_before(node)
216
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
217
+ uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
218
+ uop.H(node.ctrl1).insert_before(node)
219
+ node.delete() # delete the original CCX gate
220
+
221
+ return abc.RewriteResult(has_done_something=True)
222
+
176
223
  def rewrite_sxdg(self, node: uop.SXdag) -> abc.RewriteResult:
177
224
  return self._rewrite_1q_gates(
178
225
  cirq.XPowGate(exponent=-0.5).on(self.cached_qubits[0]), node
@@ -394,9 +441,12 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
394
441
  new_gate_stmts = self._generate_1q_gate_stmts(cirq_gate, node.qarg)
395
442
  return self._rewrite_gate_stmts(new_gate_stmts, node)
396
443
 
397
- def _generate_2q_ctrl_gate_stmts(
444
+ def _generate_multi_ctrl_gate_stmts(
398
445
  self, cirq_gate: cirq.Operation, qubits_ssa: List[ir.SSAValue]
399
446
  ) -> list[ir.Statement]:
447
+ qubit_to_ssa_map = {
448
+ q: ssa for q, ssa in zip(self.cached_qubits[: len(qubits_ssa)], qubits_ssa)
449
+ }
400
450
  target_gates = self.gateset.decompose_to_target_gateset(cirq_gate, 0)
401
451
  new_stmts = []
402
452
  for new_gate in target_gates:
@@ -412,7 +462,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
412
462
  new_stmts.append(phi2_stmt)
413
463
  new_stmts.append(
414
464
  uop.UGate(
415
- qarg=qubits_ssa[new_gate.qubits[0].x],
465
+ qarg=qubit_to_ssa_map[new_gate.qubits[0]],
416
466
  theta=phi0_stmt.result,
417
467
  phi=phi1_stmt.result,
418
468
  lam=phi2_stmt.result,
@@ -420,18 +470,31 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
420
470
  )
421
471
  else:
422
472
  # 2q
423
- new_stmts.append(uop.CZ(ctrl=qubits_ssa[0], qarg=qubits_ssa[1]))
473
+ new_stmts.append(
474
+ uop.CZ(
475
+ ctrl=qubit_to_ssa_map[new_gate.qubits[0]],
476
+ qarg=qubit_to_ssa_map[new_gate.qubits[1]],
477
+ )
478
+ )
424
479
 
425
480
  return new_stmts
426
481
 
427
482
  def _rewrite_2q_ctrl_gates(
428
483
  self, cirq_gate: cirq.Operation, node: uop.TwoQubitCtrlGate
429
484
  ) -> abc.RewriteResult:
430
- new_gate_stmts = self._generate_2q_ctrl_gate_stmts(
485
+ new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
431
486
  cirq_gate, [node.ctrl, node.qarg]
432
487
  )
433
488
  return self._rewrite_gate_stmts(new_gate_stmts, node)
434
489
 
490
+ def _rewrite_3q_ctrl_gates(
491
+ self, cirq_gate: cirq.Operation, node: uop.CCX
492
+ ) -> abc.RewriteResult:
493
+ new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
494
+ cirq_gate, [node.ctrl1, node.ctrl2, node.qarg]
495
+ )
496
+ return self._rewrite_gate_stmts(new_gate_stmts, node)
497
+
435
498
  def _rewrite_gate_stmts(
436
499
  self, new_gate_stmts: list[ir.Statement], node: ir.Statement
437
500
  ):
File without changes