bloqade-circuit 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (80) hide show
  1. bloqade/analysis/address/impls.py +14 -0
  2. bloqade/analysis/fidelity/analysis.py +27 -2
  3. bloqade/noise/fidelity.py +3 -3
  4. bloqade/noise/native/_dialect.py +1 -1
  5. bloqade/noise/native/_wrappers.py +35 -6
  6. bloqade/noise/native/stmts.py +1 -1
  7. bloqade/pyqrack/device.py +109 -21
  8. bloqade/pyqrack/qasm2/core.py +4 -1
  9. bloqade/pyqrack/squin/qubit.py +16 -9
  10. bloqade/pyqrack/squin/wire.py +22 -4
  11. bloqade/pyqrack/task.py +13 -5
  12. bloqade/qasm2/__init__.py +1 -0
  13. bloqade/qasm2/_qasm_loading.py +151 -0
  14. bloqade/qasm2/dialects/core/__init__.py +9 -1
  15. bloqade/qasm2/dialects/expr/__init__.py +18 -1
  16. bloqade/qasm2/dialects/noise.py +33 -1
  17. bloqade/qasm2/dialects/uop/__init__.py +39 -3
  18. bloqade/qasm2/dialects/uop/schedule.py +1 -1
  19. bloqade/qasm2/emit/impls/__init__.py +1 -0
  20. bloqade/qasm2/emit/impls/noise_native.py +89 -0
  21. bloqade/qasm2/emit/main.py +21 -0
  22. bloqade/qasm2/emit/target.py +20 -5
  23. bloqade/qasm2/groups.py +2 -0
  24. bloqade/qasm2/parse/__init__.py +7 -4
  25. bloqade/qasm2/parse/lowering.py +20 -130
  26. bloqade/qasm2/parse/qasm2.lark +1 -1
  27. bloqade/qasm2/passes/__init__.py +1 -0
  28. bloqade/qasm2/passes/fold.py +6 -0
  29. bloqade/qasm2/passes/noise.py +50 -2
  30. bloqade/qasm2/passes/parallel.py +9 -0
  31. bloqade/qasm2/passes/unroll_if.py +25 -0
  32. bloqade/qasm2/rewrite/__init__.py +1 -0
  33. bloqade/qasm2/rewrite/desugar.py +3 -2
  34. bloqade/qasm2/rewrite/heuristic_noise.py +1 -9
  35. bloqade/qasm2/rewrite/native_gates.py +67 -4
  36. bloqade/qasm2/rewrite/split_ifs.py +66 -0
  37. bloqade/squin/analysis/nsites/__init__.py +1 -0
  38. bloqade/squin/analysis/nsites/impls.py +25 -1
  39. bloqade/squin/noise/__init__.py +7 -26
  40. bloqade/squin/noise/_wrapper.py +25 -0
  41. bloqade/squin/op/__init__.py +33 -159
  42. bloqade/squin/op/_wrapper.py +101 -0
  43. bloqade/squin/op/stdlib.py +62 -0
  44. bloqade/squin/passes/__init__.py +1 -0
  45. bloqade/squin/passes/stim.py +68 -0
  46. bloqade/squin/rewrite/__init__.py +11 -0
  47. bloqade/squin/rewrite/qubit_to_stim.py +84 -0
  48. bloqade/squin/rewrite/squin_measure.py +98 -0
  49. bloqade/squin/rewrite/stim_rewrite_util.py +158 -0
  50. bloqade/squin/rewrite/wire_identity_elimination.py +24 -0
  51. bloqade/squin/rewrite/wire_to_stim.py +73 -0
  52. bloqade/squin/rewrite/wrap_analysis.py +72 -0
  53. bloqade/squin/wire.py +1 -13
  54. bloqade/stim/__init__.py +39 -5
  55. bloqade/stim/_wrappers.py +14 -12
  56. bloqade/stim/dialects/__init__.py +1 -5
  57. bloqade/stim/dialects/{aux → auxiliary}/__init__.py +12 -1
  58. bloqade/stim/dialects/{aux → auxiliary}/emit.py +1 -1
  59. bloqade/stim/dialects/collapse/__init__.py +13 -2
  60. bloqade/stim/dialects/collapse/{emit.py → emit_str.py} +1 -1
  61. bloqade/stim/dialects/collapse/stmts/pp_measure.py +1 -1
  62. bloqade/stim/dialects/gate/__init__.py +16 -1
  63. bloqade/stim/dialects/gate/emit.py +1 -1
  64. bloqade/stim/dialects/gate/stmts/base.py +1 -1
  65. bloqade/stim/dialects/gate/stmts/pp.py +1 -1
  66. bloqade/stim/dialects/noise/emit.py +1 -1
  67. bloqade/stim/emit/__init__.py +1 -1
  68. bloqade/stim/groups.py +4 -2
  69. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/METADATA +3 -3
  70. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/RECORD +80 -64
  71. /bloqade/stim/dialects/{aux → auxiliary}/_dialect.py +0 -0
  72. /bloqade/stim/dialects/{aux → auxiliary}/interp.py +0 -0
  73. /bloqade/stim/dialects/{aux → auxiliary}/lowering.py +0 -0
  74. /bloqade/stim/dialects/{aux → auxiliary}/stmts/__init__.py +0 -0
  75. /bloqade/stim/dialects/{aux → auxiliary}/stmts/annotate.py +0 -0
  76. /bloqade/stim/dialects/{aux → auxiliary}/stmts/const.py +0 -0
  77. /bloqade/stim/dialects/{aux → auxiliary}/types.py +0 -0
  78. /bloqade/stim/emit/{stim.py → stim_str.py} +0 -0
  79. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/WHEEL +0 -0
  80. {bloqade_circuit-0.2.2.dist-info → bloqade_circuit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,25 @@
1
+ from kirin import ir
2
+ from kirin.passes import Pass
3
+ from kirin.rewrite import (
4
+ Walk,
5
+ Chain,
6
+ Fixpoint,
7
+ ConstantFold,
8
+ CommonSubexpressionElimination,
9
+ )
10
+
11
+ from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12
+
13
+
14
+ class UnrollIfs(Pass):
15
+ """This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
16
+
17
+ def unsafe_run(self, mt: ir.Method):
18
+ result = Walk(LiftThenBody()).rewrite(mt.code)
19
+ result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
20
+ result = (
21
+ Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
22
+ .rewrite(mt.code)
23
+ .join(result)
24
+ )
25
+ return result
@@ -3,6 +3,7 @@ from .glob import (
3
3
  GlobalToParallelRule as GlobalToParallelRule,
4
4
  )
5
5
  from .register import RaiseRegisterRule as RaiseRegisterRule
6
+ from .native_gates import RydbergGateSetRewriteRule as RydbergGateSetRewriteRule
6
7
  from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
7
8
  from .uop_to_parallel import (
8
9
  MergePolicyABC as MergePolicyABC,
@@ -5,16 +5,17 @@ from kirin.passes import Pass
5
5
  from kirin.rewrite import abc, walk
6
6
  from kirin.dialects import py
7
7
 
8
+ from bloqade.qasm2 import types
8
9
  from bloqade.qasm2.dialects import core
9
10
 
10
11
 
11
12
  class IndexingDesugarRule(abc.RewriteRule):
12
13
  def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
13
14
  if isinstance(node, py.indexing.GetItem):
14
- if node.obj.type.is_subseteq(core.QRegType):
15
+ if node.obj.type.is_subseteq(types.QRegType):
15
16
  node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
16
17
  return abc.RewriteResult(has_done_something=True)
17
- elif node.obj.type.is_subseteq(core.CRegType):
18
+ elif node.obj.type.is_subseteq(types.CRegType):
18
19
  node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
19
20
  return abc.RewriteResult(has_done_something=True)
20
21
 
@@ -18,6 +18,7 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
18
18
  """
19
19
 
20
20
  address_analysis: Dict[ir.SSAValue, address.Address]
21
+ qubit_ssa_value: Dict[int, ir.SSAValue]
21
22
  gate_noise_params: native.GateNoiseParams = field(
22
23
  default_factory=native.GateNoiseParams
23
24
  )
@@ -25,15 +26,6 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
25
26
  default_factory=native.TwoRowZoneModel
26
27
  )
27
28
 
28
- def __post_init__(self):
29
- self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
30
- for ssa_value, addr in self.address_analysis.items():
31
- if (
32
- isinstance(addr, address.AddressQubit)
33
- and ssa_value not in self.qubit_ssa_value
34
- ):
35
- self.qubit_ssa_value[addr.data] = ssa_value
36
-
37
29
  def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
38
30
  if isinstance(node, uop.SingleQubitGate):
39
31
  return self.rewrite_single_qubit_gate(node)
@@ -173,6 +173,53 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
173
173
  cirq.XPowGate(exponent=0.5).on(self.cached_qubits[0]), node
174
174
  )
175
175
 
176
+ def rewrite_ccx(self, node: uop.CCX) -> abc.RewriteResult:
177
+ # from https://algassert.com/quirk#circuit=%7B%22cols%22:%5B%5B%22QFT3%22%5D,%5B%22inputA3%22,1,1,%22+=A3%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22%E2%80%A2%22,%22X%22%5D,%5B1,1,1,%22%E2%80%A6%22,%22%E2%80%A6%22,%22%E2%80%A6%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E%C2%BC%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22Z%5E%C2%BC%22,%22Z%5E%C2%BC%22%5D,%5B1,1,1,1,%22H%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22X%5E-%C2%BC%22,%22X%5E%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22H%22%5D%5D%7D
178
+
179
+ # x^(1/4)
180
+ lam1, theta1, phi1 = map(
181
+ self.const_float,
182
+ map(around, (1.5707963267948966, 0.7853981633974483, -1.5707963267948966)),
183
+ )
184
+ lam1.insert_before(node)
185
+ theta1.insert_before(node)
186
+ phi1.insert_before(node)
187
+
188
+ lam1 = lam1.result
189
+ theta1 = theta1.result
190
+ phi1 = phi1.result
191
+
192
+ # x^(-1/4)
193
+ lam2, theta2, phi2 = map(
194
+ self.const_float,
195
+ map(around, (4.71238898038469, 0.7853981633974483, 1.5707963267948966)),
196
+ )
197
+ lam2.insert_before(node)
198
+ theta2.insert_before(node)
199
+ phi2.insert_before(node)
200
+ lam2 = lam2.result
201
+ theta2 = theta2.result
202
+ phi2 = phi2.result
203
+
204
+ uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
205
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
206
+ uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
207
+ uop.UGate(node.qarg, theta1, phi1, lam1).insert_before(node)
208
+ uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
209
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
210
+ uop.T(node.ctrl1).insert_before(node)
211
+ uop.T(node.ctrl2).insert_before(node)
212
+ uop.H(node.ctrl1).insert_before(node)
213
+ uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
214
+ uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
215
+ uop.UGate(node.ctrl1, theta2, phi2, lam2).insert_before(node)
216
+ uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
217
+ uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
218
+ uop.H(node.ctrl1).insert_before(node)
219
+ node.delete() # delete the original CCX gate
220
+
221
+ return abc.RewriteResult(has_done_something=True)
222
+
176
223
  def rewrite_sxdg(self, node: uop.SXdag) -> abc.RewriteResult:
177
224
  return self._rewrite_1q_gates(
178
225
  cirq.XPowGate(exponent=-0.5).on(self.cached_qubits[0]), node
@@ -394,9 +441,12 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
394
441
  new_gate_stmts = self._generate_1q_gate_stmts(cirq_gate, node.qarg)
395
442
  return self._rewrite_gate_stmts(new_gate_stmts, node)
396
443
 
397
- def _generate_2q_ctrl_gate_stmts(
444
+ def _generate_multi_ctrl_gate_stmts(
398
445
  self, cirq_gate: cirq.Operation, qubits_ssa: List[ir.SSAValue]
399
446
  ) -> list[ir.Statement]:
447
+ qubit_to_ssa_map = {
448
+ q: ssa for q, ssa in zip(self.cached_qubits[: len(qubits_ssa)], qubits_ssa)
449
+ }
400
450
  target_gates = self.gateset.decompose_to_target_gateset(cirq_gate, 0)
401
451
  new_stmts = []
402
452
  for new_gate in target_gates:
@@ -412,7 +462,7 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
412
462
  new_stmts.append(phi2_stmt)
413
463
  new_stmts.append(
414
464
  uop.UGate(
415
- qarg=qubits_ssa[new_gate.qubits[0].x],
465
+ qarg=qubit_to_ssa_map[new_gate.qubits[0]],
416
466
  theta=phi0_stmt.result,
417
467
  phi=phi1_stmt.result,
418
468
  lam=phi2_stmt.result,
@@ -420,18 +470,31 @@ class RydbergGateSetRewriteRule(abc.RewriteRule):
420
470
  )
421
471
  else:
422
472
  # 2q
423
- new_stmts.append(uop.CZ(ctrl=qubits_ssa[0], qarg=qubits_ssa[1]))
473
+ new_stmts.append(
474
+ uop.CZ(
475
+ ctrl=qubit_to_ssa_map[new_gate.qubits[0]],
476
+ qarg=qubit_to_ssa_map[new_gate.qubits[1]],
477
+ )
478
+ )
424
479
 
425
480
  return new_stmts
426
481
 
427
482
  def _rewrite_2q_ctrl_gates(
428
483
  self, cirq_gate: cirq.Operation, node: uop.TwoQubitCtrlGate
429
484
  ) -> abc.RewriteResult:
430
- new_gate_stmts = self._generate_2q_ctrl_gate_stmts(
485
+ new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
431
486
  cirq_gate, [node.ctrl, node.qarg]
432
487
  )
433
488
  return self._rewrite_gate_stmts(new_gate_stmts, node)
434
489
 
490
+ def _rewrite_3q_ctrl_gates(
491
+ self, cirq_gate: cirq.Operation, node: uop.CCX
492
+ ) -> abc.RewriteResult:
493
+ new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
494
+ cirq_gate, [node.ctrl1, node.ctrl2, node.qarg]
495
+ )
496
+ return self._rewrite_gate_stmts(new_gate_stmts, node)
497
+
435
498
  def _rewrite_gate_stmts(
436
499
  self, new_gate_stmts: list[ir.Statement], node: ir.Statement
437
500
  ):
@@ -0,0 +1,66 @@
1
+ from kirin import ir
2
+ from kirin.dialects import scf, func
3
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
4
+
5
+ from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
6
+ from ..dialects.core.stmts import Reset, Measure
7
+
8
+ # TODO: unify with PR #248
9
+ AllowedThenType = SingleQubitGate | TwoQubitCtrlGate | Measure | Reset
10
+
11
+ DontLiftType = AllowedThenType | scf.Yield | func.Return | func.Invoke
12
+
13
+
14
+ class LiftThenBody(RewriteRule):
15
+ """Lifts anything that's not a UOP or a yield/return out of the then body"""
16
+
17
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
+ if not isinstance(node, scf.IfElse):
19
+ return RewriteResult()
20
+
21
+ then_stmts = node.then_body.stmts()
22
+
23
+ lift_stmts = [stmt for stmt in then_stmts if not isinstance(stmt, DontLiftType)]
24
+
25
+ if len(lift_stmts) == 0:
26
+ return RewriteResult()
27
+
28
+ for stmt in lift_stmts:
29
+ stmt.detach()
30
+ stmt.insert_before(node)
31
+
32
+ return RewriteResult(has_done_something=True)
33
+
34
+
35
+ class SplitIfStmts(RewriteRule):
36
+ """Splits the then body of an if-else statement into multiple if statements"""
37
+
38
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
39
+ if not isinstance(node, scf.IfElse):
40
+ return RewriteResult()
41
+
42
+ *stmts, yield_or_return = node.then_body.stmts()
43
+
44
+ if len(stmts) == 1:
45
+ return RewriteResult()
46
+
47
+ is_yield = isinstance(yield_or_return, scf.Yield)
48
+
49
+ for stmt in stmts:
50
+ stmt.detach()
51
+
52
+ yield_or_return = scf.Yield() if is_yield else func.Return()
53
+
54
+ then_block = ir.Block((stmt, yield_or_return), argtypes=(node.cond.type,))
55
+ then_body = ir.Region(then_block)
56
+ else_body = node.else_body.clone()
57
+ else_body.detach()
58
+ new_if = scf.IfElse(
59
+ cond=node.cond, then_body=then_body, else_body=else_body
60
+ )
61
+
62
+ new_if.insert_before(node)
63
+
64
+ node.delete()
65
+
66
+ return RewriteResult(has_done_something=True)
@@ -1,6 +1,7 @@
1
1
  # Need this for impl registration to work properly!
2
2
  from . import impls as impls
3
3
  from .lattice import (
4
+ Sites as Sites,
4
5
  NoSites as NoSites,
5
6
  AnySites as AnySites,
6
7
  NumberSites as NumberSites,
@@ -1,6 +1,6 @@
1
1
  from kirin import interp
2
2
 
3
- from bloqade.squin import op
3
+ from bloqade.squin import op, wire
4
4
 
5
5
  from .lattice import (
6
6
  NoSites,
@@ -9,6 +9,30 @@ from .lattice import (
9
9
  from .analysis import NSitesAnalysis
10
10
 
11
11
 
12
+ @wire.dialect.register(key="op.nsites")
13
+ class SquinWire(interp.MethodTable):
14
+
15
+ @interp.impl(wire.Apply)
16
+ @interp.impl(wire.Broadcast)
17
+ def apply(
18
+ self,
19
+ interp: NSitesAnalysis,
20
+ frame: interp.Frame,
21
+ stmt: wire.Apply | wire.Broadcast,
22
+ ):
23
+
24
+ return tuple(frame.get(input) for input in stmt.inputs)
25
+
26
+ @interp.impl(wire.MeasureAndReset)
27
+ def measure_and_reset(
28
+ self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset
29
+ ):
30
+
31
+ # MeasureAndReset produces both a new wire
32
+ # and an integer which don't have any sites at all
33
+ return (NoSites(), NoSites())
34
+
35
+
12
36
  @op.dialect.register(key="op.nsites")
13
37
  class SquinOp(interp.MethodTable):
14
38
 
@@ -1,27 +1,8 @@
1
- # Put all the proper wrappers here
2
-
3
- from kirin.lowering import wraps as _wraps
4
-
5
- from bloqade.squin.op.types import Op
6
-
7
1
  from . import stmts as stmts
8
-
9
-
10
- @_wraps(stmts.PauliError)
11
- def pauli_error(basis: Op, p: float) -> Op: ...
12
-
13
-
14
- @_wraps(stmts.PPError)
15
- def pp_error(op: Op, p: float) -> Op: ...
16
-
17
-
18
- @_wraps(stmts.Depolarize)
19
- def depolarize(n_qubits: int, p: float) -> Op: ...
20
-
21
-
22
- @_wraps(stmts.PauliChannel)
23
- def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
24
-
25
-
26
- @_wraps(stmts.QubitLoss)
27
- def qubit_loss(p: float) -> Op: ...
2
+ from ._dialect import dialect as dialect
3
+ from ._wrapper import (
4
+ pp_error as pp_error,
5
+ depolarize as depolarize,
6
+ qubit_loss as qubit_loss,
7
+ pauli_channel as pauli_channel,
8
+ )
@@ -0,0 +1,25 @@
1
+ from kirin.lowering import wraps
2
+
3
+ from bloqade.squin.op.types import Op
4
+
5
+ from . import stmts
6
+
7
+
8
+ @wraps(stmts.PauliError)
9
+ def pauli_error(basis: Op, p: float) -> Op: ...
10
+
11
+
12
+ @wraps(stmts.PPError)
13
+ def pp_error(op: Op, p: float) -> Op: ...
14
+
15
+
16
+ @wraps(stmts.Depolarize)
17
+ def depolarize(n_qubits: int, p: float) -> Op: ...
18
+
19
+
20
+ @wraps(stmts.PauliChannel)
21
+ def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
22
+
23
+
24
+ @wraps(stmts.QubitLoss)
25
+ def qubit_loss(p: float) -> Op: ...
@@ -1,162 +1,36 @@
1
- from kirin import ir as _ir
2
- from kirin.prelude import structural_no_opt as _structural_no_opt
3
- from kirin.lowering import wraps as _wraps
4
-
5
1
  from . import stmts as stmts, types as types, rewrite as rewrite
2
+ from .stdlib import (
3
+ ch as ch,
4
+ cx as cx,
5
+ cy as cy,
6
+ cz as cz,
7
+ rx as rx,
8
+ ry as ry,
9
+ rz as rz,
10
+ cphase as cphase,
11
+ )
6
12
  from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
7
13
  from ._dialect import dialect as dialect
8
-
9
-
10
- @_wraps(stmts.Kron)
11
- def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12
-
13
-
14
- @_wraps(stmts.Mult)
15
- def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
16
-
17
-
18
- @_wraps(stmts.Scale)
19
- def scale(op: types.Op, factor: complex) -> types.Op: ...
20
-
21
-
22
- @_wraps(stmts.Adjoint)
23
- def adjoint(op: types.Op) -> types.Op: ...
24
-
25
-
26
- @_wraps(stmts.Control)
27
- def control(op: types.Op, *, n_controls: int) -> types.Op:
28
- """
29
- Create a controlled operator.
30
-
31
- Note, that when considering atom loss, the operator will not be applied if
32
- any of the controls has been lost.
33
-
34
- Args:
35
- operator: The operator to apply under the control.
36
- n_controls: The number qubits to be used as control.
37
-
38
- Returns:
39
- Operator
40
- """
41
- ...
42
-
43
-
44
- @_wraps(stmts.Identity)
45
- def identity(*, sites: int) -> types.Op: ...
46
-
47
-
48
- @_wraps(stmts.Rot)
49
- def rot(axis: types.Op, angle: float) -> types.Op: ...
50
-
51
-
52
- @_wraps(stmts.ShiftOp)
53
- def shift(theta: float) -> types.Op: ...
54
-
55
-
56
- @_wraps(stmts.PhaseOp)
57
- def phase(theta: float) -> types.Op: ...
58
-
59
-
60
- @_wraps(stmts.X)
61
- def x() -> types.Op: ...
62
-
63
-
64
- @_wraps(stmts.Y)
65
- def y() -> types.Op: ...
66
-
67
-
68
- @_wraps(stmts.Z)
69
- def z() -> types.Op: ...
70
-
71
-
72
- @_wraps(stmts.H)
73
- def h() -> types.Op: ...
74
-
75
-
76
- @_wraps(stmts.S)
77
- def s() -> types.Op: ...
78
-
79
-
80
- @_wraps(stmts.T)
81
- def t() -> types.Op: ...
82
-
83
-
84
- @_wraps(stmts.P0)
85
- def p0() -> types.Op: ...
86
-
87
-
88
- @_wraps(stmts.P1)
89
- def p1() -> types.Op: ...
90
-
91
-
92
- @_wraps(stmts.Sn)
93
- def spin_n() -> types.Op: ...
94
-
95
-
96
- @_wraps(stmts.Sp)
97
- def spin_p() -> types.Op: ...
98
-
99
-
100
- @_wraps(stmts.U3)
101
- def u(theta: float, phi: float, lam: float) -> types.Op: ...
102
-
103
-
104
- @_wraps(stmts.PauliString)
105
- def pauli_string(*, string: str) -> types.Op: ...
106
-
107
-
108
- # stdlibs
109
- @_ir.dialect_group(_structural_no_opt.add(dialect))
110
- def op(self):
111
- def run_pass(method):
112
- pass
113
-
114
- return run_pass
115
-
116
-
117
- @op
118
- def rx(theta: float) -> types.Op:
119
- """Rotation X gate."""
120
- return rot(x(), theta)
121
-
122
-
123
- @op
124
- def ry(theta: float) -> types.Op:
125
- """Rotation Y gate."""
126
- return rot(y(), theta)
127
-
128
-
129
- @op
130
- def rz(theta: float) -> types.Op:
131
- """Rotation Z gate."""
132
- return rot(z(), theta)
133
-
134
-
135
- @op
136
- def cx() -> types.Op:
137
- """Controlled X gate."""
138
- return control(x(), n_controls=1)
139
-
140
-
141
- @op
142
- def cy() -> types.Op:
143
- """Controlled Y gate."""
144
- return control(y(), n_controls=1)
145
-
146
-
147
- @op
148
- def cz() -> types.Op:
149
- """Control Z gate."""
150
- return control(z(), n_controls=1)
151
-
152
-
153
- @op
154
- def ch() -> types.Op:
155
- """Control H gate."""
156
- return control(h(), n_controls=1)
157
-
158
-
159
- @op
160
- def cphase(theta: float) -> types.Op:
161
- """Control Phase gate."""
162
- return control(phase(theta), n_controls=1)
14
+ from ._wrapper import (
15
+ h as h,
16
+ s as s,
17
+ t as t,
18
+ u as u,
19
+ x as x,
20
+ y as y,
21
+ z as z,
22
+ p0 as p0,
23
+ p1 as p1,
24
+ rot as rot,
25
+ kron as kron,
26
+ mult as mult,
27
+ phase as phase,
28
+ scale as scale,
29
+ shift as shift,
30
+ spin_n as spin_n,
31
+ spin_p as spin_p,
32
+ adjoint as adjoint,
33
+ control as control,
34
+ identity as identity,
35
+ pauli_string as pauli_string,
36
+ )
@@ -0,0 +1,101 @@
1
+ from kirin.lowering import wraps
2
+
3
+ from . import stmts, types
4
+
5
+
6
+ @wraps(stmts.Kron)
7
+ def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
8
+
9
+
10
+ @wraps(stmts.Mult)
11
+ def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12
+
13
+
14
+ @wraps(stmts.Scale)
15
+ def scale(op: types.Op, factor: complex) -> types.Op: ...
16
+
17
+
18
+ @wraps(stmts.Adjoint)
19
+ def adjoint(op: types.Op) -> types.Op: ...
20
+
21
+
22
+ @wraps(stmts.Control)
23
+ def control(op: types.Op, *, n_controls: int) -> types.Op:
24
+ """
25
+ Create a controlled operator.
26
+
27
+ Note, that when considering atom loss, the operator will not be applied if
28
+ any of the controls has been lost.
29
+
30
+ Args:
31
+ operator: The operator to apply under the control.
32
+ n_controls: The number qubits to be used as control.
33
+
34
+ Returns:
35
+ Operator
36
+ """
37
+ ...
38
+
39
+
40
+ @wraps(stmts.Identity)
41
+ def identity(*, sites: int) -> types.Op: ...
42
+
43
+
44
+ @wraps(stmts.Rot)
45
+ def rot(axis: types.Op, angle: float) -> types.Op: ...
46
+
47
+
48
+ @wraps(stmts.ShiftOp)
49
+ def shift(theta: float) -> types.Op: ...
50
+
51
+
52
+ @wraps(stmts.PhaseOp)
53
+ def phase(theta: float) -> types.Op: ...
54
+
55
+
56
+ @wraps(stmts.X)
57
+ def x() -> types.Op: ...
58
+
59
+
60
+ @wraps(stmts.Y)
61
+ def y() -> types.Op: ...
62
+
63
+
64
+ @wraps(stmts.Z)
65
+ def z() -> types.Op: ...
66
+
67
+
68
+ @wraps(stmts.H)
69
+ def h() -> types.Op: ...
70
+
71
+
72
+ @wraps(stmts.S)
73
+ def s() -> types.Op: ...
74
+
75
+
76
+ @wraps(stmts.T)
77
+ def t() -> types.Op: ...
78
+
79
+
80
+ @wraps(stmts.P0)
81
+ def p0() -> types.Op: ...
82
+
83
+
84
+ @wraps(stmts.P1)
85
+ def p1() -> types.Op: ...
86
+
87
+
88
+ @wraps(stmts.Sn)
89
+ def spin_n() -> types.Op: ...
90
+
91
+
92
+ @wraps(stmts.Sp)
93
+ def spin_p() -> types.Op: ...
94
+
95
+
96
+ @wraps(stmts.U3)
97
+ def u(theta: float, phi: float, lam: float) -> types.Op: ...
98
+
99
+
100
+ @wraps(stmts.PauliString)
101
+ def pauli_string(*, string: str) -> types.Op: ...