bloqade-circuit 0.6.2__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +123 -33
  3. bloqade/analysis/address/impls.py +293 -90
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +11 -23
  6. bloqade/analysis/measure_id/__init__.py +4 -1
  7. bloqade/analysis/measure_id/analysis.py +29 -20
  8. bloqade/analysis/measure_id/impls.py +72 -31
  9. bloqade/annotate/__init__.py +6 -0
  10. bloqade/annotate/_dialect.py +3 -0
  11. bloqade/annotate/_interface.py +22 -0
  12. bloqade/annotate/stmts.py +29 -0
  13. bloqade/annotate/types.py +13 -0
  14. bloqade/cirq_utils/__init__.py +4 -2
  15. bloqade/cirq_utils/emit/__init__.py +3 -0
  16. bloqade/cirq_utils/emit/base.py +246 -0
  17. bloqade/cirq_utils/emit/gate.py +104 -0
  18. bloqade/cirq_utils/emit/noise.py +90 -0
  19. bloqade/cirq_utils/emit/qubit.py +35 -0
  20. bloqade/cirq_utils/lowering.py +660 -0
  21. bloqade/cirq_utils/noise/__init__.py +0 -2
  22. bloqade/cirq_utils/noise/_two_zone_utils.py +7 -15
  23. bloqade/cirq_utils/noise/model.py +151 -191
  24. bloqade/cirq_utils/noise/transform.py +2 -2
  25. bloqade/cirq_utils/parallelize.py +9 -6
  26. bloqade/gemini/__init__.py +1 -0
  27. bloqade/gemini/analysis/__init__.py +3 -0
  28. bloqade/gemini/analysis/logical_validation/__init__.py +1 -0
  29. bloqade/gemini/analysis/logical_validation/analysis.py +17 -0
  30. bloqade/gemini/analysis/logical_validation/impls.py +101 -0
  31. bloqade/gemini/groups.py +67 -0
  32. bloqade/native/__init__.py +23 -0
  33. bloqade/native/_prelude.py +45 -0
  34. bloqade/native/dialects/__init__.py +0 -0
  35. bloqade/native/dialects/gate/__init__.py +2 -0
  36. bloqade/native/dialects/gate/_dialect.py +3 -0
  37. bloqade/native/dialects/gate/_interface.py +32 -0
  38. bloqade/native/dialects/gate/stmts.py +31 -0
  39. bloqade/native/stdlib/__init__.py +0 -0
  40. bloqade/native/stdlib/broadcast.py +246 -0
  41. bloqade/native/stdlib/simple.py +220 -0
  42. bloqade/native/upstream/__init__.py +4 -0
  43. bloqade/native/upstream/squin2native.py +79 -0
  44. bloqade/pyqrack/__init__.py +2 -2
  45. bloqade/pyqrack/base.py +7 -1
  46. bloqade/pyqrack/device.py +190 -4
  47. bloqade/pyqrack/native.py +49 -0
  48. bloqade/pyqrack/reg.py +6 -6
  49. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  50. bloqade/pyqrack/squin/gate/gate.py +136 -0
  51. bloqade/pyqrack/squin/noise/native.py +120 -54
  52. bloqade/pyqrack/squin/qubit.py +39 -36
  53. bloqade/pyqrack/target.py +5 -4
  54. bloqade/pyqrack/task.py +114 -7
  55. bloqade/qasm2/_qasm_loading.py +3 -3
  56. bloqade/qasm2/dialects/core/address.py +21 -12
  57. bloqade/qasm2/dialects/expr/_emit.py +19 -8
  58. bloqade/qasm2/dialects/expr/stmts.py +7 -7
  59. bloqade/qasm2/dialects/noise/fidelity.py +4 -8
  60. bloqade/qasm2/dialects/noise/model.py +2 -1
  61. bloqade/qasm2/emit/base.py +16 -11
  62. bloqade/qasm2/emit/gate.py +11 -8
  63. bloqade/qasm2/emit/main.py +103 -3
  64. bloqade/qasm2/emit/target.py +9 -5
  65. bloqade/qasm2/groups.py +3 -2
  66. bloqade/qasm2/parse/lowering.py +0 -1
  67. bloqade/qasm2/passes/fold.py +14 -73
  68. bloqade/qasm2/passes/glob.py +2 -2
  69. bloqade/qasm2/passes/noise.py +1 -1
  70. bloqade/qasm2/passes/parallel.py +7 -5
  71. bloqade/qasm2/rewrite/__init__.py +0 -1
  72. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  73. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  74. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  75. bloqade/qasm2/rewrite/register.py +2 -2
  76. bloqade/qasm2/rewrite/uop_to_parallel.py +4 -2
  77. bloqade/qbraid/lowering.py +1 -0
  78. bloqade/qbraid/schema.py +2 -2
  79. bloqade/qubit/__init__.py +12 -0
  80. bloqade/qubit/_dialect.py +3 -0
  81. bloqade/qubit/_interface.py +49 -0
  82. bloqade/qubit/_prelude.py +45 -0
  83. bloqade/qubit/analysis/__init__.py +1 -0
  84. bloqade/qubit/analysis/address_impl.py +40 -0
  85. bloqade/qubit/stdlib/__init__.py +2 -0
  86. bloqade/qubit/stdlib/_new.py +34 -0
  87. bloqade/qubit/stdlib/broadcast.py +62 -0
  88. bloqade/qubit/stdlib/simple.py +59 -0
  89. bloqade/qubit/stmts.py +60 -0
  90. bloqade/rewrite/passes/__init__.py +6 -0
  91. bloqade/rewrite/passes/aggressive_unroll.py +103 -0
  92. bloqade/rewrite/passes/callgraph.py +116 -0
  93. bloqade/rewrite/passes/canonicalize_ilist.py +20 -14
  94. bloqade/rewrite/rules/split_ifs.py +18 -1
  95. bloqade/squin/__init__.py +47 -14
  96. bloqade/squin/analysis/__init__.py +0 -1
  97. bloqade/squin/analysis/schedule.py +10 -11
  98. bloqade/squin/gate/__init__.py +2 -0
  99. bloqade/squin/gate/_dialect.py +3 -0
  100. bloqade/squin/gate/_interface.py +98 -0
  101. bloqade/squin/gate/stmts.py +125 -0
  102. bloqade/squin/groups.py +5 -22
  103. bloqade/squin/noise/__init__.py +1 -10
  104. bloqade/squin/noise/_dialect.py +1 -1
  105. bloqade/squin/noise/_interface.py +45 -0
  106. bloqade/squin/noise/stmts.py +66 -28
  107. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  108. bloqade/squin/rewrite/__init__.py +0 -2
  109. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  110. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  111. bloqade/squin/stdlib/__init__.py +0 -0
  112. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  113. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  114. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  115. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  116. bloqade/squin/stdlib/simple/__init__.py +33 -0
  117. bloqade/squin/stdlib/simple/gate.py +242 -0
  118. bloqade/squin/stdlib/simple/noise.py +126 -0
  119. bloqade/stim/__init__.py +1 -0
  120. bloqade/stim/_wrappers.py +6 -0
  121. bloqade/stim/dialects/auxiliary/emit.py +19 -18
  122. bloqade/stim/dialects/collapse/emit_str.py +7 -8
  123. bloqade/stim/dialects/gate/emit.py +9 -10
  124. bloqade/stim/dialects/noise/emit.py +17 -13
  125. bloqade/stim/dialects/noise/stmts.py +5 -3
  126. bloqade/stim/emit/__init__.py +1 -0
  127. bloqade/stim/emit/impls.py +16 -0
  128. bloqade/stim/emit/stim_str.py +48 -31
  129. bloqade/stim/groups.py +12 -2
  130. bloqade/stim/parse/lowering.py +14 -17
  131. bloqade/stim/passes/__init__.py +3 -1
  132. bloqade/stim/passes/flatten.py +26 -0
  133. bloqade/stim/passes/simplify_ifs.py +16 -2
  134. bloqade/stim/passes/squin_to_stim.py +18 -60
  135. bloqade/stim/rewrite/__init__.py +3 -4
  136. bloqade/stim/rewrite/get_record_util.py +24 -0
  137. bloqade/stim/rewrite/ifs_to_stim.py +29 -31
  138. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  139. bloqade/stim/rewrite/set_detector_to_stim.py +68 -0
  140. bloqade/stim/rewrite/set_observable_to_stim.py +52 -0
  141. bloqade/stim/rewrite/squin_measure.py +11 -79
  142. bloqade/stim/rewrite/squin_noise.py +134 -108
  143. bloqade/stim/rewrite/util.py +5 -192
  144. bloqade/test_utils.py +1 -1
  145. bloqade/types.py +10 -0
  146. bloqade/validation/__init__.py +2 -0
  147. bloqade/validation/analysis/__init__.py +5 -0
  148. bloqade/validation/analysis/analysis.py +41 -0
  149. bloqade/validation/analysis/lattice.py +58 -0
  150. bloqade/validation/kernel_validation.py +77 -0
  151. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/METADATA +5 -6
  152. bloqade_circuit-0.9.1.dist-info/RECORD +265 -0
  153. bloqade/pyqrack/squin/op.py +0 -166
  154. bloqade/pyqrack/squin/runtime.py +0 -535
  155. bloqade/pyqrack/squin/wire.py +0 -51
  156. bloqade/rewrite/rules/flatten_ilist.py +0 -51
  157. bloqade/rewrite/rules/inline_getitem_ilist.py +0 -31
  158. bloqade/squin/_typeinfer.py +0 -20
  159. bloqade/squin/analysis/address_impl.py +0 -71
  160. bloqade/squin/analysis/nsites/__init__.py +0 -9
  161. bloqade/squin/analysis/nsites/analysis.py +0 -50
  162. bloqade/squin/analysis/nsites/impls.py +0 -92
  163. bloqade/squin/analysis/nsites/lattice.py +0 -49
  164. bloqade/squin/cirq/__init__.py +0 -265
  165. bloqade/squin/cirq/emit/emit_circuit.py +0 -109
  166. bloqade/squin/cirq/emit/noise.py +0 -49
  167. bloqade/squin/cirq/emit/op.py +0 -125
  168. bloqade/squin/cirq/emit/qubit.py +0 -60
  169. bloqade/squin/cirq/emit/runtime.py +0 -242
  170. bloqade/squin/cirq/lowering.py +0 -440
  171. bloqade/squin/lowering.py +0 -54
  172. bloqade/squin/noise/_wrapper.py +0 -40
  173. bloqade/squin/noise/rewrite.py +0 -111
  174. bloqade/squin/op/__init__.py +0 -41
  175. bloqade/squin/op/_dialect.py +0 -3
  176. bloqade/squin/op/_wrapper.py +0 -121
  177. bloqade/squin/op/number.py +0 -5
  178. bloqade/squin/op/rewrite.py +0 -46
  179. bloqade/squin/op/stdlib.py +0 -62
  180. bloqade/squin/op/stmts.py +0 -276
  181. bloqade/squin/op/traits.py +0 -43
  182. bloqade/squin/op/types.py +0 -26
  183. bloqade/squin/qubit.py +0 -184
  184. bloqade/squin/rewrite/canonicalize.py +0 -60
  185. bloqade/squin/rewrite/desugar.py +0 -124
  186. bloqade/squin/types.py +0 -8
  187. bloqade/squin/wire.py +0 -201
  188. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  189. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  190. bloqade_circuit-0.6.2.dist-info/RECORD +0 -234
  191. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/WHEEL +0 -0
  192. {bloqade_circuit-0.6.2.dist-info → bloqade_circuit-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,15 @@
1
1
  from . import impls as impls
2
2
  from .lattice import (
3
+ Bottom as Bottom,
3
4
  Address as Address,
4
- NotQubit as NotQubit,
5
+ Unknown as Unknown,
5
6
  AddressReg as AddressReg,
6
- AnyAddress as AnyAddress,
7
- AddressWire as AddressWire,
7
+ UnknownReg as UnknownReg,
8
+ ConstResult as ConstResult,
8
9
  AddressQubit as AddressQubit,
9
- AddressTuple as AddressTuple,
10
+ PartialIList as PartialIList,
11
+ PartialTuple as PartialTuple,
12
+ UnknownQubit as UnknownQubit,
13
+ PartialLambda as PartialLambda,
10
14
  )
11
15
  from .analysis import AddressAnalysis as AddressAnalysis
@@ -1,13 +1,13 @@
1
- from typing import TypeVar
1
+ from typing import Any, Type, TypeVar
2
2
  from dataclasses import field
3
3
 
4
- from kirin import ir, interp
4
+ from kirin import ir, types, interp
5
5
  from kirin.analysis import Forward, const
6
+ from kirin.dialects.ilist import IList
6
7
  from kirin.analysis.forward import ForwardFrame
8
+ from kirin.analysis.const.lattice import PartialLambda
7
9
 
8
- from bloqade.types import QubitType
9
-
10
- from .lattice import Address
10
+ from .lattice import Address, AddressReg, ConstResult, PartialIList, PartialTuple
11
11
 
12
12
 
13
13
  class AddressAnalysis(Forward[Address]):
@@ -15,13 +15,16 @@ class AddressAnalysis(Forward[Address]):
15
15
  This analysis pass can be used to track the global addresses of qubits and wires.
16
16
  """
17
17
 
18
- keys = ["qubit.address"]
18
+ keys = ("qubit.address",)
19
+ _const_prop: const.Propagate
19
20
  lattice = Address
20
21
  next_address: int = field(init=False)
21
22
 
22
23
  def initialize(self):
23
24
  super().initialize()
24
25
  self.next_address: int = 0
26
+ self._const_prop = const.Propagate(self.dialects)
27
+ self._const_prop.initialize()
25
28
  return self
26
29
 
27
30
  @property
@@ -31,30 +34,117 @@ class AddressAnalysis(Forward[Address]):
31
34
 
32
35
  T = TypeVar("T")
33
36
 
34
- def get_const_value(self, typ: type[T], value: ir.SSAValue) -> T:
35
- if isinstance(hint := value.hints.get("const"), const.Value):
36
- data = hint.data
37
- if isinstance(data, typ):
38
- return hint.data
39
- raise interp.InterpreterError(
40
- f"Expected constant value <type = {typ}>, got {data}"
41
- )
42
- raise interp.InterpreterError(
43
- f"Expected constant value <type = {typ}>, got {value}"
44
- )
45
-
46
- def eval_stmt_fallback(
47
- self, frame: ForwardFrame[Address], stmt: ir.Statement
48
- ) -> tuple[Address, ...] | interp.SpecialValue[Address]:
49
- return tuple(
50
- (
51
- self.lattice.top()
52
- if result.type.is_subseteq(QubitType)
53
- else self.lattice.bottom()
54
- )
55
- for result in stmt.results
56
- )
57
-
58
- def run_method(self, method: ir.Method, args: tuple[Address, ...]):
59
- # NOTE: we do not support dynamic calls here, thus no need to propagate method object
60
- return self.run_callable(method.code, (self.lattice.bottom(),) + args)
37
+ def to_address(self, result: const.Result):
38
+ return ConstResult(result)
39
+
40
+ def try_eval_const_prop(
41
+ self,
42
+ frame: ForwardFrame[Address],
43
+ stmt: ir.Statement,
44
+ args: tuple[ConstResult, ...],
45
+ ) -> interp.StatementResult[Address]:
46
+ _frame = self._const_prop.initialize_frame(frame.code)
47
+ _frame.set_values(stmt.args, tuple(x.result for x in args))
48
+ result = self._const_prop.frame_eval(_frame, stmt)
49
+
50
+ match result:
51
+ case interp.ReturnValue(constant_ret):
52
+ return interp.ReturnValue(self.to_address(constant_ret))
53
+ case interp.YieldValue(constant_values):
54
+ return interp.YieldValue(tuple(map(self.to_address, constant_values)))
55
+ case interp.Successor(block, block_args):
56
+ return interp.Successor(block, *map(self.to_address, block_args))
57
+ case tuple():
58
+ return tuple(map(self.to_address, result))
59
+ case _:
60
+ return result
61
+
62
+ def unpack_iterable(self, iterable: Address):
63
+ """Extract the values of a container lattice element.
64
+
65
+ Args:
66
+ iterable: The lattice element representing a container.
67
+
68
+ Returns:
69
+ A tuple of the container type and the contained values.
70
+
71
+ """
72
+
73
+ def from_constant(constant: const.Result) -> Address:
74
+ return ConstResult(constant)
75
+
76
+ def from_literal(literal: Any) -> Address:
77
+ return ConstResult(const.Value(literal))
78
+
79
+ match iterable:
80
+ case PartialIList(data):
81
+ return PartialIList, data
82
+ case PartialTuple(data):
83
+ return PartialTuple, data
84
+ case AddressReg():
85
+ return PartialIList, iterable.qubits
86
+ case ConstResult(const.Value(IList() as data)):
87
+ return PartialIList, tuple(map(from_literal, data))
88
+ case ConstResult(const.Value(tuple() as data)):
89
+ return PartialTuple, tuple(map(from_literal, data))
90
+ case ConstResult(const.PartialTuple(data)):
91
+ return PartialTuple, tuple(map(from_constant, data))
92
+ case _:
93
+ return None, ()
94
+
95
+ def run_lattice(
96
+ self,
97
+ callee: Address,
98
+ inputs: tuple[Address, ...],
99
+ keys: tuple[str, ...],
100
+ kwargs: tuple[Address, ...],
101
+ ) -> Address:
102
+ """Run a callable lattice element with the given inputs and keyword arguments.
103
+
104
+ Args:
105
+ callee (Address): The lattice element representing the callable.
106
+ inputs (tuple[Address, ...]): The input lattice elements.
107
+ kwargs (tuple[str, ...]): The keyword argument names.
108
+
109
+ Returns:
110
+ Address: The resulting lattice element after invoking the callable.
111
+
112
+ """
113
+
114
+ match callee:
115
+ case PartialLambda(code=code):
116
+ _, ret = self.call(
117
+ code, callee, *inputs, **{k: v for k, v in zip(keys, kwargs)}
118
+ )
119
+ case ConstResult(const.Value(ir.Method() as method)):
120
+ _, ret = self.call(
121
+ method.code,
122
+ self.method_self(method),
123
+ *inputs,
124
+ **{k: v for k, v in zip(keys, kwargs)},
125
+ )
126
+ return ret
127
+ case _:
128
+ return Address.top()
129
+
130
+ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
131
+ if not isinstance(addr, ConstResult):
132
+ return None
133
+
134
+ if not isinstance(result := addr.result, const.Value):
135
+ return None
136
+
137
+ if not isinstance(value := result.data, typ):
138
+ return None
139
+
140
+ return value
141
+
142
+ def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement):
143
+ args = frame.get_values(node.args)
144
+ if types.is_tuple_of(args, ConstResult):
145
+ return self.try_eval_const_prop(frame, node, args)
146
+
147
+ return tuple(Address.from_type(result.type) for result in node.results)
148
+
149
+ def method_self(self, method: ir.Method) -> Address:
150
+ return ConstResult(const.Value(method))
@@ -2,32 +2,56 @@
2
2
  qubit.address method table for a few builtin dialects.
3
3
  """
4
4
 
5
- from kirin import interp
5
+ from itertools import chain
6
+
7
+ from kirin import ir, interp
6
8
  from kirin.analysis import ForwardFrame, const
7
9
  from kirin.dialects import cf, py, scf, func, ilist
8
10
 
9
11
  from .lattice import (
10
12
  Address,
11
- NotQubit,
12
- AddressReg,
13
- AddressQubit,
14
- AddressTuple,
13
+ ConstResult,
14
+ PartialIList,
15
+ PartialTuple,
16
+ PartialLambda,
15
17
  )
16
18
  from .analysis import AddressAnalysis
17
19
 
18
20
 
21
+ @py.constant.dialect.register(key="qubit.address")
22
+ class PyConstant(interp.MethodTable):
23
+ @interp.impl(py.Constant)
24
+ def constant(
25
+ self,
26
+ interp_: AddressAnalysis,
27
+ frame: ForwardFrame[Address],
28
+ stmt: py.Constant,
29
+ ):
30
+ return (ConstResult(const.Value(stmt.value.unwrap())),)
31
+
32
+
19
33
  @py.binop.dialect.register(key="qubit.address")
20
34
  class PyBinOp(interp.MethodTable):
21
-
22
35
  @interp.impl(py.Add)
23
- def add(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.Add):
36
+ def add(
37
+ self,
38
+ interp_: AddressAnalysis,
39
+ frame: ForwardFrame[Address],
40
+ stmt: py.Add,
41
+ ):
24
42
  lhs = frame.get(stmt.lhs)
25
43
  rhs = frame.get(stmt.rhs)
26
44
 
27
- if isinstance(lhs, AddressTuple) and isinstance(rhs, AddressTuple):
28
- return (AddressTuple(data=lhs.data + rhs.data),)
29
- else:
30
- return (NotQubit(),)
45
+ lhs_type, lhs_values = interp_.unpack_iterable(lhs)
46
+ rhs_type, rhs_values = interp_.unpack_iterable(rhs)
47
+
48
+ if lhs_type is None or rhs_type is None:
49
+ return (Address.top(),)
50
+
51
+ if lhs_type is not rhs_type:
52
+ return (Address.bottom(),)
53
+
54
+ return (lhs_type(tuple(chain(lhs_values, rhs_values))),)
31
55
 
32
56
 
33
57
  @py.tuple.dialect.register(key="qubit.address")
@@ -35,110 +59,299 @@ class PyTuple(interp.MethodTable):
35
59
  @interp.impl(py.tuple.New)
36
60
  def new_tuple(
37
61
  self,
38
- interp: AddressAnalysis,
39
- frame: interp.Frame,
62
+ interp_: AddressAnalysis,
63
+ frame: ForwardFrame[Address],
40
64
  stmt: py.tuple.New,
41
65
  ):
42
- return (AddressTuple(frame.get_values(stmt.args)),)
66
+ return (PartialTuple(frame.get_values(stmt.args)),)
43
67
 
44
68
 
45
69
  @ilist.dialect.register(key="qubit.address")
46
- class IList(interp.MethodTable):
70
+ class IListMethods(interp.MethodTable):
47
71
  @interp.impl(ilist.New)
48
72
  def new_ilist(
49
73
  self,
50
- interp: AddressAnalysis,
51
- frame: interp.Frame,
74
+ interp_: AddressAnalysis,
75
+ frame: ForwardFrame[Address],
52
76
  stmt: ilist.New,
53
77
  ):
54
- return (AddressTuple(frame.get_values(stmt.values)),)
55
-
78
+ return (PartialIList(frame.get_values(stmt.args)),)
56
79
 
57
- @py.list.dialect.register(key="qubit.address")
58
- class PyList(interp.MethodTable):
59
- @interp.impl(py.list.New)
60
- def new_ilist(
80
+ @interp.impl(ilist.ForEach)
81
+ @interp.impl(ilist.Map)
82
+ def map_(
61
83
  self,
62
- interp: AddressAnalysis,
63
- frame: interp.Frame,
64
- stmt: py.list.New,
84
+ interp_: AddressAnalysis,
85
+ frame: ForwardFrame[Address],
86
+ stmt: ilist.Map | ilist.ForEach,
65
87
  ):
66
- return (AddressTuple(frame.get_values(stmt.args)),)
88
+ fn = frame.get(stmt.fn)
89
+ collection = frame.get(stmt.collection)
90
+ collection_type, values = interp_.unpack_iterable(collection)
91
+
92
+ if collection_type is None:
93
+ return (Address.top(),)
94
+
95
+ if collection_type is not PartialIList:
96
+ return (Address.bottom(),)
97
+
98
+ results = []
99
+ for ele in values:
100
+ ret = interp_.run_lattice(fn, (ele,), (), ())
101
+ results.append(ret)
102
+
103
+ if isinstance(stmt, ilist.Map):
104
+ return (PartialIList(tuple(results)),)
105
+
106
+
107
+ @py.len.dialect.register(key="qubit.address")
108
+ class PyLen(interp.MethodTable):
109
+ @interp.impl(py.Len)
110
+ def len_(
111
+ self, interp_: AddressAnalysis, frame: ForwardFrame[Address], stmt: py.Len
112
+ ):
113
+ obj = frame.get(stmt.value)
114
+ _, values = interp_.unpack_iterable(obj)
115
+
116
+ if values is None:
117
+ return (Address.top(),)
118
+
119
+ return (ConstResult(const.Value(len(values))),)
67
120
 
68
121
 
69
122
  @py.indexing.dialect.register(key="qubit.address")
70
123
  class PyIndexing(interp.MethodTable):
71
124
  @interp.impl(py.GetItem)
72
- def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem):
73
-
125
+ def getitem(
126
+ self,
127
+ interp_: AddressAnalysis,
128
+ frame: ForwardFrame[Address],
129
+ stmt: py.GetItem,
130
+ ):
74
131
  # determine if the index is an int constant
75
132
  # or a slice
76
- hint = stmt.index.hints.get("const")
77
- if hint is None:
78
- return (NotQubit(),)
79
-
80
- if isinstance(hint, const.Value):
81
- idx = hint.data
82
- elif isinstance(hint, slice):
83
- idx = hint
84
- else:
85
- return (NotQubit(),)
86
-
87
- # The object being indexed into
88
133
  obj = frame.get(stmt.obj)
89
- # The `data` attributes holds onto other Address types
90
- # so we just extract that here
91
- if isinstance(obj, AddressTuple):
92
- return (obj.data[idx],)
93
- # If idx is an integer index into an AddressReg,
94
- # then it's safe to assume a single qubit is being accessed.
95
- # On the other hand, if it's a slice, we return
96
- # a new AddressReg to preserve the new sequence.
97
- elif isinstance(obj, AddressReg):
98
- if isinstance(idx, slice):
99
- return (AddressReg(data=obj.data[idx]),)
100
- if isinstance(idx, int):
101
- return (AddressQubit(obj.data[idx]),)
102
- else:
103
- return (NotQubit(),)
134
+ index = frame.get(stmt.index)
135
+
136
+ typ, values = interp_.unpack_iterable(obj)
137
+ if typ is None:
138
+ return (Address.top(),)
139
+
140
+ int_index = interp_.get_const_value(index, int)
141
+ if int_index is not None:
142
+ return (values[int_index],)
143
+
144
+ slice_index = interp_.get_const_value(index, slice)
145
+ if slice_index is not None:
146
+ return (typ(values[slice_index]),)
147
+
148
+ return (Address.top(),)
104
149
 
105
150
 
106
151
  @py.assign.dialect.register(key="qubit.address")
107
152
  class PyAssign(interp.MethodTable):
108
153
  @interp.impl(py.Alias)
109
- def alias(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.Alias):
154
+ def alias(
155
+ self,
156
+ interp: AddressAnalysis,
157
+ frame: ForwardFrame[Address],
158
+ stmt: py.Alias,
159
+ ):
110
160
  return (frame.get(stmt.value),)
111
161
 
112
162
 
163
+ # TODO: look for abstract method table for func.
113
164
  @func.dialect.register(key="qubit.address")
114
165
  class Func(interp.MethodTable):
115
166
  @interp.impl(func.Return)
116
- def return_(self, _: AddressAnalysis, frame: interp.Frame, stmt: func.Return):
167
+ def return_(
168
+ self,
169
+ _: AddressAnalysis,
170
+ frame: ForwardFrame[Address],
171
+ stmt: func.Return,
172
+ ):
117
173
  return interp.ReturnValue(frame.get(stmt.value))
118
174
 
119
175
  # TODO: replace with the generic implementation
120
176
  @interp.impl(func.Invoke)
121
- def invoke(self, interp_: AddressAnalysis, frame: interp.Frame, stmt: func.Invoke):
122
- _, ret = interp_.run_method(
123
- stmt.callee,
124
- interp_.permute_values(
125
- stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
126
- ),
177
+ def invoke(
178
+ self,
179
+ interp_: AddressAnalysis,
180
+ frame: ForwardFrame[Address],
181
+ stmt: func.Invoke,
182
+ ):
183
+ _, ret = interp_.call(
184
+ stmt.callee.code,
185
+ interp_.method_self(stmt.callee),
186
+ *frame.get_values(stmt.inputs),
127
187
  )
188
+
128
189
  return (ret,)
129
190
 
130
- # TODO: support lambda?
191
+ @interp.impl(func.Lambda)
192
+ def lambda_(
193
+ self,
194
+ inter_: AddressAnalysis,
195
+ frame: ForwardFrame[Address],
196
+ stmt: func.Lambda,
197
+ ):
198
+ arg_names = [
199
+ arg.name or str(idx) for idx, arg in enumerate(stmt.body.blocks[0].args)
200
+ ]
201
+ return (
202
+ PartialLambda(
203
+ arg_names,
204
+ stmt,
205
+ frame.get_values(stmt.captured),
206
+ ),
207
+ )
208
+
209
+ @interp.impl(func.Call)
210
+ def call(
211
+ self,
212
+ interp_: AddressAnalysis,
213
+ frame: ForwardFrame[Address],
214
+ stmt: func.Call,
215
+ ):
216
+ result = interp_.run_lattice(
217
+ frame.get(stmt.callee),
218
+ frame.get_values(stmt.inputs),
219
+ stmt.keys,
220
+ frame.get_values(stmt.kwargs),
221
+ )
222
+ return (result,)
223
+
224
+ @interp.impl(func.GetField)
225
+ def get_field(
226
+ self,
227
+ interp_: AddressAnalysis,
228
+ frame: ForwardFrame[Address],
229
+ stmt: func.GetField,
230
+ ):
231
+ self_mt = frame.get(stmt.obj)
232
+ match self_mt:
233
+ case PartialLambda(captured=captured):
234
+ return (captured[stmt.field],)
235
+ case ConstResult(const.Value(ir.Method() as mt)):
236
+ return (ConstResult(const.Value(mt.fields[stmt.field])),)
237
+
238
+ return (Address.top(),)
131
239
 
132
240
 
133
241
  @cf.dialect.register(key="qubit.address")
134
- class Cf(cf.typeinfer.TypeInfer):
135
- # NOTE: cf just re-use the type infer method table
136
- # it's the same process as type infer.
137
- pass
242
+ class Cf(interp.MethodTable):
243
+
244
+ @interp.impl(cf.Branch)
245
+ def branch(
246
+ self,
247
+ interp_: AddressAnalysis,
248
+ frame: ForwardFrame[Address],
249
+ stmt: cf.Branch,
250
+ ):
251
+ frame.worklist.append(
252
+ interp.Successor(stmt.successor, *frame.get_values(stmt.arguments))
253
+ )
254
+ return ()
255
+
256
+ @interp.impl(cf.ConditionalBranch)
257
+ def conditional_branch(
258
+ self,
259
+ interp_: const.Propagate,
260
+ frame: ForwardFrame[Address],
261
+ stmt: cf.ConditionalBranch,
262
+ ):
263
+ address_cond = frame.get(stmt.cond)
264
+
265
+ if isinstance(address_cond, ConstResult) and isinstance(
266
+ cond := address_cond.result, const.Value
267
+ ):
268
+ else_successor = interp.Successor(
269
+ stmt.else_successor, *frame.get_values(stmt.else_arguments)
270
+ )
271
+ then_successor = interp.Successor(
272
+ stmt.then_successor, *frame.get_values(stmt.then_arguments)
273
+ )
274
+ if cond.data:
275
+ frame.worklist.append(then_successor)
276
+ else:
277
+ frame.worklist.append(else_successor)
278
+ else:
279
+ frame.entries[stmt.cond] = ConstResult(const.Value(True))
280
+ then_successor = interp.Successor(
281
+ stmt.then_successor, *frame.get_values(stmt.then_arguments)
282
+ )
283
+ frame.worklist.append(then_successor)
284
+
285
+ frame.entries[stmt.cond] = ConstResult(const.Value(False))
286
+ else_successor = interp.Successor(
287
+ stmt.else_successor, *frame.get_values(stmt.else_arguments)
288
+ )
289
+ frame.worklist.append(else_successor)
290
+
291
+ frame.entries[stmt.cond] = address_cond
292
+ return ()
138
293
 
139
294
 
140
295
  @scf.dialect.register(key="qubit.address")
141
- class Scf(scf.absint.Methods):
296
+ class Scf(interp.MethodTable):
297
+ @interp.impl(scf.Yield)
298
+ def yield_(
299
+ self,
300
+ interp_: AddressAnalysis,
301
+ frame: ForwardFrame[Address],
302
+ stmt: scf.Yield,
303
+ ):
304
+ return interp.YieldValue(frame.get_values(stmt.values))
305
+
306
+ @interp.impl(scf.IfElse)
307
+ def ifelse(
308
+ self,
309
+ interp_: AddressAnalysis,
310
+ frame: ForwardFrame[Address],
311
+ stmt: scf.IfElse,
312
+ ):
313
+ address_cond = frame.get(stmt.cond)
314
+ # run specific branch
315
+ if isinstance(address_cond, ConstResult) and isinstance(
316
+ const_cond := address_cond.result, const.Value
317
+ ):
318
+ body = stmt.then_body if const_cond.data else stmt.else_body
319
+ with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
320
+ ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
321
+ # interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
322
+ return ret
323
+ else:
324
+ # run both branches
325
+ with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
326
+ then_results = interp_.frame_call_region(
327
+ then_frame,
328
+ stmt,
329
+ stmt.then_body,
330
+ address_cond,
331
+ )
332
+ frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
333
+
334
+ with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
335
+ else_results = interp_.frame_call_region(
336
+ else_frame,
337
+ stmt,
338
+ stmt.else_body,
339
+ address_cond,
340
+ )
341
+ frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
342
+ # TODO: pick the non-return value
343
+ if isinstance(then_results, interp.ReturnValue) and isinstance(
344
+ else_results, interp.ReturnValue
345
+ ):
346
+ return interp.ReturnValue(then_results.value.join(else_results.value))
347
+ elif isinstance(then_results, interp.ReturnValue):
348
+ ret = else_results
349
+ elif isinstance(else_results, interp.ReturnValue):
350
+ ret = then_results
351
+ else:
352
+ ret = interp_.join_results(then_results, else_results)
353
+
354
+ return ret
142
355
 
143
356
  @interp.impl(scf.For)
144
357
  def for_loop(
@@ -147,32 +360,22 @@ class Scf(scf.absint.Methods):
147
360
  frame: ForwardFrame[Address],
148
361
  stmt: scf.For,
149
362
  ):
150
- if not isinstance(hint := stmt.iterable.hints.get("const"), const.Value):
151
- return interp_.eval_stmt_fallback(frame, stmt)
152
-
153
- iterable = hint.data
154
363
  loop_vars = frame.get_values(stmt.initializers)
155
- body_block = stmt.body.blocks[0]
156
- block_args = body_block.args
157
-
158
- # NOTE: we need to actually run iteration in case there are
159
- # new allocations/re-assign in the loop body.
160
- for _ in iterable:
161
- with interp_.new_frame(stmt) as body_frame:
162
- body_frame.entries.update(frame.entries)
163
- body_frame.set_values(
164
- block_args,
165
- (NotQubit(),) + loop_vars,
364
+ iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable))
365
+
366
+ if iter_type is None:
367
+ return interp_.eval_fallback(frame, stmt)
368
+
369
+ for value in iterable:
370
+ with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
371
+ loop_vars = interp_.frame_call_region(
372
+ body_frame, stmt, stmt.body, value, *loop_vars
166
373
  )
167
- loop_vars = interp_.run_ssacfg_region(body_frame, stmt.body, ())
168
374
 
169
375
  if loop_vars is None:
170
376
  loop_vars = ()
377
+
171
378
  elif isinstance(loop_vars, interp.ReturnValue):
172
379
  return loop_vars
173
380
 
174
- if isinstance(body_block.last_stmt, func.Return):
175
- frame.worklist.append(interp.Successor(body_block, NotQubit(), *loop_vars))
176
- return # if terminate is Return, there is no result
177
-
178
381
  return loop_vars