da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +204 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +246 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.1.post1.dist-info/METADATA +85 -0
- da4ml-0.5.1.post1.dist-info/RECORD +96 -0
- da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
- da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
- da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numba import jit
|
|
3
|
+
from numpy.typing import NDArray
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@jit
|
|
7
|
+
def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
|
|
8
|
+
x = x
|
|
9
|
+
N = np.max(np.ceil(np.log2(np.abs(x) * 1.5 + 1e-19)))
|
|
10
|
+
N = int(max(N, 1))
|
|
11
|
+
buf = np.zeros((*np.shape(x), N), dtype=np.int8)
|
|
12
|
+
|
|
13
|
+
for n in range(N - 1, -1, -1):
|
|
14
|
+
_2pn = 2**n
|
|
15
|
+
thres = _2pn / 1.5
|
|
16
|
+
bit = (x > thres).astype(np.int8)
|
|
17
|
+
bit -= (x < -thres).astype(np.int8)
|
|
18
|
+
x -= _2pn * bit.astype(x.dtype)
|
|
19
|
+
buf[..., n] = bit
|
|
20
|
+
return buf
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@jit(error_model='numpy')
|
|
24
|
+
def _shift_centering(arr: NDArray):
|
|
25
|
+
low, high = -64, 64
|
|
26
|
+
if np.all(arr == 0):
|
|
27
|
+
high = low = 0
|
|
28
|
+
while high - low > 1:
|
|
29
|
+
mid = (high + low) // 2
|
|
30
|
+
xs = arr * (2.0**mid)
|
|
31
|
+
if np.all(xs == np.floor(xs)):
|
|
32
|
+
high = mid
|
|
33
|
+
else:
|
|
34
|
+
low = mid
|
|
35
|
+
return -high
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@jit(error_model='numpy')
|
|
39
|
+
def shift_centering(arr: NDArray, axis: int):
|
|
40
|
+
n = arr.shape[axis]
|
|
41
|
+
shifts = np.empty(n, dtype=np.int8)
|
|
42
|
+
for i in range(n):
|
|
43
|
+
shifts[i] = _shift_centering(arr.take(i, axis=axis))
|
|
44
|
+
return shifts
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@jit
|
|
48
|
+
def _center(arr: NDArray):
|
|
49
|
+
shift1 = shift_centering(arr, 1) # d_out
|
|
50
|
+
arr = arr * (2.0**-shift1)
|
|
51
|
+
shift0 = shift_centering(arr, 0) # d_in
|
|
52
|
+
arr = arr * (2.0 ** -shift0[:, None])
|
|
53
|
+
return arr, shift0.astype(np.int8), shift1.astype(np.int8)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@jit(cache=True)
|
|
57
|
+
def csd_decompose(arr: NDArray, center=True):
|
|
58
|
+
"""
|
|
59
|
+
Convert an 2D array to CSD representation.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
arr : ndarray
|
|
64
|
+
Input array to be converted.
|
|
65
|
+
center : bool, optional
|
|
66
|
+
If True, the array is centered before conversion. Default is True.
|
|
67
|
+
If False, the function may accept non-2D arrays.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
csd : ndarray
|
|
72
|
+
CSD representation of the input array after centering, if center is True.
|
|
73
|
+
shift0 : ndarray
|
|
74
|
+
Shift values for the first axis.
|
|
75
|
+
shift1 : ndarray
|
|
76
|
+
Shift values for the second axis.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
if center:
|
|
80
|
+
arr, shift0, shift1 = _center(arr)
|
|
81
|
+
else:
|
|
82
|
+
shift0 = np.zeros(arr.shape[0], dtype=np.int8)
|
|
83
|
+
shift1 = np.zeros(arr.shape[1], dtype=np.int8)
|
|
84
|
+
arr = arr.copy()
|
|
85
|
+
csd = _volatile_int_arr_to_csd(arr)
|
|
86
|
+
return csd, shift0, shift1
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numba import jit
|
|
5
|
+
|
|
6
|
+
from .bit_decompose import _center, _volatile_int_arr_to_csd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@jit
|
|
10
|
+
def prim_mst_dc(cost_mat: np.ndarray, dc: int = -1):
|
|
11
|
+
"""Minimum Spanning Tree (MST) using Prim's algorithm with a delay constraint. May not be optimal.
|
|
12
|
+
Always start from the root node (0).
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
cost_mat : np.ndarray
|
|
17
|
+
The adjacency matrix of the graph, where cost_mat[i, j] is the cost of the edge between i and j.
|
|
18
|
+
|
|
19
|
+
dc : int, optional
|
|
20
|
+
The delay constraint, by default -1
|
|
21
|
+
If -1, no delay constraint is applied.
|
|
22
|
+
|
|
23
|
+
Delay of each edge is ceiling(log2(cost_mat[i, j])).
|
|
24
|
+
|
|
25
|
+
Delay from the root node to any node is the **maximum** latency of each edge connecting in between,
|
|
26
|
+
plus ceiling(log2(#number of connection edges)).
|
|
27
|
+
Latency is **NOT** the sum of the latencies.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
np.ndarray
|
|
32
|
+
The adjacency list of the MST, where each row is a pair of nodes (parent, child).
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
N = len(cost_mat)
|
|
36
|
+
lat_mat = np.ceil(np.log2(np.maximum(cost_mat, 1)))
|
|
37
|
+
parent = np.full(N, -2, dtype=np.int32) # -2: not visited, -1: root
|
|
38
|
+
|
|
39
|
+
parent[0] = -1
|
|
40
|
+
idxs = np.arange(N)
|
|
41
|
+
|
|
42
|
+
mapping = np.empty((N - 1, 2), dtype=np.int32)
|
|
43
|
+
latency = np.zeros((N,), dtype=np.int32)
|
|
44
|
+
|
|
45
|
+
if dc >= 0:
|
|
46
|
+
_dc = (2**dc - 1) + ceil(log2(np.max(cost_mat[0]) + 1e-32))
|
|
47
|
+
else:
|
|
48
|
+
_dc = -1
|
|
49
|
+
|
|
50
|
+
for n_impl in range(1, N):
|
|
51
|
+
implemented = parent != -2
|
|
52
|
+
_cost = cost_mat[~implemented][:, implemented]
|
|
53
|
+
if dc >= 0:
|
|
54
|
+
_lat = lat_mat[~implemented][:, implemented]
|
|
55
|
+
_cost = np.where(np.maximum(_lat, latency[implemented]) + 1 <= _dc, _cost, np.iinfo(_cost.dtype).max // 2)
|
|
56
|
+
_idx = int(np.argmin(_cost))
|
|
57
|
+
_i, _j = _idx // n_impl, _idx % n_impl
|
|
58
|
+
i, j = idxs[~implemented][_i], idxs[implemented][_j]
|
|
59
|
+
parent[i] = j
|
|
60
|
+
mapping[n_impl - 1, 0] = j
|
|
61
|
+
mapping[n_impl - 1, 1] = i
|
|
62
|
+
latency[i] = max(lat_mat[i, j], latency[j]) + 1 # type: ignore
|
|
63
|
+
|
|
64
|
+
return mapping
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@jit(cache=True)
|
|
68
|
+
def kernel_decompose(kernel: np.ndarray, dc: int = -2):
|
|
69
|
+
"""Decompose a 2D kernel matrix into two matrices with the delay-constrained approx MST.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
kernel : np.ndarray
|
|
74
|
+
The input kernel matrix to decompose.
|
|
75
|
+
|
|
76
|
+
dc : int, optional
|
|
77
|
+
Delay constraint, by default -1
|
|
78
|
+
If -2, no delay constraint is applied.
|
|
79
|
+
If -1, return trivial decomposition (m0 = kernel, m1 = I).
|
|
80
|
+
|
|
81
|
+
The delay constraint limits the maximum latency (hops) of the decomposed
|
|
82
|
+
multiplication structure.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
tuple[np.ndarray, np.ndarray]
|
|
87
|
+
The decomposed matrices (m0, m1): kernel = m0 @ m1
|
|
88
|
+
"""
|
|
89
|
+
kernel, shift0, shift1 = _center(kernel)
|
|
90
|
+
scale0, scale1 = 2.0**shift0, 2.0**shift1
|
|
91
|
+
m, n = kernel.shape[0], kernel.shape[1] + 1
|
|
92
|
+
mat_aug = np.zeros((m, n), dtype=kernel.dtype)
|
|
93
|
+
mat_aug[:, 1:] = kernel
|
|
94
|
+
diff0 = mat_aug[:, :, None] - mat_aug[:, None, :]
|
|
95
|
+
diff1 = mat_aug[:, :, None] + mat_aug[:, None, :]
|
|
96
|
+
dist0 = np.sum(np.sum(_volatile_int_arr_to_csd(diff0) != 0, axis=3), axis=0)
|
|
97
|
+
dist1 = np.sum(np.sum(_volatile_int_arr_to_csd(diff1) != 0, axis=3), axis=0)
|
|
98
|
+
sign = np.where(dist1 - dist0 < 0, -1, 1)
|
|
99
|
+
dist = np.minimum(dist0, dist1)
|
|
100
|
+
mapping = prim_mst_dc(dist, dc=dc)
|
|
101
|
+
n_in, n_out = kernel.shape
|
|
102
|
+
m0, m1 = np.zeros((n_in, n_out), dtype=kernel.dtype), np.zeros((n_out, n_out), dtype=kernel.dtype)
|
|
103
|
+
|
|
104
|
+
if dc == -1:
|
|
105
|
+
m0[:] = kernel
|
|
106
|
+
m1[:] = np.eye(n_out, dtype=kernel.dtype)
|
|
107
|
+
return m0 * scale0[:, None], m1 * scale1
|
|
108
|
+
|
|
109
|
+
cnt = 0
|
|
110
|
+
for _from, _to in mapping:
|
|
111
|
+
col0 = mat_aug[:, _to] - mat_aug[:, _from] * sign[_to, _from]
|
|
112
|
+
if _from != 0:
|
|
113
|
+
col1 = m1[:, _from - 1].copy() * sign[_to, _from]
|
|
114
|
+
else:
|
|
115
|
+
col1 = np.zeros(n_out, dtype=kernel.dtype)
|
|
116
|
+
if np.any(col0 != 0):
|
|
117
|
+
col1[cnt] = 1
|
|
118
|
+
m0[:, cnt] = col0
|
|
119
|
+
cnt += 1
|
|
120
|
+
m1[:, _to - 1] = col1
|
|
121
|
+
return m0 * scale0[:, None], m1 * scale1
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
from ...cmvm.types import CombLogic, QInterval, _minimal_kif
|
|
4
|
+
from ...trace.fixed_variable import _const_f
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def kif_to_vitis_type(k: bool | int = 1, i: int = 0, f: int = 0):
|
|
8
|
+
if k == i == f == 0:
|
|
9
|
+
f = 1
|
|
10
|
+
return f'ap_{"" if k else "u"}fixed<{k + i + f},{k + i}>'
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def kif_to_hlslib_type(k: bool | int = 1, i: int = 0, f: int = 0):
|
|
14
|
+
if k == i == f == 0:
|
|
15
|
+
f = 1
|
|
16
|
+
return f'ac_fixed<{int(k)},{k + i + f},{k + i}>'
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def kif_to_oneapi_type(k: bool | int = 1, i: int = 0, f: int = 0):
|
|
20
|
+
# OneAPI requires at least 2 bits for all ac_fixed as of 2025.1
|
|
21
|
+
return f'ac_fixed<{int(k)},{max(k + i + f, 2)},{k + i}>'
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_typestr_fn(flavor: str):
|
|
25
|
+
match flavor.lower():
|
|
26
|
+
case 'vitis':
|
|
27
|
+
typestr_fn = kif_to_vitis_type
|
|
28
|
+
case 'hlslib':
|
|
29
|
+
typestr_fn = kif_to_hlslib_type
|
|
30
|
+
case 'oneapi':
|
|
31
|
+
typestr_fn = kif_to_oneapi_type
|
|
32
|
+
case _:
|
|
33
|
+
raise ValueError(f'Unsupported flavor: {flavor}')
|
|
34
|
+
return typestr_fn
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def ssa_gen(sol: CombLogic, print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
|
|
38
|
+
ops = sol.ops
|
|
39
|
+
all_kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
40
|
+
all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
|
|
41
|
+
|
|
42
|
+
lines = []
|
|
43
|
+
ref_count = sol.ref_count
|
|
44
|
+
for i, op in enumerate(ops):
|
|
45
|
+
if ref_count[i] == 0:
|
|
46
|
+
# Skip unused ops
|
|
47
|
+
continue
|
|
48
|
+
|
|
49
|
+
_type = all_types[i]
|
|
50
|
+
|
|
51
|
+
ref0 = f'v{op.id0}'
|
|
52
|
+
|
|
53
|
+
match op.opcode:
|
|
54
|
+
case -1:
|
|
55
|
+
# Input marker
|
|
56
|
+
val = f'model_inp[{op.id0}]'
|
|
57
|
+
case 0 | 1:
|
|
58
|
+
# Common a+/-b<<shift op
|
|
59
|
+
ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
|
|
60
|
+
val = f'{ref0} {"-" if op.opcode == 1 else "+"} {ref1}'
|
|
61
|
+
case 2 | -2:
|
|
62
|
+
if op.opcode == 2: # relu(model_inp)
|
|
63
|
+
if ops[op.id0].qint.min < 0:
|
|
64
|
+
val = f'{ref0} > 0 ? {_type}({ref0}) : {_type}(0)'
|
|
65
|
+
else:
|
|
66
|
+
val = ref0
|
|
67
|
+
else: # relu(-model_inp)
|
|
68
|
+
if ops[op.id0].qint.max > 0:
|
|
69
|
+
val = f'{ref0} > 0 ? {_type}(0) : {_type}(-{ref0})'
|
|
70
|
+
else:
|
|
71
|
+
val = f'-{ref0}'
|
|
72
|
+
case 3 | -3:
|
|
73
|
+
# Explicit quantization op, done implicitly via assignment
|
|
74
|
+
val = ref0 if op.opcode == 3 else f'-{ref0}'
|
|
75
|
+
case 4:
|
|
76
|
+
# Constant addition
|
|
77
|
+
_number = op.data * op.qint.step
|
|
78
|
+
sign, mag = ('-' if _number < 0 else '+'), abs(_number)
|
|
79
|
+
f = _const_f(mag)
|
|
80
|
+
const_type_str = typestr_fn(*_minimal_kif(QInterval(mag, mag, 2.0**-f)))
|
|
81
|
+
val = f'{ref0} {sign} {const_type_str}({mag})'
|
|
82
|
+
case 5:
|
|
83
|
+
# Define constant
|
|
84
|
+
_number = op.data * op.qint.step
|
|
85
|
+
val = f'{_number}'
|
|
86
|
+
case 6 | -6:
|
|
87
|
+
# MSB Mux
|
|
88
|
+
id_c = op.data & 0xFFFFFFFF
|
|
89
|
+
bw_k = sum(all_kifs[id_c])
|
|
90
|
+
shift = (op.data >> 32) & 0xFFFFFFFF
|
|
91
|
+
shift = shift if shift < 0x80000000 else shift - 0x100000000
|
|
92
|
+
ref_k = f'v{id_c}[{bw_k - 1}]'
|
|
93
|
+
sign = '-' if op.opcode == -6 else ''
|
|
94
|
+
ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
|
|
95
|
+
bw0, bw1 = sum(all_kifs[op.id0]), sum(all_kifs[op.id1])
|
|
96
|
+
if bw0 == 0:
|
|
97
|
+
ref0 = '0'
|
|
98
|
+
if bw1 == 0:
|
|
99
|
+
ref1 = '0'
|
|
100
|
+
val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
|
|
101
|
+
case 7:
|
|
102
|
+
# Multiplication
|
|
103
|
+
ref1 = f'v{op.id1}'
|
|
104
|
+
val = f'{ref0} * {ref1}'
|
|
105
|
+
case _:
|
|
106
|
+
raise ValueError(f'Unsupported opcode: {op.opcode}')
|
|
107
|
+
|
|
108
|
+
line = f'{_type} v{i} = {val};'
|
|
109
|
+
|
|
110
|
+
if print_latency:
|
|
111
|
+
line += f' // {op.latency}'
|
|
112
|
+
lines.append(line)
|
|
113
|
+
return lines
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def output_gen(sol: CombLogic, typestr_fn: Callable[[bool | int, int, int], str]):
|
|
117
|
+
lines = []
|
|
118
|
+
for i, idx in enumerate(sol.out_idxs):
|
|
119
|
+
if idx < 0:
|
|
120
|
+
lines.append(f'model_out[{i}] = 0;')
|
|
121
|
+
continue
|
|
122
|
+
_type = typestr_fn(*_minimal_kif(sol.out_qint[i]))
|
|
123
|
+
shift = sol.out_shifts[i]
|
|
124
|
+
neg_str = '-' if sol.out_negs[i] else ''
|
|
125
|
+
if shift == 0:
|
|
126
|
+
lines.append(f'model_out[{i}] = {_type}({neg_str}v{idx});')
|
|
127
|
+
else:
|
|
128
|
+
lines.append(f'model_out[{i}] = {_type}({neg_str}bit_shift<{shift}>(v{idx}));')
|
|
129
|
+
return lines
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_io_types(sol: CombLogic, flavor: str):
|
|
133
|
+
typestr_fn = get_typestr_fn(flavor)
|
|
134
|
+
in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
|
|
135
|
+
inp_type = typestr_fn(*in_kif)
|
|
136
|
+
out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
|
|
137
|
+
out_type = typestr_fn(*out_kif)
|
|
138
|
+
return inp_type, out_type
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def hls_logic_and_bridge_gen(
|
|
142
|
+
sol: CombLogic,
|
|
143
|
+
fn_name: str,
|
|
144
|
+
flavor: str,
|
|
145
|
+
pragmas: list[str] | None = None,
|
|
146
|
+
n_indent: int = 4,
|
|
147
|
+
n_base_indent: int = 0,
|
|
148
|
+
print_latency: bool = False,
|
|
149
|
+
):
|
|
150
|
+
typestr_fn = get_typestr_fn(flavor)
|
|
151
|
+
inp_t, out_t = get_io_types(sol, flavor)
|
|
152
|
+
|
|
153
|
+
n_in, n_out = sol.shape
|
|
154
|
+
template_def = 'template <typename inp_t, typename out_t>'
|
|
155
|
+
fn_signature = f'void {fn_name}(inp_t model_inp[{n_in}], out_t model_out[{n_out}])'
|
|
156
|
+
pragmas = pragmas or []
|
|
157
|
+
|
|
158
|
+
ssa_lines = ssa_gen(sol, print_latency=print_latency, typestr_fn=typestr_fn)
|
|
159
|
+
output_lines = output_gen(sol, typestr_fn=typestr_fn)
|
|
160
|
+
|
|
161
|
+
indent = ' ' * n_indent
|
|
162
|
+
base_indent = indent * n_base_indent
|
|
163
|
+
body_indent = '\n' + base_indent + indent
|
|
164
|
+
code = f"""{base_indent}{template_def}
|
|
165
|
+
{base_indent}{fn_signature} {{ // {inp_t} -> {out_t}
|
|
166
|
+
{base_indent + indent}{body_indent.join(pragmas)}
|
|
167
|
+
{body_indent}{body_indent.join(ssa_lines)}
|
|
168
|
+
{body_indent}{body_indent.join(output_lines)}
|
|
169
|
+
{base_indent}}}
|
|
170
|
+
"""
|
|
171
|
+
bridge = f"""#include "binder_util.hh"
|
|
172
|
+
#include "{fn_name}.hh"
|
|
173
|
+
|
|
174
|
+
struct {fn_name}_config {{
|
|
175
|
+
static const size_t N_inp = {n_in};
|
|
176
|
+
static const size_t N_out = {n_out};
|
|
177
|
+
typedef {inp_t} inp_t;
|
|
178
|
+
typedef {out_t} out_t;
|
|
179
|
+
constexpr static auto f = {fn_name}<inp_t, out_t>;
|
|
180
|
+
}};
|
|
181
|
+
|
|
182
|
+
extern "C" {{
|
|
183
|
+
|
|
184
|
+
bool openmp_enabled() {{
|
|
185
|
+
return _openmp;
|
|
186
|
+
}}
|
|
187
|
+
|
|
188
|
+
void inference_f64(double *model_inp, double *model_out, size_t size, size_t n_threads) {{
|
|
189
|
+
batch_inference<{fn_name}_config, double>(model_inp, model_out, size, n_threads);
|
|
190
|
+
}}
|
|
191
|
+
|
|
192
|
+
void inference_f32(float *model_inp, float *model_out, size_t size, size_t n_threads) {{
|
|
193
|
+
batch_inference<{fn_name}_config, float>(model_inp, model_out, size, n_threads);
|
|
194
|
+
}}
|
|
195
|
+
}}"""
|
|
196
|
+
return code, bridge
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import shutil
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TypeVar
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
|
|
15
|
+
from da4ml.cmvm.types import CombLogic
|
|
16
|
+
from da4ml.codegen.hls.hls_codegen import get_io_types, hls_logic_and_bridge_gen
|
|
17
|
+
|
|
18
|
+
from ... import codegen
|
|
19
|
+
from ...cmvm.types import _minimal_kif
|
|
20
|
+
|
|
21
|
+
T = TypeVar('T', bound=np.floating)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HLSModel:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
solution: CombLogic,
|
|
28
|
+
prj_name: str,
|
|
29
|
+
path: str | Path,
|
|
30
|
+
flavor: str = 'vitis',
|
|
31
|
+
print_latency: bool = True,
|
|
32
|
+
part_name: str = 'xcvu13p-flga2577-2-e',
|
|
33
|
+
pragma: Sequence[str] | None = None,
|
|
34
|
+
clock_period: int = 5,
|
|
35
|
+
clock_uncertainty: float = 0.1,
|
|
36
|
+
io_delay_minmax: tuple[float, float] = (0.2, 0.4),
|
|
37
|
+
):
|
|
38
|
+
self._solution = solution
|
|
39
|
+
self._prj_name = prj_name
|
|
40
|
+
self._path = Path(path).resolve()
|
|
41
|
+
self._flavor = flavor.lower()
|
|
42
|
+
assert self._flavor in ('vitis', 'hlslib', 'oneapi'), f'Unsupported HLS flavor: {self._flavor}'
|
|
43
|
+
self._print_latency = print_latency
|
|
44
|
+
self._part_name = part_name
|
|
45
|
+
self._clock_period = clock_period
|
|
46
|
+
self._clock_uncertainty = clock_uncertainty
|
|
47
|
+
self._io_delay_minmax = io_delay_minmax
|
|
48
|
+
self.__src_root = Path(codegen.__file__).parent
|
|
49
|
+
self._lib = None
|
|
50
|
+
self._uuid = None
|
|
51
|
+
|
|
52
|
+
if pragma is None:
|
|
53
|
+
if self._flavor == 'vitis':
|
|
54
|
+
self._pragma = (
|
|
55
|
+
'#pragma HLS ARRAY_PARTITION variable=inp complete',
|
|
56
|
+
'#pragma HLS ARRAY_PARTITION variable=out complete',
|
|
57
|
+
'#pragma HLS PIPELINE II=1',
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
self._pragma = ()
|
|
61
|
+
else:
|
|
62
|
+
self._pragma = tuple(pragma)
|
|
63
|
+
|
|
64
|
+
def write(self):
|
|
65
|
+
if not self._path.exists():
|
|
66
|
+
self._path.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
template_def, bridge = hls_logic_and_bridge_gen(
|
|
68
|
+
self._solution,
|
|
69
|
+
self._prj_name,
|
|
70
|
+
self._flavor,
|
|
71
|
+
['#pragma HLS INLINE'],
|
|
72
|
+
4,
|
|
73
|
+
0,
|
|
74
|
+
self._print_latency,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
headers = ['#pragma once', '#include "bitshift.hh"']
|
|
78
|
+
|
|
79
|
+
inp_type, out_type = get_io_types(self._solution, self._flavor)
|
|
80
|
+
n_in, n_out = len(self._solution.inp_qint), len(self._solution.out_qint)
|
|
81
|
+
template_signature = (
|
|
82
|
+
f'template <typename inp_t, typename out_t>\nvoid {self._prj_name}(inp_t inp[{n_in}], out_t out[{n_out}]);'
|
|
83
|
+
)
|
|
84
|
+
fn_signature = f'void {self._prj_name}_fn({inp_type} inp[{n_in}], {out_type} out[{n_out}])'
|
|
85
|
+
|
|
86
|
+
with open(self._path / f'{self._prj_name}.hh', 'w') as f:
|
|
87
|
+
f.write('\n'.join(headers) + '\n\n')
|
|
88
|
+
f.write(f'{template_signature}\n\n{fn_signature};\n')
|
|
89
|
+
|
|
90
|
+
pragma_str = '\n'.join(self._pragma)
|
|
91
|
+
cpp_def = f"""
|
|
92
|
+
#include "{self._prj_name}.hh"
|
|
93
|
+
|
|
94
|
+
{template_def}
|
|
95
|
+
|
|
96
|
+
{fn_signature} {{
|
|
97
|
+
{pragma_str}
|
|
98
|
+
{self._prj_name}<{inp_type}, {out_type}>(inp, out);
|
|
99
|
+
}}
|
|
100
|
+
"""
|
|
101
|
+
with open(self._path / f'{self._prj_name}.cc', 'w') as f:
|
|
102
|
+
f.write(cpp_def)
|
|
103
|
+
|
|
104
|
+
with open(self._path / f'{self._prj_name}_bridge.cc', 'w') as f:
|
|
105
|
+
f.write(bridge)
|
|
106
|
+
|
|
107
|
+
shutil.copy(self.__src_root / 'hls/source/binder_util.hh', self._path)
|
|
108
|
+
shutil.copy(self.__src_root / f'hls/source/{self._flavor}_bitshift.hh', self._path / 'bitshift.hh')
|
|
109
|
+
shutil.copy(self.__src_root / 'hls/source/build_binder.mk', self._path)
|
|
110
|
+
if self._flavor == 'vitis':
|
|
111
|
+
shutil.copytree(self.__src_root / 'hls/source/ap_types', self._path / 'ap_types', dirs_exist_ok=True)
|
|
112
|
+
else:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
self._solution.save(self._path / 'project.json')
|
|
116
|
+
|
|
117
|
+
def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
|
|
118
|
+
"""Same as compile, but will not write to the library
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
verbose : bool, optional
|
|
123
|
+
Verbose output, by default False
|
|
124
|
+
openmp : bool, optional
|
|
125
|
+
Enable openmp, by default True
|
|
126
|
+
o3 : bool | None, optional
|
|
127
|
+
Turn on -O3 flag, by default False
|
|
128
|
+
clean : bool, optional
|
|
129
|
+
Remove obsolete shared object files, by default True
|
|
130
|
+
|
|
131
|
+
Raises
|
|
132
|
+
------
|
|
133
|
+
RuntimeError
|
|
134
|
+
If compilation fails
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
self._uuid = str(uuid4())
|
|
138
|
+
args = ['make', '-f', 'build_binder.mk']
|
|
139
|
+
env = os.environ.copy()
|
|
140
|
+
env['PRJ_NAME'] = self._prj_name
|
|
141
|
+
env['STAMP'] = self._uuid
|
|
142
|
+
env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
|
|
143
|
+
if o3:
|
|
144
|
+
args.append('fast')
|
|
145
|
+
|
|
146
|
+
if clean:
|
|
147
|
+
m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
|
|
148
|
+
for p in self._path.iterdir():
|
|
149
|
+
if not p.is_dir() and m.match(p.name):
|
|
150
|
+
p.unlink()
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
|
|
154
|
+
except subprocess.CalledProcessError as e:
|
|
155
|
+
print(e.stderr.decode(), file=sys.stderr)
|
|
156
|
+
print(e.stdout.decode(), file=sys.stdout)
|
|
157
|
+
raise RuntimeError('Compilation failed!!') from e
|
|
158
|
+
if r.returncode != 0:
|
|
159
|
+
print(r.stderr.decode(), file=sys.stderr)
|
|
160
|
+
print(r.stdout.decode(), file=sys.stderr)
|
|
161
|
+
raise RuntimeError('Compilation failed!!')
|
|
162
|
+
|
|
163
|
+
self._load_lib(self._uuid)
|
|
164
|
+
|
|
165
|
+
def _load_lib(self, uuid: str | None = None):
|
|
166
|
+
uuid = uuid if uuid is not None else self._uuid
|
|
167
|
+
self._uuid = uuid
|
|
168
|
+
lib_path = self._path / f'lib{self._prj_name}_{uuid}.so'
|
|
169
|
+
if not lib_path.exists():
|
|
170
|
+
raise RuntimeError(f'Library {lib_path} does not exist')
|
|
171
|
+
self._lib = ctypes.CDLL(str(lib_path))
|
|
172
|
+
|
|
173
|
+
def compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
|
|
174
|
+
"""Compile the model to a shared object file
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
verbose : bool, optional
|
|
179
|
+
Verbose output, by default False
|
|
180
|
+
openmp : bool, optional
|
|
181
|
+
Enable openmp, by default True
|
|
182
|
+
o3 : bool | None, optional
|
|
183
|
+
Turn on -O3 flag, by default False
|
|
184
|
+
clean : bool, optional
|
|
185
|
+
Remove obsolete shared object files, by default True
|
|
186
|
+
|
|
187
|
+
Raises
|
|
188
|
+
------
|
|
189
|
+
RuntimeError
|
|
190
|
+
If compilation fails
|
|
191
|
+
"""
|
|
192
|
+
self.write()
|
|
193
|
+
self._compile(verbose, openmp, o3, clean)
|
|
194
|
+
|
|
195
|
+
def predict(self, data: NDArray[T] | Sequence[NDArray[T]], n_threads: int = 0) -> NDArray[T]:
|
|
196
|
+
"""Run the model on the input data.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
data: NDArray[np.floating] | Sequence[NDArray[np.floating]]
|
|
201
|
+
Input data to the model. The shape is ignored, and the number of samples is
|
|
202
|
+
determined by the size of the data.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
NDArray[np.floating]
|
|
207
|
+
Output of the model in shape (n_samples, output_size).
|
|
208
|
+
"""
|
|
209
|
+
assert self._lib is not None, 'Library not loaded, call .compile() first.'
|
|
210
|
+
inp_size, out_size = self._solution.shape
|
|
211
|
+
|
|
212
|
+
if isinstance(data, Sequence):
|
|
213
|
+
data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
|
|
214
|
+
|
|
215
|
+
dtype = data.dtype
|
|
216
|
+
if dtype not in (np.float32, np.float64):
|
|
217
|
+
raise TypeError(f'Unsupported input data type: {dtype}. Expected float32 or float64.')
|
|
218
|
+
c_dtype = ctypes.c_float if dtype == np.float32 else ctypes.c_double
|
|
219
|
+
|
|
220
|
+
assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
|
|
221
|
+
n_sample = data.size // inp_size
|
|
222
|
+
|
|
223
|
+
inp_data = np.ascontiguousarray(data)
|
|
224
|
+
out_data = np.empty(n_sample * out_size, dtype=dtype)
|
|
225
|
+
|
|
226
|
+
inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(c_dtype))
|
|
227
|
+
out_buf = out_data.ctypes.data_as(ctypes.POINTER(c_dtype))
|
|
228
|
+
if dtype == np.float32:
|
|
229
|
+
self._lib.inference_f32(inp_buf, out_buf, n_sample, n_threads)
|
|
230
|
+
else:
|
|
231
|
+
self._lib.inference_f64(inp_buf, out_buf, n_sample, n_threads)
|
|
232
|
+
|
|
233
|
+
return out_data.reshape(n_sample, out_size) # type: ignore
|
|
234
|
+
|
|
235
|
+
def __repr__(self):
|
|
236
|
+
inp_size, out_size = self._solution.shape
|
|
237
|
+
inp_size, out_size = self._solution.shape
|
|
238
|
+
cost = round(self._solution.cost)
|
|
239
|
+
inp_kifs = tuple(zip(*map(_minimal_kif, self._solution.inp_qint)))
|
|
240
|
+
out_kifs = tuple(zip(*map(_minimal_kif, self._solution.out_qint)))
|
|
241
|
+
in_bits, out_bits = np.sum(inp_kifs), np.sum(out_kifs)
|
|
242
|
+
|
|
243
|
+
spec = f"""Top Function: {self._prj_name}\n====================
|
|
244
|
+
{inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
|
|
245
|
+
combinational @ delay={self._solution.latency}
|
|
246
|
+
Estimated cost: {cost} LUTs"""
|
|
247
|
+
|
|
248
|
+
is_compiled = self._lib is not None
|
|
249
|
+
if is_compiled:
|
|
250
|
+
assert self._uuid is not None
|
|
251
|
+
openmp = 'with OpenMP' if self._lib.openmp_enabled() else '' # type: ignore
|
|
252
|
+
spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
|
|
253
|
+
else:
|
|
254
|
+
spec += '\nEmulator is **not compiled**'
|
|
255
|
+
return spec
|