da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +16 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +3 -34
- da4ml/cmvm/api.py +239 -73
- da4ml/cmvm/core/__init__.py +222 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +569 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +11 -0
- da4ml/codegen/cpp/__init__.py +3 -0
- da4ml/codegen/cpp/cpp_codegen.py +148 -0
- da4ml/codegen/cpp/source/vitis.h +30 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
- da4ml/codegen/verilog/__init__.py +13 -0
- da4ml/codegen/verilog/comb.py +146 -0
- da4ml/codegen/verilog/io_wrapper.py +255 -0
- da4ml/codegen/verilog/pipeline.py +49 -0
- da4ml/codegen/verilog/source/build_binder.mk +27 -0
- da4ml/codegen/verilog/source/build_prj.tcl +75 -0
- da4ml/codegen/verilog/source/ioutils.hh +117 -0
- da4ml/codegen/verilog/source/shift_adder.v +56 -0
- da4ml/codegen/verilog/source/template.xdc +29 -0
- da4ml/codegen/verilog/verilog_model.py +265 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +358 -0
- da4ml/trace/fixed_variable_array.py +177 -0
- da4ml/trace/ops/__init__.py +55 -0
- da4ml/trace/ops/conv_utils.py +104 -0
- da4ml/trace/ops/einsum_utils.py +299 -0
- da4ml/trace/pipeline.py +155 -0
- da4ml/trace/tracer.py +120 -0
- da4ml-0.2.0.dist-info/METADATA +65 -0
- da4ml-0.2.0.dist-info/RECORD +39 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
- da4ml/cmvm/balanced_reduction.py +0 -46
- da4ml/cmvm/cmvm.py +0 -328
- da4ml/cmvm/codegen.py +0 -159
- da4ml/cmvm/csd.py +0 -73
- da4ml/cmvm/fixed_variable.py +0 -205
- da4ml/cmvm/graph_compile.py +0 -85
- da4ml/cmvm/nb_fixed_precision.py +0 -98
- da4ml/cmvm/scoring.py +0 -55
- da4ml/cmvm/utils.py +0 -5
- da4ml-0.1.2.dist-info/METADATA +0 -122
- da4ml-0.1.2.dist-info/RECORD +0 -18
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
da4ml/__init__.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
from .cmvm.api import cost, fn_from_kernel
|
|
2
|
-
from .cmvm.cmvm import compile_kernel
|
|
3
|
-
from .cmvm.codegen import PyCodegenBackend, VitisCodegenBackend
|
|
4
|
-
from .cmvm.graph_compile import graph_compile_states
|
|
5
|
-
from .cmvm.utils import DAState, OpCode, Score
|
|
1
|
+
# from .cmvm.api import cost, fn_from_kernel
|
|
2
|
+
# from .cmvm.cmvm import compile_kernel
|
|
3
|
+
# from .cmvm.codegen import PyCodegenBackend, VitisCodegenBackend
|
|
4
|
+
# from .cmvm.graph_compile import graph_compile_states
|
|
5
|
+
# from .cmvm.utils import DAState, OpCode, Score
|
|
6
6
|
|
|
7
|
-
__all__ = [
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
]
|
|
7
|
+
# __all__ = [
|
|
8
|
+
# 'DAState',
|
|
9
|
+
# 'OpCode',
|
|
10
|
+
# 'Score',
|
|
11
|
+
# 'cost',
|
|
12
|
+
# 'compile_kernel',
|
|
13
|
+
# 'fn_from_kernel',
|
|
14
|
+
# 'graph_compile_states',
|
|
15
|
+
# 'PyCodegenBackend',
|
|
16
|
+
# 'VitisCodegenBackend',
|
|
17
|
+
# ]
|
da4ml/_version.py
CHANGED
da4ml/cmvm/__init__.py
CHANGED
|
@@ -1,35 +1,4 @@
|
|
|
1
|
-
import
|
|
1
|
+
from .api import minimal_latency, solve
|
|
2
|
+
from .types import Op, QInterval, Solution
|
|
2
3
|
|
|
3
|
-
|
|
4
|
-
from .cmvm import compile_kernel
|
|
5
|
-
from .codegen import PyCodegenBackend
|
|
6
|
-
from .graph_compile import graph_compile_states
|
|
7
|
-
from .utils import DAState, OpCode, Score
|
|
8
|
-
|
|
9
|
-
d_in = 2
|
|
10
|
-
d_out = 2
|
|
11
|
-
kernel = np.ones((d_in, d_out), dtype=np.float32)
|
|
12
|
-
signs = [False] * d_in
|
|
13
|
-
bits = [8] * d_in
|
|
14
|
-
int_bits = [0] * d_in
|
|
15
|
-
symmetrics = [False] * d_in
|
|
16
|
-
depths = [0] * d_in
|
|
17
|
-
|
|
18
|
-
print('The da4ml library is compiling. This will take a while...', end=' ')
|
|
19
|
-
_ = fn_from_kernel(
|
|
20
|
-
kernel=kernel,
|
|
21
|
-
signs=signs,
|
|
22
|
-
bits=bits,
|
|
23
|
-
int_bits=int_bits,
|
|
24
|
-
symmetrics=symmetrics,
|
|
25
|
-
depths=depths,
|
|
26
|
-
n_beams=1,
|
|
27
|
-
dc=None,
|
|
28
|
-
n_inp_max=-1,
|
|
29
|
-
n_out_max=-1,
|
|
30
|
-
codegen_backend=PyCodegenBackend(),
|
|
31
|
-
)
|
|
32
|
-
print('Done')
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
__all__ = ['DAState', 'OpCode', 'Score', 'cost', 'compile_kernel', 'fn_from_kernel', 'graph_compile_states']
|
|
4
|
+
__all__ = ['minimal_latency', 'solve', 'QInterval', 'Op', 'Solution']
|
da4ml/cmvm/api.py
CHANGED
|
@@ -1,91 +1,257 @@
|
|
|
1
|
-
import
|
|
2
|
-
import typing
|
|
3
|
-
from collections.abc import Callable
|
|
1
|
+
from math import ceil, log2
|
|
4
2
|
|
|
5
3
|
import numpy as np
|
|
4
|
+
from numba import jit, prange
|
|
6
5
|
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
6
|
+
from .core import _solve, create_state, to_solution
|
|
7
|
+
from .types import CascadedSolution, QInterval
|
|
8
|
+
from .util import kernel_decompose
|
|
10
9
|
|
|
11
|
-
m = re.compile(r'Latency: (\d+)')
|
|
12
|
-
T = typing.TypeVar('T')
|
|
13
10
|
|
|
11
|
+
@jit(cache=True)
|
|
12
|
+
def minimal_latency(
|
|
13
|
+
kernel: np.ndarray,
|
|
14
|
+
qintervals: list[QInterval],
|
|
15
|
+
latencies: list[float],
|
|
16
|
+
carry_size: int = -1,
|
|
17
|
+
adder_size: int = -1,
|
|
18
|
+
):
|
|
19
|
+
"""Fast latency calculation for a given kernel, QInterval, and input latencies.
|
|
20
|
+
When carry_size=-1, and the input latency is constant `l`:
|
|
21
|
+
this will be the same as `l + max(ceiling(log2(max(#CSD bits for each column, 1))))`.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
kernel : np.ndarray
|
|
26
|
+
The input kernel matrix.
|
|
27
|
+
qintervals : list[QInterval]
|
|
28
|
+
List of QIntervals for each input.
|
|
29
|
+
latencies : list[float]
|
|
30
|
+
List of latencies for each input
|
|
31
|
+
carry_size : int, optional
|
|
32
|
+
The size of the carry unit for latency computation, by default -1 (fixed latency for each addition operation)
|
|
33
|
+
adder_size : int, optional
|
|
34
|
+
The size of the adder unit for latency computation, by default -1 (fixed cost for each addition operation)
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
float
|
|
39
|
+
The minimal latency for the given kernel, QInterval, and input latencies.
|
|
40
|
+
"""
|
|
14
41
|
|
|
15
|
-
|
|
42
|
+
state = create_state(kernel, qintervals, latencies, no_stat_init=True)
|
|
43
|
+
sol = to_solution(state, adder_size=adder_size, carry_size=carry_size)
|
|
44
|
+
latencies = [sol.ops[i].latency if i >= 0 else 0.0 for i in sol.out_idxs]
|
|
45
|
+
return max(latencies)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@jit(cache=True)
|
|
49
|
+
def jit_solve(
|
|
16
50
|
kernel: np.ndarray,
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
signed_balanced_reduction: bool = True,
|
|
28
|
-
) -> tuple[Callable[[list[T]], list[T]], str]:
|
|
29
|
-
"""Compile a CMVM operation, with the constant kernel, into a function with only accumulation/subtraction/shift operations.
|
|
51
|
+
method0: str = 'wmc',
|
|
52
|
+
method1: str = 'auto',
|
|
53
|
+
hard_dc: int = -1,
|
|
54
|
+
decompose_dc: int = -2,
|
|
55
|
+
qintervals: list[QInterval] | None = None,
|
|
56
|
+
latencies: list[float] | None = None,
|
|
57
|
+
adder_size: int = -1,
|
|
58
|
+
carry_size: int = -1,
|
|
59
|
+
) -> CascadedSolution:
|
|
60
|
+
"""Optimized implementation of a CMVM computation with cascaded two matrices.
|
|
30
61
|
|
|
31
62
|
Parameters
|
|
32
63
|
----------
|
|
33
64
|
kernel : np.ndarray
|
|
34
|
-
The kernel to
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
n_out_max : int, optional
|
|
52
|
-
Number of outputs to process in one block. Defaults to -1 (no limit). Decrease to improve performance, but result will be less optimal.
|
|
53
|
-
codegen_backend : PyCodegenBackend, optional
|
|
54
|
-
The codegen backend to be used. Defaults to PyCodegenBackend().
|
|
55
|
-
signed_balanced_reduction : bool, optional
|
|
56
|
-
If the reduction tree should isolate the plus and minus terms. Set to False to improve latency. Defaults to True.
|
|
65
|
+
The input kernel matrix to be implemented.
|
|
66
|
+
method0 : str, optional
|
|
67
|
+
Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
|
|
68
|
+
method1 : str, optional
|
|
69
|
+
Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
|
|
70
|
+
hard_dc : int, optional
|
|
71
|
+
Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
|
|
72
|
+
decompose_dc : int, optional
|
|
73
|
+
Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
|
|
74
|
+
qintervals : list[QInterval] | None, optional
|
|
75
|
+
List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
|
|
76
|
+
inp_latencies : list[float] | None, optional
|
|
77
|
+
List of input latencies, by default None (0. for all inputs)
|
|
78
|
+
adder_size : int, optional
|
|
79
|
+
Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
|
|
80
|
+
carry_size : int, optional
|
|
81
|
+
Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
|
|
57
82
|
|
|
58
83
|
Returns
|
|
59
84
|
-------
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
The compiled python function. It takes a list of inputs and returns a list of outputs with only accumulation/subtraction/powers of 2 operations.
|
|
63
|
-
fn_str : str
|
|
64
|
-
The code of the compiled function, depending on the codegen_backend used.
|
|
85
|
+
CascadedSolution
|
|
86
|
+
A solution containing the optimized implementation of the CMVM computation with cascaded stages.
|
|
65
87
|
"""
|
|
66
88
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
n_beams=n_beams,
|
|
78
|
-
dc=dc,
|
|
79
|
-
n_inp_max=n_inp_max,
|
|
80
|
-
n_out_max=n_out_max,
|
|
81
|
-
)
|
|
82
|
-
with Namer().tmp_scope():
|
|
83
|
-
inp, out = graph_compile_states(states, signed_balanced_reduction)
|
|
84
|
-
fn, fn_str = codegen_backend(inp, out)
|
|
85
|
-
return fn, fn_str
|
|
89
|
+
if hard_dc < 0:
|
|
90
|
+
hard_dc = int(1e9)
|
|
91
|
+
|
|
92
|
+
if method1 == 'auto':
|
|
93
|
+
if hard_dc >= 6 or method0.endswith('dc'):
|
|
94
|
+
method1 = method0
|
|
95
|
+
else:
|
|
96
|
+
method1 = method0 + '-dc'
|
|
97
|
+
if hard_dc == 0 and not method0.endswith('dc'):
|
|
98
|
+
method0 = method0 + '-dc'
|
|
86
99
|
|
|
100
|
+
if qintervals is None:
|
|
101
|
+
_qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
|
|
102
|
+
else:
|
|
103
|
+
_qintervals = list(qintervals)
|
|
104
|
+
if latencies is None:
|
|
105
|
+
_inp_latencies = [0.0] * kernel.shape[0]
|
|
106
|
+
else:
|
|
107
|
+
_inp_latencies = [float(lat) for lat in latencies]
|
|
108
|
+
assert len(_qintervals) == kernel.shape[0]
|
|
109
|
+
assert len(_inp_latencies) == kernel.shape[0]
|
|
87
110
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
111
|
+
min_lat = minimal_latency(kernel, _qintervals, _inp_latencies, carry_size=carry_size, adder_size=adder_size)
|
|
112
|
+
latency_allowed = hard_dc + min_lat
|
|
113
|
+
if decompose_dc == -2:
|
|
114
|
+
decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
|
|
115
|
+
else:
|
|
116
|
+
decompose_dc = min(hard_dc, decompose_dc, ceil(log2(kernel.shape[0])))
|
|
117
|
+
|
|
118
|
+
while True:
|
|
119
|
+
if decompose_dc < 0 and hard_dc >= 0:
|
|
120
|
+
if method0 != 'dummy':
|
|
121
|
+
method0, method1 = 'wmc-dc', 'wmc-dc'
|
|
122
|
+
else:
|
|
123
|
+
method0, method1 = 'dummy', 'dummy'
|
|
124
|
+
mat0, mat1 = kernel_decompose(kernel, dc=decompose_dc)
|
|
125
|
+
sol0 = _solve(
|
|
126
|
+
mat0, method=method0, qintervals=_qintervals, latencies=_inp_latencies, adder_size=adder_size, carry_size=carry_size
|
|
127
|
+
)
|
|
128
|
+
latencies0 = [sol0.ops[i].latency if i >= 0 else 0.0 for i in sol0.out_idxs]
|
|
129
|
+
qintervals0 = [sol0.ops[i].qint if i >= 0 else QInterval(0.0, 0.0, np.inf) for i in sol0.out_idxs]
|
|
130
|
+
if max(latencies0) > latency_allowed:
|
|
131
|
+
if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
|
|
132
|
+
decompose_dc -= 1
|
|
133
|
+
continue
|
|
134
|
+
sol1 = _solve(
|
|
135
|
+
mat1, method=method1, qintervals=qintervals0, latencies=latencies0, adder_size=adder_size, carry_size=carry_size
|
|
136
|
+
)
|
|
137
|
+
latencies1 = [sol1.ops[i].latency if i >= 0 else 0.0 for i in sol1.out_idxs]
|
|
138
|
+
if max(latencies1) > latency_allowed:
|
|
139
|
+
# Prevent infinite loop, shouldn't happen though
|
|
140
|
+
if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
|
|
141
|
+
decompose_dc -= 1
|
|
142
|
+
continue
|
|
143
|
+
if sum([op.cost for op in sol1.ops]) * 4 > sum([op.cost for op in sol0.ops]) and decompose_dc > 0:
|
|
144
|
+
# If the second stage is too expensive, the decomposition usually doesn't worth it
|
|
145
|
+
decompose_dc -= 1
|
|
146
|
+
continue
|
|
147
|
+
break
|
|
148
|
+
if max(latencies1) > latency_allowed:
|
|
149
|
+
# When latency depends on the bw, may happen
|
|
150
|
+
print(f'Latency constraint not satisfied: {int(latency_allowed)} < {int(max(latencies1))}')
|
|
151
|
+
return CascadedSolution((sol0, sol1))
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@jit(cache=True, parallel=True)
|
|
155
|
+
def solve(
|
|
156
|
+
kernel: np.ndarray,
|
|
157
|
+
method0: str = 'wmc',
|
|
158
|
+
method1: str = 'auto',
|
|
159
|
+
hard_dc: int = -1,
|
|
160
|
+
decompose_dc: int = -2,
|
|
161
|
+
qintervals: tuple[QInterval, ...] | None = None,
|
|
162
|
+
latencies: tuple[float, ...] | None = None,
|
|
163
|
+
adder_size: int = -1,
|
|
164
|
+
carry_size: int = -1,
|
|
165
|
+
search_all_decompose_dc: bool = True,
|
|
166
|
+
) -> CascadedSolution:
|
|
167
|
+
"""Solve the CMVM problem with cascaded two matrices.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
kernel : np.ndarray
|
|
172
|
+
The input kernel matrix to be implemented.
|
|
173
|
+
method0 : str, optional
|
|
174
|
+
Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
|
|
175
|
+
method1 : str, optional
|
|
176
|
+
Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
|
|
177
|
+
hard_dc : int, optional
|
|
178
|
+
Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
|
|
179
|
+
decompose_dc : int, optional
|
|
180
|
+
Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
|
|
181
|
+
qintervals : list[QInterval] | None, optional
|
|
182
|
+
List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
|
|
183
|
+
inp_latencies : list[float] | None, optional
|
|
184
|
+
List of input latencies, by default None (0. for all inputs)
|
|
185
|
+
adder_size : int, optional
|
|
186
|
+
Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
|
|
187
|
+
carry_size : int, optional
|
|
188
|
+
Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
|
|
189
|
+
search_all_decompose_dc : bool, optional
|
|
190
|
+
If True, search for all possible decomposition depth constraints. If False, use the provided decompose_dc value.
|
|
191
|
+
Default is True.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
CascadedSolution
|
|
196
|
+
A solution containing the optimized implementation of the CMVM computation with cascaded stages.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
if qintervals is None:
|
|
200
|
+
_qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
|
|
201
|
+
else:
|
|
202
|
+
_qintervals = list(qintervals)
|
|
203
|
+
if latencies is None:
|
|
204
|
+
_latencies = [0.0] * kernel.shape[0]
|
|
205
|
+
else:
|
|
206
|
+
_latencies = [float(lat) for lat in latencies]
|
|
207
|
+
|
|
208
|
+
if not search_all_decompose_dc:
|
|
209
|
+
return jit_solve(
|
|
210
|
+
kernel,
|
|
211
|
+
method0=method0,
|
|
212
|
+
method1=method1,
|
|
213
|
+
hard_dc=hard_dc,
|
|
214
|
+
decompose_dc=decompose_dc,
|
|
215
|
+
qintervals=_qintervals,
|
|
216
|
+
latencies=_latencies,
|
|
217
|
+
adder_size=adder_size,
|
|
218
|
+
carry_size=carry_size,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if hard_dc < 0:
|
|
222
|
+
hard_dc = int(1e9)
|
|
223
|
+
|
|
224
|
+
max_decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
|
|
225
|
+
try_decompose_dcs = list(range(-1, max_decompose_dc + 1))
|
|
226
|
+
|
|
227
|
+
costs = np.empty(len(try_decompose_dcs), dtype=np.float64)
|
|
228
|
+
|
|
229
|
+
for i in prange(len(try_decompose_dcs)):
|
|
230
|
+
decompose_dc = try_decompose_dcs[i]
|
|
231
|
+
_csol = jit_solve(
|
|
232
|
+
kernel,
|
|
233
|
+
method0=method0,
|
|
234
|
+
method1=method1,
|
|
235
|
+
hard_dc=hard_dc,
|
|
236
|
+
decompose_dc=decompose_dc,
|
|
237
|
+
qintervals=_qintervals,
|
|
238
|
+
latencies=_latencies,
|
|
239
|
+
adder_size=adder_size,
|
|
240
|
+
carry_size=carry_size,
|
|
241
|
+
)
|
|
242
|
+
_cost = sum([sum([op.cost for op in sol.ops]) for sol in _csol.solutions])
|
|
243
|
+
costs[i] = _cost
|
|
244
|
+
|
|
245
|
+
decompose_dc = try_decompose_dcs[np.argmin(costs)]
|
|
246
|
+
csol = jit_solve(
|
|
247
|
+
kernel,
|
|
248
|
+
method0=method0,
|
|
249
|
+
method1=method1,
|
|
250
|
+
hard_dc=hard_dc,
|
|
251
|
+
decompose_dc=decompose_dc,
|
|
252
|
+
qintervals=_qintervals,
|
|
253
|
+
latencies=_latencies,
|
|
254
|
+
adder_size=adder_size,
|
|
255
|
+
carry_size=carry_size,
|
|
256
|
+
)
|
|
257
|
+
return csol
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import heapq
|
|
2
|
+
from math import log2
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numba import jit
|
|
6
|
+
|
|
7
|
+
from ..types import DAState, Op, QInterval, Solution
|
|
8
|
+
from .indexers import (
|
|
9
|
+
idx_mc,
|
|
10
|
+
idx_mc_dc,
|
|
11
|
+
idx_wmc,
|
|
12
|
+
idx_wmc_dc,
|
|
13
|
+
)
|
|
14
|
+
from .state_opr import cost_add, create_state, qint_add, update_state
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@jit(cache=True)
|
|
18
|
+
def cmvm(
|
|
19
|
+
kernel: np.ndarray,
|
|
20
|
+
method: str = 'wmc',
|
|
21
|
+
qintervals: list[QInterval] | None = None,
|
|
22
|
+
inp_latencies: list[float] | None = None,
|
|
23
|
+
adder_size: int = -1,
|
|
24
|
+
carry_size: int = -1,
|
|
25
|
+
) -> DAState:
|
|
26
|
+
"""Optimizes the kernel using the CMVM algorithm.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
kernel : np.ndarray
|
|
31
|
+
The kernel to optimize.
|
|
32
|
+
method : str, optional
|
|
33
|
+
Which indexing method to use, by default 'wmc' (weighted most common)
|
|
34
|
+
Must be one of [`mc`, `mc-dc`, `mc-pdc`, `wmc`, `wmc-dc`, `wmc-pdc`, `dummy`]
|
|
35
|
+
qintervals : list[QInterval] | None, optional
|
|
36
|
+
List of QIntervals for each input, by default None
|
|
37
|
+
If None, defaults to [-128., 127., 1.] for each input.
|
|
38
|
+
inp_latencies : list[float] | None, optional
|
|
39
|
+
List of latencies for each input, by default None
|
|
40
|
+
If None, defaults to 0. for each input.
|
|
41
|
+
adder_size : int, optional
|
|
42
|
+
The atomic size of the adder for cost computation, by default -1
|
|
43
|
+
if -1, each adder can be arbitrary large, and the cost will be the number of adders
|
|
44
|
+
carry_size : int, optional
|
|
45
|
+
The size of the carry unit for latency computation, by default -1
|
|
46
|
+
if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
DAState
|
|
51
|
+
The optimized kernel as a DAState object.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
if qintervals is None:
|
|
55
|
+
_qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
|
|
56
|
+
else:
|
|
57
|
+
_qintervals = [QInterval(*qi) for qi in qintervals]
|
|
58
|
+
if inp_latencies is None:
|
|
59
|
+
_inp_latencies = [0.0] * kernel.shape[0]
|
|
60
|
+
else:
|
|
61
|
+
_inp_latencies = [float(lat) for lat in inp_latencies]
|
|
62
|
+
assert len(_qintervals) == kernel.shape[0]
|
|
63
|
+
assert len(_inp_latencies) == kernel.shape[0]
|
|
64
|
+
|
|
65
|
+
state = create_state(kernel, _qintervals, _inp_latencies)
|
|
66
|
+
while True:
|
|
67
|
+
if len(state.freq_stat) == 0:
|
|
68
|
+
break
|
|
69
|
+
match method:
|
|
70
|
+
case 'mc':
|
|
71
|
+
pair_idx = idx_mc(state)
|
|
72
|
+
case 'mc-dc':
|
|
73
|
+
pair_idx = idx_mc_dc(state, absolute=True)
|
|
74
|
+
case 'mc-pdc':
|
|
75
|
+
pair_idx = idx_mc_dc(state, absolute=False)
|
|
76
|
+
case 'wmc':
|
|
77
|
+
pair_idx = idx_wmc(state)
|
|
78
|
+
case 'wmc-dc':
|
|
79
|
+
pair_idx = idx_wmc_dc(state, absolute=True)
|
|
80
|
+
case 'wmc-pdc':
|
|
81
|
+
pair_idx = idx_wmc_dc(state, absolute=False)
|
|
82
|
+
case 'dummy':
|
|
83
|
+
break
|
|
84
|
+
case _:
|
|
85
|
+
raise ValueError(f'Unknown method: {method}')
|
|
86
|
+
if pair_idx < 0:
|
|
87
|
+
break
|
|
88
|
+
pair_chosen = list(state.freq_stat.keys())[pair_idx]
|
|
89
|
+
state = update_state(state, pair_chosen, adder_size=adder_size, carry_size=carry_size)
|
|
90
|
+
return state
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@jit(cache=True)
|
|
94
|
+
def to_solution(
|
|
95
|
+
state: DAState,
|
|
96
|
+
adder_size: int,
|
|
97
|
+
carry_size: int,
|
|
98
|
+
):
|
|
99
|
+
"""Converts the DAState to a Solution object with balanced tree reduction for the non-extracted bits in the kernel.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
state : DAState
|
|
104
|
+
The DAState to convert.
|
|
105
|
+
adder_size : int, optional
|
|
106
|
+
The atomic size of the adder for cost computation, by default -1
|
|
107
|
+
if -1, each adder can be arbitrary large, and the cost will be the number of adders
|
|
108
|
+
carry_size : int, optional
|
|
109
|
+
The size of the carry unit for latency computation, by default -1
|
|
110
|
+
if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
Solution
|
|
115
|
+
The Solution object with the optimized kernel.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
ops = state.ops.copy()
|
|
119
|
+
n_out = state.kernel.shape[1]
|
|
120
|
+
expr = np.empty((len(state.expr), *state.expr[0].shape), dtype=np.int8)
|
|
121
|
+
for i, v in enumerate(state.expr):
|
|
122
|
+
expr[i] = v
|
|
123
|
+
in_shifts, out_shifts = state.shifts
|
|
124
|
+
|
|
125
|
+
out_qints = []
|
|
126
|
+
out_lats = []
|
|
127
|
+
out_idx = []
|
|
128
|
+
in_shift = in_shifts.copy()
|
|
129
|
+
out_shift = out_shifts.copy()
|
|
130
|
+
out_neg = []
|
|
131
|
+
|
|
132
|
+
_global_id = len(ops)
|
|
133
|
+
for i_out in range(n_out):
|
|
134
|
+
heap = []
|
|
135
|
+
idx, shifts = np.where(expr[:, i_out] != 0)
|
|
136
|
+
sub = np.empty(len(idx), dtype=np.int64)
|
|
137
|
+
for i, (i_in, shift) in enumerate(zip(idx, shifts)):
|
|
138
|
+
sub[i] = expr[i_in, i_out, shift] == -1
|
|
139
|
+
|
|
140
|
+
qints: list[QInterval] = [state.ops[i].qint for i in idx]
|
|
141
|
+
lats: list[float] = [state.ops[i].latency for i in idx]
|
|
142
|
+
|
|
143
|
+
# No reduction required, dump the realized value directly
|
|
144
|
+
if len(sub) == 1:
|
|
145
|
+
out_shift[i_out] = out_shift[i_out] + shifts[0]
|
|
146
|
+
out_qints.append(qints[0])
|
|
147
|
+
out_lats.append(lats[0])
|
|
148
|
+
out_idx.append(idx[0])
|
|
149
|
+
out_neg.append(sub[0])
|
|
150
|
+
continue
|
|
151
|
+
# Output is zero
|
|
152
|
+
if len(sub) == 0:
|
|
153
|
+
out_idx.append(-1) # -1 means output constant zero
|
|
154
|
+
out_qints.append(QInterval(0.0, 0.0, np.inf))
|
|
155
|
+
out_lats.append(0.0)
|
|
156
|
+
out_neg.append(False)
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
# Sort by latency -> location of rightmost bit -> lower bound
|
|
160
|
+
left_align: list[int] = []
|
|
161
|
+
for i, qint in enumerate(qints):
|
|
162
|
+
n_int = int(log2(max(abs(qint.max + qint.step), abs(qint.min))))
|
|
163
|
+
left_align.append(n_int + shifts[i])
|
|
164
|
+
heap = list(zip(lats, sub, left_align, qints, idx, shifts))
|
|
165
|
+
heapq.heapify(heap)
|
|
166
|
+
|
|
167
|
+
while len(heap) > 1:
|
|
168
|
+
lat0, sub0, _, qint0, id0, shift0 = heapq.heappop(heap)
|
|
169
|
+
lat1, sub1, _, qint1, id1, shift1 = heapq.heappop(heap)
|
|
170
|
+
|
|
171
|
+
if sub0:
|
|
172
|
+
shift = shift0 - shift1
|
|
173
|
+
qint = qint_add(qint1, qint0, shift, sub1, sub0)
|
|
174
|
+
dlat, dcost = cost_add(qint1, qint0, shift=shift, sub=1 ^ sub1, adder_size=adder_size, carry_size=carry_size)
|
|
175
|
+
lat = max(lat0, lat1) + dlat
|
|
176
|
+
op = Op(id1, id0, 1 ^ sub1, shift, qint, lat, dcost)
|
|
177
|
+
shift = shift1
|
|
178
|
+
else:
|
|
179
|
+
shift = shift1 - shift0
|
|
180
|
+
qint = qint_add(qint0, qint1, shift, sub0, sub1)
|
|
181
|
+
dlat, dcost = cost_add(qint0, qint1, shift=shift, sub=sub1, adder_size=adder_size, carry_size=carry_size)
|
|
182
|
+
lat = max(lat0, lat1) + dlat
|
|
183
|
+
op = Op(id0, id1, sub1, shift, qint, lat, dcost)
|
|
184
|
+
shift = shift0
|
|
185
|
+
|
|
186
|
+
left_align = int(log2(max(abs(qint.max + qint.step), abs(qint.min)))) + shift
|
|
187
|
+
heapq.heappush(heap, (lat, sub0 & sub1, left_align, qint, _global_id, shift))
|
|
188
|
+
ops.append(op)
|
|
189
|
+
_global_id += 1
|
|
190
|
+
|
|
191
|
+
lat, sub, _, qint, id0, shift0 = heap[0]
|
|
192
|
+
out_idx.append(_global_id - 1)
|
|
193
|
+
out_qints.append(qint)
|
|
194
|
+
out_lats.append(lat)
|
|
195
|
+
out_neg.append(sub)
|
|
196
|
+
out_shift[i_out] = out_shift[i_out] + shift0
|
|
197
|
+
|
|
198
|
+
return Solution(
|
|
199
|
+
shape=state.kernel.shape, # type: ignore
|
|
200
|
+
inp_shift=list(in_shift),
|
|
201
|
+
out_idxs=out_idx,
|
|
202
|
+
out_shifts=list(out_shift),
|
|
203
|
+
out_negs=out_neg,
|
|
204
|
+
ops=ops,
|
|
205
|
+
carry_size=carry_size,
|
|
206
|
+
adder_size=adder_size,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@jit
|
|
211
|
+
def _solve(
|
|
212
|
+
kernel: np.ndarray,
|
|
213
|
+
method: str,
|
|
214
|
+
qintervals: list[QInterval],
|
|
215
|
+
latencies: list[float],
|
|
216
|
+
adder_size: int,
|
|
217
|
+
carry_size: int,
|
|
218
|
+
):
|
|
219
|
+
state = cmvm(
|
|
220
|
+
kernel, method=method, qintervals=qintervals, inp_latencies=latencies, adder_size=adder_size, carry_size=carry_size
|
|
221
|
+
)
|
|
222
|
+
return to_solution(state, adder_size=adder_size, carry_size=carry_size)
|