bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,53 @@
1
+ from kirin.decl import statement
2
+
3
+ from .base import SingleQubitGate
4
+ from .._dialect import dialect
5
+
6
+
7
+ # Pauli Gates
8
+ # -----------------------------------
9
+ @statement(dialect=dialect)
10
+ class Identity(SingleQubitGate):
11
+ name = "I"
12
+
13
+
14
+ @statement(dialect=dialect)
15
+ class X(SingleQubitGate):
16
+ name = "X"
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class Y(SingleQubitGate):
21
+ name = "Y"
22
+
23
+
24
+ @statement(dialect=dialect)
25
+ class Z(SingleQubitGate):
26
+ name = "Z"
27
+
28
+
29
+ # Single Qubit Clifford Gates
30
+ # ---------------------------------------
31
+ @statement(dialect=dialect)
32
+ class H(SingleQubitGate):
33
+ name = "H"
34
+
35
+
36
+ @statement(dialect=dialect)
37
+ class S(SingleQubitGate):
38
+ name = "S"
39
+
40
+
41
+ @statement(dialect=dialect)
42
+ class SqrtX(SingleQubitGate):
43
+ name = "SQRT_X"
44
+
45
+
46
+ @statement(dialect=dialect)
47
+ class SqrtY(SingleQubitGate):
48
+ name = "SQRT_Y"
49
+
50
+
51
+ @statement(dialect=dialect)
52
+ class SqrtZ(SingleQubitGate):
53
+ name = "SQRT_Z"
@@ -0,0 +1,11 @@
1
+ from kirin.decl import statement
2
+
3
+ from .base import TwoQubitGate
4
+ from .._dialect import dialect
5
+
6
+
7
+ # Two Qubit Clifford Gates
8
+ # ---------------------------------------
9
+ @statement(dialect=dialect)
10
+ class Swap(TwoQubitGate):
11
+ name = "SWAP"
@@ -0,0 +1,21 @@
1
+ from kirin.decl import statement
2
+
3
+ from .base import ControlledTwoQubitGate
4
+ from .._dialect import dialect
5
+
6
+
7
+ # Two Qubit Clifford Gates
8
+ # ---------------------------------------
9
+ @statement(dialect=dialect)
10
+ class CX(ControlledTwoQubitGate):
11
+ name = "CX"
12
+
13
+
14
+ @statement(dialect=dialect)
15
+ class CY(ControlledTwoQubitGate):
16
+ name = "CY"
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class CZ(ControlledTwoQubitGate):
21
+ name = "CZ"
@@ -0,0 +1,15 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from .._dialect import dialect
5
+ from ...aux.types import PauliStringType
6
+
7
+
8
+ # Generalized Pauli-product gates
9
+ # ---------------------------------------
10
+ @statement(dialect=dialect)
11
+ class SPP(ir.Statement):
12
+ name = "SPP"
13
+ traits = frozenset({lowering.FromPythonCall()})
14
+ dagger: bool = info.attribute(types.Bool, default=False)
15
+ targets: tuple[ir.SSAValue, ...] = info.argument(PauliStringType)
@@ -0,0 +1,3 @@
1
+ from .emit import EmitStimNoiseMethods as EmitStimNoiseMethods
2
+ from .stmts import * # noqa F403
3
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("stim.noise")
@@ -0,0 +1,66 @@
1
+ from kirin.emit import EmitStrFrame
2
+ from kirin.interp import MethodTable, impl
3
+
4
+ from bloqade.stim.emit.stim import EmitStimMain
5
+
6
+ from . import stmts
7
+ from ._dialect import dialect
8
+
9
+
10
+ @dialect.register(key="emit.stim")
11
+ class EmitStimNoiseMethods(MethodTable):
12
+
13
+ single_p_error_map: dict[str, str] = {
14
+ stmts.Depolarize1.name: "DEPOLARIZE1",
15
+ stmts.Depolarize2.name: "DEPOLARIZE2",
16
+ stmts.XError.name: "X_ERROR",
17
+ stmts.YError.name: "Y_ERROR",
18
+ stmts.ZError.name: "Z_ERROR",
19
+ }
20
+
21
+ @impl(stmts.XError)
22
+ @impl(stmts.YError)
23
+ @impl(stmts.ZError)
24
+ @impl(stmts.Depolarize1)
25
+ @impl(stmts.Depolarize2)
26
+ def single_p_error(
27
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Depolarize1
28
+ ):
29
+
30
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
31
+ p: str = frame.get(stmt.p)
32
+ name = self.single_p_error_map[stmt.name]
33
+ res = f"{name}({p}) " + " ".join(targets)
34
+ emit.writeln(frame, res)
35
+
36
+ return ()
37
+
38
+ @impl(stmts.PauliChannel1)
39
+ def pauli_channel1(
40
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PauliChannel1
41
+ ):
42
+
43
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
44
+ px: str = frame.get(stmt.px)
45
+ py: str = frame.get(stmt.py)
46
+ pz: str = frame.get(stmt.pz)
47
+ res = f"PAULI_CHANNEL_1({px},{py},{pz}) " + " ".join(targets)
48
+ emit.writeln(frame, res)
49
+
50
+ return ()
51
+
52
+ @impl(stmts.PauliChannel2)
53
+ def pauli_channel2(
54
+ self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PauliChannel2
55
+ ):
56
+
57
+ targets: tuple[str, ...] = frame.get_values(stmt.targets)
58
+ prob: tuple[str, ...] = frame.get_values(stmt.args)[
59
+ :15
60
+ ] # extract the first 15 argument, which is the probabilities
61
+ prob_str: str = ", ".join(prob)
62
+
63
+ res = f"PAULI_CHANNEL_2({prob_str}) " + " ".join(targets)
64
+ emit.writeln(frame, res)
65
+
66
+ return ()
@@ -0,0 +1,77 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+
4
+ from ._dialect import dialect
5
+
6
+
7
+ @statement(dialect=dialect)
8
+ class Depolarize1(ir.Statement):
9
+ name = "Depolarize1"
10
+ traits = frozenset({lowering.FromPythonCall()})
11
+ p: ir.SSAValue = info.argument(types.Float)
12
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
13
+
14
+
15
+ @statement(dialect=dialect)
16
+ class Depolarize2(ir.Statement):
17
+ name = "Depolarize2"
18
+ traits = frozenset({lowering.FromPythonCall()})
19
+ p: ir.SSAValue = info.argument(types.Float)
20
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
21
+
22
+
23
+ @statement(dialect=dialect)
24
+ class PauliChannel1(ir.Statement):
25
+ name = "PauliChannel1"
26
+ traits = frozenset({lowering.FromPythonCall()})
27
+ px: ir.SSAValue = info.argument(types.Float)
28
+ py: ir.SSAValue = info.argument(types.Float)
29
+ pz: ir.SSAValue = info.argument(types.Float)
30
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
31
+
32
+
33
+ @statement(dialect=dialect)
34
+ class PauliChannel2(ir.Statement):
35
+ name = "PauliChannel2"
36
+ # TODO custom lowering to make sugar for this
37
+ traits = frozenset({lowering.FromPythonCall()})
38
+ pix: ir.SSAValue = info.argument(types.Float)
39
+ piy: ir.SSAValue = info.argument(types.Float)
40
+ piz: ir.SSAValue = info.argument(types.Float)
41
+ pxi: ir.SSAValue = info.argument(types.Float)
42
+ pxx: ir.SSAValue = info.argument(types.Float)
43
+ pxy: ir.SSAValue = info.argument(types.Float)
44
+ pxz: ir.SSAValue = info.argument(types.Float)
45
+ pyi: ir.SSAValue = info.argument(types.Float)
46
+ pyx: ir.SSAValue = info.argument(types.Float)
47
+ pyy: ir.SSAValue = info.argument(types.Float)
48
+ pyz: ir.SSAValue = info.argument(types.Float)
49
+ pzi: ir.SSAValue = info.argument(types.Float)
50
+ pzx: ir.SSAValue = info.argument(types.Float)
51
+ pzy: ir.SSAValue = info.argument(types.Float)
52
+ pzz: ir.SSAValue = info.argument(types.Float)
53
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
54
+
55
+
56
+ @statement(dialect=dialect)
57
+ class XError(ir.Statement):
58
+ name = "X_ERROR"
59
+ traits = frozenset({lowering.FromPythonCall()})
60
+ p: ir.SSAValue = info.argument(types.Float)
61
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
62
+
63
+
64
+ @statement(dialect=dialect)
65
+ class YError(ir.Statement):
66
+ name = "Y_ERROR"
67
+ traits = frozenset({lowering.FromPythonCall()})
68
+ p: ir.SSAValue = info.argument(types.Float)
69
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
70
+
71
+
72
+ @statement(dialect=dialect)
73
+ class ZError(ir.Statement):
74
+ name = "Z_ERROR"
75
+ traits = frozenset({lowering.FromPythonCall()})
76
+ p: ir.SSAValue = info.argument(types.Float)
77
+ targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
@@ -0,0 +1 @@
1
+ from .stim import FuncEmit as FuncEmit, EmitStimMain as EmitStimMain
@@ -0,0 +1,54 @@
1
+ from io import StringIO
2
+ from typing import IO, TypeVar
3
+ from dataclasses import field, dataclass
4
+
5
+ from kirin import ir, interp
6
+ from kirin.emit import EmitStr, EmitStrFrame
7
+ from kirin.dialects import func
8
+
9
+ IO_t = TypeVar("IO_t", bound=IO)
10
+
11
+
12
+ def _default_dialect_group() -> ir.DialectGroup:
13
+ from ..groups import main
14
+
15
+ return main
16
+
17
+
18
+ @dataclass
19
+ class EmitStimMain(EmitStr):
20
+ keys = ["emit.stim"]
21
+ dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
22
+ file: StringIO = field(default_factory=StringIO)
23
+
24
+ def initialize(self):
25
+ super().initialize()
26
+ self.file.truncate(0)
27
+ self.file.seek(0)
28
+ return self
29
+
30
+ def eval_stmt_fallback(
31
+ self, frame: EmitStrFrame, stmt: ir.Statement
32
+ ) -> tuple[str, ...]:
33
+ return (stmt.name,)
34
+
35
+ def emit_block(self, frame: EmitStrFrame, block: ir.Block) -> str | None:
36
+ for stmt in block.stmts:
37
+ result = self.eval_stmt(frame, stmt)
38
+ if isinstance(result, tuple):
39
+ frame.set_values(stmt.results, result)
40
+ return None
41
+
42
+ def get_output(self) -> str:
43
+ self.file.seek(0)
44
+ return self.file.read()
45
+
46
+
47
+ @func.dialect.register(key="emit.stim")
48
+ class FuncEmit(interp.MethodTable):
49
+
50
+ @interp.impl(func.Function)
51
+ def emit_func(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: func.Function):
52
+ _ = emit.run_ssacfg_region(frame, stmt.body)
53
+ # emit.output = "\n".join(frame.body)
54
+ return ()
bloqade/stim/groups.py ADDED
@@ -0,0 +1,26 @@
1
+ from kirin import ir
2
+ from kirin.passes import Fold, TypeInfer
3
+ from kirin.dialects import func, lowering
4
+
5
+ from .dialects import aux, gate, noise, collapse
6
+
7
+
8
+ @ir.dialect_group([noise, gate, aux, collapse, func, lowering.func, lowering.call])
9
+ def main(self):
10
+ typeinfer_pass = TypeInfer(self)
11
+ fold_pass = Fold(self)
12
+
13
+ def run_pass(
14
+ mt: ir.Method,
15
+ *,
16
+ typeinfer: bool = False,
17
+ fold: bool = True,
18
+ ) -> None:
19
+
20
+ if typeinfer:
21
+ typeinfer_pass(mt)
22
+
23
+ if fold:
24
+ fold_pass(mt)
25
+
26
+ return run_pass
bloqade/test_utils.py ADDED
@@ -0,0 +1,35 @@
1
+ import io
2
+ import difflib
3
+
4
+ from kirin import ir, print as pprint
5
+ from rich.console import Console
6
+
7
+
8
+ def print_diff(node: pprint.Printable, expected_node: pprint.Printable):
9
+ gn_con = Console(record=True, file=io.StringIO())
10
+ node.print(console=gn_con)
11
+
12
+ expected_con = Console(record=True, file=io.StringIO())
13
+ expected_node.print(console=expected_con)
14
+
15
+ expected = expected_con.export_text()
16
+
17
+ gn = gn_con.export_text()
18
+ diff = difflib.Differ().compare(
19
+ expected.splitlines(),
20
+ gn.splitlines(),
21
+ )
22
+
23
+ print("\n".join(diff))
24
+
25
+
26
+ def assert_nodes(node: ir.IRNode, expected_node: ir.IRNode):
27
+ try:
28
+ assert node.is_equal(expected_node)
29
+ except AssertionError as e:
30
+ print_diff(node, expected_node)
31
+ raise e
32
+
33
+
34
+ def assert_methods(mt: ir.Method, expected_mt: ir.Method):
35
+ assert_nodes(mt.code, expected_mt.code)
bloqade/types.py ADDED
@@ -0,0 +1,24 @@
1
+ """Bloqade types.
2
+
3
+ This module defines the basic types used in Bloqade eDSLs.
4
+ """
5
+
6
+ from abc import ABC
7
+
8
+ from kirin import types
9
+
10
+
11
+ class Qubit(ABC):
12
+ """Runtime representation of a qubit.
13
+
14
+ Note:
15
+ This is the base class of more specific qubit types, such as
16
+ a reference to a piece of quantum register in some quantum register
17
+ dialects.
18
+ """
19
+
20
+ pass
21
+
22
+
23
+ QubitType = types.PyClass(Qubit)
24
+ """Kirin type for a qubit."""
@@ -0,0 +1 @@
1
+ from . import animation as animation
File without changes
@@ -0,0 +1,267 @@
1
+ import bisect
2
+ import functools
3
+ from typing import Optional
4
+
5
+ import tqdm
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.animation import FuncAnimation
9
+
10
+ from .base import FieldOfView, GatePainter, quera_color_code
11
+ from .runtime.qpustate import QPUStateABC
12
+
13
+
14
+ def animate_qpu_state(
15
+ state: QPUStateABC,
16
+ display_fov: Optional[FieldOfView] = None,
17
+ dilation_rate: float = 0.05,
18
+ fps: int = 30,
19
+ gate_display_dilation: float = 1.0,
20
+ fig_args={},
21
+ save_mpeg: bool = False,
22
+ filename: str = "vqpu_animation",
23
+ start_block: int = 0,
24
+ n_blocks: int | None = None,
25
+ ):
26
+ """Generate an animation from the QPU state
27
+
28
+ Args:
29
+ state (QPUStateABC): The QPU state to animate
30
+ display_fov (Optional[FieldOfView], optional): The field of view to display. Defaults to None. If None, it will use the QPU's field of view.
31
+ dilation_rate (float, optional): The rate at which to dilate the time. Defaults to 0.05.
32
+ fps (int, optional): The frames per second. Defaults to 30.
33
+ gate_display_dilation (float, optional): The rate at which to dilate the gate display. Defaults to 1.0.
34
+ fig_args (dict, optional): The arguments to pass to the matplotlib.pyplot.figure. Defaults to {}.
35
+ save_mpeg (bool, optional): Whether to save the animation as an mpeg. Defaults to False.
36
+ filename (str, optional): The filename to save the mpeg as. Defaults to "vqpu_animation".
37
+ start_block (int, optional): The block to start the animation at. Defaults to 0.
38
+ n_blocks (int | None, optional): The number of blocks to animate. Defaults to None. If None, it will animate all blocks after `start_block`.
39
+
40
+ """
41
+ qpu_fov = state.qpu_fov
42
+
43
+ if display_fov is None:
44
+ display_fov = qpu_fov
45
+
46
+ if start_block >= len(state.block_durations) or start_block < 0:
47
+ raise ValueError("Start block index is out of range")
48
+
49
+ if n_blocks is None:
50
+ n_blocks = len(state.block_durations) - start_block
51
+
52
+ if n_blocks < 0:
53
+ raise ValueError("Number of block to animate must be non-negative")
54
+
55
+ slm_sites = state.get_slm_sites()
56
+
57
+ # Scale the figure to different screens and so that the number of SLM sites has the same
58
+ # "area" on screen
59
+ nsites = max([4, len(slm_sites)])
60
+ scale = (
61
+ np.sqrt(44.0 / nsites) * 2.0 * plt.rcParams["figure.dpi"] / 100
62
+ ) # scale the size of the figure
63
+
64
+ # figure:
65
+ new_fig_args = {"figsize": (14, 8), **fig_args}
66
+ fig, mpl_axs = plt.subplot_mosaic(
67
+ mosaic=[["Reg", "Info"], ["Reg", "Gate"], ["Reg", "Gate"]],
68
+ gridspec_kw={"width_ratios": [3, 1]},
69
+ **new_fig_args,
70
+ )
71
+
72
+ # mpl_axs["Reg"].axis("equal") # Axis equal must come before axis limits
73
+ mpl_axs["Reg"].set_xlim(left=display_fov.xmin, right=display_fov.xmax)
74
+ mpl_axs["Reg"].set_ylim(bottom=display_fov.ymin, top=display_fov.ymax)
75
+ mpl_axs["Reg"].set(xlabel="x (um)", ylabel="y (um)")
76
+ mpl_axs["Reg"].set_aspect("equal")
77
+
78
+ # slm:
79
+ slm_plt_arg = {
80
+ "facecolors": "none",
81
+ "edgecolors": "k",
82
+ "linestyle": "-",
83
+ "s": 80 * scale,
84
+ "alpha": 0.3,
85
+ "linewidth": 0.5 * np.sqrt(scale),
86
+ }
87
+ mpl_axs["Reg"].scatter(
88
+ x=slm_sites[:, 0], y=slm_sites[:, 1], **slm_plt_arg
89
+ ) # this is statically generated, so it will be the background
90
+
91
+ # atoms:
92
+ reg_plt_arg = {
93
+ "s": 65 * scale,
94
+ "marker": "o",
95
+ "facecolors": quera_color_code.purple,
96
+ "alpha": 1.0,
97
+ }
98
+ reg_panel = mpl_axs["Reg"]
99
+ reg_scat = reg_panel.scatter([], [], **reg_plt_arg)
100
+
101
+ # gates:
102
+ gp = GatePainter(mpl_ax=reg_panel, qpu_fov=qpu_fov, scale=scale)
103
+
104
+ # annotate_args = {"fontsize": 8, "ha": "center", "alpha": 0.7, "color": quera_color_code.yellow}
105
+ annotate_args = {
106
+ "fontsize": 6 * np.sqrt(scale),
107
+ "ha": "center",
108
+ "va": "center",
109
+ "alpha": 1.0,
110
+ "color": quera_color_code.yellow,
111
+ "weight": "bold",
112
+ }
113
+ reg_annot_list = [
114
+ reg_panel.annotate(f"{i}", atom_position, **annotate_args)
115
+ for i, atom_position in state.get_atoms_position(time=0.0, include_lost=False)
116
+ ]
117
+
118
+ # AODs:
119
+ aod_plot_args = {
120
+ "s": 260 * scale,
121
+ "marker": "+",
122
+ "alpha": 0.7,
123
+ "facecolors": quera_color_code.red,
124
+ "zorder": -100,
125
+ "linewidth": np.sqrt(scale),
126
+ }
127
+ aod_scat = reg_panel.scatter(x=[], y=[], **aod_plot_args)
128
+
129
+ aod_h_args = {
130
+ "s": 1e20,
131
+ "marker": "|",
132
+ "alpha": 1.0,
133
+ "color": "#FFE8E9",
134
+ "zorder": -101,
135
+ "linewidth": 0.5 * np.sqrt(scale),
136
+ }
137
+ aod_h_scat = reg_panel.scatter(x=[], y=[], **aod_h_args)
138
+ aod_v_args = {
139
+ "s": 1e20,
140
+ "marker": "_",
141
+ "alpha": 1.0,
142
+ "color": "#FFE8E9",
143
+ "zorder": -101,
144
+ "linewidth": 0.5 * np.sqrt(scale),
145
+ }
146
+ aod_v_scat = reg_panel.scatter(x=[], y=[], **aod_v_args)
147
+
148
+ ## Info Panel
149
+ info_text = mpl_axs["Info"].text(x=0.05, y=0.5, s="")
150
+ mpl_axs["Info"].set_xticks([])
151
+ mpl_axs["Info"].set_yticks([])
152
+ mpl_axs["Info"].grid(False)
153
+
154
+ ## Event Panel:
155
+ log_text = mpl_axs["Gate"].text(x=0.05, y=0.0, s="", size=6)
156
+ mpl_axs["Gate"].set_xticks([])
157
+ mpl_axs["Gate"].set_yticks([])
158
+ mpl_axs["Gate"].grid(False)
159
+
160
+ tstep_mv = 1.0 / (fps * dilation_rate)
161
+ tstep_gate = 1.0 / (fps * dilation_rate * gate_display_dilation)
162
+ blk_t_end = np.cumsum(state.block_durations)
163
+
164
+ # determine the dilation part of the timeline, and generate more frame
165
+ chunk_times = []
166
+ curr_t = 0 if start_block == 0 else blk_t_end[start_block - 1]
167
+
168
+ for glb_tstart_gate, duration in state.get_gate_events_timing():
169
+ if glb_tstart_gate < curr_t: # gate start before the current time
170
+ if glb_tstart_gate + duration < curr_t:
171
+ continue
172
+ else:
173
+ dt = glb_tstart_gate - curr_t
174
+ chunk_times.append(np.linspace(curr_t, glb_tstart_gate, int(dt / tstep_mv)))
175
+ curr_t = glb_tstart_gate
176
+
177
+ t_gate_end = glb_tstart_gate + duration
178
+ dt = t_gate_end - curr_t
179
+ chunk_times.append(np.linspace(curr_t, t_gate_end, int(dt / tstep_gate)))
180
+ curr_t = t_gate_end
181
+
182
+ dt = blk_t_end[-1] - curr_t
183
+ chunk_times.append(np.linspace(curr_t, blk_t_end[-1], int(dt / tstep_mv)))
184
+
185
+ times = np.concatenate(chunk_times)
186
+
187
+ fig.tight_layout()
188
+ fig.subplots_adjust(wspace=0.1)
189
+
190
+ def _update_annotate(loc, idx, annotate_artist):
191
+ new_loc = (loc[0], loc[1] - 0.06)
192
+ annotate_artist.set_position(new_loc)
193
+ txt = f"{idx}"
194
+ annotate_artist.set_text(txt)
195
+ return loc
196
+
197
+ def update(frame: int, state: QPUStateABC, times: np.ndarray, blk_t_end: np.array):
198
+
199
+ # get positions:
200
+
201
+ blk_id = bisect.bisect_left(blk_t_end, times[frame])
202
+ lbl = f"Block: [{blk_id}]\n"
203
+ lbl += f"Block dur: {state.block_durations[blk_id]:.2f} us\n"
204
+ lbl += f"Total elapsed time: {times[frame]:.2f} us"
205
+ info_text.set_text(lbl)
206
+
207
+ # update atoms location and annotation
208
+ post = np.array(
209
+ [
210
+ _update_annotate(
211
+ atom_position,
212
+ i,
213
+ reg_annot_list[i],
214
+ )
215
+ for i, atom_position in state.get_atoms_position(
216
+ times[frame], include_lost=False
217
+ )
218
+ ]
219
+ )
220
+ post = post if post.size > 0 else np.array([(None, None)])
221
+ reg_scat.set_offsets(post)
222
+
223
+ # update log event panels
224
+ lost_events = state.get_atoms_lost_info(times[frame])
225
+
226
+ # update log gate:
227
+ gate_events = state.get_gate_events(times[frame])
228
+ gate_events_log = [
229
+ f"Gate: {gate.cls_name} @ {t:.6f} (us)\n"
230
+ for t, gate in state.get_gate_events(times[frame])
231
+ ]
232
+ log_text.set_text("".join(lost_events) + "".join(gate_events_log))
233
+
234
+ gate_artists = gp.process_gates([gate for _, gate in gate_events])
235
+
236
+ # update AODs
237
+ post = state.sample_aod_traps(times[frame]) or [(None, None)]
238
+ aod_scat.set_offsets(post)
239
+ aod_v_scat.set_offsets(post)
240
+ aod_h_scat.set_offsets(post)
241
+
242
+ return (
243
+ [reg_scat, info_text, log_text, aod_scat, aod_v_scat, aod_h_scat]
244
+ + reg_annot_list
245
+ + gate_artists
246
+ )
247
+
248
+ ani = FuncAnimation(
249
+ fig=fig,
250
+ func=functools.partial(update, state=state, times=times, blk_t_end=blk_t_end),
251
+ frames=len(times),
252
+ interval=tstep_mv,
253
+ blit=True,
254
+ repeat=False,
255
+ )
256
+ if save_mpeg:
257
+ n_frame = len(times)
258
+ pbar = tqdm.tqdm(range(n_frame))
259
+
260
+ def p_call_back(i, total_n):
261
+ pbar.update()
262
+
263
+ ani.save(
264
+ f"{filename}.mp4", writer="ffmpeg", fps=fps, progress_callback=p_call_back
265
+ )
266
+ else:
267
+ return ani