da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +16 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +3 -34
- da4ml/cmvm/api.py +239 -73
- da4ml/cmvm/core/__init__.py +222 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +569 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +11 -0
- da4ml/codegen/cpp/__init__.py +3 -0
- da4ml/codegen/cpp/cpp_codegen.py +148 -0
- da4ml/codegen/cpp/source/vitis.h +30 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
- da4ml/codegen/verilog/__init__.py +13 -0
- da4ml/codegen/verilog/comb.py +146 -0
- da4ml/codegen/verilog/io_wrapper.py +255 -0
- da4ml/codegen/verilog/pipeline.py +49 -0
- da4ml/codegen/verilog/source/build_binder.mk +27 -0
- da4ml/codegen/verilog/source/build_prj.tcl +75 -0
- da4ml/codegen/verilog/source/ioutils.hh +117 -0
- da4ml/codegen/verilog/source/shift_adder.v +56 -0
- da4ml/codegen/verilog/source/template.xdc +29 -0
- da4ml/codegen/verilog/verilog_model.py +265 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +358 -0
- da4ml/trace/fixed_variable_array.py +177 -0
- da4ml/trace/ops/__init__.py +55 -0
- da4ml/trace/ops/conv_utils.py +104 -0
- da4ml/trace/ops/einsum_utils.py +299 -0
- da4ml/trace/pipeline.py +155 -0
- da4ml/trace/tracer.py +120 -0
- da4ml-0.2.0.dist-info/METADATA +65 -0
- da4ml-0.2.0.dist-info/RECORD +39 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
- da4ml/cmvm/balanced_reduction.py +0 -46
- da4ml/cmvm/cmvm.py +0 -328
- da4ml/cmvm/codegen.py +0 -159
- da4ml/cmvm/csd.py +0 -73
- da4ml/cmvm/fixed_variable.py +0 -205
- da4ml/cmvm/graph_compile.py +0 -85
- da4ml/cmvm/nb_fixed_precision.py +0 -98
- da4ml/cmvm/scoring.py +0 -55
- da4ml/cmvm/utils.py +0 -5
- da4ml-0.1.2.dist-info/METADATA +0 -122
- da4ml-0.1.2.dist-info/RECORD +0 -18
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numba import jit
|
|
5
|
+
|
|
6
|
+
from .bit_decompose import _center, _volatile_int_arr_to_csd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@jit
|
|
10
|
+
def prim_mst_dc(cost_mat: np.ndarray, dc: int = -1):
|
|
11
|
+
"""Minimum Spanning Tree (MST) using Prim's algorithm with a delay constraint. May not be optimal.
|
|
12
|
+
Always start from the root node (0).
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
cost_mat : np.ndarray
|
|
17
|
+
The adjacency matrix of the graph, where cost_mat[i, j] is the cost of the edge between i and j.
|
|
18
|
+
|
|
19
|
+
dc : int, optional
|
|
20
|
+
The delay constraint, by default -1
|
|
21
|
+
If -1, no delay constraint is applied.
|
|
22
|
+
|
|
23
|
+
Delay of each edge is ceiling(log2(cost_mat[i, j])).
|
|
24
|
+
|
|
25
|
+
Delay from the root node to any node is the **maximum** latency of each edge connecting in between,
|
|
26
|
+
plus ceiling(log2(#number of connection edges)).
|
|
27
|
+
Latency is **NOT** the sum of the latencies.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
np.ndarray
|
|
32
|
+
The adjacency list of the MST, where each row is a pair of nodes (parent, child).
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
N = len(cost_mat)
|
|
36
|
+
lat_mat = np.ceil(np.log2(np.maximum(cost_mat, 1)))
|
|
37
|
+
parent = np.full(N, -2, dtype=np.int32) # -2: not visited, -1: root
|
|
38
|
+
|
|
39
|
+
parent[0] = -1
|
|
40
|
+
idxs = np.arange(N)
|
|
41
|
+
|
|
42
|
+
mapping = np.empty((N - 1, 2), dtype=np.int32)
|
|
43
|
+
latency = np.zeros((N,), dtype=np.int32)
|
|
44
|
+
|
|
45
|
+
if dc >= 0:
|
|
46
|
+
_dc = (2**dc - 1) + ceil(log2(np.max(cost_mat[0]) + 1e-32))
|
|
47
|
+
else:
|
|
48
|
+
_dc = -1
|
|
49
|
+
|
|
50
|
+
for n_impl in range(1, N):
|
|
51
|
+
implemented = parent != -2
|
|
52
|
+
_cost = cost_mat[~implemented][:, implemented]
|
|
53
|
+
if dc >= 0:
|
|
54
|
+
_lat = lat_mat[~implemented][:, implemented]
|
|
55
|
+
_cost = np.where(np.maximum(_lat, latency[implemented]) + 1 <= _dc, _cost, np.iinfo(_cost.dtype).max // 2)
|
|
56
|
+
_idx = int(np.argmin(_cost))
|
|
57
|
+
_i, _j = _idx // n_impl, _idx % n_impl
|
|
58
|
+
i, j = idxs[~implemented][_i], idxs[implemented][_j]
|
|
59
|
+
parent[i] = j
|
|
60
|
+
mapping[n_impl - 1, 0] = j
|
|
61
|
+
mapping[n_impl - 1, 1] = i
|
|
62
|
+
latency[i] = max(lat_mat[i, j], latency[j]) + 1 # type: ignore
|
|
63
|
+
|
|
64
|
+
return mapping
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@jit
|
|
68
|
+
def kernel_decompose(kernel: np.ndarray, dc: int = -2):
|
|
69
|
+
"""Decompose a 2D kernel matrix into two matrices with the delay-constrained approx MST.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
kernel : np.ndarray
|
|
74
|
+
The input kernel matrix to decompose.
|
|
75
|
+
|
|
76
|
+
dc : int, optional
|
|
77
|
+
Delay constraint, by default -1
|
|
78
|
+
If -2, no delay constraint is applied.
|
|
79
|
+
If -1, return trivial decomposition (m0 = kernel, m1 = I).
|
|
80
|
+
|
|
81
|
+
The delay constraint limits the maximum latency (hops) of the decomposed
|
|
82
|
+
multiplication structure.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
tuple[np.ndarray, np.ndarray]
|
|
87
|
+
The decomposed matrices (m0, m1): kernel = m0 @ m1
|
|
88
|
+
"""
|
|
89
|
+
kernel, shift0, shift1 = _center(kernel)
|
|
90
|
+
scale0, scale1 = 2.0**shift0, 2.0**shift1
|
|
91
|
+
m, n = kernel.shape[0], kernel.shape[1] + 1
|
|
92
|
+
mat_aug = np.zeros((m, n), dtype=kernel.dtype)
|
|
93
|
+
mat_aug[:, 1:] = kernel
|
|
94
|
+
diff0 = mat_aug[:, :, None] - mat_aug[:, None, :]
|
|
95
|
+
diff1 = mat_aug[:, :, None] + mat_aug[:, None, :]
|
|
96
|
+
dist0 = np.sum(np.sum(_volatile_int_arr_to_csd(diff0) != 0, axis=3), axis=0)
|
|
97
|
+
dist1 = np.sum(np.sum(_volatile_int_arr_to_csd(diff1) != 0, axis=3), axis=0)
|
|
98
|
+
sign = np.where(dist1 - dist0 < 0, -1, 1)
|
|
99
|
+
dist = np.minimum(dist0, dist1)
|
|
100
|
+
mapping = prim_mst_dc(dist, dc=dc)
|
|
101
|
+
n_in, n_out = kernel.shape
|
|
102
|
+
m0, m1 = np.zeros((n_in, n_out), dtype=kernel.dtype), np.zeros((n_out, n_out), dtype=kernel.dtype)
|
|
103
|
+
|
|
104
|
+
if dc == -1:
|
|
105
|
+
m0[:] = kernel
|
|
106
|
+
m1[:] = np.eye(n_out, dtype=kernel.dtype)
|
|
107
|
+
return m0 * scale0[:, None], m1 * scale1
|
|
108
|
+
|
|
109
|
+
cnt = 0
|
|
110
|
+
for _from, _to in mapping:
|
|
111
|
+
col0 = mat_aug[:, _to] - mat_aug[:, _from] * sign[_to, _from]
|
|
112
|
+
if _from != 0:
|
|
113
|
+
col1 = m1[:, _from - 1].copy() * sign[_to, _from]
|
|
114
|
+
else:
|
|
115
|
+
col1 = np.zeros(n_out, dtype=kernel.dtype)
|
|
116
|
+
if np.any(col0 != 0):
|
|
117
|
+
col1[cnt] = 1
|
|
118
|
+
m0[:, cnt] = col0
|
|
119
|
+
cnt += 1
|
|
120
|
+
m1[:, _to - 1] = col1
|
|
121
|
+
return m0 * scale0[:, None], m1 * scale1
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .cpp import cpp_logic_and_bridge_gen
|
|
2
|
+
from .verilog import comb_binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_binder_gen, pipeline_logic_gen
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
'cpp_logic_and_bridge_gen',
|
|
6
|
+
'comb_logic_gen',
|
|
7
|
+
'generate_io_wrapper',
|
|
8
|
+
'comb_binder_gen',
|
|
9
|
+
'pipeline_logic_gen',
|
|
10
|
+
'pipeline_binder_gen',
|
|
11
|
+
]
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
from ...cmvm.types import Op, QInterval, Solution, _minimal_kif
|
|
4
|
+
from ...trace.fixed_variable import _const_f
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def kif_to_vitis_type(k: bool | int, i: int, f: int):
|
|
8
|
+
if k == i == f == 0:
|
|
9
|
+
f = 1
|
|
10
|
+
return f'ap_{"" if k else "u"}fixed<{k+i+f},{k+i}>'
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def kif_to_hlslib_type(k: bool | int, i: int, f: int):
|
|
14
|
+
if k == i == f == 0:
|
|
15
|
+
f = 1
|
|
16
|
+
return f'ac_fixed<{int(k)},{k+i+f},{k+i}>'
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_typestr_fn(flavor: str):
|
|
20
|
+
match flavor.lower():
|
|
21
|
+
case 'vitis':
|
|
22
|
+
typestr_fn = kif_to_vitis_type
|
|
23
|
+
case 'hlslib':
|
|
24
|
+
typestr_fn = kif_to_hlslib_type
|
|
25
|
+
case _:
|
|
26
|
+
raise ValueError(f'Unsupported flavor: {flavor}')
|
|
27
|
+
return typestr_fn
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
|
|
31
|
+
all_kifs = map(_minimal_kif, (op.qint for op in ops))
|
|
32
|
+
all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
|
|
33
|
+
|
|
34
|
+
lines = []
|
|
35
|
+
|
|
36
|
+
for i, op in enumerate(ops):
|
|
37
|
+
_type = all_types[i]
|
|
38
|
+
|
|
39
|
+
ref0 = f'v{op.id0}'
|
|
40
|
+
|
|
41
|
+
match op.opcode:
|
|
42
|
+
case -1:
|
|
43
|
+
# Input marker
|
|
44
|
+
val = f'inp[{ops[op.id0].id0}]'
|
|
45
|
+
|
|
46
|
+
case 0 | 1:
|
|
47
|
+
# Common a+/-b<<shift op
|
|
48
|
+
ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
|
|
49
|
+
val = f'{ref0} {"-" if op.opcode == 1 else "+"} {ref1}'
|
|
50
|
+
|
|
51
|
+
case 2 | -2:
|
|
52
|
+
if op.opcode == 2: # relu(inp)
|
|
53
|
+
if ops[op.id0].qint.min < 0:
|
|
54
|
+
val = f'{ref0} > 0 ? {_type}({ref0}) : {_type}(0)'
|
|
55
|
+
else:
|
|
56
|
+
val = ref0
|
|
57
|
+
else: # relu(-inp)
|
|
58
|
+
if ops[op.id0].qint.max > 0:
|
|
59
|
+
val = f'{ref0} > 0 ? {_type}(0) : {_type}(-{ref0})'
|
|
60
|
+
else:
|
|
61
|
+
val = f'-{ref0}'
|
|
62
|
+
|
|
63
|
+
case 3 | -3:
|
|
64
|
+
# Explicit quantization op, done implicitly via assignment
|
|
65
|
+
val = ref0 if op.opcode == 3 else f'-{ref0}'
|
|
66
|
+
|
|
67
|
+
case 4:
|
|
68
|
+
# Constant addition
|
|
69
|
+
_number = op.data * op.qint.step
|
|
70
|
+
sign, mag = ('-' if _number < 0 else '+'), abs(_number)
|
|
71
|
+
f = _const_f(mag)
|
|
72
|
+
const_type_str = typestr_fn(*_minimal_kif(QInterval(mag, mag, 2.0**-f)))
|
|
73
|
+
val = f'{ref0} {sign} {const_type_str}({mag})'
|
|
74
|
+
|
|
75
|
+
case 5:
|
|
76
|
+
_number = op.data * op.qint.step
|
|
77
|
+
val = f'{_number}'
|
|
78
|
+
|
|
79
|
+
case _:
|
|
80
|
+
raise ValueError(f'Unsupported opcode: {op.opcode}')
|
|
81
|
+
|
|
82
|
+
line = f'{_type} v{i} = {val};'
|
|
83
|
+
|
|
84
|
+
if print_latency:
|
|
85
|
+
line += f' // {op.latency}'
|
|
86
|
+
lines.append(line)
|
|
87
|
+
return lines
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def output_gen(sol: Solution, typestr_fn: Callable[[bool | int, int, int], str]):
|
|
91
|
+
lines = []
|
|
92
|
+
for i, idx in enumerate(sol.out_idxs):
|
|
93
|
+
if idx < 0:
|
|
94
|
+
lines.append(f'out[{i}] = 0;')
|
|
95
|
+
continue
|
|
96
|
+
_type = typestr_fn(*_minimal_kif(sol.out_qint[i]))
|
|
97
|
+
shift = sol.out_shifts[i]
|
|
98
|
+
neg_str = '-' if sol.out_negs[i] else ''
|
|
99
|
+
if shift == 0:
|
|
100
|
+
lines.append(f'out[{i}] = {_type}({neg_str}v{idx});')
|
|
101
|
+
else:
|
|
102
|
+
lines.append(f'out[{i}] = {_type}({neg_str}bit_shift<{shift}>(v{idx}));')
|
|
103
|
+
return lines
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def cpp_logic_and_bridge_gen(
|
|
107
|
+
sol: Solution,
|
|
108
|
+
fn_name: str,
|
|
109
|
+
flavor: str,
|
|
110
|
+
pragmas: list[str] | None = None,
|
|
111
|
+
n_indent: int = 4,
|
|
112
|
+
n_base_indent: int = 0,
|
|
113
|
+
print_latency: bool = False,
|
|
114
|
+
):
|
|
115
|
+
typestr_fn = get_typestr_fn(flavor)
|
|
116
|
+
in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
|
|
117
|
+
inp_type = typestr_fn(*in_kif)
|
|
118
|
+
out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
|
|
119
|
+
out_type = typestr_fn(*out_kif)
|
|
120
|
+
|
|
121
|
+
n_in, n_out = sol.shape
|
|
122
|
+
template_def = 'template <typename inp_t, typename out_t>'
|
|
123
|
+
fn_signature = f'void {fn_name}(inp_t inp[{n_in}], out_t out[{n_out}])'
|
|
124
|
+
pragmas = pragmas or []
|
|
125
|
+
|
|
126
|
+
ssa_lines = ssa_gen(sol.ops, print_latency=print_latency, typestr_fn=typestr_fn)
|
|
127
|
+
output_lines = output_gen(sol, typestr_fn=typestr_fn)
|
|
128
|
+
|
|
129
|
+
indent = ' ' * n_indent
|
|
130
|
+
base_indent = indent * n_base_indent
|
|
131
|
+
body_indent = '\n' + base_indent + indent
|
|
132
|
+
code = f"""{base_indent}{template_def}
|
|
133
|
+
{base_indent}{fn_signature} {{ // {inp_type} -> {out_type}
|
|
134
|
+
{body_indent}{body_indent.join(pragmas)}
|
|
135
|
+
{body_indent}{body_indent.join(ssa_lines)}
|
|
136
|
+
{body_indent}{body_indent.join(output_lines)}
|
|
137
|
+
{base_indent}}}
|
|
138
|
+
"""
|
|
139
|
+
bridge = f"""#include "bridge.h"
|
|
140
|
+
#include "fn.h"
|
|
141
|
+
|
|
142
|
+
extern "C" {{
|
|
143
|
+
void bridge(double *inp, double *out, int size) {{
|
|
144
|
+
auto fn = {fn_name}<{inp_type}, {out_type}>;
|
|
145
|
+
vitis_bridge<{inp_type}, {out_type}, {n_in}, {n_out}>(fn, inp, out, size);
|
|
146
|
+
}}
|
|
147
|
+
}}"""
|
|
148
|
+
return code, bridge
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "ap_fixed.h"
|
|
3
|
+
|
|
4
|
+
template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_fixed<b, i + s> bit_shift(ap_fixed<b, i, Q, O, N> x) {
|
|
5
|
+
#pragma HLS INLINE
|
|
6
|
+
ap_fixed<b, i + s> r;
|
|
7
|
+
r.range() = x.range();
|
|
8
|
+
return r;
|
|
9
|
+
};
|
|
10
|
+
|
|
11
|
+
template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_ufixed<b, i + s> bit_shift(ap_ufixed<b, i, Q, O, N> x) {
|
|
12
|
+
#pragma HLS INLINE
|
|
13
|
+
ap_ufixed<b, i + s> r;
|
|
14
|
+
r.range() = x.range();
|
|
15
|
+
return r;
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
template <int s, int b> ap_fixed<b, s> bit_shift(ap_int<b> x) {
|
|
19
|
+
#pragma HLS INLINE
|
|
20
|
+
ap_fixed<b, s> r;
|
|
21
|
+
r.range() = x.range();
|
|
22
|
+
return r;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
template <int s, int b> ap_ufixed<b, s> bit_shift(ap_uint<b> x) {
|
|
26
|
+
#pragma HLS INLINE
|
|
27
|
+
ap_ufixed<b, s> r;
|
|
28
|
+
r.range() = x.range();
|
|
29
|
+
return r;
|
|
30
|
+
};
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "ap_fixed.h"
|
|
3
|
+
|
|
4
|
+
template <typename inp_t, typename out_t, size_t SIZE_IN, size_t SIZE_OUT, typename F>
|
|
5
|
+
void vitis_bridge(F f, double *inp, double *out, int size) {
|
|
6
|
+
inp_t in_fixed_buf[SIZE_IN];
|
|
7
|
+
out_t out_fixed_buf[SIZE_OUT];
|
|
8
|
+
for (int i = 0; i < size; i++) {
|
|
9
|
+
for (int j = 0; j < SIZE_IN; j++) {
|
|
10
|
+
in_fixed_buf[j] = inp_t(inp[i * SIZE_IN + j]);
|
|
11
|
+
}
|
|
12
|
+
f(in_fixed_buf, out_fixed_buf);
|
|
13
|
+
for (int j = 0; j < SIZE_OUT; j++) {
|
|
14
|
+
out[i * SIZE_OUT + j] = double(out_fixed_buf[j]);
|
|
15
|
+
}
|
|
16
|
+
}
|
|
17
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .comb import comb_logic_gen
|
|
2
|
+
from .io_wrapper import comb_binder_gen, generate_io_wrapper, pipeline_binder_gen
|
|
3
|
+
from .pipeline import pipeline_logic_gen
|
|
4
|
+
from .verilog_model import VerilogModel
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'comb_logic_gen',
|
|
8
|
+
'generate_io_wrapper',
|
|
9
|
+
'comb_binder_gen',
|
|
10
|
+
'pipeline_logic_gen',
|
|
11
|
+
'pipeline_binder_gen',
|
|
12
|
+
'VerilogModel',
|
|
13
|
+
]
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from da4ml.cmvm.types import Op, QInterval, Solution, _minimal_kif
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def ssa_gen(ops: list[Op], print_latency: bool = False):
|
|
9
|
+
kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
10
|
+
widths = list(map(sum, kifs))
|
|
11
|
+
inp_kifs = [_minimal_kif(op.qint) for op in ops if op.opcode == -1]
|
|
12
|
+
inp_widths = list(map(sum, inp_kifs))
|
|
13
|
+
_inp_widths = np.cumsum([0] + inp_widths)
|
|
14
|
+
inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
|
|
15
|
+
|
|
16
|
+
lines = []
|
|
17
|
+
|
|
18
|
+
for i, op in enumerate(ops):
|
|
19
|
+
bw = widths[i]
|
|
20
|
+
v = f'v{i}[{bw-1}:0]'
|
|
21
|
+
_def = f'wire [{bw-1}:0] v{i};'
|
|
22
|
+
|
|
23
|
+
match op.opcode:
|
|
24
|
+
case -1: # Input marker
|
|
25
|
+
i0, i1 = inp_idxs[op.id0]
|
|
26
|
+
line = f'{_def} assign {v} = inp[{i0}:{i1}];'
|
|
27
|
+
case 2 | -2: # ReLU
|
|
28
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
29
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
30
|
+
|
|
31
|
+
v0_name = f'v{op.id0}'
|
|
32
|
+
bw0 = widths[op.id0]
|
|
33
|
+
|
|
34
|
+
if op.opcode == -2:
|
|
35
|
+
_min, _max, step = ops[op.id0].qint
|
|
36
|
+
bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
|
|
37
|
+
lines.append(
|
|
38
|
+
f'wire [{bw_neg-1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg-1}:0] = -{v0_name}[{bw0-1}:0];'
|
|
39
|
+
)
|
|
40
|
+
v0_name = f'v{op.id0}_neg'
|
|
41
|
+
if ops[op.id0].qint.min < 0:
|
|
42
|
+
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}] & {{{bw}{{~{v0_name}[{bw0-1}]}}}};'
|
|
43
|
+
else:
|
|
44
|
+
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
|
|
45
|
+
case 3 | -3: # Explicit quantization
|
|
46
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
47
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
48
|
+
v0_name = f'v{op.id0}'
|
|
49
|
+
bw0 = widths[op.id0]
|
|
50
|
+
|
|
51
|
+
if op.opcode == -3:
|
|
52
|
+
_min, _max, step = ops[op.id0].qint
|
|
53
|
+
bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
|
|
54
|
+
lines.append(
|
|
55
|
+
f'wire [{bw_neg-1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg-1}:0] = -{v0_name}[{bw0-1}:0];'
|
|
56
|
+
)
|
|
57
|
+
v0_name = f'v{op.id0}_neg'
|
|
58
|
+
|
|
59
|
+
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
|
|
60
|
+
case 4: # constant addition
|
|
61
|
+
num = op.data
|
|
62
|
+
sign, mag = int(num < 0), abs(num)
|
|
63
|
+
line = f"{_def} assign {v} = '{bin(mag)[1:]};"
|
|
64
|
+
bw1 = ceil(log2(mag + 1))
|
|
65
|
+
bw0 = widths[op.id0]
|
|
66
|
+
s0 = int(kifs[op.id0][0])
|
|
67
|
+
v0 = f'v{op.id0}[{bw0-1}:0]'
|
|
68
|
+
v1 = f"'{bin(mag)[1:]}"
|
|
69
|
+
shift = int(log2(op.qint.step / ops[op.id0].qint.step))
|
|
70
|
+
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
|
|
71
|
+
case 5: # constant
|
|
72
|
+
num = op.data
|
|
73
|
+
if num < 0:
|
|
74
|
+
num = 2**bw + num
|
|
75
|
+
line = f"{_def} assign {v} = '{bin(num)[1:]};"
|
|
76
|
+
|
|
77
|
+
case 0 | 1: # Common a+/-b<<shift oprs
|
|
78
|
+
p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
|
|
79
|
+
|
|
80
|
+
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
81
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
82
|
+
shift = op.data + f0 - f1
|
|
83
|
+
v0, v1 = f'v{op.id0}[{bw0-1}:0]', f'v{op.id1}[{bw1-1}:0]'
|
|
84
|
+
|
|
85
|
+
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
|
|
86
|
+
case _:
|
|
87
|
+
raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
|
|
88
|
+
|
|
89
|
+
if print_latency:
|
|
90
|
+
line += f' // {op.latency}'
|
|
91
|
+
lines.append(line)
|
|
92
|
+
return lines
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def output_gen(sol: Solution):
|
|
96
|
+
lines = []
|
|
97
|
+
widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
98
|
+
_widths = np.cumsum([0] + widths)
|
|
99
|
+
out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
|
|
100
|
+
for i, idx in enumerate(sol.out_idxs):
|
|
101
|
+
if idx < 0:
|
|
102
|
+
continue
|
|
103
|
+
i0, i1 = out_idxs[i]
|
|
104
|
+
bw = widths[i]
|
|
105
|
+
bw0 = sum(_minimal_kif(sol.ops[idx].qint))
|
|
106
|
+
if sol.out_negs[i]:
|
|
107
|
+
lines.append(f'wire [{bw-1}:0] out_neg{i}; assign out_neg{i} = -v{idx}[{bw0-1}:0];')
|
|
108
|
+
lines.append(f'assign out[{i0}:{i1}] = out_neg{i}[{bw-1}:0];')
|
|
109
|
+
else:
|
|
110
|
+
lines.append(f'assign out[{i0}:{i1}] = v{idx}[{bw-1}:0];')
|
|
111
|
+
return lines
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
|
|
115
|
+
inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
|
|
116
|
+
out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
117
|
+
|
|
118
|
+
fn_signature = [
|
|
119
|
+
f'module {fn_name} (',
|
|
120
|
+
f' input [{inp_bits-1}:0] inp,',
|
|
121
|
+
f' output [{out_bits-1}:0] out',
|
|
122
|
+
');',
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
ssa_lines = ssa_gen(sol.ops, print_latency=print_latency)
|
|
126
|
+
output_lines = output_gen(sol)
|
|
127
|
+
|
|
128
|
+
indent = ' '
|
|
129
|
+
base_indent = '\n'
|
|
130
|
+
body_indent = base_indent + indent
|
|
131
|
+
code = f"""{base_indent[1:]}{base_indent.join(fn_signature)}
|
|
132
|
+
|
|
133
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
134
|
+
// Explicit quantization operation will drop bits if exists
|
|
135
|
+
|
|
136
|
+
{body_indent.join(ssa_lines)}
|
|
137
|
+
|
|
138
|
+
// verilator lint_on UNUSEDSIGNAL
|
|
139
|
+
|
|
140
|
+
{body_indent.join(output_lines)}
|
|
141
|
+
|
|
142
|
+
endmodule
|
|
143
|
+
"""
|
|
144
|
+
if timescale is not None:
|
|
145
|
+
code = f'{timescale}\n\n{code}'
|
|
146
|
+
return code
|