da4ml 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +16 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +3 -34
- da4ml/cmvm/api.py +239 -73
- da4ml/cmvm/core/__init__.py +222 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +569 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +11 -0
- da4ml/codegen/cpp/__init__.py +3 -0
- da4ml/codegen/cpp/cpp_codegen.py +148 -0
- da4ml/codegen/cpp/source/vitis.h +30 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
- da4ml/codegen/verilog/__init__.py +13 -0
- da4ml/codegen/verilog/comb.py +146 -0
- da4ml/codegen/verilog/io_wrapper.py +255 -0
- da4ml/codegen/verilog/pipeline.py +49 -0
- da4ml/codegen/verilog/source/build_binder.mk +27 -0
- da4ml/codegen/verilog/source/build_prj.tcl +75 -0
- da4ml/codegen/verilog/source/ioutils.hh +117 -0
- da4ml/codegen/verilog/source/shift_adder.v +56 -0
- da4ml/codegen/verilog/source/template.xdc +29 -0
- da4ml/codegen/verilog/verilog_model.py +265 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +358 -0
- da4ml/trace/fixed_variable_array.py +177 -0
- da4ml/trace/ops/__init__.py +55 -0
- da4ml/trace/ops/conv_utils.py +104 -0
- da4ml/trace/ops/einsum_utils.py +299 -0
- da4ml/trace/pipeline.py +155 -0
- da4ml/trace/tracer.py +120 -0
- da4ml-0.2.0.dist-info/METADATA +65 -0
- da4ml-0.2.0.dist-info/RECORD +39 -0
- {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
- da4ml/cmvm/balanced_reduction.py +0 -46
- da4ml/cmvm/cmvm.py +0 -328
- da4ml/cmvm/codegen.py +0 -159
- da4ml/cmvm/csd.py +0 -73
- da4ml/cmvm/fixed_variable.py +0 -205
- da4ml/cmvm/graph_compile.py +0 -85
- da4ml/cmvm/nb_fixed_precision.py +0 -98
- da4ml/cmvm/scoring.py +0 -55
- da4ml/cmvm/utils.py +0 -5
- da4ml-0.1.1.dist-info/METADATA +0 -121
- da4ml-0.1.1.dist-info/RECORD +0 -18
- {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info/licenses}/LICENSE +0 -0
- {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
da4ml/cmvm/cmvm.py
DELETED
|
@@ -1,328 +0,0 @@
|
|
|
1
|
-
import heapq
|
|
2
|
-
from math import ceil
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
from numba import njit, prange
|
|
6
|
-
from numpy.typing import NDArray
|
|
7
|
-
|
|
8
|
-
from .csd import to_csd
|
|
9
|
-
from .nb_fixed_precision import NBFixedPrecision
|
|
10
|
-
from .scoring import scorer
|
|
11
|
-
from .utils import DAState, OpCode, Score
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@njit
|
|
15
|
-
def extract_pairs(
|
|
16
|
-
csd: list[NDArray[np.int8]], precisions: list[NBFixedPrecision], updated: list[int] | None = None, dc: int | None = None
|
|
17
|
-
):
|
|
18
|
-
d_in = len(csd)
|
|
19
|
-
if d_in == 0:
|
|
20
|
-
raise ValueError('csd must have at least one element')
|
|
21
|
-
d_out, n_bit = csd[0].shape
|
|
22
|
-
_stat = np.zeros((d_in, d_in, n_bit, 2), dtype=np.int64)
|
|
23
|
-
process_locs = np.zeros((d_in, d_out), dtype=np.bool_)
|
|
24
|
-
|
|
25
|
-
if updated is not None:
|
|
26
|
-
for i in range(len(updated)):
|
|
27
|
-
pos = updated[i]
|
|
28
|
-
for n in range(d_out):
|
|
29
|
-
process_locs[pos, n] = True
|
|
30
|
-
else:
|
|
31
|
-
for pos in range(d_in):
|
|
32
|
-
for n in range(d_out):
|
|
33
|
-
if np.any(csd[pos][n]) and precisions[pos].b != 0:
|
|
34
|
-
process_locs[pos, n] = True
|
|
35
|
-
|
|
36
|
-
if dc is not None:
|
|
37
|
-
depths = np.zeros(d_in, dtype=np.int64)
|
|
38
|
-
for pos in range(d_in):
|
|
39
|
-
depths[pos] = precisions[pos]._depth
|
|
40
|
-
depth_min = np.min(depths)
|
|
41
|
-
mask = depths <= depth_min + dc
|
|
42
|
-
if np.count_nonzero(mask) >= 2:
|
|
43
|
-
for n in range(d_out):
|
|
44
|
-
process_locs[:, n] &= mask
|
|
45
|
-
|
|
46
|
-
args: list[tuple[int, int, int]] = []
|
|
47
|
-
for pos0 in range(d_in):
|
|
48
|
-
for pos1 in range(d_in):
|
|
49
|
-
for n in range(d_out):
|
|
50
|
-
if process_locs[pos0, n] or process_locs[pos1, n]:
|
|
51
|
-
args.append((pos0, pos1, n))
|
|
52
|
-
|
|
53
|
-
for idx in prange(len(args)):
|
|
54
|
-
pos0, pos1, n = args[idx]
|
|
55
|
-
for shift0 in range(n_bit):
|
|
56
|
-
if csd[pos0][n, shift0] == 0:
|
|
57
|
-
continue
|
|
58
|
-
lower = shift0 if pos0 < pos1 else shift0 + 1
|
|
59
|
-
for shift1 in range(lower, n_bit):
|
|
60
|
-
if csd[pos1][n, shift1] == 0:
|
|
61
|
-
continue
|
|
62
|
-
dsign = int(csd[pos0][n, shift0] != csd[pos1][n, shift1])
|
|
63
|
-
_stat[pos0, pos1, shift1 - shift0, dsign] += 1
|
|
64
|
-
|
|
65
|
-
stat: list[tuple[float, int, int, int, int, int]] = []
|
|
66
|
-
for pos0 in range(d_in):
|
|
67
|
-
for pos1 in range(d_in):
|
|
68
|
-
for dshift in range(n_bit):
|
|
69
|
-
for dsign in range(2):
|
|
70
|
-
if _stat[pos0, pos1, dshift, dsign] > 1:
|
|
71
|
-
count = int(_stat[pos0, pos1, dshift, dsign])
|
|
72
|
-
score = scorer(precisions[pos0], precisions[pos1], dshift, dsign)
|
|
73
|
-
n_count = (count - 1) if count > 1 else 0
|
|
74
|
-
data = (-score * n_count, count, pos0, pos1, dshift, dsign)
|
|
75
|
-
stat.append(data)
|
|
76
|
-
if updated is None:
|
|
77
|
-
heapq.heapify(stat)
|
|
78
|
-
return stat
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@njit
|
|
82
|
-
def init_var(
|
|
83
|
-
k: bool,
|
|
84
|
-
b: int,
|
|
85
|
-
i: int,
|
|
86
|
-
symmetric: bool = False,
|
|
87
|
-
_depth: int = 0,
|
|
88
|
-
) -> NBFixedPrecision:
|
|
89
|
-
shift = b - i
|
|
90
|
-
int_max = 2**b - 1
|
|
91
|
-
int_min = -int_max - 1 + symmetric if k else 0
|
|
92
|
-
return NBFixedPrecision(
|
|
93
|
-
int_min,
|
|
94
|
-
int_max,
|
|
95
|
-
shift,
|
|
96
|
-
symmetric,
|
|
97
|
-
_depth,
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
@njit
|
|
102
|
-
def init_vars(
|
|
103
|
-
ks: list[bool],
|
|
104
|
-
bs: list[int],
|
|
105
|
-
is_: list[int],
|
|
106
|
-
symmetrics: list[bool],
|
|
107
|
-
depths: list[int],
|
|
108
|
-
):
|
|
109
|
-
n = len(ks)
|
|
110
|
-
vars_ = []
|
|
111
|
-
for i in range(n):
|
|
112
|
-
vars_.append(init_var(ks[i], bs[i], is_[i], symmetrics[i], depths[i]))
|
|
113
|
-
return vars_
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
@njit
|
|
117
|
-
def init_state(
|
|
118
|
-
kernel: np.ndarray,
|
|
119
|
-
signs: list[bool],
|
|
120
|
-
bits: list[int],
|
|
121
|
-
int_bits: list[int],
|
|
122
|
-
symmetrics: list[bool],
|
|
123
|
-
depths: list[int],
|
|
124
|
-
):
|
|
125
|
-
assert kernel.ndim == 2
|
|
126
|
-
assert len(signs) == len(bits) == len(int_bits) == len(symmetrics) == len(depths) == kernel.shape[0]
|
|
127
|
-
csd, shifts = to_csd(kernel)
|
|
128
|
-
shift = shifts[0]
|
|
129
|
-
|
|
130
|
-
d_in = len(csd)
|
|
131
|
-
vars_ = init_vars(signs, bits, int_bits, symmetrics, depths)
|
|
132
|
-
op_codes = []
|
|
133
|
-
for i in range(d_in):
|
|
134
|
-
opr_code = OpCode(i, -10, shift, 0, 0, 0)
|
|
135
|
-
op_codes.append(opr_code)
|
|
136
|
-
vars_[i] = vars_[i] << shift
|
|
137
|
-
pairs = extract_pairs(csd, vars_)
|
|
138
|
-
|
|
139
|
-
potential = 0.0
|
|
140
|
-
for i in range(len(pairs)):
|
|
141
|
-
potential -= pairs[i][0]
|
|
142
|
-
score = Score(potential, 0.0, 0.0, 0.0)
|
|
143
|
-
|
|
144
|
-
state = DAState(csd, vars_, op_codes, pairs, score, kernel)
|
|
145
|
-
return state
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
@njit
|
|
149
|
-
def update_state(state: DAState, pair: tuple[float, int, int, int, int, int], dc=None):
|
|
150
|
-
neg_cum_score, count, pos0, pos1, dshift, dsign = pair
|
|
151
|
-
variables = state.variables.copy()
|
|
152
|
-
op_codes = state.op_codes.copy()
|
|
153
|
-
csd = state.csd
|
|
154
|
-
pairs = state.pairs
|
|
155
|
-
|
|
156
|
-
realized = state.score.realized - neg_cum_score
|
|
157
|
-
|
|
158
|
-
_d_in, d_out, n_bit = len(csd), *csd[0].shape
|
|
159
|
-
new_csd_col = np.zeros((d_out, n_bit), dtype=np.int8)
|
|
160
|
-
dsign = -1 if dsign else 1
|
|
161
|
-
for n in range(d_out):
|
|
162
|
-
for shift0 in range(0, n_bit - dshift):
|
|
163
|
-
_dsign = csd[pos0][n, shift0] * csd[pos1][n, shift0 + dshift]
|
|
164
|
-
if _dsign == dsign:
|
|
165
|
-
new_csd_col[n, shift0] = csd[pos0][n, shift0]
|
|
166
|
-
csd[pos0][n, shift0] = 0
|
|
167
|
-
csd[pos1][n, shift0 + dshift] = 0
|
|
168
|
-
|
|
169
|
-
csd.append(new_csd_col)
|
|
170
|
-
v0, v1 = variables[pos0], variables[pos1] << dshift
|
|
171
|
-
v = v0 + v1 if dsign == 0 else v0 - v1
|
|
172
|
-
variables.append(v)
|
|
173
|
-
op_code = OpCode(pos0, pos1, 0, dshift, 1, dsign)
|
|
174
|
-
op_codes.append(op_code)
|
|
175
|
-
updated = [pos0, pos1, len(variables) - 1]
|
|
176
|
-
|
|
177
|
-
d_pairs = extract_pairs(csd, variables, updated, dc)
|
|
178
|
-
for i in range(len(pairs) - 1, -1, -1):
|
|
179
|
-
_pair = pairs[i]
|
|
180
|
-
if pos0 == _pair[2] or pos1 == _pair[2] or pos0 == _pair[3] or pos1 == _pair[3]:
|
|
181
|
-
pairs.pop(i)
|
|
182
|
-
|
|
183
|
-
for i in range(len(d_pairs)):
|
|
184
|
-
heapq.heappush(pairs, d_pairs[i])
|
|
185
|
-
|
|
186
|
-
cur_potential = 0.0
|
|
187
|
-
for i in range(len(pairs)):
|
|
188
|
-
cur_potential -= pairs[i][0]
|
|
189
|
-
|
|
190
|
-
lost = state.score.potential - cur_potential
|
|
191
|
-
value = realized - lost
|
|
192
|
-
score = Score(state.score.potential, realized, lost, value)
|
|
193
|
-
return DAState(csd, variables, op_codes, pairs, score, state.kernel)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
@njit(cache=True)
|
|
197
|
-
def get_top_n_pairs(state: DAState, n: int):
|
|
198
|
-
return state.pairs[:n]
|
|
199
|
-
_pairs = state.pairs.copy()
|
|
200
|
-
return [heapq.heappop(_pairs) for _ in range(min(n, len(_pairs)))]
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
@njit
|
|
204
|
-
def cmvm_cse(state: DAState, progress=None, beams: int = 1, dc=None):
|
|
205
|
-
assert len(state.pairs) > 0, f'{len(state.pairs)}'
|
|
206
|
-
top_pairs = get_top_n_pairs(state, beams)
|
|
207
|
-
states_0 = [update_state(state, top_pairs[i], dc) for i in range(len(top_pairs))]
|
|
208
|
-
|
|
209
|
-
next_score = np.full((beams, beams), -1, dtype=np.float64)
|
|
210
|
-
use_n = np.empty(beams, dtype=np.int64)
|
|
211
|
-
while True:
|
|
212
|
-
use_n[:] = 0
|
|
213
|
-
next_score[:] = -1
|
|
214
|
-
for i in range(len(states_0)):
|
|
215
|
-
if len(states_0[i].pairs) > 0 and states_0[i].pairs[0][0] < 0:
|
|
216
|
-
break
|
|
217
|
-
else:
|
|
218
|
-
break
|
|
219
|
-
|
|
220
|
-
for i in range(len(states_0)):
|
|
221
|
-
# pass
|
|
222
|
-
_state = states_0[i]
|
|
223
|
-
next_score[i, 0] = _state.score.realized
|
|
224
|
-
next_score[i, 1:] = -1
|
|
225
|
-
top_pairs = get_top_n_pairs(_state, beams)
|
|
226
|
-
for j in range(len(top_pairs)):
|
|
227
|
-
next_score[i, j] = -top_pairs[j][0] + _state.score.realized
|
|
228
|
-
|
|
229
|
-
order = np.argsort(next_score.ravel())[::-1]
|
|
230
|
-
for i in range(beams):
|
|
231
|
-
i_st, i_pair = np.divmod(order[i], next_score.shape[0])
|
|
232
|
-
if next_score[i_st, i_pair] > 0:
|
|
233
|
-
use_n[i_st] += 1
|
|
234
|
-
|
|
235
|
-
states_1 = []
|
|
236
|
-
for i in range(beams):
|
|
237
|
-
if use_n[i] == 0:
|
|
238
|
-
continue
|
|
239
|
-
|
|
240
|
-
_state = states_0[i]
|
|
241
|
-
if len(_state.pairs) == 0 or _state.pairs[0][0] > 0:
|
|
242
|
-
states_1.append(_state)
|
|
243
|
-
continue
|
|
244
|
-
|
|
245
|
-
top_pairs = get_top_n_pairs(_state, use_n[i])
|
|
246
|
-
for j in range(use_n[i]):
|
|
247
|
-
states_1.append(update_state(_state, top_pairs[j], dc))
|
|
248
|
-
|
|
249
|
-
states_0 = states_1
|
|
250
|
-
|
|
251
|
-
_max = states_0[0].score.realized
|
|
252
|
-
_idx = 0
|
|
253
|
-
for i in range(len(states_0)):
|
|
254
|
-
_state = states_0[i]
|
|
255
|
-
if _state.score.realized > _max:
|
|
256
|
-
_max = _state.score.realized
|
|
257
|
-
_idx = i
|
|
258
|
-
return states_0[_idx]
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
@njit
|
|
262
|
-
def compile_kernel_mono(
|
|
263
|
-
kernel: np.ndarray,
|
|
264
|
-
signs: list[bool],
|
|
265
|
-
bits: list[int],
|
|
266
|
-
int_bits: list[int],
|
|
267
|
-
symmetrics: list[bool],
|
|
268
|
-
depths: list[int],
|
|
269
|
-
n_beams: int = 1,
|
|
270
|
-
dc: int | None = None,
|
|
271
|
-
):
|
|
272
|
-
state = init_state(kernel, signs, bits, int_bits, symmetrics, depths)
|
|
273
|
-
_state = cmvm_cse(state, beams=n_beams, dc=dc)
|
|
274
|
-
return _state
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
# @njit(cache=True)
|
|
278
|
-
def compile_kernel(
|
|
279
|
-
kernel: np.ndarray,
|
|
280
|
-
signs: list[bool],
|
|
281
|
-
bits: list[int],
|
|
282
|
-
int_bits: list[int],
|
|
283
|
-
symmetrics: list[bool],
|
|
284
|
-
depths: list[int],
|
|
285
|
-
n_beams: int = 1,
|
|
286
|
-
dc: int | None = None,
|
|
287
|
-
n_inp_max: int = -1,
|
|
288
|
-
n_out_max: int = -1,
|
|
289
|
-
) -> list[list[DAState]]:
|
|
290
|
-
d_inp, d_out = kernel.shape
|
|
291
|
-
n_inp_part = 1
|
|
292
|
-
if n_inp_max > 0 and n_inp_max < d_inp:
|
|
293
|
-
n_inp_part = ceil(d_inp / n_inp_max)
|
|
294
|
-
n_out_part = 1
|
|
295
|
-
if n_out_max > 0 and n_out_max < d_out:
|
|
296
|
-
n_out_part = ceil(d_out / n_out_max)
|
|
297
|
-
|
|
298
|
-
inp_chunk_size = ceil(d_inp / n_inp_part)
|
|
299
|
-
out_chunk_size = ceil(d_out / n_out_part)
|
|
300
|
-
|
|
301
|
-
inp_part_locs = np.arange(0, n_inp_part + 1) * inp_chunk_size
|
|
302
|
-
out_part_locs = np.arange(0, n_out_part + 1) * out_chunk_size
|
|
303
|
-
|
|
304
|
-
states = [[None for _ in range(n_out_part)] for _ in range(n_inp_part)]
|
|
305
|
-
for idx in range(n_inp_part * n_out_part):
|
|
306
|
-
j, i = np.divmod(idx, n_inp_part)
|
|
307
|
-
inp_start, inp_end = inp_part_locs[i], min(inp_part_locs[i + 1], d_inp) # type: ignore
|
|
308
|
-
out_start, out_end = out_part_locs[j], min(out_part_locs[j + 1], d_out) # type: ignore
|
|
309
|
-
_kernel = kernel[inp_start:inp_end, out_start:out_end]
|
|
310
|
-
_signs = signs[inp_start:inp_end]
|
|
311
|
-
_bits = bits[inp_start:inp_end]
|
|
312
|
-
_int_bits = int_bits[inp_start:inp_end]
|
|
313
|
-
_symmetrics = symmetrics[inp_start:inp_end]
|
|
314
|
-
_depths = depths[inp_start:inp_end]
|
|
315
|
-
|
|
316
|
-
# unify input type to prevent recompilation
|
|
317
|
-
_kernel = np.ascontiguousarray(_kernel)
|
|
318
|
-
_signs = [bool(v) for v in _signs]
|
|
319
|
-
_bits = [int(v) for v in _bits]
|
|
320
|
-
_int_bits = [int(v) for v in _int_bits]
|
|
321
|
-
_symmetrics = [bool(v) for v in _symmetrics]
|
|
322
|
-
_depths = [int(v) for v in _depths]
|
|
323
|
-
try:
|
|
324
|
-
states[i][j] = compile_kernel_mono(_kernel, _signs, _bits, _int_bits, _symmetrics, _depths, n_beams, dc)
|
|
325
|
-
except AssertionError:
|
|
326
|
-
states[i][j] = init_state(_kernel, _signs, _bits, _int_bits, _symmetrics, _depths)
|
|
327
|
-
|
|
328
|
-
return states # type: ignore
|
da4ml/cmvm/codegen.py
DELETED
|
@@ -1,159 +0,0 @@
|
|
|
1
|
-
import types
|
|
2
|
-
from math import log2
|
|
3
|
-
|
|
4
|
-
from .fixed_variable import FixedVariable, Namer
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class PyCodegenBackend:
|
|
8
|
-
_comment = '#'
|
|
9
|
-
|
|
10
|
-
def __init__(self, namer=Namer(), fn_name: str = 'placeholder', **kwargs):
|
|
11
|
-
self._namer = namer
|
|
12
|
-
self._attrs = {'fn_name': fn_name, **kwargs}
|
|
13
|
-
|
|
14
|
-
def reference_code(self, v: FixedVariable):
|
|
15
|
-
"""How the variable should be referenced in the code"""
|
|
16
|
-
if v.int_min == v.int_max:
|
|
17
|
-
return f'{v.min}'
|
|
18
|
-
|
|
19
|
-
neg = v._factor < 0
|
|
20
|
-
shift = log2(abs(v._factor))
|
|
21
|
-
assert shift % 1 == 0
|
|
22
|
-
shift = int(shift)
|
|
23
|
-
s_sign = '-' if neg else ''
|
|
24
|
-
s_shift = f' * {2.**shift}' if shift != 0 else ''
|
|
25
|
-
return f'{s_sign}{v.name}{s_shift}'
|
|
26
|
-
|
|
27
|
-
def def_code(self, v: FixedVariable):
|
|
28
|
-
"""How the variable should be defined in the code"""
|
|
29
|
-
if v.int_min == v.int_max:
|
|
30
|
-
raise ValueError('Constant variable should not be defined')
|
|
31
|
-
assert v._from is not None, 'Variable not derived from other variables cannot be defined in runtime'
|
|
32
|
-
v1_str = self.reference_code(v._from[0])
|
|
33
|
-
v2_str = self.reference_code(v._from[1])
|
|
34
|
-
if v2_str[0] == '-':
|
|
35
|
-
return f'{v.name} = {v1_str} - {v2_str[1:]}'
|
|
36
|
-
return f'{v.name} = {v1_str} + {v2_str}'
|
|
37
|
-
|
|
38
|
-
def _resolve_variable(self, v: FixedVariable, _recorded: dict[str, FixedVariable]):
|
|
39
|
-
if v.name in _recorded:
|
|
40
|
-
return
|
|
41
|
-
|
|
42
|
-
if v.int_min == v.int_max:
|
|
43
|
-
_recorded[v.name] = v
|
|
44
|
-
return
|
|
45
|
-
|
|
46
|
-
if v._from is None:
|
|
47
|
-
raise ValueError('Variable not derived from other variables cannot be defined in runtime')
|
|
48
|
-
|
|
49
|
-
self._resolve_variable(v._from[0], _recorded)
|
|
50
|
-
self._resolve_variable(v._from[1], _recorded)
|
|
51
|
-
_recorded[v.name] = v
|
|
52
|
-
|
|
53
|
-
def resolve_all_variables(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
|
|
54
|
-
_recorded = {v.name: v for v in inputs}
|
|
55
|
-
for v in outputs:
|
|
56
|
-
self._resolve_variable(v, _recorded)
|
|
57
|
-
return _recorded
|
|
58
|
-
|
|
59
|
-
def gen_lines(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
|
|
60
|
-
variables = self.resolve_all_variables(inputs, outputs)
|
|
61
|
-
keys = list(variables.keys())
|
|
62
|
-
keys = sorted(keys, key=lambda x: variables[x]._depth)
|
|
63
|
-
codes = []
|
|
64
|
-
cur_depth = -1
|
|
65
|
-
s_inputs = set(inputs)
|
|
66
|
-
for key in keys:
|
|
67
|
-
v = variables[key]
|
|
68
|
-
if v.int_min == v.int_max or v in s_inputs:
|
|
69
|
-
continue
|
|
70
|
-
if cur_depth != v._depth:
|
|
71
|
-
cur_depth = v._depth
|
|
72
|
-
codes.append(f'{self._comment} ========================== Latency: {cur_depth} ==========================')
|
|
73
|
-
codes.append(self.def_code(v))
|
|
74
|
-
for i, out in enumerate(outputs):
|
|
75
|
-
codes.append(f'out[{i}] = {self.reference_code(out)}')
|
|
76
|
-
return codes
|
|
77
|
-
|
|
78
|
-
def gen_fn(self, inputs: list[FixedVariable], outputs: list[FixedVariable], **kwargs):
|
|
79
|
-
fn_name = kwargs.get('fn_name', self._attrs['fn_name'])
|
|
80
|
-
code = self.gen_lines(inputs, outputs)
|
|
81
|
-
code_str = '\n '.join(code)
|
|
82
|
-
fn_str = f"""def {fn_name}(inp: list[float]):
|
|
83
|
-
out = [0.]*{len(outputs)}
|
|
84
|
-
{code_str}
|
|
85
|
-
return out
|
|
86
|
-
"""
|
|
87
|
-
fn_obj = compile(fn_str, '<string>', 'exec')
|
|
88
|
-
fn = types.FunctionType(fn_obj.co_consts[1], globals())
|
|
89
|
-
return fn, fn_str
|
|
90
|
-
|
|
91
|
-
def __call__(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
|
|
92
|
-
return self.gen_fn(inputs, outputs)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
class VitisCodegenBackend(PyCodegenBackend):
|
|
96
|
-
_comment = '//'
|
|
97
|
-
|
|
98
|
-
def __init__(self, namer=Namer(), fn_name: str = 'placeholder', **kwargs):
|
|
99
|
-
self._namer = namer
|
|
100
|
-
self._attrs = {'fn_name': fn_name, **kwargs}
|
|
101
|
-
|
|
102
|
-
def reference_code(self, v: FixedVariable):
|
|
103
|
-
"""How the variable should be referenced in the code"""
|
|
104
|
-
if v.int_min == v.int_max:
|
|
105
|
-
k, b, i = v.k, v.b, v.i
|
|
106
|
-
u = '' if k else 'u'
|
|
107
|
-
type_str = f'ap_{u}fixed<{max(b+k,1)}, {i+k}>'
|
|
108
|
-
return f'{type_str}({v.min})'
|
|
109
|
-
|
|
110
|
-
neg = v._factor < 0
|
|
111
|
-
shift = log2(abs(v._factor))
|
|
112
|
-
assert shift % 1 == 0
|
|
113
|
-
shift = int(shift)
|
|
114
|
-
s_sign = '-' if neg else ''
|
|
115
|
-
if shift == 0:
|
|
116
|
-
return f'{s_sign}{v.name}'
|
|
117
|
-
return f'{s_sign}bit_shift<{shift}>({v.name})'
|
|
118
|
-
|
|
119
|
-
def gen_lines(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
|
|
120
|
-
codes = super().gen_lines(inputs, outputs)
|
|
121
|
-
n = len(outputs)
|
|
122
|
-
for i, out in enumerate(outputs):
|
|
123
|
-
codes[-n + i] = f'out[{i}] = {self.reference_code(out)};'
|
|
124
|
-
return codes
|
|
125
|
-
|
|
126
|
-
def def_code(self, v: FixedVariable):
|
|
127
|
-
"""How the variable should be defined in the code"""
|
|
128
|
-
if v.int_min == v.int_max:
|
|
129
|
-
raise ValueError('Constant variable should not be defined')
|
|
130
|
-
assert v._from is not None, 'Variable not derived from other variables cannot be defined in runtime'
|
|
131
|
-
v1_str = self.reference_code(v._from[0])
|
|
132
|
-
v2_str = self.reference_code(v._from[1])
|
|
133
|
-
vv = v * (1 / v._factor)
|
|
134
|
-
k, b, i = vv.k, vv.b, vv.i
|
|
135
|
-
b, i = b + k, i + k # b and i did not include sign bit
|
|
136
|
-
u = '' if k else 'u'
|
|
137
|
-
type_str = f'ap_{u}fixed<{b}, {i}>'
|
|
138
|
-
if v2_str[0] == '-':
|
|
139
|
-
return f'{type_str} {v.name} = {v1_str} - {v2_str[1:]};'
|
|
140
|
-
return f'{type_str} {v.name} = {v1_str} + {v2_str};'
|
|
141
|
-
|
|
142
|
-
def gen_fn(self, inputs: list[FixedVariable], outputs: list[FixedVariable], **kwargs):
|
|
143
|
-
attrs = {**self._attrs, **kwargs}
|
|
144
|
-
fn_name = attrs['fn_name']
|
|
145
|
-
code = self.gen_lines(inputs, outputs)
|
|
146
|
-
code_str = '\n '.join(code)
|
|
147
|
-
|
|
148
|
-
fn_str = f"""template <typename inp_t, typename out_t>
|
|
149
|
-
void {fn_name}(inp_t inp[{len(inputs)}], out_t out[{len(outputs)}]) {{
|
|
150
|
-
{code_str}
|
|
151
|
-
}}
|
|
152
|
-
"""
|
|
153
|
-
self._comment = '#'
|
|
154
|
-
fn, _ = PyCodegenBackend().gen_fn(inputs, outputs, fn_name=fn_name)
|
|
155
|
-
self._comment = '//'
|
|
156
|
-
return fn, fn_str
|
|
157
|
-
|
|
158
|
-
def __call__(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
|
|
159
|
-
return self.gen_fn(inputs, outputs)
|
da4ml/cmvm/csd.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
from numba import njit
|
|
3
|
-
from numpy.typing import NDArray
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
@njit
|
|
7
|
-
def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
|
|
8
|
-
x = x
|
|
9
|
-
N = np.max(np.ceil(np.log2(np.abs(x + 1e-19) * 1.5)))
|
|
10
|
-
N = int(max(N, 1))
|
|
11
|
-
buf = np.zeros((*np.shape(x), N), dtype=np.int8)
|
|
12
|
-
|
|
13
|
-
for n in range(N - 1, -1, -1):
|
|
14
|
-
_2pn = 2**n
|
|
15
|
-
thres = _2pn / 1.5
|
|
16
|
-
bit = (x > thres).astype(np.int8)
|
|
17
|
-
bit -= (x < -thres).astype(np.int8)
|
|
18
|
-
x -= _2pn * bit
|
|
19
|
-
buf[..., n] = bit
|
|
20
|
-
return buf
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@njit(error_model='numpy')
|
|
24
|
-
def to_csd(x: NDArray) -> tuple[list[NDArray[np.int8]], NDArray[np.int8]]:
|
|
25
|
-
low, high = -32, 32
|
|
26
|
-
if np.all(x == 0):
|
|
27
|
-
high = low = 0
|
|
28
|
-
while high - low > 1:
|
|
29
|
-
mid = (high + low) // 2
|
|
30
|
-
xs = x * (2.0**mid)
|
|
31
|
-
if np.all(xs == np.floor(xs)):
|
|
32
|
-
high = mid
|
|
33
|
-
else:
|
|
34
|
-
low = mid
|
|
35
|
-
_x = x * (2.0**high)
|
|
36
|
-
csd = _volatile_int_arr_to_csd(_x)
|
|
37
|
-
shifts = np.arange(csd.shape[-1], dtype=np.int8) - high
|
|
38
|
-
return list(csd), shifts
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@njit
|
|
42
|
-
def _volatile_int_arr_to_binary(x: NDArray) -> NDArray[np.int8]:
|
|
43
|
-
x = x
|
|
44
|
-
N = np.max(np.ceil(np.log2(np.abs(x) + 1)))
|
|
45
|
-
N = int(max(N, 1))
|
|
46
|
-
buf = np.zeros((*np.shape(x), N), dtype=np.int8)
|
|
47
|
-
|
|
48
|
-
for n in range(N - 1, -1, -1):
|
|
49
|
-
_2pn = 2**n
|
|
50
|
-
thres = _2pn
|
|
51
|
-
bit = (x >= thres).astype(np.int8)
|
|
52
|
-
bit -= (x <= -thres).astype(np.int8)
|
|
53
|
-
x -= _2pn * bit
|
|
54
|
-
buf[..., n] = bit
|
|
55
|
-
return buf
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
@njit(error_model='numpy')
|
|
59
|
-
def to_binary(x: NDArray) -> tuple[list[NDArray[np.int8]], NDArray[np.int8]]:
|
|
60
|
-
low, high = -32, 32
|
|
61
|
-
if np.all(x == 0):
|
|
62
|
-
high = low = 0
|
|
63
|
-
while high - low > 1:
|
|
64
|
-
mid = (high + low) // 2
|
|
65
|
-
xs = x * (2.0**mid)
|
|
66
|
-
if np.all(xs == np.floor(xs)):
|
|
67
|
-
high = mid
|
|
68
|
-
else:
|
|
69
|
-
low = mid
|
|
70
|
-
_x = x * (2.0**high)
|
|
71
|
-
csd = _volatile_int_arr_to_binary(_x)
|
|
72
|
-
shifts = np.arange(csd.shape[-1], dtype=np.int8) - high
|
|
73
|
-
return list(csd), shifts
|