bloqade-circuit 0.7.12__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,72 +1,138 @@
1
- import random
2
- import typing
3
- from dataclasses import dataclass
4
-
5
1
  from kirin import interp
6
- from kirin.dialects import ilist
7
2
 
8
- from bloqade.pyqrack import QubitState, PyQrackQubit, PyQrackInterpreter
9
- from bloqade.squin.noise.stmts import QubitLoss, StochasticUnitaryChannel
3
+ from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter
4
+ from bloqade.squin.noise.stmts import (
5
+ QubitLoss,
6
+ Depolarize,
7
+ Depolarize2,
8
+ CorrelatedQubitLoss,
9
+ TwoQubitPauliChannel,
10
+ SingleQubitPauliChannel,
11
+ )
10
12
  from bloqade.squin.noise._dialect import dialect as squin_noise_dialect
11
13
 
12
- from ..runtime import OperatorRuntimeABC
13
-
14
-
15
- @dataclass(frozen=True)
16
- class StochasticUnitaryChannelRuntime(OperatorRuntimeABC):
17
- operators: ilist.IList[OperatorRuntimeABC, typing.Any]
18
- probabilities: ilist.IList[float, typing.Any]
19
-
20
- @property
21
- def n_sites(self) -> int:
22
- n = self.operators[0].n_sites
23
- for op in self.operators[1:]:
24
- assert (
25
- op.n_sites == n
26
- ), "Encountered a stochastic unitary channel with operators of different size!"
27
- return n
28
-
29
- def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
30
- # NOTE: probabilities don't necessarily sum to 1; could be no noise event should occur
31
- p_no_op = 1 - sum(self.probabilities)
32
- if random.uniform(0.0, 1.0) < p_no_op:
33
- return
34
-
35
- selected_ops = random.choices(self.operators, weights=self.probabilities)
36
- for op in selected_ops:
37
- op.apply(*qubits, adjoint=adjoint)
38
-
39
14
 
40
- @dataclass(frozen=True)
41
- class QubitLossRuntime(OperatorRuntimeABC):
42
- p: float
43
-
44
- @property
45
- def n_sites(self) -> int:
46
- return 1
15
+ @squin_noise_dialect.register(key="pyqrack")
16
+ class PyQrackMethods(interp.MethodTable):
47
17
 
48
- def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None:
49
- if random.uniform(0.0, 1.0) <= self.p:
50
- qubit.state = QubitState.Lost
18
+ single_pauli_choices = ("i", "x", "y", "z")
19
+ two_pauli_choices = (
20
+ "ii",
21
+ "ix",
22
+ "iy",
23
+ "iz",
24
+ "xi",
25
+ "xx",
26
+ "xy",
27
+ "xz",
28
+ "yi",
29
+ "yx",
30
+ "yy",
31
+ "yz",
32
+ "zi",
33
+ "zx",
34
+ "zy",
35
+ "zz",
36
+ )
37
+
38
+ @interp.impl(Depolarize)
39
+ def depolarize(
40
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Depolarize
41
+ ):
42
+ p = frame.get(stmt.p)
43
+ ps = [p / 3.0] * 3
44
+ qubits = frame.get(stmt.qubits)
45
+ self.apply_single_qubit_pauli_error(interp, ps, qubits)
51
46
 
47
+ @interp.impl(Depolarize2)
48
+ def depolarize2(
49
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Depolarize2
50
+ ):
51
+ p = frame.get(stmt.p)
52
+ ps = [p / 15.0] * 15
53
+ controls = frame.get(stmt.controls)
54
+ targets = frame.get(stmt.targets)
55
+ self.apply_two_qubit_pauli_error(interp, ps, controls, targets)
52
56
 
53
- @squin_noise_dialect.register(key="pyqrack")
54
- class PyQrackMethods(interp.MethodTable):
55
- @interp.impl(StochasticUnitaryChannel)
56
- def stochastic_unitary_channel(
57
+ @interp.impl(SingleQubitPauliChannel)
58
+ def single_qubit_pauli_channel(
57
59
  self,
58
60
  interp: PyQrackInterpreter,
59
61
  frame: interp.Frame,
60
- stmt: StochasticUnitaryChannel,
62
+ stmt: SingleQubitPauliChannel,
61
63
  ):
62
- operators = frame.get(stmt.operators)
63
- probabilities = frame.get(stmt.probabilities)
64
-
65
- return (StochasticUnitaryChannelRuntime(operators, probabilities),)
64
+ px = frame.get(stmt.px)
65
+ py = frame.get(stmt.py)
66
+ pz = frame.get(stmt.pz)
67
+ qubits = frame.get(stmt.qubits)
68
+ self.apply_single_qubit_pauli_error(interp, [px, py, pz], qubits)
69
+
70
+ @interp.impl(TwoQubitPauliChannel)
71
+ def two_qubit_pauli_channel(
72
+ self,
73
+ interp: PyQrackInterpreter,
74
+ frame: interp.Frame,
75
+ stmt: TwoQubitPauliChannel,
76
+ ):
77
+ ps = frame.get(stmt.probabilities)
78
+ controls = frame.get(stmt.controls)
79
+ targets = frame.get(stmt.targets)
80
+ self.apply_two_qubit_pauli_error(interp, ps, controls, targets)
66
81
 
67
82
  @interp.impl(QubitLoss)
68
83
  def qubit_loss(
69
84
  self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: QubitLoss
70
85
  ):
71
86
  p = frame.get(stmt.p)
72
- return (QubitLossRuntime(p),)
87
+ qubits: list[PyQrackQubit] = frame.get(stmt.qubits)
88
+ for qbit in qubits:
89
+ if interp.rng_state.uniform(0.0, 1.0) <= p:
90
+ qbit.drop()
91
+
92
+ @interp.impl(CorrelatedQubitLoss)
93
+ def correlated_qubit_loss(
94
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: CorrelatedQubitLoss
95
+ ):
96
+ p = frame.get(stmt.p)
97
+ qubits: list[list[PyQrackQubit]] = frame.get(stmt.qubits)
98
+ for qubit_group in qubits:
99
+ if interp.rng_state.uniform(0.0, 1.0) <= p:
100
+ for qbit in qubit_group:
101
+ qbit.drop()
102
+
103
+ def apply_single_qubit_pauli_error(
104
+ self,
105
+ interp: PyQrackInterpreter,
106
+ ps: list[float],
107
+ qubits: list[PyQrackQubit],
108
+ ):
109
+ pi = 1 - sum(ps)
110
+ probs = [pi] + ps
111
+
112
+ assert all(0 <= x <= 1 for x in probs), "Invalid Pauli error probabilities"
113
+
114
+ for qbit in qubits:
115
+ which = interp.rng_state.choice(self.single_pauli_choices, p=probs)
116
+ self.apply_pauli_error(which, qbit)
117
+
118
+ def apply_two_qubit_pauli_error(
119
+ self,
120
+ interp: PyQrackInterpreter,
121
+ ps: list[float],
122
+ controls: list[PyQrackQubit],
123
+ targets: list[PyQrackQubit],
124
+ ):
125
+ pii = 1 - sum(ps)
126
+ probs = [pii] + ps
127
+ assert all(0 <= x <= 1 for x in probs), "Invalid Pauli error probabilities"
128
+
129
+ for control, target in zip(controls, targets):
130
+ which = interp.rng_state.choice(self.two_pauli_choices, p=probs)
131
+ self.apply_pauli_error(which[0], control)
132
+ self.apply_pauli_error(which[1], target)
133
+
134
+ def apply_pauli_error(self, which: str, qbit: PyQrackQubit):
135
+ if not qbit.is_active() or which == "i":
136
+ return
137
+
138
+ getattr(qbit.sim_reg, which)(qbit.addr)
@@ -3,41 +3,20 @@ from typing import Any
3
3
  from kirin import interp
4
4
  from kirin.dialects import ilist
5
5
 
6
- from bloqade.squin import qubit
6
+ from bloqade.qubit import stmts as qubit
7
7
  from bloqade.pyqrack.reg import QubitState, Measurement, PyQrackQubit
8
8
  from bloqade.pyqrack.base import PyQrackInterpreter
9
9
 
10
- from .runtime import OperatorRuntimeABC
11
-
12
10
 
13
11
  @qubit.dialect.register(key="pyqrack")
14
12
  class PyQrackMethods(interp.MethodTable):
15
13
  @interp.impl(qubit.New)
16
- def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
17
- n_qubits: int = frame.get(stmt.n_qubits)
18
- qreg = ilist.IList(
19
- [
20
- PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
21
- for i in interp.memory.allocate(n_qubits=n_qubits)
22
- ]
23
- )
24
- return (qreg,)
25
-
26
- @interp.impl(qubit.Apply)
27
- def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
28
- qubits: list[PyQrackQubit] = [frame.get(qbit) for qbit in stmt.qubits]
29
- operator: OperatorRuntimeABC = frame.get(stmt.operator)
30
- operator.apply(*qubits)
31
-
32
- @interp.impl(qubit.Broadcast)
33
- def broadcast(
34
- self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Broadcast
14
+ def new_qubit(
15
+ self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New
35
16
  ):
36
- operator: OperatorRuntimeABC = frame.get(stmt.operator)
37
- qubits: list[ilist.IList[PyQrackQubit, Any]] = [
38
- frame.get(qbit) for qbit in stmt.qubits
39
- ]
40
- operator.broadcast_apply(qubits)
17
+ (addr,) = interp.memory.allocate(1)
18
+ qb = PyQrackQubit(addr, interp.memory.sim_reg, QubitState.Active)
19
+ return (qb,)
41
20
 
42
21
  def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
43
22
  if qbit.is_active():
@@ -48,35 +27,40 @@ class PyQrackMethods(interp.MethodTable):
48
27
  interp.set_global_measurement_id(m)
49
28
  return m
50
29
 
51
- @interp.impl(qubit.MeasureQubitList)
30
+ @interp.impl(qubit.Measure)
52
31
  def measure_qubit_list(
53
32
  self,
54
33
  interp: PyQrackInterpreter,
55
34
  frame: interp.Frame,
56
- stmt: qubit.MeasureQubitList,
35
+ stmt: qubit.Measure,
57
36
  ):
58
37
  qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
59
38
  result = ilist.IList([self._measure_qubit(qbit, interp) for qbit in qubits])
60
39
  return (result,)
61
40
 
62
- @interp.impl(qubit.MeasureQubit)
63
- def measure_qubit(
64
- self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit
65
- ):
66
- qbit: PyQrackQubit = frame.get(stmt.qubit)
67
- result = self._measure_qubit(qbit, interp)
68
- return (result,)
69
-
70
41
  @interp.impl(qubit.QubitId)
71
42
  def qubit_id(
72
43
  self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId
73
44
  ):
74
- qbit: PyQrackQubit = frame.get(stmt.qubit)
75
- return (qbit.addr,)
45
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
46
+ ids = ilist.IList([qbit.addr for qbit in qubits])
47
+ return (ids,)
76
48
 
77
49
  @interp.impl(qubit.MeasurementId)
78
50
  def measurement_id(
79
51
  self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId
80
52
  ):
81
- measurement: Measurement = frame.get(stmt.measurement)
82
- return (measurement.measurement_id,)
53
+ measurements: ilist.IList[Measurement, Any] = frame.get(stmt.measurements)
54
+ ids = ilist.IList([measurement.measurement_id for measurement in measurements])
55
+ return (ids,)
56
+
57
+ @interp.impl(qubit.Reset)
58
+ def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset):
59
+ qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
60
+ for qbit in qubits:
61
+ if not qbit.is_active():
62
+ continue
63
+
64
+ m = qbit.sim_reg.m(qbit.addr)
65
+ if m == Measurement.One:
66
+ qbit.sim_reg.x(qbit.addr)
bloqade/pyqrack/target.py CHANGED
@@ -12,7 +12,7 @@ from bloqade.pyqrack.base import (
12
12
  PyQrackInterpreter,
13
13
  _default_pyqrack_args,
14
14
  )
15
- from bloqade.analysis.address import AnyAddress, AddressAnalysis
15
+ from bloqade.analysis.address import UnknownQubit, AddressAnalysis
16
16
 
17
17
  Params = ParamSpec("Params")
18
18
  RetType = TypeVar("RetType")
@@ -53,7 +53,7 @@ class PyQrack:
53
53
  address_analysis = AddressAnalysis(mt.dialects)
54
54
  frame, _ = address_analysis.run_analysis(mt)
55
55
  if self.min_qubits == 0 and any(
56
- isinstance(a, AnyAddress) for a in frame.entries.values()
56
+ isinstance(a, UnknownQubit) for a in frame.entries.values()
57
57
  ):
58
58
  raise ValueError(
59
59
  "All addresses must be resolved. Or set min_qubits to a positive integer."
@@ -1,12 +1,14 @@
1
1
  from kirin import interp
2
+ from kirin.analysis import const
2
3
 
3
4
  from bloqade.analysis.address import (
4
5
  Address,
5
- NotQubit,
6
6
  AddressReg,
7
- AddressQubit,
7
+ ConstResult,
8
+ UnknownQubit,
8
9
  AddressAnalysis,
9
10
  )
11
+ from bloqade.analysis.address.lattice import UnknownReg
10
12
 
11
13
  from .stmts import QRegGet, QRegNew
12
14
  from ._dialect import dialect
@@ -22,17 +24,24 @@ class AddressMethodTable(interp.MethodTable):
22
24
  frame: interp.Frame[Address],
23
25
  stmt: QRegNew,
24
26
  ):
25
- n_qubits = interp.get_const_value(int, stmt.n_qubits)
26
- addr = AddressReg(range(interp.next_address, interp.next_address + n_qubits))
27
- interp.next_address += n_qubits
28
- return (addr,)
27
+ n_qubits = frame.get(stmt.n_qubits)
28
+ match n_qubits:
29
+ case ConstResult(const.Value(int() as n)):
30
+ addr = AddressReg(range(interp.next_address, interp.next_address + n))
31
+ interp.next_address += n
32
+ return (addr,)
33
+ case _:
34
+ return (UnknownReg(),)
29
35
 
30
36
  @interp.impl(QRegGet)
31
37
  def get(self, interp: AddressAnalysis, frame: interp.Frame[Address], stmt: QRegGet):
32
38
  addr = frame.get(stmt.reg)
33
- pos = interp.get_const_value(int, stmt.idx)
34
- if isinstance(addr, AddressReg):
35
- global_idx = addr.data[pos]
36
- return (AddressQubit(global_idx),)
37
- else: # this is not reachable
38
- return (NotQubit(),)
39
+ idx = frame.get(stmt.idx)
40
+
41
+ typ, values = interp.unpack_iterable(addr)
42
+ idx_value = interp.get_const_value(idx, int)
43
+
44
+ if typ is not None and idx_value is not None:
45
+ return (values[idx_value],)
46
+
47
+ return (UnknownQubit(),)
@@ -1,7 +1,6 @@
1
1
  from kirin import interp
2
2
  from kirin.lattice import EmptyLattice
3
3
 
4
- from bloqade.analysis.address import AddressQubit, AddressTuple
5
4
  from bloqade.analysis.fidelity import FidelityAnalysis
6
5
 
7
6
  from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
@@ -42,10 +41,7 @@ class FidelityMethodTable(interp.MethodTable):
42
41
  stmt: AtomLossChannel,
43
42
  ):
44
43
  # NOTE: since AtomLossChannel acts on IList[Qubit], we know the assigned address is a tuple
45
- addresses: AddressTuple = interp.addr_frame.get(stmt.qargs)
46
-
44
+ addresses = interp.addr_frame.get(stmt.qargs)
47
45
  # NOTE: get the corresponding index and reduce survival probability accordingly
48
- for qbit_address in addresses.data:
49
- assert isinstance(qbit_address, AddressQubit)
50
- index = qbit_address.data
46
+ for index in addresses.data:
51
47
  interp._current_atom_survival_probability[index] *= 1 - stmt.prob
@@ -1,4 +1,5 @@
1
1
  import abc
2
+ from typing import Sequence
2
3
  from dataclasses import field, dataclass
3
4
 
4
5
 
@@ -161,7 +162,7 @@ class MoveNoiseModelABC(abc.ABC):
161
162
 
162
163
  @abc.abstractmethod
163
164
  def parallel_cz_errors(
164
- self, ctrls: list[int], qargs: list[int], rest: list[int]
165
+ self, ctrls: Sequence[int], qargs: Sequence[int], rest: Sequence[int]
165
166
  ) -> dict[tuple[float, float, float, float], list[int]]:
166
167
  """Takes a set of ctrls and qargs and returns a noise model for all qubits."""
167
168
  pass
@@ -28,7 +28,6 @@ from bloqade.qasm2.rewrite import (
28
28
  UOpToParallelRule,
29
29
  ParallelToGlobalRule,
30
30
  SimpleOptimalMergePolicy,
31
- RydbergGateSetRewriteRule,
32
31
  )
33
32
  from bloqade.squin.analysis import schedule
34
33
 
@@ -151,6 +150,9 @@ class UOpToParallel(Pass):
151
150
  return result
152
151
 
153
152
  if self.rewrite_to_native_first:
153
+ # NOTE: this import also imports cirq, so we do it locally here
154
+ from bloqade.qasm2.rewrite.native_gates import RydbergGateSetRewriteRule
155
+
154
156
  result = (
155
157
  Fixpoint(Walk(RydbergGateSetRewriteRule(self.dialects)))
156
158
  .rewrite(mt.code)
@@ -3,7 +3,6 @@ from .glob import (
3
3
  GlobalToParallelRule as GlobalToParallelRule,
4
4
  )
5
5
  from .register import RaiseRegisterRule as RaiseRegisterRule
6
- from .native_gates import RydbergGateSetRewriteRule as RydbergGateSetRewriteRule
7
6
  from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
8
7
  from .uop_to_parallel import (
9
8
  MergePolicyABC as MergePolicyABC,
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Tuple, cast
1
+ from typing import Dict, List, Tuple
2
2
  from dataclasses import field, dataclass
3
3
 
4
4
  from kirin import ir
@@ -55,7 +55,7 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
55
55
 
56
56
  def rewrite_global_single_qubit_gate(self, node: glob.UGate):
57
57
  addrs = self.address_analysis[node.registers]
58
- if not isinstance(addrs, address.AddressTuple):
58
+ if not isinstance(addrs, address.PartialIList):
59
59
  return rewrite_abc.RewriteResult()
60
60
 
61
61
  qargs = []
@@ -74,10 +74,7 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
74
74
 
75
75
  def rewrite_parallel_single_qubit_gate(self, node: parallel.RZ | parallel.UGate):
76
76
  addrs = self.address_analysis[node.qargs]
77
- if not isinstance(addrs, address.AddressTuple):
78
- return rewrite_abc.RewriteResult()
79
-
80
- if not all(isinstance(addr, address.AddressQubit) for addr in addrs.data):
77
+ if not isinstance(addrs, address.AddressReg):
81
78
  return rewrite_abc.RewriteResult()
82
79
 
83
80
  assert isinstance(node.qargs, ir.ResultValue)
@@ -178,18 +175,11 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
178
175
  qargs = self.address_analysis[node.qargs]
179
176
 
180
177
  has_done_something = False
181
- if (
182
- isinstance(ctrls, address.AddressTuple)
183
- and all(isinstance(addr, address.AddressQubit) for addr in ctrls.data)
184
- and isinstance(qargs, address.AddressTuple)
185
- and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
178
+ if isinstance(ctrls, address.AddressReg) and isinstance(
179
+ qargs, address.AddressReg
186
180
  ):
187
- ctrl_qubits = list(
188
- map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
189
- )
190
- qarg_qubits = list(
191
- map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
192
- )
181
+ ctrl_qubits = tuple(ctrls.data)
182
+ qarg_qubits = tuple(qargs.data)
193
183
  rest = sorted(
194
184
  set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
195
185
  )
@@ -3,7 +3,6 @@ from dataclasses import dataclass
3
3
 
4
4
  from kirin import ir
5
5
  from kirin.rewrite import abc
6
- from kirin.analysis import const
7
6
  from kirin.dialects import ilist
8
7
 
9
8
  from bloqade.analysis import address
@@ -20,28 +19,24 @@ class ParallelToGlobalRule(abc.RewriteRule):
20
19
  return abc.RewriteResult()
21
20
 
22
21
  qargs = node.qargs
23
- qarg_addresses = self.address_analysis.get(qargs, None)
22
+ qargs_address = self.address_analysis.get(qargs, address.Unknown())
24
23
 
25
- if isinstance(qarg_addresses, address.AddressReg):
26
- # NOTE: we only have an AddressReg if it's an entire register, definitely rewrite that
27
- return self._rewrite_parallel_to_glob(node)
28
-
29
- if not isinstance(qarg_addresses, address.AddressTuple):
24
+ if not isinstance(qargs_address, address.AddressReg):
30
25
  return abc.RewriteResult()
31
26
 
32
- idxs, qreg = self._find_qreg(qargs.owner, set())
27
+ qregs = self._get_all_qreg(qargs.owner)
33
28
 
34
- if qreg is None:
35
- # NOTE: no unique register found
29
+ if len(qregs) != 1:
36
30
  return abc.RewriteResult()
37
31
 
38
- if not isinstance(hint := qreg.n_qubits.hints.get("const"), const.Value):
39
- # NOTE: non-constant number of qubits
32
+ qreg = next(iter(qregs))
33
+
34
+ qreg_address = self.address_analysis.get(qreg, address.Unknown())
35
+
36
+ if not isinstance(qreg_address, address.AddressReg):
40
37
  return abc.RewriteResult()
41
38
 
42
- n = hint.data
43
- if len(idxs) != n:
44
- # NOTE: not all qubits of the register are there
39
+ if set(qargs_address.data) != set(qreg_address.data):
45
40
  return abc.RewriteResult()
46
41
 
47
42
  return self._rewrite_parallel_to_glob(node)
@@ -53,6 +48,24 @@ class ParallelToGlobalRule(abc.RewriteRule):
53
48
  node.replace_by(global_u)
54
49
  return abc.RewriteResult(has_done_something=True)
55
50
 
51
+ @staticmethod
52
+ def _get_all_qreg(owner: ir.Statement | ir.Block):
53
+ stack = [owner]
54
+ qregs: set[ir.SSAValue] = set()
55
+ while stack:
56
+ current = stack.pop()
57
+
58
+ if isinstance(current, core.stmts.QRegGet):
59
+ stack.append(current.reg.owner)
60
+ elif isinstance(current, ilist.New):
61
+ for val in current.values:
62
+ stack.append(val.owner)
63
+
64
+ elif isinstance(current, core.QRegNew):
65
+ qregs.add(current.result)
66
+
67
+ return qregs
68
+
56
69
  @staticmethod
57
70
  def _find_qreg(
58
71
  qargs_owner: ir.Statement | ir.Block, idxs: set
@@ -21,16 +21,10 @@ class ParallelToUOpRule(abc.RewriteRule):
21
21
 
22
22
  def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
23
23
  addr = self.address_analysis.get(ilist_ref)
24
- if not isinstance(addr, address.AddressTuple):
24
+ if not isinstance(addr, address.AddressReg):
25
25
  return None
26
26
 
27
- ids = []
28
- for ele in addr.data:
29
- if not isinstance(ele, address.AddressQubit):
30
- return None
31
-
32
- ids.append(ele.data)
33
-
27
+ ids = addr.data
34
28
  return [self.id_map[ele] for ele in ids]
35
29
 
36
30
  def rewrite_cz(self, node: ir.Statement):
@@ -0,0 +1,12 @@
1
+ from bloqade.types import Qubit as Qubit, QubitType as QubitType
2
+
3
+ from . import stmts as stmts, analysis as analysis
4
+ from .stdlib import new as new, qalloc as qalloc, broadcast as broadcast
5
+ from ._dialect import dialect as dialect
6
+ from ._prelude import kernel as kernel
7
+ from .stdlib.simple import (
8
+ reset as reset,
9
+ measure as measure,
10
+ get_qubit_id as get_qubit_id,
11
+ get_measurement_id as get_measurement_id,
12
+ )
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("qubit")
@@ -0,0 +1,49 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+ from kirin.lowering import wraps
5
+
6
+ from bloqade.types import Qubit, MeasurementResult
7
+
8
+ from .stmts import New, Reset, Measure, QubitId, MeasurementId
9
+
10
+
11
+ @wraps(New)
12
+ def new() -> Qubit:
13
+ """Create a new qubit.
14
+
15
+ Returns:
16
+ Qubit: A new qubit.
17
+ """
18
+ ...
19
+
20
+
21
+ N = TypeVar("N", bound=int)
22
+
23
+
24
+ @wraps(Measure)
25
+ def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
26
+ """Measure a list of qubits.
27
+
28
+ Args:
29
+ qubits (IList[Qubit, N]): The list of qubits to measure.
30
+
31
+ Returns:
32
+ IList[MeasurementResult, N]: The list containing the results of the measurements.
33
+ A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
34
+ """
35
+ ...
36
+
37
+
38
+ @wraps(QubitId)
39
+ def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]: ...
40
+
41
+
42
+ @wraps(MeasurementId)
43
+ def get_measurement_id(
44
+ measurements: ilist.IList[MeasurementResult, N],
45
+ ) -> ilist.IList[int, N]: ...
46
+
47
+
48
+ @wraps(Reset)
49
+ def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...