da4ml 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/_version.py +2 -2
- da4ml/codegen/__init__.py +4 -7
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/{cpp/cpp_codegen.py → hls/hls_codegen.py} +19 -12
- da4ml/codegen/{cpp → hls}/hls_model.py +7 -7
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/{verilog/source → rtl/common_source}/binder_util.hh +4 -4
- da4ml/codegen/{verilog/source → rtl/common_source}/build_binder.mk +7 -1
- da4ml/codegen/{verilog/source → rtl/common_source}/build_prj.tcl +28 -7
- da4ml/codegen/{verilog/verilog_model.py → rtl/rtl_model.py} +87 -16
- da4ml/codegen/{verilog → rtl/verilog}/__init__.py +0 -2
- da4ml/codegen/{verilog → rtl/verilog}/comb.py +32 -34
- da4ml/codegen/{verilog → rtl/verilog}/io_wrapper.py +8 -8
- da4ml/codegen/{verilog → rtl/verilog}/pipeline.py +10 -10
- da4ml/codegen/{verilog → rtl/verilog}/source/negative.v +2 -1
- da4ml/codegen/rtl/vhdl/__init__.py +10 -0
- da4ml/codegen/rtl/vhdl/comb.py +192 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +157 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/codegen/rtl/vhdl/source/template.xdc +32 -0
- {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/METADATA +2 -2
- da4ml-0.4.0.dist-info/RECORD +76 -0
- da4ml/codegen/cpp/__init__.py +0 -4
- da4ml-0.3.3.dist-info/RECORD +0 -66
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_binary.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_common.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_decl.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_base.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_ref.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_special.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_base.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_ref.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_special.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/ap_shift_reg.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/etc/ap_private.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/hls_math.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/hls_stream.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/ap_types/utils/x_hls_utils.h +0 -0
- /da4ml/codegen/{cpp → hls}/source/binder_util.hh +0 -0
- /da4ml/codegen/{cpp → hls}/source/build_binder.mk +0 -0
- /da4ml/codegen/{cpp → hls}/source/vitis_bitshift.hh +0 -0
- /da4ml/codegen/{verilog/source → rtl/common_source}/ioutil.hh +0 -0
- /da4ml/codegen/{verilog/source → rtl/common_source}/template.xdc +0 -0
- /da4ml/codegen/{verilog → rtl/verilog}/source/multiplier.v +0 -0
- /da4ml/codegen/{verilog → rtl/verilog}/source/mux.v +0 -0
- /da4ml/codegen/{verilog → rtl/verilog}/source/shift_adder.v +0 -0
- {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/WHEEL +0 -0
- {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from itertools import accumulate
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
@@ -66,18 +66,18 @@ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipe
|
|
|
66
66
|
w_reg_in, w_het_in = shape_in
|
|
67
67
|
w_reg_out, w_het_out = shape_out
|
|
68
68
|
|
|
69
|
-
inp_assignment = [f'assign packed_inp[{ih}:{jh}] =
|
|
69
|
+
inp_assignment = [f'assign packed_inp[{ih}:{jh}] = model_inp[{ir}:{jr}];' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
|
|
70
70
|
_out_assignment: list[tuple[int, str]] = []
|
|
71
71
|
|
|
72
72
|
for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
|
|
73
73
|
if ih == jh - 1:
|
|
74
74
|
continue
|
|
75
|
-
_out_assignment.append((ih, f'assign
|
|
75
|
+
_out_assignment.append((ih, f'assign model_out[{ir}:{jr}] = packed_out[{ih}:{jh}];'))
|
|
76
76
|
|
|
77
77
|
for i, (i, j, copy_from) in enumerate(pad_out):
|
|
78
78
|
n_bit = i - j + 1
|
|
79
79
|
pad = f"{n_bit}'b0" if copy_from == -1 else f'{{{n_bit}{{packed_out[{copy_from}]}}}}'
|
|
80
|
-
_out_assignment.append((i, f'assign
|
|
80
|
+
_out_assignment.append((i, f'assign model_out[{i}:{j}] = {pad};'))
|
|
81
81
|
_out_assignment.sort(key=lambda x: x[0])
|
|
82
82
|
out_assignment = [v for _, v in _out_assignment]
|
|
83
83
|
|
|
@@ -93,9 +93,9 @@ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipe
|
|
|
93
93
|
|
|
94
94
|
module {module_name}_wrapper ({clk_and_rst_inp}
|
|
95
95
|
// verilator lint_off UNUSEDSIGNAL
|
|
96
|
-
input [{w_reg_in - 1}:0]
|
|
96
|
+
input [{w_reg_in - 1}:0] model_inp,
|
|
97
97
|
// verilator lint_on UNUSEDSIGNAL
|
|
98
|
-
output [{w_reg_out - 1}:0]
|
|
98
|
+
output [{w_reg_out - 1}:0] model_out
|
|
99
99
|
);
|
|
100
100
|
wire [{w_het_in - 1}:0] packed_inp;
|
|
101
101
|
wire [{w_het_out - 1}:0] packed_out;
|
|
@@ -103,8 +103,8 @@ module {module_name}_wrapper ({clk_and_rst_inp}
|
|
|
103
103
|
{inp_assignment_str}
|
|
104
104
|
|
|
105
105
|
{module_name} op ({clk_and_rst_bind}
|
|
106
|
-
.
|
|
107
|
-
.
|
|
106
|
+
.model_inp(packed_inp),
|
|
107
|
+
.model_out(packed_out)
|
|
108
108
|
);
|
|
109
109
|
|
|
110
110
|
{out_assignment_str}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from ....cmvm.types import CascadedSolution, _minimal_kif
|
|
2
2
|
from .comb import comb_logic_gen
|
|
3
3
|
|
|
4
4
|
|
|
@@ -13,18 +13,18 @@ def pipeline_logic_gen(
|
|
|
13
13
|
inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
|
|
14
14
|
out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
|
|
15
15
|
|
|
16
|
-
registers = [f'reg [{width
|
|
16
|
+
registers = [f'reg [{width-1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
|
|
17
17
|
for i in range(0, register_layers - 1):
|
|
18
|
-
registers += [f'reg [{width
|
|
19
|
-
wires = [f'wire [{width
|
|
18
|
+
registers += [f'reg [{width-1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
|
|
19
|
+
wires = [f'wire [{width-1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
|
|
20
20
|
|
|
21
|
-
comb_logic = [f'{name}_stage{i} stage{i} (.
|
|
21
|
+
comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
|
|
22
22
|
|
|
23
23
|
if register_layers == 1:
|
|
24
|
-
serial_logic = ['stage0_inp <=
|
|
24
|
+
serial_logic = ['stage0_inp <= model_inp;']
|
|
25
25
|
serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
|
|
26
26
|
else:
|
|
27
|
-
serial_logic = ['stage0_inp_copy0 <=
|
|
27
|
+
serial_logic = ['stage0_inp_copy0 <= model_inp;']
|
|
28
28
|
for j in range(1, register_layers - 1):
|
|
29
29
|
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
|
|
30
30
|
serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
|
|
@@ -34,15 +34,15 @@ def pipeline_logic_gen(
|
|
|
34
34
|
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
|
|
35
35
|
serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
|
|
36
36
|
|
|
37
|
-
serial_logic += [f'
|
|
37
|
+
serial_logic += [f'model_out <= stage{N-1}_out;']
|
|
38
38
|
|
|
39
39
|
sep0 = '\n '
|
|
40
40
|
sep1 = '\n '
|
|
41
41
|
|
|
42
42
|
module = f"""module {name} (
|
|
43
43
|
input clk,
|
|
44
|
-
input [{inp_bits[0]-1}:0]
|
|
45
|
-
output reg [{out_bits[-1]-1}:0]
|
|
44
|
+
input [{inp_bits[0]-1}:0] model_inp,
|
|
45
|
+
output reg [{out_bits[-1]-1}:0] model_out
|
|
46
46
|
);
|
|
47
47
|
|
|
48
48
|
{sep0.join(registers)}
|
|
@@ -11,6 +11,7 @@ module negative #(
|
|
|
11
11
|
// verilator lint_off UNUSEDSIGNAL
|
|
12
12
|
output [BW_OUT-1:0] out
|
|
13
13
|
);
|
|
14
|
+
/* verilator lint_off WIDTHTRUNC */
|
|
14
15
|
generate
|
|
15
16
|
if (BW_IN < BW_OUT) begin : in_is_smaller
|
|
16
17
|
wire [BW_OUT-1:0] in_ext;
|
|
@@ -24,5 +25,5 @@ module negative #(
|
|
|
24
25
|
assign out = -in[BW_OUT-1:0];
|
|
25
26
|
end
|
|
26
27
|
endgenerate
|
|
27
|
-
|
|
28
|
+
/* verilator lint_on WIDTHTRUNC */
|
|
28
29
|
endmodule
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ....cmvm.types import Op, QInterval, Solution, _minimal_kif
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def make_neg(
|
|
9
|
+
signals: list[str],
|
|
10
|
+
assigns: list[str],
|
|
11
|
+
op: Op,
|
|
12
|
+
ops: list[Op],
|
|
13
|
+
bw0: int,
|
|
14
|
+
v0_name: str,
|
|
15
|
+
):
|
|
16
|
+
_min, _max, step = ops[op.id0].qint
|
|
17
|
+
bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
|
|
18
|
+
was_signed = int(_min < 0)
|
|
19
|
+
signals.append(f'signal v{op.id0}_neg : std_logic_vector({bw_neg-1} downto {0});')
|
|
20
|
+
assigns.append(
|
|
21
|
+
f'op_neg_{op.id0} : entity work.negative generic map (BW_IN => {bw0}, BW_OUT => {bw_neg}, IN_SIGNED => {was_signed}) port map (neg_in => {v0_name}, neg_out => v{op.id0}_neg);'
|
|
22
|
+
)
|
|
23
|
+
bw0 = bw_neg
|
|
24
|
+
v0_name = f'v{op.id0}_neg'
|
|
25
|
+
return bw0, v0_name
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
29
|
+
ops = sol.ops
|
|
30
|
+
kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
31
|
+
widths = list(map(sum, kifs))
|
|
32
|
+
inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
|
|
33
|
+
inp_widths = list(map(sum, inp_kifs))
|
|
34
|
+
_inp_widths = np.cumsum([0] + inp_widths)
|
|
35
|
+
inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
|
|
36
|
+
|
|
37
|
+
signals = []
|
|
38
|
+
assigns = []
|
|
39
|
+
ref_count = sol.ref_count
|
|
40
|
+
|
|
41
|
+
for i, op in enumerate(ops):
|
|
42
|
+
if ref_count[i] == 0:
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
bw = widths[i]
|
|
46
|
+
if bw == 0:
|
|
47
|
+
continue
|
|
48
|
+
|
|
49
|
+
match op.opcode:
|
|
50
|
+
case -1: # Input marker
|
|
51
|
+
i0, i1 = inp_idxs[op.id0]
|
|
52
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
53
|
+
line = f'v{i} <= model_inp({i0} downto {i1});'
|
|
54
|
+
|
|
55
|
+
case 0 | 1: # Common a+/-b<<shift oprs
|
|
56
|
+
p0, p1 = kifs[op.id0], kifs[op.id1]
|
|
57
|
+
bw0, bw1 = widths[op.id0], widths[op.id1]
|
|
58
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
59
|
+
shift = op.data + f0 - f1
|
|
60
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
61
|
+
line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{op.opcode}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
|
|
62
|
+
|
|
63
|
+
case 2 | -2: # ReLU
|
|
64
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
65
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
66
|
+
v0_name = f'v{op.id0}'
|
|
67
|
+
bw0 = widths[op.id0]
|
|
68
|
+
if op.opcode == -2 and op.id0 not in neg_defined:
|
|
69
|
+
neg_defined.add(op.id0)
|
|
70
|
+
bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
|
|
71
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
72
|
+
if ops[op.id0].qint.min < 0:
|
|
73
|
+
if bw > 1:
|
|
74
|
+
line = f'v{i} <= {v0_name}({i0} downto {i1}) and ({bw - 1} downto 0 => not {v0_name}({bw0-1}));'
|
|
75
|
+
else:
|
|
76
|
+
line = f'v{i}(0) <= {v0_name}(0) and (not {v0_name}({bw0-1}));'
|
|
77
|
+
else:
|
|
78
|
+
line = f'v{i} <= {v0_name}({i0} downto {i1});'
|
|
79
|
+
|
|
80
|
+
case 3 | -3: # Explicit quantization
|
|
81
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
82
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
83
|
+
v0_name = f'v{op.id0}'
|
|
84
|
+
bw0 = widths[op.id0]
|
|
85
|
+
if op.opcode == -3 and op.id0 not in neg_defined:
|
|
86
|
+
neg_defined.add(op.id0)
|
|
87
|
+
bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
|
|
88
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
89
|
+
line = f'v{i} <= {v0_name}({i0} downto {i1});'
|
|
90
|
+
|
|
91
|
+
case 4: # constant addition
|
|
92
|
+
num = op.data
|
|
93
|
+
sign, mag = int(num < 0), abs(num)
|
|
94
|
+
bw1 = ceil(log2(mag + 1)) if mag > 0 else 1
|
|
95
|
+
bw0 = widths[op.id0]
|
|
96
|
+
s0 = int(kifs[op.id0][0])
|
|
97
|
+
shift = kifs[op.id0][2] - kifs[i][2]
|
|
98
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
99
|
+
bin_val = format(mag, f'0{bw1}b')
|
|
100
|
+
line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>0,BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{sign}) port map(in0=>v{op.id0},in1=>"{bin_val}",result=>v{i});'
|
|
101
|
+
case 5: # constant
|
|
102
|
+
num = op.data
|
|
103
|
+
if num < 0:
|
|
104
|
+
num = 2**bw + num
|
|
105
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
106
|
+
bin_val = format(num, f'0{bw}b')
|
|
107
|
+
line = f'v{i} <= "{bin_val}";'
|
|
108
|
+
|
|
109
|
+
case 6 | -6: # MSB Muxing
|
|
110
|
+
k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
|
|
111
|
+
p0, p1 = kifs[a], kifs[b]
|
|
112
|
+
inv = '1' if op.opcode == -6 else '0'
|
|
113
|
+
bwk, bw0, bw1 = widths[k], widths[a], widths[b]
|
|
114
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
115
|
+
_shift = (op.data >> 32) & 0xFFFFFFFF
|
|
116
|
+
_shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
|
|
117
|
+
shift = f0 - f1 + _shift
|
|
118
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
119
|
+
line = f'op_{i}:entity work.mux generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},INVERT1=>{inv}) port map(key=>v{k}({bwk-1}),in0=>v{a},in1=>v{b},result=>v{i});'
|
|
120
|
+
case 7: # Multiplication
|
|
121
|
+
bw0, bw1 = widths[op.id0], widths[op.id1]
|
|
122
|
+
s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
|
|
123
|
+
signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
|
|
124
|
+
line = f'op_{i}:entity work.multiplier generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
|
|
125
|
+
|
|
126
|
+
case _:
|
|
127
|
+
raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
|
|
128
|
+
|
|
129
|
+
if print_latency:
|
|
130
|
+
line += f' -- {op.latency}'
|
|
131
|
+
assigns.append(line)
|
|
132
|
+
return signals, assigns
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def output_gen(sol: Solution, neg_defined: set[int]):
|
|
136
|
+
assigns = []
|
|
137
|
+
signals = []
|
|
138
|
+
widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
139
|
+
_widths = np.cumsum([0] + widths)
|
|
140
|
+
out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
|
|
141
|
+
for i, idx in enumerate(sol.out_idxs):
|
|
142
|
+
if idx < 0:
|
|
143
|
+
continue
|
|
144
|
+
i0, i1 = out_idxs[i]
|
|
145
|
+
if i0 == i1 - 1:
|
|
146
|
+
continue
|
|
147
|
+
bw = widths[i]
|
|
148
|
+
if sol.out_negs[i]:
|
|
149
|
+
if idx not in neg_defined:
|
|
150
|
+
neg_defined.add(idx)
|
|
151
|
+
bw0 = sum(_minimal_kif(sol.ops[idx].qint))
|
|
152
|
+
was_signed = int(_minimal_kif(sol.ops[idx].qint)[0])
|
|
153
|
+
signals.append(f'signal v{idx}_neg:std_logic_vector({bw-1} downto {0});')
|
|
154
|
+
assigns.append(
|
|
155
|
+
f'op_neg_{idx}:entity work.negative generic map(BW_IN=>{bw0},BW_OUT=>{bw},IN_SIGNED=>{was_signed}) port map(neg_in=>v{idx},neg_out=>v{idx}_neg);'
|
|
156
|
+
)
|
|
157
|
+
assigns.append(f'model_out({i0} downto {i1}) <= v{idx}_neg({bw-1} downto {0});')
|
|
158
|
+
else:
|
|
159
|
+
assigns.append(f'model_out({i0} downto {i1}) <= v{idx}({bw-1} downto {0});')
|
|
160
|
+
return signals, assigns
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
|
|
164
|
+
inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
|
|
165
|
+
out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
166
|
+
|
|
167
|
+
neg_defined = set()
|
|
168
|
+
ssa_signals, ssa_assigns = ssa_gen(sol, neg_defined=neg_defined, print_latency=print_latency)
|
|
169
|
+
output_signals, output_assigns = output_gen(sol, neg_defined)
|
|
170
|
+
blk = '\n '
|
|
171
|
+
|
|
172
|
+
code = f"""library ieee;
|
|
173
|
+
use ieee.std_logic_1164.all;
|
|
174
|
+
use ieee.numeric_std.all;
|
|
175
|
+
|
|
176
|
+
entity {fn_name} is port(
|
|
177
|
+
model_inp:in std_logic_vector({inp_bits-1} downto {0});
|
|
178
|
+
model_out:out std_logic_vector({out_bits-1} downto {0})
|
|
179
|
+
);
|
|
180
|
+
end entity {fn_name};
|
|
181
|
+
|
|
182
|
+
architecture rtl of {fn_name} is
|
|
183
|
+
{blk.join(ssa_signals + output_signals)}
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
begin
|
|
187
|
+
{blk.join(ssa_assigns + output_assigns)}
|
|
188
|
+
|
|
189
|
+
end architecture rtl;
|
|
190
|
+
|
|
191
|
+
"""
|
|
192
|
+
return code
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from itertools import accumulate
|
|
2
|
+
|
|
3
|
+
from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _loc(i: int, j: int):
|
|
7
|
+
return f'({i} downto {j})' if i != j else f'({i})'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
11
|
+
N = len(qints)
|
|
12
|
+
ks, _is, fs = zip(*map(_minimal_kif, qints))
|
|
13
|
+
Is = [_i + _k for _i, _k in zip(_is, ks)]
|
|
14
|
+
max_I, max_f = max(_is) + max(ks), max(fs)
|
|
15
|
+
max_bw = max_I + max_f
|
|
16
|
+
width_regular, width_packed = max_bw * N, sum(Is) + sum(fs)
|
|
17
|
+
|
|
18
|
+
regular: list[tuple[int, int]] = []
|
|
19
|
+
pads: list[tuple[int, int, int]] = []
|
|
20
|
+
|
|
21
|
+
bws = [I + f for I, f in zip(Is, fs)]
|
|
22
|
+
_bw = list(accumulate([0] + bws))
|
|
23
|
+
hetero = [(i - 1, j) for i, j in zip(_bw[1:], _bw[:-1])]
|
|
24
|
+
|
|
25
|
+
for i in range(N):
|
|
26
|
+
base = max_bw * i
|
|
27
|
+
bias_low = max_f - fs[i]
|
|
28
|
+
bias_high = max_I - Is[i]
|
|
29
|
+
low = base + bias_low
|
|
30
|
+
high = (base + max_bw - 1) - bias_high
|
|
31
|
+
regular.append((high, low))
|
|
32
|
+
|
|
33
|
+
if bias_low != 0:
|
|
34
|
+
pads.append((base + bias_low - 1, base, -1))
|
|
35
|
+
if bias_high != 0:
|
|
36
|
+
copy_from = hetero[i][0] if ks[i] else -1
|
|
37
|
+
pads.append((base + max_bw - 1, base + max_bw - bias_high, copy_from))
|
|
38
|
+
|
|
39
|
+
mask = list(high < low for high, low in hetero)
|
|
40
|
+
regular = [r for r, m in zip(regular, mask) if not m]
|
|
41
|
+
hetero = [h for h, m in zip(hetero, mask) if not m]
|
|
42
|
+
|
|
43
|
+
if not merge:
|
|
44
|
+
return regular, hetero, pads, (width_regular, width_packed)
|
|
45
|
+
|
|
46
|
+
# Merging consecutive intervals when possible
|
|
47
|
+
NN = len(regular) - 2
|
|
48
|
+
for i in range(NN, -1, -1):
|
|
49
|
+
this_high = regular[i][0]
|
|
50
|
+
next_low = regular[i + 1][1]
|
|
51
|
+
if next_low - this_high != 1:
|
|
52
|
+
continue
|
|
53
|
+
regular[i] = (regular[i + 1][0], regular[i][1])
|
|
54
|
+
regular.pop(i + 1)
|
|
55
|
+
hetero[i] = (hetero[i + 1][0], hetero[i][1])
|
|
56
|
+
hetero.pop(i + 1)
|
|
57
|
+
|
|
58
|
+
for i in range(len(pads) - 2, -1, -1):
|
|
59
|
+
if pads[i + 1][1] - pads[i][0] == 1 and pads[i][2] == pads[i + 1][2]:
|
|
60
|
+
pads[i] = (pads[i + 1][0], pads[i][1], pads[i][2])
|
|
61
|
+
pads.pop(i + 1)
|
|
62
|
+
|
|
63
|
+
return regular, hetero, pads, (width_regular, width_packed)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipelined: bool = False):
|
|
67
|
+
reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
|
|
68
|
+
reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
|
|
69
|
+
|
|
70
|
+
w_reg_in, w_het_in = shape_in
|
|
71
|
+
w_reg_out, w_het_out = shape_out
|
|
72
|
+
|
|
73
|
+
inp_assignment = [f'packed_inp{_loc(ih,jh)} <= model_inp{_loc(ir,jr)};' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
|
|
74
|
+
_out_assignment: list[tuple[int, str]] = []
|
|
75
|
+
|
|
76
|
+
for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
|
|
77
|
+
if ih == jh - 1:
|
|
78
|
+
continue
|
|
79
|
+
_out_assignment.append((ih, f'model_out{_loc(ir,jr)} <= packed_out{_loc(ih,jh)};'))
|
|
80
|
+
|
|
81
|
+
for i, (i, j, copy_from) in enumerate(pad_out):
|
|
82
|
+
n_bit = i - j + 1
|
|
83
|
+
value = "'0'" if copy_from == -1 else f'packed_out({copy_from})'
|
|
84
|
+
pad = f'(others => {value})' if n_bit > 1 else value
|
|
85
|
+
_out_assignment.append((i, f'model_out{_loc(i,j)} <= {pad};'))
|
|
86
|
+
_out_assignment.sort(key=lambda x: x[0])
|
|
87
|
+
out_assignment = [v for _, v in _out_assignment]
|
|
88
|
+
|
|
89
|
+
inp_assignment_str = '\n '.join(inp_assignment)
|
|
90
|
+
out_assignment_str = '\n '.join(out_assignment)
|
|
91
|
+
|
|
92
|
+
clk_and_rst_inp, clk_and_rst_bind = '', ''
|
|
93
|
+
if pipelined:
|
|
94
|
+
clk_and_rst_inp = '\n clk:in std_logic;'
|
|
95
|
+
clk_and_rst_bind = '\n clk=>clk,'
|
|
96
|
+
|
|
97
|
+
return f"""library ieee;
|
|
98
|
+
use ieee.std_logic_1164.all;
|
|
99
|
+
entity {module_name}_wrapper is port({clk_and_rst_inp}
|
|
100
|
+
model_inp:in std_logic_vector({w_reg_in-1} downto {0});
|
|
101
|
+
model_out:out std_logic_vector({w_reg_out-1} downto {0})
|
|
102
|
+
);
|
|
103
|
+
end entity {module_name}_wrapper;
|
|
104
|
+
|
|
105
|
+
architecture rtl of {module_name}_wrapper is
|
|
106
|
+
signal packed_inp:std_logic_vector({w_het_in-1} downto {0});
|
|
107
|
+
signal packed_out:std_logic_vector({w_het_out-1} downto {0});
|
|
108
|
+
|
|
109
|
+
begin
|
|
110
|
+
{inp_assignment_str}
|
|
111
|
+
|
|
112
|
+
op:entity work.{module_name} port map({clk_and_rst_bind}
|
|
113
|
+
model_inp=>packed_inp,
|
|
114
|
+
model_out=>packed_out
|
|
115
|
+
);
|
|
116
|
+
|
|
117
|
+
{out_assignment_str}
|
|
118
|
+
|
|
119
|
+
end architecture rtl;
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def binder_gen(csol: CascadedSolution | Solution, module_name: str, II: int = 1, latency_multiplier: int = 1):
|
|
124
|
+
k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
|
|
125
|
+
k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
|
|
126
|
+
max_inp_bw = max(k_in) + max(i_in) + max(f_in)
|
|
127
|
+
max_out_bw = max(k_out) + max(i_out) + max(f_out)
|
|
128
|
+
if isinstance(csol, Solution):
|
|
129
|
+
II = latency = 0
|
|
130
|
+
else:
|
|
131
|
+
latency = len(csol.solutions) * latency_multiplier
|
|
132
|
+
|
|
133
|
+
n_in, n_out = csol.shape
|
|
134
|
+
return f"""#include <cstddef>
|
|
135
|
+
#include "binder_util.hh"
|
|
136
|
+
#include "V{module_name}.h"
|
|
137
|
+
|
|
138
|
+
struct {module_name}_config {{
|
|
139
|
+
static const size_t N_inp = {n_in};
|
|
140
|
+
static const size_t N_out = {n_out};
|
|
141
|
+
static const size_t max_inp_bw = {max_inp_bw};
|
|
142
|
+
static const size_t max_out_bw = {max_out_bw};
|
|
143
|
+
static const size_t II = {II};
|
|
144
|
+
static const size_t latency = {latency};
|
|
145
|
+
typedef V{module_name} dut_t;
|
|
146
|
+
}};
|
|
147
|
+
|
|
148
|
+
extern "C" {{
|
|
149
|
+
bool openmp_enabled() {{
|
|
150
|
+
return _openmp;
|
|
151
|
+
}}
|
|
152
|
+
|
|
153
|
+
void inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {{
|
|
154
|
+
batch_inference<{module_name}_config>(c_inp, c_out, n_samples);
|
|
155
|
+
}}
|
|
156
|
+
}}
|
|
157
|
+
"""
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from ....cmvm.types import CascadedSolution, _minimal_kif
|
|
2
|
+
from .comb import comb_logic_gen
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def pipeline_logic_gen(
|
|
6
|
+
csol: CascadedSolution,
|
|
7
|
+
name: str,
|
|
8
|
+
print_latency=False,
|
|
9
|
+
timescale: str | None = None,
|
|
10
|
+
register_layers: int = 1,
|
|
11
|
+
):
|
|
12
|
+
N = len(csol.solutions)
|
|
13
|
+
inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
|
|
14
|
+
out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
|
|
15
|
+
|
|
16
|
+
registers = [f'signal stage{i}_inp:std_logic_vector({width-1} downto 0);' for i, width in enumerate(inp_bits)]
|
|
17
|
+
for i in range(0, register_layers - 1):
|
|
18
|
+
registers += [f'signal stage{j}_inp_copy{i}:std_logic_vector({width-1} downto 0);' for j, width in enumerate(inp_bits)]
|
|
19
|
+
wires = [f'signal stage{i}_out:std_logic_vector({width-1} downto 0);' for i, width in enumerate(out_bits)]
|
|
20
|
+
|
|
21
|
+
comb_logic = [
|
|
22
|
+
f'stage{i}:entity work.{name}_stage{i} port map(model_inp=>stage{i}_inp,model_out=>stage{i}_out);' for i in range(N)
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
if register_layers == 1:
|
|
26
|
+
serial_logic = ['stage0_inp <= model_inp;']
|
|
27
|
+
serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
|
|
28
|
+
else:
|
|
29
|
+
serial_logic = ['stage0_inp_copy0 <= model_inp;']
|
|
30
|
+
for j in range(1, register_layers - 1):
|
|
31
|
+
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
|
|
32
|
+
serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
|
|
33
|
+
for i in range(1, N):
|
|
34
|
+
serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
|
|
35
|
+
for j in range(1, register_layers - 1):
|
|
36
|
+
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
|
|
37
|
+
serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
|
|
38
|
+
|
|
39
|
+
serial_logic += [f'model_out <= stage{N-1}_out;']
|
|
40
|
+
|
|
41
|
+
blk = '\n '
|
|
42
|
+
|
|
43
|
+
module = f"""library ieee;
|
|
44
|
+
use ieee.std_logic_1164.all;
|
|
45
|
+
entity {name} is port(
|
|
46
|
+
clk:in std_logic;
|
|
47
|
+
model_inp:in std_logic_vector({inp_bits[0]-1} downto 0);
|
|
48
|
+
model_out:out std_logic_vector({out_bits[-1]-1} downto 0));
|
|
49
|
+
end entity {name};
|
|
50
|
+
|
|
51
|
+
architecture rtl of {name} is
|
|
52
|
+
{blk.join(registers)}
|
|
53
|
+
{blk.join(wires)}
|
|
54
|
+
|
|
55
|
+
begin
|
|
56
|
+
{blk.join(comb_logic)}
|
|
57
|
+
|
|
58
|
+
process(clk) begin
|
|
59
|
+
if rising_edge(clk) then
|
|
60
|
+
{f'{blk} '.join(serial_logic)}
|
|
61
|
+
end if;
|
|
62
|
+
end process;
|
|
63
|
+
end architecture rtl;
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
ret: dict[str, str] = {}
|
|
67
|
+
for i, s in enumerate(csol.solutions):
|
|
68
|
+
stage_name = f'{name}_stage{i}'
|
|
69
|
+
ret[stage_name] = comb_logic_gen(s, stage_name, print_latency=print_latency, timescale=timescale)
|
|
70
|
+
ret[name] = module
|
|
71
|
+
return ret
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
library ieee;
|
|
2
|
+
use ieee.std_logic_1164.all;
|
|
3
|
+
use ieee.numeric_std.all;
|
|
4
|
+
|
|
5
|
+
entity multiplier is
|
|
6
|
+
generic (
|
|
7
|
+
BW_INPUT0 : integer := 32;
|
|
8
|
+
BW_INPUT1 : integer := 32;
|
|
9
|
+
SIGNED0 : integer := 0;
|
|
10
|
+
SIGNED1 : integer := 0;
|
|
11
|
+
BW_OUT : integer := 32
|
|
12
|
+
);
|
|
13
|
+
port (
|
|
14
|
+
in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
|
|
15
|
+
in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
|
|
16
|
+
result : out std_logic_vector(BW_OUT-1 downto 0)
|
|
17
|
+
);
|
|
18
|
+
end entity multiplier;
|
|
19
|
+
|
|
20
|
+
architecture rtl of multiplier is
|
|
21
|
+
constant BW_BUF : integer := BW_INPUT0 + BW_INPUT1;
|
|
22
|
+
signal mult_buffer : std_logic_vector(BW_BUF-1 downto 0);
|
|
23
|
+
begin
|
|
24
|
+
|
|
25
|
+
gen_mult : process(in0, in1)
|
|
26
|
+
begin
|
|
27
|
+
if SIGNED0 = 1 and SIGNED1 = 1 then
|
|
28
|
+
mult_buffer <= std_logic_vector(resize(signed(in0) * signed(in1), BW_BUF));
|
|
29
|
+
elsif SIGNED0 = 1 and SIGNED1 = 0 then
|
|
30
|
+
mult_buffer <= std_logic_vector(resize(signed(in0) * signed('0' & in1), BW_BUF));
|
|
31
|
+
elsif SIGNED0 = 0 and SIGNED1 = 1 then
|
|
32
|
+
mult_buffer <= std_logic_vector(resize(signed('0' & in0) * signed(in1), BW_BUF));
|
|
33
|
+
else
|
|
34
|
+
mult_buffer <= std_logic_vector(resize(unsigned(in0) * unsigned(in1), BW_BUF));
|
|
35
|
+
end if;
|
|
36
|
+
end process;
|
|
37
|
+
|
|
38
|
+
result <= mult_buffer(BW_OUT-1 downto 0);
|
|
39
|
+
|
|
40
|
+
end architecture rtl;
|