da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +16 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +3 -34
- da4ml/cmvm/api.py +239 -73
- da4ml/cmvm/core/__init__.py +222 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +569 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +11 -0
- da4ml/codegen/cpp/__init__.py +3 -0
- da4ml/codegen/cpp/cpp_codegen.py +148 -0
- da4ml/codegen/cpp/source/vitis.h +30 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
- da4ml/codegen/verilog/__init__.py +13 -0
- da4ml/codegen/verilog/comb.py +146 -0
- da4ml/codegen/verilog/io_wrapper.py +255 -0
- da4ml/codegen/verilog/pipeline.py +49 -0
- da4ml/codegen/verilog/source/build_binder.mk +27 -0
- da4ml/codegen/verilog/source/build_prj.tcl +75 -0
- da4ml/codegen/verilog/source/ioutils.hh +117 -0
- da4ml/codegen/verilog/source/shift_adder.v +56 -0
- da4ml/codegen/verilog/source/template.xdc +29 -0
- da4ml/codegen/verilog/verilog_model.py +265 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +358 -0
- da4ml/trace/fixed_variable_array.py +177 -0
- da4ml/trace/ops/__init__.py +55 -0
- da4ml/trace/ops/conv_utils.py +104 -0
- da4ml/trace/ops/einsum_utils.py +299 -0
- da4ml/trace/pipeline.py +155 -0
- da4ml/trace/tracer.py +120 -0
- da4ml-0.2.0.dist-info/METADATA +65 -0
- da4ml-0.2.0.dist-info/RECORD +39 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
- da4ml/cmvm/balanced_reduction.py +0 -46
- da4ml/cmvm/cmvm.py +0 -328
- da4ml/cmvm/codegen.py +0 -159
- da4ml/cmvm/csd.py +0 -73
- da4ml/cmvm/fixed_variable.py +0 -205
- da4ml/cmvm/graph_compile.py +0 -85
- da4ml/cmvm/nb_fixed_precision.py +0 -98
- da4ml/cmvm/scoring.py +0 -55
- da4ml/cmvm/utils.py +0 -5
- da4ml-0.1.2.dist-info/METADATA +0 -122
- da4ml-0.1.2.dist-info/RECORD +0 -18
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numba import jit
|
|
5
|
+
|
|
6
|
+
from ..types import DAState, QInterval
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@jit
|
|
10
|
+
def idx_mc(state: DAState):
|
|
11
|
+
"""Choose the pair with highest frequency."""
|
|
12
|
+
freqs = list(state.freq_stat.values())
|
|
13
|
+
max_freq = max(freqs)
|
|
14
|
+
pair_idx = freqs.index(max_freq)
|
|
15
|
+
return pair_idx
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@jit
|
|
19
|
+
def idx_mc_dc(state: DAState, absolute: bool = False):
|
|
20
|
+
"""Choose the pair with highest frequency with latency penalty.
|
|
21
|
+
If absolute is True, return -1 if any latency overhead may present."""
|
|
22
|
+
freqs = list(state.freq_stat.values())
|
|
23
|
+
factor = max(freqs) + 1
|
|
24
|
+
ops = state.ops
|
|
25
|
+
lat_penalty = [abs(ops[pair.id1].latency - ops[pair.id0].latency) * factor for pair in state.freq_stat.keys()]
|
|
26
|
+
score = [freq - lat_penalty[i] for i, freq in enumerate(freqs)]
|
|
27
|
+
max_score = max(score)
|
|
28
|
+
if absolute and max_score < 0:
|
|
29
|
+
return -1
|
|
30
|
+
pair_idx = score.index(max_score)
|
|
31
|
+
return pair_idx
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@jit
|
|
35
|
+
def overlap_and_accum(qint0: QInterval, qint1: QInterval):
|
|
36
|
+
"""Calculate the overlap and total number of bits for two QIntervals, when represented in fixed-point format."""
|
|
37
|
+
|
|
38
|
+
min0, max0, step0 = qint0
|
|
39
|
+
min1, max1, step1 = qint1
|
|
40
|
+
max0, max1 = max0 + step0, max1 + step1
|
|
41
|
+
|
|
42
|
+
f = -log2(max(step0, step1))
|
|
43
|
+
i_high = ceil(log2(max(abs(min0), abs(min1), abs(max0), abs(max1))))
|
|
44
|
+
i_low = ceil(log2(min(max(abs(min0), abs(max0)), max(abs(min1), abs(max1)))))
|
|
45
|
+
k = int(qint0.min < 0 or qint1.min < 0)
|
|
46
|
+
n_accum = k + i_high + f
|
|
47
|
+
n_overlap = k + i_low + f
|
|
48
|
+
return n_overlap, n_accum
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@jit
|
|
52
|
+
def idx_wmc(state: DAState):
|
|
53
|
+
"""Choose the pair with the highest weighted most common subexpression (WMC) score."""
|
|
54
|
+
freqs = list(state.freq_stat.values())
|
|
55
|
+
keys = list(state.freq_stat.keys())
|
|
56
|
+
score = np.empty(len(freqs), dtype=np.float32)
|
|
57
|
+
for i, (k, v) in enumerate(zip(keys, freqs)):
|
|
58
|
+
id0, id1 = k.id0, k.id1
|
|
59
|
+
qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
|
|
60
|
+
n_overlap, _ = overlap_and_accum(qint0, qint1)
|
|
61
|
+
score[i] = v * n_overlap
|
|
62
|
+
max_score = np.max(score)
|
|
63
|
+
if max_score < 0:
|
|
64
|
+
return -1
|
|
65
|
+
return np.argmax(score)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@jit
|
|
69
|
+
def idx_wmc_dc(state: DAState, absolute: bool = False):
|
|
70
|
+
"""Choose the pair with the highest weighted most common subexpression (WMC) score with latency and cost penalty.
|
|
71
|
+
When absolute is True, return -1 if any latency overhead may present."""
|
|
72
|
+
freqs = list(state.freq_stat.values())
|
|
73
|
+
keys = list(state.freq_stat.keys())
|
|
74
|
+
score = np.empty(len(freqs), dtype=np.float32)
|
|
75
|
+
for i, (k, v) in enumerate(zip(keys, freqs)):
|
|
76
|
+
id0, id1 = k.id0, k.id1
|
|
77
|
+
qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
|
|
78
|
+
lat0, lat1 = state.ops[id0].latency, state.ops[id1].latency
|
|
79
|
+
n_overlap, _ = overlap_and_accum(qint0, qint1)
|
|
80
|
+
score[i] = v * n_overlap - 256 * abs(lat0 - lat1)
|
|
81
|
+
if absolute and np.max(score) < 0:
|
|
82
|
+
return -1
|
|
83
|
+
return np.argmax(score)
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numba import jit
|
|
5
|
+
|
|
6
|
+
from ..types import DAState, Op, Pair, QInterval
|
|
7
|
+
from ..util import csd_decompose
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@jit
|
|
11
|
+
def qint_add(qint0: QInterval, qint1: QInterval, shift: int, sub0=False, sub1=False) -> QInterval:
|
|
12
|
+
min0, max0, step0 = qint0
|
|
13
|
+
min1, max1, step1 = qint1
|
|
14
|
+
if sub0:
|
|
15
|
+
min0, max0 = -max0, -min0
|
|
16
|
+
if sub1:
|
|
17
|
+
min1, max1 = -max1, -min1
|
|
18
|
+
|
|
19
|
+
s = 2.0**shift
|
|
20
|
+
min1, max1, step1 = min1 * s, max1 * s, step1 * s
|
|
21
|
+
|
|
22
|
+
return QInterval(min0 + min1, max0 + max1, min(step0, step1))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@jit
|
|
26
|
+
def cost_add(
|
|
27
|
+
qint0: QInterval, qint1: QInterval, shift: int, sub: bool = False, adder_size: int = -1, carry_size: int = -1
|
|
28
|
+
) -> tuple[float, float]:
|
|
29
|
+
"""Calculate the latency and cost of an addition operation.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
qint1 : QInterval
|
|
34
|
+
The first QInterval.
|
|
35
|
+
qint2 : QInterval
|
|
36
|
+
The second QInterval.
|
|
37
|
+
sub : bool
|
|
38
|
+
If True, the operation is a subtraction (a - b) instead of an addition (a + b).
|
|
39
|
+
adder_size : int
|
|
40
|
+
The atomic size of the adder.
|
|
41
|
+
carry_size : int
|
|
42
|
+
The size of the look-ahead carry.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
tuple[float, float]
|
|
47
|
+
The latency and cost of the addition operation.
|
|
48
|
+
"""
|
|
49
|
+
if adder_size < 0 and carry_size < 0:
|
|
50
|
+
return 1.0, 1.0
|
|
51
|
+
if adder_size < 0:
|
|
52
|
+
adder_size = 65535
|
|
53
|
+
if carry_size < 0:
|
|
54
|
+
carry_size = 65535
|
|
55
|
+
|
|
56
|
+
min0, max0, step0 = qint0
|
|
57
|
+
min1, max1, step1 = qint1
|
|
58
|
+
if sub:
|
|
59
|
+
min1, max1 = max1, min1
|
|
60
|
+
sf = 2.0**shift
|
|
61
|
+
min1, max1, step1 = min1 * sf, max1 * sf, step1 * sf
|
|
62
|
+
max0, max1 = max0 + step0, max1 + step1
|
|
63
|
+
|
|
64
|
+
f = -log2(max(step0, step1))
|
|
65
|
+
i = ceil(log2(max(abs(min0), abs(min1), abs(max0), abs(max1))))
|
|
66
|
+
k = int(qint0.min < 0 or qint1.min < 0)
|
|
67
|
+
n_accum = k + i + f
|
|
68
|
+
# Align to the number of carry and adder bits, when they are block-based (e.g., 4/8 bits look-ahead carry in Xilinx FPGAs)
|
|
69
|
+
# For Altera, the carry seems to be single bit adder chains, but need to check
|
|
70
|
+
return float(ceil(n_accum / carry_size)), float(ceil(n_accum / adder_size))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@jit
|
|
74
|
+
def create_state(
|
|
75
|
+
kernel: np.ndarray,
|
|
76
|
+
qintervals: list[QInterval],
|
|
77
|
+
inp_latencies: list[float],
|
|
78
|
+
no_stat_init: bool = False,
|
|
79
|
+
):
|
|
80
|
+
assert len(qintervals) == kernel.shape[0]
|
|
81
|
+
assert len(inp_latencies) == kernel.shape[0]
|
|
82
|
+
assert kernel.ndim == 2
|
|
83
|
+
|
|
84
|
+
kernel = kernel.astype(np.float64)
|
|
85
|
+
n_in, n_out = kernel.shape
|
|
86
|
+
kernel = np.asarray(kernel)
|
|
87
|
+
csd, shift0, shift1 = csd_decompose(kernel)
|
|
88
|
+
for i, qint in enumerate(qintervals):
|
|
89
|
+
if qint.min == qint.max == 0:
|
|
90
|
+
csd[i] = 0
|
|
91
|
+
n_bits = csd.shape[-1]
|
|
92
|
+
expr = list(csd)
|
|
93
|
+
shifts = (shift0, shift1)
|
|
94
|
+
|
|
95
|
+
# Dirty numba typing trick
|
|
96
|
+
stat = {Pair(-1, -1, False, 0): 0}
|
|
97
|
+
del stat[Pair(-1, -1, False, 0)]
|
|
98
|
+
|
|
99
|
+
# Loop over outputs, in0, in1, shift0, shift1 to gather all two-term pairs
|
|
100
|
+
# Force i1>=i0
|
|
101
|
+
if not no_stat_init:
|
|
102
|
+
# Initialize the stat dictionary
|
|
103
|
+
# Skip if no_stat_init is True (skip optimization)
|
|
104
|
+
for i_out in range(n_out):
|
|
105
|
+
for i0 in range(n_in):
|
|
106
|
+
for j0 in range(n_bits):
|
|
107
|
+
bit0 = csd[i0, i_out, j0]
|
|
108
|
+
if not bit0:
|
|
109
|
+
continue
|
|
110
|
+
for i1 in range(i0, n_in):
|
|
111
|
+
for j1 in range(n_bits):
|
|
112
|
+
bit1 = csd[i1, i_out, j1]
|
|
113
|
+
if not bit1:
|
|
114
|
+
continue
|
|
115
|
+
# Avoid count the same bit
|
|
116
|
+
if i0 == i1 and j0 <= j1:
|
|
117
|
+
continue
|
|
118
|
+
pair = Pair(i0, i1, bit0 != bit1, j1 - j0)
|
|
119
|
+
stat[pair] = stat.get(pair, 0) + 1
|
|
120
|
+
|
|
121
|
+
for k in list(stat.keys()):
|
|
122
|
+
if stat[k] < 2.0:
|
|
123
|
+
del stat[k]
|
|
124
|
+
|
|
125
|
+
ops = [Op(i, -1, -1, 0, qintervals[i], inp_latencies[i], 0.0) for i in range(n_in)]
|
|
126
|
+
|
|
127
|
+
return DAState(
|
|
128
|
+
shifts=shifts,
|
|
129
|
+
expr=expr,
|
|
130
|
+
ops=ops,
|
|
131
|
+
freq_stat=stat,
|
|
132
|
+
kernel=kernel,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@jit
|
|
137
|
+
def update_stats(
|
|
138
|
+
state: DAState,
|
|
139
|
+
pair: Pair,
|
|
140
|
+
):
|
|
141
|
+
"""Updates the statistics of any 2-term pair in the state that may be affected by implementing op."""
|
|
142
|
+
id0, id1 = pair.id0, pair.id1
|
|
143
|
+
|
|
144
|
+
ks = list(state.freq_stat.keys())
|
|
145
|
+
for k in ks:
|
|
146
|
+
if k.id0 == id0 or k.id1 == id1 or k.id1 == id0 or k.id0 == id1:
|
|
147
|
+
del state.freq_stat[k]
|
|
148
|
+
|
|
149
|
+
n_constructed = len(state.expr)
|
|
150
|
+
modified = [n_constructed - 1]
|
|
151
|
+
modified.append(id0)
|
|
152
|
+
if id1 != id0:
|
|
153
|
+
modified.append(id1)
|
|
154
|
+
|
|
155
|
+
n_bits = state.expr[0].shape[-1]
|
|
156
|
+
|
|
157
|
+
# Loop over outputs, in0, in1, shift0, shift1 to gather all two-term pairs
|
|
158
|
+
for i_out in range(state.kernel.shape[1]):
|
|
159
|
+
for _in0 in modified:
|
|
160
|
+
for _in1 in range(n_constructed):
|
|
161
|
+
if _in1 in modified and _in0 > _in1:
|
|
162
|
+
# Avoid double counting of the two locations when _i0 != _i1
|
|
163
|
+
continue
|
|
164
|
+
# Order inputs, as _in0 can be either in0 or in1, range of _in is not restricted
|
|
165
|
+
id0, id1 = (_in0, _in1) if _in0 <= _in1 else (_in1, _in0)
|
|
166
|
+
for j0 in range(n_bits):
|
|
167
|
+
bit0 = state.expr[id0][i_out, j0]
|
|
168
|
+
if not bit0:
|
|
169
|
+
continue
|
|
170
|
+
for j1 in range(n_bits):
|
|
171
|
+
bit1 = state.expr[id1][i_out, j1]
|
|
172
|
+
if not bit1:
|
|
173
|
+
continue
|
|
174
|
+
if id0 == id1 and j0 <= j1:
|
|
175
|
+
continue
|
|
176
|
+
pair = Pair(id0, id1, bit0 != bit1, j1 - j0)
|
|
177
|
+
state.freq_stat[pair] = state.freq_stat.get(pair, 0) + 1
|
|
178
|
+
|
|
179
|
+
ks, vs = list(state.freq_stat.keys()), list(state.freq_stat.values())
|
|
180
|
+
for k, v in zip(ks, vs):
|
|
181
|
+
if v < 2.0:
|
|
182
|
+
del state.freq_stat[k]
|
|
183
|
+
return state
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@jit
|
|
187
|
+
def gather_matching_idxs(state: DAState, pair: Pair):
|
|
188
|
+
"""Generates all i_out, j0, j1 ST expr[i_out][in0, j0] and expr[i_out][in1, j1] corresponds to op provided."""
|
|
189
|
+
id0, id1 = pair.id0, pair.id1
|
|
190
|
+
shift = pair.shift
|
|
191
|
+
sub = pair.sub
|
|
192
|
+
n_out = state.kernel.shape[1]
|
|
193
|
+
n_bits = state.expr[0].shape[-1]
|
|
194
|
+
|
|
195
|
+
flip = False
|
|
196
|
+
if shift < 0:
|
|
197
|
+
id0, id1 = id1, id0
|
|
198
|
+
shift = -shift
|
|
199
|
+
flip = True
|
|
200
|
+
|
|
201
|
+
sign = 1 if not sub else -1
|
|
202
|
+
|
|
203
|
+
for j0 in range(n_bits - shift):
|
|
204
|
+
for i_out in range(n_out):
|
|
205
|
+
bit0 = state.expr[id0][i_out, j0]
|
|
206
|
+
j1 = j0 + shift
|
|
207
|
+
bit1 = state.expr[id1][i_out, j1]
|
|
208
|
+
if sign * bit1 * bit0 != 1:
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
if flip:
|
|
212
|
+
yield i_out, j1, j0
|
|
213
|
+
else:
|
|
214
|
+
yield i_out, j0, j1
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@jit
|
|
218
|
+
def pair_to_op(pair: Pair, state: DAState, adder_size: int = -1, carry_size: int = -1):
|
|
219
|
+
id0, id1 = pair.id0, pair.id1
|
|
220
|
+
dlat, cost = cost_add(
|
|
221
|
+
state.ops[pair.id0].qint,
|
|
222
|
+
state.ops[pair.id1].qint,
|
|
223
|
+
pair.shift,
|
|
224
|
+
pair.sub,
|
|
225
|
+
adder_size=adder_size,
|
|
226
|
+
carry_size=carry_size,
|
|
227
|
+
)
|
|
228
|
+
lat = max(state.ops[id0].latency, state.ops[id1].latency) + dlat
|
|
229
|
+
qint = qint_add(
|
|
230
|
+
state.ops[pair.id0].qint,
|
|
231
|
+
state.ops[pair.id1].qint,
|
|
232
|
+
shift=pair.shift,
|
|
233
|
+
sub1=pair.sub,
|
|
234
|
+
)
|
|
235
|
+
return Op(id0, id1, int(pair.sub), pair.shift, qint, lat, cost)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@jit
|
|
239
|
+
def update_expr(
|
|
240
|
+
state: DAState,
|
|
241
|
+
pair: Pair,
|
|
242
|
+
adder_size: int,
|
|
243
|
+
carry_size: int,
|
|
244
|
+
):
|
|
245
|
+
"Updates the state by implementing the operation op, excepts common 2-term pair freq update."
|
|
246
|
+
id0, id1 = pair.id0, pair.id1
|
|
247
|
+
op = pair_to_op(pair, state, adder_size=adder_size, carry_size=carry_size)
|
|
248
|
+
n_out = state.kernel.shape[1]
|
|
249
|
+
n_bits = state.expr[0].shape[-1]
|
|
250
|
+
|
|
251
|
+
expr = state.expr.copy()
|
|
252
|
+
ops = state.ops.copy()
|
|
253
|
+
|
|
254
|
+
ops.append(op)
|
|
255
|
+
|
|
256
|
+
new_slice = np.zeros((n_out, n_bits), dtype=np.int8)
|
|
257
|
+
|
|
258
|
+
for i_out, j0, j1 in gather_matching_idxs(state, pair):
|
|
259
|
+
new_slice[i_out, j0] = expr[id0][i_out, j0]
|
|
260
|
+
expr[id0][i_out, j0] = 0
|
|
261
|
+
expr[id1][i_out, j1] = 0
|
|
262
|
+
|
|
263
|
+
expr.append(new_slice)
|
|
264
|
+
|
|
265
|
+
return DAState(
|
|
266
|
+
shifts=state.shifts,
|
|
267
|
+
expr=expr,
|
|
268
|
+
ops=ops,
|
|
269
|
+
freq_stat=state.freq_stat,
|
|
270
|
+
kernel=state.kernel,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@jit
|
|
275
|
+
def update_state(
|
|
276
|
+
state: DAState,
|
|
277
|
+
pair_chosen: Pair,
|
|
278
|
+
adder_size: int,
|
|
279
|
+
carry_size: int,
|
|
280
|
+
):
|
|
281
|
+
"""Update the state by removing all occurrences of pair_chosen from the state, register op code, and update the statistics."""
|
|
282
|
+
state = update_expr(state, pair_chosen, adder_size=adder_size, carry_size=carry_size)
|
|
283
|
+
state = update_stats(state, pair_chosen)
|
|
284
|
+
return state
|