bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,72 +2,110 @@ from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
3
  from kirin.dialects import ilist
4
4
 
5
- from bloqade.squin.op.types import OpType
5
+ from bloqade.types import QubitType
6
6
 
7
7
  from ._dialect import dialect
8
- from ..op.types import NumOperators
9
8
 
10
9
 
11
10
  @statement
12
11
  class NoiseChannel(ir.Statement):
13
- traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
14
- result: ir.ResultValue = info.result(OpType)
12
+ traits = frozenset({lowering.FromPythonCall()})
15
13
 
16
14
 
17
- @statement(dialect=dialect)
18
- class PauliError(NoiseChannel):
19
- basis: ir.SSAValue = info.argument(OpType)
20
- p: ir.SSAValue = info.argument(types.Float)
15
+ @statement
16
+ class SingleQubitNoiseChannel(NoiseChannel):
17
+ # NOTE: we are not adding e.g. qubits here, since inheriting then will
18
+ # change the order of the wrapper arguments
19
+ pass
20
+
21
+
22
+ @statement
23
+ class TwoQubitNoiseChannel(NoiseChannel):
24
+ pass
21
25
 
22
26
 
23
27
  @statement(dialect=dialect)
24
- class PPError(NoiseChannel):
28
+ class SingleQubitPauliChannel(SingleQubitNoiseChannel):
25
29
  """
26
- Pauli Product Error
30
+ This will apply one of the randomly chosen Pauli operators according to the
31
+ given probabilities (p_x, p_y, p_z).
27
32
  """
28
33
 
29
- op: ir.SSAValue = info.argument(OpType)
30
- p: ir.SSAValue = info.argument(types.Float)
34
+ px: ir.SSAValue = info.argument(types.Float)
35
+ py: ir.SSAValue = info.argument(types.Float)
36
+ pz: ir.SSAValue = info.argument(types.Float)
37
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
38
+
39
+
40
+ N = types.TypeVar("N", bound=types.Int)
31
41
 
32
42
 
33
43
  @statement(dialect=dialect)
34
- class Depolarize(NoiseChannel):
44
+ class TwoQubitPauliChannel(TwoQubitNoiseChannel):
35
45
  """
36
- Apply depolarize error to single qubit
46
+ This will apply one of the randomly chosen Pauli products:
47
+
48
+ {IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}
49
+
50
+ but the choice is weighed with the given probability.
51
+
52
+ NOTE: the given parameters are ordered as given in the list above!
37
53
  """
38
54
 
39
- p: ir.SSAValue = info.argument(types.Float)
55
+ probabilities: ir.SSAValue = info.argument(
56
+ ilist.IListType[QubitType, types.Literal(15)]
57
+ )
58
+ controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
59
+ targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
40
60
 
41
61
 
42
62
  @statement(dialect=dialect)
43
- class Depolarize2(NoiseChannel):
63
+ class Depolarize(SingleQubitNoiseChannel):
44
64
  """
45
- Apply correlated depolarize error to two qubit
65
+ Apply depolarize error to single qubit.
66
+
67
+ This randomly picks one of the three Pauli operators to apply. Each Pauli
68
+ operator has the probability `p / 3` to be selected. No operator is applied
69
+ with the probability `1 - p`.
46
70
  """
47
71
 
48
72
  p: ir.SSAValue = info.argument(types.Float)
73
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
49
74
 
50
75
 
51
76
  @statement(dialect=dialect)
52
- class SingleQubitPauliChannel(NoiseChannel):
53
- params: ir.SSAValue = info.argument(ilist.IListType[types.Float, types.Literal(3)])
77
+ class Depolarize2(TwoQubitNoiseChannel):
78
+ """
79
+ Apply correlated depolarize error to two qubits
54
80
 
81
+ This will apply one of the randomly chosen Pauli products each with probability `p / 15`:
55
82
 
56
- @statement(dialect=dialect)
57
- class TwoQubitPauliChannel(NoiseChannel):
58
- params: ir.SSAValue = info.argument(ilist.IListType[types.Float, types.Literal(15)])
83
+ `{IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ}`
84
+ """
85
+
86
+ p: ir.SSAValue = info.argument(types.Float)
87
+ controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
88
+ targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
59
89
 
60
90
 
61
91
  @statement(dialect=dialect)
62
- class QubitLoss(NoiseChannel):
92
+ class QubitLoss(SingleQubitNoiseChannel):
93
+ """
94
+ Apply an atom loss with channel.
95
+ """
96
+
63
97
  # NOTE: qubit loss error (not supported by Stim)
64
98
  p: ir.SSAValue = info.argument(types.Float)
99
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
65
100
 
66
101
 
67
102
  @statement(dialect=dialect)
68
- class StochasticUnitaryChannel(ir.Statement):
69
- operators: ir.SSAValue = info.argument(ilist.IListType[OpType, NumOperators])
70
- probabilities: ir.SSAValue = info.argument(
71
- ilist.IListType[types.Float, NumOperators]
103
+ class CorrelatedQubitLoss(NoiseChannel):
104
+ """
105
+ Apply a correlated atom loss channel.
106
+ """
107
+
108
+ p: ir.SSAValue = info.argument(types.Float)
109
+ qubits: ir.SSAValue = info.argument(
110
+ ilist.IListType[ilist.IListType[QubitType, N], types.Any]
72
111
  )
73
- result: ir.ResultValue = info.result(OpType)
@@ -1,47 +1,56 @@
1
1
  # create rewrite rule name SquinMeasureToStim using kirin
2
2
  import math
3
- from typing import List, Tuple, Callable
4
3
 
5
4
  import numpy as np
6
5
  from kirin import ir
7
6
  from kirin.dialects import py
8
7
  from kirin.rewrite.abc import RewriteRule, RewriteResult
9
8
 
10
- from bloqade.squin import op, qubit
9
+ from bloqade.squin import gate
11
10
 
12
11
 
13
- def sdag() -> list[ir.Statement]:
14
- return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)]
12
+ # Placeholder type, swap in an actual S statement with adjoint=True
13
+ # during the rewrite method
14
+ class Sdag(ir.Statement):
15
+ pass
16
+
17
+
18
+ class SqrtXdag(ir.Statement):
19
+ pass
20
+
21
+
22
+ class SqrtYdag(ir.Statement):
23
+ pass
15
24
 
16
25
 
17
26
  # (theta, phi, lam)
18
27
  U3_HALF_PI_ANGLE_TO_GATES: dict[
19
- tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]]
28
+ tuple[int, int, int], list[type[ir.Statement]] | list[None]
20
29
  ] = {
21
- (0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],),
22
- (0, 0, 1): lambda: ([op.stmts.S()],),
23
- (0, 0, 2): lambda: ([op.stmts.Z()],),
24
- (0, 0, 3): lambda: (sdag(),),
25
- (1, 0, 0): lambda: ([op.stmts.SqrtY()],),
26
- (1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]),
27
- (1, 0, 2): lambda: ([op.stmts.H()],),
28
- (1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]),
29
- (1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]),
30
- (1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
31
- (1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
32
- (1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
33
- (1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]),
34
- (1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
35
- (1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
36
- (1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
37
- (1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()),
38
- (1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
39
- (1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
40
- (1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()),
41
- (2, 0, 0): lambda: ([op.stmts.Y()],),
42
- (2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]),
43
- (2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]),
44
- (2, 0, 3): lambda: (sdag(), [op.stmts.Y()]),
30
+ (0, 0, 0): [None],
31
+ (0, 0, 1): [gate.stmts.S],
32
+ (0, 0, 2): [gate.stmts.Z],
33
+ (0, 0, 3): [Sdag],
34
+ (1, 0, 0): [gate.stmts.SqrtY],
35
+ (1, 0, 1): [gate.stmts.S, gate.stmts.SqrtY],
36
+ (1, 0, 2): [gate.stmts.H],
37
+ (1, 0, 3): [Sdag, gate.stmts.SqrtY],
38
+ (1, 1, 0): [gate.stmts.S, SqrtXdag],
39
+ (1, 1, 1): [gate.stmts.Z, SqrtXdag],
40
+ (1, 1, 2): [Sdag, SqrtXdag],
41
+ (1, 1, 3): [SqrtXdag],
42
+ (1, 2, 0): [gate.stmts.Z, SqrtYdag],
43
+ (1, 2, 1): [Sdag, SqrtYdag],
44
+ (1, 2, 2): [SqrtYdag],
45
+ (1, 2, 3): [gate.stmts.S, SqrtYdag],
46
+ (1, 3, 0): [Sdag, gate.stmts.SqrtX],
47
+ (1, 3, 1): [gate.stmts.SqrtX],
48
+ (1, 3, 2): [gate.stmts.S, gate.stmts.SqrtX],
49
+ (1, 3, 3): [gate.stmts.Z, gate.stmts.SqrtX],
50
+ (2, 0, 0): [gate.stmts.Y],
51
+ (2, 0, 1): [gate.stmts.S, gate.stmts.Y],
52
+ (2, 0, 2): [gate.stmts.X],
53
+ (2, 0, 3): [Sdag, gate.stmts.Y],
45
54
  }
46
55
 
47
56
 
@@ -61,8 +70,8 @@ class SquinU3ToClifford(RewriteRule):
61
70
  """
62
71
 
63
72
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
64
- if isinstance(node, (qubit.Apply, qubit.Broadcast)):
65
- return self.rewrite_ApplyOrBroadcast_onU3(node)
73
+ if isinstance(node, gate.stmts.U3):
74
+ return self.rewrite_U3(node)
66
75
  else:
67
76
  return RewriteResult()
68
77
 
@@ -87,35 +96,39 @@ class SquinU3ToClifford(RewriteRule):
87
96
  else:
88
97
  return round((angle / math.tau) % 1 * 4) % 4
89
98
 
90
- def rewrite_ApplyOrBroadcast_onU3(
91
- self, node: qubit.Apply | qubit.Broadcast
92
- ) -> RewriteResult:
99
+ def rewrite_U3(self, node: gate.stmts.U3) -> RewriteResult:
93
100
  """
94
101
  Rewrite Apply and Broadcast nodes to their clifford equivalent statements.
95
102
  """
96
- if not isinstance(node.operator.owner, op.stmts.U3):
97
- return RewriteResult()
98
103
 
99
- gates = self.decompose_U3_gates(node.operator.owner)
104
+ gates = self.decompose_U3_gates(node)
100
105
 
101
106
  if len(gates) == 0:
102
107
  return RewriteResult()
103
108
 
104
- for stmt_list in gates:
105
- for gate_stmt in stmt_list[:-1]:
106
- gate_stmt.insert_before(node)
107
-
108
- oper = stmt_list[-1]
109
- oper.insert_before(node)
110
- new_node = node.__class__(operator=oper.result, qubits=node.qubits)
111
- new_node.insert_before(node)
109
+ # Get rid of the U3 gate altogether if it's identity
110
+ if len(gates) == 1 and gates[0] is None:
111
+ node.delete()
112
+ return RewriteResult(has_done_something=True)
113
+
114
+ for gate_stmt in gates:
115
+ if gate_stmt is Sdag:
116
+ new_stmt = gate.stmts.S(adjoint=True, qubits=node.qubits)
117
+ elif gate_stmt is SqrtXdag:
118
+ new_stmt = gate.stmts.SqrtX(adjoint=True, qubits=node.qubits)
119
+ elif gate_stmt is SqrtYdag:
120
+ new_stmt = gate.stmts.SqrtY(adjoint=True, qubits=node.qubits)
121
+ else:
122
+ new_stmt = gate_stmt(qubits=node.qubits)
123
+ new_stmt.insert_before(node)
112
124
 
113
125
  node.delete()
114
126
 
115
- # rewrite U3 to clifford gates
116
127
  return RewriteResult(has_done_something=True)
117
128
 
118
- def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...]:
129
+ def decompose_U3_gates(
130
+ self, node: gate.stmts.U3
131
+ ) -> list[type[ir.Statement]] | list[None]:
119
132
  """
120
133
  Rewrite U3 statements to clifford gates if possible.
121
134
  """
@@ -124,7 +137,13 @@ class SquinU3ToClifford(RewriteRule):
124
137
  lam = self.get_constant(node.lam)
125
138
 
126
139
  if theta is None or phi is None or lam is None:
127
- return ()
140
+ return []
141
+
142
+ # Angles will be in units of turns, we convert to radians
143
+ # to allow for the old logic to work
144
+ theta = theta * math.tau
145
+ phi = phi * math.tau
146
+ lam = lam * math.tau
128
147
 
129
148
  # For U3(2*pi*n, phi, lam) = U3(0, 0, lam + phi) which is a Z rotation.
130
149
  if np.isclose(np.mod(theta, math.tau), 0):
@@ -139,13 +158,13 @@ class SquinU3ToClifford(RewriteRule):
139
158
  lam_half_pi: int | None = self.resolve_angle(lam)
140
159
 
141
160
  if theta_half_pi is None or phi_half_pi is None or lam_half_pi is None:
142
- return ()
161
+ return []
143
162
 
144
163
  angles_key = (theta_half_pi, phi_half_pi, lam_half_pi)
145
164
  if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
146
165
  angles_key = equivalent_u3_para(*angles_key)
147
166
  if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
148
- return ()
167
+ return []
149
168
 
150
169
  gates_stmts = U3_HALF_PI_ANGLE_TO_GATES.get(angles_key)
151
170
 
@@ -154,4 +173,4 @@ class SquinU3ToClifford(RewriteRule):
154
173
  gates_stmts is not None
155
174
  ), "internal error, U3 gates not found for angles: {}".format(angles_key)
156
175
 
157
- return gates_stmts()
176
+ return gates_stmts
@@ -1,7 +1,5 @@
1
1
  from .wrap_analysis import (
2
- SitesAttribute as SitesAttribute,
3
2
  AddressAttribute as AddressAttribute,
4
- WrapOpSiteAnalysis as WrapOpSiteAnalysis,
5
3
  WrapAddressAnalysis as WrapAddressAnalysis,
6
4
  )
7
5
  from .U3_to_clifford import SquinU3ToClifford as SquinU3ToClifford
@@ -1,14 +1,14 @@
1
1
  from kirin import ir
2
2
  from kirin.rewrite.abc import RewriteRule, RewriteResult
3
3
 
4
- from bloqade.squin import qubit
4
+ from bloqade import qubit
5
5
 
6
6
 
7
7
  class RemoveDeadRegister(RewriteRule):
8
8
 
9
9
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
10
10
 
11
- if not isinstance(node, qubit.New):
11
+ if not isinstance(node, qubit.stmts.New):
12
12
  return RewriteResult()
13
13
 
14
14
  if bool(node.result.uses):
@@ -5,12 +5,11 @@ from kirin import ir
5
5
  from kirin.rewrite.abc import RewriteRule, RewriteResult
6
6
  from kirin.print.printer import Printer
7
7
 
8
- from bloqade.squin import op, wire
8
+ from bloqade import qubit
9
9
  from bloqade.analysis.address import Address
10
- from bloqade.squin.analysis.nsites import Sites
11
10
 
12
11
 
13
- @wire.dialect.register
12
+ @qubit.dialect.register
14
13
  @dataclass
15
14
  class AddressAttribute(ir.Attribute):
16
15
 
@@ -25,21 +24,6 @@ class AddressAttribute(ir.Attribute):
25
24
  printer.print(self.address)
26
25
 
27
26
 
28
- @op.dialect.register
29
- @dataclass
30
- class SitesAttribute(ir.Attribute):
31
-
32
- name = "Sites"
33
- sites: Sites
34
-
35
- def __hash__(self) -> int:
36
- return hash(self.sites)
37
-
38
- def print_impl(self, printer: Printer) -> None:
39
- # Can return to implementing this later
40
- printer.print(self.sites)
41
-
42
-
43
27
  @dataclass
44
28
  class WrapAnalysis(RewriteRule):
45
29
 
@@ -61,7 +45,8 @@ class WrapAddressAnalysis(WrapAnalysis):
61
45
  address_analysis: dict[ir.SSAValue, Address]
62
46
 
63
47
  def wrap(self, value: ir.SSAValue) -> bool:
64
- address_analysis_result = self.address_analysis[value]
48
+ if (address_analysis_result := self.address_analysis.get(value)) is None:
49
+ return False
65
50
 
66
51
  if value.hints.get("address") is not None:
67
52
  return False
@@ -69,19 +54,3 @@ class WrapAddressAnalysis(WrapAnalysis):
69
54
  value.hints["address"] = AddressAttribute(address_analysis_result)
70
55
 
71
56
  return True
72
-
73
-
74
- @dataclass
75
- class WrapOpSiteAnalysis(WrapAnalysis):
76
-
77
- op_site_analysis: dict[ir.SSAValue, Sites]
78
-
79
- def wrap(self, value: ir.SSAValue) -> bool:
80
- op_site_analysis_result = self.op_site_analysis[value]
81
-
82
- if value.hints.get("sites") is not None:
83
- return False
84
-
85
- value.hints["sites"] = SitesAttribute(op_site_analysis_result)
86
-
87
- return True
File without changes
@@ -0,0 +1,34 @@
1
+ from .gate import (
2
+ h as h,
3
+ s as s,
4
+ t as t,
5
+ x as x,
6
+ y as y,
7
+ z as z,
8
+ cx as cx,
9
+ cy as cy,
10
+ cz as cz,
11
+ rx as rx,
12
+ ry as ry,
13
+ rz as rz,
14
+ u3 as u3,
15
+ s_adj as s_adj,
16
+ shift as shift,
17
+ t_adj as t_adj,
18
+ sqrt_x as sqrt_x,
19
+ sqrt_y as sqrt_y,
20
+ sqrt_z as sqrt_z,
21
+ sqrt_x_adj as sqrt_x_adj,
22
+ sqrt_y_adj as sqrt_y_adj,
23
+ sqrt_z_adj as sqrt_z_adj,
24
+ )
25
+ from .noise import (
26
+ bit_flip as bit_flip,
27
+ depolarize as depolarize,
28
+ qubit_loss as qubit_loss,
29
+ depolarize2 as depolarize2,
30
+ correlated_qubit_loss as correlated_qubit_loss,
31
+ two_qubit_pauli_channel as two_qubit_pauli_channel,
32
+ single_qubit_pauli_channel as single_qubit_pauli_channel,
33
+ )
34
+ from ._qubit import reset as reset, measure as measure
@@ -0,0 +1,4 @@
1
+ from bloqade.qubit.stdlib.broadcast import (
2
+ reset as reset,
3
+ measure as measure,
4
+ )