bloqade-circuit 0.6.4__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/analysis.py +18 -20
  7. bloqade/analysis/measure_id/impls.py +31 -29
  8. bloqade/annotate/__init__.py +6 -0
  9. bloqade/annotate/_dialect.py +3 -0
  10. bloqade/annotate/_interface.py +22 -0
  11. bloqade/annotate/stmts.py +29 -0
  12. bloqade/annotate/types.py +13 -0
  13. bloqade/cirq_utils/__init__.py +4 -2
  14. bloqade/cirq_utils/emit/__init__.py +3 -0
  15. bloqade/cirq_utils/emit/base.py +246 -0
  16. bloqade/cirq_utils/emit/gate.py +104 -0
  17. bloqade/cirq_utils/emit/noise.py +90 -0
  18. bloqade/cirq_utils/emit/qubit.py +35 -0
  19. bloqade/cirq_utils/lowering.py +660 -0
  20. bloqade/cirq_utils/noise/__init__.py +0 -2
  21. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  22. bloqade/cirq_utils/noise/model.py +151 -191
  23. bloqade/cirq_utils/noise/transform.py +2 -2
  24. bloqade/cirq_utils/parallelize.py +9 -6
  25. bloqade/gemini/__init__.py +1 -0
  26. bloqade/gemini/analysis/__init__.py +3 -0
  27. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  28. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  29. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  30. bloqade/gemini/groups.py +67 -0
  31. bloqade/native/__init__.py +23 -0
  32. bloqade/native/_prelude.py +45 -0
  33. bloqade/native/dialects/__init__.py +0 -0
  34. bloqade/native/dialects/gate/__init__.py +2 -0
  35. bloqade/native/dialects/gate/_dialect.py +3 -0
  36. bloqade/native/dialects/gate/_interface.py +32 -0
  37. bloqade/native/dialects/gate/stmts.py +31 -0
  38. bloqade/native/stdlib/__init__.py +0 -0
  39. bloqade/native/stdlib/broadcast.py +246 -0
  40. bloqade/native/stdlib/simple.py +220 -0
  41. bloqade/native/upstream/__init__.py +4 -0
  42. bloqade/native/upstream/squin2native.py +79 -0
  43. bloqade/pyqrack/__init__.py +2 -2
  44. bloqade/pyqrack/base.py +7 -1
  45. bloqade/pyqrack/device.py +192 -18
  46. bloqade/pyqrack/native.py +49 -0
  47. bloqade/pyqrack/reg.py +6 -6
  48. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  49. bloqade/pyqrack/squin/gate/gate.py +136 -0
  50. bloqade/pyqrack/squin/noise/native.py +120 -54
  51. bloqade/pyqrack/squin/qubit.py +39 -36
  52. bloqade/pyqrack/target.py +5 -4
  53. bloqade/pyqrack/task.py +114 -7
  54. bloqade/qasm2/_qasm_loading.py +3 -3
  55. bloqade/qasm2/dialects/core/address.py +21 -12
  56. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  57. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  58. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  59. bloqade/qasm2/dialects/noise/model.py +2 -1
  60. bloqade/qasm2/emit/base.py +16 -11
  61. bloqade/qasm2/emit/gate.py +11 -8
  62. bloqade/qasm2/emit/main.py +103 -3
  63. bloqade/qasm2/emit/target.py +9 -5
  64. bloqade/qasm2/groups.py +3 -2
  65. bloqade/qasm2/parse/lowering.py +0 -1
  66. bloqade/qasm2/passes/fold.py +14 -73
  67. bloqade/qasm2/passes/glob.py +2 -2
  68. bloqade/qasm2/passes/noise.py +1 -1
  69. bloqade/qasm2/passes/parallel.py +7 -5
  70. bloqade/qasm2/rewrite/__init__.py +0 -1
  71. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  72. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  73. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  74. bloqade/qasm2/rewrite/register.py +2 -2
  75. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  76. bloqade/qbraid/lowering.py +1 -0
  77. bloqade/qbraid/schema.py +2 -2
  78. bloqade/qubit/__init__.py +12 -0
  79. bloqade/qubit/_dialect.py +3 -0
  80. bloqade/qubit/_interface.py +49 -0
  81. bloqade/qubit/_prelude.py +45 -0
  82. bloqade/qubit/analysis/__init__.py +1 -0
  83. bloqade/qubit/analysis/address_impl.py +40 -0
  84. bloqade/qubit/stdlib/__init__.py +2 -0
  85. bloqade/qubit/stdlib/_new.py +34 -0
  86. bloqade/qubit/stdlib/broadcast.py +62 -0
  87. bloqade/qubit/stdlib/simple.py +59 -0
  88. bloqade/qubit/stmts.py +60 -0
  89. bloqade/rewrite/passes/__init__.py +6 -0
  90. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  91. bloqade/rewrite/passes/callgraph.py +116 -0
  92. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  93. bloqade/rewrite/rules/split_ifs.py +18 -1
  94. bloqade/squin/__init__.py +47 -14
  95. bloqade/squin/analysis/__init__.py +0 -1
  96. bloqade/squin/analysis/schedule.py +10 -11
  97. bloqade/squin/gate/__init__.py +2 -0
  98. bloqade/squin/gate/_dialect.py +3 -0
  99. bloqade/squin/gate/_interface.py +98 -0
  100. bloqade/squin/gate/stmts.py +125 -0
  101. bloqade/squin/groups.py +5 -22
  102. bloqade/squin/noise/__init__.py +1 -10
  103. bloqade/squin/noise/_dialect.py +1 -1
  104. bloqade/squin/noise/_interface.py +45 -0
  105. bloqade/squin/noise/stmts.py +66 -28
  106. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  107. bloqade/squin/rewrite/__init__.py +0 -2
  108. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  109. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  110. bloqade/squin/stdlib/__init__.py +0 -0
  111. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  112. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  113. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  114. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  115. bloqade/squin/stdlib/simple/__init__.py +33 -0
  116. bloqade/squin/stdlib/simple/gate.py +242 -0
  117. bloqade/squin/stdlib/simple/noise.py +126 -0
  118. bloqade/stim/__init__.py +1 -0
  119. bloqade/stim/_wrappers.py +6 -0
  120. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  121. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  122. bloqade/stim/dialects/gate/emit.py +9 -10
  123. bloqade/stim/dialects/noise/emit.py +17 -13
  124. bloqade/stim/dialects/noise/stmts.py +5 -3
  125. bloqade/stim/emit/__init__.py +1 -0
  126. bloqade/stim/emit/impls.py +16 -0
  127. bloqade/stim/emit/stim_str.py +48 -31
  128. bloqade/stim/groups.py +12 -2
  129. bloqade/stim/parse/lowering.py +14 -17
  130. bloqade/stim/passes/__init__.py +0 -2
  131. bloqade/stim/passes/flatten.py +26 -0
  132. bloqade/stim/passes/simplify_ifs.py +6 -1
  133. bloqade/stim/passes/squin_to_stim.py +9 -84
  134. bloqade/stim/rewrite/__init__.py +2 -4
  135. bloqade/stim/rewrite/get_record_util.py +24 -0
  136. bloqade/stim/rewrite/ifs_to_stim.py +24 -25
  137. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  138. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  139. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  140. bloqade/stim/rewrite/squin_measure.py +9 -18
  141. bloqade/stim/rewrite/squin_noise.py +134 -108
  142. bloqade/stim/rewrite/util.py +5 -192
  143. bloqade/test_utils.py +1 -1
  144. bloqade/types.py +10 -0
  145. bloqade/validation/__init__.py +2 -0
  146. bloqade/validation/analysis/__init__.py +5 -0
  147. bloqade/validation/analysis/analysis.py +41 -0
  148. bloqade/validation/analysis/lattice.py +58 -0
  149. bloqade/validation/kernel_validation.py +77 -0
  150. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  151. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  152. bloqade/pyqrack/squin/op.py +0 -180
  153. bloqade/pyqrack/squin/runtime.py +0 -535
  154. bloqade/pyqrack/squin/wire.py +0 -51
  155. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  156. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  157. bloqade/squin/_typeinfer.py +0 -20
  158. bloqade/squin/analysis/address_impl.py +0 -71
  159. bloqade/squin/analysis/nsites/__init__.py +0 -9
  160. bloqade/squin/analysis/nsites/analysis.py +0 -50
  161. bloqade/squin/analysis/nsites/impls.py +0 -92
  162. bloqade/squin/analysis/nsites/lattice.py +0 -49
  163. bloqade/squin/cirq/__init__.py +0 -280
  164. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  165. bloqade/squin/cirq/emit/noise.py +0 -49
  166. bloqade/squin/cirq/emit/op.py +0 -125
  167. bloqade/squin/cirq/emit/qubit.py +0 -60
  168. bloqade/squin/cirq/emit/runtime.py +0 -242
  169. bloqade/squin/cirq/lowering.py +0 -440
  170. bloqade/squin/lowering.py +0 -54
  171. bloqade/squin/noise/_wrapper.py +0 -40
  172. bloqade/squin/noise/rewrite.py +0 -111
  173. bloqade/squin/op/__init__.py +0 -41
  174. bloqade/squin/op/_dialect.py +0 -3
  175. bloqade/squin/op/_wrapper.py +0 -121
  176. bloqade/squin/op/number.py +0 -5
  177. bloqade/squin/op/rewrite.py +0 -46
  178. bloqade/squin/op/stdlib.py +0 -62
  179. bloqade/squin/op/stmts.py +0 -276
  180. bloqade/squin/op/traits.py +0 -43
  181. bloqade/squin/op/types.py +0 -26
  182. bloqade/squin/qubit.py +0 -184
  183. bloqade/squin/rewrite/canonicalize.py +0 -60
  184. bloqade/squin/rewrite/desugar.py +0 -124
  185. bloqade/squin/types.py +0 -8
  186. bloqade/squin/wire.py +0 -201
  187. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  188. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  189. bloqade_circuit-0.6.4.dist-info/RECORD +0 -234
  190. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  191. {bloqade_circuit-0.6.4.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,116 @@
1
+ from dataclasses import field, dataclass
2
+
3
+ from kirin import ir, passes, rewrite
4
+ from kirin.analysis import CallGraph
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+ from kirin.dialects.func.stmts import Invoke
7
+
8
+
9
+ @dataclass
10
+ class ReplaceMethods(RewriteRule):
11
+ new_symbols: dict[ir.Method, ir.Method]
12
+
13
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14
+ if (
15
+ not isinstance(node, Invoke)
16
+ or (new_callee := self.new_symbols.get(node.callee)) is None
17
+ ):
18
+ return RewriteResult()
19
+
20
+ node.replace_by(
21
+ Invoke(
22
+ inputs=node.inputs,
23
+ callee=new_callee,
24
+ purity=node.purity,
25
+ )
26
+ )
27
+
28
+ return RewriteResult(has_done_something=True)
29
+
30
+
31
+ @dataclass
32
+ class UpdateDialectsOnCallGraph(passes.Pass):
33
+ """Update All dialects on the call graph to a new set of dialects given to this pass.
34
+
35
+ Usage:
36
+ pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
37
+ pass_(some_method)
38
+
39
+ Note: This pass does not update the dialects of the input method, but copies
40
+ all other methods invoked within it before updating their dialects.
41
+
42
+ """
43
+
44
+ fold_pass: passes.Fold = field(init=False)
45
+
46
+ def __post_init__(self):
47
+ self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
48
+
49
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
50
+ mt_map = {}
51
+
52
+ cg = CallGraph(mt)
53
+
54
+ all_methods = set(sum(map(tuple, cg.defs.values()), ()))
55
+ for original_mt in all_methods:
56
+ if original_mt is mt:
57
+ new_mt = original_mt
58
+ else:
59
+ new_mt = original_mt.similar(self.dialects)
60
+ mt_map[original_mt] = new_mt
61
+
62
+ result = RewriteResult()
63
+
64
+ for _, new_mt in mt_map.items():
65
+ result = (
66
+ rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
67
+ )
68
+ self.fold_pass(new_mt)
69
+
70
+ return result
71
+
72
+
73
+ @dataclass
74
+ class CallGraphPass(passes.Pass):
75
+ """Copy all functions in the call graph and apply a rule to each of them.
76
+
77
+
78
+ Usage:
79
+ rule = Walk(SomeRewriteRule())
80
+ pass_ = CallGraphPass(rule=rule, dialects=...)
81
+ pass_(some_method)
82
+
83
+ Note: This pass modifies the input method in place, but copies
84
+ all methods invoked within it before applying the rule to them.
85
+
86
+ """
87
+
88
+ rule: RewriteRule
89
+ """The rule to apply to each function in the call graph."""
90
+
91
+ fold_pass: passes.Fold = field(init=False)
92
+
93
+ def __post_init__(self):
94
+ self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
95
+
96
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
97
+ result = RewriteResult()
98
+ mt_map = {}
99
+
100
+ cg = CallGraph(mt)
101
+
102
+ all_methods = set(cg.edges.keys())
103
+ for original_mt in all_methods:
104
+ if original_mt is mt:
105
+ new_mt = original_mt
106
+ else:
107
+ new_mt = original_mt.similar()
108
+ result = self.rule.rewrite(new_mt.code).join(result)
109
+ mt_map[original_mt] = new_mt
110
+
111
+ if result.has_done_something:
112
+ for _, new_mt in mt_map.items():
113
+ rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code)
114
+ self.fold_pass(new_mt)
115
+
116
+ return result
@@ -1,28 +1,34 @@
1
- from dataclasses import dataclass
1
+ from dataclasses import field, dataclass
2
2
 
3
- from kirin import ir
4
- from kirin.passes import Pass
3
+ from kirin import ir, passes
5
4
  from kirin.rewrite import (
6
5
  Walk,
7
6
  Chain,
8
7
  Fixpoint,
8
+ DeadCodeElimination,
9
9
  )
10
- from kirin.analysis import const
11
-
12
- from ..rules.flatten_ilist import FlattenAddOpIList
13
- from ..rules.inline_getitem_ilist import InlineGetItemFromIList
10
+ from kirin.dialects.ilist import rewrite
14
11
 
15
12
 
16
13
  @dataclass
17
- class CanonicalizeIList(Pass):
14
+ class CanonicalizeIList(passes.Pass):
18
15
 
19
- def unsafe_run(self, mt: ir.Method):
16
+ fold_pass: passes.Fold = field(init=False)
20
17
 
21
- cp_result_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
18
+ def __post_init__(self):
19
+ self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
22
20
 
23
- return Fixpoint(
24
- Chain(
25
- Walk(InlineGetItemFromIList(constprop_result=cp_result_frame.entries)),
26
- Walk(FlattenAddOpIList()),
21
+ def unsafe_run(self, mt: ir.Method):
22
+ result = Fixpoint(
23
+ Walk(
24
+ Chain(
25
+ rewrite.InlineGetItem(),
26
+ rewrite.FlattenAdd(),
27
+ rewrite.HintLen(),
28
+ DeadCodeElimination(),
29
+ )
27
30
  )
28
31
  ).rewrite(mt.code)
32
+
33
+ result = self.fold_pass(mt).join(result)
34
+ return result
@@ -46,9 +46,13 @@ class SplitIfStmts(RewriteRule):
46
46
  if not isinstance(node, scf.IfElse):
47
47
  return RewriteResult()
48
48
 
49
+ # NOTE: only empty else bodies are allowed in valid QASM2
50
+ if not self._has_empty_else(node):
51
+ return RewriteResult()
52
+
49
53
  *stmts, yield_or_return = node.then_body.stmts()
50
54
 
51
- if len(stmts) == 1:
55
+ if len(stmts) <= 1:
52
56
  return RewriteResult()
53
57
 
54
58
  is_yield = isinstance(yield_or_return, scf.Yield)
@@ -71,3 +75,16 @@ class SplitIfStmts(RewriteRule):
71
75
  node.delete()
72
76
 
73
77
  return RewriteResult(has_done_something=True)
78
+
79
+ def _has_empty_else(self, node: scf.IfElse) -> bool:
80
+ else_stmts = list(node.else_body.stmts())
81
+ if len(else_stmts) > 1:
82
+ return False
83
+
84
+ if len(else_stmts) == 0:
85
+ return True
86
+
87
+ if not isinstance(else_stmts[0], scf.Yield):
88
+ return False
89
+
90
+ return len(else_stmts[0].values) == 0
bloqade/squin/__init__.py CHANGED
@@ -1,19 +1,52 @@
1
1
  from . import (
2
- op as op,
3
- wire as wire,
2
+ gate as gate,
4
3
  noise as noise,
5
- qubit as qubit,
6
4
  analysis as analysis,
7
- lowering as lowering,
8
- _typeinfer as _typeinfer,
9
5
  )
10
- from .groups import wired as wired, kernel as kernel
6
+ from .. import qubit as qubit, annotate as annotate
7
+ from ..qubit import (
8
+ reset as reset,
9
+ qalloc as qalloc,
10
+ measure as measure,
11
+ get_qubit_id as get_qubit_id,
12
+ get_measurement_id as get_measurement_id,
13
+ )
14
+ from .groups import kernel as kernel
15
+ from ..annotate import set_detector as set_detector, set_observable as set_observable
16
+ from .stdlib.simple import (
17
+ h as h,
18
+ s as s,
19
+ t as t,
20
+ x as x,
21
+ y as y,
22
+ z as z,
23
+ cx as cx,
24
+ cy as cy,
25
+ cz as cz,
26
+ rx as rx,
27
+ ry as ry,
28
+ rz as rz,
29
+ u3 as u3,
30
+ s_adj as s_adj,
31
+ shift as shift,
32
+ t_adj as t_adj,
33
+ sqrt_x as sqrt_x,
34
+ sqrt_y as sqrt_y,
35
+ sqrt_z as sqrt_z,
36
+ bit_flip as bit_flip,
37
+ depolarize as depolarize,
38
+ qubit_loss as qubit_loss,
39
+ sqrt_x_adj as sqrt_x_adj,
40
+ sqrt_y_adj as sqrt_y_adj,
41
+ sqrt_z_adj as sqrt_z_adj,
42
+ depolarize2 as depolarize2,
43
+ correlated_qubit_loss as correlated_qubit_loss,
44
+ two_qubit_pauli_channel as two_qubit_pauli_channel,
45
+ single_qubit_pauli_channel as single_qubit_pauli_channel,
46
+ )
11
47
 
12
- try:
13
- # NOTE: make sure optional cirq dependency is installed
14
- import cirq as cirq_package # noqa: F401
15
- except ImportError:
16
- pass
17
- else:
18
- from . import cirq as cirq
19
- from .cirq import load_circuit as load_circuit
48
+ # NOTE: it's important to keep these imports here since they import squin.kernel
49
+ # we skip isort here
50
+ from .stdlib import ( # isort: skip
51
+ broadcast as broadcast,
52
+ )
@@ -1 +0,0 @@
1
- from . import address_impl as address_impl
@@ -185,18 +185,17 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
185
185
  self.stmt_dag = StmtDag()
186
186
  self.use_def = {}
187
187
 
188
- def run_method(self, method: ir.Method, args: tuple[GateSchedule, ...]):
189
- # NOTE: we do not support dynamic calls here, thus no need to propagate method object
190
- return self.run_callable(method.code, (self.lattice.bottom(),) + args)
188
+ def method_self(self, method: ir.Method) -> GateSchedule:
189
+ return self.lattice.bottom()
191
190
 
192
- def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
193
- if stmt.has_trait(ir.IsTerminator):
191
+ def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
192
+ if node.has_trait(ir.IsTerminator):
194
193
  assert (
195
- stmt.parent_block is not None
194
+ node.parent_block is not None
196
195
  ), "Terminator statement has no parent block"
197
- self.push_current_dag(stmt.parent_block)
196
+ self.push_current_dag(node.parent_block)
198
197
 
199
- return tuple(self.lattice.top() for _ in stmt.results)
198
+ return tuple(self.lattice.top() for _ in node.results)
200
199
 
201
200
  def _update_dag(self, stmt: ir.Statement, addr: address.Address):
202
201
  if isinstance(addr, address.AddressQubit):
@@ -210,8 +209,8 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
210
209
  if old_stmt is not None:
211
210
  self.stmt_dag.add_edge(old_stmt, stmt)
212
211
  self.use_def[idx] = stmt
213
- elif isinstance(addr, address.AddressTuple):
214
- for sub_addr in addr.data:
212
+ elif isinstance(addr, address.AddressReg):
213
+ for sub_addr in addr.qubits:
215
214
  self._update_dag(stmt, sub_addr)
216
215
 
217
216
  def update_dag(self, stmt: ir.Statement, args: Sequence[ir.SSAValue]):
@@ -226,7 +225,7 @@ class DagScheduleAnalysis(Forward[GateSchedule]):
226
225
  if args is None:
227
226
  args = tuple(self.lattice.top() for _ in mt.args)
228
227
 
229
- self.run(mt, args, kwargs)
228
+ self.run(mt)
230
229
  return self.stmt_dags
231
230
 
232
231
 
@@ -0,0 +1,2 @@
1
+ from . import stmts as stmts
2
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("squin.gate")
@@ -0,0 +1,98 @@
1
+ from typing import Any, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+ from kirin.lowering import wraps
5
+
6
+ from bloqade.types import Qubit
7
+
8
+ from .stmts import (
9
+ CX,
10
+ CY,
11
+ CZ,
12
+ U3,
13
+ H,
14
+ S,
15
+ T,
16
+ X,
17
+ Y,
18
+ Z,
19
+ Rx,
20
+ Ry,
21
+ Rz,
22
+ SqrtX,
23
+ SqrtY,
24
+ )
25
+
26
+
27
+ @wraps(X)
28
+ def x(qubits: ilist.IList[Qubit, Any]) -> None: ...
29
+
30
+
31
+ @wraps(Y)
32
+ def y(qubits: ilist.IList[Qubit, Any]) -> None: ...
33
+
34
+
35
+ @wraps(Z)
36
+ def z(qubits: ilist.IList[Qubit, Any]) -> None: ...
37
+
38
+
39
+ @wraps(H)
40
+ def h(qubits: ilist.IList[Qubit, Any]) -> None: ...
41
+
42
+
43
+ @wraps(T)
44
+ def t(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
45
+
46
+
47
+ @wraps(S)
48
+ def s(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
49
+
50
+
51
+ @wraps(SqrtX)
52
+ def sqrt_x(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
53
+
54
+
55
+ @wraps(SqrtY)
56
+ def sqrt_y(qubits: ilist.IList[Qubit, Any], *, adjoint: bool) -> None: ...
57
+
58
+
59
+ @wraps(Rx)
60
+ def rx(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
61
+
62
+
63
+ @wraps(Ry)
64
+ def ry(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
65
+
66
+
67
+ @wraps(Rz)
68
+ def rz(angle: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
69
+
70
+
71
+ Len = TypeVar("Len", bound=int)
72
+
73
+
74
+ @wraps(CX)
75
+ def cx(
76
+ controls: ilist.IList[Qubit, Len],
77
+ targets: ilist.IList[Qubit, Len],
78
+ ) -> None: ...
79
+
80
+
81
+ @wraps(CY)
82
+ def cy(
83
+ controls: ilist.IList[Qubit, Len],
84
+ targets: ilist.IList[Qubit, Len],
85
+ ) -> None: ...
86
+
87
+
88
+ @wraps(CZ)
89
+ def cz(
90
+ controls: ilist.IList[Qubit, Len],
91
+ targets: ilist.IList[Qubit, Len],
92
+ ) -> None: ...
93
+
94
+
95
+ @wraps(U3)
96
+ def u3(
97
+ theta: float, phi: float, lam: float, qubits: ilist.IList[Qubit, Any]
98
+ ) -> None: ...
@@ -0,0 +1,125 @@
1
+ from kirin import ir, types, lowering
2
+ from kirin.decl import info, statement
3
+ from kirin.dialects import ilist
4
+
5
+ from bloqade.types import QubitType
6
+
7
+ from ._dialect import dialect
8
+
9
+
10
+ @statement
11
+ class Gate(ir.Statement):
12
+ # NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class
13
+ pass
14
+
15
+
16
+ @statement
17
+ class SingleQubitGate(Gate):
18
+ traits = frozenset({lowering.FromPythonCall()})
19
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
20
+
21
+
22
+ @statement(dialect=dialect)
23
+ class X(SingleQubitGate):
24
+ pass
25
+
26
+
27
+ @statement(dialect=dialect)
28
+ class Y(SingleQubitGate):
29
+ pass
30
+
31
+
32
+ @statement(dialect=dialect)
33
+ class Z(SingleQubitGate):
34
+ pass
35
+
36
+
37
+ @statement(dialect=dialect)
38
+ class H(SingleQubitGate):
39
+ pass
40
+
41
+
42
+ @statement
43
+ class SingleQubitNonHermitianGate(SingleQubitGate):
44
+ adjoint: bool = info.attribute(default=False)
45
+
46
+
47
+ @statement(dialect=dialect)
48
+ class T(SingleQubitNonHermitianGate):
49
+ pass
50
+
51
+
52
+ @statement(dialect=dialect)
53
+ class S(SingleQubitNonHermitianGate):
54
+ pass
55
+
56
+
57
+ @statement(dialect=dialect)
58
+ class SqrtX(SingleQubitNonHermitianGate):
59
+ pass
60
+
61
+
62
+ @statement(dialect=dialect)
63
+ class SqrtY(SingleQubitNonHermitianGate):
64
+ pass
65
+
66
+
67
+ @statement
68
+ class RotationGate(Gate):
69
+ # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
70
+ traits = frozenset({lowering.FromPythonCall()})
71
+ angle: ir.SSAValue = info.argument(types.Float)
72
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
73
+
74
+
75
+ @statement(dialect=dialect)
76
+ class Rx(RotationGate):
77
+ pass
78
+
79
+
80
+ @statement(dialect=dialect)
81
+ class Ry(RotationGate):
82
+ pass
83
+
84
+
85
+ @statement(dialect=dialect)
86
+ class Rz(RotationGate):
87
+ pass
88
+
89
+
90
+ N = types.TypeVar("N", bound=types.Int)
91
+
92
+
93
+ @statement
94
+ class ControlledGate(Gate):
95
+ traits = frozenset({lowering.FromPythonCall()})
96
+ controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
97
+ targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
98
+
99
+
100
+ @statement(dialect=dialect)
101
+ class CX(ControlledGate):
102
+ name = "cx"
103
+ pass
104
+
105
+
106
+ @statement(dialect=dialect)
107
+ class CY(ControlledGate):
108
+ name = "cy"
109
+ pass
110
+
111
+
112
+ @statement(dialect=dialect)
113
+ class CZ(ControlledGate):
114
+ name = "cz"
115
+ pass
116
+
117
+
118
+ @statement(dialect=dialect)
119
+ class U3(Gate):
120
+ # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
121
+ traits = frozenset({lowering.FromPythonCall()})
122
+ theta: ir.SSAValue = info.argument(types.Float)
123
+ phi: ir.SSAValue = info.argument(types.Float)
124
+ lam: ir.SSAValue = info.argument(types.Float)
125
+ qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
bloqade/squin/groups.py CHANGED
@@ -1,31 +1,24 @@
1
1
  from kirin import ir, passes
2
2
  from kirin.prelude import structural_no_opt
3
- from kirin.rewrite import Walk, Chain
4
- from kirin.dialects import ilist
3
+ from kirin.dialects import debug, ilist
5
4
 
6
- from . import op, wire, noise, qubit
7
- from .op.rewrite import PyMultToSquinMult
8
- from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
5
+ from . import gate, noise
6
+ from .. import qubit, annotate
9
7
 
10
8
 
11
- @ir.dialect_group(structural_no_opt.union([op, qubit, noise]))
9
+ @ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug, annotate]))
12
10
  def kernel(self):
13
11
  fold_pass = passes.Fold(self)
14
12
  typeinfer_pass = passes.TypeInfer(self)
15
13
  ilist_desugar_pass = ilist.IListDesugar(self)
16
- desugar_pass = Walk(Chain(MeasureDesugarRule(), ApplyDesugarRule()))
17
- py_mult_to_mult_pass = PyMultToSquinMult(self)
18
14
 
19
15
  def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
20
16
  method.verify()
21
17
  if fold:
22
18
  fold_pass.fixpoint(method)
23
19
 
24
- py_mult_to_mult_pass(method)
25
-
26
20
  if typeinfer:
27
- typeinfer_pass(method)
28
- desugar_pass.rewrite(method.code)
21
+ typeinfer_pass(method) # infer types before desugaring
29
22
 
30
23
  ilist_desugar_pass(method)
31
24
 
@@ -34,13 +27,3 @@ def kernel(self):
34
27
  method.verify_type()
35
28
 
36
29
  return run_pass
37
-
38
-
39
- @ir.dialect_group(structural_no_opt.union([op, wire, noise]))
40
- def wired(self):
41
- py_mult_to_mult_pass = PyMultToSquinMult(self)
42
-
43
- def run_pass(method):
44
- py_mult_to_mult_pass(method)
45
-
46
- return run_pass
@@ -1,11 +1,2 @@
1
- from . import stmts as stmts
1
+ from . import stmts as stmts, _interface as _interface
2
2
  from ._dialect import dialect as dialect
3
- from ._wrapper import (
4
- pp_error as pp_error,
5
- depolarize as depolarize,
6
- qubit_loss as qubit_loss,
7
- depolarize2 as depolarize2,
8
- pauli_error as pauli_error,
9
- two_qubit_pauli_channel as two_qubit_pauli_channel,
10
- single_qubit_pauli_channel as single_qubit_pauli_channel,
11
- )
@@ -1,3 +1,3 @@
1
1
  from kirin import ir
2
2
 
3
- dialect = ir.Dialect(name="squin.noise")
3
+ dialect = ir.Dialect("squin.noise")
@@ -0,0 +1,45 @@
1
+ from typing import Any, Literal, TypeVar
2
+
3
+ from kirin.dialects import ilist
4
+ from kirin.lowering import wraps
5
+
6
+ from bloqade.types import Qubit
7
+
8
+ from . import stmts
9
+
10
+
11
+ @wraps(stmts.Depolarize)
12
+ def depolarize(p: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
13
+
14
+
15
+ N = TypeVar("N", bound=int)
16
+
17
+
18
+ @wraps(stmts.Depolarize2)
19
+ def depolarize2(
20
+ p: float, controls: ilist.IList[Qubit, N], targets: ilist.IList[Qubit, N]
21
+ ) -> None: ...
22
+
23
+
24
+ @wraps(stmts.SingleQubitPauliChannel)
25
+ def single_qubit_pauli_channel(
26
+ px: float, py: float, pz: float, qubits: ilist.IList[Qubit, Any]
27
+ ) -> None: ...
28
+
29
+
30
+ @wraps(stmts.TwoQubitPauliChannel)
31
+ def two_qubit_pauli_channel(
32
+ probabilities: ilist.IList[float, Literal[15]],
33
+ controls: ilist.IList[Qubit, N],
34
+ targets: ilist.IList[Qubit, N],
35
+ ) -> None: ...
36
+
37
+
38
+ @wraps(stmts.QubitLoss)
39
+ def qubit_loss(p: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
40
+
41
+
42
+ @wraps(stmts.CorrelatedQubitLoss)
43
+ def correlated_qubit_loss(
44
+ p: float, qubits: ilist.IList[ilist.IList[Qubit, N], Any]
45
+ ) -> None: ...