bloqade-circuit 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (61) hide show
  1. bloqade/analysis/address/impls.py +21 -68
  2. bloqade/analysis/measure_id/__init__.py +2 -0
  3. bloqade/analysis/measure_id/analysis.py +45 -0
  4. bloqade/analysis/measure_id/impls.py +155 -0
  5. bloqade/analysis/measure_id/lattice.py +82 -0
  6. bloqade/cirq_utils/__init__.py +7 -0
  7. bloqade/cirq_utils/lineprog.py +295 -0
  8. bloqade/cirq_utils/parallelize.py +400 -0
  9. bloqade/pyqrack/squin/op.py +7 -2
  10. bloqade/pyqrack/squin/runtime.py +4 -2
  11. bloqade/qasm2/dialects/expr/stmts.py +2 -20
  12. bloqade/qasm2/parse/lowering.py +1 -0
  13. bloqade/qasm2/passes/parallel.py +18 -0
  14. bloqade/qasm2/passes/unroll_if.py +9 -2
  15. bloqade/qasm2/rewrite/__init__.py +1 -0
  16. bloqade/qasm2/rewrite/parallel_to_glob.py +82 -0
  17. bloqade/rewrite/__init__.py +0 -0
  18. bloqade/rewrite/passes/__init__.py +1 -0
  19. bloqade/rewrite/passes/canonicalize_ilist.py +28 -0
  20. bloqade/rewrite/rules/__init__.py +1 -0
  21. bloqade/rewrite/rules/flatten_ilist.py +51 -0
  22. bloqade/rewrite/rules/inline_getitem_ilist.py +31 -0
  23. bloqade/{qasm2/rewrite → rewrite/rules}/split_ifs.py +15 -8
  24. bloqade/squin/__init__.py +2 -0
  25. bloqade/squin/_typeinfer.py +20 -0
  26. bloqade/squin/analysis/__init__.py +1 -0
  27. bloqade/squin/analysis/address_impl.py +71 -0
  28. bloqade/squin/analysis/nsites/impls.py +6 -1
  29. bloqade/squin/cirq/lowering.py +19 -6
  30. bloqade/squin/noise/stmts.py +1 -1
  31. bloqade/squin/op/__init__.py +1 -0
  32. bloqade/squin/op/_wrapper.py +4 -0
  33. bloqade/squin/op/stmts.py +20 -2
  34. bloqade/squin/qubit.py +8 -5
  35. bloqade/squin/rewrite/__init__.py +1 -0
  36. bloqade/squin/rewrite/canonicalize.py +60 -0
  37. bloqade/squin/rewrite/desugar.py +52 -5
  38. bloqade/squin/types.py +8 -0
  39. bloqade/squin/wire.py +91 -5
  40. bloqade/stim/__init__.py +1 -0
  41. bloqade/stim/_wrappers.py +4 -0
  42. bloqade/stim/dialects/auxiliary/interp.py +0 -10
  43. bloqade/stim/dialects/auxiliary/stmts/annotate.py +1 -1
  44. bloqade/stim/dialects/noise/emit.py +1 -0
  45. bloqade/stim/dialects/noise/stmts.py +5 -0
  46. bloqade/stim/passes/__init__.py +1 -1
  47. bloqade/stim/passes/simplify_ifs.py +32 -0
  48. bloqade/stim/passes/squin_to_stim.py +109 -26
  49. bloqade/stim/rewrite/__init__.py +1 -0
  50. bloqade/stim/rewrite/ifs_to_stim.py +203 -0
  51. bloqade/stim/rewrite/qubit_to_stim.py +13 -6
  52. bloqade/stim/rewrite/squin_measure.py +68 -5
  53. bloqade/stim/rewrite/squin_noise.py +120 -0
  54. bloqade/stim/rewrite/util.py +40 -9
  55. bloqade/stim/rewrite/wire_to_stim.py +8 -3
  56. bloqade/stim/upstream/__init__.py +1 -0
  57. bloqade/stim/upstream/from_squin.py +10 -0
  58. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/METADATA +4 -2
  59. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/RECORD +61 -38
  60. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/WHEEL +0 -0
  61. {bloqade_circuit-0.4.5.dist-info → bloqade_circuit-0.5.1.dist-info}/licenses/LICENSE +0 -0
@@ -96,10 +96,15 @@ class PyQrackMethods(interp.MethodTable):
96
96
  return (PhaseOpRuntime(theta, global_=global_),)
97
97
 
98
98
  @interp.impl(op.stmts.Reset)
99
+ @interp.impl(op.stmts.ResetToOne)
99
100
  def reset(
100
- self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Reset
101
+ self,
102
+ interp: PyQrackInterpreter,
103
+ frame: interp.Frame,
104
+ stmt: op.stmts.Reset | op.stmts.ResetToOne,
101
105
  ) -> tuple[OperatorRuntimeABC]:
102
- return (ResetRuntime(),)
106
+ target_state = isinstance(stmt, op.stmts.ResetToOne)
107
+ return (ResetRuntime(target_state=target_state),)
103
108
 
104
109
  @interp.impl(op.stmts.X)
105
110
  @interp.impl(op.stmts.Y)
@@ -43,7 +43,9 @@ class OperatorRuntimeABC:
43
43
 
44
44
  @dataclass(frozen=True)
45
45
  class ResetRuntime(OperatorRuntimeABC):
46
- """Reset the qubit to |0> state"""
46
+ """Reset the qubit to the target state"""
47
+
48
+ target_state: bool
47
49
 
48
50
  @property
49
51
  def n_sites(self) -> int:
@@ -55,7 +57,7 @@ class ResetRuntime(OperatorRuntimeABC):
55
57
  continue
56
58
 
57
59
  res: bool = qubit.sim_reg.m(qubit.addr)
58
- if res:
60
+ if res != self.target_state:
59
61
  qubit.sim_reg.x(qubit.addr)
60
62
 
61
63
 
@@ -1,34 +1,16 @@
1
1
  from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
+ from kirin.dialects import func
3
4
  from kirin.print.printer import Printer
4
- from kirin.dialects.func.attrs import Signature
5
5
 
6
6
  from ._dialect import dialect
7
7
 
8
8
 
9
- class GateFuncOpCallableInterface(ir.CallableStmtInterface["GateFunction"]):
10
-
11
- @classmethod
12
- def get_callable_region(cls, stmt: "GateFunction") -> ir.Region:
13
- return stmt.body
14
-
15
-
16
9
  @statement(dialect=dialect)
17
- class GateFunction(ir.Statement):
10
+ class GateFunction(func.Function):
18
11
  """Special Function for qasm2 gate subroutine."""
19
12
 
20
13
  name = "gate.func"
21
- traits = frozenset(
22
- {
23
- ir.IsolatedFromAbove(),
24
- ir.SymbolOpInterface(),
25
- ir.HasSignature(),
26
- GateFuncOpCallableInterface(),
27
- }
28
- )
29
- sym_name: str = info.attribute()
30
- signature: Signature = info.attribute()
31
- body: ir.Region = info.region(multi=True)
32
14
 
33
15
  def print_impl(self, printer: Printer) -> None:
34
16
  with printer.rich(style="red"):
@@ -36,6 +36,7 @@ class QASM2(lowering.LoweringABC[ast.Node]):
36
36
  file=file,
37
37
  lineno_offset=lineno_offset,
38
38
  col_offset=col_offset,
39
+ compactify=compactify,
39
40
  )
40
41
 
41
42
  return frame.curr_region
@@ -26,6 +26,7 @@ from bloqade.qasm2.rewrite import (
26
26
  ParallelToUOpRule,
27
27
  RaiseRegisterRule,
28
28
  UOpToParallelRule,
29
+ ParallelToGlobalRule,
29
30
  SimpleOptimalMergePolicy,
30
31
  RydbergGateSetRewriteRule,
31
32
  )
@@ -183,3 +184,20 @@ class UOpToParallel(Pass):
183
184
  CommonSubexpressionElimination(),
184
185
  )
185
186
  return Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
187
+
188
+
189
+ @dataclass
190
+ class ParallelToGlobal(Pass):
191
+
192
+ def generate_rule(self, mt: ir.Method) -> ParallelToGlobalRule:
193
+ address_analysis = address.AddressAnalysis(mt.dialects)
194
+ frame, _ = address_analysis.run_analysis(mt)
195
+ return ParallelToGlobalRule(frame.entries)
196
+
197
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
198
+ rule = self.generate_rule(mt)
199
+
200
+ result = Walk(rule).rewrite(mt.code)
201
+ result = Walk(DeadCodeElimination()).rewrite(mt.code).join(result)
202
+
203
+ return result
@@ -7,15 +7,22 @@ from kirin.rewrite import (
7
7
  ConstantFold,
8
8
  CommonSubexpressionElimination,
9
9
  )
10
+ from kirin.dialects import scf, func
10
11
 
11
- from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12
+ from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
13
+
14
+ from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
15
+ from ..dialects.core.stmts import Reset, Measure
16
+
17
+ AllowedThenType = (SingleQubitGate, TwoQubitCtrlGate, Measure, Reset)
18
+ DontLiftType = AllowedThenType + (scf.Yield, func.Return, func.Invoke)
12
19
 
13
20
 
14
21
  class UnrollIfs(Pass):
15
22
  """This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16
23
 
17
24
  def unsafe_run(self, mt: ir.Method):
18
- result = Walk(LiftThenBody()).rewrite(mt.code)
25
+ result = Walk(LiftThenBody(exclude_stmts=DontLiftType)).rewrite(mt.code)
19
26
  result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20
27
  result = (
21
28
  Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
@@ -11,5 +11,6 @@ from .uop_to_parallel import (
11
11
  SimpleGreedyMergePolicy as SimpleGreedyMergePolicy,
12
12
  SimpleOptimalMergePolicy as SimpleOptimalMergePolicy,
13
13
  )
14
+ from .parallel_to_glob import ParallelToGlobalRule as ParallelToGlobalRule
14
15
  from .noise.remove_noise import RemoveNoisePass as RemoveNoisePass
15
16
  from .noise.heuristic_noise import NoiseRewriteRule as NoiseRewriteRule
@@ -0,0 +1,82 @@
1
+ from typing import Dict
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.rewrite import abc
6
+ from kirin.analysis import const
7
+ from kirin.dialects import ilist
8
+
9
+ from bloqade.analysis import address
10
+
11
+ from ..dialects import core, glob, parallel
12
+
13
+
14
+ @dataclass
15
+ class ParallelToGlobalRule(abc.RewriteRule):
16
+ address_analysis: Dict[ir.SSAValue, address.Address]
17
+
18
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
19
+ if not isinstance(node, parallel.UGate):
20
+ return abc.RewriteResult()
21
+
22
+ qargs = node.qargs
23
+ qarg_addresses = self.address_analysis.get(qargs, None)
24
+
25
+ if isinstance(qarg_addresses, address.AddressReg):
26
+ # NOTE: we only have an AddressReg if it's an entire register, definitely rewrite that
27
+ return self._rewrite_parallel_to_glob(node)
28
+
29
+ if not isinstance(qarg_addresses, address.AddressTuple):
30
+ return abc.RewriteResult()
31
+
32
+ idxs, qreg = self._find_qreg(qargs.owner, set())
33
+
34
+ if qreg is None:
35
+ # NOTE: no unique register found
36
+ return abc.RewriteResult()
37
+
38
+ if not isinstance(hint := qreg.n_qubits.hints.get("const"), const.Value):
39
+ # NOTE: non-constant number of qubits
40
+ return abc.RewriteResult()
41
+
42
+ n = hint.data
43
+ if len(idxs) != n:
44
+ # NOTE: not all qubits of the register are there
45
+ return abc.RewriteResult()
46
+
47
+ return self._rewrite_parallel_to_glob(node)
48
+
49
+ @staticmethod
50
+ def _rewrite_parallel_to_glob(node: parallel.UGate) -> abc.RewriteResult:
51
+ theta, phi, lam = node.theta, node.phi, node.lam
52
+ global_u = glob.UGate(node.qargs, theta=theta, phi=phi, lam=lam)
53
+ node.replace_by(global_u)
54
+ return abc.RewriteResult(has_done_something=True)
55
+
56
+ @staticmethod
57
+ def _find_qreg(
58
+ qargs_owner: ir.Statement | ir.Block, idxs: set
59
+ ) -> tuple[set, core.stmts.QRegNew | None]:
60
+
61
+ if isinstance(qargs_owner, core.stmts.QRegGet):
62
+ idxs.add(qargs_owner.idx)
63
+ qreg = qargs_owner.reg.owner
64
+ if not isinstance(qreg, core.stmts.QRegNew):
65
+ # NOTE: this could potentially be casted
66
+ qreg = None
67
+ return idxs, qreg
68
+
69
+ if isinstance(qargs_owner, ilist.New):
70
+ vals = qargs_owner.values
71
+ if len(vals) == 0:
72
+ return idxs, None
73
+
74
+ idxs, first_qreg = ParallelToGlobalRule._find_qreg(vals[0].owner, idxs)
75
+ for val in vals[1:]:
76
+ idxs, qreg = ParallelToGlobalRule._find_qreg(val.owner, idxs)
77
+ if qreg != first_qreg:
78
+ return idxs, None
79
+
80
+ return idxs, first_qreg
81
+
82
+ return idxs, None
File without changes
@@ -0,0 +1 @@
1
+ from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.passes import Pass
5
+ from kirin.rewrite import (
6
+ Walk,
7
+ Chain,
8
+ Fixpoint,
9
+ )
10
+ from kirin.analysis import const
11
+
12
+ from ..rules.flatten_ilist import FlattenAddOpIList
13
+ from ..rules.inline_getitem_ilist import InlineGetItemFromIList
14
+
15
+
16
+ @dataclass
17
+ class CanonicalizeIList(Pass):
18
+
19
+ def unsafe_run(self, mt: ir.Method):
20
+
21
+ cp_result_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
22
+
23
+ return Fixpoint(
24
+ Chain(
25
+ Walk(InlineGetItemFromIList(constprop_result=cp_result_frame.entries)),
26
+ Walk(FlattenAddOpIList()),
27
+ )
28
+ ).rewrite(mt.code)
@@ -0,0 +1 @@
1
+ from .split_ifs import LiftThenBody as LiftThenBody, SplitIfStmts as SplitIfStmts
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.dialects import py, ilist
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+
7
+
8
+ @dataclass
9
+ class FlattenAddOpIList(RewriteRule):
10
+
11
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
12
+ if not isinstance(node, py.binop.Add):
13
+ return RewriteResult()
14
+
15
+ # check if we are adding two ilist.New objects
16
+ new_data = ()
17
+
18
+ # lhs:
19
+ if not isinstance(node.lhs.owner, ilist.New):
20
+ if not (
21
+ isinstance(node.lhs.owner, py.Constant)
22
+ and isinstance(
23
+ const_ilist := node.lhs.owner.value.unwrap(), ilist.IList
24
+ )
25
+ and len(const_ilist.data) == 0
26
+ ):
27
+ return RewriteResult()
28
+
29
+ else:
30
+ new_data += node.lhs.owner.values
31
+
32
+ # rhs:
33
+ if not isinstance(node.rhs.owner, ilist.New):
34
+ if not (
35
+ isinstance(node.rhs.owner, py.Constant)
36
+ and isinstance(
37
+ const_ilist := node.rhs.owner.value.unwrap(), ilist.IList
38
+ )
39
+ and len(const_ilist.data) == 0
40
+ ):
41
+ return RewriteResult()
42
+
43
+ else:
44
+ new_data += node.rhs.owner.values
45
+
46
+ new_stmt = ilist.New(values=new_data)
47
+ node.replace_by(new_stmt)
48
+
49
+ return RewriteResult(
50
+ has_done_something=True,
51
+ )
@@ -0,0 +1,31 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.analysis import const
5
+ from kirin.dialects import py, ilist
6
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
7
+
8
+
9
+ @dataclass
10
+ class InlineGetItemFromIList(RewriteRule):
11
+ constprop_result: dict[ir.SSAValue, const.Result]
12
+
13
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14
+ if not isinstance(node, py.indexing.GetItem):
15
+ return RewriteResult()
16
+
17
+ if not isinstance(node.obj.owner, ilist.New):
18
+ return RewriteResult()
19
+
20
+ if not isinstance(
21
+ index_value := self.constprop_result.get(node.index), const.Value
22
+ ):
23
+ return RewriteResult()
24
+
25
+ elem_ssa = node.obj.owner.values[index_value.data]
26
+
27
+ node.result.replace_by(elem_ssa)
28
+
29
+ return RewriteResult(
30
+ has_done_something=True,
31
+ )
@@ -1,18 +1,23 @@
1
+ from dataclasses import field, dataclass
2
+
1
3
  from kirin import ir
2
4
  from kirin.dialects import scf, func
3
5
  from kirin.rewrite.abc import RewriteRule, RewriteResult
4
6
 
5
- from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6
- from ..dialects.core.stmts import Reset, Measure
7
7
 
8
- # TODO: unify with PR #248
9
- AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
8
+ @dataclass
9
+ class LiftThenBody(RewriteRule):
10
+ """
11
+ Lifts anything that's not in the `exclude_stmts` in the *then* body
12
+
10
13
 
11
- DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
14
+ Args:
15
+ exclude_stmts: A tuple of statement types that should not be lifted from the then body.
16
+ Defaults to an empty tuple, meaning all statements are lifted.
12
17
 
18
+ """
13
19
 
14
- class LiftThenBody(RewriteRule):
15
- """Lifts anything that's not a UOP or a yield/return out of the then body"""
20
+ exclude_stmts: tuple[type[ir.Statement], ...] = field(default_factory=tuple)
16
21
 
17
22
  def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
23
  if not isinstance(node, scf.IfElse):
@@ -20,7 +25,9 @@ class LiftThenBody(RewriteRule):
20
25
 
21
26
  then_stmts = node.then_body.stmts()
22
27
 
23
- lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
28
+ lift_stmts = [
29
+ stmt for stmt in then_stmts if not isinstance(stmt, self.exclude_stmts)
30
+ ]
24
31
 
25
32
  if len(lift_stmts) == 0:
26
33
  return RewriteResult()
bloqade/squin/__init__.py CHANGED
@@ -3,7 +3,9 @@ from . import (
3
3
  wire as wire,
4
4
  noise as noise,
5
5
  qubit as qubit,
6
+ analysis as analysis,
6
7
  lowering as lowering,
8
+ _typeinfer as _typeinfer,
7
9
  )
8
10
  from .groups import wired as wired, kernel as kernel
9
11
 
@@ -0,0 +1,20 @@
1
+ from kirin import types, interp
2
+ from kirin.analysis import TypeInference, const
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade import squin
6
+
7
+
8
+ @squin.qubit.dialect.register(key="typeinfer")
9
+ class TypeInfer(interp.MethodTable):
10
+ @interp.impl(squin.qubit.New)
11
+ def _call(self, interp: TypeInference, frame: interp.Frame, stmt: squin.qubit.New):
12
+ # based on Xiu-zhe (Roger) Luo's get_const_value function
13
+
14
+ if (hint := stmt.n_qubits.hints.get("const")) is None:
15
+ return (ilist.IListType[squin.qubit.QubitType, types.Any],)
16
+
17
+ if isinstance(hint, const.Value) and isinstance(hint.data, int):
18
+ return (ilist.IListType[squin.qubit.QubitType, types.Literal(hint.data)],)
19
+
20
+ return (ilist.IListType[squin.qubit.QubitType, types.Any],)
@@ -0,0 +1 @@
1
+ from . import address_impl as address_impl
@@ -0,0 +1,71 @@
1
+ from kirin import interp
2
+ from kirin.analysis import ForwardFrame
3
+
4
+ from bloqade.analysis.address.lattice import (
5
+ Address,
6
+ AddressReg,
7
+ AddressWire,
8
+ AddressQubit,
9
+ )
10
+ from bloqade.analysis.address.analysis import AddressAnalysis
11
+
12
+ from .. import wire, qubit
13
+
14
+ # Address lattice elements we can work with:
15
+ ## NotQubit (bottom), AnyAddress (top)
16
+
17
+ ## AddressTuple -> data: tuple[Address, ...]
18
+ ### Recursive type, could contain itself or other variants
19
+ ### This pops up in cases where you can have an IList/Tuple
20
+ ### That contains elements that could be other Address types
21
+
22
+ ## AddressReg -> data: Sequence[int]
23
+ ### specific to creation of a register of qubits
24
+
25
+ ## AddressQubit -> data: int
26
+ ### Base qubit address type
27
+
28
+
29
+ @wire.dialect.register(key="qubit.address")
30
+ class SquinWireMethodTable(interp.MethodTable):
31
+
32
+ @interp.impl(wire.Unwrap)
33
+ def unwrap(
34
+ self,
35
+ interp_: AddressAnalysis,
36
+ frame: ForwardFrame[Address],
37
+ stmt: wire.Unwrap,
38
+ ):
39
+
40
+ origin_qubit = frame.get(stmt.qubit)
41
+
42
+ if isinstance(origin_qubit, AddressQubit):
43
+ return (AddressWire(origin_qubit=origin_qubit),)
44
+ else:
45
+ return (Address.top(),)
46
+
47
+ @interp.impl(wire.Apply)
48
+ def apply(
49
+ self,
50
+ interp_: AddressAnalysis,
51
+ frame: ForwardFrame[Address],
52
+ stmt: wire.Apply,
53
+ ):
54
+ return frame.get_values(stmt.inputs)
55
+
56
+
57
+ @qubit.dialect.register(key="qubit.address")
58
+ class SquinQubitMethodTable(interp.MethodTable):
59
+
60
+ # This can be treated like a QRegNew impl
61
+ @interp.impl(qubit.New)
62
+ def new(
63
+ self,
64
+ interp_: AddressAnalysis,
65
+ frame: ForwardFrame[Address],
66
+ stmt: qubit.New,
67
+ ):
68
+ n_qubits = interp_.get_const_value(int, stmt.n_qubits)
69
+ addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
70
+ interp_.next_address += n_qubits
71
+ return (addr,)
@@ -1,5 +1,5 @@
1
1
  from kirin import interp
2
- from kirin.dialects import scf
2
+ from kirin.dialects import scf, func
3
3
  from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer
4
4
 
5
5
  from bloqade.squin import op, wire
@@ -85,3 +85,8 @@ class SquinOp(interp.MethodTable):
85
85
  @scf.dialect.register(key="op.nsites")
86
86
  class ScfSquinOp(ScfTypeInfer):
87
87
  pass
88
+
89
+
90
+ @func.dialect.register(key="op.nsites")
91
+ class FuncSquinOp(func.typeinfer.TypeInfer):
92
+ pass
@@ -368,11 +368,24 @@ class Squin(lowering.LoweringABC[CirqNode]):
368
368
  state: lowering.State[CirqNode],
369
369
  node: cirq.GeneralizedAmplitudeDampingChannel,
370
370
  ):
371
- raise NotImplementedError("TODO: needs a new operator statement")
372
- # p = state.current_frame.push(py.Constant(node.p))
373
- # gamma = state.current_frame.push(py.Constant(node.gamma))
371
+ p = state.current_frame.push(py.Constant(node.p)).result
372
+ gamma = state.current_frame.push(py.Constant(node.gamma)).result
374
373
 
375
- # p1 =
374
+ # NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel,
375
+ # which basically means p is the probability of the environment being in the vacuum state
376
+ prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result
377
+ one_ = state.current_frame.push(py.Constant(1)).result
378
+ p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result
379
+ prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result
376
380
 
377
- # x = state.current_frame.push(op.stmts.X())
378
- # noise_channel1 = noise.stmts.PauliError(basis=x.result, p=)
381
+ r0 = state.current_frame.push(op.stmts.Reset()).result
382
+ r1 = state.current_frame.push(op.stmts.ResetToOne()).result
383
+
384
+ probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result
385
+ ops = state.current_frame.push(ilist.New(values=(r0, r1))).result
386
+
387
+ noise_channel = state.current_frame.push(
388
+ noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops)
389
+ )
390
+
391
+ return noise_channel
@@ -10,7 +10,7 @@ from ..op.types import NumOperators
10
10
 
11
11
  @statement
12
12
  class NoiseChannel(ir.Statement):
13
- traits = frozenset({lowering.FromPythonCall()})
13
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
14
14
  result: ir.ResultValue = info.result(OpType)
15
15
 
16
16
 
@@ -37,4 +37,5 @@ from ._wrapper import (
37
37
  control as control,
38
38
  identity as identity,
39
39
  pauli_string as pauli_string,
40
+ reset_to_one as reset_to_one,
40
41
  )
@@ -41,6 +41,10 @@ def control(op: types.Op, *, n_controls: int) -> types.Op:
41
41
  def reset() -> types.Op: ...
42
42
 
43
43
 
44
+ @wraps(stmts.ResetToOne)
45
+ def reset_to_one() -> types.Op: ...
46
+
47
+
44
48
  @wraps(stmts.Identity)
45
49
  def identity(*, sites: int) -> types.Op: ...
46
50
 
bloqade/squin/op/stmts.py CHANGED
@@ -98,6 +98,15 @@ class ConstantUnitary(ConstantOp):
98
98
 
99
99
  @statement(dialect=dialect)
100
100
  class U3(PrimitiveOp):
101
+ """
102
+ The rotation operator U3(theta, phi, lam).
103
+ Note that we use the convention from the QASM2 specification, namely
104
+
105
+ $$
106
+ U_3(\theta, \phi, \lambda) = R_z(\phi) R_y(\theta) R_z(\lambda)
107
+ $$
108
+ """
109
+
101
110
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
102
111
  theta: ir.SSAValue = info.argument(types.Float)
103
112
  phi: ir.SSAValue = info.argument(types.Float)
@@ -110,7 +119,7 @@ class PhaseOp(PrimitiveOp):
110
119
  A phase operator.
111
120
 
112
121
  $$
113
- PhaseOp(theta) = e^{i \theta} I
122
+ PhaseOp(\theta) = e^{i \theta} I
114
123
  $$
115
124
  """
116
125
 
@@ -124,7 +133,7 @@ class ShiftOp(PrimitiveOp):
124
133
  A phase shift operator.
125
134
 
126
135
  $$
127
- Shift(theta) = \\begin{bmatrix} 1 & 0 \\\\ 0 & e^{i \\theta} \\end{bmatrix}
136
+ Shift(\theta) = \\begin{bmatrix} 1 & 0 \\\\ 0 & e^{i \\theta} \\end{bmatrix}
128
137
  $$
129
138
  """
130
139
 
@@ -141,6 +150,15 @@ class Reset(PrimitiveOp):
141
150
  traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
142
151
 
143
152
 
153
+ @statement(dialect=dialect)
154
+ class ResetToOne(PrimitiveOp):
155
+ """
156
+ Reset qubits to the one state. Mainly needed to accommodate cirq's GeneralizedAmplitudeDampingChannel
157
+ """
158
+
159
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall(), FixedSites(1)})
160
+
161
+
144
162
  @statement
145
163
  class CliffordOp(ConstantUnitary):
146
164
  pass
bloqade/squin/qubit.py CHANGED
@@ -17,6 +17,7 @@ from kirin.lowering import wraps
17
17
  from bloqade.types import Qubit, QubitType
18
18
  from bloqade.squin.op.types import Op, OpType
19
19
 
20
+ from .types import MeasurementResult, MeasurementResultType
20
21
  from .lowering import ApplyAnyCallLowering
21
22
 
22
23
  dialect = ir.Dialect("squin.qubit")
@@ -65,8 +66,8 @@ class MeasureQubit(ir.Statement):
65
66
  name = "measure.qubit"
66
67
 
67
68
  traits = frozenset({lowering.FromPythonCall()})
68
- qubit: ir.SSAValue = info.argument(ilist.IListType[QubitType])
69
- result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
69
+ qubit: ir.SSAValue = info.argument(QubitType)
70
+ result: ir.ResultValue = info.result(MeasurementResultType)
70
71
 
71
72
 
72
73
  @statement(dialect=dialect)
@@ -75,7 +76,7 @@ class MeasureQubitList(ir.Statement):
75
76
 
76
77
  traits = frozenset({lowering.FromPythonCall()})
77
78
  qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
78
- result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
79
+ result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType])
79
80
 
80
81
 
81
82
  # NOTE: no dependent types in Python, so we have to mark it Any...
@@ -131,9 +132,11 @@ def apply(operator: Op, *qubits) -> None: ...
131
132
 
132
133
 
133
134
  @overload
134
- def measure(input: Qubit) -> bool: ...
135
+ def measure(input: Qubit) -> MeasurementResult: ...
135
136
  @overload
136
- def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> ilist.IList[bool, Any]: ...
137
+ def measure(
138
+ input: ilist.IList[Qubit, Any] | list[Qubit],
139
+ ) -> ilist.IList[MeasurementResult, Any]: ...
137
140
 
138
141
 
139
142
  @wraps(MeasureAny)
@@ -4,4 +4,5 @@ from .wrap_analysis import (
4
4
  WrapOpSiteAnalysis as WrapOpSiteAnalysis,
5
5
  WrapAddressAnalysis as WrapAddressAnalysis,
6
6
  )
7
+ from .U3_to_clifford import SquinU3ToClifford as SquinU3ToClifford
7
8
  from .remove_dangling_qubits import RemoveDeadRegister as RemoveDeadRegister