bloqade-circuit 0.7.13__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +96 -51
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.13.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@ import typing
3
3
  from kirin import lowering
4
4
  from kirin.dialects import ilist
5
5
 
6
- from bloqade.squin import qubit
6
+ from bloqade import qubit
7
7
 
8
8
  from .stmts import CZ, R, Rz
9
9
 
@@ -12,21 +12,21 @@ Len = typing.TypeVar("Len")
12
12
 
13
13
  @lowering.wraps(CZ)
14
14
  def cz(
15
- ctrls: ilist.IList[qubit.Qubit, Len],
16
- qargs: ilist.IList[qubit.Qubit, Len],
15
+ controls: ilist.IList[qubit.Qubit, Len],
16
+ targets: ilist.IList[qubit.Qubit, Len],
17
17
  ): ...
18
18
 
19
19
 
20
20
  @lowering.wraps(R)
21
21
  def r(
22
- inputs: ilist.IList[qubit.Qubit, typing.Any],
23
22
  axis_angle: float,
24
23
  rotation_angle: float,
24
+ qubits: ilist.IList[qubit.Qubit, typing.Any],
25
25
  ): ...
26
26
 
27
27
 
28
28
  @lowering.wraps(Rz)
29
29
  def rz(
30
- inputs: ilist.IList[qubit.Qubit, typing.Any],
31
30
  rotation_angle: float,
31
+ qubits: ilist.IList[qubit.Qubit, typing.Any],
32
32
  ): ...
@@ -2,7 +2,7 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
  from kirin.dialects import ilist
4
4
 
5
- from bloqade.squin import qubit
5
+ from bloqade.types import QubitType
6
6
 
7
7
  from ._dialect import dialect
8
8
 
@@ -12,20 +12,20 @@ N = types.TypeVar("N")
12
12
  @statement(dialect=dialect)
13
13
  class CZ(ir.Statement):
14
14
  traits = frozenset({lowering.FromPythonCall()})
15
- ctrls: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
16
- qargs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
15
+ controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
16
+ targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
17
17
 
18
18
 
19
19
  @statement(dialect=dialect)
20
20
  class R(ir.Statement):
21
21
  traits = frozenset({lowering.FromPythonCall()})
22
- inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
23
22
  axis_angle: ir.SSAValue = info.argument(types.Float)
24
23
  rotation_angle: ir.SSAValue = info.argument(types.Float)
24
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
25
25
 
26
26
 
27
27
  @statement(dialect=dialect)
28
28
  class Rz(ir.Statement):
29
29
  traits = frozenset({lowering.FromPythonCall()})
30
- inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
31
30
  rotation_angle: ir.SSAValue = info.argument(types.Float)
31
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
@@ -3,9 +3,9 @@ from typing import Any, TypeVar
3
3
 
4
4
  from kirin.dialects import ilist
5
5
 
6
- from bloqade.squin import qubit
6
+ from bloqade import qubit
7
7
  from bloqade.native._prelude import kernel
8
- from bloqade.native.dialects.gates import _interface as native
8
+ from bloqade.native.dialects.gate import _interface as native
9
9
 
10
10
 
11
11
  @kernel
@@ -29,7 +29,7 @@ def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
29
29
  angle (float): Rotation angle in radians.
30
30
  qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
31
31
  """
32
- native.r(qubits, 0.0, _radian_to_turn(angle))
32
+ native.r(0.0, _radian_to_turn(angle), qubits)
33
33
 
34
34
 
35
35
  @kernel
@@ -70,7 +70,7 @@ def ry(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
70
70
  angle (float): Rotation angle in radians.
71
71
  qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
72
72
  """
73
- native.r(qubits, 0.25, _radian_to_turn(angle))
73
+ native.r(0.25, _radian_to_turn(angle), qubits)
74
74
 
75
75
 
76
76
  @kernel
@@ -111,7 +111,7 @@ def rz(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
111
111
  angle (float): Rotation angle in radians.
112
112
  qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
113
113
  """
114
- native.rz(qubits, _radian_to_turn(angle))
114
+ native.rz(_radian_to_turn(angle), qubits)
115
115
 
116
116
 
117
117
  @kernel
@@ -167,43 +167,43 @@ def t(qubits: ilist.IList[qubit.Qubit, Any]):
167
167
 
168
168
 
169
169
  @kernel
170
- def shift(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
171
- """Apply a phase shift to the |1> state on a group of qubits.
170
+ def t_adj(qubits: ilist.IList[qubit.Qubit, Any]):
171
+ """Apply the adjoint of aT gate on a group of qubits.
172
172
 
173
173
  Args:
174
- angle (float): Phase shift angle in radians.
175
174
  qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
176
175
  """
177
- rz(angle / 2.0, qubits)
176
+ rz(-math.pi / 4.0, qubits)
178
177
 
179
178
 
180
179
  @kernel
181
- def rot(phi: float, theta: float, omega: float, qubits: ilist.IList[qubit.Qubit, Any]):
182
- """Apply a general single-qubit rotation on a group of qubits.
180
+ def shift(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
181
+ """Apply a phase shift to the |1> state on a group of qubits.
183
182
 
184
183
  Args:
185
- phi (float): Z rotation before Y (radians).
186
- theta (float): Y rotation (radians).
187
- omega (float): Z rotation after Y (radians).
184
+ angle (float): Phase shift angle in radians.
188
185
  qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
189
186
  """
190
- rz(phi, qubits)
191
- ry(theta, qubits)
192
- rz(omega, qubits)
187
+ rz(angle / 2.0, qubits)
193
188
 
194
189
 
195
190
  @kernel
196
191
  def u3(theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, Any]):
197
192
  """Apply the U3 gate on a group of qubits.
198
193
 
194
+ The applied gate is represented by the unitary matrix given by:
195
+
196
+ $$ U3(\\theta, \\phi, \\lambda) = R_z(\\phi)R_y(\\theta)R_z(\\lambda) $$
197
+
199
198
  Args:
200
199
  theta (float): Rotation around Y axis (radians).
201
200
  phi (float): Global phase shift component (radians).
202
201
  lam (float): Z rotations in decomposition (radians).
203
202
  qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
204
203
  """
205
- rot(lam, theta, -lam, qubits)
206
- shift(phi + lam, qubits)
204
+ rz(lam, qubits)
205
+ ry(theta, qubits)
206
+ rz(phi, qubits)
207
207
 
208
208
 
209
209
  N = TypeVar("N")
@@ -1,6 +1,6 @@
1
1
  from kirin.dialects import ilist
2
2
 
3
- from bloqade.squin import qubit
3
+ from bloqade import qubit
4
4
 
5
5
  from . import broadcast
6
6
  from .._prelude import kernel
@@ -150,33 +150,34 @@ def t(qubit: qubit.Qubit):
150
150
 
151
151
 
152
152
  @kernel
153
- def shift(angle: float, qubit: qubit.Qubit):
154
- """Apply a phase shift on the |1> state of a single qubit.
153
+ def t_adj(qubit: qubit.Qubit):
154
+ """Apply the adjoint of the T gate on a single qubit.
155
155
 
156
156
  Args:
157
- angle (float): Shift angle in radians.
158
- qubit (qubit.Qubit): The qubit to apply the shift to.
157
+ qubit (qubit.Qubit): The qubit to apply the adjoint T gate to.
159
158
  """
160
- broadcast.shift(angle, ilist.IList([qubit]))
159
+ broadcast.t_adj(ilist.IList([qubit]))
161
160
 
162
161
 
163
162
  @kernel
164
- def rot(phi: float, theta: float, omega: float, qubit: qubit.Qubit):
165
- """Apply a general single-qubit rotation on a single qubit.
163
+ def shift(angle: float, qubit: qubit.Qubit):
164
+ """Apply a phase shift on the |1> state of a single qubit.
166
165
 
167
166
  Args:
168
- phi (float): Z rotation before Y (radians).
169
- theta (float): Y rotation (radians).
170
- omega (float): Z rotation after Y (radians).
171
- qubit (qubit.Qubit): The qubit to apply the rotation to.
167
+ angle (float): Shift angle in radians.
168
+ qubit (qubit.Qubit): The qubit to apply the shift to.
172
169
  """
173
- broadcast.rot(phi, theta, omega, ilist.IList([qubit]))
170
+ broadcast.shift(angle, ilist.IList([qubit]))
174
171
 
175
172
 
176
173
  @kernel
177
174
  def u3(theta: float, phi: float, lam: float, qubit: qubit.Qubit):
178
175
  """Apply the U3 gate on a single qubit.
179
176
 
177
+ The applied gate is represented by the unitary matrix given by:
178
+
179
+ $$ U3(\\theta, \\phi, \\lambda) = R_z(\\phi)R_y(\\theta)R_z(\\lambda) $$
180
+
180
181
  Args:
181
182
  theta (float): Rotation angle around the Y axis in radians.
182
183
  phi (float): Rotation angle around the Z axis in radians.
@@ -0,0 +1,5 @@
1
+ from .squin2native import (
2
+ GateRule as GateRule,
3
+ SquinToNative as SquinToNative,
4
+ SquinToNativePass as SquinToNativePass,
5
+ )
@@ -0,0 +1,136 @@
1
+ from itertools import chain
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir, passes, rewrite
5
+ from kirin.dialects import py, func
6
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
7
+ from kirin.passes.callgraph import CallGraphPass, ReplaceMethods
8
+ from kirin.analysis.callgraph import CallGraph
9
+
10
+ from bloqade.native import kernel, broadcast
11
+ from bloqade.squin.gate import stmts, dialect as gate_dialect
12
+
13
+
14
+ class GateRule(RewriteRule):
15
+ SQUIN_MAPPING: dict[type[ir.Statement], tuple[ir.Method, ...]] = {
16
+ stmts.X: (broadcast.x,),
17
+ stmts.Y: (broadcast.y,),
18
+ stmts.Z: (broadcast.z,),
19
+ stmts.H: (broadcast.h,),
20
+ stmts.S: (broadcast.s, broadcast.s_adj),
21
+ stmts.T: (broadcast.t, broadcast.t_adj),
22
+ stmts.SqrtX: (broadcast.sqrt_x, broadcast.sqrt_x_adj),
23
+ stmts.SqrtY: (broadcast.sqrt_y, broadcast.sqrt_y_adj),
24
+ stmts.Rx: (broadcast.rx,),
25
+ stmts.Ry: (broadcast.ry,),
26
+ stmts.Rz: (broadcast.rz,),
27
+ stmts.CX: (broadcast.cx,),
28
+ stmts.CY: (broadcast.cy,),
29
+ stmts.CZ: (broadcast.cz,),
30
+ stmts.U3: (broadcast.u3,),
31
+ }
32
+
33
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
34
+ if (native_methods := self.SQUIN_MAPPING.get(type(node))) is None:
35
+ return RewriteResult()
36
+
37
+ if isinstance(node, stmts.SingleQubitNonHermitianGate):
38
+ native_method = native_methods[1] if node.adjoint else native_methods[0]
39
+ else:
40
+ native_method = native_methods[0]
41
+
42
+ # do not rewrite in invoke because callgraph pass will be looking for invoke statements
43
+ (callee := py.Constant(native_method)).insert_before(node)
44
+ node.replace_by(func.Call(callee.result, tuple(node.args), kwargs=()))
45
+
46
+ return RewriteResult(has_done_something=True)
47
+
48
+
49
+ @dataclass
50
+ class UpdateDialectsOnCallGraph(passes.Pass):
51
+ """Update All dialects on the call graph to a new set of dialects given to this pass.
52
+
53
+ Usage:
54
+ pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
55
+ pass_(some_method)
56
+
57
+ Note: This pass does not update the dialects of the input method, but copies
58
+ all other methods invoked within it before updating their dialects.
59
+
60
+ """
61
+
62
+ fold_pass: passes.Fold = field(init=False)
63
+
64
+ def __post_init__(self):
65
+ self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
66
+
67
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
68
+ mt_map = {}
69
+
70
+ cg = CallGraph(mt)
71
+
72
+ all_methods = set(sum(map(tuple, cg.defs.values()), ()))
73
+ for original_mt in all_methods:
74
+ if original_mt is mt:
75
+ new_mt = original_mt
76
+ else:
77
+ new_mt = original_mt.similar(self.dialects)
78
+ mt_map[original_mt] = new_mt
79
+
80
+ result = RewriteResult()
81
+
82
+ for _, new_mt in mt_map.items():
83
+ result = (
84
+ rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
85
+ )
86
+ self.fold_pass(new_mt)
87
+
88
+ return result
89
+
90
+
91
+ @dataclass
92
+ class SquinToNativePass(passes.Pass):
93
+
94
+ call_graph_pass: CallGraphPass = field(init=False)
95
+
96
+ def __post_init__(self):
97
+ rule = rewrite.Walk(GateRule())
98
+ self.call_graph_pass = CallGraphPass(
99
+ self.dialects, rule, no_raise=self.no_raise
100
+ )
101
+
102
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
103
+ return self.call_graph_pass.unsafe_run(mt)
104
+
105
+
106
+ class SquinToNative:
107
+ """A Target that converts Squin gates to native gates."""
108
+
109
+ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:
110
+ """Convert Squin gates to native gates.
111
+
112
+ Args:
113
+ mt (ir.Method): The method to convert.
114
+ no_raise (bool, optional): Whether to suppress errors. Defaults to True.
115
+
116
+ Returns:
117
+ ir.Method: The converted method.
118
+ """
119
+ old_callgraph = CallGraph(mt)
120
+ all_dialects = chain.from_iterable(
121
+ ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers
122
+ )
123
+ new_dialects = (
124
+ mt.dialects.union(all_dialects).discard(gate_dialect).union(kernel)
125
+ )
126
+
127
+ out = mt.similar(new_dialects)
128
+ UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out)
129
+ SquinToNativePass(new_dialects, no_raise=no_raise)(out)
130
+ # verify all kernels in the callgraph
131
+ new_callgraph = CallGraph(out)
132
+ all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers)
133
+ for ker in all_kernels:
134
+ ker.verify()
135
+
136
+ return out
@@ -3,7 +3,6 @@ from .reg import (
3
3
  CRegister as CRegister,
4
4
  QubitState as QubitState,
5
5
  Measurement as Measurement,
6
- PyQrackWire as PyQrackWire,
7
6
  PyQrackQubit as PyQrackQubit,
8
7
  )
9
8
  from .base import (
@@ -16,7 +15,7 @@ from .task import PyQrackSimulatorTask as PyQrackSimulatorTask
16
15
  # NOTE: The following import is for registering the method tables
17
16
  from .noise import native as native
18
17
  from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
19
- from .squin import op as op, noise as noise, qubit as qubit
18
+ from .squin import gate as gate, noise as noise, qubit as qubit
20
19
  from .device import (
21
20
  StackMemorySimulator as StackMemorySimulator,
22
21
  DynamicMemorySimulator as DynamicMemorySimulator,
bloqade/pyqrack/device.py CHANGED
@@ -3,10 +3,8 @@ from dataclasses import field, dataclass
3
3
 
4
4
  import numpy as np
5
5
  from kirin import ir
6
- from kirin.passes import fold
7
6
  from kirin.dialects.ilist import IList
8
7
 
9
- from bloqade.squin import noise as squin_noise
10
8
  from pyqrack.pauli import Pauli
11
9
  from bloqade.device import AbstractSimulatorDevice
12
10
  from bloqade.pyqrack.reg import Measurement, PyQrackQubit
@@ -20,8 +18,7 @@ from bloqade.pyqrack.base import (
20
18
  )
21
19
  from bloqade.pyqrack.task import PyQrackSimulatorTask
22
20
  from pyqrack.qrack_simulator import QrackSimulator
23
- from bloqade.squin.noise.rewrite import RewriteNoiseStmts
24
- from bloqade.analysis.address.lattice import AnyAddress
21
+ from bloqade.analysis.address.lattice import UnknownReg, UnknownQubit
25
22
  from bloqade.analysis.address.analysis import AddressAnalysis
26
23
 
27
24
  RetType = TypeVar("RetType")
@@ -33,7 +30,7 @@ class QuantumState(NamedTuple):
33
30
  A representation of a quantum state as a density matrix, where the density matrix is
34
31
  rho = sum_i eigenvalues[i] |eigenvectors[:,i]><eigenvectors[:,i]|.
35
32
 
36
- This reprsentation is efficient for low-rank density matrices by only storing
33
+ This representation is efficient for low-rank density matrices by only storing
37
34
  the non-zero eigenvalues and corresponding eigenvectors of the density matrix.
38
35
  For example, a pure state has only one non-zero eigenvalue equal to 1.0.
39
36
 
@@ -158,7 +155,7 @@ def _pyqrack_reduced_density_matrix(
158
155
  # The singular values and vectors are the eigenspace of the reduced density matrix
159
156
  s, v, d = np.linalg.svd(vec_svd, full_matrices=False)
160
157
 
161
- # Remove the negligable singular values
158
+ # Remove the negligible singular values
162
159
  nonzero_inds = np.where(np.abs(v) > tol)[0]
163
160
  s = s[:, nonzero_inds]
164
161
  v = v[nonzero_inds] ** 2
@@ -191,22 +188,14 @@ class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
191
188
  kwargs: dict[str, Any],
192
189
  memory: MemoryType,
193
190
  ) -> PyQrackSimulatorTask[Params, RetType, MemoryType]:
194
- if squin_noise in mt.dialects:
195
- # NOTE: rewrite noise statements
196
- mt_ = mt.similar(mt.dialects)
197
- RewriteNoiseStmts(mt_.dialects)(mt_)
198
- fold.Fold(mt_.dialects)(mt_)
199
- else:
200
- mt_ = mt
201
-
202
191
  interp = PyQrackInterpreter(
203
- mt_.dialects,
192
+ mt.dialects,
204
193
  memory=memory,
205
194
  rng_state=self.rng_state,
206
195
  loss_m_result=self.loss_m_result,
207
196
  )
208
197
  return PyQrackSimulatorTask(
209
- kernel=mt_, args=args, kwargs=kwargs, pyqrack_interp=interp
198
+ kernel=mt, args=args, kwargs=kwargs, pyqrack_interp=interp
210
199
  )
211
200
 
212
201
  def state_vector(
@@ -366,7 +355,7 @@ class StackMemorySimulator(PyQrackSimulatorBase):
366
355
  address_analysis = AddressAnalysis(dialects=kernel.dialects)
367
356
  frame, _ = address_analysis.run_analysis(kernel)
368
357
  if self.min_qubits == 0 and any(
369
- isinstance(a, AnyAddress) for a in frame.entries.values()
358
+ isinstance(a, (UnknownQubit, UnknownReg)) for a in frame.entries.values()
370
359
  ):
371
360
  raise ValueError(
372
361
  "All addresses must be resolved. Or set min_qubits to a positive integer."
bloqade/pyqrack/native.py CHANGED
@@ -7,29 +7,29 @@ from kirin.dialects import ilist
7
7
  from pyqrack import Pauli
8
8
  from bloqade.pyqrack import PyQrackQubit
9
9
  from bloqade.pyqrack.base import PyQrackInterpreter
10
- from bloqade.native.dialects import gates
10
+ from bloqade.native.dialects.gate import stmts
11
11
 
12
12
 
13
- @gates.dialect.register(key="pyqrack")
13
+ @stmts.dialect.register(key="pyqrack")
14
14
  class NativeMethods(interp.MethodTable):
15
15
 
16
- @interp.impl(gates.CZ)
17
- def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.CZ):
18
- ctrls = frame.get_casted(stmt.ctrls, ilist.IList[PyQrackQubit, Any])
19
- qargs = frame.get_casted(stmt.qargs, ilist.IList[PyQrackQubit, Any])
16
+ @interp.impl(stmts.CZ)
17
+ def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.CZ):
18
+ controls = frame.get_casted(stmt.controls, ilist.IList[PyQrackQubit, Any])
19
+ targets = frame.get_casted(stmt.targets, ilist.IList[PyQrackQubit, Any])
20
20
 
21
- for ctrl, qarg in zip(ctrls, qargs):
22
- if ctrl.is_active() and qarg.is_active():
23
- ctrl.sim_reg.mcz([ctrl.addr], qarg.addr)
21
+ for ctrl, trgt in zip(controls, targets):
22
+ if ctrl.is_active() and trgt.is_active():
23
+ ctrl.sim_reg.mcz([ctrl.addr], trgt.addr)
24
24
 
25
25
  return ()
26
26
 
27
- @interp.impl(gates.R)
28
- def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.R):
29
- inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any])
27
+ @interp.impl(stmts.R)
28
+ def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.R):
29
+ qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
30
30
  rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
31
31
  axis_angle = 2 * math.pi * frame.get_casted(stmt.axis_angle, float)
32
- for qubit in inputs:
32
+ for qubit in qubits:
33
33
  if qubit.is_active():
34
34
  qubit.sim_reg.r(Pauli.PauliZ, axis_angle, qubit.addr)
35
35
  qubit.sim_reg.r(Pauli.PauliX, rotation_angle, qubit.addr)
@@ -37,12 +37,12 @@ class NativeMethods(interp.MethodTable):
37
37
 
38
38
  return ()
39
39
 
40
- @interp.impl(gates.Rz)
41
- def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.Rz):
42
- inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any])
40
+ @interp.impl(stmts.Rz)
41
+ def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.Rz):
42
+ qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
43
43
  rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
44
44
 
45
- for qubit in inputs:
45
+ for qubit in qubits:
46
46
  if qubit.is_active():
47
47
  qubit.sim_reg.r(Pauli.PauliZ, rotation_angle, qubit.addr)
48
48
 
bloqade/pyqrack/reg.py CHANGED
@@ -2,8 +2,8 @@ import enum
2
2
  from typing import TYPE_CHECKING
3
3
  from dataclasses import dataclass
4
4
 
5
+ from bloqade.types import MeasurementResult
5
6
  from bloqade.qasm2.types import Qubit
6
- from bloqade.squin.types import MeasurementResult
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  from pyqrack import QrackSimulator
@@ -75,8 +75,3 @@ class PyQrackQubit(Qubit):
75
75
  def drop(self):
76
76
  """Drop the qubit in-place."""
77
77
  self.state = QubitState.Lost
78
-
79
-
80
- @dataclass
81
- class PyQrackWire:
82
- qubit: PyQrackQubit
@@ -0,0 +1 @@
1
+ from . import gate as gate
@@ -0,0 +1,136 @@
1
+ import math
2
+ from typing import Any
3
+
4
+ from kirin import interp
5
+ from kirin.dialects import ilist
6
+
7
+ from bloqade.squin import gate
8
+ from pyqrack.pauli import Pauli
9
+ from bloqade.pyqrack.reg import PyQrackQubit
10
+ from bloqade.pyqrack.target import PyQrackInterpreter
11
+ from bloqade.squin.gate.stmts import (
12
+ CX,
13
+ CY,
14
+ CZ,
15
+ U3,
16
+ H,
17
+ S,
18
+ T,
19
+ X,
20
+ Y,
21
+ Z,
22
+ Rx,
23
+ Ry,
24
+ Rz,
25
+ SqrtX,
26
+ SqrtY,
27
+ )
28
+
29
+
30
+ @gate.dialect.register(key="pyqrack")
31
+ class PyQrackMethods(interp.MethodTable):
32
+
33
+ @interp.impl(X)
34
+ @interp.impl(Y)
35
+ @interp.impl(Z)
36
+ @interp.impl(H)
37
+ def single_qubit_gate(
38
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: X | Y | Z | H
39
+ ):
40
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
41
+ method_name = stmt.name.lower()
42
+ for qbit in qubits:
43
+ if qbit.is_active():
44
+ getattr(qbit.sim_reg, method_name)(qbit.addr)
45
+
46
+ @interp.impl(T)
47
+ @interp.impl(S)
48
+ def single_qubit_nh_gate(
49
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: S | T
50
+ ):
51
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
52
+
53
+ method_name = stmt.name.lower()
54
+ if stmt.adjoint:
55
+ method_name = "adj" + method_name
56
+
57
+ for qbit in qubits:
58
+ if qbit.is_active():
59
+ getattr(qbit.sim_reg, method_name)(qbit.addr)
60
+ qbit.sim_reg.r
61
+
62
+ @interp.impl(SqrtX)
63
+ @interp.impl(SqrtY)
64
+ def sqrt_x(
65
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: SqrtX | SqrtY
66
+ ):
67
+ angle = math.pi / 2
68
+
69
+ if isinstance(stmt, SqrtX):
70
+ axis = Pauli.PauliX
71
+ else:
72
+ angle *= -1
73
+ axis = Pauli.PauliY
74
+
75
+ if stmt.adjoint:
76
+ angle *= -1
77
+
78
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
79
+ for qbit in qubits:
80
+ if qbit.is_active():
81
+ qbit.sim_reg.r(axis, angle, qbit.addr)
82
+
83
+ @interp.impl(Rx)
84
+ @interp.impl(Ry)
85
+ @interp.impl(Rz)
86
+ def rot(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Rx | Ry | Rz):
87
+ match stmt:
88
+ case Rx():
89
+ axis = Pauli.PauliX
90
+ case Ry():
91
+ axis = Pauli.PauliY
92
+ case Rz():
93
+ axis = Pauli.PauliZ
94
+
95
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
96
+
97
+ # NOTE: convert turns to radians
98
+ angle = frame.get(stmt.angle) * 2 * math.pi
99
+
100
+ for qbit in qubits:
101
+ if qbit.is_active():
102
+ qbit.sim_reg.r(axis, angle, qbit.addr)
103
+
104
+ @interp.impl(CX)
105
+ @interp.impl(CY)
106
+ @interp.impl(CZ)
107
+ def control(
108
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: CX | CY | CZ
109
+ ):
110
+ controls: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.controls)
111
+ targets: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.targets)
112
+
113
+ if len(controls) != len(targets):
114
+ raise RuntimeError(
115
+ f"Found {len(controls)} controls but {len(targets)} targets when trying to evaluate {stmt}."
116
+ )
117
+
118
+ # NOTE: pyqrack convention "multi-control-x"
119
+ method_name = "m" + stmt.name.lower()
120
+
121
+ for control, target in zip(controls, targets):
122
+ if control.is_active() and target.is_active():
123
+ getattr(control.sim_reg, method_name)([control.addr], target.addr)
124
+
125
+ @interp.impl(U3)
126
+ def u3(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: U3):
127
+ theta = frame.get(stmt.theta) * 2 * math.pi
128
+ phi = frame.get(stmt.phi) * 2 * math.pi
129
+ lam = frame.get(stmt.lam) * 2 * math.pi
130
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
131
+
132
+ for qbit in qubits:
133
+ if not qbit.is_active():
134
+ continue
135
+
136
+ qbit.sim_reg.u(qbit.addr, theta, phi, lam)