bloqade-circuit 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (42) hide show
  1. bloqade/cirq_utils/__init__.py +7 -0
  2. bloqade/cirq_utils/lineprog.py +295 -0
  3. bloqade/cirq_utils/parallelize.py +400 -0
  4. bloqade/pyqrack/squin/op.py +7 -2
  5. bloqade/pyqrack/squin/runtime.py +4 -2
  6. bloqade/qasm2/dialects/expr/stmts.py +2 -20
  7. bloqade/qasm2/parse/lowering.py +1 -0
  8. bloqade/qasm2/passes/parallel.py +18 -0
  9. bloqade/qasm2/rewrite/__init__.py +1 -0
  10. bloqade/qasm2/rewrite/parallel_to_glob.py +82 -0
  11. bloqade/squin/__init__.py +1 -0
  12. bloqade/squin/_typeinfer.py +20 -0
  13. bloqade/squin/analysis/nsites/impls.py +6 -1
  14. bloqade/squin/cirq/__init__.py +74 -9
  15. bloqade/squin/cirq/emit/noise.py +49 -0
  16. bloqade/squin/cirq/emit/runtime.py +9 -1
  17. bloqade/squin/cirq/lowering.py +46 -27
  18. bloqade/squin/noise/_wrapper.py +9 -2
  19. bloqade/squin/noise/rewrite.py +3 -3
  20. bloqade/squin/op/__init__.py +1 -0
  21. bloqade/squin/op/_wrapper.py +4 -0
  22. bloqade/squin/op/stmts.py +20 -2
  23. bloqade/squin/qubit.py +8 -5
  24. bloqade/squin/rewrite/__init__.py +1 -0
  25. bloqade/squin/rewrite/canonicalize.py +60 -0
  26. bloqade/squin/rewrite/desugar.py +52 -5
  27. bloqade/squin/types.py +8 -0
  28. bloqade/squin/wire.py +91 -5
  29. bloqade/stim/__init__.py +1 -0
  30. bloqade/stim/_wrappers.py +4 -0
  31. bloqade/stim/dialects/noise/emit.py +1 -0
  32. bloqade/stim/dialects/noise/stmts.py +5 -0
  33. bloqade/stim/passes/squin_to_stim.py +16 -1
  34. bloqade/stim/rewrite/__init__.py +1 -0
  35. bloqade/stim/rewrite/qubit_to_stim.py +10 -6
  36. bloqade/stim/rewrite/squin_noise.py +120 -0
  37. bloqade/stim/rewrite/util.py +44 -9
  38. bloqade/stim/rewrite/wire_to_stim.py +8 -3
  39. {bloqade_circuit-0.4.4.dist-info → bloqade_circuit-0.5.0.dist-info}/METADATA +4 -2
  40. {bloqade_circuit-0.4.4.dist-info → bloqade_circuit-0.5.0.dist-info}/RECORD +42 -33
  41. {bloqade_circuit-0.4.4.dist-info → bloqade_circuit-0.5.0.dist-info}/WHEEL +0 -0
  42. {bloqade_circuit-0.4.4.dist-info → bloqade_circuit-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -96,10 +96,15 @@ class PyQrackMethods(interp.MethodTable):
96
96
  return (PhaseOpRuntime(theta, global_=global_),)
97
97
 
98
98
  @interp.impl(op.stmts.Reset)
99
+ @interp.impl(op.stmts.ResetToOne)
99
100
  def reset(
100
- self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Reset
101
+ self,
102
+ interp: PyQrackInterpreter,
103
+ frame: interp.Frame,
104
+ stmt: op.stmts.Reset | op.stmts.ResetToOne,
101
105
  ) -> tuple[OperatorRuntimeABC]:
102
- return (ResetRuntime(),)
106
+ target_state = isinstance(stmt, op.stmts.ResetToOne)
107
+ return (ResetRuntime(target_state=target_state),)
103
108
 
104
109
  @interp.impl(op.stmts.X)
105
110
  @interp.impl(op.stmts.Y)
@@ -43,7 +43,9 @@ class OperatorRuntimeABC:
43
43
 
44
44
  @dataclass(frozen=True)
45
45
  class ResetRuntime(OperatorRuntimeABC):
46
- """Reset the qubit to |0> state"""
46
+ """Reset the qubit to the target state"""
47
+
48
+ target_state: bool
47
49
 
48
50
  @property
49
51
  def n_sites(self) -> int:
@@ -55,7 +57,7 @@ class ResetRuntime(OperatorRuntimeABC):
55
57
  continue
56
58
 
57
59
  res: bool = qubit.sim_reg.m(qubit.addr)
58
- if res:
60
+ if res != self.target_state:
59
61
  qubit.sim_reg.x(qubit.addr)
60
62
 
61
63
 
@@ -1,34 +1,16 @@
1
1
  from kirin import ir, types, lowering
2
2
  from kirin.decl import info, statement
3
+ from kirin.dialects import func
3
4
  from kirin.print.printer import Printer
4
- from kirin.dialects.func.attrs import Signature
5
5
 
6
6
  from ._dialect import dialect
7
7
 
8
8
 
9
- class GateFuncOpCallableInterface(ir.CallableStmtInterface["GateFunction"]):
10
-
11
- @classmethod
12
- def get_callable_region(cls, stmt: "GateFunction") -> ir.Region:
13
- return stmt.body
14
-
15
-
16
9
  @statement(dialect=dialect)
17
- class GateFunction(ir.Statement):
10
+ class GateFunction(func.Function):
18
11
  """Special Function for qasm2 gate subroutine."""
19
12
 
20
13
  name = "gate.func"
21
- traits = frozenset(
22
- {
23
- ir.IsolatedFromAbove(),
24
- ir.SymbolOpInterface(),
25
- ir.HasSignature(),
26
- GateFuncOpCallableInterface(),
27
- }
28
- )
29
- sym_name: str = info.attribute()
30
- signature: Signature = info.attribute()
31
- body: ir.Region = info.region(multi=True)
32
14
 
33
15
  def print_impl(self, printer: Printer) -> None:
34
16
  with printer.rich(style="red"):
@@ -36,6 +36,7 @@ class QASM2(lowering.LoweringABC[ast.Node]):
36
36
  file=file,
37
37
  lineno_offset=lineno_offset,
38
38
  col_offset=col_offset,
39
+ compactify=compactify,
39
40
  )
40
41
 
41
42
  return frame.curr_region
@@ -26,6 +26,7 @@ from bloqade.qasm2.rewrite import (
26
26
  ParallelToUOpRule,
27
27
  RaiseRegisterRule,
28
28
  UOpToParallelRule,
29
+ ParallelToGlobalRule,
29
30
  SimpleOptimalMergePolicy,
30
31
  RydbergGateSetRewriteRule,
31
32
  )
@@ -183,3 +184,20 @@ class UOpToParallel(Pass):
183
184
  CommonSubexpressionElimination(),
184
185
  )
185
186
  return Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
187
+
188
+
189
+ @dataclass
190
+ class ParallelToGlobal(Pass):
191
+
192
+ def generate_rule(self, mt: ir.Method) -> ParallelToGlobalRule:
193
+ address_analysis = address.AddressAnalysis(mt.dialects)
194
+ frame, _ = address_analysis.run_analysis(mt)
195
+ return ParallelToGlobalRule(frame.entries)
196
+
197
+ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
198
+ rule = self.generate_rule(mt)
199
+
200
+ result = Walk(rule).rewrite(mt.code)
201
+ result = Walk(DeadCodeElimination()).rewrite(mt.code).join(result)
202
+
203
+ return result
@@ -11,5 +11,6 @@ from .uop_to_parallel import (
11
11
  SimpleGreedyMergePolicy as SimpleGreedyMergePolicy,
12
12
  SimpleOptimalMergePolicy as SimpleOptimalMergePolicy,
13
13
  )
14
+ from .parallel_to_glob import ParallelToGlobalRule as ParallelToGlobalRule
14
15
  from .noise.remove_noise import RemoveNoisePass as RemoveNoisePass
15
16
  from .noise.heuristic_noise import NoiseRewriteRule as NoiseRewriteRule
@@ -0,0 +1,82 @@
1
+ from typing import Dict
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.rewrite import abc
6
+ from kirin.analysis import const
7
+ from kirin.dialects import ilist
8
+
9
+ from bloqade.analysis import address
10
+
11
+ from ..dialects import core, glob, parallel
12
+
13
+
14
+ @dataclass
15
+ class ParallelToGlobalRule(abc.RewriteRule):
16
+ address_analysis: Dict[ir.SSAValue, address.Address]
17
+
18
+ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
19
+ if not isinstance(node, parallel.UGate):
20
+ return abc.RewriteResult()
21
+
22
+ qargs = node.qargs
23
+ qarg_addresses = self.address_analysis.get(qargs, None)
24
+
25
+ if isinstance(qarg_addresses, address.AddressReg):
26
+ # NOTE: we only have an AddressReg if it's an entire register, definitely rewrite that
27
+ return self._rewrite_parallel_to_glob(node)
28
+
29
+ if not isinstance(qarg_addresses, address.AddressTuple):
30
+ return abc.RewriteResult()
31
+
32
+ idxs, qreg = self._find_qreg(qargs.owner, set())
33
+
34
+ if qreg is None:
35
+ # NOTE: no unique register found
36
+ return abc.RewriteResult()
37
+
38
+ if not isinstance(hint := qreg.n_qubits.hints.get("const"), const.Value):
39
+ # NOTE: non-constant number of qubits
40
+ return abc.RewriteResult()
41
+
42
+ n = hint.data
43
+ if len(idxs) != n:
44
+ # NOTE: not all qubits of the register are there
45
+ return abc.RewriteResult()
46
+
47
+ return self._rewrite_parallel_to_glob(node)
48
+
49
+ @staticmethod
50
+ def _rewrite_parallel_to_glob(node: parallel.UGate) -> abc.RewriteResult:
51
+ theta, phi, lam = node.theta, node.phi, node.lam
52
+ global_u = glob.UGate(node.qargs, theta=theta, phi=phi, lam=lam)
53
+ node.replace_by(global_u)
54
+ return abc.RewriteResult(has_done_something=True)
55
+
56
+ @staticmethod
57
+ def _find_qreg(
58
+ qargs_owner: ir.Statement | ir.Block, idxs: set
59
+ ) -> tuple[set, core.stmts.QRegNew | None]:
60
+
61
+ if isinstance(qargs_owner, core.stmts.QRegGet):
62
+ idxs.add(qargs_owner.idx)
63
+ qreg = qargs_owner.reg.owner
64
+ if not isinstance(qreg, core.stmts.QRegNew):
65
+ # NOTE: this could potentially be casted
66
+ qreg = None
67
+ return idxs, qreg
68
+
69
+ if isinstance(qargs_owner, ilist.New):
70
+ vals = qargs_owner.values
71
+ if len(vals) == 0:
72
+ return idxs, None
73
+
74
+ idxs, first_qreg = ParallelToGlobalRule._find_qreg(vals[0].owner, idxs)
75
+ for val in vals[1:]:
76
+ idxs, qreg = ParallelToGlobalRule._find_qreg(val.owner, idxs)
77
+ if qreg != first_qreg:
78
+ return idxs, None
79
+
80
+ return idxs, first_qreg
81
+
82
+ return idxs, None
bloqade/squin/__init__.py CHANGED
@@ -4,6 +4,7 @@ from . import (
4
4
  noise as noise,
5
5
  qubit as qubit,
6
6
  lowering as lowering,
7
+ _typeinfer as _typeinfer,
7
8
  )
8
9
  from .groups import wired as wired, kernel as kernel
9
10
 
@@ -0,0 +1,20 @@
1
+ from kirin import types, interp
2
+ from kirin.analysis import TypeInference, const
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade import squin
6
+
7
+
8
+ @squin.qubit.dialect.register(key="typeinfer")
9
+ class TypeInfer(interp.MethodTable):
10
+ @interp.impl(squin.qubit.New)
11
+ def _call(self, interp: TypeInference, frame: interp.Frame, stmt: squin.qubit.New):
12
+ # based on Xiu-zhe (Roger) Luo's get_const_value function
13
+
14
+ if (hint := stmt.n_qubits.hints.get("const")) is None:
15
+ return (ilist.IListType[squin.qubit.QubitType, types.Any],)
16
+
17
+ if isinstance(hint, const.Value) and isinstance(hint.data, int):
18
+ return (ilist.IListType[squin.qubit.QubitType, types.Literal(hint.data)],)
19
+
20
+ return (ilist.IListType[squin.qubit.QubitType, types.Any],)
@@ -1,5 +1,5 @@
1
1
  from kirin import interp
2
- from kirin.dialects import scf
2
+ from kirin.dialects import scf, func
3
3
  from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer
4
4
 
5
5
  from bloqade.squin import op, wire
@@ -85,3 +85,8 @@ class SquinOp(interp.MethodTable):
85
85
  @scf.dialect.register(key="op.nsites")
86
86
  class ScfSquinOp(ScfTypeInfer):
87
87
  pass
88
+
89
+
90
+ @func.dialect.register(key="op.nsites")
91
+ class FuncSquinOp(func.typeinfer.TypeInfer):
92
+ pass
@@ -9,8 +9,9 @@ from . import lowering as lowering
9
9
  from .. import kernel
10
10
 
11
11
  # NOTE: just to register methods
12
- from .emit import op as op, qubit as qubit
12
+ from .emit import op as op, noise as noise, qubit as qubit
13
13
  from .lowering import Squin
14
+ from ..noise.rewrite import RewriteNoiseStmts
14
15
  from .emit.emit_circuit import EmitCirq
15
16
 
16
17
 
@@ -18,6 +19,9 @@ def load_circuit(
18
19
  circuit: cirq.Circuit,
19
20
  kernel_name: str = "main",
20
21
  dialects: ir.DialectGroup = kernel,
22
+ register_as_argument: bool = False,
23
+ return_register: bool = False,
24
+ register_argument_name: str = "q",
21
25
  globals: dict[str, Any] | None = None,
22
26
  file: str | None = None,
23
27
  lineno_offset: int = 0,
@@ -32,13 +36,23 @@ def load_circuit(
32
36
  Keyword Args:
33
37
  kernel_name (str): The name of the kernel to load. Defaults to "main".
34
38
  dialects (ir.DialectGroup | None): The dialects to use. Defaults to `squin.kernel`.
39
+ register_as_argument (bool): Determine whether the resulting kernel function should accept
40
+ a single `ilist.IList[Qubit, Any]` argument that is a list of qubits used within the
41
+ function. This allows you to compose kernel functions generated from circuits.
42
+ Defaults to `False`.
43
+ return_register (bool): Determine whether the resulting kernel functionr returns a
44
+ single value of type `ilist.IList[Qubit, Any]` that is the list of qubits used
45
+ in the kernel function. Useful when you want to compose multiple kernel functions
46
+ generated from circuits. Defaults to `False`.
47
+ register_argument_name (str): The name of the argument that represents the qubit register.
48
+ Only used when `register_as_argument=True`. Defaults to "q".
35
49
  globals (dict[str, Any] | None): The global variables to use. Defaults to None.
36
50
  file (str | None): The file name for error reporting. Defaults to None.
37
51
  lineno_offset (int): The line number offset for error reporting. Defaults to 0.
38
52
  col_offset (int): The column number offset for error reporting. Defaults to 0.
39
53
  compactify (bool): Whether to compactify the output. Defaults to True.
40
54
 
41
- Example:
55
+ ## Usage Examples:
42
56
 
43
57
  ```python
44
58
  # from cirq's "hello qubit" example
@@ -60,6 +74,30 @@ def load_circuit(
60
74
  # print the resulting IR
61
75
  main.print()
62
76
  ```
77
+
78
+ You can also compose kernel functions generated from circuits by passing in
79
+ and / or returning the respective quantum registers:
80
+
81
+ ```python
82
+ q = cirq.LineQubit.range(2)
83
+ circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q))
84
+
85
+ get_entangled_qubits = squin.cirq.load_circuit(
86
+ circuit, return_register=True, kernel_name="get_entangled_qubits"
87
+ )
88
+ get_entangled_qubits.print()
89
+
90
+ entangle_qubits = squin.cirq.load_circuit(
91
+ circuit, register_as_argument=True, kernel_name="entangle_qubits"
92
+ )
93
+
94
+ @squin.kernel
95
+ def main():
96
+ qreg = get_entangled_qubits()
97
+ qreg2 = squin.qubit.new(1)
98
+ entangle_qubits([qreg[1], qreg2[0]])
99
+ return squin.qubit.measure(qreg2)
100
+ ```
63
101
  """
64
102
 
65
103
  target = Squin(dialects=dialects, circuit=circuit)
@@ -71,16 +109,38 @@ def load_circuit(
71
109
  lineno_offset=lineno_offset,
72
110
  col_offset=col_offset,
73
111
  compactify=compactify,
112
+ register_as_argument=register_as_argument,
113
+ register_argument_name=register_argument_name,
74
114
  )
75
115
 
76
- # NOTE: no return value
77
- return_value = func.ConstantNone()
78
- body.blocks[0].stmts.append(return_value)
79
- body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
116
+ if return_register:
117
+ return_value = target.qreg
118
+ else:
119
+ return_value = func.ConstantNone()
120
+ body.blocks[0].stmts.append(return_value)
121
+
122
+ return_node = func.Return(value_or_stmt=return_value)
123
+ body.blocks[0].stmts.append(return_node)
124
+
125
+ self_arg_name = kernel_name + "_self"
126
+ arg_names = [self_arg_name]
127
+ if register_as_argument:
128
+ args = (target.qreg.type,)
129
+ arg_names.append(register_argument_name)
130
+ else:
131
+ args = ()
132
+
133
+ # NOTE: add _self as argument; need to know signature before so do it after lowering
134
+ signature = func.Signature(args, return_node.value.type)
135
+ body.blocks[0].args.insert_from(
136
+ 0,
137
+ types.Generic(ir.Method, types.Tuple.where(signature.inputs), signature.output),
138
+ self_arg_name,
139
+ )
80
140
 
81
141
  code = func.Function(
82
142
  sym_name=kernel_name,
83
- signature=func.Signature((), types.NoneType),
143
+ signature=signature,
84
144
  body=body,
85
145
  )
86
146
 
@@ -88,7 +148,7 @@ def load_circuit(
88
148
  mod=None,
89
149
  py_func=None,
90
150
  sym_name=kernel_name,
91
- arg_names=[],
151
+ arg_names=arg_names,
92
152
  dialects=dialects,
93
153
  code=code,
94
154
  )
@@ -176,7 +236,12 @@ def emit_circuit(
176
236
  )
177
237
 
178
238
  emitter = EmitCirq(qubits=qubits)
179
- return emitter.run(mt, args=())
239
+
240
+ # Rewrite noise statements
241
+ mt_ = mt.similar(mt.dialects)
242
+ RewriteNoiseStmts(mt_.dialects)(mt_)
243
+
244
+ return emitter.run(mt_, args=())
180
245
 
181
246
 
182
247
  def dump_circuit(mt: ir.Method, qubits: Sequence[cirq.Qid] | None = None, **kwargs):
@@ -0,0 +1,49 @@
1
+ import cirq
2
+ from kirin.emit import EmitError
3
+ from kirin.interp import MethodTable, impl
4
+
5
+ from ... import noise
6
+ from .runtime import (
7
+ KronRuntime,
8
+ BasicOpRuntime,
9
+ OperatorRuntimeABC,
10
+ PauliStringRuntime,
11
+ )
12
+ from .emit_circuit import EmitCirq, EmitCirqFrame
13
+
14
+
15
+ @noise.dialect.register(key="emit.cirq")
16
+ class EmitCirqNoiseMethods(MethodTable):
17
+
18
+ @impl(noise.stmts.StochasticUnitaryChannel)
19
+ def stochastic_unitary_channel(
20
+ self,
21
+ emit: EmitCirq,
22
+ frame: EmitCirqFrame,
23
+ stmt: noise.stmts.StochasticUnitaryChannel,
24
+ ):
25
+ ops = frame.get(stmt.operators)
26
+ ps = frame.get(stmt.probabilities)
27
+
28
+ error_probabilities = {self._op_to_key(op_): p for op_, p in zip(ops, ps)}
29
+ cirq_op = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
30
+ return (BasicOpRuntime(cirq_op),)
31
+
32
+ @staticmethod
33
+ def _op_to_key(operator: OperatorRuntimeABC) -> str:
34
+ match operator:
35
+ case KronRuntime():
36
+ key_lhs = EmitCirqNoiseMethods._op_to_key(operator.lhs)
37
+ key_rhs = EmitCirqNoiseMethods._op_to_key(operator.rhs)
38
+ return key_lhs + key_rhs
39
+
40
+ case BasicOpRuntime():
41
+ return str(operator.gate)
42
+
43
+ case PauliStringRuntime():
44
+ return operator.string
45
+
46
+ case _:
47
+ raise EmitError(
48
+ f"Unexpected operator runtime in StochasticUnitaryChannel of type {type(operator).__name__} encountered!"
49
+ )
@@ -21,7 +21,10 @@ class OperatorRuntimeABC:
21
21
 
22
22
  def unsafe_apply(
23
23
  self, qubits: Sequence[cirq.Qid], adjoint: bool = False
24
- ) -> list[cirq.Operation]: ...
24
+ ) -> list[cirq.Operation]:
25
+ raise NotImplementedError(
26
+ f"Apply method needs to be implemented in {self.__class__.__name__}"
27
+ )
25
28
 
26
29
 
27
30
  @dataclass
@@ -38,6 +41,11 @@ class BasicOpRuntime(UnsafeOperatorRuntimeABC):
38
41
  def num_qubits(self) -> int:
39
42
  return self.gate.num_qubits()
40
43
 
44
+ def unsafe_apply(
45
+ self, qubits: Sequence[cirq.Qid], adjoint: bool = False
46
+ ) -> list[cirq.Operation]:
47
+ return [self.gate(*qubits)]
48
+
41
49
 
42
50
  @dataclass
43
51
  class UnitaryRuntime(BasicOpRuntime):
@@ -3,7 +3,7 @@ from typing import Any
3
3
  from dataclasses import field, dataclass
4
4
 
5
5
  import cirq
6
- from kirin import ir, lowering
6
+ from kirin import ir, types, lowering
7
7
  from kirin.rewrite import Walk, CFGCompactify
8
8
  from kirin.dialects import py, scf, ilist
9
9
 
@@ -25,27 +25,26 @@ class Squin(lowering.LoweringABC[CirqNode]):
25
25
  """Lower a cirq.Circuit object to a squin kernel"""
26
26
 
27
27
  circuit: cirq.Circuit
28
- qreg: qubit.New = field(init=False)
28
+ qreg: ir.SSAValue = field(init=False)
29
29
  qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict)
30
30
  next_qreg_index: int = field(init=False, default=0)
31
31
 
32
- def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
33
- index = self.qreg_index.get(qid)
34
-
35
- if index is None:
36
- index = self.next_qreg_index
37
- self.qreg_index[qid] = index
38
- self.next_qreg_index += 1
32
+ def __post_init__(self):
33
+ # TODO: sort by cirq ordering
34
+ qbits = sorted(self.circuit.all_qubits())
35
+ self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)}
39
36
 
37
+ def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
38
+ index = self.qreg_index[qid]
40
39
  index_ssa = state.current_frame.push(py.Constant(index)).result
41
- qbit_getitem = state.current_frame.push(py.GetItem(self.qreg.result, index_ssa))
40
+ qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa))
42
41
  return qbit_getitem.result
43
42
 
44
43
  def lower_qubit_getindices(
45
44
  self, state: lowering.State[CirqNode], qids: list[cirq.Qid]
46
45
  ):
47
46
  qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids]
48
- qbits_stmt = ilist.New(values=qbits_getitem)
47
+ qbits_stmt = ilist.New(values=qbits_getitem, elem_type=qubit.QubitType)
49
48
  qbits_result = state.current_frame.get(qbits_stmt.name)
50
49
 
51
50
  if qbits_result is not None:
@@ -64,6 +63,8 @@ class Squin(lowering.LoweringABC[CirqNode]):
64
63
  lineno_offset: int = 0,
65
64
  col_offset: int = 0,
66
65
  compactify: bool = True,
66
+ register_as_argument: bool = False,
67
+ register_argument_name: str = "q",
67
68
  ) -> ir.Region:
68
69
 
69
70
  state = lowering.State(
@@ -73,16 +74,21 @@ class Squin(lowering.LoweringABC[CirqNode]):
73
74
  col_offset=col_offset,
74
75
  )
75
76
 
76
- with state.frame(
77
- [stmt],
78
- globals=globals,
79
- finalize_next=False,
80
- ) as frame:
81
- # NOTE: create a global register of qubits first
82
- # TODO: can there be a circuit without qubits?
83
- n_qubits = cirq.num_qubits(self.circuit)
84
- n = frame.push(py.Constant(n_qubits))
85
- self.qreg = frame.push(qubit.New(n_qubits=n.result))
77
+ with state.frame([stmt], globals=globals, finalize_next=False) as frame:
78
+
79
+ # NOTE: need a register of qubits before lowering statements
80
+ if register_as_argument:
81
+ # NOTE: register as argument to the kernel; we have freedom of choice for the name here
82
+ frame.curr_block.args.append_from(
83
+ ilist.IListType[qubit.QubitType, types.Any],
84
+ name=register_argument_name,
85
+ )
86
+ self.qreg = frame.curr_block.args[0]
87
+ else:
88
+ # NOTE: create a new register of appropriate size
89
+ n_qubits = len(self.qreg_index)
90
+ n = frame.push(py.Constant(n_qubits))
91
+ self.qreg = frame.push(qubit.New(n_qubits=n.result)).result
86
92
 
87
93
  self.visit(state, stmt)
88
94
 
@@ -362,11 +368,24 @@ class Squin(lowering.LoweringABC[CirqNode]):
362
368
  state: lowering.State[CirqNode],
363
369
  node: cirq.GeneralizedAmplitudeDampingChannel,
364
370
  ):
365
- raise NotImplementedError("TODO: needs a new operator statement")
366
- # p = state.current_frame.push(py.Constant(node.p))
367
- # gamma = state.current_frame.push(py.Constant(node.gamma))
371
+ p = state.current_frame.push(py.Constant(node.p)).result
372
+ gamma = state.current_frame.push(py.Constant(node.gamma)).result
368
373
 
369
- # p1 =
374
+ # NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel,
375
+ # which basically means p is the probability of the environment being in the vacuum state
376
+ prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result
377
+ one_ = state.current_frame.push(py.Constant(1)).result
378
+ p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result
379
+ prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result
370
380
 
371
- # x = state.current_frame.push(op.stmts.X())
372
- # noise_channel1 = noise.stmts.PauliError(basis=x.result, p=)
381
+ r0 = state.current_frame.push(op.stmts.Reset()).result
382
+ r1 = state.current_frame.push(op.stmts.ResetToOne()).result
383
+
384
+ probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result
385
+ ops = state.current_frame.push(ilist.New(values=(r0, r1))).result
386
+
387
+ noise_channel = state.current_frame.push(
388
+ noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops)
389
+ )
390
+
391
+ return noise_channel
@@ -1,3 +1,6 @@
1
+ from typing import Literal
2
+
3
+ from kirin.dialects import ilist
1
4
  from kirin.lowering import wraps
2
5
 
3
6
  from bloqade.squin.op.types import Op
@@ -18,11 +21,15 @@ def depolarize(p: float) -> Op: ...
18
21
 
19
22
 
20
23
  @wraps(stmts.SingleQubitPauliChannel)
21
- def single_qubit_pauli_channel(params: tuple[float, float, float]) -> Op: ...
24
+ def single_qubit_pauli_channel(
25
+ params: ilist.IList[float, Literal[3]] | list[float] | tuple[float, float, float],
26
+ ) -> Op: ...
22
27
 
23
28
 
24
29
  @wraps(stmts.TwoQubitPauliChannel)
25
- def two_qubit_pauli_channel(params: tuple[float, ...]) -> Op: ...
30
+ def two_qubit_pauli_channel(
31
+ params: ilist.IList[float, Literal[15]] | list[float] | tuple[float, ...],
32
+ ) -> Op: ...
26
33
 
27
34
 
28
35
  @wraps(stmts.QubitLoss)
@@ -58,12 +58,12 @@ class _RewriteNoiseStmts(RewriteRule):
58
58
  def rewrite_two_qubit_pauli_channel(
59
59
  self, node: TwoQubitPauliChannel
60
60
  ) -> RewriteResult:
61
- paulis = (X(), Y(), Z(), Identity(sites=1))
61
+ paulis = (Identity(sites=1), X(), Y(), Z())
62
62
  for op in paulis:
63
63
  op.insert_before(node)
64
64
 
65
- # NOTE: collect list so we can skip the last entry, which will be two identities
66
- combinations = list(itertools.product(paulis, repeat=2))[:-1]
65
+ # NOTE: collect list so we can skip the first entry, which will be two identities
66
+ combinations = list(itertools.product(paulis, repeat=2))[1:]
67
67
  operators: list[ir.SSAValue] = []
68
68
  for pauli_1, pauli_2 in combinations:
69
69
  op = Kron(pauli_1.result, pauli_2.result)
@@ -37,4 +37,5 @@ from ._wrapper import (
37
37
  control as control,
38
38
  identity as identity,
39
39
  pauli_string as pauli_string,
40
+ reset_to_one as reset_to_one,
40
41
  )
@@ -41,6 +41,10 @@ def control(op: types.Op, *, n_controls: int) -> types.Op:
41
41
  def reset() -> types.Op: ...
42
42
 
43
43
 
44
+ @wraps(stmts.ResetToOne)
45
+ def reset_to_one() -> types.Op: ...
46
+
47
+
44
48
  @wraps(stmts.Identity)
45
49
  def identity(*, sites: int) -> types.Op: ...
46
50