bloqade-circuit 0.7.12__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bloqade-circuit might be problematic. Click here for more details.

Files changed (136) hide show
  1. bloqade/analysis/address/__init__.py +8 -4
  2. bloqade/analysis/address/analysis.py +119 -29
  3. bloqade/analysis/address/impls.py +290 -87
  4. bloqade/analysis/address/lattice.py +209 -24
  5. bloqade/analysis/fidelity/analysis.py +2 -2
  6. bloqade/analysis/measure_id/impls.py +3 -27
  7. bloqade/cirq_utils/__init__.py +3 -1
  8. bloqade/cirq_utils/emit/__init__.py +3 -0
  9. bloqade/cirq_utils/emit/base.py +243 -0
  10. bloqade/cirq_utils/emit/gate.py +104 -0
  11. bloqade/cirq_utils/emit/noise.py +90 -0
  12. bloqade/cirq_utils/emit/qubit.py +35 -0
  13. bloqade/cirq_utils/lowering.py +664 -0
  14. bloqade/native/__init__.py +0 -1
  15. bloqade/native/_prelude.py +3 -3
  16. bloqade/native/dialects/gate/__init__.py +2 -0
  17. bloqade/native/dialects/gate/_dialect.py +3 -0
  18. bloqade/native/dialects/{gates → gate}/_interface.py +5 -5
  19. bloqade/native/dialects/{gates → gate}/stmts.py +5 -5
  20. bloqade/native/stdlib/broadcast.py +19 -19
  21. bloqade/native/stdlib/simple.py +14 -13
  22. bloqade/native/upstream/__init__.py +5 -0
  23. bloqade/native/upstream/squin2native.py +136 -0
  24. bloqade/pyqrack/__init__.py +1 -2
  25. bloqade/pyqrack/device.py +6 -17
  26. bloqade/pyqrack/native.py +17 -17
  27. bloqade/pyqrack/reg.py +1 -6
  28. bloqade/pyqrack/squin/gate/__init__.py +1 -0
  29. bloqade/pyqrack/squin/gate/gate.py +136 -0
  30. bloqade/pyqrack/squin/noise/native.py +120 -54
  31. bloqade/pyqrack/squin/qubit.py +25 -41
  32. bloqade/pyqrack/target.py +2 -2
  33. bloqade/qasm2/dialects/core/address.py +21 -12
  34. bloqade/qasm2/dialects/noise/fidelity.py +2 -6
  35. bloqade/qasm2/dialects/noise/model.py +2 -1
  36. bloqade/qasm2/passes/parallel.py +3 -1
  37. bloqade/qasm2/rewrite/__init__.py +0 -1
  38. bloqade/qasm2/rewrite/noise/heuristic_noise.py +7 -17
  39. bloqade/qasm2/rewrite/parallel_to_glob.py +28 -15
  40. bloqade/qasm2/rewrite/parallel_to_uop.py +2 -8
  41. bloqade/qubit/__init__.py +12 -0
  42. bloqade/qubit/_dialect.py +3 -0
  43. bloqade/qubit/_interface.py +49 -0
  44. bloqade/qubit/_prelude.py +45 -0
  45. bloqade/qubit/analysis/__init__.py +1 -0
  46. bloqade/qubit/analysis/address_impl.py +40 -0
  47. bloqade/qubit/stdlib/__init__.py +2 -0
  48. bloqade/qubit/stdlib/_new.py +34 -0
  49. bloqade/qubit/stdlib/broadcast.py +62 -0
  50. bloqade/qubit/stdlib/simple.py +59 -0
  51. bloqade/qubit/stmts.py +60 -0
  52. bloqade/rewrite/passes/aggressive_unroll.py +2 -1
  53. bloqade/squin/__init__.py +44 -17
  54. bloqade/squin/analysis/__init__.py +0 -1
  55. bloqade/squin/analysis/schedule.py +2 -2
  56. bloqade/squin/gate/__init__.py +2 -0
  57. bloqade/squin/gate/_dialect.py +3 -0
  58. bloqade/squin/gate/_interface.py +98 -0
  59. bloqade/squin/gate/stmts.py +119 -0
  60. bloqade/squin/groups.py +4 -21
  61. bloqade/squin/noise/__init__.py +1 -9
  62. bloqade/squin/noise/_dialect.py +1 -1
  63. bloqade/squin/noise/_interface.py +45 -0
  64. bloqade/squin/noise/stmts.py +65 -29
  65. bloqade/squin/rewrite/U3_to_clifford.py +70 -51
  66. bloqade/squin/rewrite/__init__.py +0 -2
  67. bloqade/squin/rewrite/remove_dangling_qubits.py +2 -2
  68. bloqade/squin/rewrite/wrap_analysis.py +4 -35
  69. bloqade/squin/stdlib/broadcast/__init__.py +34 -0
  70. bloqade/squin/stdlib/broadcast/_qubit.py +4 -0
  71. bloqade/squin/stdlib/broadcast/gate.py +260 -0
  72. bloqade/squin/stdlib/broadcast/noise.py +144 -0
  73. bloqade/squin/stdlib/simple/__init__.py +33 -0
  74. bloqade/squin/stdlib/simple/gate.py +242 -0
  75. bloqade/squin/stdlib/simple/noise.py +126 -0
  76. bloqade/stim/__init__.py +1 -0
  77. bloqade/stim/_wrappers.py +6 -0
  78. bloqade/stim/dialects/noise/emit.py +6 -1
  79. bloqade/stim/dialects/noise/stmts.py +5 -3
  80. bloqade/stim/emit/stim_str.py +2 -0
  81. bloqade/stim/parse/lowering.py +12 -17
  82. bloqade/stim/passes/__init__.py +0 -1
  83. bloqade/stim/passes/flatten.py +26 -0
  84. bloqade/stim/passes/simplify_ifs.py +6 -1
  85. bloqade/stim/passes/squin_to_stim.py +4 -70
  86. bloqade/stim/rewrite/__init__.py +0 -4
  87. bloqade/stim/rewrite/ifs_to_stim.py +23 -29
  88. bloqade/stim/rewrite/qubit_to_stim.py +90 -41
  89. bloqade/stim/rewrite/squin_measure.py +9 -18
  90. bloqade/stim/rewrite/squin_noise.py +132 -108
  91. bloqade/stim/rewrite/util.py +5 -204
  92. bloqade/types.py +10 -0
  93. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/METADATA +2 -2
  94. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/RECORD +96 -100
  95. bloqade/native/dialects/gates/__init__.py +0 -3
  96. bloqade/native/dialects/gates/_dialect.py +0 -3
  97. bloqade/pyqrack/squin/op.py +0 -180
  98. bloqade/pyqrack/squin/runtime.py +0 -543
  99. bloqade/pyqrack/squin/wire.py +0 -51
  100. bloqade/squin/_typeinfer.py +0 -20
  101. bloqade/squin/analysis/address_impl.py +0 -71
  102. bloqade/squin/analysis/nsites/__init__.py +0 -9
  103. bloqade/squin/analysis/nsites/analysis.py +0 -50
  104. bloqade/squin/analysis/nsites/impls.py +0 -99
  105. bloqade/squin/analysis/nsites/lattice.py +0 -49
  106. bloqade/squin/cirq/__init__.py +0 -306
  107. bloqade/squin/cirq/emit/emit_circuit.py +0 -129
  108. bloqade/squin/cirq/emit/noise.py +0 -49
  109. bloqade/squin/cirq/emit/op.py +0 -176
  110. bloqade/squin/cirq/emit/qubit.py +0 -58
  111. bloqade/squin/cirq/emit/runtime.py +0 -242
  112. bloqade/squin/cirq/lowering.py +0 -439
  113. bloqade/squin/lowering.py +0 -80
  114. bloqade/squin/noise/_wrapper.py +0 -36
  115. bloqade/squin/noise/rewrite.py +0 -129
  116. bloqade/squin/op/__init__.py +0 -41
  117. bloqade/squin/op/_dialect.py +0 -3
  118. bloqade/squin/op/_wrapper.py +0 -121
  119. bloqade/squin/op/number.py +0 -5
  120. bloqade/squin/op/rewrite.py +0 -46
  121. bloqade/squin/op/stdlib.py +0 -62
  122. bloqade/squin/op/stmts.py +0 -300
  123. bloqade/squin/op/traits.py +0 -43
  124. bloqade/squin/op/types.py +0 -128
  125. bloqade/squin/parallel.py +0 -200
  126. bloqade/squin/qubit.py +0 -194
  127. bloqade/squin/rewrite/canonicalize.py +0 -60
  128. bloqade/squin/rewrite/desugar.py +0 -102
  129. bloqade/squin/stdlib/channel.py +0 -86
  130. bloqade/squin/stdlib/gate.py +0 -201
  131. bloqade/squin/types.py +0 -8
  132. bloqade/squin/wire.py +0 -201
  133. bloqade/stim/rewrite/wire_identity_elimination.py +0 -24
  134. bloqade/stim/rewrite/wire_to_stim.py +0 -57
  135. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/WHEEL +0 -0
  136. {bloqade_circuit-0.7.12.dist-info → bloqade_circuit-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,15 @@
1
1
  from . import impls as impls
2
2
  from .lattice import (
3
+ Bottom as Bottom,
3
4
  Address as Address,
4
- NotQubit as NotQubit,
5
+ Unknown as Unknown,
5
6
  AddressReg as AddressReg,
6
- AnyAddress as AnyAddress,
7
- AddressWire as AddressWire,
7
+ UnknownReg as UnknownReg,
8
+ ConstResult as ConstResult,
8
9
  AddressQubit as AddressQubit,
9
- AddressTuple as AddressTuple,
10
+ PartialIList as PartialIList,
11
+ PartialTuple as PartialTuple,
12
+ UnknownQubit as UnknownQubit,
13
+ PartialLambda as PartialLambda,
10
14
  )
11
15
  from .analysis import AddressAnalysis as AddressAnalysis
@@ -1,13 +1,13 @@
1
- from typing import TypeVar
1
+ from typing import Any, Type, TypeVar
2
2
  from dataclasses import field
3
3
 
4
- from kirin import ir, interp
4
+ from kirin import ir, types, interp
5
5
  from kirin.analysis import Forward, const
6
+ from kirin.dialects.ilist import IList
6
7
  from kirin.analysis.forward import ForwardFrame
8
+ from kirin.analysis.const.lattice import PartialLambda
7
9
 
8
- from bloqade.types import QubitType
9
-
10
- from .lattice import Address
10
+ from .lattice import Address, AddressReg, ConstResult, PartialIList, PartialTuple
11
11
 
12
12
 
13
13
  class AddressAnalysis(Forward[Address]):
@@ -16,12 +16,15 @@ class AddressAnalysis(Forward[Address]):
16
16
  """
17
17
 
18
18
  keys = ["qubit.address"]
19
+ _const_prop: const.Propagate
19
20
  lattice = Address
20
21
  next_address: int = field(init=False)
21
22
 
22
23
  def initialize(self):
23
24
  super().initialize()
24
25
  self.next_address: int = 0
26
+ self._const_prop = const.Propagate(self.dialects)
27
+ self._const_prop.initialize()
25
28
  return self
26
29
 
27
30
  @property
@@ -31,30 +34,117 @@ class AddressAnalysis(Forward[Address]):
31
34
 
32
35
  T = TypeVar("T")
33
36
 
34
- def get_const_value(self, typ: type[T], value: ir.SSAValue) -> T:
35
- if isinstance(hint := value.hints.get("const"), const.Value):
36
- data = hint.data
37
- if isinstance(data, typ):
38
- return hint.data
39
- raise interp.InterpreterError(
40
- f"Expected constant value <type = {typ}>, got {data}"
41
- )
42
- raise interp.InterpreterError(
43
- f"Expected constant value <type = {typ}>, got {value}"
44
- )
45
-
46
- def eval_stmt_fallback(
47
- self, frame: ForwardFrame[Address], stmt: ir.Statement
48
- ) -> tuple[Address, ...] | interp.SpecialValue[Address]:
49
- return tuple(
50
- (
51
- self.lattice.top()
52
- if result.type.is_subseteq(QubitType)
53
- else self.lattice.bottom()
54
- )
55
- for result in stmt.results
56
- )
37
+ def to_address(self, result: const.Result):
38
+ return ConstResult(result)
39
+
40
+ def try_eval_const_prop(
41
+ self,
42
+ frame: ForwardFrame[Address],
43
+ stmt: ir.Statement,
44
+ args: tuple[ConstResult, ...],
45
+ ) -> interp.StatementResult[Address]:
46
+ _frame = self._const_prop.initialize_frame(frame.code)
47
+ _frame.set_values(stmt.args, tuple(x.result for x in args))
48
+ result = self._const_prop.eval_stmt(_frame, stmt)
49
+
50
+ match result:
51
+ case interp.ReturnValue(constant_ret):
52
+ return interp.ReturnValue(self.to_address(constant_ret))
53
+ case interp.YieldValue(constant_values):
54
+ return interp.YieldValue(tuple(map(self.to_address, constant_values)))
55
+ case interp.Successor(block, block_args):
56
+ return interp.Successor(block, *map(self.to_address, block_args))
57
+ case tuple():
58
+ return tuple(map(self.to_address, result))
59
+ case _:
60
+ return result
61
+
62
+ def unpack_iterable(self, iterable: Address):
63
+ """Extract the values of a container lattice element.
64
+
65
+ Args:
66
+ iterable: The lattice element representing a container.
67
+
68
+ Returns:
69
+ A tuple of the container type and the contained values.
70
+
71
+ """
72
+
73
+ def from_constant(constant: const.Result) -> Address:
74
+ return ConstResult(constant)
75
+
76
+ def from_literal(literal: Any) -> Address:
77
+ return ConstResult(const.Value(literal))
78
+
79
+ match iterable:
80
+ case PartialIList(data):
81
+ return PartialIList, data
82
+ case PartialTuple(data):
83
+ return PartialTuple, data
84
+ case AddressReg():
85
+ return PartialIList, iterable.qubits
86
+ case ConstResult(const.Value(IList() as data)):
87
+ return PartialIList, tuple(map(from_literal, data))
88
+ case ConstResult(const.Value(tuple() as data)):
89
+ return PartialTuple, tuple(map(from_literal, data))
90
+ case ConstResult(const.PartialTuple(data)):
91
+ return PartialTuple, tuple(map(from_constant, data))
92
+ case _:
93
+ return None, ()
94
+
95
+ def run_lattice(
96
+ self,
97
+ callee: Address,
98
+ inputs: tuple[Address, ...],
99
+ kwargs: tuple[str, ...],
100
+ ) -> Address:
101
+ """Run a callable lattice element with the given inputs and keyword arguments.
102
+
103
+ Args:
104
+ callee (Address): The lattice element representing the callable.
105
+ inputs (tuple[Address, ...]): The input lattice elements.
106
+ kwargs (tuple[str, ...]): The keyword argument names.
107
+
108
+ Returns:
109
+ Address: The resulting lattice element after invoking the callable.
110
+
111
+ """
112
+
113
+ match callee:
114
+ case PartialLambda(code=code, argnames=argnames):
115
+ _, ret = self.run_callable(
116
+ code, (callee,) + self.permute_values(argnames, inputs, kwargs)
117
+ )
118
+ return ret
119
+ case ConstResult(const.Value(ir.Method() as method)):
120
+ _, ret = self.run_method(
121
+ method,
122
+ self.permute_values(method.arg_names, inputs, kwargs),
123
+ )
124
+ return ret
125
+ case _:
126
+ return Address.top()
127
+
128
+ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
129
+ if not isinstance(addr, ConstResult):
130
+ return None
131
+
132
+ if not isinstance(result := addr.result, const.Value):
133
+ return None
134
+
135
+ if not isinstance(value := result.data, typ):
136
+ return None
137
+
138
+ return value
139
+
140
+ def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement):
141
+ args = frame.get_values(stmt.args)
142
+ if types.is_tuple_of(args, ConstResult):
143
+ return self.try_eval_const_prop(frame, stmt, args)
144
+
145
+ return tuple(Address.from_type(result.type) for result in stmt.results)
57
146
 
58
147
  def run_method(self, method: ir.Method, args: tuple[Address, ...]):
59
148
  # NOTE: we do not support dynamic calls here, thus no need to propagate method object
60
- return self.run_callable(method.code, (self.lattice.bottom(),) + args)
149
+ self_mt = ConstResult(const.Value(method))
150
+ return self.run_callable(method.code, (self_mt,) + args)
@@ -2,32 +2,56 @@
2
2
  qubit.address method table for a few builtin dialects.
3
3
  """
4
4
 
5
- from kirin import interp
5
+ from itertools import chain
6
+
7
+ from kirin import ir, interp
6
8
  from kirin.analysis import ForwardFrame, const
7
9
  from kirin.dialects import cf, py, scf, func, ilist
8
10
 
9
11
  from .lattice import (
10
12
  Address,
11
- NotQubit,
12
- AddressReg,
13
- AddressQubit,
14
- AddressTuple,
13
+ ConstResult,
14
+ PartialIList,
15
+ PartialTuple,
16
+ PartialLambda,
15
17
  )
16
18
  from .analysis import AddressAnalysis
17
19
 
18
20
 
21
+ @py.constant.dialect.register(key="qubit.address")
22
+ class PyConstant(interp.MethodTable):
23
+ @interp.impl(py.Constant)
24
+ def constant(
25
+ self,
26
+ interp_: AddressAnalysis,
27
+ frame: ForwardFrame[Address],
28
+ stmt: py.Constant,
29
+ ):
30
+ return (ConstResult(const.Value(stmt.value.unwrap())),)
31
+
32
+
19
33
  @py.binop.dialect.register(key="qubit.address")
20
34
  class PyBinOp(interp.MethodTable):
21
-
22
35
  @interp.impl(py.Add)
23
- def add(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.Add):
36
+ def add(
37
+ self,
38
+ interp_: AddressAnalysis,
39
+ frame: ForwardFrame[Address],
40
+ stmt: py.Add,
41
+ ):
24
42
  lhs = frame.get(stmt.lhs)
25
43
  rhs = frame.get(stmt.rhs)
26
44
 
27
- if isinstance(lhs, AddressTuple) and isinstance(rhs, AddressTuple):
28
- return (AddressTuple(data=lhs.data + rhs.data),)
29
- else:
30
- return (NotQubit(),)
45
+ lhs_type, lhs_values = interp_.unpack_iterable(lhs)
46
+ rhs_type, rhs_values = interp_.unpack_iterable(rhs)
47
+
48
+ if lhs_type is None or rhs_type is None:
49
+ return (Address.top(),)
50
+
51
+ if lhs_type is not rhs_type:
52
+ return (Address.bottom(),)
53
+
54
+ return (lhs_type(tuple(chain(lhs_values, rhs_values))),)
31
55
 
32
56
 
33
57
  @py.tuple.dialect.register(key="qubit.address")
@@ -35,110 +59,299 @@ class PyTuple(interp.MethodTable):
35
59
  @interp.impl(py.tuple.New)
36
60
  def new_tuple(
37
61
  self,
38
- interp: AddressAnalysis,
39
- frame: interp.Frame,
62
+ interp_: AddressAnalysis,
63
+ frame: ForwardFrame[Address],
40
64
  stmt: py.tuple.New,
41
65
  ):
42
- return (AddressTuple(frame.get_values(stmt.args)),)
66
+ return (PartialTuple(frame.get_values(stmt.args)),)
43
67
 
44
68
 
45
69
  @ilist.dialect.register(key="qubit.address")
46
- class IList(interp.MethodTable):
70
+ class IListMethods(interp.MethodTable):
47
71
  @interp.impl(ilist.New)
48
72
  def new_ilist(
49
73
  self,
50
- interp: AddressAnalysis,
51
- frame: interp.Frame,
74
+ interp_: AddressAnalysis,
75
+ frame: ForwardFrame[Address],
52
76
  stmt: ilist.New,
53
77
  ):
54
- return (AddressTuple(frame.get_values(stmt.values)),)
78
+ return (PartialIList(frame.get_values(stmt.args)),)
55
79
 
56
-
57
- @py.list.dialect.register(key="qubit.address")
58
- class PyList(interp.MethodTable):
59
- @interp.impl(py.list.New)
60
- def new_ilist(
80
+ @interp.impl(ilist.ForEach)
81
+ @interp.impl(ilist.Map)
82
+ def map_(
61
83
  self,
62
- interp: AddressAnalysis,
63
- frame: interp.Frame,
64
- stmt: py.list.New,
84
+ interp_: AddressAnalysis,
85
+ frame: ForwardFrame[Address],
86
+ stmt: ilist.Map | ilist.ForEach,
65
87
  ):
66
- return (AddressTuple(frame.get_values(stmt.args)),)
88
+ fn = frame.get(stmt.fn)
89
+ collection = frame.get(stmt.collection)
90
+ collection_type, values = interp_.unpack_iterable(collection)
91
+
92
+ if collection_type is None:
93
+ return (Address.top(),)
94
+
95
+ if collection_type is not PartialIList:
96
+ return (Address.bottom(),)
97
+
98
+ results = []
99
+ for ele in values:
100
+ ret = interp_.run_lattice(fn, (ele,), ())
101
+ results.append(ret)
102
+
103
+ if isinstance(stmt, ilist.Map):
104
+ return (PartialIList(tuple(results)),)
105
+
106
+
107
+ @py.len.dialect.register(key="qubit.address")
108
+ class PyLen(interp.MethodTable):
109
+ @interp.impl(py.Len)
110
+ def len_(
111
+ self, interp_: AddressAnalysis, frame: ForwardFrame[Address], stmt: py.Len
112
+ ):
113
+ obj = frame.get(stmt.value)
114
+ _, values = interp_.unpack_iterable(obj)
115
+
116
+ if values is None:
117
+ return (Address.top(),)
118
+
119
+ return (ConstResult(const.Value(len(values))),)
67
120
 
68
121
 
69
122
  @py.indexing.dialect.register(key="qubit.address")
70
123
  class PyIndexing(interp.MethodTable):
71
124
  @interp.impl(py.GetItem)
72
- def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem):
73
-
125
+ def getitem(
126
+ self,
127
+ interp_: AddressAnalysis,
128
+ frame: ForwardFrame[Address],
129
+ stmt: py.GetItem,
130
+ ):
74
131
  # determine if the index is an int constant
75
132
  # or a slice
76
- hint = stmt.index.hints.get("const")
77
- if hint is None:
78
- return (NotQubit(),)
79
-
80
- if isinstance(hint, const.Value):
81
- idx = hint.data
82
- elif isinstance(hint, slice):
83
- idx = hint
84
- else:
85
- return (NotQubit(),)
86
-
87
- # The object being indexed into
88
133
  obj = frame.get(stmt.obj)
89
- # The `data` attributes holds onto other Address types
90
- # so we just extract that here
91
- if isinstance(obj, AddressTuple):
92
- return (obj.data[idx],)
93
- # If idx is an integer index into an AddressReg,
94
- # then it's safe to assume a single qubit is being accessed.
95
- # On the other hand, if it's a slice, we return
96
- # a new AddressReg to preserve the new sequence.
97
- elif isinstance(obj, AddressReg):
98
- if isinstance(idx, slice):
99
- return (AddressReg(data=obj.data[idx]),)
100
- if isinstance(idx, int):
101
- return (AddressQubit(obj.data[idx]),)
102
- else:
103
- return (NotQubit(),)
134
+ index = frame.get(stmt.index)
135
+
136
+ typ, values = interp_.unpack_iterable(obj)
137
+ if typ is None:
138
+ return (Address.top(),)
139
+
140
+ int_index = interp_.get_const_value(index, int)
141
+ if int_index is not None:
142
+ return (values[int_index],)
143
+
144
+ slice_index = interp_.get_const_value(index, slice)
145
+ if slice_index is not None:
146
+ return (typ(values[slice_index]),)
147
+
148
+ return (Address.top(),)
104
149
 
105
150
 
106
151
  @py.assign.dialect.register(key="qubit.address")
107
152
  class PyAssign(interp.MethodTable):
108
153
  @interp.impl(py.Alias)
109
- def alias(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.Alias):
154
+ def alias(
155
+ self,
156
+ interp: AddressAnalysis,
157
+ frame: ForwardFrame[Address],
158
+ stmt: py.Alias,
159
+ ):
110
160
  return (frame.get(stmt.value),)
111
161
 
112
162
 
163
+ # TODO: look for abstract method table for func.
113
164
  @func.dialect.register(key="qubit.address")
114
165
  class Func(interp.MethodTable):
115
166
  @interp.impl(func.Return)
116
- def return_(self, _: AddressAnalysis, frame: interp.Frame, stmt: func.Return):
167
+ def return_(
168
+ self,
169
+ _: AddressAnalysis,
170
+ frame: ForwardFrame[Address],
171
+ stmt: func.Return,
172
+ ):
117
173
  return interp.ReturnValue(frame.get(stmt.value))
118
174
 
119
175
  # TODO: replace with the generic implementation
120
176
  @interp.impl(func.Invoke)
121
- def invoke(self, interp_: AddressAnalysis, frame: interp.Frame, stmt: func.Invoke):
177
+ def invoke(
178
+ self,
179
+ interp_: AddressAnalysis,
180
+ frame: ForwardFrame[Address],
181
+ stmt: func.Invoke,
182
+ ):
183
+
184
+ args = interp_.permute_values(
185
+ stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
186
+ )
122
187
  _, ret = interp_.run_method(
123
188
  stmt.callee,
124
- interp_.permute_values(
125
- stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
126
- ),
189
+ args,
127
190
  )
191
+
128
192
  return (ret,)
129
193
 
130
- # TODO: support lambda?
194
+ @interp.impl(func.Lambda)
195
+ def lambda_(
196
+ self,
197
+ inter_: AddressAnalysis,
198
+ frame: ForwardFrame[Address],
199
+ stmt: func.Lambda,
200
+ ):
201
+ arg_names = [
202
+ arg.name or str(idx) for idx, arg in enumerate(stmt.body.blocks[0].args)
203
+ ]
204
+ return (
205
+ PartialLambda(
206
+ arg_names,
207
+ stmt,
208
+ frame.get_values(stmt.captured),
209
+ ),
210
+ )
211
+
212
+ @interp.impl(func.Call)
213
+ def call(
214
+ self,
215
+ interp_: AddressAnalysis,
216
+ frame: ForwardFrame[Address],
217
+ stmt: func.Call,
218
+ ):
219
+ result = interp_.run_lattice(
220
+ frame.get(stmt.callee),
221
+ frame.get_values(stmt.inputs),
222
+ stmt.kwargs,
223
+ )
224
+ return (result,)
225
+
226
+ @interp.impl(func.GetField)
227
+ def get_field(
228
+ self,
229
+ interp_: AddressAnalysis,
230
+ frame: ForwardFrame[Address],
231
+ stmt: func.GetField,
232
+ ):
233
+ self_mt = frame.get(stmt.obj)
234
+ match self_mt:
235
+ case PartialLambda(captured=captured):
236
+ return (captured[stmt.field],)
237
+ case ConstResult(const.Value(ir.Method() as mt)):
238
+ return (ConstResult(const.Value(mt.fields[stmt.field])),)
239
+
240
+ return (Address.top(),)
131
241
 
132
242
 
133
243
  @cf.dialect.register(key="qubit.address")
134
- class Cf(cf.typeinfer.TypeInfer):
135
- # NOTE: cf just re-use the type infer method table
136
- # it's the same process as type infer.
137
- pass
244
+ class Cf(interp.MethodTable):
245
+
246
+ @interp.impl(cf.Branch)
247
+ def branch(
248
+ self,
249
+ interp_: AddressAnalysis,
250
+ frame: ForwardFrame[Address],
251
+ stmt: cf.Branch,
252
+ ):
253
+ frame.worklist.append(
254
+ interp.Successor(stmt.successor, *frame.get_values(stmt.arguments))
255
+ )
256
+ return ()
257
+
258
+ @interp.impl(cf.ConditionalBranch)
259
+ def conditional_branch(
260
+ self,
261
+ interp_: const.Propagate,
262
+ frame: ForwardFrame[Address],
263
+ stmt: cf.ConditionalBranch,
264
+ ):
265
+ address_cond = frame.get(stmt.cond)
266
+
267
+ if isinstance(address_cond, ConstResult) and isinstance(
268
+ cond := address_cond.result, const.Value
269
+ ):
270
+ else_successor = interp.Successor(
271
+ stmt.else_successor, *frame.get_values(stmt.else_arguments)
272
+ )
273
+ then_successor = interp.Successor(
274
+ stmt.then_successor, *frame.get_values(stmt.then_arguments)
275
+ )
276
+ if cond.data:
277
+ frame.worklist.append(then_successor)
278
+ else:
279
+ frame.worklist.append(else_successor)
280
+ else:
281
+ frame.entries[stmt.cond] = ConstResult(const.Value(True))
282
+ then_successor = interp.Successor(
283
+ stmt.then_successor, *frame.get_values(stmt.then_arguments)
284
+ )
285
+ frame.worklist.append(then_successor)
286
+
287
+ frame.entries[stmt.cond] = ConstResult(const.Value(False))
288
+ else_successor = interp.Successor(
289
+ stmt.else_successor, *frame.get_values(stmt.else_arguments)
290
+ )
291
+ frame.worklist.append(else_successor)
292
+
293
+ frame.entries[stmt.cond] = address_cond
294
+ return ()
138
295
 
139
296
 
140
297
  @scf.dialect.register(key="qubit.address")
141
- class Scf(scf.absint.Methods):
298
+ class Scf(interp.MethodTable):
299
+ @interp.impl(scf.Yield)
300
+ def yield_(
301
+ self,
302
+ interp_: AddressAnalysis,
303
+ frame: ForwardFrame[Address],
304
+ stmt: scf.Yield,
305
+ ):
306
+ return interp.YieldValue(frame.get_values(stmt.values))
307
+
308
+ @interp.impl(scf.IfElse)
309
+ def ifelse(
310
+ self,
311
+ interp_: AddressAnalysis,
312
+ frame: ForwardFrame[Address],
313
+ stmt: scf.IfElse,
314
+ ):
315
+ address_cond = frame.get(stmt.cond)
316
+ # run specific branch
317
+ if isinstance(address_cond, ConstResult) and isinstance(
318
+ const_cond := address_cond.result, const.Value
319
+ ):
320
+ body = stmt.then_body if const_cond.data else stmt.else_body
321
+ with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
322
+ ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,))
323
+ # interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
324
+ return ret
325
+ else:
326
+ # run both branches
327
+ with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
328
+ then_results = interp_.run_ssacfg_region(
329
+ then_frame, stmt.then_body, (address_cond,)
330
+ )
331
+ interp_.set_values(
332
+ frame, then_frame.entries.keys(), then_frame.entries.values()
333
+ )
334
+
335
+ with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
336
+ else_results = interp_.run_ssacfg_region(
337
+ else_frame, stmt.else_body, (address_cond,)
338
+ )
339
+ interp_.set_values(
340
+ frame, else_frame.entries.keys(), else_frame.entries.values()
341
+ )
342
+ # TODO: pick the non-return value
343
+ if isinstance(then_results, interp.ReturnValue) and isinstance(
344
+ else_results, interp.ReturnValue
345
+ ):
346
+ return interp.ReturnValue(then_results.value.join(else_results.value))
347
+ elif isinstance(then_results, interp.ReturnValue):
348
+ ret = else_results
349
+ elif isinstance(else_results, interp.ReturnValue):
350
+ ret = then_results
351
+ else:
352
+ ret = interp_.join_results(then_results, else_results)
353
+
354
+ return ret
142
355
 
143
356
  @interp.impl(scf.For)
144
357
  def for_loop(
@@ -147,32 +360,22 @@ class Scf(scf.absint.Methods):
147
360
  frame: ForwardFrame[Address],
148
361
  stmt: scf.For,
149
362
  ):
150
- if not isinstance(hint := stmt.iterable.hints.get("const"), const.Value):
363
+ loop_vars = frame.get_values(stmt.initializers)
364
+ iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable))
365
+
366
+ if iter_type is None:
151
367
  return interp_.eval_stmt_fallback(frame, stmt)
152
368
 
153
- iterable = hint.data
154
- loop_vars = frame.get_values(stmt.initializers)
155
- body_block = stmt.body.blocks[0]
156
- block_args = body_block.args
157
-
158
- # NOTE: we need to actually run iteration in case there are
159
- # new allocations/re-assign in the loop body.
160
- for _ in iterable:
161
- with interp_.new_frame(stmt) as body_frame:
162
- body_frame.entries.update(frame.entries)
163
- body_frame.set_values(
164
- block_args,
165
- (NotQubit(),) + loop_vars,
369
+ for value in iterable:
370
+ with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
371
+ loop_vars = interp_.run_ssacfg_region(
372
+ body_frame, stmt.body, (value,) + loop_vars
166
373
  )
167
- loop_vars = interp_.run_ssacfg_region(body_frame, stmt.body, ())
168
374
 
169
375
  if loop_vars is None:
170
376
  loop_vars = ()
377
+
171
378
  elif isinstance(loop_vars, interp.ReturnValue):
172
379
  return loop_vars
173
380
 
174
- if isinstance(body_block.last_stmt, func.Return):
175
- frame.worklist.append(interp.Successor(body_block, NotQubit(), *loop_vars))
176
- return # if terminate is Return, there is no result
177
-
178
381
  return loop_vars