bloqade-circuit 0.7.12__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,45 @@
1
+ from typing import Annotated
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Default
5
+ from kirin.prelude import structural_no_opt
6
+ from typing_extensions import Doc
7
+
8
+ from . import _dialect as qubit
9
+
10
+
11
+ @ir.dialect_group(structural_no_opt.union([qubit]))
12
+ def kernel(self):
13
+ """Compile to a qubit kernel"""
14
+
15
+ def run_pass(
16
+ mt,
17
+ *,
18
+ verify: Annotated[
19
+ bool, Doc("run `verify` before running passes, default is `True`")
20
+ ] = True,
21
+ typeinfer: Annotated[
22
+ bool,
23
+ Doc(
24
+ "run type inference and apply the inferred type to IR, default `False`"
25
+ ),
26
+ ] = False,
27
+ fold: Annotated[bool, Doc("run folding passes, default is `True`")] = True,
28
+ aggressive: Annotated[
29
+ bool, Doc("run aggressive folding passes if `fold=True`")
30
+ ] = False,
31
+ no_raise: Annotated[
32
+ bool, Doc("do not raise exception during analysis, default is `True`")
33
+ ] = True,
34
+ ) -> None:
35
+ default_pass = Default(
36
+ self,
37
+ verify=verify,
38
+ fold=fold,
39
+ aggressive=aggressive,
40
+ typeinfer=typeinfer,
41
+ no_raise=no_raise,
42
+ )
43
+ default_pass.fixpoint(mt)
44
+
45
+ return run_pass
@@ -0,0 +1 @@
1
+ from . import address_impl as address_impl
@@ -0,0 +1,40 @@
1
+ from kirin import interp
2
+ from kirin.analysis import ForwardFrame
3
+
4
+ from bloqade.analysis.address.lattice import (
5
+ Address,
6
+ AddressQubit,
7
+ )
8
+ from bloqade.analysis.address.analysis import AddressAnalysis
9
+
10
+ from .. import stmts
11
+ from .._dialect import dialect
12
+
13
+ # Address lattice elements we can work with:
14
+ ## NotQubit (bottom), AnyAddress (top)
15
+
16
+ ## AddressTuple -> data: tuple[Address, ...]
17
+ ### Recursive type, could contain itself or other variants
18
+ ### This pops up in cases where you can have an IList/Tuple
19
+ ### That contains elements that could be other Address types
20
+
21
+ ## AddressReg -> data: Sequence[int]
22
+ ### specific to creation of a register of qubits
23
+
24
+ ## AddressQubit -> data: int
25
+ ### Base qubit address type
26
+
27
+
28
+ @dialect.register(key="qubit.address")
29
+ class SquinQubitMethodTable(interp.MethodTable):
30
+
31
+ @interp.impl(stmts.New)
32
+ def new_qubit(
33
+ self,
34
+ interp_: AddressAnalysis,
35
+ frame: ForwardFrame[Address],
36
+ stmt: stmts.New,
37
+ ):
38
+ addr = AddressQubit(interp_.next_address)
39
+ interp_.next_address += 1
40
+ return (addr,)
@@ -0,0 +1,2 @@
1
+ from . import simple as simple, broadcast as broadcast
2
+ from ._new import new as new, qalloc as qalloc
@@ -0,0 +1,34 @@
1
+ from typing import Any
2
+
3
+ from kirin.dialects import ilist
4
+
5
+ from .. import _interface as qubit
6
+ from .._prelude import kernel
7
+
8
+
9
+ @kernel(typeinfer=True)
10
+ def new() -> qubit.Qubit:
11
+ """Allocate a single new qubit
12
+
13
+ Returns:
14
+ (Qubit): The newly allocated qubit.
15
+ """
16
+ return qubit.new()
17
+
18
+
19
+ # NOTE: this is a special case, that doesn't use the usual simple / broadcast semantics.
20
+ @kernel(typeinfer=True)
21
+ def qalloc(n_qubits: int) -> ilist.IList[qubit.Qubit, Any]:
22
+ """Allocate a new list of qubits.
23
+
24
+ Args:
25
+ n_qubits(int): The number of qubits to create.
26
+
27
+ Returns:
28
+ (ilist.IList[Qubit, n_qubits]) A list of qubits.
29
+ """
30
+
31
+ def _new(qid: int) -> qubit.Qubit:
32
+ return qubit.new()
33
+
34
+ return ilist.map(_new, ilist.range(n_qubits))
@@ -0,0 +1,62 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import Qubit, MeasurementResult
6
+
7
+ from .. import _interface as _qubit
8
+ from .._prelude import kernel
9
+
10
+ N = TypeVar("N", bound=int)
11
+
12
+
13
+ @kernel
14
+ def reset(qubits: ilist.IList[Qubit, Any]) -> None:
15
+ """
16
+ Reset a list of qubits to the zero state.
17
+
18
+ Args:
19
+ qubits (IList[Qubit, Any]): The list of qubits to reset.
20
+ """
21
+ _qubit.reset(qubits)
22
+
23
+
24
+ @kernel
25
+ def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
26
+ """Measure a list of qubits.
27
+
28
+ Args:
29
+ qubits (IList[Qubit, N]): The list of qubits to measure.
30
+
31
+ Returns:
32
+ IList[MeasurementResult, N]: The list containing the results of the measurements.
33
+ A MeasurementResult can represent both 0 and 1 as well as atom loss.
34
+ """
35
+ return _qubit.measure(qubits)
36
+
37
+
38
+ @kernel
39
+ def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]:
40
+ """Get the global, unique ID of each qubit in the list.
41
+
42
+ Args:
43
+ qubits (IList[Qubit, N]): The list of qubits of which you want the ID.
44
+
45
+ Returns:
46
+ qubit_ids (IList[int, N]): The list of global, unique IDs of the qubits.
47
+ """
48
+ return _qubit.get_qubit_id(qubits)
49
+
50
+
51
+ @kernel
52
+ def get_measurement_id(
53
+ measurements: ilist.IList[MeasurementResult, N],
54
+ ) -> ilist.IList[int, N]:
55
+ """Get the global, unique ID of each of the measurement results in the list.
56
+
57
+ Args:
58
+ measurements (IList[MeasurementResult, N]): The previously taken measurement of which you want to know the ID.
59
+ Returns:
60
+ measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
61
+ """
62
+ return _qubit.get_measurement_id(measurements)
@@ -0,0 +1,59 @@
1
+ from kirin.dialects import ilist
2
+
3
+ from bloqade.types import Qubit, MeasurementResult
4
+
5
+ from . import broadcast
6
+ from .._prelude import kernel
7
+
8
+
9
+ @kernel
10
+ def reset(qubit: Qubit) -> None:
11
+ """
12
+ Reset a qubit to the zero state.
13
+
14
+ Args:
15
+ qubit (Qubit): The list qubit to reset.
16
+ """
17
+ return broadcast.reset(ilist.IList([qubit]))
18
+
19
+
20
+ @kernel
21
+ def measure(qubit: Qubit) -> MeasurementResult:
22
+ """Measure a qubit.
23
+
24
+ Args:
25
+ qubit (Qubit): The qubit to measure.
26
+
27
+ Returns:
28
+ MeasurementResult: The result of the measurement.
29
+ A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
30
+ """
31
+ measurement_results = broadcast.measure(ilist.IList([qubit]))
32
+ return measurement_results[0]
33
+
34
+
35
+ @kernel
36
+ def get_qubit_id(qubit: Qubit) -> int:
37
+ """Get the global, unique ID of the qubit.
38
+
39
+ Args:
40
+ qubit (Qubit): The qubit of which you want the ID.
41
+
42
+ Returns:
43
+ qubit_id (int): The global, unique ID of the qubit.
44
+ """
45
+ ids = broadcast.get_qubit_id(ilist.IList([qubit]))
46
+ return ids[0]
47
+
48
+
49
+ @kernel
50
+ def get_measurement_id(measurement: MeasurementResult) -> int:
51
+ """Get the global, unique ID of the measurement result.
52
+
53
+ Args:
54
+ measurement (MeasurementResult): The previously taken measurement of which you want to know the ID.
55
+ Returns:
56
+ measurement_id (int): The global, unique ID of the measurement.
57
+ """
58
+ ids = broadcast.get_measurement_id(ilist.IList([measurement]))
59
+ return ids[0]
bloqade/qubit/stmts.py ADDED
@@ -0,0 +1,60 @@
1
+ from kirin import ir, types, interp, lowering
2
+ from kirin.decl import info, statement
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import QubitType, MeasurementResultType
6
+
7
+ from ._dialect import dialect
8
+
9
+
10
+ @statement(dialect=dialect)
11
+ class New(ir.Statement):
12
+ traits = frozenset({lowering.FromPythonCall()})
13
+ result: ir.ResultValue = info.result(QubitType)
14
+
15
+
16
+ Len = types.TypeVar("Len", bound=types.Int)
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class Measure(ir.Statement):
21
+ traits = frozenset({lowering.FromPythonCall()})
22
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
23
+ result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType, Len])
24
+
25
+
26
+ @statement(dialect=dialect)
27
+ class QubitId(ir.Statement):
28
+ traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
29
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
30
+ result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
31
+
32
+
33
+ @statement(dialect=dialect)
34
+ class MeasurementId(ir.Statement):
35
+ traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
36
+ measurements: ir.SSAValue = info.argument(
37
+ ilist.IListType[MeasurementResultType, Len]
38
+ )
39
+ result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
40
+
41
+
42
+ @statement(dialect=dialect)
43
+ class Reset(ir.Statement):
44
+ traits = frozenset({lowering.FromPythonCall()})
45
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
46
+
47
+
48
+ # TODO: investigate why this is needed to get type inference to be correct.
49
+ @dialect.register(key="typeinfer")
50
+ class __TypeInfer(interp.MethodTable):
51
+ @interp.impl(Measure)
52
+ def measure_list(self, _interp, frame: interp.AbstractFrame, stmt: Measure):
53
+ qubit_type = frame.get(stmt.qubits)
54
+
55
+ if isinstance(qubit_type, types.Generic):
56
+ len_type = qubit_type.vars[1]
57
+ else:
58
+ len_type = types.Any
59
+
60
+ return (ilist.IListType[MeasurementResultType, len_type],)
@@ -38,6 +38,7 @@ class Fold(Pass):
38
38
  InlineGetField(),
39
39
  InlineGetItem(),
40
40
  ilist.rewrite.InlineGetItem(),
41
+ ilist.rewrite.FlattenAdd(),
41
42
  ilist.rewrite.HintLen(),
42
43
  )
43
44
  result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
@@ -68,7 +69,7 @@ class AggressiveUnroll(Pass):
68
69
  .rewrite(mt.code)
69
70
  .join(result)
70
71
  )
71
- result = self.typeinfer.unsafe_run(mt).join(result)
72
+ self.typeinfer.unsafe_run(mt)
72
73
  result = self.fold.unsafe_run(mt).join(result)
73
74
  result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
74
75
  result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
bloqade/squin/__init__.py CHANGED
@@ -1,24 +1,51 @@
1
1
  from . import (
2
- op as op,
3
- wire as wire,
2
+ gate as gate,
4
3
  noise as noise,
5
- qubit as qubit,
6
4
  analysis as analysis,
7
- lowering as lowering,
8
- _typeinfer as _typeinfer,
9
5
  )
10
- from .groups import wired as wired, kernel as kernel
6
+ from .. import qubit as qubit
7
+ from ..qubit import (
8
+ reset as reset,
9
+ qalloc as qalloc,
10
+ measure as measure,
11
+ get_qubit_id as get_qubit_id,
12
+ get_measurement_id as get_measurement_id,
13
+ )
14
+ from .groups import kernel as kernel
15
+ from .stdlib.simple import (
16
+ h as h,
17
+ s as s,
18
+ t as t,
19
+ x as x,
20
+ y as y,
21
+ z as z,
22
+ cx as cx,
23
+ cy as cy,
24
+ cz as cz,
25
+ rx as rx,
26
+ ry as ry,
27
+ rz as rz,
28
+ u3 as u3,
29
+ s_adj as s_adj,
30
+ shift as shift,
31
+ t_adj as t_adj,
32
+ sqrt_x as sqrt_x,
33
+ sqrt_y as sqrt_y,
34
+ sqrt_z as sqrt_z,
35
+ bit_flip as bit_flip,
36
+ depolarize as depolarize,
37
+ qubit_loss as qubit_loss,
38
+ sqrt_x_adj as sqrt_x_adj,
39
+ sqrt_y_adj as sqrt_y_adj,
40
+ sqrt_z_adj as sqrt_z_adj,
41
+ depolarize2 as depolarize2,
42
+ correlated_qubit_loss as correlated_qubit_loss,
43
+ two_qubit_pauli_channel as two_qubit_pauli_channel,
44
+ single_qubit_pauli_channel as single_qubit_pauli_channel,
45
+ )
11
46
 
12
47
  # NOTE: it's important to keep these imports here since they import squin.kernel
13
48
  # we skip isort here
14
- from . import parallel as parallel # isort: skip
15
- from .stdlib import gate as gate, channel as channel # isort: skip
16
-
17
- try:
18
- # NOTE: make sure optional cirq dependency is installed
19
- import cirq as cirq_package # noqa: F401
20
- except ImportError:
21
- pass
22
- else:
23
- from . import cirq as cirq
24
- from .cirq import load_circuit as load_circuit
49
+ from .stdlib import ( # isort: skip
50
+ broadcast as broadcast,
51
+ )
@@ -1 +0,0 @@
1
- from . import address_impl as address_impl
@@ -210,8 +210,8 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
210
210
  if old_stmt is not None:
211
211
  self.stmt_dag.add_edge(old_stmt, stmt)
212
212
  self.use_def[idx] = stmt
213
- elif isinstance(addr, address.AddressTuple):
214
- for sub_addr in addr.data:
213
+ elif isinstance(addr, address.AddressReg):
214
+ for sub_addr in addr.qubits:
215
215
  self._update_dag(stmt, sub_addr)
216
216
 
217
217
  def update_dag(self, stmt: ir.Statement, args: Sequence[ir.SSAValue]):
@@ -0,0 +1,2 @@
1
+ from . import stmts as stmts
2
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("squin.gate")
@@ -0,0 +1,98 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+ from kirin.lowering import wraps
5
+
6
+ from bloqade.types import Qubit
7
+
8
+ from .stmts import (
9
+ CX,
10
+ CY,
11
+ CZ,
12
+ U3,
13
+ H,
14
+ S,
15
+ T,
16
+ X,
17
+ Y,
18
+ Z,
19
+ Rx,
20
+ Ry,
21
+ Rz,
22
+ SqrtX,
23
+ SqrtY,
24
+ )
25
+
26
+
27
+ @wraps(X)
28
+ def x(qubits: ilist.IList[Qubit, Any]) -> None: ...
29
+
30
+
31
+ @wraps(Y)
32
+ def y(qubits: ilist.IList[Qubit, Any]) -> None: ...
33
+
34
+
35
+ @wraps(Z)
36
+ def z(qubits: ilist.IList[Qubit, Any]) -> None: ...
37
+
38
+
39
+ @wraps(H)
40
+ def h(qubits: ilist.IList[Qubit, Any]) -> None: ...
41
+
42
+
43
+ @wraps(T)
44
+ def t(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
45
+
46
+
47
+ @wraps(S)
48
+ def s(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
49
+
50
+
51
+ @wraps(SqrtX)
52
+ def sqrt_x(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
53
+
54
+
55
+ @wraps(SqrtY)
56
+ def sqrt_y(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
57
+
58
+
59
+ @wraps(Rx)
60
+ def rx(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
61
+
62
+
63
+ @wraps(Ry)
64
+ def ry(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
65
+
66
+
67
+ @wraps(Rz)
68
+ def rz(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
69
+
70
+
71
+ Len = TypeVar("Len", bound=int)
72
+
73
+
74
+ @wraps(CX)
75
+ def cx(
76
+ controls: ilist.IList[Qubit, Len],
77
+ targets: ilist.IList[Qubit, Len],
78
+ ) -> None: ...
79
+
80
+
81
+ @wraps(CY)
82
+ def cy(
83
+ controls: ilist.IList[Qubit, Len],
84
+ targets: ilist.IList[Qubit, Len],
85
+ ) -> None: ...
86
+
87
+
88
+ @wraps(CZ)
89
+ def cz(
90
+ controls: ilist.IList[Qubit, Len],
91
+ targets: ilist.IList[Qubit, Len],
92
+ ) -> None: ...
93
+
94
+
95
+ @wraps(U3)
96
+ def u3(
97
+ theta: float, phi: float, lam: float, qubits: ilist.IList[Qubit, Any]
98
+ ) -> None: ...
@@ -0,0 +1,119 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import QubitType
6
+
7
+ from ._dialect import dialect
8
+
9
+
10
+ @statement
11
+ class SingleQubitGate(ir.Statement):
12
+ traits = frozenset({lowering.FromPythonCall()})
13
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
14
+
15
+
16
+ @statement(dialect=dialect)
17
+ class X(SingleQubitGate):
18
+ pass
19
+
20
+
21
+ @statement(dialect=dialect)
22
+ class Y(SingleQubitGate):
23
+ pass
24
+
25
+
26
+ @statement(dialect=dialect)
27
+ class Z(SingleQubitGate):
28
+ pass
29
+
30
+
31
+ @statement(dialect=dialect)
32
+ class H(SingleQubitGate):
33
+ pass
34
+
35
+
36
+ @statement
37
+ class SingleQubitNonHermitianGate(SingleQubitGate):
38
+ adjoint: bool = info.attribute(default=False)
39
+
40
+
41
+ @statement(dialect=dialect)
42
+ class T(SingleQubitNonHermitianGate):
43
+ pass
44
+
45
+
46
+ @statement(dialect=dialect)
47
+ class S(SingleQubitNonHermitianGate):
48
+ pass
49
+
50
+
51
+ @statement(dialect=dialect)
52
+ class SqrtX(SingleQubitNonHermitianGate):
53
+ pass
54
+
55
+
56
+ @statement(dialect=dialect)
57
+ class SqrtY(SingleQubitNonHermitianGate):
58
+ pass
59
+
60
+
61
+ @statement
62
+ class RotationGate(ir.Statement):
63
+ # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
64
+ traits = frozenset({lowering.FromPythonCall()})
65
+ angle: ir.SSAValue = info.argument(types.Float)
66
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
67
+
68
+
69
+ @statement(dialect=dialect)
70
+ class Rx(RotationGate):
71
+ pass
72
+
73
+
74
+ @statement(dialect=dialect)
75
+ class Ry(RotationGate):
76
+ pass
77
+
78
+
79
+ @statement(dialect=dialect)
80
+ class Rz(RotationGate):
81
+ pass
82
+
83
+
84
+ N = types.TypeVar("N", bound=types.Int)
85
+
86
+
87
+ @statement
88
+ class ControlledGate(ir.Statement):
89
+ traits = frozenset({lowering.FromPythonCall()})
90
+ controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
91
+ targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
92
+
93
+
94
+ @statement(dialect=dialect)
95
+ class CX(ControlledGate):
96
+ name = "cx"
97
+ pass
98
+
99
+
100
+ @statement(dialect=dialect)
101
+ class CY(ControlledGate):
102
+ name = "cy"
103
+ pass
104
+
105
+
106
+ @statement(dialect=dialect)
107
+ class CZ(ControlledGate):
108
+ name = "cz"
109
+ pass
110
+
111
+
112
+ @statement(dialect=dialect)
113
+ class U3(ir.Statement):
114
+ # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
115
+ traits = frozenset({lowering.FromPythonCall()})
116
+ theta: ir.SSAValue = info.argument(types.Float)
117
+ phi: ir.SSAValue = info.argument(types.Float)
118
+ lam: ir.SSAValue = info.argument(types.Float)
119
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
bloqade/squin/groups.py CHANGED
@@ -1,31 +1,24 @@
1
1
  from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
- from kirin.rewrite import Walk, Chain
4
3
  from kirin.dialects import debug, ilist
5
4
 
6
- from . import op, wire, noise, qubit
7
- from .op.rewrite import PyMultToSquinMult
8
- from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
5
+ from . import gate, noise
6
+ from .. import qubit
9
7
 
10
8
 
11
- @ir.dialect_group(structural_no_opt.union([op, qubit, noise, debug]))
9
+ @ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug]))
12
10
  def kernel(self):
13
11
  fold_pass = passes.Fold(self)
14
12
  typeinfer_pass = passes.TypeInfer(self)
15
13
  ilist_desugar_pass = ilist.IListDesugar(self)
16
- desugar_pass = Walk(Chain(MeasureDesugarRule(), ApplyDesugarRule()))
17
- py_mult_to_mult_pass = PyMultToSquinMult(self)
18
14
 
19
15
  def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
20
16
  method.verify()
21
17
  if fold:
22
18
  fold_pass.fixpoint(method)
23
19
 
24
- py_mult_to_mult_pass(method)
25
-
26
20
  if typeinfer:
27
- typeinfer_pass(method)
28
- desugar_pass.rewrite(method.code)
21
+ typeinfer_pass(method) # infer types before desugaring
29
22
 
30
23
  ilist_desugar_pass(method)
31
24
 
@@ -34,13 +27,3 @@ def kernel(self):
34
27
  method.verify_type()
35
28
 
36
29
  return run_pass
37
-
38
-
39
- @ir.dialect_group(structural_no_opt.union([op, wire, noise]))
40
- def wired(self):
41
- py_mult_to_mult_pass = PyMultToSquinMult(self)
42
-
43
- def run_pass(method):
44
- py_mult_to_mult_pass(method)
45
-
46
- return run_pass