bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,68 @@
1
+ from kirin import interp
2
+
3
+ from bloqade.qasm2.parse import ast
4
+ from bloqade.qasm2.emit.main import EmitQASM2Main, EmitQASM2Frame
5
+
6
+ from . import stmts
7
+ from ._dialect import dialect
8
+
9
+
10
+ @dialect.register(key="emit.qasm2.main")
11
+ class Core(interp.MethodTable):
12
+
13
+ @interp.impl(stmts.CRegNew)
14
+ def emit_creg_new(
15
+ self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.CRegNew
16
+ ):
17
+ n_bits = emit.assert_node(ast.Number, frame.get(stmt.n_bits))
18
+ # check if its int first, because Int.is_integer() is added for >=3.12
19
+ assert isinstance(n_bits.value, int), "expected integer"
20
+ name = emit.ssa_id[stmt.result]
21
+ frame.body.append(ast.CReg(name=name, size=int(n_bits.value)))
22
+ return (ast.Name(name),)
23
+
24
+ @interp.impl(stmts.QRegNew)
25
+ def emit_qreg_new(
26
+ self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.QRegNew
27
+ ):
28
+ n_bits = emit.assert_node(ast.Number, frame.get(stmt.n_qubits))
29
+ assert isinstance(n_bits.value, int), "expected integer"
30
+ name = emit.ssa_id[stmt.result]
31
+ frame.body.append(ast.QReg(name=name, size=int(n_bits.value)))
32
+ return (ast.Name(name),)
33
+
34
+ @interp.impl(stmts.Reset)
35
+ def emit_reset(self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.Reset):
36
+ qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg))
37
+ frame.body.append(ast.Reset(qarg=qarg))
38
+ return ()
39
+
40
+ @interp.impl(stmts.Measure)
41
+ def emit_measure(
42
+ self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.Measure
43
+ ):
44
+ qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg))
45
+ carg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.carg))
46
+ frame.body.append(ast.Measure(qarg=qarg, carg=carg))
47
+ return ()
48
+
49
+ @interp.impl(stmts.CRegEq)
50
+ def emit_creg_eq(
51
+ self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.CRegEq
52
+ ):
53
+ lhs = emit.assert_node(ast.Expr, frame.get(stmt.lhs))
54
+ rhs = emit.assert_node(ast.Expr, frame.get(stmt.rhs))
55
+ return (ast.Cmp(lhs=lhs, rhs=rhs),)
56
+
57
+ @interp.impl(stmts.CRegGet)
58
+ @interp.impl(stmts.QRegGet)
59
+ def emit_qreg_get(
60
+ self,
61
+ emit: EmitQASM2Main,
62
+ frame: EmitQASM2Frame,
63
+ stmt: stmts.QRegGet | stmts.CRegGet,
64
+ ):
65
+ reg = emit.assert_node(ast.Name, frame.get(stmt.reg))
66
+ idx = emit.assert_node(ast.Number, frame.get(stmt.idx))
67
+ assert isinstance(idx.value, int)
68
+ return (ast.Bit(reg, int(idx.value)),)
@@ -0,0 +1,23 @@
1
+ from kirin import types, interp
2
+ from kirin.analysis import TypeInference
3
+ from kirin.dialects import py
4
+
5
+ from bloqade.qasm2.types import CRegType, QRegType, QubitType
6
+
7
+ from ._dialect import dialect
8
+
9
+
10
+ @dialect.register(key="typeinfer")
11
+ class TypeInfer(interp.MethodTable):
12
+
13
+ @interp.impl(py.indexing.GetItem, QRegType, types.Int)
14
+ def getitem_qreg(
15
+ self, infer: TypeInference, frame: interp.Frame, node: py.indexing.GetItem
16
+ ):
17
+ return (QubitType,)
18
+
19
+ @interp.impl(py.indexing.GetItem, CRegType, types.Int)
20
+ def getitem_creg(
21
+ self, infer: TypeInference, frame: interp.Frame, node: py.indexing.GetItem
22
+ ):
23
+ return (QubitType,)
@@ -0,0 +1,38 @@
1
+ from kirin import interp
2
+
3
+ from bloqade.analysis.address import (
4
+ Address,
5
+ NotQubit,
6
+ AddressReg,
7
+ AddressQubit,
8
+ AddressAnalysis,
9
+ )
10
+
11
+ from .stmts import QRegGet, QRegNew
12
+ from ._dialect import dialect
13
+
14
+
15
+ @dialect.register(key="qubit.address")
16
+ class AddressMethodTable(interp.MethodTable):
17
+
18
+ @interp.impl(QRegNew)
19
+ def new(
20
+ self,
21
+ interp: AddressAnalysis,
22
+ frame: interp.Frame[Address],
23
+ stmt: QRegNew,
24
+ ):
25
+ n_qubits = interp.get_const_value(int, stmt.n_qubits)
26
+ addr = AddressReg(range(interp.next_address, interp.next_address + n_qubits))
27
+ interp.next_address += n_qubits
28
+ return (addr,)
29
+
30
+ @interp.impl(QRegGet)
31
+ def get(self, interp: AddressAnalysis, frame: interp.Frame[Address], stmt: QRegGet):
32
+ addr = frame.get(stmt.reg)
33
+ pos = interp.get_const_value(int, stmt.idx)
34
+ if isinstance(addr, AddressReg):
35
+ global_idx = addr.data[pos]
36
+ return (AddressQubit(global_idx),)
37
+ else: # this is not reachable
38
+ return (NotQubit(),)
@@ -0,0 +1,94 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from bloqade.qasm2.types import BitType, CRegType, QRegType, QubitType
5
+
6
+ from ._dialect import dialect
7
+
8
+
9
+ @statement(dialect=dialect)
10
+ class QRegNew(ir.Statement):
11
+ """Create a new quantum register."""
12
+
13
+ name = "qreg.new"
14
+ traits = frozenset({lowering.FromPythonCall()})
15
+ n_qubits: ir.SSAValue = info.argument(types.Int)
16
+ """n_qubits: The number of qubits in the register."""
17
+ result: ir.ResultValue = info.result(QRegType)
18
+ """A new quantum register with n_qubits set to |0>."""
19
+
20
+
21
+ @statement(dialect=dialect)
22
+ class CRegNew(ir.Statement):
23
+ """Create a new classical register."""
24
+
25
+ name = "creg.new"
26
+ traits = frozenset({lowering.FromPythonCall()})
27
+ n_bits: ir.SSAValue = info.argument(types.Int)
28
+ """n_bits (Int): The number of bits in the register."""
29
+ result: ir.ResultValue = info.result(CRegType)
30
+ """result (CReg): The new classical register with all bits set to 0."""
31
+
32
+
33
+ @statement(dialect=dialect)
34
+ class Reset(ir.Statement):
35
+ """Reset a qubit to the |0> state."""
36
+
37
+ name = "reset"
38
+ traits = frozenset({lowering.FromPythonCall()})
39
+ qarg: ir.SSAValue = info.argument(QubitType)
40
+ """qarg (Qubit): The qubit to reset."""
41
+
42
+
43
+ @statement(dialect=dialect)
44
+ class Measure(ir.Statement):
45
+ """Measure a qubit and store the result in a bit."""
46
+
47
+ name = "measure"
48
+ traits = frozenset({lowering.FromPythonCall()})
49
+ qarg: ir.SSAValue = info.argument(QubitType)
50
+ """qarg (Qubit): The qubit to measure."""
51
+ carg: ir.SSAValue = info.argument(BitType)
52
+ """carg (Bit): The bit to store the result in."""
53
+
54
+
55
+ @statement(dialect=dialect)
56
+ class CRegEq(ir.Statement):
57
+ """Check if two classical registers are equal."""
58
+
59
+ name = "eq"
60
+ traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
61
+ lhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType)
62
+ """lhs (CReg): The first register."""
63
+ rhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType)
64
+ """rhs (CReg): The second register."""
65
+ result: ir.ResultValue = info.result(types.Bool)
66
+ """result (bool): True if the registers are equal, False otherwise."""
67
+
68
+
69
+ @statement(dialect=dialect)
70
+ class QRegGet(ir.Statement):
71
+ """Get a qubit from a quantum register."""
72
+
73
+ name = "qreg.get"
74
+ traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
75
+ reg: ir.SSAValue = info.argument(QRegType)
76
+ """reg (QReg): The quantum register."""
77
+ idx: ir.SSAValue = info.argument(types.Int)
78
+ """idx (Int): The index of the qubit in the register."""
79
+ result: ir.ResultValue = info.result(QubitType)
80
+ """result (Qubit): The qubit at position `idx`."""
81
+
82
+
83
+ @statement(dialect=dialect)
84
+ class CRegGet(ir.Statement):
85
+ """Get a bit from a classical register."""
86
+
87
+ name = "creg.get"
88
+ traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
89
+ reg: ir.SSAValue = info.argument(CRegType)
90
+ """reg (CReg): The classical register."""
91
+ idx: ir.SSAValue = info.argument(types.Int)
92
+ """idx (Int): The index of the bit in the register."""
93
+ result: ir.ResultValue = info.result(BitType)
94
+ """result (Bit): The bit at position `idx`."""
@@ -0,0 +1,3 @@
1
+ from . import _emit as _emit, _interp as _interp, _from_python as _from_python
2
+ from .stmts import * # noqa: F403
3
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("qasm2.expr")
@@ -0,0 +1,103 @@
1
+ from typing import Literal
2
+
3
+ from kirin import interp
4
+ from kirin.emit.exceptions import EmitError
5
+
6
+ from bloqade.qasm2.parse import ast
7
+ from bloqade.qasm2.types import QubitType
8
+ from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame
9
+
10
+ from . import stmts
11
+ from ._dialect import dialect
12
+
13
+
14
+ @dialect.register(key="emit.qasm2.gate")
15
+ class EmitExpr(interp.MethodTable):
16
+
17
+ @interp.impl(stmts.GateFunction)
18
+ def emit_func(
19
+ self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
20
+ ):
21
+
22
+ cparams, qparams = [], []
23
+ for arg in stmt.body.blocks[0].args[1:]:
24
+ name = frame.get(arg)
25
+ if not isinstance(name, ast.Name):
26
+ raise EmitError("expected ast.Name")
27
+ if arg.type.is_subseteq(QubitType):
28
+ qparams.append(name.id)
29
+ else:
30
+ cparams.append(name.id)
31
+ emit.run_ssacfg_region(frame, stmt.body)
32
+ emit.output = ast.Gate(
33
+ name=stmt.sym_name,
34
+ cparams=cparams,
35
+ qparams=qparams,
36
+ body=frame.body,
37
+ )
38
+ return ()
39
+
40
+ @interp.impl(stmts.ConstInt)
41
+ @interp.impl(stmts.ConstFloat)
42
+ def emit_const_int(
43
+ self,
44
+ emit: EmitQASM2Gate,
45
+ frame: EmitQASM2Frame,
46
+ stmt: stmts.ConstInt | stmts.ConstFloat,
47
+ ):
48
+ return (ast.Number(stmt.value),)
49
+
50
+ @interp.impl(stmts.ConstPI)
51
+ def emit_const_pi(
52
+ self,
53
+ emit: EmitQASM2Gate,
54
+ frame: EmitQASM2Frame,
55
+ stmt: stmts.ConstPI,
56
+ ):
57
+ return (ast.Pi(),)
58
+
59
+ @interp.impl(stmts.Neg)
60
+ def emit_neg(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Neg):
61
+ arg = emit.assert_node(ast.Expr, frame.get(stmt.value))
62
+ return (ast.UnaryOp("-", arg),)
63
+
64
+ @interp.impl(stmts.Sin)
65
+ @interp.impl(stmts.Cos)
66
+ @interp.impl(stmts.Tan)
67
+ @interp.impl(stmts.Exp)
68
+ @interp.impl(stmts.Log)
69
+ @interp.impl(stmts.Sqrt)
70
+ def emit_sin(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
71
+ arg = emit.assert_node(ast.Expr, frame.get(stmt.value))
72
+ return (ast.Call(stmt.name, [arg]),)
73
+
74
+ def emit_binop(
75
+ self,
76
+ sym: Literal["+", "-", "*", "/", "^"],
77
+ emit: EmitQASM2Gate,
78
+ frame: EmitQASM2Frame,
79
+ stmt,
80
+ ):
81
+ lhs = emit.assert_node(ast.Expr, frame.get(stmt.lhs))
82
+ rhs = emit.assert_node(ast.Expr, frame.get(stmt.rhs))
83
+ return (ast.BinOp(sym, lhs, rhs),)
84
+
85
+ @interp.impl(stmts.Add)
86
+ def emit_add(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
87
+ return self.emit_binop("+", emit, frame, stmt)
88
+
89
+ @interp.impl(stmts.Sub)
90
+ def emit_sub(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
91
+ return self.emit_binop("-", emit, frame, stmt)
92
+
93
+ @interp.impl(stmts.Mul)
94
+ def emit_mul(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
95
+ return self.emit_binop("*", emit, frame, stmt)
96
+
97
+ @interp.impl(stmts.Div)
98
+ def emit_div(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
99
+ return self.emit_binop("/", emit, frame, stmt)
100
+
101
+ @interp.impl(stmts.Pow)
102
+ def emit_pow(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
103
+ return self.emit_binop("^", emit, frame, stmt)
@@ -0,0 +1,86 @@
1
+ import ast
2
+
3
+ from kirin import ir, types, lowering
4
+
5
+ from . import stmts
6
+ from ._dialect import dialect
7
+
8
+
9
+ @dialect.register
10
+ class QASMUopLowering(lowering.FromPythonAST):
11
+
12
+ def lower_Name(self, state: lowering.State, node: ast.Name):
13
+ name = node.id
14
+ if isinstance(node.ctx, ast.Load):
15
+ value = state.current_frame.get(name)
16
+ if value is None:
17
+ raise lowering.BuildError(f"{name} is not defined")
18
+ return value
19
+ elif isinstance(node.ctx, ast.Store):
20
+ raise lowering.BuildError("unhandled store operation")
21
+ else: # Del
22
+ raise lowering.BuildError("unhandled del operation")
23
+
24
+ def lower_Assign(self, state: lowering.State, node: ast.Assign):
25
+ # NOTE: QASM only expects one value on right hand side
26
+ rhs = state.lower(node.value).expect_one()
27
+ current_frame = state.current_frame
28
+ match node:
29
+ case ast.Assign(targets=[ast.Name(lhs_name, ast.Store())], value=_):
30
+ rhs.name = lhs_name
31
+ current_frame.defs[lhs_name] = rhs
32
+ case _:
33
+ target_syntax = ", ".join(
34
+ ast.unparse(target) for target in node.targets
35
+ )
36
+ raise lowering.BuildError(f"unsupported target syntax {target_syntax}")
37
+
38
+ def lower_Expr(self, state: lowering.State, node: ast.Expr):
39
+ return state.parent.visit(state, node.value)
40
+
41
+ def lower_Constant(self, state: lowering.State, node: ast.Constant):
42
+ if isinstance(node.value, int):
43
+ stmt = stmts.ConstInt(value=node.value)
44
+ return state.current_frame.push(stmt)
45
+ elif isinstance(node.value, float):
46
+ stmt = stmts.ConstFloat(value=node.value)
47
+ return state.current_frame.push(stmt)
48
+ else:
49
+ raise lowering.BuildError(
50
+ f"unsupported QASM 2.0 constant type {type(node.value)}"
51
+ )
52
+
53
+ def lower_BinOp(self, state: lowering.State, node: ast.BinOp):
54
+ lhs = state.lower(node.left).expect_one()
55
+ rhs = state.lower(node.right).expect_one()
56
+ if isinstance(node.op, ast.Add):
57
+ stmt = stmts.Add(lhs, rhs)
58
+ elif isinstance(node.op, ast.Sub):
59
+ stmt = stmts.Sub(lhs, rhs)
60
+ elif isinstance(node.op, ast.Mult):
61
+ stmt = stmts.Mul(lhs, rhs)
62
+ elif isinstance(node.op, ast.Div):
63
+ stmt = stmts.Div(lhs, rhs)
64
+ elif isinstance(node.op, ast.Pow):
65
+ stmt = stmts.Pow(lhs, rhs)
66
+ else:
67
+ raise lowering.BuildError(f"unsupported QASM 2.0 binop {node.op}")
68
+ stmt.result.type = self.__promote_binop_type(lhs, rhs)
69
+ return state.current_frame.push(stmt)
70
+
71
+ def __promote_binop_type(
72
+ self, lhs: ir.SSAValue, rhs: ir.SSAValue
73
+ ) -> types.TypeAttribute:
74
+ if lhs.type.is_subseteq(types.Float) or rhs.type.is_subseteq(types.Float):
75
+ return types.Float
76
+ return types.Int
77
+
78
+ def lower_UnaryOp(self, state: lowering.State, node: ast.UnaryOp):
79
+ if isinstance(node.op, ast.USub):
80
+ value = state.lower(node.operand).expect_one()
81
+ stmt = stmts.Neg(value)
82
+ return state.current_frame.push(stmt)
83
+ elif isinstance(node.op, ast.UAdd):
84
+ return state.lower(node.operand).expect_one()
85
+ else:
86
+ raise lowering.BuildError(f"unsupported QASM 2.0 unaryop {node.op}")
@@ -0,0 +1,75 @@
1
+ import math
2
+ from typing import Union
3
+
4
+ from kirin.interp import Frame, Interpreter, MethodTable, impl
5
+
6
+ from . import stmts
7
+ from ._dialect import dialect
8
+
9
+
10
+ @dialect.register
11
+ class Qasm2UopInterpreter(MethodTable):
12
+ name = "qasm2.uop"
13
+ dialect = dialect
14
+
15
+ @impl(stmts.ConstFloat)
16
+ @impl(stmts.ConstInt)
17
+ def new_const(
18
+ self,
19
+ interp: Interpreter,
20
+ frame: Frame,
21
+ stmt: Union[stmts.ConstFloat, stmts.ConstInt],
22
+ ):
23
+ return (stmt.value,)
24
+
25
+ @impl(stmts.ConstPI)
26
+ def new_const_pi(self, interp: Interpreter, frame: Frame, stmt: stmts.ConstPI):
27
+ return (3.141592653589793,)
28
+
29
+ @impl(stmts.Add)
30
+ def add(self, interp: Interpreter, frame: Frame, stmt: stmts.Add):
31
+ return (frame.get(stmt.lhs) + frame.get(stmt.rhs),)
32
+
33
+ @impl(stmts.Sub)
34
+ def sub(self, interp: Interpreter, frame: Frame, stmt: stmts.Sub):
35
+ return (frame.get(stmt.lhs) - frame.get(stmt.rhs),)
36
+
37
+ @impl(stmts.Mul)
38
+ def mul(self, interp: Interpreter, frame: Frame, stmt: stmts.Mul):
39
+ return (frame.get(stmt.lhs) * frame.get(stmt.rhs),)
40
+
41
+ @impl(stmts.Div)
42
+ def div(self, interp: Interpreter, frame: Frame, stmt: stmts.Div):
43
+ return (frame.get(stmt.lhs) / frame.get(stmt.rhs),)
44
+
45
+ @impl(stmts.Pow)
46
+ def pow(self, interp: Interpreter, frame: Frame, stmt: stmts.Pow):
47
+ return (frame.get(stmt.lhs) ** frame.get(stmt.rhs),)
48
+
49
+ @impl(stmts.Neg)
50
+ def neg(self, interp: Interpreter, frame: Frame, stmt: stmts.Neg):
51
+ return (-frame.get(stmt.value),)
52
+
53
+ @impl(stmts.Sqrt)
54
+ def sqrt(self, interp: Interpreter, frame: Frame, stmt: stmts.Sqrt):
55
+ return (math.sqrt(frame.get(stmt.value)),)
56
+
57
+ @impl(stmts.Sin)
58
+ def sin(self, interp: Interpreter, frame: Frame, stmt: stmts.Sin):
59
+ return (math.sin(frame.get(stmt.value)),)
60
+
61
+ @impl(stmts.Cos)
62
+ def cos(self, interp: Interpreter, frame: Frame, stmt: stmts.Cos):
63
+ return (math.cos(frame.get(stmt.value)),)
64
+
65
+ @impl(stmts.Tan)
66
+ def tan(self, interp: Interpreter, frame: Frame, stmt: stmts.Tan):
67
+ return (math.tan(frame.get(stmt.value)),)
68
+
69
+ @impl(stmts.Log)
70
+ def log(self, interp: Interpreter, frame: Frame, stmt: stmts.Log):
71
+ return (math.log(frame.get(stmt.value)),)
72
+
73
+ @impl(stmts.Exp)
74
+ def exp(self, interp: Interpreter, frame: Frame, stmt: stmts.Exp):
75
+ return (math.exp(frame.get(stmt.value)),)