bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,79 @@
1
+ from itertools import chain
2
+
3
+ from kirin import ir, rewrite
4
+ from kirin.dialects import py, func
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+ from kirin.analysis.callgraph import CallGraph
7
+
8
+ from bloqade.native import kernel, broadcast
9
+ from bloqade.squin.gate import stmts, dialect as gate_dialect
10
+ from bloqade.rewrite.passes import CallGraphPass, UpdateDialectsOnCallGraph
11
+
12
+
13
+ class GateRule(RewriteRule):
14
+ SQUIN_MAPPING: dict[type[ir.Statement], tuple[ir.Method, ...]] = {
15
+ stmts.X: (broadcast.x,),
16
+ stmts.Y: (broadcast.y,),
17
+ stmts.Z: (broadcast.z,),
18
+ stmts.H: (broadcast.h,),
19
+ stmts.S: (broadcast.s, broadcast.s_adj),
20
+ stmts.T: (broadcast.t, broadcast.t_adj),
21
+ stmts.SqrtX: (broadcast.sqrt_x, broadcast.sqrt_x_adj),
22
+ stmts.SqrtY: (broadcast.sqrt_y, broadcast.sqrt_y_adj),
23
+ stmts.Rx: (broadcast.rx,),
24
+ stmts.Ry: (broadcast.ry,),
25
+ stmts.Rz: (broadcast.rz,),
26
+ stmts.CX: (broadcast.cx,),
27
+ stmts.CY: (broadcast.cy,),
28
+ stmts.CZ: (broadcast.cz,),
29
+ stmts.U3: (broadcast.u3,),
30
+ }
31
+
32
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
33
+ if (native_methods := self.SQUIN_MAPPING.get(type(node))) is None:
34
+ return RewriteResult()
35
+
36
+ if isinstance(node, stmts.SingleQubitNonHermitianGate):
37
+ native_method = native_methods[1] if node.adjoint else native_methods[0]
38
+ else:
39
+ native_method = native_methods[0]
40
+
41
+ # do not rewrite in invoke because callgraph pass will be looking for invoke statements
42
+ (callee := py.Constant(native_method)).insert_before(node)
43
+ node.replace_by(func.Call(callee.result, tuple(node.args), kwargs=()))
44
+
45
+ return RewriteResult(has_done_something=True)
46
+
47
+
48
+ class SquinToNative:
49
+ """A Target that converts Squin gates to native gates."""
50
+
51
+ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:
52
+ """Convert Squin gates to native gates.
53
+
54
+ Args:
55
+ mt (ir.Method): The method to convert.
56
+ no_raise (bool, optional): Whether to suppress errors. Defaults to True.
57
+
58
+ Returns:
59
+ ir.Method: The converted method.
60
+ """
61
+ old_callgraph = CallGraph(mt)
62
+ all_dialects = chain.from_iterable(
63
+ ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers
64
+ )
65
+ combined_dialects = mt.dialects.union(all_dialects).union(kernel)
66
+
67
+ out = mt.similar(combined_dialects)
68
+ UpdateDialectsOnCallGraph(combined_dialects, no_raise=no_raise)(out)
69
+ CallGraphPass(combined_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(
70
+ out
71
+ )
72
+ # verify all kernels in the callgraph and discard gate dialect
73
+ out.dialects.discard(gate_dialect)
74
+ new_callgraph = CallGraph(out)
75
+ for ker in new_callgraph.edges.keys():
76
+ ker.dialects.discard(gate_dialect)
77
+ ker.verify()
78
+
79
+ return out
@@ -3,7 +3,6 @@ from .reg import (
3
3
  CRegister as CRegister,
4
4
  QubitState as QubitState,
5
5
  Measurement as Measurement,
6
- PyQrackWire as PyQrackWire,
7
6
  PyQrackQubit as PyQrackQubit,
8
7
  )
9
8
  from .base import (
@@ -16,9 +15,10 @@ from .task import PyQrackSimulatorTask as PyQrackSimulatorTask
16
15
  # NOTE: The following import is for registering the method tables
17
16
  from .noise import native as native
18
17
  from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
19
- from .squin import op as op, noise as noise, qubit as qubit
18
+ from .squin import gate as gate, noise as noise, qubit as qubit
20
19
  from .device import (
21
20
  StackMemorySimulator as StackMemorySimulator,
22
21
  DynamicMemorySimulator as DynamicMemorySimulator,
23
22
  )
23
+ from .native import NativeMethods as NativeMethods
24
24
  from .target import PyQrack as PyQrack
bloqade/pyqrack/base.py CHANGED
@@ -48,7 +48,7 @@ def _default_pyqrack_args() -> PyQrackOptions:
48
48
  isSchmidtDecomposeMulti=True,
49
49
  isSchmidtDecompose=True,
50
50
  isStabilizerHybrid=False,
51
- isBinaryDecisionTree=True,
51
+ isBinaryDecisionTree=False,
52
52
  isPaged=True,
53
53
  isCpuGpuHybrid=True,
54
54
  isOpenCL=True,
@@ -146,7 +146,13 @@ class PyQrackInterpreter(Interpreter, typing.Generic[MemoryType]):
146
146
  loss_m_result: Measurement = field(default=Measurement.One, kw_only=True)
147
147
  """The value of a measurement result when a qubit is lost."""
148
148
 
149
+ global_measurement_id: int = field(init=False, default=0)
150
+
149
151
  def initialize(self) -> Self:
150
152
  super().initialize()
151
153
  self.memory.reset() # reset allocated qubits
152
154
  return self
155
+
156
+ def set_global_measurement_id(self, m: Measurement):
157
+ m.measurement_id = self.global_measurement_id
158
+ self.global_measurement_id += 1
bloqade/pyqrack/device.py CHANGED
@@ -1,11 +1,10 @@
1
- from typing import Any, TypeVar, ParamSpec
1
+ from typing import Any, TypeVar, ParamSpec, NamedTuple
2
2
  from dataclasses import field, dataclass
3
3
 
4
4
  import numpy as np
5
5
  from kirin import ir
6
- from kirin.passes import fold
6
+ from kirin.dialects.ilist import IList
7
7
 
8
- from bloqade.squin import noise as squin_noise
9
8
  from pyqrack.pauli import Pauli
10
9
  from bloqade.device import AbstractSimulatorDevice
11
10
  from bloqade.pyqrack.reg import Measurement, PyQrackQubit
@@ -18,14 +17,153 @@ from bloqade.pyqrack.base import (
18
17
  _default_pyqrack_args,
19
18
  )
20
19
  from bloqade.pyqrack.task import PyQrackSimulatorTask
21
- from bloqade.squin.noise.rewrite import RewriteNoiseStmts
22
- from bloqade.analysis.address.lattice import AnyAddress
20
+ from pyqrack.qrack_simulator import QrackSimulator
21
+ from bloqade.analysis.address.lattice import UnknownReg, UnknownQubit
23
22
  from bloqade.analysis.address.analysis import AddressAnalysis
24
23
 
25
24
  RetType = TypeVar("RetType")
26
25
  Params = ParamSpec("Params")
27
26
 
28
27
 
28
+ class QuantumState(NamedTuple):
29
+ """
30
+ A representation of a quantum state as a density matrix, where the density matrix is
31
+ rho = sum_i eigenvalues[i] |eigenvectors[:,i]><eigenvectors[:,i]|.
32
+
33
+ This representation is efficient for low-rank density matrices by only storing
34
+ the non-zero eigenvalues and corresponding eigenvectors of the density matrix.
35
+ For example, a pure state has only one non-zero eigenvalue equal to 1.0.
36
+
37
+ Endianness and qubit ordering of the state vector is consistent with Cirq, where
38
+ eigenvectors[0,0] corresponds to the amplitude of the |00..000> element of the zeroth eigenvector;
39
+ eigenvectors[1,0] corresponds to the amplitude of the |00..001> element of the zeroth eigenvector;
40
+ eigenvectors[3,0] corresponds to the amplitude of the |00..011> element of the zeroth eigenvector;
41
+ eigenvectors[-1,0] corresponds to the amplitude of the |11..111> element of the zeroth eigenvector.
42
+ A flip of the LAST bit |00..000><00..001| corresponds to applying a PauliX gate to the FIRST qubit.
43
+ A flip of the FIRST bit |00..000><10..000| corresponds to applying a PauliX gate to the LAST qubit.
44
+
45
+ Attributes:
46
+ eigenvalues (1d np.ndarray):
47
+ The non-zero eigenvalues of the density matrix.
48
+ eigenvectors (2d np.ndarray):
49
+ The corresponding eigenvectors of the density matrix,
50
+ where eigenvectors[:,i] is the i-th eigenvector.
51
+ Methods:
52
+ Not Implemented, pending https://github.com/QuEraComputing/bloqade-circuit/issues/447
53
+ """
54
+
55
+ eigenvalues: np.ndarray
56
+ eigenvectors: np.ndarray
57
+
58
+ def canonicalize(self, tol: float = 1e-12) -> "QuantumState":
59
+ raise NotImplementedError(
60
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
61
+ )
62
+
63
+ def __add__(self, other: "QuantumState") -> "QuantumState":
64
+ raise NotImplementedError(
65
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
66
+ )
67
+
68
+ def __mul__(self, scalar: float) -> "QuantumState":
69
+ raise NotImplementedError(
70
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
71
+ )
72
+
73
+ @property
74
+ def dense(self) -> np.ndarray[tuple[int, int], np.complexfloating]:
75
+ raise NotImplementedError(
76
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
77
+ )
78
+
79
+ def __matmul__(self, right: "cirq.Circuit") -> "QuantumState": # noqa: F821
80
+ raise NotImplementedError(
81
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
82
+ )
83
+
84
+ def expect(self, operator: Any) -> float:
85
+ raise NotImplementedError(
86
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
87
+ )
88
+
89
+ def probability(self) -> np.ndarray[tuple[int], np.floating]:
90
+ raise NotImplementedError(
91
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
92
+ )
93
+
94
+ def von_neumann_entropy(self) -> float:
95
+ raise NotImplementedError(
96
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
97
+ )
98
+
99
+ @property
100
+ def qubit_basis(self) -> list[PyQrackQubit]:
101
+ raise NotImplementedError(
102
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
103
+ )
104
+
105
+ def reduced_density_matrix(
106
+ self, qubits: list[PyQrackQubit], tol: float = 1e-12
107
+ ) -> "QuantumState":
108
+ raise NotImplementedError(
109
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
110
+ )
111
+
112
+ def overlap(self, other: "QuantumState") -> complex:
113
+ raise NotImplementedError(
114
+ "https://github.com/QuEraComputing/bloqade-circuit/issues/447"
115
+ )
116
+
117
+
118
+ def _pyqrack_reduced_density_matrix(
119
+ inds: tuple[int, ...], sim_reg: QrackSimulator, tol: float = 1e-12
120
+ ) -> QuantumState:
121
+ """
122
+ Extract the reduced density matrix representing the state of a list
123
+ of qubits from a PyQRack simulator register.
124
+
125
+ Inputs:
126
+ inds: A list of integers labeling the qubit registers to extract the reduced density matrix for
127
+ sim_reg: The PyQRack simulator register to extract the reduced density matrix from
128
+ tol: The tolerance for density matrix eigenvalues to be considered non-zero.
129
+ Outputs:
130
+ An eigh result containing the eigenvalues and eigenvectors of the reduced density matrix.
131
+ """
132
+ # Identify the rest of the qubits in the register
133
+ N = sim_reg.num_qubits()
134
+ other = tuple(set(range(N)).difference(inds))
135
+
136
+ if len(set(inds)) != len(inds):
137
+ raise ValueError("Qubits must be unique.")
138
+
139
+ if max(inds) > N - 1:
140
+ raise ValueError(
141
+ f"Qubit indices {inds} exceed the number of qubits in the register {N}."
142
+ )
143
+
144
+ reordering = inds + other
145
+ # Fix pyqrack edannes to be consistent with Cirq.
146
+ reordering = tuple(N - 1 - x for x in reordering)
147
+ # Extract the statevector from the PyQRack qubits
148
+ statevector = np.array(sim_reg.out_ket())
149
+ # Reshape into a (2,2,2, ..., 2) tensor
150
+ vec_f = np.reshape(statevector, (2,) * N)
151
+ # Reorder the indexes to obey the order of the qubits
152
+ vec_p = np.transpose(vec_f, reordering)
153
+ # Rehape into a 2^N by 2^M matrix to compute the singular value decomposition
154
+ vec_svd = np.reshape(vec_p, (2 ** len(inds), 2 ** len(other)))
155
+ # The singular values and vectors are the eigenspace of the reduced density matrix
156
+ s, v, d = np.linalg.svd(vec_svd, full_matrices=False)
157
+
158
+ # Remove the negligible singular values
159
+ nonzero_inds = np.where(np.abs(v) > tol)[0]
160
+ s = s[:, nonzero_inds]
161
+ v = v[nonzero_inds] ** 2
162
+ # Forge into the correct result type
163
+ result = QuantumState(eigenvalues=v, eigenvectors=s)
164
+ return result
165
+
166
+
29
167
  @dataclass
30
168
  class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
31
169
  """PyQrack simulation device base class."""
@@ -50,23 +188,14 @@ class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
50
188
  kwargs: dict[str, Any],
51
189
  memory: MemoryType,
52
190
  ) -> PyQrackSimulatorTask[Params, RetType, MemoryType]:
53
-
54
- if squin_noise in mt.dialects:
55
- # NOTE: rewrite noise statements
56
- mt_ = mt.similar(mt.dialects)
57
- RewriteNoiseStmts(mt_.dialects)(mt_)
58
- fold.Fold(mt_.dialects)(mt_)
59
- else:
60
- mt_ = mt
61
-
62
191
  interp = PyQrackInterpreter(
63
- mt_.dialects,
192
+ mt.dialects,
64
193
  memory=memory,
65
194
  rng_state=self.rng_state,
66
195
  loss_m_result=self.loss_m_result,
67
196
  )
68
197
  return PyQrackSimulatorTask(
69
- kernel=mt_, args=args, kwargs=kwargs, pyqrack_interp=interp
198
+ kernel=mt, args=args, kwargs=kwargs, pyqrack_interp=interp
70
199
  )
71
200
 
72
201
  def state_vector(
@@ -112,6 +241,51 @@ class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
112
241
 
113
242
  return sim_reg.pauli_expectation(qubit_ids, pauli)
114
243
 
244
+ @staticmethod
245
+ def quantum_state(
246
+ qubits: list[PyQrackQubit] | IList[PyQrackQubit, Any], tol: float = 1e-12
247
+ ) -> "QuantumState":
248
+ """
249
+ Extract the reduced density matrix representing the state of a list
250
+ of qubits from a PyQRack simulator register.
251
+
252
+ Inputs:
253
+ qubits: A list of PyQRack qubits to extract the reduced density matrix for
254
+ tol: The tolerance for density matrix eigenvalues to be considered non-zero.
255
+ Outputs:
256
+ An eigh result containing the eigenvalues and eigenvectors of the reduced density matrix.
257
+ """
258
+ if len(qubits) == 0:
259
+ return QuantumState(
260
+ eigenvalues=np.array([]), eigenvectors=np.array([]).reshape(0, 0)
261
+ )
262
+ sim_reg = qubits[0].sim_reg
263
+
264
+ if not all([x.sim_reg is sim_reg for x in qubits]):
265
+ raise ValueError("All qubits must be from the same simulator register.")
266
+ inds: tuple[int, ...] = tuple(qubit.addr for qubit in qubits)
267
+
268
+ return _pyqrack_reduced_density_matrix(inds, sim_reg, tol)
269
+
270
+ @classmethod
271
+ def reduced_density_matrix(
272
+ cls, qubits: list[PyQrackQubit] | IList[PyQrackQubit, Any], tol: float = 1e-12
273
+ ) -> np.ndarray:
274
+ """
275
+ Extract the reduced density matrix representing the state of a list
276
+ of qubits from a PyQRack simulator register.
277
+
278
+ Inputs:
279
+ qubits: A list of PyQRack qubits to extract the reduced density matrix for
280
+ tol: The tolerance for density matrix eigenvalues to be considered non-zero.
281
+ Outputs:
282
+ A dense 2^n x 2^n numpy array representing the reduced density matrix.
283
+ """
284
+ rdm = cls.quantum_state(qubits, tol)
285
+ return np.einsum(
286
+ "ax,x,bx", rdm.eigenvectors, rdm.eigenvalues, rdm.eigenvectors.conj()
287
+ )
288
+
115
289
 
116
290
  @dataclass
117
291
  class StackMemorySimulator(PyQrackSimulatorBase):
@@ -179,9 +353,9 @@ class StackMemorySimulator(PyQrackSimulatorBase):
179
353
  kwargs = {}
180
354
 
181
355
  address_analysis = AddressAnalysis(dialects=kernel.dialects)
182
- frame, _ = address_analysis.run_analysis(kernel)
356
+ frame, _ = address_analysis.run(kernel)
183
357
  if self.min_qubits == 0 and any(
184
- isinstance(a, AnyAddress) for a in frame.entries.values()
358
+ isinstance(a, (UnknownQubit, UnknownReg)) for a in frame.entries.values()
185
359
  ):
186
360
  raise ValueError(
187
361
  "All addresses must be resolved. Or set min_qubits to a positive integer."
@@ -0,0 +1,49 @@
1
+ import math
2
+ from typing import Any
3
+
4
+ from kirin import interp
5
+ from kirin.dialects import ilist
6
+
7
+ from pyqrack import Pauli
8
+ from bloqade.pyqrack import PyQrackQubit
9
+ from bloqade.pyqrack.base import PyQrackInterpreter
10
+ from bloqade.native.dialects.gate import stmts
11
+
12
+
13
+ @stmts.dialect.register(key="pyqrack")
14
+ class NativeMethods(interp.MethodTable):
15
+
16
+ @interp.impl(stmts.CZ)
17
+ def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.CZ):
18
+ controls = frame.get_casted(stmt.controls, ilist.IList[PyQrackQubit, Any])
19
+ targets = frame.get_casted(stmt.targets, ilist.IList[PyQrackQubit, Any])
20
+
21
+ for ctrl, trgt in zip(controls, targets):
22
+ if ctrl.is_active() and trgt.is_active():
23
+ ctrl.sim_reg.mcz([ctrl.addr], trgt.addr)
24
+
25
+ return ()
26
+
27
+ @interp.impl(stmts.R)
28
+ def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.R):
29
+ qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
30
+ rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
31
+ axis_angle = 2 * math.pi * frame.get_casted(stmt.axis_angle, float)
32
+ for qubit in qubits:
33
+ if qubit.is_active():
34
+ qubit.sim_reg.r(Pauli.PauliZ, axis_angle, qubit.addr)
35
+ qubit.sim_reg.r(Pauli.PauliX, rotation_angle, qubit.addr)
36
+ qubit.sim_reg.r(Pauli.PauliZ, -axis_angle, qubit.addr)
37
+
38
+ return ()
39
+
40
+ @interp.impl(stmts.Rz)
41
+ def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.Rz):
42
+ qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
43
+ rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
44
+
45
+ for qubit in qubits:
46
+ if qubit.is_active():
47
+ qubit.sim_reg.r(Pauli.PauliZ, rotation_angle, qubit.addr)
48
+
49
+ return ()
bloqade/pyqrack/reg.py CHANGED
@@ -2,15 +2,20 @@ import enum
2
2
  from typing import TYPE_CHECKING
3
3
  from dataclasses import dataclass
4
4
 
5
+ from bloqade.types import MeasurementResult
5
6
  from bloqade.qasm2.types import Qubit
6
7
 
7
8
  if TYPE_CHECKING:
8
9
  from pyqrack import QrackSimulator
9
10
 
10
11
 
11
- class Measurement(enum.IntEnum):
12
+ class Measurement(MeasurementResult, enum.IntEnum):
12
13
  """Enumeration of measurement results."""
13
14
 
15
+ def __init__(self, measurement_id: int = 0) -> None:
16
+ super().__init__()
17
+ self.measurement_id = measurement_id
18
+
14
19
  Zero = 0
15
20
  One = 1
16
21
  Lost = enum.auto()
@@ -70,8 +75,3 @@ class PyQrackQubit(Qubit):
70
75
  def drop(self):
71
76
  """Drop the qubit in-place."""
72
77
  self.state = QubitState.Lost
73
-
74
-
75
- @dataclass
76
- class PyQrackWire:
77
- qubit: PyQrackQubit
@@ -0,0 +1 @@
1
+ from . import gate as gate
@@ -0,0 +1,136 @@
1
+ import math
2
+ from typing import Any
3
+
4
+ from kirin import interp
5
+ from kirin.dialects import ilist
6
+
7
+ from bloqade.squin import gate
8
+ from pyqrack.pauli import Pauli
9
+ from bloqade.pyqrack.reg import PyQrackQubit
10
+ from bloqade.pyqrack.target import PyQrackInterpreter
11
+ from bloqade.squin.gate.stmts import (
12
+ CX,
13
+ CY,
14
+ CZ,
15
+ U3,
16
+ H,
17
+ S,
18
+ T,
19
+ X,
20
+ Y,
21
+ Z,
22
+ Rx,
23
+ Ry,
24
+ Rz,
25
+ SqrtX,
26
+ SqrtY,
27
+ )
28
+
29
+
30
+ @gate.dialect.register(key="pyqrack")
31
+ class PyQrackMethods(interp.MethodTable):
32
+
33
+ @interp.impl(X)
34
+ @interp.impl(Y)
35
+ @interp.impl(Z)
36
+ @interp.impl(H)
37
+ def single_qubit_gate(
38
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: X | Y | Z | H
39
+ ):
40
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
41
+ method_name = stmt.name.lower()
42
+ for qbit in qubits:
43
+ if qbit.is_active():
44
+ getattr(qbit.sim_reg, method_name)(qbit.addr)
45
+
46
+ @interp.impl(T)
47
+ @interp.impl(S)
48
+ def single_qubit_nh_gate(
49
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: S | T
50
+ ):
51
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
52
+
53
+ method_name = stmt.name.lower()
54
+ if stmt.adjoint:
55
+ method_name = "adj" + method_name
56
+
57
+ for qbit in qubits:
58
+ if qbit.is_active():
59
+ getattr(qbit.sim_reg, method_name)(qbit.addr)
60
+ qbit.sim_reg.r
61
+
62
+ @interp.impl(SqrtX)
63
+ @interp.impl(SqrtY)
64
+ def sqrt_x(
65
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: SqrtX | SqrtY
66
+ ):
67
+ angle = math.pi / 2
68
+
69
+ if isinstance(stmt, SqrtX):
70
+ axis = Pauli.PauliX
71
+ else:
72
+ angle *= -1
73
+ axis = Pauli.PauliY
74
+
75
+ if stmt.adjoint:
76
+ angle *= -1
77
+
78
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
79
+ for qbit in qubits:
80
+ if qbit.is_active():
81
+ qbit.sim_reg.r(axis, angle, qbit.addr)
82
+
83
+ @interp.impl(Rx)
84
+ @interp.impl(Ry)
85
+ @interp.impl(Rz)
86
+ def rot(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Rx | Ry | Rz):
87
+ match stmt:
88
+ case Rx():
89
+ axis = Pauli.PauliX
90
+ case Ry():
91
+ axis = Pauli.PauliY
92
+ case Rz():
93
+ axis = Pauli.PauliZ
94
+
95
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
96
+
97
+ # NOTE: convert turns to radians
98
+ angle = frame.get(stmt.angle) * 2 * math.pi
99
+
100
+ for qbit in qubits:
101
+ if qbit.is_active():
102
+ qbit.sim_reg.r(axis, angle, qbit.addr)
103
+
104
+ @interp.impl(CX)
105
+ @interp.impl(CY)
106
+ @interp.impl(CZ)
107
+ def control(
108
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: CX | CY | CZ
109
+ ):
110
+ controls: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.controls)
111
+ targets: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.targets)
112
+
113
+ if len(controls) != len(targets):
114
+ raise RuntimeError(
115
+ f"Found {len(controls)} controls but {len(targets)} targets when trying to evaluate {stmt}."
116
+ )
117
+
118
+ # NOTE: pyqrack convention "multi-control-x"
119
+ method_name = "m" + stmt.name.lower()
120
+
121
+ for control, target in zip(controls, targets):
122
+ if control.is_active() and target.is_active():
123
+ getattr(control.sim_reg, method_name)([control.addr], target.addr)
124
+
125
+ @interp.impl(U3)
126
+ def u3(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: U3):
127
+ theta = frame.get(stmt.theta) * 2 * math.pi
128
+ phi = frame.get(stmt.phi) * 2 * math.pi
129
+ lam = frame.get(stmt.lam) * 2 * math.pi
130
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
131
+
132
+ for qbit in qubits:
133
+ if not qbit.is_active():
134
+ continue
135
+
136
+ qbit.sim_reg.u(qbit.addr, theta, phi, lam)