bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,79 @@
1
+ from itertools import chain
2
+
3
+ from kirin import ir, rewrite
4
+ from kirin.dialects import py, func
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+ from kirin.analysis.callgraph import CallGraph
7
+
8
+ from bloqade.native import kernel, broadcast
9
+ from bloqade.squin.gate import stmts, dialect as gate_dialect
10
+ from bloqade.rewrite.passes import CallGraphPass, UpdateDialectsOnCallGraph
11
+
12
+
13
+ class GateRule(RewriteRule):
14
+ SQUIN_MAPPING: dict[type[ir.Statement], tuple[ir.Method, ...]] = {
15
+ stmts.X: (broadcast.x,),
16
+ stmts.Y: (broadcast.y,),
17
+ stmts.Z: (broadcast.z,),
18
+ stmts.H: (broadcast.h,),
19
+ stmts.S: (broadcast.s, broadcast.s_adj),
20
+ stmts.T: (broadcast.t, broadcast.t_adj),
21
+ stmts.SqrtX: (broadcast.sqrt_x, broadcast.sqrt_x_adj),
22
+ stmts.SqrtY: (broadcast.sqrt_y, broadcast.sqrt_y_adj),
23
+ stmts.Rx: (broadcast.rx,),
24
+ stmts.Ry: (broadcast.ry,),
25
+ stmts.Rz: (broadcast.rz,),
26
+ stmts.CX: (broadcast.cx,),
27
+ stmts.CY: (broadcast.cy,),
28
+ stmts.CZ: (broadcast.cz,),
29
+ stmts.U3: (broadcast.u3,),
30
+ }
31
+
32
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
33
+ if (native_methods := self.SQUIN_MAPPING.get(type(node))) is None:
34
+ return RewriteResult()
35
+
36
+ if isinstance(node, stmts.SingleQubitNonHermitianGate):
37
+ native_method = native_methods[1] if node.adjoint else native_methods[0]
38
+ else:
39
+ native_method = native_methods[0]
40
+
41
+ # do not rewrite in invoke because callgraph pass will be looking for invoke statements
42
+ (callee := py.Constant(native_method)).insert_before(node)
43
+ node.replace_by(func.Call(callee.result, tuple(node.args), kwargs=()))
44
+
45
+ return RewriteResult(has_done_something=True)
46
+
47
+
48
+ class SquinToNative:
49
+ """A Target that converts Squin gates to native gates."""
50
+
51
+ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:
52
+ """Convert Squin gates to native gates.
53
+
54
+ Args:
55
+ mt (ir.Method): The method to convert.
56
+ no_raise (bool, optional): Whether to suppress errors. Defaults to True.
57
+
58
+ Returns:
59
+ ir.Method: The converted method.
60
+ """
61
+ old_callgraph = CallGraph(mt)
62
+ all_dialects = chain.from_iterable(
63
+ ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers
64
+ )
65
+ combined_dialects = mt.dialects.union(all_dialects).union(kernel)
66
+
67
+ out = mt.similar(combined_dialects)
68
+ UpdateDialectsOnCallGraph(combined_dialects, no_raise=no_raise)(out)
69
+ CallGraphPass(combined_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(
70
+ out
71
+ )
72
+ # verify all kernels in the callgraph and discard gate dialect
73
+ out.dialects.discard(gate_dialect)
74
+ new_callgraph = CallGraph(out)
75
+ for ker in new_callgraph.edges.keys():
76
+ ker.dialects.discard(gate_dialect)
77
+ ker.verify()
78
+
79
+ return out
@@ -3,7 +3,6 @@ from .reg import (
3
3
  CRegister as CRegister,
4
4
  QubitState as QubitState,
5
5
  Measurement as Measurement,
6
- PyQrackWire as PyQrackWire,
7
6
  PyQrackQubit as PyQrackQubit,
8
7
  )
9
8
  from .base import (
@@ -16,9 +15,10 @@ from .task import PyQrackSimulatorTask as PyQrackSimulatorTask
16
15
  # NOTE: The following import is for registering the method tables
17
16
  from .noise import native as native
18
17
  from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
19
- from .squin import op as op, noise as noise, qubit as qubit
18
+ from .squin import gate as gate, noise as noise, qubit as qubit
20
19
  from .device import (
21
20
  StackMemorySimulator as StackMemorySimulator,
22
21
  DynamicMemorySimulator as DynamicMemorySimulator,
23
22
  )
23
+ from .native import NativeMethods as NativeMethods
24
24
  from .target import PyQrack as PyQrack
bloqade/pyqrack/base.py CHANGED
@@ -48,7 +48,7 @@ def _default_pyqrack_args() -> PyQrackOptions:
48
48
  isSchmidtDecomposeMulti=True,
49
49
  isSchmidtDecompose=True,
50
50
  isStabilizerHybrid=False,
51
- isBinaryDecisionTree=True,
51
+ isBinaryDecisionTree=False,
52
52
  isPaged=True,
53
53
  isCpuGpuHybrid=True,
54
54
  isOpenCL=True,
@@ -146,7 +146,13 @@ class PyQrackInterpreter(Interpreter, typing.Generic[MemoryType]):
146
146
  loss_m_result: Measurement = field(default=Measurement.One, kw_only=True)
147
147
  """The value of a measurement result when a qubit is lost."""
148
148
 
149
+ global_measurement_id: int = field(init=False, default=0)
150
+
149
151
  def initialize(self) -> Self:
150
152
  super().initialize()
151
153
  self.memory.reset() # reset allocated qubits
152
154
  return self
155
+
156
+ def set_global_measurement_id(self, m: Measurement):
157
+ m.measurement_id = self.global_measurement_id
158
+ self.global_measurement_id += 1
bloqade/pyqrack/device.py CHANGED
@@ -1,8 +1,9 @@
1
- from typing import Any, TypeVar, ParamSpec
1
+ from typing import Any, TypeVar, ParamSpec, NamedTuple
2
2
  from dataclasses import field, dataclass
3
3
 
4
4
  import numpy as np
5
5
  from kirin import ir
6
+ from kirin.dialects.ilist import IList
6
7
 
7
8
  from pyqrack.pauli import Pauli
8
9
  from bloqade.device import AbstractSimulatorDevice
@@ -16,13 +17,153 @@ from bloqade.pyqrack.base import (
16
17
  _default_pyqrack_args,
17
18
  )
18
19
  from bloqade.pyqrack.task import PyQrackSimulatorTask
19
- from bloqade.analysis.address.lattice import AnyAddress
20
+ from pyqrack.qrack_simulator import QrackSimulator
21
+ from bloqade.analysis.address.lattice import UnknownReg, UnknownQubit
20
22
  from bloqade.analysis.address.analysis import AddressAnalysis
21
23
 
22
24
  RetType = TypeVar("RetType")
23
25
  Params = ParamSpec("Params")
24
26
 
25
27
 
28
+ class QuantumState(NamedTuple):
29
+ """
30
+ A representation of a quantum state as a density matrix, where the density matrix is
31
+ rho = sum_i eigenvalues[i] |eigenvectors[:,i]><eigenvectors[:,i]|.
32
+
33
+ This representation is efficient for low-rank density matrices by only storing
34
+ the non-zero eigenvalues and corresponding eigenvectors of the density matrix.
35
+ For example, a pure state has only one non-zero eigenvalue equal to 1.0.
36
+
37
+ Endianness and qubit ordering of the state vector is consistent with Cirq, where
38
+ eigenvectors[0,0] corresponds to the amplitude of the |00..000> element of the zeroth eigenvector;
39
+ eigenvectors[1,0] corresponds to the amplitude of the |00..001> element of the zeroth eigenvector;
40
+ eigenvectors[3,0] corresponds to the amplitude of the |00..011> element of the zeroth eigenvector;
41
+ eigenvectors[-1,0] corresponds to the amplitude of the |11..111> element of the zeroth eigenvector.
42
+ A flip of the LAST bit |00..000><00..001| corresponds to applying a PauliX gate to the FIRST qubit.
43
+ A flip of the FIRST bit |00..000><10..000| corresponds to applying a PauliX gate to the LAST qubit.
44
+
45
+ Attributes:
46
+ eigenvalues (1d np.ndarray):
47
+ The non-zero eigenvalues of the density matrix.
48
+ eigenvectors (2d np.ndarray):
49
+ The corresponding eigenvectors of the density matrix,
50
+ where eigenvectors[:,i] is the i-th eigenvector.
51
+ Methods:
52
+ Not Implemented, pending https://github.com/QuEraComputing/bloqade-circuit/issues/447
53
+ """
54
+
55
+ eigenvalues: np.ndarray
56
+ eigenvectors: np.ndarray
57
+
58
+ def canonicalize(self, tol: float = 1e-12) -> "QuantumState":
59
+ raise NotImplementedError(
60
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
61
+ )
62
+
63
+ def __add__(self, other: "QuantumState") -> "QuantumState":
64
+ raise NotImplementedError(
65
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
66
+ )
67
+
68
+ def __mul__(self, scalar: float) -> "QuantumState":
69
+ raise NotImplementedError(
70
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
71
+ )
72
+
73
+ @property
74
+ def dense(self) -> np.ndarray[tuple[int, int], np.complexfloating]:
75
+ raise NotImplementedError(
76
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
77
+ )
78
+
79
+ def __matmul__(self, right: "cirq.Circuit") -> "QuantumState": # noqa: F821
80
+ raise NotImplementedError(
81
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
82
+ )
83
+
84
+ def expect(self, operator: Any) -> float:
85
+ raise NotImplementedError(
86
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
87
+ )
88
+
89
+ def probability(self) -> np.ndarray[tuple[int], np.floating]:
90
+ raise NotImplementedError(
91
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
92
+ )
93
+
94
+ def von_neumann_entropy(self) -> float:
95
+ raise NotImplementedError(
96
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
97
+ )
98
+
99
+ @property
100
+ def qubit_basis(self) -> list[PyQrackQubit]:
101
+ raise NotImplementedError(
102
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
103
+ )
104
+
105
+ def reduced_density_matrix(
106
+ self, qubits: list[PyQrackQubit], tol: float = 1e-12
107
+ ) -> "QuantumState":
108
+ raise NotImplementedError(
109
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
110
+ )
111
+
112
+ def overlap(self, other: "QuantumState") -> complex:
113
+ raise NotImplementedError(
114
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
115
+ )
116
+
117
+
118
+ def _pyqrack_reduced_density_matrix(
119
+ inds: tuple[int, ...], sim_reg: QrackSimulator, tol: float = 1e-12
120
+ ) -> QuantumState:
121
+ """
122
+ Extract the reduced density matrix representing the state of a list
123
+ of qubits from a PyQRack simulator register.
124
+
125
+ Inputs:
126
+ inds: A list of integers labeling the qubit registers to extract the reduced density matrix for
127
+ sim_reg: The PyQRack simulator register to extract the reduced density matrix from
128
+ tol: The tolerance for density matrix eigenvalues to be considered non-zero.
129
+ Outputs:
130
+ An eigh result containing the eigenvalues and eigenvectors of the reduced density matrix.
131
+ """
132
+ # Identify the rest of the qubits in the register
133
+ N = sim_reg.num_qubits()
134
+ other = tuple(set(range(N)).difference(inds))
135
+
136
+ if len(set(inds)) != len(inds):
137
+ raise ValueError("Qubits must be unique.")
138
+
139
+ if max(inds) > N - 1:
140
+ raise ValueError(
141
+ f"Qubit indices {inds} exceed the number of qubits in the register {N}."
142
+ )
143
+
144
+ reordering = inds + other
145
+ # Fix pyqrack edannes to be consistent with Cirq.
146
+ reordering = tuple(N - 1 - x for x in reordering)
147
+ # Extract the statevector from the PyQRack qubits
148
+ statevector = np.array(sim_reg.out_ket())
149
+ # Reshape into a (2,2,2, ..., 2) tensor
150
+ vec_f = np.reshape(statevector, (2,) * N)
151
+ # Reorder the indexes to obey the order of the qubits
152
+ vec_p = np.transpose(vec_f, reordering)
153
+ # Rehape into a 2^N by 2^M matrix to compute the singular value decomposition
154
+ vec_svd = np.reshape(vec_p, (2 ** len(inds), 2 ** len(other)))
155
+ # The singular values and vectors are the eigenspace of the reduced density matrix
156
+ s, v, d = np.linalg.svd(vec_svd, full_matrices=False)
157
+
158
+ # Remove the negligible singular values
159
+ nonzero_inds = np.where(np.abs(v) > tol)[0]
160
+ s = s[:, nonzero_inds]
161
+ v = v[nonzero_inds] ** 2
162
+ # Forge into the correct result type
163
+ result = QuantumState(eigenvalues=v, eigenvectors=s)
164
+ return result
165
+
166
+
26
167
  @dataclass
27
168
  class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
28
169
  """PyQrack simulation device base class."""
@@ -100,6 +241,51 @@ class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
100
241
 
101
242
  return sim_reg.pauli_expectation(qubit_ids, pauli)
102
243
 
244
+ @staticmethod
245
+ def quantum_state(
246
+ qubits: list[PyQrackQubit] | IList[PyQrackQubit, Any], tol: float = 1e-12
247
+ ) -> "QuantumState":
248
+ """
249
+ Extract the reduced density matrix representing the state of a list
250
+ of qubits from a PyQRack simulator register.
251
+
252
+ Inputs:
253
+ qubits: A list of PyQRack qubits to extract the reduced density matrix for
254
+ tol: The tolerance for density matrix eigenvalues to be considered non-zero.
255
+ Outputs:
256
+ An eigh result containing the eigenvalues and eigenvectors of the reduced density matrix.
257
+ """
258
+ if len(qubits) == 0:
259
+ return QuantumState(
260
+ eigenvalues=np.array([]), eigenvectors=np.array([]).reshape(0, 0)
261
+ )
262
+ sim_reg = qubits[0].sim_reg
263
+
264
+ if not all([x.sim_reg is sim_reg for x in qubits]):
265
+ raise ValueError("All qubits must be from the same simulator register.")
266
+ inds: tuple[int, ...] = tuple(qubit.addr for qubit in qubits)
267
+
268
+ return _pyqrack_reduced_density_matrix(inds, sim_reg, tol)
269
+
270
+ @classmethod
271
+ def reduced_density_matrix(
272
+ cls, qubits: list[PyQrackQubit] | IList[PyQrackQubit, Any], tol: float = 1e-12
273
+ ) -> np.ndarray:
274
+ """
275
+ Extract the reduced density matrix representing the state of a list
276
+ of qubits from a PyQRack simulator register.
277
+
278
+ Inputs:
279
+ qubits: A list of PyQRack qubits to extract the reduced density matrix for
280
+ tol: The tolerance for density matrix eigenvalues to be considered non-zero.
281
+ Outputs:
282
+ A dense 2^n x 2^n numpy array representing the reduced density matrix.
283
+ """
284
+ rdm = cls.quantum_state(qubits, tol)
285
+ return np.einsum(
286
+ "ax,x,bx", rdm.eigenvectors, rdm.eigenvalues, rdm.eigenvectors.conj()
287
+ )
288
+
103
289
 
104
290
  @dataclass
105
291
  class StackMemorySimulator(PyQrackSimulatorBase):
@@ -167,9 +353,9 @@ class StackMemorySimulator(PyQrackSimulatorBase):
167
353
  kwargs = {}
168
354
 
169
355
  address_analysis = AddressAnalysis(dialects=kernel.dialects)
170
- frame, _ = address_analysis.run_analysis(kernel)
356
+ frame, _ = address_analysis.run(kernel)
171
357
  if self.min_qubits == 0 and any(
172
- isinstance(a, AnyAddress) for a in frame.entries.values()
358
+ isinstance(a, (UnknownQubit, UnknownReg)) for a in frame.entries.values()
173
359
  ):
174
360
  raise ValueError(
175
361
  "All addresses must be resolved. Or set min_qubits to a positive integer."
@@ -0,0 +1,49 @@
1
+ import math
2
+ from typing import Any
3
+
4
+ from kirin import interp
5
+ from kirin.dialects import ilist
6
+
7
+ from pyqrack import Pauli
8
+ from bloqade.pyqrack import PyQrackQubit
9
+ from bloqade.pyqrack.base import PyQrackInterpreter
10
+ from bloqade.native.dialects.gate import stmts
11
+
12
+
13
+ @stmts.dialect.register(key="pyqrack")
14
+ class NativeMethods(interp.MethodTable):
15
+
16
+ @interp.impl(stmts.CZ)
17
+ def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.CZ):
18
+ controls = frame.get_casted(stmt.controls, ilist.IList[PyQrackQubit, Any])
19
+ targets = frame.get_casted(stmt.targets, ilist.IList[PyQrackQubit, Any])
20
+
21
+ for ctrl, trgt in zip(controls, targets):
22
+ if ctrl.is_active() and trgt.is_active():
23
+ ctrl.sim_reg.mcz([ctrl.addr], trgt.addr)
24
+
25
+ return ()
26
+
27
+ @interp.impl(stmts.R)
28
+ def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.R):
29
+ qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
30
+ rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
31
+ axis_angle = 2 * math.pi * frame.get_casted(stmt.axis_angle, float)
32
+ for qubit in qubits:
33
+ if qubit.is_active():
34
+ qubit.sim_reg.r(Pauli.PauliZ, axis_angle, qubit.addr)
35
+ qubit.sim_reg.r(Pauli.PauliX, rotation_angle, qubit.addr)
36
+ qubit.sim_reg.r(Pauli.PauliZ, -axis_angle, qubit.addr)
37
+
38
+ return ()
39
+
40
+ @interp.impl(stmts.Rz)
41
+ def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.Rz):
42
+ qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
43
+ rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
44
+
45
+ for qubit in qubits:
46
+ if qubit.is_active():
47
+ qubit.sim_reg.r(Pauli.PauliZ, rotation_angle, qubit.addr)
48
+
49
+ return ()
bloqade/pyqrack/reg.py CHANGED
@@ -2,15 +2,20 @@ import enum
2
2
  from typing import TYPE_CHECKING
3
3
  from dataclasses import dataclass
4
4
 
5
+ from bloqade.types import MeasurementResult
5
6
  from bloqade.qasm2.types import Qubit
6
7
 
7
8
  if TYPE_CHECKING:
8
9
  from pyqrack import QrackSimulator
9
10
 
10
11
 
11
- class Measurement(enum.IntEnum):
12
+ class Measurement(MeasurementResult, enum.IntEnum):
12
13
  """Enumeration of measurement results."""
13
14
 
15
+ def __init__(self, measurement_id: int = 0) -> None:
16
+ super().__init__()
17
+ self.measurement_id = measurement_id
18
+
14
19
  Zero = 0
15
20
  One = 1
16
21
  Lost = enum.auto()
@@ -70,8 +75,3 @@ class PyQrackQubit(Qubit):
70
75
  def drop(self):
71
76
  """Drop the qubit in-place."""
72
77
  self.state = QubitState.Lost
73
-
74
-
75
- @dataclass
76
- class PyQrackWire:
77
- qubit: PyQrackQubit
@@ -0,0 +1 @@
1
+ from . import gate as gate
@@ -0,0 +1,136 @@
1
+ import math
2
+ from typing import Any
3
+
4
+ from kirin import interp
5
+ from kirin.dialects import ilist
6
+
7
+ from bloqade.squin import gate
8
+ from pyqrack.pauli import Pauli
9
+ from bloqade.pyqrack.reg import PyQrackQubit
10
+ from bloqade.pyqrack.target import PyQrackInterpreter
11
+ from bloqade.squin.gate.stmts import (
12
+ CX,
13
+ CY,
14
+ CZ,
15
+ U3,
16
+ H,
17
+ S,
18
+ T,
19
+ X,
20
+ Y,
21
+ Z,
22
+ Rx,
23
+ Ry,
24
+ Rz,
25
+ SqrtX,
26
+ SqrtY,
27
+ )
28
+
29
+
30
+ @gate.dialect.register(key="pyqrack")
31
+ class PyQrackMethods(interp.MethodTable):
32
+
33
+ @interp.impl(X)
34
+ @interp.impl(Y)
35
+ @interp.impl(Z)
36
+ @interp.impl(H)
37
+ def single_qubit_gate(
38
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: X | Y | Z | H
39
+ ):
40
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
41
+ method_name = stmt.name.lower()
42
+ for qbit in qubits:
43
+ if qbit.is_active():
44
+ getattr(qbit.sim_reg, method_name)(qbit.addr)
45
+
46
+ @interp.impl(T)
47
+ @interp.impl(S)
48
+ def single_qubit_nh_gate(
49
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: S | T
50
+ ):
51
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
52
+
53
+ method_name = stmt.name.lower()
54
+ if stmt.adjoint:
55
+ method_name = "adj" + method_name
56
+
57
+ for qbit in qubits:
58
+ if qbit.is_active():
59
+ getattr(qbit.sim_reg, method_name)(qbit.addr)
60
+ qbit.sim_reg.r
61
+
62
+ @interp.impl(SqrtX)
63
+ @interp.impl(SqrtY)
64
+ def sqrt_x(
65
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: SqrtX | SqrtY
66
+ ):
67
+ angle = math.pi / 2
68
+
69
+ if isinstance(stmt, SqrtX):
70
+ axis = Pauli.PauliX
71
+ else:
72
+ angle *= -1
73
+ axis = Pauli.PauliY
74
+
75
+ if stmt.adjoint:
76
+ angle *= -1
77
+
78
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
79
+ for qbit in qubits:
80
+ if qbit.is_active():
81
+ qbit.sim_reg.r(axis, angle, qbit.addr)
82
+
83
+ @interp.impl(Rx)
84
+ @interp.impl(Ry)
85
+ @interp.impl(Rz)
86
+ def rot(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Rx | Ry | Rz):
87
+ match stmt:
88
+ case Rx():
89
+ axis = Pauli.PauliX
90
+ case Ry():
91
+ axis = Pauli.PauliY
92
+ case Rz():
93
+ axis = Pauli.PauliZ
94
+
95
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
96
+
97
+ # NOTE: convert turns to radians
98
+ angle = frame.get(stmt.angle) * 2 * math.pi
99
+
100
+ for qbit in qubits:
101
+ if qbit.is_active():
102
+ qbit.sim_reg.r(axis, angle, qbit.addr)
103
+
104
+ @interp.impl(CX)
105
+ @interp.impl(CY)
106
+ @interp.impl(CZ)
107
+ def control(
108
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: CX | CY | CZ
109
+ ):
110
+ controls: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.controls)
111
+ targets: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.targets)
112
+
113
+ if len(controls) != len(targets):
114
+ raise RuntimeError(
115
+ f"Found {len(controls)} controls but {len(targets)} targets when trying to evaluate {stmt}."
116
+ )
117
+
118
+ # NOTE: pyqrack convention "multi-control-x"
119
+ method_name = "m" + stmt.name.lower()
120
+
121
+ for control, target in zip(controls, targets):
122
+ if control.is_active() and target.is_active():
123
+ getattr(control.sim_reg, method_name)([control.addr], target.addr)
124
+
125
+ @interp.impl(U3)
126
+ def u3(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: U3):
127
+ theta = frame.get(stmt.theta) * 2 * math.pi
128
+ phi = frame.get(stmt.phi) * 2 * math.pi
129
+ lam = frame.get(stmt.lam) * 2 * math.pi
130
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
131
+
132
+ for qbit in qubits:
133
+ if not qbit.is_active():
134
+ continue
135
+
136
+ qbit.sim_reg.u(qbit.addr, theta, phi, lam)