bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,102 @@
1
+ from kirin.emit import EmitStrFrame
2
+ from kirin.interp import MethodTable, impl
3
+
4
+ from bloqade.stim.emit.stim import EmitStimMain
5
+
6
+ from . import stmts
7
+ from ._dialect import dialect
8
+
9
+
10
+ @dialect.register(key="emit.stim")
11
+ class EmitStimAuxMethods(MethodTable):
12
+
13
+ @impl(stmts.ConstInt)
14
+ def const_int(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstInt):
15
+
16
+ out: str = f"{stmt.value}"
17
+
18
+ return (out,)
19
+
20
+ @impl(stmts.ConstFloat)
21
+ def const_float(
22
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstFloat
23
+ ):
24
+
25
+ out: str = f"{stmt.value:.8f}"
26
+
27
+ return (out,)
28
+
29
+ @impl(stmts.ConstBool)
30
+ def const_bool(
31
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstBool
32
+ ):
33
+ out: str = "!" if stmt.value else ""
34
+
35
+ return (out,)
36
+
37
+ @impl(stmts.ConstStr)
38
+ def const_str(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstBool):
39
+
40
+ return (stmt.value,)
41
+
42
+ @impl(stmts.Neg)
43
+ def neg(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Neg):
44
+
45
+ operand: str = frame.get(stmt.operand)
46
+
47
+ return ("-" + operand,)
48
+
49
+ @impl(stmts.GetRecord)
50
+ def get_rec(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.GetRecord):
51
+
52
+ id: str = frame.get(stmt.id)
53
+ out: str = f"rec[{id}]"
54
+
55
+ return (out,)
56
+
57
+ @impl(stmts.Tick)
58
+ def tick(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Tick):
59
+
60
+ emit.writeln(frame, "TICK")
61
+
62
+ return ()
63
+
64
+ @impl(stmts.Detector)
65
+ def detector(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Detector):
66
+
67
+ coords: tuple[str, ...] = frame.get_values(stmt.coord)
68
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
69
+
70
+ coord_str: str = ", ".join(coords)
71
+ target_str: str = " ".join(targets)
72
+ emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}")
73
+
74
+ return ()
75
+
76
+ @impl(stmts.ObservableInclude)
77
+ def obs_include(
78
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ObservableInclude
79
+ ):
80
+
81
+ idx: str = frame.get(stmt.idx)
82
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
83
+
84
+ target_str: str = " ".join(targets)
85
+ emit.writeln(frame, f"OBSERVABLE_INCLUDE({idx}) {target_str}")
86
+
87
+ return ()
88
+
89
+ @impl(stmts.NewPauliString)
90
+ def new_paulistr(
91
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.NewPauliString
92
+ ):
93
+
94
+ string: tuple[str, ...] = frame.get_values(stmt.string)
95
+ flipped: tuple[str, ...] = frame.get_values(stmt.flipped)
96
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
97
+
98
+ out = "*".join(
99
+ f"{flp}{base}{tgt}" for flp, base, tgt in zip(flipped, string, targets)
100
+ )
101
+
102
+ return (out,)
@@ -0,0 +1,39 @@
1
+ from kirin import interp
2
+
3
+ from . import stmts
4
+ from .types import RecordResult
5
+ from ._dialect import dialect
6
+
7
+
8
+ @dialect.register
9
+ class StimAuxMethods(interp.MethodTable):
10
+
11
+ @interp.impl(stmts.ConstFloat)
12
+ @interp.impl(stmts.ConstInt)
13
+ @interp.impl(stmts.ConstBool)
14
+ @interp.impl(stmts.ConstStr)
15
+ def const(
16
+ self,
17
+ interpreter: interp.Interpreter,
18
+ frame: interp.Frame,
19
+ stmt: stmts.ConstFloat | stmts.ConstInt | stmts.ConstBool | stmts.ConstStr,
20
+ ):
21
+ return (stmt.value,)
22
+
23
+ @interp.impl(stmts.Neg)
24
+ def neg(
25
+ self,
26
+ interpreter: interp.Interpreter,
27
+ frame: interp.Frame,
28
+ stmt: stmts.Neg,
29
+ ):
30
+ return (-frame.get(stmt.operand),)
31
+
32
+ @interp.impl(stmts.GetRecord)
33
+ def get_rec(
34
+ self,
35
+ interpreter: interp.Interpreter,
36
+ frame: interp.Frame,
37
+ stmt: stmts.GetRecord,
38
+ ):
39
+ return (RecordResult(value=frame.get(stmt.id)),)
@@ -0,0 +1,40 @@
1
+ import ast
2
+
3
+ from kirin import lowering
4
+
5
+ from . import stmts
6
+ from ._dialect import dialect
7
+
8
+
9
+ @dialect.register
10
+ class StimAuxLowering(lowering.FromPythonAST):
11
+
12
+ def _const_stmt(
13
+ self, state: lowering.State, value: int | float | str | bool
14
+ ) -> stmts.ConstInt | stmts.ConstFloat | stmts.ConstStr | stmts.ConstBool:
15
+
16
+ if isinstance(value, bool):
17
+ return stmts.ConstBool(value=value)
18
+ elif isinstance(value, int):
19
+ return stmts.ConstInt(value=value)
20
+ elif isinstance(value, float):
21
+ return stmts.ConstFloat(value=value)
22
+ elif isinstance(value, str):
23
+ return stmts.ConstStr(value=value)
24
+ else:
25
+ raise lowering.BuildError(f"unsupported Stim constant type {type(value)}")
26
+
27
+ def lower_Constant(self, state: lowering.State, node: ast.Constant):
28
+ stmt = self._const_stmt(state, node.value)
29
+ return state.current_frame.push(stmt)
30
+
31
+ def lower_Expr(self, state: lowering.State, node: ast.Expr):
32
+ return state.parent.visit(state, node.value) # just forward the visit
33
+
34
+ def lower_UnaryOp(self, state: lowering.State, node: ast.UnaryOp):
35
+ if isinstance(node.op, ast.USub):
36
+ value = state.lower(node.operand).expect_one()
37
+ stmt = stmts.Neg(operand=value)
38
+ return state.current_frame.push(stmt)
39
+ else:
40
+ raise lowering.BuildError(f"unsupported Stim unaryop {node.op}")
@@ -0,0 +1,14 @@
1
+ from .const import (
2
+ Neg as Neg,
3
+ ConstInt as ConstInt,
4
+ ConstStr as ConstStr,
5
+ ConstBool as ConstBool,
6
+ ConstFloat as ConstFloat,
7
+ )
8
+ from .annotate import (
9
+ Tick as Tick,
10
+ Detector as Detector,
11
+ GetRecord as GetRecord,
12
+ NewPauliString as NewPauliString,
13
+ ObservableInclude as ObservableInclude,
14
+ )
@@ -0,0 +1,47 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from ..types import RecordType, PauliStringType
5
+ from .._dialect import dialect
6
+
7
+ PyNum = types.Union(types.Int, types.Float)
8
+
9
+
10
+ @statement(dialect=dialect)
11
+ class GetRecord(ir.Statement):
12
+ name = "get_rec"
13
+ traits = frozenset({lowering.FromPythonCall()})
14
+ id: ir.SSAValue = info.argument(type=types.Int)
15
+ result: ir.ResultValue = info.result(type=RecordType)
16
+
17
+
18
+ @statement(dialect=dialect)
19
+ class Detector(ir.Statement):
20
+ name = "detector"
21
+ traits = frozenset({lowering.FromPythonCall()})
22
+ coord: tuple[ir.SSAValue, ...] = info.argument(PyNum)
23
+ targets: tuple[ir.SSAValue, ...] = info.argument(RecordType)
24
+
25
+
26
+ @statement(dialect=dialect)
27
+ class ObservableInclude(ir.Statement):
28
+ name = "obs.include"
29
+ traits = frozenset({lowering.FromPythonCall()})
30
+ idx: ir.SSAValue = info.argument(type=types.Int)
31
+ targets: tuple[ir.SSAValue, ...] = info.argument(RecordType)
32
+
33
+
34
+ @statement(dialect=dialect)
35
+ class Tick(ir.Statement):
36
+ name = "tick"
37
+ traits = frozenset({lowering.FromPythonCall()})
38
+
39
+
40
+ @statement(dialect=dialect)
41
+ class NewPauliString(ir.Statement):
42
+ name = "new_pauli_string"
43
+ traits = frozenset({lowering.FromPythonCall()})
44
+ string: tuple[ir.SSAValue, ...] = info.argument(types.String)
45
+ flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool)
46
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
47
+ result: ir.ResultValue = info.result(type=PauliStringType)
@@ -0,0 +1,95 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+ from kirin.print import Printer
4
+
5
+ from .._dialect import dialect as dialect
6
+
7
+
8
+ @statement(dialect=dialect)
9
+ class ConstInt(ir.Statement):
10
+ """IR Statement representing a constant integer value."""
11
+
12
+ name = "constant.int"
13
+ traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()})
14
+ value: int = info.attribute(types.Int)
15
+ """value (int): The constant integer value."""
16
+ result: ir.ResultValue = info.result(types.Int)
17
+ """result (Int): The result value."""
18
+
19
+ def print_impl(self, printer: Printer) -> None:
20
+ printer.print_name(self)
21
+ printer.plain_print(" ")
22
+ printer.plain_print(repr(self.value))
23
+ with printer.rich(style="comment"):
24
+ printer.plain_print(" : ")
25
+ printer.print(self.result.type)
26
+
27
+
28
+ @statement(dialect=dialect)
29
+ class ConstFloat(ir.Statement):
30
+ """IR Statement representing a constant float value."""
31
+
32
+ name = "constant.float"
33
+ traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()})
34
+ value: float = info.attribute(types.Float)
35
+ """value (float): The constant float value."""
36
+ result: ir.ResultValue = info.result(types.Float)
37
+ """result (Float): The result value."""
38
+
39
+ def print_impl(self, printer: Printer) -> None:
40
+ printer.print_name(self)
41
+ printer.plain_print(" ")
42
+ printer.plain_print(repr(self.value))
43
+ with printer.rich(style="comment"):
44
+ printer.plain_print(" : ")
45
+ printer.print(self.result.type)
46
+
47
+
48
+ @statement(dialect=dialect)
49
+ class ConstBool(ir.Statement):
50
+ """IR Statement representing a constant float value."""
51
+
52
+ name = "constant.bool"
53
+ traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()})
54
+ value: bool = info.attribute(types.Bool)
55
+ """value (float): The constant float value."""
56
+ result: ir.ResultValue = info.result(types.Bool)
57
+ """result (Float): The result value."""
58
+
59
+ def print_impl(self, printer: Printer) -> None:
60
+ printer.print_name(self)
61
+ printer.plain_print(" ")
62
+ printer.plain_print(repr(self.value))
63
+ with printer.rich(style="comment"):
64
+ printer.plain_print(" : ")
65
+ printer.print(self.result.type)
66
+
67
+
68
+ @statement(dialect=dialect)
69
+ class ConstStr(ir.Statement):
70
+ """IR Statement representing a constant str value."""
71
+
72
+ name = "constant.str"
73
+ traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()})
74
+ value: str = info.attribute(types.String)
75
+ """value (str): The constant str value."""
76
+ result: ir.ResultValue = info.result(types.String)
77
+ """result (str): The result value."""
78
+
79
+ def print_impl(self, printer: Printer) -> None:
80
+ printer.print_name(self)
81
+ printer.plain_print(" ")
82
+ printer.plain_print(repr(self.value))
83
+ with printer.rich(style="comment"):
84
+ printer.plain_print(" : ")
85
+ printer.print(self.result.type)
86
+
87
+
88
+ @statement(dialect=dialect)
89
+ class Neg(ir.Statement):
90
+ """IR Statement representing a negation operation."""
91
+
92
+ name = "neg"
93
+ traits = frozenset({lowering.FromPythonCall()})
94
+ operand: ir.SSAValue = info.argument(types.Int)
95
+ result: ir.ResultValue = info.result(types.Int)
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import types
4
+
5
+
6
+ @dataclass
7
+ class RecordResult:
8
+ value: int
9
+
10
+
11
+ @dataclass
12
+ class PauliString:
13
+ string: tuple[str, ...]
14
+ flipped: tuple[bool, ...]
15
+ targets: tuple[int, ...]
16
+
17
+
18
+ RecordType = types.PyClass(RecordResult)
19
+ PauliStringType = types.PyClass(PauliString)
@@ -0,0 +1,3 @@
1
+ from .emit import EmitStimCollapseMethods as EmitStimCollapseMethods
2
+ from .stmts import * # noqa F403
3
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("stim.collapse")
@@ -0,0 +1,68 @@
1
+ from kirin.emit import EmitStrFrame
2
+ from kirin.interp import MethodTable, impl
3
+
4
+ from bloqade.stim.emit.stim import EmitStimMain
5
+
6
+ from . import stmts
7
+ from ._dialect import dialect
8
+ from .stmts.reset import Reset
9
+ from .stmts.measure import Measurement
10
+
11
+
12
+ @dialect.register(key="emit.stim")
13
+ class EmitStimCollapseMethods(MethodTable):
14
+
15
+ meas_map: dict[str, str] = {
16
+ stmts.MX.name: "MX",
17
+ stmts.MY.name: "MY",
18
+ stmts.MZ.name: "MZ",
19
+ stmts.MXX.name: "MXX",
20
+ stmts.MYY.name: "MYY",
21
+ stmts.MZZ.name: "MZZ",
22
+ }
23
+
24
+ @impl(stmts.MX)
25
+ @impl(stmts.MY)
26
+ @impl(stmts.MZ)
27
+ @impl(stmts.MXX)
28
+ @impl(stmts.MYY)
29
+ @impl(stmts.MZZ)
30
+ def get_measure(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Measurement):
31
+
32
+ probability: str = frame.get(stmt.p)
33
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
34
+
35
+ out = f"{self.meas_map[stmt.name]}({probability}) " + " ".join(targets)
36
+ emit.writeln(frame, out)
37
+
38
+ return ()
39
+
40
+ reset_map: dict[str, str] = {
41
+ stmts.RX.name: "RX",
42
+ stmts.RY.name: "RY",
43
+ stmts.RZ.name: "RZ",
44
+ }
45
+
46
+ @impl(stmts.RX)
47
+ @impl(stmts.RY)
48
+ @impl(stmts.RZ)
49
+ def get_reset(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Reset):
50
+
51
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
52
+
53
+ out = f"{self.reset_map[stmt.name]} " + " ".join(targets)
54
+ emit.writeln(frame, out)
55
+
56
+ return ()
57
+
58
+ @impl(stmts.PPMeasurement)
59
+ def pp_measure(
60
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PPMeasurement
61
+ ):
62
+ probability: str = frame.get(stmt.p)
63
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
64
+
65
+ out = f"MPP({probability}) " + " ".join(targets)
66
+ emit.writeln(frame, out)
67
+
68
+ return ()
@@ -0,0 +1,3 @@
1
+ from .reset import RX as RX, RY as RY, RZ as RZ
2
+ from .measure import MX as MX, MY as MY, MZ as MZ, MXX as MXX, MYY as MYY, MZZ as MZZ
3
+ from .pp_measure import PPMeasurement as PPMeasurement
@@ -0,0 +1,45 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from .._dialect import dialect
5
+
6
+
7
+ @statement
8
+ class Measurement(ir.Statement):
9
+ name = "measurement"
10
+ traits = frozenset({lowering.FromPythonCall()})
11
+ p: ir.SSAValue = info.argument(types.Float)
12
+ """probability of noise introduced by measurement. For example 0.01 means 1% the measurement will be flipped"""
13
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
14
+
15
+
16
+ # 1Q measurements
17
+ @statement(dialect=dialect)
18
+ class MZ(Measurement):
19
+ name = "MZ"
20
+
21
+
22
+ @statement(dialect=dialect)
23
+ class MY(Measurement):
24
+ name = "MY"
25
+
26
+
27
+ @statement(dialect=dialect)
28
+ class MX(Measurement):
29
+ name = "MX"
30
+
31
+
32
+ # Pair measurements
33
+ @statement(dialect=dialect)
34
+ class MZZ(Measurement):
35
+ name = "MZZ"
36
+
37
+
38
+ @statement(dialect=dialect)
39
+ class MYY(Measurement):
40
+ name = "MYY"
41
+
42
+
43
+ @statement(dialect=dialect)
44
+ class MXX(Measurement):
45
+ name = "MXX"
@@ -0,0 +1,14 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from .._dialect import dialect
5
+ from ...aux.types import PauliStringType
6
+
7
+
8
+ @statement(dialect=dialect)
9
+ class PPMeasurement(ir.Statement):
10
+ name = "MPP"
11
+ traits = frozenset({lowering.FromPythonCall()})
12
+ p: ir.SSAValue = info.argument(types.Float)
13
+ """probability of noise introduced by measurement. For example 0.01 means 1% the measurement will be flipped"""
14
+ targets: tuple[ir.SSAValue, ...] = info.argument(PauliStringType)
@@ -0,0 +1,26 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from .._dialect import dialect
5
+
6
+
7
+ @statement
8
+ class Reset(ir.Statement):
9
+ name = "reset"
10
+ traits = frozenset({lowering.FromPythonCall()})
11
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
12
+
13
+
14
+ @statement(dialect=dialect)
15
+ class RZ(Reset):
16
+ name = "RZ"
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class RY(Reset):
21
+ name = "RY"
22
+
23
+
24
+ @statement(dialect=dialect)
25
+ class RX(Reset):
26
+ name = "RX"
@@ -0,0 +1,3 @@
1
+ from .emit import EmitStimGateMethods as EmitStimGateMethods
2
+ from .stmts import * # noqa F403
3
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("stim.gate")
@@ -0,0 +1,87 @@
1
+ from kirin.emit import EmitStrFrame
2
+ from kirin.interp import MethodTable, impl
3
+
4
+ from bloqade.stim.emit.stim import EmitStimMain
5
+
6
+ from . import stmts
7
+ from ._dialect import dialect
8
+ from .stmts.base import SingleQubitGate, ControlledTwoQubitGate
9
+
10
+
11
+ @dialect.register(key="emit.stim")
12
+ class EmitStimGateMethods(MethodTable):
13
+
14
+ gate_1q_map: dict[str, tuple[str, str]] = {
15
+ stmts.X.name: ("X", "X"),
16
+ stmts.Y.name: ("Y", "Y"),
17
+ stmts.Z.name: ("Z", "Z"),
18
+ stmts.H.name: ("H", "H"),
19
+ stmts.S.name: ("S", "S_DAG"),
20
+ stmts.SqrtX.name: ("SQRT_X", "SQRT_X_DAG"),
21
+ stmts.SqrtY.name: ("SQRT_Y", "SQRT_Y_DAG"),
22
+ stmts.SqrtZ.name: ("SQRT_Z", "SQRT_Z_DAG"),
23
+ }
24
+
25
+ @impl(stmts.X)
26
+ @impl(stmts.Y)
27
+ @impl(stmts.Z)
28
+ @impl(stmts.S)
29
+ @impl(stmts.H)
30
+ @impl(stmts.SqrtX)
31
+ @impl(stmts.SqrtY)
32
+ @impl(stmts.SqrtZ)
33
+ def single_qubit_gate(
34
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: SingleQubitGate
35
+ ):
36
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
37
+ res = f"{self.gate_1q_map[stmt.name][int(stmt.dagger)]} " + " ".join(targets)
38
+ emit.writeln(frame, res)
39
+
40
+ return ()
41
+
42
+ gate_2q_map: dict[str, tuple[str, str]] = {
43
+ stmts.Swap.name: ("SWAP", "SWAP"),
44
+ }
45
+
46
+ @impl(stmts.Swap)
47
+ def two_qubit_gate(
48
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: ControlledTwoQubitGate
49
+ ):
50
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
51
+ res = f"{self.gate_ctrl_2q_map[stmt.name][int(stmt.dagger)]} " + " ".join(
52
+ targets
53
+ )
54
+ emit.writeln(frame, res)
55
+
56
+ return ()
57
+
58
+ gate_ctrl_2q_map: dict[str, tuple[str, str]] = {
59
+ stmts.CX.name: ("CX", "CX"),
60
+ stmts.CY.name: ("CY", "CY"),
61
+ stmts.CZ.name: ("CZ", "CZ"),
62
+ stmts.Swap.name: ("SWAP", "SWAP"),
63
+ }
64
+
65
+ @impl(stmts.CX)
66
+ @impl(stmts.CY)
67
+ @impl(stmts.CZ)
68
+ def ctrl_two_qubit_gate(
69
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: ControlledTwoQubitGate
70
+ ):
71
+ controls: tuple[str, ...] = frame.get_values(stmt.controls)
72
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
73
+ res = f"{self.gate_ctrl_2q_map[stmt.name][int(stmt.dagger)]} " + " ".join(
74
+ f"{ctrl} {tgt}" for ctrl, tgt in zip(controls, targets)
75
+ )
76
+ emit.writeln(frame, res)
77
+
78
+ return ()
79
+
80
+ @impl(stmts.SPP)
81
+ def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP):
82
+
83
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
84
+ res = "SPP " + " ".join(targets)
85
+ emit.writeln(frame, res)
86
+
87
+ return ()
@@ -0,0 +1,14 @@
1
+ from .pp import SPP as SPP
2
+ from .control_2q import CX as CX, CY as CY, CZ as CZ
3
+ from .clifford_1q import (
4
+ H as H,
5
+ S as S,
6
+ X as X,
7
+ Y as Y,
8
+ Z as Z,
9
+ SqrtX as SqrtX,
10
+ SqrtY as SqrtY,
11
+ SqrtZ as SqrtZ,
12
+ Identity as Identity,
13
+ )
14
+ from .clifford_2q import Swap as Swap
@@ -0,0 +1,31 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from bloqade.stim.dialects.aux import RecordType
5
+
6
+
7
+ @statement
8
+ class Gate(ir.Statement):
9
+ name = "stim_gate"
10
+ traits = frozenset({lowering.FromPythonCall()})
11
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
12
+ dagger: bool = info.attribute(default=False)
13
+
14
+
15
+ @statement
16
+ class SingleQubitGate(Gate):
17
+ name = "single_qubit_gate"
18
+
19
+
20
+ @statement
21
+ class TwoQubitGate(Gate):
22
+ name = "two_qubit_gate"
23
+
24
+
25
+ # control can either be a qubit or a measurement record
26
+ @statement
27
+ class ControlledTwoQubitGate(Gate):
28
+ name = "controlled_two_qubit_gate"
29
+ controls: tuple[ir.SSAValue, ...] = info.argument(
30
+ types.Union(types.Int, RecordType)
31
+ )