da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +194 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +240 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.0.dist-info/METADATA +85 -0
- da4ml-0.5.0.dist-info/RECORD +96 -0
- da4ml-0.5.0.dist-info/WHEEL +6 -0
- da4ml-0.5.0.dist-info/entry_points.txt +3 -0
- da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
da4ml/cmvm/api.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
from typing import TypedDict
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numba import jit, prange
|
|
6
|
+
|
|
7
|
+
from .core import _solve, create_state, to_solution
|
|
8
|
+
from .types import Pipeline, QInterval
|
|
9
|
+
from .util import kernel_decompose
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@jit(cache=True)
|
|
13
|
+
def minimal_latency(
|
|
14
|
+
kernel: np.ndarray,
|
|
15
|
+
qintervals: list[QInterval],
|
|
16
|
+
latencies: list[float],
|
|
17
|
+
carry_size: int = -1,
|
|
18
|
+
adder_size: int = -1,
|
|
19
|
+
):
|
|
20
|
+
"""Fast latency calculation for a given kernel, QInterval, and input latencies.
|
|
21
|
+
When carry_size=-1, and the input latency is constant `l`:
|
|
22
|
+
this will be the same as `l + max(ceiling(log2(max(#CSD bits for each column, 1))))`.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
kernel : np.ndarray
|
|
27
|
+
The input kernel matrix.
|
|
28
|
+
qintervals : list[QInterval]
|
|
29
|
+
List of QIntervals for each input.
|
|
30
|
+
latencies : list[float]
|
|
31
|
+
List of latencies for each input
|
|
32
|
+
carry_size : int, optional
|
|
33
|
+
The size of the carry unit for latency computation, by default -1 (fixed latency for each addition operation)
|
|
34
|
+
adder_size : int, optional
|
|
35
|
+
The size of the adder unit for latency computation, by default -1 (fixed cost for each addition operation)
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
float
|
|
40
|
+
The minimal latency for the given kernel, QInterval, and input latencies.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
state = create_state(kernel, qintervals, latencies, no_stat_init=True)
|
|
44
|
+
sol = to_solution(state, adder_size=adder_size, carry_size=carry_size)
|
|
45
|
+
latencies = [sol.ops[i].latency if i >= 0 else 0.0 for i in sol.out_idxs]
|
|
46
|
+
return max(latencies)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@jit(cache=True)
|
|
50
|
+
def jit_solve(
|
|
51
|
+
kernel: np.ndarray,
|
|
52
|
+
method0: str = 'wmc',
|
|
53
|
+
method1: str = 'auto',
|
|
54
|
+
hard_dc: int = -1,
|
|
55
|
+
decompose_dc: int = -2,
|
|
56
|
+
qintervals: list[QInterval] | None = None,
|
|
57
|
+
latencies: list[float] | None = None,
|
|
58
|
+
adder_size: int = -1,
|
|
59
|
+
carry_size: int = -1,
|
|
60
|
+
) -> Pipeline:
|
|
61
|
+
"""Optimized implementation of a CMVM computation with cascaded two matrices.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
kernel : np.ndarray
|
|
66
|
+
The input kernel matrix to be implemented.
|
|
67
|
+
method0 : str, optional
|
|
68
|
+
Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
|
|
69
|
+
method1 : str, optional
|
|
70
|
+
Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
|
|
71
|
+
hard_dc : int, optional
|
|
72
|
+
Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
|
|
73
|
+
decompose_dc : int, optional
|
|
74
|
+
Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
|
|
75
|
+
qintervals : list[QInterval] | None, optional
|
|
76
|
+
List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
|
|
77
|
+
inp_latencies : list[float] | None, optional
|
|
78
|
+
List of input latencies, by default None (0. for all inputs)
|
|
79
|
+
adder_size : int, optional
|
|
80
|
+
Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
|
|
81
|
+
carry_size : int, optional
|
|
82
|
+
Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
CascadedSolution
|
|
87
|
+
A solution containing the optimized implementation of the CMVM computation with cascaded stages.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
if hard_dc < 0:
|
|
91
|
+
hard_dc = int(1e9)
|
|
92
|
+
|
|
93
|
+
if method1 == 'auto':
|
|
94
|
+
if hard_dc >= 6 or method0.endswith('dc'):
|
|
95
|
+
method1 = method0
|
|
96
|
+
else:
|
|
97
|
+
method1 = method0 + '-dc'
|
|
98
|
+
if hard_dc == 0 and not method0.endswith('dc'):
|
|
99
|
+
method0 = method0 + '-dc'
|
|
100
|
+
|
|
101
|
+
if qintervals is None:
|
|
102
|
+
_qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
|
|
103
|
+
else:
|
|
104
|
+
_qintervals = list(qintervals)
|
|
105
|
+
if latencies is None:
|
|
106
|
+
_inp_latencies = [0.0] * kernel.shape[0]
|
|
107
|
+
else:
|
|
108
|
+
_inp_latencies = [float(lat) for lat in latencies]
|
|
109
|
+
assert len(_qintervals) == kernel.shape[0]
|
|
110
|
+
assert len(_inp_latencies) == kernel.shape[0]
|
|
111
|
+
|
|
112
|
+
min_lat = minimal_latency(kernel, _qintervals, _inp_latencies, carry_size=carry_size, adder_size=adder_size)
|
|
113
|
+
latency_allowed = hard_dc + min_lat
|
|
114
|
+
if decompose_dc == -2:
|
|
115
|
+
decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
|
|
116
|
+
else:
|
|
117
|
+
decompose_dc = min(hard_dc, decompose_dc, ceil(log2(kernel.shape[0])))
|
|
118
|
+
|
|
119
|
+
while True:
|
|
120
|
+
if decompose_dc < 0 and hard_dc >= 0:
|
|
121
|
+
if method0 != 'dummy':
|
|
122
|
+
method0, method1 = 'wmc-dc', 'wmc-dc'
|
|
123
|
+
else:
|
|
124
|
+
method0, method1 = 'dummy', 'dummy'
|
|
125
|
+
mat0, mat1 = kernel_decompose(kernel, dc=decompose_dc)
|
|
126
|
+
sol0 = _solve(
|
|
127
|
+
mat0, method=method0, qintervals=_qintervals, latencies=_inp_latencies, adder_size=adder_size, carry_size=carry_size
|
|
128
|
+
)
|
|
129
|
+
latencies0 = [sol0.ops[i].latency if i >= 0 else 0.0 for i in sol0.out_idxs]
|
|
130
|
+
qintervals0 = [sol0.ops[i].qint if i >= 0 else QInterval(0.0, 0.0, np.inf) for i in sol0.out_idxs]
|
|
131
|
+
if max(latencies0) > latency_allowed:
|
|
132
|
+
if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
|
|
133
|
+
decompose_dc -= 1
|
|
134
|
+
continue
|
|
135
|
+
sol1 = _solve(
|
|
136
|
+
mat1, method=method1, qintervals=qintervals0, latencies=latencies0, adder_size=adder_size, carry_size=carry_size
|
|
137
|
+
)
|
|
138
|
+
latencies1 = [sol1.ops[i].latency if i >= 0 else 0.0 for i in sol1.out_idxs]
|
|
139
|
+
if max(latencies1) > latency_allowed:
|
|
140
|
+
# Prevent infinite loop, shouldn't happen though
|
|
141
|
+
if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
|
|
142
|
+
decompose_dc -= 1
|
|
143
|
+
continue
|
|
144
|
+
break
|
|
145
|
+
if max(latencies1) > latency_allowed:
|
|
146
|
+
# When latency depends on the bw, may happen
|
|
147
|
+
print(f'Latency constraint not satisfied: {int(latency_allowed)} < {int(max(latencies1))}')
|
|
148
|
+
return Pipeline((sol0, sol1))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@jit(cache=True, parallel=True)
|
|
152
|
+
def solve(
|
|
153
|
+
kernel: np.ndarray,
|
|
154
|
+
method0: str = 'wmc',
|
|
155
|
+
method1: str = 'auto',
|
|
156
|
+
hard_dc: int = -1,
|
|
157
|
+
decompose_dc: int = -2,
|
|
158
|
+
qintervals: list[QInterval] | None = None,
|
|
159
|
+
latencies: list[float] | None = None,
|
|
160
|
+
adder_size: int = -1,
|
|
161
|
+
carry_size: int = -1,
|
|
162
|
+
search_all_decompose_dc: bool = True,
|
|
163
|
+
) -> Pipeline:
|
|
164
|
+
"""Solve the CMVM problem with cascaded two matrices.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
kernel : np.ndarray
|
|
169
|
+
The input kernel matrix to be implemented.
|
|
170
|
+
method0 : str, optional
|
|
171
|
+
Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
|
|
172
|
+
method1 : str, optional
|
|
173
|
+
Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
|
|
174
|
+
hard_dc : int, optional
|
|
175
|
+
Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
|
|
176
|
+
decompose_dc : int, optional
|
|
177
|
+
Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
|
|
178
|
+
qintervals : list[QInterval] | None, optional
|
|
179
|
+
List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
|
|
180
|
+
inp_latencies : list[float] | None, optional
|
|
181
|
+
List of input latencies, by default None (0. for all inputs)
|
|
182
|
+
adder_size : int, optional
|
|
183
|
+
Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
|
|
184
|
+
carry_size : int, optional
|
|
185
|
+
Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
|
|
186
|
+
search_all_decompose_dc : bool, optional
|
|
187
|
+
If True, search for all possible decomposition depth constraints. If False, use the provided decompose_dc value.
|
|
188
|
+
Default is True.
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
CascadedSolution
|
|
193
|
+
A solution containing the optimized implementation of the CMVM computation with cascaded stages.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
if qintervals is None:
|
|
197
|
+
_qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
|
|
198
|
+
else:
|
|
199
|
+
_qintervals = list(qintervals)
|
|
200
|
+
if latencies is None:
|
|
201
|
+
_latencies = [0.0] * kernel.shape[0]
|
|
202
|
+
else:
|
|
203
|
+
_latencies = [float(lat) for lat in latencies]
|
|
204
|
+
|
|
205
|
+
if not search_all_decompose_dc:
|
|
206
|
+
return jit_solve(
|
|
207
|
+
kernel,
|
|
208
|
+
method0=method0,
|
|
209
|
+
method1=method1,
|
|
210
|
+
hard_dc=hard_dc,
|
|
211
|
+
decompose_dc=decompose_dc,
|
|
212
|
+
qintervals=_qintervals,
|
|
213
|
+
latencies=_latencies,
|
|
214
|
+
adder_size=adder_size,
|
|
215
|
+
carry_size=carry_size,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if hard_dc < 0:
|
|
219
|
+
hard_dc = int(1e9)
|
|
220
|
+
|
|
221
|
+
max_decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
|
|
222
|
+
try_decompose_dcs = list(range(-1, max_decompose_dc + 1))
|
|
223
|
+
|
|
224
|
+
costs = np.empty(len(try_decompose_dcs), dtype=np.float64)
|
|
225
|
+
|
|
226
|
+
for i in prange(len(try_decompose_dcs)):
|
|
227
|
+
decompose_dc = try_decompose_dcs[i]
|
|
228
|
+
_csol = jit_solve(
|
|
229
|
+
kernel,
|
|
230
|
+
method0=method0,
|
|
231
|
+
method1=method1,
|
|
232
|
+
hard_dc=hard_dc,
|
|
233
|
+
decompose_dc=decompose_dc,
|
|
234
|
+
qintervals=_qintervals,
|
|
235
|
+
latencies=_latencies,
|
|
236
|
+
adder_size=adder_size,
|
|
237
|
+
carry_size=carry_size,
|
|
238
|
+
)
|
|
239
|
+
_cost = sum([sum([op.cost for op in sol.ops]) for sol in _csol.solutions])
|
|
240
|
+
costs[i] = _cost
|
|
241
|
+
|
|
242
|
+
decompose_dc = try_decompose_dcs[np.argmin(costs)]
|
|
243
|
+
csol = jit_solve(
|
|
244
|
+
kernel,
|
|
245
|
+
method0=method0,
|
|
246
|
+
method1=method1,
|
|
247
|
+
hard_dc=hard_dc,
|
|
248
|
+
decompose_dc=decompose_dc,
|
|
249
|
+
qintervals=_qintervals,
|
|
250
|
+
latencies=_latencies,
|
|
251
|
+
adder_size=adder_size,
|
|
252
|
+
carry_size=carry_size,
|
|
253
|
+
)
|
|
254
|
+
return csol
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class solver_options_t(TypedDict, total=False):
|
|
258
|
+
method0: str
|
|
259
|
+
method1: str
|
|
260
|
+
hard_dc: int
|
|
261
|
+
decompose_dc: int
|
|
262
|
+
adder_size: int
|
|
263
|
+
carry_size: int
|
|
264
|
+
search_all_decompose_dc: bool
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import heapq
|
|
2
|
+
from math import log2
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numba import jit
|
|
6
|
+
|
|
7
|
+
from ..types import CombLogic, DAState, Op, QInterval
|
|
8
|
+
from .indexers import (
|
|
9
|
+
idx_mc,
|
|
10
|
+
idx_mc_dc,
|
|
11
|
+
idx_wmc,
|
|
12
|
+
idx_wmc_dc,
|
|
13
|
+
)
|
|
14
|
+
from .state_opr import cost_add, create_state, qint_add, update_state
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@jit(cache=True)
|
|
18
|
+
def cmvm(
|
|
19
|
+
kernel: np.ndarray,
|
|
20
|
+
method: str = 'wmc',
|
|
21
|
+
qintervals: list[QInterval] | None = None,
|
|
22
|
+
inp_latencies: list[float] | None = None,
|
|
23
|
+
adder_size: int = -1,
|
|
24
|
+
carry_size: int = -1,
|
|
25
|
+
) -> DAState:
|
|
26
|
+
"""Optimizes the kernel using the CMVM algorithm.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
kernel : np.ndarray
|
|
31
|
+
The kernel to optimize.
|
|
32
|
+
method : str, optional
|
|
33
|
+
Which indexing method to use, by default 'wmc' (weighted most common)
|
|
34
|
+
Must be one of [`mc`, `mc-dc`, `mc-pdc`, `wmc`, `wmc-dc`, `wmc-pdc`, `dummy`]
|
|
35
|
+
qintervals : list[QInterval] | None, optional
|
|
36
|
+
List of QIntervals for each input, by default None
|
|
37
|
+
If None, defaults to [-128., 127., 1.] for each input.
|
|
38
|
+
inp_latencies : list[float] | None, optional
|
|
39
|
+
List of latencies for each input, by default None
|
|
40
|
+
If None, defaults to 0. for each input.
|
|
41
|
+
adder_size : int, optional
|
|
42
|
+
The atomic size of the adder for cost computation, by default -1
|
|
43
|
+
if -1, each adder can be arbitrary large, and the cost will be the number of adders
|
|
44
|
+
carry_size : int, optional
|
|
45
|
+
The size of the carry unit for latency computation, by default -1
|
|
46
|
+
if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
DAState
|
|
51
|
+
The optimized kernel as a DAState object.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
if qintervals is None:
|
|
55
|
+
_qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
|
|
56
|
+
else:
|
|
57
|
+
_qintervals = [QInterval(*qi) for qi in qintervals]
|
|
58
|
+
if inp_latencies is None:
|
|
59
|
+
_inp_latencies = [0.0] * kernel.shape[0]
|
|
60
|
+
else:
|
|
61
|
+
_inp_latencies = [float(lat) for lat in inp_latencies]
|
|
62
|
+
assert len(_qintervals) == kernel.shape[0]
|
|
63
|
+
assert len(_inp_latencies) == kernel.shape[0]
|
|
64
|
+
|
|
65
|
+
state = create_state(kernel, _qintervals, _inp_latencies)
|
|
66
|
+
while True:
|
|
67
|
+
if len(state.freq_stat) == 0:
|
|
68
|
+
break
|
|
69
|
+
match method:
|
|
70
|
+
case 'mc':
|
|
71
|
+
pair_idx = idx_mc(state)
|
|
72
|
+
case 'mc-dc':
|
|
73
|
+
pair_idx = idx_mc_dc(state, absolute=True)
|
|
74
|
+
case 'mc-pdc':
|
|
75
|
+
pair_idx = idx_mc_dc(state, absolute=False)
|
|
76
|
+
case 'wmc':
|
|
77
|
+
pair_idx = idx_wmc(state)
|
|
78
|
+
case 'wmc-dc':
|
|
79
|
+
pair_idx = idx_wmc_dc(state, absolute=True)
|
|
80
|
+
case 'wmc-pdc':
|
|
81
|
+
pair_idx = idx_wmc_dc(state, absolute=False)
|
|
82
|
+
case 'dummy':
|
|
83
|
+
break
|
|
84
|
+
case _:
|
|
85
|
+
raise ValueError(f'Unknown method: {method}')
|
|
86
|
+
if pair_idx < 0:
|
|
87
|
+
break
|
|
88
|
+
pair_chosen = list(state.freq_stat.keys())[pair_idx]
|
|
89
|
+
state = update_state(state, pair_chosen, adder_size=adder_size, carry_size=carry_size)
|
|
90
|
+
return state
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@jit(cache=True)
|
|
94
|
+
def to_solution(
|
|
95
|
+
state: DAState,
|
|
96
|
+
adder_size: int,
|
|
97
|
+
carry_size: int,
|
|
98
|
+
):
|
|
99
|
+
"""Converts the DAState to a Solution object with balanced tree reduction for the non-extracted bits in the kernel.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
state : DAState
|
|
104
|
+
The DAState to convert.
|
|
105
|
+
adder_size : int, optional
|
|
106
|
+
The atomic size of the adder for cost computation, by default -1
|
|
107
|
+
if -1, each adder can be arbitrary large, and the cost will be the number of adders
|
|
108
|
+
carry_size : int, optional
|
|
109
|
+
The size of the carry unit for latency computation, by default -1
|
|
110
|
+
if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
Solution
|
|
115
|
+
The Solution object with the optimized kernel.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
ops = state.ops.copy()
|
|
119
|
+
n_out = state.kernel.shape[1]
|
|
120
|
+
expr = np.empty((len(state.expr), *state.expr[0].shape), dtype=np.int8)
|
|
121
|
+
for i, v in enumerate(state.expr):
|
|
122
|
+
expr[i] = v
|
|
123
|
+
in_shifts, out_shifts = state.shifts
|
|
124
|
+
|
|
125
|
+
out_qints = []
|
|
126
|
+
out_lats = []
|
|
127
|
+
out_idx = []
|
|
128
|
+
in_shift = in_shifts.copy()
|
|
129
|
+
out_shift = out_shifts.copy()
|
|
130
|
+
out_neg = []
|
|
131
|
+
|
|
132
|
+
_global_id = len(ops)
|
|
133
|
+
for i_out in range(n_out):
|
|
134
|
+
idx, shifts = np.where(expr[:, i_out] != 0)
|
|
135
|
+
sub = np.empty(len(idx), dtype=np.int64)
|
|
136
|
+
for i, (i_in, shift) in enumerate(zip(idx, shifts)):
|
|
137
|
+
sub[i] = expr[i_in, i_out, shift] == -1
|
|
138
|
+
|
|
139
|
+
qints: list[QInterval] = [state.ops[i].qint for i in idx]
|
|
140
|
+
lats: list[float] = [state.ops[i].latency for i in idx]
|
|
141
|
+
|
|
142
|
+
# No reduction required, dump the realized value directly
|
|
143
|
+
if len(sub) == 1:
|
|
144
|
+
out_shift[i_out] = out_shift[i_out] + shifts[0]
|
|
145
|
+
out_qints.append(qints[0])
|
|
146
|
+
out_lats.append(lats[0])
|
|
147
|
+
out_idx.append(idx[0])
|
|
148
|
+
out_neg.append(sub[0])
|
|
149
|
+
continue
|
|
150
|
+
# Output is zero
|
|
151
|
+
if len(sub) == 0:
|
|
152
|
+
out_idx.append(-1) # -1 means output constant zero
|
|
153
|
+
out_qints.append(QInterval(0.0, 0.0, np.inf))
|
|
154
|
+
out_lats.append(0.0)
|
|
155
|
+
out_neg.append(False)
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
# Sort by latency -> location of rightmost bit -> lower bound
|
|
159
|
+
left_align: list[int] = []
|
|
160
|
+
for i, qint in enumerate(qints):
|
|
161
|
+
n_int = int(log2(max(abs(qint.max + qint.step), abs(qint.min))))
|
|
162
|
+
left_align.append(n_int + shifts[i])
|
|
163
|
+
heap = list(zip(lats, sub, left_align, qints, idx, shifts))
|
|
164
|
+
heapq.heapify(heap)
|
|
165
|
+
|
|
166
|
+
while len(heap) > 1:
|
|
167
|
+
lat0, sub0, _, qint0, id0, shift0 = heapq.heappop(heap)
|
|
168
|
+
lat1, sub1, _, qint1, id1, shift1 = heapq.heappop(heap)
|
|
169
|
+
|
|
170
|
+
if sub0:
|
|
171
|
+
shift = shift0 - shift1
|
|
172
|
+
qint = qint_add(qint1, qint0, shift, sub1, sub0)
|
|
173
|
+
dlat, dcost = cost_add(qint1, qint0, shift=shift, sub=1 ^ sub1, adder_size=adder_size, carry_size=carry_size)
|
|
174
|
+
lat = max(lat0, lat1) + dlat
|
|
175
|
+
op = Op(id1, id0, 1 ^ sub1, shift, qint, lat, dcost)
|
|
176
|
+
shift = shift1
|
|
177
|
+
else:
|
|
178
|
+
shift = shift1 - shift0
|
|
179
|
+
qint = qint_add(qint0, qint1, shift, sub0, sub1)
|
|
180
|
+
dlat, dcost = cost_add(qint0, qint1, shift=shift, sub=sub1, adder_size=adder_size, carry_size=carry_size)
|
|
181
|
+
lat = max(lat0, lat1) + dlat
|
|
182
|
+
op = Op(id0, id1, sub1, shift, qint, lat, dcost)
|
|
183
|
+
shift = shift0
|
|
184
|
+
|
|
185
|
+
left_align = int(log2(max(abs(qint.max + qint.step), abs(qint.min)))) + shift
|
|
186
|
+
heapq.heappush(heap, (lat, sub0 & sub1, left_align, qint, _global_id, shift))
|
|
187
|
+
ops.append(op)
|
|
188
|
+
_global_id += 1
|
|
189
|
+
|
|
190
|
+
lat, sub, _, qint, id0, shift0 = heap[0]
|
|
191
|
+
out_idx.append(_global_id - 1)
|
|
192
|
+
out_qints.append(qint)
|
|
193
|
+
out_lats.append(lat)
|
|
194
|
+
out_neg.append(sub)
|
|
195
|
+
out_shift[i_out] = out_shift[i_out] + shift0
|
|
196
|
+
|
|
197
|
+
return CombLogic(
|
|
198
|
+
shape=state.kernel.shape, # type: ignore
|
|
199
|
+
inp_shifts=list(in_shift),
|
|
200
|
+
out_idxs=out_idx,
|
|
201
|
+
out_shifts=list(out_shift),
|
|
202
|
+
out_negs=out_neg,
|
|
203
|
+
ops=ops,
|
|
204
|
+
carry_size=carry_size,
|
|
205
|
+
adder_size=adder_size,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@jit
|
|
210
|
+
def _solve(
|
|
211
|
+
kernel: np.ndarray,
|
|
212
|
+
method: str,
|
|
213
|
+
qintervals: list[QInterval],
|
|
214
|
+
latencies: list[float],
|
|
215
|
+
adder_size: int,
|
|
216
|
+
carry_size: int,
|
|
217
|
+
):
|
|
218
|
+
state = cmvm(
|
|
219
|
+
kernel, method=method, qintervals=qintervals, inp_latencies=latencies, adder_size=adder_size, carry_size=carry_size
|
|
220
|
+
)
|
|
221
|
+
return to_solution(state, adder_size=adder_size, carry_size=carry_size)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numba import jit
|
|
5
|
+
|
|
6
|
+
from ..types import DAState, QInterval
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@jit
|
|
10
|
+
def idx_mc(state: DAState):
|
|
11
|
+
"""Choose the pair with highest frequency."""
|
|
12
|
+
freqs = list(state.freq_stat.values())
|
|
13
|
+
max_freq = max(freqs)
|
|
14
|
+
pair_idx = freqs.index(max_freq)
|
|
15
|
+
return pair_idx
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@jit
|
|
19
|
+
def idx_mc_dc(state: DAState, absolute: bool = False):
|
|
20
|
+
"""Choose the pair with highest frequency with latency penalty.
|
|
21
|
+
If absolute is True, return -1 if any latency overhead may present."""
|
|
22
|
+
freqs = list(state.freq_stat.values())
|
|
23
|
+
factor = max(freqs) + 1
|
|
24
|
+
ops = state.ops
|
|
25
|
+
lat_penalty = [abs(ops[pair.id1].latency - ops[pair.id0].latency) * factor for pair in state.freq_stat.keys()]
|
|
26
|
+
score = [freq - lat_penalty[i] for i, freq in enumerate(freqs)]
|
|
27
|
+
max_score = max(score)
|
|
28
|
+
if absolute and max_score < 0:
|
|
29
|
+
return -1
|
|
30
|
+
pair_idx = score.index(max_score)
|
|
31
|
+
return pair_idx
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@jit
|
|
35
|
+
def overlap_and_accum(qint0: QInterval, qint1: QInterval):
|
|
36
|
+
"""Calculate the overlap and total number of bits for two QIntervals, when represented in fixed-point format."""
|
|
37
|
+
|
|
38
|
+
min0, max0, step0 = qint0
|
|
39
|
+
min1, max1, step1 = qint1
|
|
40
|
+
max0, max1 = max0 + step0, max1 + step1
|
|
41
|
+
|
|
42
|
+
f = -log2(max(step0, step1))
|
|
43
|
+
i_high = ceil(log2(max(abs(min0), abs(min1), abs(max0), abs(max1))))
|
|
44
|
+
i_low = ceil(log2(min(max(abs(min0), abs(max0)), max(abs(min1), abs(max1)))))
|
|
45
|
+
k = int(qint0.min < 0 or qint1.min < 0)
|
|
46
|
+
n_accum = k + i_high + f
|
|
47
|
+
n_overlap = k + i_low + f
|
|
48
|
+
return n_overlap, n_accum
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@jit
|
|
52
|
+
def idx_wmc(state: DAState):
|
|
53
|
+
"""Choose the pair with the highest weighted most common subexpression (WMC) score."""
|
|
54
|
+
freqs = list(state.freq_stat.values())
|
|
55
|
+
keys = list(state.freq_stat.keys())
|
|
56
|
+
score = np.empty(len(freqs), dtype=np.float32)
|
|
57
|
+
for i, (k, v) in enumerate(zip(keys, freqs)):
|
|
58
|
+
id0, id1 = k.id0, k.id1
|
|
59
|
+
qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
|
|
60
|
+
n_overlap, _ = overlap_and_accum(qint0, qint1)
|
|
61
|
+
score[i] = v * n_overlap
|
|
62
|
+
max_score = np.max(score)
|
|
63
|
+
if max_score < 0:
|
|
64
|
+
return -1
|
|
65
|
+
return np.argmax(score)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@jit
|
|
69
|
+
def idx_wmc_dc(state: DAState, absolute: bool = False):
|
|
70
|
+
"""Choose the pair with the highest weighted most common subexpression (WMC) score with latency and cost penalty.
|
|
71
|
+
When absolute is True, return -1 if any latency overhead may present."""
|
|
72
|
+
freqs = list(state.freq_stat.values())
|
|
73
|
+
keys = list(state.freq_stat.keys())
|
|
74
|
+
score = np.empty(len(freqs), dtype=np.float32)
|
|
75
|
+
for i, (k, v) in enumerate(zip(keys, freqs)):
|
|
76
|
+
id0, id1 = k.id0, k.id1
|
|
77
|
+
qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
|
|
78
|
+
lat0, lat1 = state.ops[id0].latency, state.ops[id1].latency
|
|
79
|
+
n_overlap, _ = overlap_and_accum(qint0, qint1)
|
|
80
|
+
score[i] = v * n_overlap - 256 * abs(lat0 - lat1)
|
|
81
|
+
if absolute and np.max(score) < 0:
|
|
82
|
+
return -1
|
|
83
|
+
return np.argmax(score)
|