bloqade-circuit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (153) hide show
  1. bloqade/analysis/__init__.py +0 -0
  2. bloqade/analysis/address/__init__.py +11 -0
  3. bloqade/analysis/address/analysis.py +60 -0
  4. bloqade/analysis/address/impls.py +228 -0
  5. bloqade/analysis/address/lattice.py +85 -0
  6. bloqade/noise/__init__.py +1 -0
  7. bloqade/noise/native/__init__.py +20 -0
  8. bloqade/noise/native/_dialect.py +3 -0
  9. bloqade/noise/native/_wrappers.py +34 -0
  10. bloqade/noise/native/model.py +347 -0
  11. bloqade/noise/native/rewrite.py +35 -0
  12. bloqade/noise/native/stmts.py +46 -0
  13. bloqade/pyqrack/__init__.py +18 -0
  14. bloqade/pyqrack/base.py +131 -0
  15. bloqade/pyqrack/noise/__init__.py +0 -0
  16. bloqade/pyqrack/noise/native.py +100 -0
  17. bloqade/pyqrack/qasm2/__init__.py +0 -0
  18. bloqade/pyqrack/qasm2/core.py +79 -0
  19. bloqade/pyqrack/qasm2/parallel.py +46 -0
  20. bloqade/pyqrack/qasm2/uop.py +247 -0
  21. bloqade/pyqrack/reg.py +109 -0
  22. bloqade/pyqrack/target.py +112 -0
  23. bloqade/qasm2/__init__.py +19 -0
  24. bloqade/qasm2/_wrappers.py +674 -0
  25. bloqade/qasm2/dialects/__init__.py +10 -0
  26. bloqade/qasm2/dialects/core/__init__.py +3 -0
  27. bloqade/qasm2/dialects/core/_dialect.py +3 -0
  28. bloqade/qasm2/dialects/core/_emit.py +68 -0
  29. bloqade/qasm2/dialects/core/_typeinfer.py +23 -0
  30. bloqade/qasm2/dialects/core/address.py +38 -0
  31. bloqade/qasm2/dialects/core/stmts.py +94 -0
  32. bloqade/qasm2/dialects/expr/__init__.py +3 -0
  33. bloqade/qasm2/dialects/expr/_dialect.py +3 -0
  34. bloqade/qasm2/dialects/expr/_emit.py +103 -0
  35. bloqade/qasm2/dialects/expr/_from_python.py +86 -0
  36. bloqade/qasm2/dialects/expr/_interp.py +75 -0
  37. bloqade/qasm2/dialects/expr/stmts.py +262 -0
  38. bloqade/qasm2/dialects/glob.py +45 -0
  39. bloqade/qasm2/dialects/indexing.py +64 -0
  40. bloqade/qasm2/dialects/inline.py +76 -0
  41. bloqade/qasm2/dialects/noise.py +16 -0
  42. bloqade/qasm2/dialects/parallel.py +110 -0
  43. bloqade/qasm2/dialects/uop/__init__.py +4 -0
  44. bloqade/qasm2/dialects/uop/_dialect.py +3 -0
  45. bloqade/qasm2/dialects/uop/_emit.py +211 -0
  46. bloqade/qasm2/dialects/uop/schedule.py +89 -0
  47. bloqade/qasm2/dialects/uop/stmts.py +325 -0
  48. bloqade/qasm2/emit/__init__.py +1 -0
  49. bloqade/qasm2/emit/base.py +72 -0
  50. bloqade/qasm2/emit/gate.py +102 -0
  51. bloqade/qasm2/emit/main.py +106 -0
  52. bloqade/qasm2/emit/target.py +165 -0
  53. bloqade/qasm2/glob.py +24 -0
  54. bloqade/qasm2/groups.py +120 -0
  55. bloqade/qasm2/parallel.py +48 -0
  56. bloqade/qasm2/parse/__init__.py +37 -0
  57. bloqade/qasm2/parse/ast.py +235 -0
  58. bloqade/qasm2/parse/build.py +289 -0
  59. bloqade/qasm2/parse/lowering.py +553 -0
  60. bloqade/qasm2/parse/parser.py +5 -0
  61. bloqade/qasm2/parse/print.py +293 -0
  62. bloqade/qasm2/parse/qasm2.lark +75 -0
  63. bloqade/qasm2/parse/visitor.py +16 -0
  64. bloqade/qasm2/parse/visitor.pyi +39 -0
  65. bloqade/qasm2/passes/__init__.py +5 -0
  66. bloqade/qasm2/passes/fold.py +94 -0
  67. bloqade/qasm2/passes/glob.py +119 -0
  68. bloqade/qasm2/passes/noise.py +61 -0
  69. bloqade/qasm2/passes/parallel.py +176 -0
  70. bloqade/qasm2/passes/py2qasm.py +63 -0
  71. bloqade/qasm2/passes/qasm2py.py +61 -0
  72. bloqade/qasm2/rewrite/__init__.py +12 -0
  73. bloqade/qasm2/rewrite/desugar.py +28 -0
  74. bloqade/qasm2/rewrite/glob.py +103 -0
  75. bloqade/qasm2/rewrite/heuristic_noise.py +247 -0
  76. bloqade/qasm2/rewrite/native_gates.py +447 -0
  77. bloqade/qasm2/rewrite/parallel_to_uop.py +83 -0
  78. bloqade/qasm2/rewrite/register.py +45 -0
  79. bloqade/qasm2/rewrite/uop_to_parallel.py +395 -0
  80. bloqade/qasm2/types.py +39 -0
  81. bloqade/qbraid/__init__.py +2 -0
  82. bloqade/qbraid/lowering.py +324 -0
  83. bloqade/qbraid/schema.py +252 -0
  84. bloqade/qbraid/simulation_result.py +99 -0
  85. bloqade/qbraid/target.py +86 -0
  86. bloqade/squin/__init__.py +2 -0
  87. bloqade/squin/analysis/__init__.py +0 -0
  88. bloqade/squin/analysis/nsites/__init__.py +8 -0
  89. bloqade/squin/analysis/nsites/analysis.py +52 -0
  90. bloqade/squin/analysis/nsites/impls.py +69 -0
  91. bloqade/squin/analysis/nsites/lattice.py +49 -0
  92. bloqade/squin/analysis/schedule.py +244 -0
  93. bloqade/squin/groups.py +38 -0
  94. bloqade/squin/op/__init__.py +132 -0
  95. bloqade/squin/op/_dialect.py +3 -0
  96. bloqade/squin/op/complex.py +6 -0
  97. bloqade/squin/op/stmts.py +220 -0
  98. bloqade/squin/op/traits.py +43 -0
  99. bloqade/squin/op/types.py +10 -0
  100. bloqade/squin/qubit.py +118 -0
  101. bloqade/squin/wire.py +103 -0
  102. bloqade/stim/__init__.py +6 -0
  103. bloqade/stim/_wrappers.py +186 -0
  104. bloqade/stim/dialects/__init__.py +5 -0
  105. bloqade/stim/dialects/aux/__init__.py +11 -0
  106. bloqade/stim/dialects/aux/_dialect.py +3 -0
  107. bloqade/stim/dialects/aux/emit.py +102 -0
  108. bloqade/stim/dialects/aux/interp.py +39 -0
  109. bloqade/stim/dialects/aux/lowering.py +40 -0
  110. bloqade/stim/dialects/aux/stmts/__init__.py +14 -0
  111. bloqade/stim/dialects/aux/stmts/annotate.py +47 -0
  112. bloqade/stim/dialects/aux/stmts/const.py +95 -0
  113. bloqade/stim/dialects/aux/types.py +19 -0
  114. bloqade/stim/dialects/collapse/__init__.py +3 -0
  115. bloqade/stim/dialects/collapse/_dialect.py +3 -0
  116. bloqade/stim/dialects/collapse/emit.py +68 -0
  117. bloqade/stim/dialects/collapse/stmts/__init__.py +3 -0
  118. bloqade/stim/dialects/collapse/stmts/measure.py +45 -0
  119. bloqade/stim/dialects/collapse/stmts/pp_measure.py +14 -0
  120. bloqade/stim/dialects/collapse/stmts/reset.py +26 -0
  121. bloqade/stim/dialects/gate/__init__.py +3 -0
  122. bloqade/stim/dialects/gate/_dialect.py +3 -0
  123. bloqade/stim/dialects/gate/emit.py +87 -0
  124. bloqade/stim/dialects/gate/stmts/__init__.py +14 -0
  125. bloqade/stim/dialects/gate/stmts/base.py +31 -0
  126. bloqade/stim/dialects/gate/stmts/clifford_1q.py +53 -0
  127. bloqade/stim/dialects/gate/stmts/clifford_2q.py +11 -0
  128. bloqade/stim/dialects/gate/stmts/control_2q.py +21 -0
  129. bloqade/stim/dialects/gate/stmts/pp.py +15 -0
  130. bloqade/stim/dialects/noise/__init__.py +3 -0
  131. bloqade/stim/dialects/noise/_dialect.py +3 -0
  132. bloqade/stim/dialects/noise/emit.py +66 -0
  133. bloqade/stim/dialects/noise/stmts.py +77 -0
  134. bloqade/stim/emit/__init__.py +1 -0
  135. bloqade/stim/emit/stim.py +54 -0
  136. bloqade/stim/groups.py +26 -0
  137. bloqade/test_utils.py +35 -0
  138. bloqade/types.py +24 -0
  139. bloqade/visual/__init__.py +1 -0
  140. bloqade/visual/animation/__init__.py +0 -0
  141. bloqade/visual/animation/animate.py +267 -0
  142. bloqade/visual/animation/base.py +346 -0
  143. bloqade/visual/animation/gate_event.py +24 -0
  144. bloqade/visual/animation/runtime/__init__.py +0 -0
  145. bloqade/visual/animation/runtime/aod.py +36 -0
  146. bloqade/visual/animation/runtime/atoms.py +55 -0
  147. bloqade/visual/animation/runtime/ppoly.py +50 -0
  148. bloqade/visual/animation/runtime/qpustate.py +119 -0
  149. bloqade/visual/animation/runtime/utils.py +43 -0
  150. bloqade_circuit-0.1.0.dist-info/METADATA +70 -0
  151. bloqade_circuit-0.1.0.dist-info/RECORD +153 -0
  152. bloqade_circuit-0.1.0.dist-info/WHEEL +4 -0
  153. bloqade_circuit-0.1.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,86 @@
1
+ from typing import TYPE_CHECKING, Union, Optional
2
+
3
+ from kirin import ir
4
+
5
+ if TYPE_CHECKING:
6
+ from qbraid import QbraidProvider
7
+ from qbraid.runtime import QbraidJob
8
+
9
+ from bloqade.qasm2.emit import QASM2
10
+
11
+
12
+ class qBraid:
13
+ """qBraid target for Bloqade kernels.
14
+
15
+ qBraid target that accepts a Bloqade kernel and submits the kernel to the QuEra simulator hosted on qBraid. A `QbraidJob` is obtainable
16
+ that then lets you query the status of the submitted program on the simulator as well
17
+ as obtain results.
18
+
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ allow_parallel: bool = False,
25
+ allow_global: bool = False,
26
+ provider: "QbraidProvider", # inject externally for easier mocking
27
+ qelib1: bool = True,
28
+ ) -> None:
29
+ """Initialize the qBraid target.
30
+
31
+ Args:
32
+ allow_parallel (bool):
33
+ Allow parallel gate in the resulting QASM2 AST. Defaults to `False`.
34
+ In the case its False, and the input kernel uses parallel gates, they will get rewrite into uop gates.
35
+
36
+ allow_global (bool):
37
+ Allow global gate in the resulting QASM2 AST. Defaults to `False`.
38
+ In the case its False, and the input kernel uses global gates, they will get rewrite into parallel gates.
39
+ If both `allow_parallel` and `allow_global` are False, the input kernel will be rewritten to use uop gates.
40
+
41
+ provider (QbraidProvider):
42
+ Qbraid-provided object to allow submission of the kernel to the QuEra simulator.
43
+ qelib1 (bool):
44
+ Include the `include "qelib1.inc"` line in the resulting QASM2 AST that's
45
+ submitted to qBraid. Defaults to `True`.
46
+ """
47
+
48
+ self.qelib1 = qelib1
49
+ self.provider = provider
50
+ self.allow_parallel = allow_parallel
51
+ self.allow_global = allow_global
52
+
53
+ def emit(
54
+ self,
55
+ method: ir.Method,
56
+ shots: Optional[int] = None,
57
+ tags: Optional[dict[str, str]] = None,
58
+ ) -> Union["QbraidJob", list["QbraidJob"]]:
59
+ """Submit the Bloqade kernel to the QuEra simulator on qBraid.
60
+
61
+ Args:
62
+ method (ir.Method):
63
+ The kernel to submit to qBraid.
64
+ shots: (Optional[int]):
65
+ Number of times to run the kernel. Defaults to None.
66
+ tags: (Optional[dict[str,str]]):
67
+ A dictionary of tags to associate with the Job.
68
+
69
+ Returns:
70
+ Union[QbraidJob, list[QbraidJob]]:
71
+ An object you can query for the status of your submission as well as
72
+ obtain simulator results from.
73
+ """
74
+
75
+ # Convert method to QASM2 string
76
+ qasm2_emitter = QASM2(
77
+ allow_parallel=self.allow_parallel,
78
+ allow_global=self.allow_global,
79
+ qelib1=self.qelib1,
80
+ )
81
+ qasm2_prog = qasm2_emitter.emit_str(method)
82
+
83
+ # Submit the QASM2 string to the qBraid simulator
84
+ quera_qasm_simulator = self.provider.get_device("quera_qasm_simulator")
85
+
86
+ return quera_qasm_simulator.run(qasm2_prog, shots=shots, tags=tags)
@@ -0,0 +1,2 @@
1
+ from . import op as op, wire as wire, qubit as qubit
2
+ from .groups import wired as wired, kernel as kernel
File without changes
@@ -0,0 +1,8 @@
1
+ # Need this for impl registration to work properly!
2
+ from . import impls as impls
3
+ from .lattice import (
4
+ NoSites as NoSites,
5
+ AnySites as AnySites,
6
+ NumberSites as NumberSites,
7
+ )
8
+ from .analysis import NSitesAnalysis as NSitesAnalysis
@@ -0,0 +1,52 @@
1
+ # from typing import cast
2
+
3
+ from kirin import ir
4
+ from kirin.analysis import Forward
5
+ from kirin.analysis.forward import ForwardFrame
6
+
7
+ from bloqade.squin.op.types import OpType
8
+ from bloqade.squin.op.traits import HasSites, FixedSites
9
+
10
+ from .lattice import Sites, NoSites, NumberSites
11
+
12
+
13
+ class NSitesAnalysis(Forward[Sites]):
14
+
15
+ keys = ["op.nsites"]
16
+ lattice = Sites
17
+
18
+ # Take a page from const prop in Kirin,
19
+ # I can get the data I want from the SizedTrait
20
+ # and go from there
21
+
22
+ ## This gets called before the registry look up
23
+ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
24
+ method = self.lookup_registry(frame, stmt)
25
+ if method is not None:
26
+ return method(self, frame, stmt)
27
+ elif stmt.has_trait(HasSites):
28
+ has_sites_trait = stmt.get_trait(HasSites)
29
+ sites = has_sites_trait.get_sites(stmt)
30
+ return (NumberSites(sites=sites),)
31
+ elif stmt.has_trait(FixedSites):
32
+ sites_trait = stmt.get_trait(FixedSites)
33
+ return (NumberSites(sites=sites_trait.data),)
34
+ else:
35
+ return (NoSites(),)
36
+
37
+ # For when no implementation is found for the statement
38
+ def eval_stmt_fallback(
39
+ self, frame: ForwardFrame[Sites], stmt: ir.Statement
40
+ ) -> tuple[Sites, ...]: # some form of Shape will go back into the frame
41
+ return tuple(
42
+ (
43
+ self.lattice.top()
44
+ if result.type.is_subseteq(OpType)
45
+ else self.lattice.bottom()
46
+ )
47
+ for result in stmt.results
48
+ )
49
+
50
+ def run_method(self, method: ir.Method, args: tuple[Sites, ...]):
51
+ # NOTE: we do not support dynamic calls here, thus no need to propagate method object
52
+ return self.run_callable(method.code, (self.lattice.bottom(),) + args)
@@ -0,0 +1,69 @@
1
+ from typing import cast
2
+
3
+ from kirin import ir, interp
4
+
5
+ from bloqade.squin import op
6
+
7
+ from .lattice import (
8
+ NoSites,
9
+ NumberSites,
10
+ )
11
+ from .analysis import NSitesAnalysis
12
+
13
+
14
+ @op.dialect.register(key="op.nsites")
15
+ class SquinOp(interp.MethodTable):
16
+
17
+ @interp.impl(op.stmts.Kron)
18
+ def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
19
+ lhs = frame.get(stmt.lhs)
20
+ rhs = frame.get(stmt.rhs)
21
+ if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
22
+ new_n_sites = lhs.sites + rhs.sites
23
+ return (NumberSites(sites=new_n_sites),)
24
+ else:
25
+ return (NoSites(),)
26
+
27
+ @interp.impl(op.stmts.Mult)
28
+ def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
29
+ lhs = frame.get(stmt.lhs)
30
+ rhs = frame.get(stmt.rhs)
31
+
32
+ if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
33
+ lhs_sites = lhs.sites
34
+ rhs_sites = rhs.sites
35
+ # I originally considered throwing an exception here
36
+ # but Xiu-zhe (Roger) Luo has pointed out it would be
37
+ # a much better UX to add a type element that
38
+ # could explicitly indicate the error. The downside
39
+ # is you'll have some added complexity in the type lattice.
40
+ if lhs_sites != rhs_sites:
41
+ return (NoSites(),)
42
+ else:
43
+ return (NumberSites(sites=lhs_sites + rhs_sites),)
44
+ else:
45
+ return (NoSites(),)
46
+
47
+ @interp.impl(op.stmts.Control)
48
+ def control(
49
+ self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control
50
+ ):
51
+ op_sites = frame.get(stmt.op)
52
+
53
+ if isinstance(op_sites, NumberSites):
54
+ n_sites = op_sites.sites
55
+ n_controls_attr = stmt.get_attr_or_prop("n_controls")
56
+ n_controls = cast(ir.PyAttr[int], n_controls_attr).data
57
+ return (NumberSites(sites=n_sites + n_controls),)
58
+ else:
59
+ return (NoSites(),)
60
+
61
+ @interp.impl(op.stmts.Rot)
62
+ def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
63
+ op_sites = frame.get(stmt.axis)
64
+ return (op_sites,)
65
+
66
+ @interp.impl(op.stmts.Scale)
67
+ def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
68
+ op_sites = frame.get(stmt.op)
69
+ return (op_sites,)
@@ -0,0 +1,49 @@
1
+ from typing import final
2
+ from dataclasses import dataclass
3
+
4
+ from kirin.lattice import (
5
+ SingletonMeta,
6
+ BoundedLattice,
7
+ SimpleJoinMixin,
8
+ SimpleMeetMixin,
9
+ )
10
+
11
+
12
+ @dataclass
13
+ class Sites(
14
+ SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"]
15
+ ):
16
+ @classmethod
17
+ def bottom(cls) -> "Sites":
18
+ return NoSites()
19
+
20
+ @classmethod
21
+ def top(cls) -> "Sites":
22
+ return AnySites()
23
+
24
+
25
+ @final
26
+ @dataclass
27
+ class NoSites(Sites, metaclass=SingletonMeta):
28
+
29
+ def is_subseteq(self, other: Sites) -> bool:
30
+ return True
31
+
32
+
33
+ @final
34
+ @dataclass
35
+ class AnySites(Sites, metaclass=SingletonMeta):
36
+
37
+ def is_subseteq(self, other: Sites) -> bool:
38
+ return isinstance(other, Sites)
39
+
40
+
41
+ @final
42
+ @dataclass
43
+ class NumberSites(Sites):
44
+ sites: int
45
+
46
+ def is_subseteq(self, other: Sites) -> bool:
47
+ if isinstance(other, NumberSites):
48
+ return self.sites == other.sites
49
+ return False
@@ -0,0 +1,244 @@
1
+ from typing import Any, Set, Dict, Iterable, Optional, final
2
+ from itertools import chain
3
+ from collections import OrderedDict
4
+ from dataclasses import field, dataclass
5
+ from collections.abc import Sequence
6
+
7
+ from kirin import ir, graph, interp, idtable
8
+ from kirin.lattice import (
9
+ SingletonMeta,
10
+ BoundedLattice,
11
+ SimpleJoinMixin,
12
+ SimpleMeetMixin,
13
+ )
14
+ from kirin.analysis import Forward, ForwardFrame
15
+ from kirin.dialects import func
16
+
17
+ from bloqade.analysis import address
18
+ from bloqade.qasm2.parse.print import Printer
19
+
20
+
21
+ @dataclass
22
+ class GateSchedule(
23
+ SimpleJoinMixin["GateSchedule"],
24
+ SimpleMeetMixin["GateSchedule"],
25
+ BoundedLattice["GateSchedule"],
26
+ ):
27
+
28
+ @classmethod
29
+ def bottom(cls) -> "GateSchedule":
30
+ return NotQubit()
31
+
32
+ @classmethod
33
+ def top(cls) -> "GateSchedule":
34
+ return Qubit()
35
+
36
+
37
+ @final
38
+ @dataclass
39
+ class NotQubit(GateSchedule, metaclass=SingletonMeta):
40
+
41
+ def is_subseteq(self, other: GateSchedule) -> bool:
42
+ return True
43
+
44
+
45
+ @final
46
+ @dataclass
47
+ class Qubit(GateSchedule, metaclass=SingletonMeta):
48
+
49
+ def is_subseteq(self, other: GateSchedule) -> bool:
50
+ return isinstance(other, Qubit)
51
+
52
+
53
+ # Treat global gates as terminators for this analysis, e.g. split block in half.
54
+
55
+
56
+ @dataclass(slots=True)
57
+ class StmtDag(graph.Graph[ir.Statement]):
58
+ id_table: idtable.IdTable[ir.Statement] = field(
59
+ default_factory=lambda: idtable.IdTable()
60
+ )
61
+ stmts: Dict[str, ir.Statement] = field(default_factory=OrderedDict)
62
+ out_edges: Dict[str, Set[str]] = field(default_factory=OrderedDict)
63
+ inc_edges: Dict[str, Set[str]] = field(default_factory=OrderedDict)
64
+ stmt_index: Dict[ir.Statement, int] = field(default_factory=OrderedDict)
65
+
66
+ def update_index(self, node: ir.Statement):
67
+ if node not in self.stmt_index:
68
+ self.stmt_index[node] = len(self.stmt_index)
69
+
70
+ def add_node(self, node: ir.Statement):
71
+ node_id = self.id_table[node]
72
+ self.stmts[node_id] = node
73
+ self.update_index(node)
74
+ self.out_edges.setdefault(node_id, set())
75
+ self.inc_edges.setdefault(node_id, set())
76
+ return node_id
77
+
78
+ def add_edge(self, src: ir.Statement, dst: ir.Statement):
79
+ src_id = self.add_node(src)
80
+ dst_id = self.add_node(dst)
81
+
82
+ self.out_edges[src_id].add(dst_id)
83
+ self.inc_edges[dst_id].add(src_id)
84
+
85
+ def get_parents(self, node: ir.Statement) -> Iterable[ir.Statement]:
86
+ return (
87
+ self.stmts[node_id]
88
+ for node_id in self.inc_edges.get(self.id_table[node], set())
89
+ )
90
+
91
+ def get_children(self, node: ir.Statement) -> Iterable[ir.Statement]:
92
+ return (
93
+ self.stmts[node_id]
94
+ for node_id in self.out_edges.get(self.id_table[node], set())
95
+ )
96
+
97
+ def get_neighbors(self, node: ir.Statement) -> Iterable[ir.Statement]:
98
+ return chain(self.get_parents(node), self.get_children(node))
99
+
100
+ def get_nodes(self) -> Iterable[ir.Statement]:
101
+ return self.stmts.values()
102
+
103
+ def get_edges(self) -> Iterable[tuple[ir.Statement, ir.Statement]]:
104
+ return (
105
+ (self.stmts[src], self.stmts[dst])
106
+ for src, dsts in self.out_edges.items()
107
+ for dst in dsts
108
+ )
109
+
110
+ def print(
111
+ self,
112
+ printer: Optional["Printer"] = None,
113
+ analysis: dict["ir.SSAValue", Any] | None = None,
114
+ ) -> None:
115
+ raise NotImplementedError
116
+
117
+ def topological_groups(self):
118
+ """Split the dag into topological groups where each group
119
+ contains nodes that have no dependencies on each other, but
120
+ have dependencies on nodes in one or more previous groups.
121
+
122
+ Yields:
123
+ List[str]: A list of node ids in a topological group
124
+
125
+
126
+ Raises:
127
+ ValueError: If a cyclic dependency is detected
128
+
129
+
130
+ The idea is to yield all nodes with no dependencies, then remove
131
+ those nodes from the graph repeating until no nodes are left
132
+ or we reach some upper limit. Worse case is a linear dag,
133
+ so we can use len(dag.stmts) as the upper limit
134
+
135
+ If we reach the limit and there are still nodes left, then we
136
+ have a cyclic dependency.
137
+ """
138
+
139
+ inc_edges = {k: set(v) for k, v in self.inc_edges.items()}
140
+
141
+ check_next = inc_edges.keys()
142
+
143
+ for _ in range(len(self.stmts)):
144
+ if len(inc_edges) == 0:
145
+ break
146
+
147
+ group = [node_id for node_id in check_next if len(inc_edges[node_id]) == 0]
148
+ yield group
149
+
150
+ check_next = set()
151
+ for n in group:
152
+ inc_edges.pop(n)
153
+ for m in self.out_edges[n]:
154
+ check_next.add(m)
155
+ inc_edges[m].remove(n)
156
+
157
+ if inc_edges:
158
+ raise ValueError("Cyclic dependency detected")
159
+
160
+
161
+ @dataclass
162
+ class DagScheduleAnalysis(Forward[GateSchedule]):
163
+ keys = ["qasm2.schedule.dag"]
164
+ lattice = GateSchedule
165
+
166
+ address_analysis: Dict[ir.SSAValue, address.Address]
167
+ use_def: Dict[int, ir.Statement] = field(init=False)
168
+ stmt_dag: StmtDag = field(init=False)
169
+ stmt_dags: Dict[ir.Block, StmtDag] = field(init=False)
170
+
171
+ def initialize(self):
172
+ self.use_def = {}
173
+ self.stmt_dag = StmtDag()
174
+ self.stmt_dags = {}
175
+ return super().initialize()
176
+
177
+ def push_current_dag(self, block: ir.Block):
178
+ # run when hitting terminator statements
179
+ assert block not in self.stmt_dags, "Block already in stmt_dags"
180
+
181
+ for node in self.use_def.values():
182
+ self.stmt_dag.add_node(node)
183
+
184
+ self.stmt_dags[block] = self.stmt_dag
185
+ self.stmt_dag = StmtDag()
186
+ self.use_def = {}
187
+
188
+ def run_method(self, method: ir.Method, args: tuple[GateSchedule, ...]):
189
+ # NOTE: we do not support dynamic calls here, thus no need to propagate method object
190
+ return self.run_callable(method.code, (self.lattice.bottom(),) + args)
191
+
192
+ def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
193
+ if stmt.has_trait(ir.IsTerminator):
194
+ assert (
195
+ stmt.parent_block is not None
196
+ ), "Terminator statement has no parent block"
197
+ self.push_current_dag(stmt.parent_block)
198
+
199
+ return tuple(self.lattice.top() for _ in stmt.results)
200
+
201
+ def _update_dag(self, stmt: ir.Statement, addr: address.Address):
202
+ if isinstance(addr, address.AddressQubit):
203
+ old_stmt = self.use_def.get(addr.data, None)
204
+ if old_stmt is not None:
205
+ self.stmt_dag.add_edge(old_stmt, stmt)
206
+ self.use_def[addr.data] = stmt
207
+ elif isinstance(addr, address.AddressReg):
208
+ for idx in addr.data:
209
+ old_stmt = self.use_def.get(idx, None)
210
+ if old_stmt is not None:
211
+ self.stmt_dag.add_edge(old_stmt, stmt)
212
+ self.use_def[idx] = stmt
213
+ elif isinstance(addr, address.AddressTuple):
214
+ for sub_addr in addr.data:
215
+ self._update_dag(stmt, sub_addr)
216
+
217
+ def update_dag(self, stmt: ir.Statement, args: Sequence[ir.SSAValue]):
218
+ self.stmt_dag.add_node(stmt)
219
+
220
+ for arg in args:
221
+ self._update_dag(
222
+ stmt, self.address_analysis.get(arg, address.Address.bottom())
223
+ )
224
+
225
+ def get_dags(self, mt: ir.Method, args=None, kwargs=None):
226
+ if args is None:
227
+ args = tuple(self.lattice.top() for _ in mt.args)
228
+
229
+ self.run(mt, args, kwargs).expect()
230
+ return self.stmt_dags
231
+
232
+
233
+ @func.dialect.register(key="qasm2.schedule.dag")
234
+ class FuncImpl(interp.MethodTable):
235
+ @interp.impl(func.Invoke)
236
+ @interp.impl(func.Call)
237
+ def invoke(
238
+ self,
239
+ interp: DagScheduleAnalysis,
240
+ frame: ForwardFrame,
241
+ stmt: func.Invoke | func.Call,
242
+ ):
243
+ interp.update_dag(stmt, stmt.inputs)
244
+ return tuple(interp.lattice.top() for _ in stmt.results)
@@ -0,0 +1,38 @@
1
+ from kirin import ir, passes
2
+ from kirin.prelude import structural_no_opt
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.qasm2.rewrite.desugar import IndexingDesugarPass
6
+
7
+ from . import op, wire, qubit
8
+
9
+
10
+ @ir.dialect_group(structural_no_opt.union([op, qubit]))
11
+ def kernel(self):
12
+ fold_pass = passes.Fold(self)
13
+ typeinfer_pass = passes.TypeInfer(self)
14
+ ilist_desugar_pass = ilist.IListDesugar(self)
15
+ indexing_desugar_pass = IndexingDesugarPass(self)
16
+
17
+ def run_pass(method, *, fold=True, typeinfer=True):
18
+ method.verify()
19
+ if fold:
20
+ fold_pass.fixpoint(method)
21
+
22
+ if typeinfer:
23
+ typeinfer_pass(method)
24
+ ilist_desugar_pass(method)
25
+ indexing_desugar_pass(method)
26
+ if typeinfer:
27
+ typeinfer_pass(method) # fix types after desugaring
28
+ method.code.typecheck()
29
+
30
+ return run_pass
31
+
32
+
33
+ @ir.dialect_group(structural_no_opt.union([op, wire]))
34
+ def wired(self):
35
+ def run_pass(method):
36
+ pass
37
+
38
+ return run_pass
@@ -0,0 +1,132 @@
1
+ from kirin import ir as _ir
2
+ from kirin.prelude import structural_no_opt as _structural_no_opt
3
+ from kirin.lowering import wraps as _wraps
4
+
5
+ from . import stmts as stmts, types as types
6
+ from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
7
+ from ._dialect import dialect as dialect
8
+
9
+
10
+ @_wraps(stmts.Kron)
11
+ def kron(lhs: types.Op, rhs: types.Op, *, is_unitary: bool = False) -> types.Op: ...
12
+
13
+
14
+ @_wraps(stmts.Adjoint)
15
+ def adjoint(op: types.Op, *, is_unitary: bool = False) -> types.Op: ...
16
+
17
+
18
+ @_wraps(stmts.Control)
19
+ def control(op: types.Op, *, n_controls: int, is_unitary: bool = False) -> types.Op: ...
20
+
21
+
22
+ @_wraps(stmts.Identity)
23
+ def identity(*, size: int) -> types.Op: ...
24
+
25
+
26
+ @_wraps(stmts.Rot)
27
+ def rot(axis: types.Op, angle: float) -> types.Op: ...
28
+
29
+
30
+ @_wraps(stmts.ShiftOp)
31
+ def shift(theta: float) -> types.Op: ...
32
+
33
+
34
+ @_wraps(stmts.PhaseOp)
35
+ def phase(theta: float) -> types.Op: ...
36
+
37
+
38
+ @_wraps(stmts.X)
39
+ def x() -> types.Op: ...
40
+
41
+
42
+ @_wraps(stmts.Y)
43
+ def y() -> types.Op: ...
44
+
45
+
46
+ @_wraps(stmts.Z)
47
+ def z() -> types.Op: ...
48
+
49
+
50
+ @_wraps(stmts.H)
51
+ def h() -> types.Op: ...
52
+
53
+
54
+ @_wraps(stmts.S)
55
+ def s() -> types.Op: ...
56
+
57
+
58
+ @_wraps(stmts.T)
59
+ def t() -> types.Op: ...
60
+
61
+
62
+ @_wraps(stmts.P0)
63
+ def p0() -> types.Op: ...
64
+
65
+
66
+ @_wraps(stmts.P1)
67
+ def p1() -> types.Op: ...
68
+
69
+
70
+ @_wraps(stmts.Sn)
71
+ def spin_n() -> types.Op: ...
72
+
73
+
74
+ @_wraps(stmts.Sp)
75
+ def spin_p() -> types.Op: ...
76
+
77
+
78
+ # stdlibs
79
+ @_ir.dialect_group(_structural_no_opt.add(dialect))
80
+ def op(self):
81
+ def run_pass(method):
82
+ pass
83
+
84
+ return run_pass
85
+
86
+
87
+ @op
88
+ def rx(theta: float) -> types.Op:
89
+ """Rotation X gate."""
90
+ return rot(x(), theta)
91
+
92
+
93
+ @op
94
+ def ry(theta: float) -> types.Op:
95
+ """Rotation Y gate."""
96
+ return rot(y(), theta)
97
+
98
+
99
+ @op
100
+ def rz(theta: float) -> types.Op:
101
+ """Rotation Z gate."""
102
+ return rot(z(), theta)
103
+
104
+
105
+ @op
106
+ def cx() -> types.Op:
107
+ """Controlled X gate."""
108
+ return control(x(), n_controls=1)
109
+
110
+
111
+ @op
112
+ def cy() -> types.Op:
113
+ """Controlled Y gate."""
114
+ return control(y(), n_controls=1)
115
+
116
+
117
+ @op
118
+ def cz() -> types.Op:
119
+ """Control Z gate."""
120
+ return control(z(), n_controls=1)
121
+
122
+
123
+ @op
124
+ def ch() -> types.Op:
125
+ """Control H gate."""
126
+ return control(h(), n_controls=1)
127
+
128
+
129
+ @op
130
+ def cphase(theta: float) -> types.Op:
131
+ """Control Phase gate."""
132
+ return control(phase(theta), n_controls=1)
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("squin.op")
@@ -0,0 +1,6 @@
1
+ # Stopgap Measure, squin dialect needs Complex type but
2
+ # this is only available in Kirin 0.15.x
3
+
4
+ from kirin.ir.attrs.types import PyClass
5
+
6
+ Complex = PyClass(complex)