da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +239 -73
  5. da4ml/cmvm/core/__init__.py +222 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +49 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +265 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +177 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +120 -0
  35. da4ml-0.2.0.dist-info/METADATA +65 -0
  36. da4ml-0.2.0.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
da4ml/__init__.py CHANGED
@@ -1,17 +1,17 @@
1
- from .cmvm.api import cost, fn_from_kernel
2
- from .cmvm.cmvm import compile_kernel
3
- from .cmvm.codegen import PyCodegenBackend, VitisCodegenBackend
4
- from .cmvm.graph_compile import graph_compile_states
5
- from .cmvm.utils import DAState, OpCode, Score
1
+ # from .cmvm.api import cost, fn_from_kernel
2
+ # from .cmvm.cmvm import compile_kernel
3
+ # from .cmvm.codegen import PyCodegenBackend, VitisCodegenBackend
4
+ # from .cmvm.graph_compile import graph_compile_states
5
+ # from .cmvm.utils import DAState, OpCode, Score
6
6
 
7
- __all__ = [
8
- 'DAState',
9
- 'OpCode',
10
- 'Score',
11
- 'cost',
12
- 'compile_kernel',
13
- 'fn_from_kernel',
14
- 'graph_compile_states',
15
- 'PyCodegenBackend',
16
- 'VitisCodegenBackend',
17
- ]
7
+ # __all__ = [
8
+ # 'DAState',
9
+ # 'OpCode',
10
+ # 'Score',
11
+ # 'cost',
12
+ # 'compile_kernel',
13
+ # 'fn_from_kernel',
14
+ # 'graph_compile_states',
15
+ # 'PyCodegenBackend',
16
+ # 'VitisCodegenBackend',
17
+ # ]
da4ml/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.2'
21
- __version_tuple__ = version_tuple = (0, 1, 2)
20
+ __version__ = version = '0.2.0'
21
+ __version_tuple__ = version_tuple = (0, 2, 0)
da4ml/cmvm/__init__.py CHANGED
@@ -1,35 +1,4 @@
1
- import numpy as np
1
+ from .api import minimal_latency, solve
2
+ from .types import Op, QInterval, Solution
2
3
 
3
- from .api import cost, fn_from_kernel
4
- from .cmvm import compile_kernel
5
- from .codegen import PyCodegenBackend
6
- from .graph_compile import graph_compile_states
7
- from .utils import DAState, OpCode, Score
8
-
9
- d_in = 2
10
- d_out = 2
11
- kernel = np.ones((d_in, d_out), dtype=np.float32)
12
- signs = [False] * d_in
13
- bits = [8] * d_in
14
- int_bits = [0] * d_in
15
- symmetrics = [False] * d_in
16
- depths = [0] * d_in
17
-
18
- print('The da4ml library is compiling. This will take a while...', end=' ')
19
- _ = fn_from_kernel(
20
- kernel=kernel,
21
- signs=signs,
22
- bits=bits,
23
- int_bits=int_bits,
24
- symmetrics=symmetrics,
25
- depths=depths,
26
- n_beams=1,
27
- dc=None,
28
- n_inp_max=-1,
29
- n_out_max=-1,
30
- codegen_backend=PyCodegenBackend(),
31
- )
32
- print('Done')
33
-
34
-
35
- __all__ = ['DAState', 'OpCode', 'Score', 'cost', 'compile_kernel', 'fn_from_kernel', 'graph_compile_states']
4
+ __all__ = ['minimal_latency', 'solve', 'QInterval', 'Op', 'Solution']
da4ml/cmvm/api.py CHANGED
@@ -1,91 +1,257 @@
1
- import re
2
- import typing
3
- from collections.abc import Callable
1
+ from math import ceil, log2
4
2
 
5
3
  import numpy as np
4
+ from numba import jit, prange
6
5
 
7
- from .cmvm import compile_kernel
8
- from .codegen import Namer, PyCodegenBackend
9
- from .graph_compile import graph_compile_states
6
+ from .core import _solve, create_state, to_solution
7
+ from .types import CascadedSolution, QInterval
8
+ from .util import kernel_decompose
10
9
 
11
- m = re.compile(r'Latency: (\d+)')
12
- T = typing.TypeVar('T')
13
10
 
11
+ @jit(cache=True)
12
+ def minimal_latency(
13
+ kernel: np.ndarray,
14
+ qintervals: list[QInterval],
15
+ latencies: list[float],
16
+ carry_size: int = -1,
17
+ adder_size: int = -1,
18
+ ):
19
+ """Fast latency calculation for a given kernel, QInterval, and input latencies.
20
+ When carry_size=-1, and the input latency is constant `l`:
21
+ this will be the same as `l + max(ceiling(log2(max(#CSD bits for each column, 1))))`.
22
+
23
+ Parameters
24
+ ----------
25
+ kernel : np.ndarray
26
+ The input kernel matrix.
27
+ qintervals : list[QInterval]
28
+ List of QIntervals for each input.
29
+ latencies : list[float]
30
+ List of latencies for each input
31
+ carry_size : int, optional
32
+ The size of the carry unit for latency computation, by default -1 (fixed latency for each addition operation)
33
+ adder_size : int, optional
34
+ The size of the adder unit for latency computation, by default -1 (fixed cost for each addition operation)
35
+
36
+ Returns
37
+ -------
38
+ float
39
+ The minimal latency for the given kernel, QInterval, and input latencies.
40
+ """
14
41
 
15
- def fn_from_kernel(
42
+ state = create_state(kernel, qintervals, latencies, no_stat_init=True)
43
+ sol = to_solution(state, adder_size=adder_size, carry_size=carry_size)
44
+ latencies = [sol.ops[i].latency if i >= 0 else 0.0 for i in sol.out_idxs]
45
+ return max(latencies)
46
+
47
+
48
+ @jit(cache=True)
49
+ def jit_solve(
16
50
  kernel: np.ndarray,
17
- signs: list[bool],
18
- bits: list[int],
19
- int_bits: list[int],
20
- symmetrics: list[bool],
21
- depths: list[int] | None = None,
22
- n_beams: int = 1,
23
- dc: int | None = None,
24
- n_inp_max: int = -1,
25
- n_out_max: int = -1,
26
- codegen_backend: PyCodegenBackend = PyCodegenBackend(),
27
- signed_balanced_reduction: bool = True,
28
- ) -> tuple[Callable[[list[T]], list[T]], str]:
29
- """Compile a CMVM operation, with the constant kernel, into a function with only accumulation/subtraction/shift operations.
51
+ method0: str = 'wmc',
52
+ method1: str = 'auto',
53
+ hard_dc: int = -1,
54
+ decompose_dc: int = -2,
55
+ qintervals: list[QInterval] | None = None,
56
+ latencies: list[float] | None = None,
57
+ adder_size: int = -1,
58
+ carry_size: int = -1,
59
+ ) -> CascadedSolution:
60
+ """Optimized implementation of a CMVM computation with cascaded two matrices.
30
61
 
31
62
  Parameters
32
63
  ----------
33
64
  kernel : np.ndarray
34
- The kernel to compile. Must be of shape (n_inp, n_out).
35
- signs : list[bool]
36
- If the input is signed. Must be of length n_inp.
37
- bits : list[int]
38
- The bitwidth of the inputs. Must be of length n_inp.
39
- int_bits : list[int]
40
- The number of integer bits in the inputs (incl. sign bit!). Must be of length n_inp.
41
- symmetrics : list[bool]
42
- If the input is symmetricly quantized. Must be of length n_inp.
43
- depths : list[int]|None, optional
44
- The depth associated with each input. Must be of length n_inp. Defaults to [0]*n_inp.
45
- n_beams : int, optional
46
- Number of beams to use in beam search. Defaults to 1. (Currently disabled!)
47
- dc : int | None, optional
48
- Delay constraint. Not (properly) implemented yet. Defaults to None.
49
- n_inp_max : int, optional
50
- Number of inputs to process in one block. Defaults to -1 (no limit). Decrease to improve performance, but result will be less optimal.
51
- n_out_max : int, optional
52
- Number of outputs to process in one block. Defaults to -1 (no limit). Decrease to improve performance, but result will be less optimal.
53
- codegen_backend : PyCodegenBackend, optional
54
- The codegen backend to be used. Defaults to PyCodegenBackend().
55
- signed_balanced_reduction : bool, optional
56
- If the reduction tree should isolate the plus and minus terms. Set to False to improve latency. Defaults to True.
65
+ The input kernel matrix to be implemented.
66
+ method0 : str, optional
67
+ Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
68
+ method1 : str, optional
69
+ Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
70
+ hard_dc : int, optional
71
+ Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
72
+ decompose_dc : int, optional
73
+ Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
74
+ qintervals : list[QInterval] | None, optional
75
+ List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
76
+ inp_latencies : list[float] | None, optional
77
+ List of input latencies, by default None (0. for all inputs)
78
+ adder_size : int, optional
79
+ Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
80
+ carry_size : int, optional
81
+ Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
57
82
 
58
83
  Returns
59
84
  -------
60
- tuple[Callable[[list[T]], list[T]], str]
61
- fn : Callable[[list[T]], list[T]]
62
- The compiled python function. It takes a list of inputs and returns a list of outputs with only accumulation/subtraction/powers of 2 operations.
63
- fn_str : str
64
- The code of the compiled function, depending on the codegen_backend used.
85
+ CascadedSolution
86
+ A solution containing the optimized implementation of the CMVM computation with cascaded stages.
65
87
  """
66
88
 
67
- assert n_beams == 1, 'n_beams>1 is disabled for now. Change line 159 & 160 in this file to enable it.'
68
- if depths is None:
69
- depths = [0] * len(signs)
70
- states = compile_kernel(
71
- kernel=kernel,
72
- signs=signs,
73
- bits=bits,
74
- int_bits=int_bits,
75
- symmetrics=symmetrics,
76
- depths=depths,
77
- n_beams=n_beams,
78
- dc=dc,
79
- n_inp_max=n_inp_max,
80
- n_out_max=n_out_max,
81
- )
82
- with Namer().tmp_scope():
83
- inp, out = graph_compile_states(states, signed_balanced_reduction)
84
- fn, fn_str = codegen_backend(inp, out)
85
- return fn, fn_str
89
+ if hard_dc < 0:
90
+ hard_dc = int(1e9)
91
+
92
+ if method1 == 'auto':
93
+ if hard_dc >= 6 or method0.endswith('dc'):
94
+ method1 = method0
95
+ else:
96
+ method1 = method0 + '-dc'
97
+ if hard_dc == 0 and not method0.endswith('dc'):
98
+ method0 = method0 + '-dc'
86
99
 
100
+ if qintervals is None:
101
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
102
+ else:
103
+ _qintervals = list(qintervals)
104
+ if latencies is None:
105
+ _inp_latencies = [0.0] * kernel.shape[0]
106
+ else:
107
+ _inp_latencies = [float(lat) for lat in latencies]
108
+ assert len(_qintervals) == kernel.shape[0]
109
+ assert len(_inp_latencies) == kernel.shape[0]
87
110
 
88
- def cost(fn_str: str):
89
- n_add = fn_str.count('\n') - 3 - fn_str.count('out[')
90
- latency = m.findall(fn_str)[-1]
91
- return n_add, int(latency)
111
+ min_lat = minimal_latency(kernel, _qintervals, _inp_latencies, carry_size=carry_size, adder_size=adder_size)
112
+ latency_allowed = hard_dc + min_lat
113
+ if decompose_dc == -2:
114
+ decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
115
+ else:
116
+ decompose_dc = min(hard_dc, decompose_dc, ceil(log2(kernel.shape[0])))
117
+
118
+ while True:
119
+ if decompose_dc < 0 and hard_dc >= 0:
120
+ if method0 != 'dummy':
121
+ method0, method1 = 'wmc-dc', 'wmc-dc'
122
+ else:
123
+ method0, method1 = 'dummy', 'dummy'
124
+ mat0, mat1 = kernel_decompose(kernel, dc=decompose_dc)
125
+ sol0 = _solve(
126
+ mat0, method=method0, qintervals=_qintervals, latencies=_inp_latencies, adder_size=adder_size, carry_size=carry_size
127
+ )
128
+ latencies0 = [sol0.ops[i].latency if i >= 0 else 0.0 for i in sol0.out_idxs]
129
+ qintervals0 = [sol0.ops[i].qint if i >= 0 else QInterval(0.0, 0.0, np.inf) for i in sol0.out_idxs]
130
+ if max(latencies0) > latency_allowed:
131
+ if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
132
+ decompose_dc -= 1
133
+ continue
134
+ sol1 = _solve(
135
+ mat1, method=method1, qintervals=qintervals0, latencies=latencies0, adder_size=adder_size, carry_size=carry_size
136
+ )
137
+ latencies1 = [sol1.ops[i].latency if i >= 0 else 0.0 for i in sol1.out_idxs]
138
+ if max(latencies1) > latency_allowed:
139
+ # Prevent infinite loop, shouldn't happen though
140
+ if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
141
+ decompose_dc -= 1
142
+ continue
143
+ if sum([op.cost for op in sol1.ops]) * 4 > sum([op.cost for op in sol0.ops]) and decompose_dc > 0:
144
+ # If the second stage is too expensive, the decomposition usually doesn't worth it
145
+ decompose_dc -= 1
146
+ continue
147
+ break
148
+ if max(latencies1) > latency_allowed:
149
+ # When latency depends on the bw, may happen
150
+ print(f'Latency constraint not satisfied: {int(latency_allowed)} < {int(max(latencies1))}')
151
+ return CascadedSolution((sol0, sol1))
152
+
153
+
154
+ @jit(cache=True, parallel=True)
155
+ def solve(
156
+ kernel: np.ndarray,
157
+ method0: str = 'wmc',
158
+ method1: str = 'auto',
159
+ hard_dc: int = -1,
160
+ decompose_dc: int = -2,
161
+ qintervals: tuple[QInterval, ...] | None = None,
162
+ latencies: tuple[float, ...] | None = None,
163
+ adder_size: int = -1,
164
+ carry_size: int = -1,
165
+ search_all_decompose_dc: bool = True,
166
+ ) -> CascadedSolution:
167
+ """Solve the CMVM problem with cascaded two matrices.
168
+
169
+ Parameters
170
+ ----------
171
+ kernel : np.ndarray
172
+ The input kernel matrix to be implemented.
173
+ method0 : str, optional
174
+ Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
175
+ method1 : str, optional
176
+ Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
177
+ hard_dc : int, optional
178
+ Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
179
+ decompose_dc : int, optional
180
+ Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
181
+ qintervals : list[QInterval] | None, optional
182
+ List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
183
+ inp_latencies : list[float] | None, optional
184
+ List of input latencies, by default None (0. for all inputs)
185
+ adder_size : int, optional
186
+ Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
187
+ carry_size : int, optional
188
+ Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
189
+ search_all_decompose_dc : bool, optional
190
+ If True, search for all possible decomposition depth constraints. If False, use the provided decompose_dc value.
191
+ Default is True.
192
+
193
+ Returns
194
+ -------
195
+ CascadedSolution
196
+ A solution containing the optimized implementation of the CMVM computation with cascaded stages.
197
+ """
198
+
199
+ if qintervals is None:
200
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
201
+ else:
202
+ _qintervals = list(qintervals)
203
+ if latencies is None:
204
+ _latencies = [0.0] * kernel.shape[0]
205
+ else:
206
+ _latencies = [float(lat) for lat in latencies]
207
+
208
+ if not search_all_decompose_dc:
209
+ return jit_solve(
210
+ kernel,
211
+ method0=method0,
212
+ method1=method1,
213
+ hard_dc=hard_dc,
214
+ decompose_dc=decompose_dc,
215
+ qintervals=_qintervals,
216
+ latencies=_latencies,
217
+ adder_size=adder_size,
218
+ carry_size=carry_size,
219
+ )
220
+
221
+ if hard_dc < 0:
222
+ hard_dc = int(1e9)
223
+
224
+ max_decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
225
+ try_decompose_dcs = list(range(-1, max_decompose_dc + 1))
226
+
227
+ costs = np.empty(len(try_decompose_dcs), dtype=np.float64)
228
+
229
+ for i in prange(len(try_decompose_dcs)):
230
+ decompose_dc = try_decompose_dcs[i]
231
+ _csol = jit_solve(
232
+ kernel,
233
+ method0=method0,
234
+ method1=method1,
235
+ hard_dc=hard_dc,
236
+ decompose_dc=decompose_dc,
237
+ qintervals=_qintervals,
238
+ latencies=_latencies,
239
+ adder_size=adder_size,
240
+ carry_size=carry_size,
241
+ )
242
+ _cost = sum([sum([op.cost for op in sol.ops]) for sol in _csol.solutions])
243
+ costs[i] = _cost
244
+
245
+ decompose_dc = try_decompose_dcs[np.argmin(costs)]
246
+ csol = jit_solve(
247
+ kernel,
248
+ method0=method0,
249
+ method1=method1,
250
+ hard_dc=hard_dc,
251
+ decompose_dc=decompose_dc,
252
+ qintervals=_qintervals,
253
+ latencies=_latencies,
254
+ adder_size=adder_size,
255
+ carry_size=carry_size,
256
+ )
257
+ return csol
@@ -0,0 +1,222 @@
1
+ import heapq
2
+ from math import log2
3
+
4
+ import numpy as np
5
+ from numba import jit
6
+
7
+ from ..types import DAState, Op, QInterval, Solution
8
+ from .indexers import (
9
+ idx_mc,
10
+ idx_mc_dc,
11
+ idx_wmc,
12
+ idx_wmc_dc,
13
+ )
14
+ from .state_opr import cost_add, create_state, qint_add, update_state
15
+
16
+
17
+ @jit(cache=True)
18
+ def cmvm(
19
+ kernel: np.ndarray,
20
+ method: str = 'wmc',
21
+ qintervals: list[QInterval] | None = None,
22
+ inp_latencies: list[float] | None = None,
23
+ adder_size: int = -1,
24
+ carry_size: int = -1,
25
+ ) -> DAState:
26
+ """Optimizes the kernel using the CMVM algorithm.
27
+
28
+ Parameters
29
+ ----------
30
+ kernel : np.ndarray
31
+ The kernel to optimize.
32
+ method : str, optional
33
+ Which indexing method to use, by default 'wmc' (weighted most common)
34
+ Must be one of [`mc`, `mc-dc`, `mc-pdc`, `wmc`, `wmc-dc`, `wmc-pdc`, `dummy`]
35
+ qintervals : list[QInterval] | None, optional
36
+ List of QIntervals for each input, by default None
37
+ If None, defaults to [-128., 127., 1.] for each input.
38
+ inp_latencies : list[float] | None, optional
39
+ List of latencies for each input, by default None
40
+ If None, defaults to 0. for each input.
41
+ adder_size : int, optional
42
+ The atomic size of the adder for cost computation, by default -1
43
+ if -1, each adder can be arbitrary large, and the cost will be the number of adders
44
+ carry_size : int, optional
45
+ The size of the carry unit for latency computation, by default -1
46
+ if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
47
+
48
+ Returns
49
+ -------
50
+ DAState
51
+ The optimized kernel as a DAState object.
52
+ """
53
+
54
+ if qintervals is None:
55
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
56
+ else:
57
+ _qintervals = [QInterval(*qi) for qi in qintervals]
58
+ if inp_latencies is None:
59
+ _inp_latencies = [0.0] * kernel.shape[0]
60
+ else:
61
+ _inp_latencies = [float(lat) for lat in inp_latencies]
62
+ assert len(_qintervals) == kernel.shape[0]
63
+ assert len(_inp_latencies) == kernel.shape[0]
64
+
65
+ state = create_state(kernel, _qintervals, _inp_latencies)
66
+ while True:
67
+ if len(state.freq_stat) == 0:
68
+ break
69
+ match method:
70
+ case 'mc':
71
+ pair_idx = idx_mc(state)
72
+ case 'mc-dc':
73
+ pair_idx = idx_mc_dc(state, absolute=True)
74
+ case 'mc-pdc':
75
+ pair_idx = idx_mc_dc(state, absolute=False)
76
+ case 'wmc':
77
+ pair_idx = idx_wmc(state)
78
+ case 'wmc-dc':
79
+ pair_idx = idx_wmc_dc(state, absolute=True)
80
+ case 'wmc-pdc':
81
+ pair_idx = idx_wmc_dc(state, absolute=False)
82
+ case 'dummy':
83
+ break
84
+ case _:
85
+ raise ValueError(f'Unknown method: {method}')
86
+ if pair_idx < 0:
87
+ break
88
+ pair_chosen = list(state.freq_stat.keys())[pair_idx]
89
+ state = update_state(state, pair_chosen, adder_size=adder_size, carry_size=carry_size)
90
+ return state
91
+
92
+
93
+ @jit(cache=True)
94
+ def to_solution(
95
+ state: DAState,
96
+ adder_size: int,
97
+ carry_size: int,
98
+ ):
99
+ """Converts the DAState to a Solution object with balanced tree reduction for the non-extracted bits in the kernel.
100
+
101
+ Parameters
102
+ ----------
103
+ state : DAState
104
+ The DAState to convert.
105
+ adder_size : int, optional
106
+ The atomic size of the adder for cost computation, by default -1
107
+ if -1, each adder can be arbitrary large, and the cost will be the number of adders
108
+ carry_size : int, optional
109
+ The size of the carry unit for latency computation, by default -1
110
+ if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
111
+
112
+ Returns
113
+ -------
114
+ Solution
115
+ The Solution object with the optimized kernel.
116
+ """
117
+
118
+ ops = state.ops.copy()
119
+ n_out = state.kernel.shape[1]
120
+ expr = np.empty((len(state.expr), *state.expr[0].shape), dtype=np.int8)
121
+ for i, v in enumerate(state.expr):
122
+ expr[i] = v
123
+ in_shifts, out_shifts = state.shifts
124
+
125
+ out_qints = []
126
+ out_lats = []
127
+ out_idx = []
128
+ in_shift = in_shifts.copy()
129
+ out_shift = out_shifts.copy()
130
+ out_neg = []
131
+
132
+ _global_id = len(ops)
133
+ for i_out in range(n_out):
134
+ heap = []
135
+ idx, shifts = np.where(expr[:, i_out] != 0)
136
+ sub = np.empty(len(idx), dtype=np.int64)
137
+ for i, (i_in, shift) in enumerate(zip(idx, shifts)):
138
+ sub[i] = expr[i_in, i_out, shift] == -1
139
+
140
+ qints: list[QInterval] = [state.ops[i].qint for i in idx]
141
+ lats: list[float] = [state.ops[i].latency for i in idx]
142
+
143
+ # No reduction required, dump the realized value directly
144
+ if len(sub) == 1:
145
+ out_shift[i_out] = out_shift[i_out] + shifts[0]
146
+ out_qints.append(qints[0])
147
+ out_lats.append(lats[0])
148
+ out_idx.append(idx[0])
149
+ out_neg.append(sub[0])
150
+ continue
151
+ # Output is zero
152
+ if len(sub) == 0:
153
+ out_idx.append(-1) # -1 means output constant zero
154
+ out_qints.append(QInterval(0.0, 0.0, np.inf))
155
+ out_lats.append(0.0)
156
+ out_neg.append(False)
157
+ continue
158
+
159
+ # Sort by latency -> location of rightmost bit -> lower bound
160
+ left_align: list[int] = []
161
+ for i, qint in enumerate(qints):
162
+ n_int = int(log2(max(abs(qint.max + qint.step), abs(qint.min))))
163
+ left_align.append(n_int + shifts[i])
164
+ heap = list(zip(lats, sub, left_align, qints, idx, shifts))
165
+ heapq.heapify(heap)
166
+
167
+ while len(heap) > 1:
168
+ lat0, sub0, _, qint0, id0, shift0 = heapq.heappop(heap)
169
+ lat1, sub1, _, qint1, id1, shift1 = heapq.heappop(heap)
170
+
171
+ if sub0:
172
+ shift = shift0 - shift1
173
+ qint = qint_add(qint1, qint0, shift, sub1, sub0)
174
+ dlat, dcost = cost_add(qint1, qint0, shift=shift, sub=1 ^ sub1, adder_size=adder_size, carry_size=carry_size)
175
+ lat = max(lat0, lat1) + dlat
176
+ op = Op(id1, id0, 1 ^ sub1, shift, qint, lat, dcost)
177
+ shift = shift1
178
+ else:
179
+ shift = shift1 - shift0
180
+ qint = qint_add(qint0, qint1, shift, sub0, sub1)
181
+ dlat, dcost = cost_add(qint0, qint1, shift=shift, sub=sub1, adder_size=adder_size, carry_size=carry_size)
182
+ lat = max(lat0, lat1) + dlat
183
+ op = Op(id0, id1, sub1, shift, qint, lat, dcost)
184
+ shift = shift0
185
+
186
+ left_align = int(log2(max(abs(qint.max + qint.step), abs(qint.min)))) + shift
187
+ heapq.heappush(heap, (lat, sub0 & sub1, left_align, qint, _global_id, shift))
188
+ ops.append(op)
189
+ _global_id += 1
190
+
191
+ lat, sub, _, qint, id0, shift0 = heap[0]
192
+ out_idx.append(_global_id - 1)
193
+ out_qints.append(qint)
194
+ out_lats.append(lat)
195
+ out_neg.append(sub)
196
+ out_shift[i_out] = out_shift[i_out] + shift0
197
+
198
+ return Solution(
199
+ shape=state.kernel.shape, # type: ignore
200
+ inp_shift=list(in_shift),
201
+ out_idxs=out_idx,
202
+ out_shifts=list(out_shift),
203
+ out_negs=out_neg,
204
+ ops=ops,
205
+ carry_size=carry_size,
206
+ adder_size=adder_size,
207
+ )
208
+
209
+
210
+ @jit
211
+ def _solve(
212
+ kernel: np.ndarray,
213
+ method: str,
214
+ qintervals: list[QInterval],
215
+ latencies: list[float],
216
+ adder_size: int,
217
+ carry_size: int,
218
+ ):
219
+ state = cmvm(
220
+ kernel, method=method, qintervals=qintervals, inp_latencies=latencies, adder_size=adder_size, carry_size=carry_size
221
+ )
222
+ return to_solution(state, adder_size=adder_size, carry_size=carry_size)