bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -21,16 +21,10 @@ class ParallelToUOpRule(abc.RewriteRule):
21
21
 
22
22
  def get_qubit_ssa(self, ilist_ref: ir.SSAValue) -> Optional[List[ir.SSAValue]]:
23
23
  addr = self.address_analysis.get(ilist_ref)
24
- if not isinstance(addr, address.AddressTuple):
24
+ if not isinstance(addr, address.AddressReg):
25
25
  return None
26
26
 
27
- ids = []
28
- for ele in addr.data:
29
- if not isinstance(ele, address.AddressQubit):
30
- return None
31
-
32
- ids.append(ele.data)
33
-
27
+ ids = addr.data
34
28
  return [self.id_map[ele] for ele in ids]
35
29
 
36
30
  def rewrite_cz(self, node: ir.Statement):
@@ -2,7 +2,7 @@ from kirin import ir
2
2
  from kirin.dialects import py
3
3
  from kirin.rewrite.abc import RewriteRule, RewriteResult
4
4
 
5
- from bloqade.qasm2.dialects import core
5
+ from bloqade.qasm2.dialects import core, expr
6
6
 
7
7
 
8
8
  class RaiseRegisterRule(RewriteRule):
@@ -26,7 +26,7 @@ class RaiseRegisterRule(RewriteRule):
26
26
  n_qubits_ref = node.n_qubits
27
27
 
28
28
  n_qubits = n_qubits_ref.owner
29
- if isinstance(n_qubits, py.Constant):
29
+ if isinstance(n_qubits, py.Constant | expr.ConstInt):
30
30
  # case where the n_qubits comes from a constant
31
31
  new_n_qubits = n_qubits.from_stmt(n_qubits)
32
32
  new_n_qubits.insert_before(first_stmt)
@@ -8,7 +8,7 @@ from kirin.rewrite.abc import RewriteRule, RewriteResult
8
8
  from kirin.analysis.const import lattice
9
9
 
10
10
  from bloqade.analysis import address
11
- from bloqade.qasm2.dialects import uop, core, parallel
11
+ from bloqade.qasm2.dialects import uop, core, expr, parallel
12
12
  from bloqade.squin.analysis.schedule import StmtDag
13
13
 
14
14
 
@@ -66,7 +66,7 @@ class SimpleMergePolicy(MergePolicyABC):
66
66
  assert isinstance(hint1, lattice.Result) and isinstance(
67
67
  hint2, lattice.Result
68
68
  )
69
- return hint1.is_equal(hint2)
69
+ return hint1.is_structurally_equal(hint2)
70
70
  else:
71
71
  return False
72
72
 
@@ -194,6 +194,8 @@ class SimpleMergePolicy(MergePolicyABC):
194
194
  new_qubits.append(new_qubit.result)
195
195
  case core.QRegGet(
196
196
  reg=reg, idx=ir.ResultValue(stmt=py.Constant() as idx)
197
+ ) | core.QRegGet(
198
+ reg=reg, idx=ir.ResultValue(stmt=expr.ConstInt() as idx)
197
199
  ):
198
200
  (new_idx := idx.from_stmt(idx)).insert_before(node)
199
201
  (
@@ -320,5 +320,6 @@ class Lowering:
320
320
  self.block_list.append(const_pi)
321
321
  turns = self.lower_number(2 * value)
322
322
  mul = qasm2.expr.Mul(const_pi.result, turns)
323
+ mul.result.type = types.Float
323
324
  self.block_list.append(mul)
324
325
  return mul.result
bloqade/qbraid/schema.py CHANGED
@@ -238,13 +238,13 @@ class NoiseModel(BaseModel, Generic[ErrorModelType], extra="forbid"):
238
238
  str: The decompiled circuit from hardware execution.
239
239
 
240
240
  """
241
- from bloqade.noise import native
242
241
  from bloqade.qasm2.emit import QASM2
243
242
  from bloqade.qasm2.passes import glob, parallel
243
+ from bloqade.qasm2.rewrite.noise import remove_noise
244
244
 
245
245
  mt = self.lower_noise_model("method")
246
246
 
247
- native.RemoveNoisePass(mt.dialects)(mt)
247
+ remove_noise.RemoveNoisePass(mt.dialects)(mt)
248
248
  parallel.ParallelToUOp(mt.dialects)(mt)
249
249
  glob.GlobalToUOP(mt.dialects)(mt)
250
250
  return QASM2(qelib1=True).emit_str(mt)
@@ -0,0 +1,12 @@
1
+ from bloqade.types import Qubit as Qubit, QubitType as QubitType
2
+
3
+ from . import stmts as stmts, analysis as analysis
4
+ from .stdlib import new as new, qalloc as qalloc, broadcast as broadcast
5
+ from ._dialect import dialect as dialect
6
+ from ._prelude import kernel as kernel
7
+ from .stdlib.simple import (
8
+ reset as reset,
9
+ measure as measure,
10
+ get_qubit_id as get_qubit_id,
11
+ get_measurement_id as get_measurement_id,
12
+ )
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("qubit")
@@ -0,0 +1,49 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+ from kirin.lowering import wraps
5
+
6
+ from bloqade.types import Qubit, MeasurementResult
7
+
8
+ from .stmts import New, Reset, Measure, QubitId, MeasurementId
9
+
10
+
11
+ @wraps(New)
12
+ def new() -> Qubit:
13
+ """Create a new qubit.
14
+
15
+ Returns:
16
+ Qubit: A new qubit.
17
+ """
18
+ ...
19
+
20
+
21
+ N = TypeVar("N", bound=int)
22
+
23
+
24
+ @wraps(Measure)
25
+ def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
26
+ """Measure a list of qubits.
27
+
28
+ Args:
29
+ qubits (IList[Qubit, N]): The list of qubits to measure.
30
+
31
+ Returns:
32
+ IList[MeasurementResult, N]: The list containing the results of the measurements.
33
+ A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
34
+ """
35
+ ...
36
+
37
+
38
+ @wraps(QubitId)
39
+ def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]: ...
40
+
41
+
42
+ @wraps(MeasurementId)
43
+ def get_measurement_id(
44
+ measurements: ilist.IList[MeasurementResult, N],
45
+ ) -> ilist.IList[int, N]: ...
46
+
47
+
48
+ @wraps(Reset)
49
+ def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...
@@ -0,0 +1,45 @@
1
+ from typing import Annotated
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Default
5
+ from kirin.prelude import structural_no_opt
6
+ from typing_extensions import Doc
7
+
8
+ from . import _dialect as qubit
9
+
10
+
11
+ @ir.dialect_group(structural_no_opt.union([qubit]))
12
+ def kernel(self):
13
+ """Compile to a qubit kernel"""
14
+
15
+ def run_pass(
16
+ mt,
17
+ *,
18
+ verify: Annotated[
19
+ bool, Doc("run `verify` before running passes, default is `True`")
20
+ ] = True,
21
+ typeinfer: Annotated[
22
+ bool,
23
+ Doc(
24
+ "run type inference and apply the inferred type to IR, default `False`"
25
+ ),
26
+ ] = False,
27
+ fold: Annotated[bool, Doc("run folding passes, default is `True`")] = True,
28
+ aggressive: Annotated[
29
+ bool, Doc("run aggressive folding passes if `fold=True`")
30
+ ] = False,
31
+ no_raise: Annotated[
32
+ bool, Doc("do not raise exception during analysis, default is `True`")
33
+ ] = True,
34
+ ) -> None:
35
+ default_pass = Default(
36
+ self,
37
+ verify=verify,
38
+ fold=fold,
39
+ aggressive=aggressive,
40
+ typeinfer=typeinfer,
41
+ no_raise=no_raise,
42
+ )
43
+ default_pass.fixpoint(mt)
44
+
45
+ return run_pass
@@ -0,0 +1 @@
1
+ from . import address_impl as address_impl
@@ -0,0 +1,40 @@
1
+ from kirin import interp
2
+ from kirin.analysis import ForwardFrame
3
+
4
+ from bloqade.analysis.address.lattice import (
5
+ Address,
6
+ AddressQubit,
7
+ )
8
+ from bloqade.analysis.address.analysis import AddressAnalysis
9
+
10
+ from .. import stmts
11
+ from .._dialect import dialect
12
+
13
+ # Address lattice elements we can work with:
14
+ ## NotQubit (bottom), AnyAddress (top)
15
+
16
+ ## AddressTuple -> data: tuple[Address, ...]
17
+ ### Recursive type, could contain itself or other variants
18
+ ### This pops up in cases where you can have an IList/Tuple
19
+ ### That contains elements that could be other Address types
20
+
21
+ ## AddressReg -> data: Sequence[int]
22
+ ### specific to creation of a register of qubits
23
+
24
+ ## AddressQubit -> data: int
25
+ ### Base qubit address type
26
+
27
+
28
+ @dialect.register(key="qubit.address")
29
+ class SquinQubitMethodTable(interp.MethodTable):
30
+
31
+ @interp.impl(stmts.New)
32
+ def new_qubit(
33
+ self,
34
+ interp_: AddressAnalysis,
35
+ frame: ForwardFrame[Address],
36
+ stmt: stmts.New,
37
+ ):
38
+ addr = AddressQubit(interp_.next_address)
39
+ interp_.next_address += 1
40
+ return (addr,)
@@ -0,0 +1,2 @@
1
+ from . import simple as simple, broadcast as broadcast
2
+ from ._new import new as new, qalloc as qalloc
@@ -0,0 +1,34 @@
1
+ from typing import Any
2
+
3
+ from kirin.dialects import ilist
4
+
5
+ from .. import _interface as qubit
6
+ from .._prelude import kernel
7
+
8
+
9
+ @kernel(typeinfer=True)
10
+ def new() -> qubit.Qubit:
11
+ """Allocate a single new qubit
12
+
13
+ Returns:
14
+ (Qubit): The newly allocated qubit.
15
+ """
16
+ return qubit.new()
17
+
18
+
19
+ # NOTE: this is a special case, that doesn't use the usual simple / broadcast semantics.
20
+ @kernel(typeinfer=True)
21
+ def qalloc(n_qubits: int) -> ilist.IList[qubit.Qubit, Any]:
22
+ """Allocate a new list of qubits.
23
+
24
+ Args:
25
+ n_qubits(int): The number of qubits to create.
26
+
27
+ Returns:
28
+ (ilist.IList[Qubit, n_qubits]) A list of qubits.
29
+ """
30
+
31
+ def _new(qid: int) -> qubit.Qubit:
32
+ return qubit.new()
33
+
34
+ return ilist.map(_new, ilist.range(n_qubits))
@@ -0,0 +1,62 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import Qubit, MeasurementResult
6
+
7
+ from .. import _interface as _qubit
8
+ from .._prelude import kernel
9
+
10
+ N = TypeVar("N", bound=int)
11
+
12
+
13
+ @kernel
14
+ def reset(qubits: ilist.IList[Qubit, Any]) -> None:
15
+ """
16
+ Reset a list of qubits to the zero state.
17
+
18
+ Args:
19
+ qubits (IList[Qubit, Any]): The list of qubits to reset.
20
+ """
21
+ _qubit.reset(qubits)
22
+
23
+
24
+ @kernel
25
+ def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
26
+ """Measure a list of qubits.
27
+
28
+ Args:
29
+ qubits (IList[Qubit, N]): The list of qubits to measure.
30
+
31
+ Returns:
32
+ IList[MeasurementResult, N]: The list containing the results of the measurements.
33
+ A MeasurementResult can represent both 0 and 1 as well as atom loss.
34
+ """
35
+ return _qubit.measure(qubits)
36
+
37
+
38
+ @kernel
39
+ def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]:
40
+ """Get the global, unique ID of each qubit in the list.
41
+
42
+ Args:
43
+ qubits (IList[Qubit, N]): The list of qubits of which you want the ID.
44
+
45
+ Returns:
46
+ qubit_ids (IList[int, N]): The list of global, unique IDs of the qubits.
47
+ """
48
+ return _qubit.get_qubit_id(qubits)
49
+
50
+
51
+ @kernel
52
+ def get_measurement_id(
53
+ measurements: ilist.IList[MeasurementResult, N],
54
+ ) -> ilist.IList[int, N]:
55
+ """Get the global, unique ID of each of the measurement results in the list.
56
+
57
+ Args:
58
+ measurements (IList[MeasurementResult, N]): The previously taken measurement of which you want to know the ID.
59
+ Returns:
60
+ measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
61
+ """
62
+ return _qubit.get_measurement_id(measurements)
@@ -0,0 +1,59 @@
1
+ from kirin.dialects import ilist
2
+
3
+ from bloqade.types import Qubit, MeasurementResult
4
+
5
+ from . import broadcast
6
+ from .._prelude import kernel
7
+
8
+
9
+ @kernel
10
+ def reset(qubit: Qubit) -> None:
11
+ """
12
+ Reset a qubit to the zero state.
13
+
14
+ Args:
15
+ qubit (Qubit): The list qubit to reset.
16
+ """
17
+ return broadcast.reset(ilist.IList([qubit]))
18
+
19
+
20
+ @kernel
21
+ def measure(qubit: Qubit) -> MeasurementResult:
22
+ """Measure a qubit.
23
+
24
+ Args:
25
+ qubit (Qubit): The qubit to measure.
26
+
27
+ Returns:
28
+ MeasurementResult: The result of the measurement.
29
+ A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
30
+ """
31
+ measurement_results = broadcast.measure(ilist.IList([qubit]))
32
+ return measurement_results[0]
33
+
34
+
35
+ @kernel
36
+ def get_qubit_id(qubit: Qubit) -> int:
37
+ """Get the global, unique ID of the qubit.
38
+
39
+ Args:
40
+ qubit (Qubit): The qubit of which you want the ID.
41
+
42
+ Returns:
43
+ qubit_id (int): The global, unique ID of the qubit.
44
+ """
45
+ ids = broadcast.get_qubit_id(ilist.IList([qubit]))
46
+ return ids[0]
47
+
48
+
49
+ @kernel
50
+ def get_measurement_id(measurement: MeasurementResult) -> int:
51
+ """Get the global, unique ID of the measurement result.
52
+
53
+ Args:
54
+ measurement (MeasurementResult): The previously taken measurement of which you want to know the ID.
55
+ Returns:
56
+ measurement_id (int): The global, unique ID of the measurement.
57
+ """
58
+ ids = broadcast.get_measurement_id(ilist.IList([measurement]))
59
+ return ids[0]
bloqade/qubit/stmts.py ADDED
@@ -0,0 +1,60 @@
1
+ from kirin import ir, types, interp, lowering
2
+ from kirin.decl import info, statement
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import QubitType, MeasurementResultType
6
+
7
+ from ._dialect import dialect
8
+
9
+
10
+ @statement(dialect=dialect)
11
+ class New(ir.Statement):
12
+ traits = frozenset({lowering.FromPythonCall()})
13
+ result: ir.ResultValue = info.result(QubitType)
14
+
15
+
16
+ Len = types.TypeVar("Len", bound=types.Int)
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class Measure(ir.Statement):
21
+ traits = frozenset({lowering.FromPythonCall()})
22
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
23
+ result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType, Len])
24
+
25
+
26
+ @statement(dialect=dialect)
27
+ class QubitId(ir.Statement):
28
+ traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
29
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
30
+ result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
31
+
32
+
33
+ @statement(dialect=dialect)
34
+ class MeasurementId(ir.Statement):
35
+ traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
36
+ measurements: ir.SSAValue = info.argument(
37
+ ilist.IListType[MeasurementResultType, Len]
38
+ )
39
+ result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
40
+
41
+
42
+ @statement(dialect=dialect)
43
+ class Reset(ir.Statement):
44
+ traits = frozenset({lowering.FromPythonCall()})
45
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
46
+
47
+
48
+ # TODO: investigate why this is needed to get type inference to be correct.
49
+ @dialect.register(key="typeinfer")
50
+ class __TypeInfer(interp.MethodTable):
51
+ @interp.impl(Measure)
52
+ def measure_list(self, _interp, frame: interp.AbstractFrame, stmt: Measure):
53
+ qubit_type = frame.get(stmt.qubits)
54
+
55
+ if isinstance(qubit_type, types.Generic):
56
+ len_type = qubit_type.vars[1]
57
+ else:
58
+ len_type = types.Any
59
+
60
+ return (ilist.IListType[MeasurementResultType, len_type],)
@@ -1 +1,7 @@
1
+ from .callgraph import (
2
+ CallGraphPass as CallGraphPass,
3
+ ReplaceMethods as ReplaceMethods,
4
+ UpdateDialectsOnCallGraph as UpdateDialectsOnCallGraph,
5
+ )
6
+ from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
1
7
  from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
@@ -0,0 +1,103 @@
1
+ from typing import Callable
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.passes import Pass, HintConst, TypeInfer
6
+ from kirin.rewrite import (
7
+ Walk,
8
+ Chain,
9
+ Inline,
10
+ Fixpoint,
11
+ Call2Invoke,
12
+ ConstantFold,
13
+ CFGCompactify,
14
+ InlineGetItem,
15
+ InlineGetField,
16
+ DeadCodeElimination,
17
+ CommonSubexpressionElimination,
18
+ )
19
+ from kirin.dialects import scf, ilist
20
+ from kirin.ir.method import Method
21
+ from kirin.rewrite.abc import RewriteResult
22
+ from kirin.passes.aggressive import UnrollScf
23
+
24
+ from .canonicalize_ilist import CanonicalizeIList
25
+
26
+
27
+ @dataclass
28
+ class Fold(Pass):
29
+ hint_const: HintConst = field(init=False)
30
+
31
+ def __post_init__(self):
32
+ self.hint_const = HintConst(self.dialects, no_raise=self.no_raise)
33
+
34
+ def unsafe_run(self, mt: Method) -> RewriteResult:
35
+ result = RewriteResult()
36
+ result = self.hint_const.unsafe_run(mt).join(result)
37
+ rule = Chain(
38
+ ConstantFold(),
39
+ Call2Invoke(),
40
+ InlineGetField(),
41
+ InlineGetItem(),
42
+ ilist.rewrite.InlineGetItem(),
43
+ ilist.rewrite.FlattenAdd(),
44
+ ilist.rewrite.HintLen(),
45
+ DeadCodeElimination(),
46
+ )
47
+ result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
48
+
49
+ return result
50
+
51
+
52
+ @dataclass
53
+ class AggressiveUnroll(Pass):
54
+ """A pass to unroll structured control flow"""
55
+
56
+ additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True
57
+
58
+ fold: Fold = field(init=False)
59
+ typeinfer: TypeInfer = field(init=False)
60
+ scf_unroll: UnrollScf = field(init=False)
61
+ canonicalize_ilist: CanonicalizeIList = field(init=False)
62
+
63
+ def __post_init__(self):
64
+ self.fold = Fold(self.dialects, no_raise=self.no_raise)
65
+ self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
66
+ self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)
67
+ self.canonicalize_ilist = CanonicalizeIList(
68
+ self.dialects, no_raise=self.no_raise
69
+ )
70
+
71
+ def unsafe_run(self, mt: Method) -> RewriteResult:
72
+ result = RewriteResult()
73
+ result = self.fold.unsafe_run(mt).join(result)
74
+ result = self.scf_unroll.unsafe_run(mt).join(result)
75
+ self.typeinfer.unsafe_run(
76
+ mt
77
+ ) # Do not join the result of typeinfer or fixpoint will waste time
78
+ result = (
79
+ Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
80
+ .rewrite(mt.code)
81
+ .join(result)
82
+ )
83
+ result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
84
+ result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
85
+ result = self.canonicalize_ilist.fixpoint(mt).join(result)
86
+ rule = Chain(
87
+ CommonSubexpressionElimination(),
88
+ DeadCodeElimination(),
89
+ )
90
+ result = Walk(rule).rewrite(mt.code).join(result)
91
+
92
+ return result
93
+
94
+ def inline_heuristic(self, node: ir.Statement) -> bool:
95
+ """The heuristic to decide whether to inline a function call or not.
96
+ inside loops and if-else, only inline simple functions, i.e.
97
+ functions with a single block
98
+ """
99
+ return not isinstance(
100
+ node.parent_stmt, (scf.For, scf.IfElse)
101
+ ) and self.additional_inline_heuristic(
102
+ node
103
+ ) # always inline calls outside of loops and if-else