bloqade-circuit 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (74) hide show
  1. bloqade/analysis/address/impls.py +5 -9
  2. bloqade/analysis/address/lattice.py +1 -1
  3. bloqade/analysis/fidelity/__init__.py +1 -0
  4. bloqade/analysis/fidelity/analysis.py +69 -0
  5. bloqade/device.py +130 -0
  6. bloqade/noise/__init__.py +2 -1
  7. bloqade/noise/fidelity.py +51 -0
  8. bloqade/noise/native/model.py +1 -2
  9. bloqade/noise/native/rewrite.py +5 -5
  10. bloqade/noise/native/stmts.py +40 -11
  11. bloqade/pyqrack/__init__.py +8 -2
  12. bloqade/pyqrack/base.py +24 -3
  13. bloqade/pyqrack/device.py +166 -0
  14. bloqade/pyqrack/noise/native.py +1 -2
  15. bloqade/pyqrack/qasm2/core.py +31 -15
  16. bloqade/pyqrack/qasm2/glob.py +28 -0
  17. bloqade/pyqrack/qasm2/uop.py +9 -1
  18. bloqade/pyqrack/reg.py +17 -49
  19. bloqade/pyqrack/squin/__init__.py +0 -0
  20. bloqade/pyqrack/squin/op.py +154 -0
  21. bloqade/pyqrack/squin/qubit.py +85 -0
  22. bloqade/pyqrack/squin/runtime.py +515 -0
  23. bloqade/pyqrack/squin/wire.py +69 -0
  24. bloqade/pyqrack/target.py +9 -2
  25. bloqade/pyqrack/task.py +30 -0
  26. bloqade/qasm2/_wrappers.py +11 -1
  27. bloqade/qasm2/dialects/core/stmts.py +15 -4
  28. bloqade/qasm2/dialects/expr/_emit.py +9 -8
  29. bloqade/qasm2/emit/base.py +4 -2
  30. bloqade/qasm2/emit/gate.py +0 -14
  31. bloqade/qasm2/emit/main.py +19 -15
  32. bloqade/qasm2/emit/target.py +2 -6
  33. bloqade/qasm2/glob.py +1 -1
  34. bloqade/qasm2/parse/lowering.py +124 -1
  35. bloqade/qasm2/passes/glob.py +3 -3
  36. bloqade/qasm2/passes/lift_qubits.py +26 -0
  37. bloqade/qasm2/passes/noise.py +6 -14
  38. bloqade/qasm2/passes/parallel.py +3 -3
  39. bloqade/qasm2/passes/py2qasm.py +1 -2
  40. bloqade/qasm2/passes/qasm2py.py +1 -2
  41. bloqade/qasm2/rewrite/desugar.py +6 -6
  42. bloqade/qasm2/rewrite/glob.py +9 -9
  43. bloqade/qasm2/rewrite/heuristic_noise.py +30 -38
  44. bloqade/qasm2/rewrite/insert_qubits.py +34 -0
  45. bloqade/qasm2/rewrite/native_gates.py +54 -55
  46. bloqade/qasm2/rewrite/parallel_to_uop.py +9 -9
  47. bloqade/qasm2/rewrite/uop_to_parallel.py +20 -22
  48. bloqade/qasm2/types.py +3 -6
  49. bloqade/qbraid/schema.py +10 -12
  50. bloqade/squin/__init__.py +1 -1
  51. bloqade/squin/analysis/nsites/analysis.py +4 -6
  52. bloqade/squin/analysis/nsites/impls.py +2 -6
  53. bloqade/squin/analysis/schedule.py +1 -1
  54. bloqade/squin/groups.py +15 -7
  55. bloqade/squin/noise/__init__.py +27 -0
  56. bloqade/squin/noise/_dialect.py +3 -0
  57. bloqade/squin/noise/stmts.py +59 -0
  58. bloqade/squin/op/__init__.py +35 -5
  59. bloqade/squin/op/number.py +5 -0
  60. bloqade/squin/op/rewrite.py +46 -0
  61. bloqade/squin/op/stmts.py +23 -2
  62. bloqade/squin/op/types.py +14 -0
  63. bloqade/squin/qubit.py +79 -11
  64. bloqade/squin/rewrite/__init__.py +0 -0
  65. bloqade/squin/rewrite/measure_desugar.py +33 -0
  66. bloqade/squin/wire.py +31 -2
  67. bloqade/stim/emit/stim.py +1 -1
  68. bloqade/task.py +94 -0
  69. bloqade/visual/animation/base.py +25 -15
  70. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/METADATA +8 -2
  71. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/RECORD +73 -52
  72. bloqade/squin/op/complex.py +0 -6
  73. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/WHEEL +0 -0
  74. {bloqade_circuit-0.1.0.dist-info → bloqade_circuit-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -192,7 +192,10 @@ class SquinWireMethodTable(interp.MethodTable):
192
192
 
193
193
  origin_qubit = frame.get(stmt.qubit)
194
194
 
195
- return (AddressWire(origin_qubit=origin_qubit),)
195
+ if isinstance(origin_qubit, AddressQubit):
196
+ return (AddressWire(origin_qubit=origin_qubit),)
197
+ else:
198
+ return (Address.top(),)
196
199
 
197
200
  @interp.impl(squin.wire.Apply)
198
201
  def apply(
@@ -201,14 +204,7 @@ class SquinWireMethodTable(interp.MethodTable):
201
204
  frame: ForwardFrame[Address],
202
205
  stmt: squin.wire.Apply,
203
206
  ):
204
-
205
- origin_qubits = tuple(
206
- [frame.get(input_elem).origin_qubit for input_elem in stmt.inputs]
207
- )
208
- new_address_wires = tuple(
209
- [AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits]
210
- )
211
- return new_address_wires
207
+ return frame.get_values(stmt.inputs)
212
208
 
213
209
 
214
210
  @squin.qubit.dialect.register(key="qubit.address")
@@ -81,5 +81,5 @@ class AddressWire(Address):
81
81
 
82
82
  def is_subseteq(self, other: Address) -> bool:
83
83
  if isinstance(other, AddressWire):
84
- return self.origin_qubit == self.origin_qubit
84
+ return self.origin_qubit == other.origin_qubit
85
85
  return False
@@ -0,0 +1 @@
1
+ from .analysis import FidelityAnalysis as FidelityAnalysis
@@ -0,0 +1,69 @@
1
+ from typing import Any
2
+ from dataclasses import field
3
+
4
+ from kirin import ir
5
+ from kirin.lattice import EmptyLattice
6
+ from kirin.analysis import Forward
7
+ from kirin.interp.value import Successor
8
+ from kirin.analysis.forward import ForwardFrame
9
+
10
+ from ..address import AddressAnalysis
11
+
12
+
13
+ class FidelityAnalysis(Forward):
14
+ """
15
+ This analysis pass can be used to track the global addresses of qubits and wires.
16
+ """
17
+
18
+ keys = ["circuit.fidelity"]
19
+ lattice = EmptyLattice
20
+
21
+ """
22
+ The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
23
+ """
24
+ gate_fidelity: float = 1.0
25
+
26
+ _current_gate_fidelity: float = field(init=False)
27
+
28
+ """
29
+ The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
30
+ """
31
+ atom_survival_probability: list[float] = field(init=False)
32
+
33
+ _current_atom_survival_probability: list[float] = field(init=False)
34
+
35
+ addr_frame: ForwardFrame = field(init=False)
36
+
37
+ def initialize(self):
38
+ super().initialize()
39
+ self._current_gate_fidelity = 1.0
40
+ self._current_atom_survival_probability = [
41
+ 1.0 for _ in range(len(self.atom_survival_probability))
42
+ ]
43
+ return self
44
+
45
+ def posthook_succ(self, frame: ForwardFrame, succ: Successor):
46
+ self.gate_fidelity *= self._current_gate_fidelity
47
+ for i, _current_survival in enumerate(self._current_atom_survival_probability):
48
+ self.atom_survival_probability[i] *= _current_survival
49
+
50
+ def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
51
+ # NOTE: default is to conserve fidelity, so do nothing here
52
+ return
53
+
54
+ def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
55
+ return self.run_callable(method.code, (self.lattice.bottom(),) + args)
56
+
57
+ def run_analysis(
58
+ self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
59
+ ) -> tuple[ForwardFrame, Any]:
60
+ self._run_address_analysis(method, no_raise=no_raise)
61
+ return super().run_analysis(method, args, no_raise=no_raise)
62
+
63
+ def _run_address_analysis(self, method: ir.Method, no_raise: bool):
64
+ addr_analysis = AddressAnalysis(self.dialects)
65
+ addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
66
+ self.addr_frame = addr_frame
67
+
68
+ # NOTE: make sure we have as many probabilities as we have addresses
69
+ self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
bloqade/device.py ADDED
@@ -0,0 +1,130 @@
1
+ import abc
2
+ from typing import Any, Generic, TypeVar, ParamSpec
3
+
4
+ from kirin import ir
5
+
6
+ from bloqade.task import (
7
+ BatchFuture,
8
+ AbstractTask,
9
+ AbstractRemoteTask,
10
+ AbstractSimulatorTask,
11
+ DeviceTaskExpectMixin,
12
+ )
13
+
14
+ Params = ParamSpec("Params")
15
+ RetType = TypeVar("RetType")
16
+ ObsType = TypeVar("ObsType")
17
+
18
+
19
+ TaskType = TypeVar("TaskType", bound=AbstractTask)
20
+
21
+
22
+ class AbstractDevice(abc.ABC, Generic[TaskType]):
23
+ """Abstract base class for devices. Defines the minimum interface for devices."""
24
+
25
+ @abc.abstractmethod
26
+ def task(
27
+ self,
28
+ kernel: ir.Method[Params, RetType],
29
+ args: tuple[Any, ...] = (),
30
+ kwargs: dict[str, Any] | None = None,
31
+ ) -> TaskType:
32
+ """Creates a remote task for the device."""
33
+
34
+
35
+ ExpectTaskType = TypeVar("ExpectTaskType", bound=DeviceTaskExpectMixin)
36
+
37
+
38
+ class ExpectationDeviceMixin(AbstractDevice[ExpectTaskType]):
39
+ def expect(
40
+ self,
41
+ kernel: ir.Method[Params, RetType],
42
+ observable: ir.Method[[RetType], ObsType],
43
+ args: tuple[Any, ...] = (),
44
+ kwargs: dict[str, Any] | None = None,
45
+ *,
46
+ shots: int = 1,
47
+ ) -> ObsType:
48
+ """Returns the expectation value of the given observable after running the task."""
49
+ return self.task(kernel, args, kwargs).expect(observable, shots)
50
+
51
+
52
+ RemoteTaskType = TypeVar("RemoteTaskType", bound=AbstractRemoteTask)
53
+
54
+
55
+ class AbstractRemoteDevice(AbstractDevice[RemoteTaskType]):
56
+ """Abstract base class for remote devices."""
57
+
58
+ def run(
59
+ self,
60
+ kernel: ir.Method[Params, RetType],
61
+ args: tuple[Any, ...] = (),
62
+ kwargs: dict[str, Any] | None = None,
63
+ *,
64
+ shots: int = 1,
65
+ timeout: float | None = None,
66
+ ) -> list[RetType]:
67
+ """Runs the kernel and returns the result.
68
+
69
+ Args:
70
+ kernel (ir.Method):
71
+ The kernel method to run.
72
+ args (tuple[Any, ...]):
73
+ Positional arguments to pass to the kernel method.
74
+ kwargs (dict[str, Any] | None):
75
+ Keyword arguments to pass to the kernel method.
76
+ shots (int):
77
+ The number of times to run the kernel method.
78
+ timeout (float | None):
79
+ Timeout in seconds for the asynchronous execution. If None, wait indefinitely.
80
+
81
+ Returns:
82
+ list[RetType]:
83
+ The result of the kernel method, if any.
84
+
85
+ """
86
+ return self.task(kernel, args, kwargs).run(shots=shots, timeout=timeout)
87
+
88
+ def run_async(
89
+ self,
90
+ kernel: ir.Method[Params, RetType],
91
+ args: tuple[Any, ...] = (),
92
+ kwargs: dict[str, Any] | None = None,
93
+ *,
94
+ shots: int = 1,
95
+ ) -> BatchFuture[RetType]:
96
+ """Runs the kernel asynchronously and returns a Future object.
97
+
98
+ Args:
99
+ kernel (ir.Method):
100
+ The kernel method to run.
101
+ args (tuple[Any, ...]):
102
+ Positional arguments to pass to the kernel method.
103
+ kwargs (dict[str, Any] | None):
104
+ Keyword arguments to pass to the kernel method.
105
+ shots (int):
106
+ The number of times to run the kernel method.
107
+
108
+ Returns:
109
+ Future[list[RetType]]:
110
+ The Future for all executions of the kernel method.
111
+
112
+
113
+ """
114
+ return self.task(kernel, args, kwargs).run_async(shots=shots)
115
+
116
+
117
+ SimulatorTaskType = TypeVar("SimulatorTaskType", bound=AbstractSimulatorTask)
118
+
119
+
120
+ class AbstractSimulatorDevice(AbstractDevice[SimulatorTaskType]):
121
+ """Abstract base class for simulator devices."""
122
+
123
+ def run(
124
+ self,
125
+ kernel: ir.Method[Params, RetType],
126
+ args: tuple[Any, ...] = (),
127
+ kwargs: dict[str, Any] | None = None,
128
+ ) -> RetType:
129
+ """Runs the kernel and returns the result."""
130
+ return self.task(kernel, args, kwargs).run()
bloqade/noise/__init__.py CHANGED
@@ -1 +1,2 @@
1
- from . import native as native
1
+ # NOTE: just to register methods
2
+ from . import native as native, fidelity as fidelity
@@ -0,0 +1,51 @@
1
+ from kirin import interp
2
+ from kirin.lattice import EmptyLattice
3
+
4
+ from bloqade.analysis.fidelity import FidelityAnalysis
5
+
6
+ from .native import dialect as native
7
+ from .native.stmts import PauliChannel, CZPauliChannel, AtomLossChannel
8
+ from ..analysis.address import AddressQubit, AddressTuple
9
+
10
+
11
+ @native.register(key="circuit.fidelity")
12
+ class FidelityMethodTable(interp.MethodTable):
13
+
14
+ @interp.impl(PauliChannel)
15
+ @interp.impl(CZPauliChannel)
16
+ def pauli_channel(
17
+ self,
18
+ interp: FidelityAnalysis,
19
+ frame: interp.Frame[EmptyLattice],
20
+ stmt: PauliChannel | CZPauliChannel,
21
+ ):
22
+ probs = stmt.probabilities
23
+ try:
24
+ ps, ps_ctrl = probs
25
+ except ValueError:
26
+ (ps,) = probs
27
+ ps_ctrl = ()
28
+
29
+ p = sum(ps)
30
+ p_ctrl = sum(ps_ctrl)
31
+
32
+ # NOTE: fidelity is just the inverse probability of any noise to occur
33
+ fid = (1 - p) * (1 - p_ctrl)
34
+
35
+ interp._current_gate_fidelity *= fid
36
+
37
+ @interp.impl(AtomLossChannel)
38
+ def atom_loss(
39
+ self,
40
+ interp: FidelityAnalysis,
41
+ frame: interp.Frame[EmptyLattice],
42
+ stmt: AtomLossChannel,
43
+ ):
44
+ # NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
45
+ addresses: AddressTuple = interp.addr_frame.get(stmt.qargs)
46
+
47
+ # NOTE: get the corresponding index and reduce survival probability accordingly
48
+ for qbit_address in addresses.data:
49
+ assert isinstance(qbit_address, AddressQubit)
50
+ index = qbit_address.data
51
+ interp._current_atom_survival_probability[index] *= 1 - stmt.prob
@@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC):
102
102
  params: MoveNoiseParams = field(default_factory=MoveNoiseParams)
103
103
  """Parameters for calculating move noise."""
104
104
 
105
- @classmethod
106
105
  @abc.abstractmethod
107
106
  def parallel_cz_errors(
108
- cls, ctrls: List[int], qargs: List[int], rest: List[int]
107
+ self, ctrls: List[int], qargs: List[int], rest: List[int]
109
108
  ) -> Dict[Tuple[float, float, float, float], List[int]]:
110
109
  """Takes a set of ctrls and qargs and returns a noise model for all qubits."""
111
110
  pass
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
 
3
3
  from kirin import ir
4
- from kirin.rewrite import abc, dce, walk, result, fixpoint
4
+ from kirin.rewrite import abc, dce, walk, fixpoint
5
5
  from kirin.passes.abc import Pass
6
6
 
7
7
  from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
@@ -9,19 +9,19 @@ from ._dialect import dialect
9
9
 
10
10
 
11
11
  class RemoveNoiseRewrite(abc.RewriteRule):
12
- def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
12
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
13
13
  if isinstance(node, (AtomLossChannel, PauliChannel, CZPauliChannel)):
14
14
  node.delete()
15
- return result.RewriteResult(has_done_something=True)
15
+ return abc.RewriteResult(has_done_something=True)
16
16
 
17
- return result.RewriteResult()
17
+ return abc.RewriteResult()
18
18
 
19
19
 
20
20
  @dataclass
21
21
  class RemoveNoisePass(Pass):
22
22
  name = "remove-noise"
23
23
 
24
- def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
24
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
25
25
  delete_walk = walk.Walk(RemoveNoiseRewrite())
26
26
  dce_walk = fixpoint.Fixpoint(walk.Walk(dce.DeadCodeElimination()))
27
27
 
@@ -1,3 +1,5 @@
1
+ from typing import Tuple
2
+
1
3
  from kirin import ir, types, lowering
2
4
  from kirin.decl import info, statement
3
5
  from kirin.dialects import ilist
@@ -7,25 +9,44 @@ from bloqade.qasm2.types import QubitType
7
9
  from ._dialect import dialect
8
10
 
9
11
 
10
- @statement(dialect=dialect)
11
- class PauliChannel(ir.Statement):
12
-
12
+ @statement
13
+ class NativeNoiseStmt(ir.Statement):
13
14
  traits = frozenset({lowering.FromPythonCall()})
14
15
 
16
+ @property
17
+ def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
18
+ raise NotImplementedError(f"Override the method in {type(self).__name__}")
19
+
20
+ def check(self):
21
+ for probs in self.probabilities:
22
+ self.check_probability(sum(probs))
23
+ for p in probs:
24
+ self.check_probability(p)
25
+
26
+ def check_probability(self, p: float):
27
+ if not 0 <= p <= 1:
28
+ raise ValueError(
29
+ f"Invalid noise probability encountered in {type(self).__name__}: {p}"
30
+ )
31
+
32
+
33
+ @statement(dialect=dialect)
34
+ class PauliChannel(NativeNoiseStmt):
15
35
  px: float = info.attribute(types.Float)
16
36
  py: float = info.attribute(types.Float)
17
37
  pz: float = info.attribute(types.Float)
18
38
  qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
19
39
 
40
+ @property
41
+ def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
42
+ return ((self.px, self.py, self.pz),)
43
+
20
44
 
21
45
  NumQubits = types.TypeVar("NumQubits")
22
46
 
23
47
 
24
48
  @statement(dialect=dialect)
25
- class CZPauliChannel(ir.Statement):
26
-
27
- traits = frozenset({lowering.FromPythonCall()})
28
-
49
+ class CZPauliChannel(NativeNoiseStmt):
29
50
  paired: bool = info.attribute(types.Bool)
30
51
  px_ctrl: float = info.attribute(types.Float)
31
52
  py_ctrl: float = info.attribute(types.Float)
@@ -36,11 +57,19 @@ class CZPauliChannel(ir.Statement):
36
57
  ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
37
58
  qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
38
59
 
60
+ @property
61
+ def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
62
+ return (
63
+ (self.px_ctrl, self.py_ctrl, self.pz_ctrl),
64
+ (self.px_qarg, self.py_qarg, self.pz_qarg),
65
+ )
39
66
 
40
- @statement(dialect=dialect)
41
- class AtomLossChannel(ir.Statement):
42
-
43
- traits = frozenset({lowering.FromPythonCall()})
44
67
 
68
+ @statement(dialect=dialect)
69
+ class AtomLossChannel(NativeNoiseStmt):
45
70
  prob: float = info.attribute(types.Float)
46
71
  qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
72
+
73
+ @property
74
+ def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
75
+ return ((self.prob,),)
@@ -1,9 +1,9 @@
1
1
  from .reg import (
2
2
  CBitRef as CBitRef,
3
3
  CRegister as CRegister,
4
- PyQrackReg as PyQrackReg,
5
4
  QubitState as QubitState,
6
5
  Measurement as Measurement,
6
+ PyQrackWire as PyQrackWire,
7
7
  PyQrackQubit as PyQrackQubit,
8
8
  )
9
9
  from .base import (
@@ -11,8 +11,14 @@ from .base import (
11
11
  DynamicMemory as DynamicMemory,
12
12
  PyQrackInterpreter as PyQrackInterpreter,
13
13
  )
14
+ from .task import PyQrackSimulatorTask as PyQrackSimulatorTask
14
15
 
15
16
  # NOTE: The following import is for registering the method tables
16
17
  from .noise import native as native
17
- from .qasm2 import uop as uop, core as core, parallel as parallel
18
+ from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
19
+ from .squin import op as op, qubit as qubit
20
+ from .device import (
21
+ StackMemorySimulator as StackMemorySimulator,
22
+ DynamicMemorySimulator as DynamicMemorySimulator,
23
+ )
18
24
  from .target import PyQrack as PyQrack
bloqade/pyqrack/base.py CHANGED
@@ -26,13 +26,28 @@ class PyQrackOptions(typing.TypedDict):
26
26
  isOpenCL: bool
27
27
 
28
28
 
29
+ def _validate_pyqrack_options(options: PyQrackOptions) -> None:
30
+ if options["isBinaryDecisionTree"] and options["isStabilizerHybrid"]:
31
+ raise ValueError(
32
+ "Cannot use both isBinaryDecisionTree and isStabilizerHybrid at the same time."
33
+ )
34
+ elif options["isTensorNetwork"] and options["isBinaryDecisionTree"]:
35
+ raise ValueError(
36
+ "Cannot use both isTensorNetwork and isBinaryDecisionTree at the same time."
37
+ )
38
+ elif options["isTensorNetwork"] and options["isStabilizerHybrid"]:
39
+ raise ValueError(
40
+ "Cannot use both isTensorNetwork and isStabilizerHybrid at the same time."
41
+ )
42
+
43
+
29
44
  def _default_pyqrack_args() -> PyQrackOptions:
30
45
  return PyQrackOptions(
31
46
  qubitCount=-1,
32
47
  isTensorNetwork=False,
33
48
  isSchmidtDecomposeMulti=True,
34
49
  isSchmidtDecompose=True,
35
- isStabilizerHybrid=True,
50
+ isStabilizerHybrid=False,
36
51
  isBinaryDecisionTree=True,
37
52
  isPaged=True,
38
53
  isCpuGpuHybrid=True,
@@ -45,6 +60,9 @@ class MemoryABC(abc.ABC):
45
60
  pyqrack_options: PyQrackOptions = field(default_factory=_default_pyqrack_args)
46
61
  sim_reg: "QrackSimulator" = field(init=False)
47
62
 
63
+ def __post_init__(self):
64
+ _validate_pyqrack_options(self.pyqrack_options)
65
+
48
66
  @abc.abstractmethod
49
67
  def allocate(self, n_qubits: int) -> tuple[int, ...]:
50
68
  """Allocate `n_qubits` qubits and return their ids."""
@@ -115,10 +133,13 @@ class DynamicMemory(MemoryABC):
115
133
  return tuple(range(start, start + n_qubits))
116
134
 
117
135
 
136
+ MemoryType = typing.TypeVar("MemoryType", bound=MemoryABC)
137
+
138
+
118
139
  @dataclass
119
- class PyQrackInterpreter(Interpreter):
140
+ class PyQrackInterpreter(Interpreter, typing.Generic[MemoryType]):
120
141
  keys = ["pyqrack", "main"]
121
- memory: MemoryABC = field(kw_only=True)
142
+ memory: MemoryType = field(kw_only=True)
122
143
  rng_state: np.random.Generator = field(
123
144
  default_factory=np.random.default_rng, kw_only=True
124
145
  )
@@ -0,0 +1,166 @@
1
+ from typing import Any, TypeVar, ParamSpec
2
+ from dataclasses import field, dataclass
3
+
4
+ import numpy as np
5
+ from kirin import ir
6
+
7
+ from pyqrack.pauli import Pauli
8
+ from bloqade.device import AbstractSimulatorDevice
9
+ from bloqade.pyqrack.reg import Measurement, PyQrackQubit
10
+ from bloqade.pyqrack.base import (
11
+ MemoryABC,
12
+ StackMemory,
13
+ DynamicMemory,
14
+ PyQrackOptions,
15
+ PyQrackInterpreter,
16
+ _default_pyqrack_args,
17
+ )
18
+ from bloqade.pyqrack.task import PyQrackSimulatorTask
19
+ from bloqade.analysis.address.lattice import AnyAddress
20
+ from bloqade.analysis.address.analysis import AddressAnalysis
21
+
22
+ RetType = TypeVar("RetType")
23
+ Params = ParamSpec("Params")
24
+
25
+
26
+ @dataclass
27
+ class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
28
+ options: PyQrackOptions = field(default_factory=_default_pyqrack_args)
29
+ loss_m_result: Measurement = field(default=Measurement.One, kw_only=True)
30
+ rng_state: np.random.Generator = field(
31
+ default_factory=np.random.default_rng, kw_only=True
32
+ )
33
+
34
+ MemoryType = TypeVar("MemoryType", bound=MemoryABC)
35
+
36
+ def __post_init__(self):
37
+ self.options = PyQrackOptions({**_default_pyqrack_args(), **self.options})
38
+
39
+ def new_task(
40
+ self,
41
+ mt: ir.Method[Params, RetType],
42
+ args: tuple[Any, ...],
43
+ kwargs: dict[str, Any],
44
+ memory: MemoryType,
45
+ ) -> PyQrackSimulatorTask[Params, RetType, MemoryType]:
46
+ interp = PyQrackInterpreter(
47
+ mt.dialects,
48
+ memory=memory,
49
+ rng_state=self.rng_state,
50
+ loss_m_result=self.loss_m_result,
51
+ )
52
+ return PyQrackSimulatorTask(
53
+ kernel=mt, args=args, kwargs=kwargs, pyqrack_interp=interp
54
+ )
55
+
56
+ def state_vector(
57
+ self,
58
+ kernel: ir.Method[Params, RetType],
59
+ args: tuple[Any, ...] = (),
60
+ kwargs: dict[str, Any] | None = None,
61
+ ) -> list[complex]:
62
+ """Runs task and returns the state vector."""
63
+ task = self.task(kernel, args, kwargs)
64
+ task.run()
65
+ return task.state.sim_reg.out_ket()
66
+
67
+ @staticmethod
68
+ def pauli_expectation(pauli: list[Pauli], qubits: list[PyQrackQubit]) -> float:
69
+ """Returns the expectation value of the given Pauli operator given a list of Pauli operators and qubits.
70
+
71
+ Args:
72
+ pauli (list[Pauli]):
73
+ List of Pauli operators to compute the expectation value for.
74
+ qubits (list[PyQrackQubit]):
75
+ List of qubits corresponding to the Pauli operators.
76
+
77
+ returns:
78
+ float:
79
+ The expectation value of the Pauli operator.
80
+
81
+ """
82
+
83
+ if len(pauli) == 0:
84
+ return 0.0
85
+
86
+ if len(pauli) != len(qubits):
87
+ raise ValueError("Length of Pauli and qubits must match.")
88
+
89
+ sim_reg = qubits[0].sim_reg
90
+
91
+ if any(qubit.sim_reg is not sim_reg for qubit in qubits):
92
+ raise ValueError("All qubits must belong to the same simulator register.")
93
+
94
+ qubit_ids = [qubit.addr for qubit in qubits]
95
+
96
+ if len(qubit_ids) != len(set(qubit_ids)):
97
+ raise ValueError("Qubits must be unique.")
98
+
99
+ return sim_reg.pauli_expectation(pauli, qubit_ids)
100
+
101
+
102
+ @dataclass
103
+ class StackMemorySimulator(PyQrackSimulatorBase):
104
+ """PyQrack simulator device with precalculated stack of qubits."""
105
+
106
+ min_qubits: int = field(default=0, kw_only=True)
107
+
108
+ def task(
109
+ self,
110
+ kernel: ir.Method[Params, RetType],
111
+ args: tuple[Any, ...] = (),
112
+ kwargs: dict[str, Any] | None = None,
113
+ ):
114
+ if kwargs is None:
115
+ kwargs = {}
116
+
117
+ address_analysis = AddressAnalysis(dialects=kernel.dialects)
118
+ frame, _ = address_analysis.run_analysis(kernel)
119
+ if self.min_qubits == 0 and any(
120
+ isinstance(a, AnyAddress) for a in frame.entries.values()
121
+ ):
122
+ raise ValueError(
123
+ "All addresses must be resolved. Or set min_qubits to a positive integer."
124
+ )
125
+
126
+ num_qubits = max(address_analysis.qubit_count, self.min_qubits)
127
+ options = self.options.copy()
128
+ options["qubitCount"] = num_qubits
129
+ memory = StackMemory(
130
+ options,
131
+ total=num_qubits,
132
+ )
133
+
134
+ return self.new_task(kernel, args, kwargs, memory)
135
+
136
+
137
+ @dataclass
138
+ class DynamicMemorySimulator(PyQrackSimulatorBase):
139
+ """PyQrack simulator device with dynamic qubit allocation."""
140
+
141
+ def task(
142
+ self,
143
+ kernel: ir.Method[Params, RetType],
144
+ args: tuple[Any, ...] = (),
145
+ kwargs: dict[str, Any] | None = None,
146
+ ):
147
+ if kwargs is None:
148
+ kwargs = {}
149
+
150
+ memory = DynamicMemory(self.options.copy())
151
+ return self.new_task(kernel, args, kwargs, memory)
152
+
153
+
154
+ def test():
155
+ from bloqade.qasm2 import extended
156
+
157
+ @extended
158
+ def main():
159
+ return 1
160
+
161
+ @extended
162
+ def obs(result: int) -> int:
163
+ return result
164
+
165
+ res = DynamicMemorySimulator().task(main)
166
+ return res.run()
@@ -93,8 +93,7 @@ class PyQrackMethods(interp.MethodTable):
93
93
 
94
94
  for qarg in active_qubits:
95
95
  if interp.rng_state.uniform() <= stmt.prob:
96
- sim_reg = qarg.ref.sim_reg
97
- sim_reg.force_m(qarg.addr, 0)
96
+ qarg.sim_reg.m(qarg.addr)
98
97
  qarg.drop()
99
98
 
100
99
  return ()