da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +194 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +240 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.0.dist-info/METADATA +85 -0
- da4ml-0.5.0.dist-info/RECORD +96 -0
- da4ml-0.5.0.dist-info/WHEEL +6 -0
- da4ml-0.5.0.dist-info/entry_points.txt +3 -0
- da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from math import ceil, log2
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ....cmvm.types import CombLogic, QInterval, _minimal_kif
|
|
6
|
+
from ..verilog.comb import get_table_name
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def make_neg(
|
|
10
|
+
signals: list[str],
|
|
11
|
+
assigns: list[str],
|
|
12
|
+
idx: int,
|
|
13
|
+
qint: QInterval,
|
|
14
|
+
v0_name: str,
|
|
15
|
+
neg_repo: dict[int, tuple[int, str]],
|
|
16
|
+
):
|
|
17
|
+
if idx in neg_repo:
|
|
18
|
+
return neg_repo[idx]
|
|
19
|
+
_min, _max, step = qint
|
|
20
|
+
was_signed = int(_min < 0)
|
|
21
|
+
bw0 = sum(_minimal_kif(qint))
|
|
22
|
+
bw_neg = sum(_minimal_kif(QInterval(-_max, -_min, step)))
|
|
23
|
+
signals.append(f'signal v{idx}_neg : std_logic_vector({bw_neg - 1} downto {0});')
|
|
24
|
+
assigns.append(
|
|
25
|
+
f'op_neg_{idx} : entity work.negative generic map (BW_IN => {bw0}, BW_OUT => {bw_neg}, IN_SIGNED => {was_signed}) port map (neg_in => {v0_name}, neg_out => v{idx}_neg);'
|
|
26
|
+
)
|
|
27
|
+
bw0 = bw_neg
|
|
28
|
+
v0_name = f'v{idx}_neg'
|
|
29
|
+
neg_repo[idx] = (bw0, v0_name)
|
|
30
|
+
return bw0, v0_name
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def ssa_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]], print_latency: bool = False):
|
|
34
|
+
ops = sol.ops
|
|
35
|
+
kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
36
|
+
widths = list(map(sum, kifs))
|
|
37
|
+
inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
|
|
38
|
+
inp_widths = list(map(sum, inp_kifs))
|
|
39
|
+
_inp_widths = np.cumsum([0] + inp_widths)
|
|
40
|
+
inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
|
|
41
|
+
|
|
42
|
+
signals = []
|
|
43
|
+
assigns = []
|
|
44
|
+
ref_count = sol.ref_count
|
|
45
|
+
|
|
46
|
+
for i, op in enumerate(ops):
|
|
47
|
+
if ref_count[i] == 0:
|
|
48
|
+
continue
|
|
49
|
+
|
|
50
|
+
bw = widths[i]
|
|
51
|
+
if bw == 0:
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
signals.append(f'signal v{i}:std_logic_vector({bw - 1} downto {0});')
|
|
55
|
+
|
|
56
|
+
match op.opcode:
|
|
57
|
+
case -1: # Input marker
|
|
58
|
+
i0, i1 = inp_idxs[op.id0]
|
|
59
|
+
line = f'v{i} <= model_inp({i0} downto {i1});'
|
|
60
|
+
|
|
61
|
+
case 0 | 1: # Common a+/-b<<shift oprs
|
|
62
|
+
p0, p1 = kifs[op.id0], kifs[op.id1]
|
|
63
|
+
bw0, bw1 = widths[op.id0], widths[op.id1]
|
|
64
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
65
|
+
shift = op.data + f0 - f1
|
|
66
|
+
line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{op.opcode}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
|
|
67
|
+
|
|
68
|
+
case 2 | -2: # ReLU
|
|
69
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
70
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
71
|
+
v0_name = f'v{op.id0}'
|
|
72
|
+
bw0 = widths[op.id0]
|
|
73
|
+
if op.opcode == -2:
|
|
74
|
+
bw0, v0_name = make_neg(signals, assigns, op.id0, ops[op.id0].qint, v0_name, neg_repo)
|
|
75
|
+
if ops[op.id0].qint.min < 0:
|
|
76
|
+
if bw > 1:
|
|
77
|
+
line = f'v{i} <= {v0_name}({i0} downto {i1}) and ({bw - 1} downto 0 => not {v0_name}({bw0 - 1}));'
|
|
78
|
+
else:
|
|
79
|
+
line = f'v{i}(0) <= {v0_name}(0) and (not {v0_name}({bw0 - 1}));'
|
|
80
|
+
else:
|
|
81
|
+
line = f'v{i} <= {v0_name}({i0} downto {i1});'
|
|
82
|
+
|
|
83
|
+
case 3 | -3: # Explicit quantization
|
|
84
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
85
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
86
|
+
v0_name = f'v{op.id0}'
|
|
87
|
+
bw0 = widths[op.id0]
|
|
88
|
+
if op.opcode == -3:
|
|
89
|
+
bw0, v0_name = make_neg(signals, assigns, op.id0, ops[op.id0].qint, v0_name, neg_repo)
|
|
90
|
+
|
|
91
|
+
if i0 >= bw0:
|
|
92
|
+
if op.opcode == 3:
|
|
93
|
+
assert ops[op.id0].qint.min < 0, f'{i}, {op.id0}'
|
|
94
|
+
else:
|
|
95
|
+
assert ops[op.id0].qint.max > 0, f'{i}, {op.id0}'
|
|
96
|
+
|
|
97
|
+
if i1 >= bw0:
|
|
98
|
+
v0_name = f'({i0 - i1} downto 0 => {v0_name}({bw0 - 1}))'
|
|
99
|
+
else:
|
|
100
|
+
v0_name = f'({i0 - bw0} downto 0 => {v0_name}({bw0 - 1})) & {v0_name}({bw0 - 1} downto {i1})'
|
|
101
|
+
line = f'v{i} <= {v0_name};'
|
|
102
|
+
else:
|
|
103
|
+
line = f'v{i} <= {v0_name}({i0} downto {i1});'
|
|
104
|
+
|
|
105
|
+
case 4: # constant addition
|
|
106
|
+
num = op.data
|
|
107
|
+
sign, mag = int(num < 0), abs(num)
|
|
108
|
+
bw1 = ceil(log2(mag + 1)) if mag > 0 else 1
|
|
109
|
+
bw0 = widths[op.id0]
|
|
110
|
+
s0 = int(kifs[op.id0][0])
|
|
111
|
+
shift = kifs[op.id0][2] - kifs[i][2]
|
|
112
|
+
bin_val = format(mag, f'0{bw1}b')
|
|
113
|
+
line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>0,BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{sign}) port map(in0=>v{op.id0},in1=>"{bin_val}",result=>v{i});'
|
|
114
|
+
case 5: # constant
|
|
115
|
+
num = op.data
|
|
116
|
+
if num < 0:
|
|
117
|
+
num = 2**bw + num
|
|
118
|
+
bin_val = format(num, f'0{bw}b')
|
|
119
|
+
line = f'v{i} <= "{bin_val}";'
|
|
120
|
+
|
|
121
|
+
case 6 | -6: # MSB Muxing
|
|
122
|
+
k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
|
|
123
|
+
p0, p1 = kifs[a], kifs[b]
|
|
124
|
+
inv = '1' if op.opcode == -6 else '0'
|
|
125
|
+
bwk, bw0, bw1 = widths[k], widths[a], widths[b]
|
|
126
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
127
|
+
_shift = (op.data >> 32) & 0xFFFFFFFF
|
|
128
|
+
_shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
|
|
129
|
+
shift = f0 - f1 + _shift
|
|
130
|
+
v0, v1 = f'v{a}', f'v{b}'
|
|
131
|
+
if bw0 == 0:
|
|
132
|
+
v0, bw0 = 'B"0"', 1
|
|
133
|
+
if bw1 == 0:
|
|
134
|
+
v1, bw1 = 'B"0"', 1
|
|
135
|
+
line = f'op_{i}:entity work.mux generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},INVERT1=>{inv}) port map(key=>v{k}({bwk - 1}),in0=>{v0},in1=>{v1},result=>v{i});'
|
|
136
|
+
|
|
137
|
+
case 7: # Multiplication
|
|
138
|
+
bw0, bw1 = widths[op.id0], widths[op.id1]
|
|
139
|
+
s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
|
|
140
|
+
line = f'op_{i}:entity work.multiplier generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
|
|
141
|
+
|
|
142
|
+
case 8: # Lookup Table
|
|
143
|
+
name = get_table_name(sol, op)
|
|
144
|
+
bw0 = widths[op.id0]
|
|
145
|
+
line = f'op_{i}:entity work.lookup_table generic map(BW_IN=>{bw0},BW_OUT=>{bw},MEM_FILE=>"{name}") port map(inp=>v{op.id0},outp=>v{i});'
|
|
146
|
+
|
|
147
|
+
case _:
|
|
148
|
+
raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
|
|
149
|
+
|
|
150
|
+
if print_latency:
|
|
151
|
+
line += f' -- {op.latency}'
|
|
152
|
+
assigns.append(line)
|
|
153
|
+
return signals, assigns
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def output_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]]):
|
|
157
|
+
assigns = []
|
|
158
|
+
signals = []
|
|
159
|
+
widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
160
|
+
_widths = np.cumsum([0] + widths)
|
|
161
|
+
out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
|
|
162
|
+
for i, idx in enumerate(sol.out_idxs):
|
|
163
|
+
if idx < 0:
|
|
164
|
+
continue
|
|
165
|
+
i0, i1 = out_idxs[i]
|
|
166
|
+
if i0 == i1 - 1:
|
|
167
|
+
continue
|
|
168
|
+
bw = widths[i]
|
|
169
|
+
if sol.out_negs[i]:
|
|
170
|
+
bw, name = make_neg(signals, assigns, idx, sol.ops[idx].qint, f'v{idx}', neg_repo)
|
|
171
|
+
assigns.append(f'model_out({i0} downto {i1}) <= {name}({bw - 1} downto {0});')
|
|
172
|
+
else:
|
|
173
|
+
assigns.append(f'model_out({i0} downto {i1}) <= v{idx}({bw - 1} downto {0});')
|
|
174
|
+
return signals, assigns
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
|
|
178
|
+
inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
|
|
179
|
+
out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
180
|
+
|
|
181
|
+
neg_repo: dict[int, tuple[int, str]] = {}
|
|
182
|
+
ssa_signals, ssa_assigns = ssa_gen(sol, neg_repo=neg_repo, print_latency=print_latency)
|
|
183
|
+
output_signals, output_assigns = output_gen(sol, neg_repo)
|
|
184
|
+
blk = '\n '
|
|
185
|
+
|
|
186
|
+
code = f"""library ieee;
|
|
187
|
+
use ieee.std_logic_1164.all;
|
|
188
|
+
use ieee.numeric_std.all;
|
|
189
|
+
|
|
190
|
+
entity {fn_name} is port(
|
|
191
|
+
model_inp:in std_logic_vector({inp_bits - 1} downto {0});
|
|
192
|
+
model_out:out std_logic_vector({out_bits - 1} downto {0})
|
|
193
|
+
);
|
|
194
|
+
end entity {fn_name};
|
|
195
|
+
|
|
196
|
+
architecture rtl of {fn_name} is
|
|
197
|
+
{blk.join(ssa_signals + output_signals)}
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
begin
|
|
201
|
+
{blk.join(ssa_assigns + output_assigns)}
|
|
202
|
+
|
|
203
|
+
end architecture rtl;
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
return code
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from itertools import accumulate
|
|
2
|
+
|
|
3
|
+
from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _loc(i: int, j: int):
|
|
7
|
+
return f'({i} downto {j})' if i != j else f'({i})'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
11
|
+
N = len(qints)
|
|
12
|
+
ks, _is, fs = zip(*map(_minimal_kif, qints))
|
|
13
|
+
Is = [_i + _k for _i, _k in zip(_is, ks)]
|
|
14
|
+
max_I, max_f = max(_is) + max(ks), max(fs)
|
|
15
|
+
max_bw = max_I + max_f
|
|
16
|
+
width_regular, width_packed = max_bw * N, sum(Is) + sum(fs)
|
|
17
|
+
|
|
18
|
+
regular: list[tuple[int, int]] = []
|
|
19
|
+
pads: list[tuple[int, int, int]] = []
|
|
20
|
+
|
|
21
|
+
bws = [I + f for I, f in zip(Is, fs)]
|
|
22
|
+
_bw = list(accumulate([0] + bws))
|
|
23
|
+
hetero = [(i - 1, j) for i, j in zip(_bw[1:], _bw[:-1])]
|
|
24
|
+
|
|
25
|
+
for i in range(N):
|
|
26
|
+
base = max_bw * i
|
|
27
|
+
bias_low = max_f - fs[i]
|
|
28
|
+
bias_high = max_I - Is[i]
|
|
29
|
+
low = base + bias_low
|
|
30
|
+
high = (base + max_bw - 1) - bias_high
|
|
31
|
+
regular.append((high, low))
|
|
32
|
+
|
|
33
|
+
if bias_low != 0:
|
|
34
|
+
pads.append((base + bias_low - 1, base, -1))
|
|
35
|
+
if bias_high != 0:
|
|
36
|
+
copy_from = hetero[i][0] if ks[i] else -1
|
|
37
|
+
pads.append((base + max_bw - 1, base + max_bw - bias_high, copy_from))
|
|
38
|
+
|
|
39
|
+
mask = list(high < low for high, low in hetero)
|
|
40
|
+
regular = [r for r, m in zip(regular, mask) if not m]
|
|
41
|
+
hetero = [h for h, m in zip(hetero, mask) if not m]
|
|
42
|
+
|
|
43
|
+
if not merge:
|
|
44
|
+
return regular, hetero, pads, (width_regular, width_packed)
|
|
45
|
+
|
|
46
|
+
# Merging consecutive intervals when possible
|
|
47
|
+
NN = len(regular) - 2
|
|
48
|
+
for i in range(NN, -1, -1):
|
|
49
|
+
this_high = regular[i][0]
|
|
50
|
+
next_low = regular[i + 1][1]
|
|
51
|
+
if next_low - this_high != 1:
|
|
52
|
+
continue
|
|
53
|
+
regular[i] = (regular[i + 1][0], regular[i][1])
|
|
54
|
+
regular.pop(i + 1)
|
|
55
|
+
hetero[i] = (hetero[i + 1][0], hetero[i][1])
|
|
56
|
+
hetero.pop(i + 1)
|
|
57
|
+
|
|
58
|
+
for i in range(len(pads) - 2, -1, -1):
|
|
59
|
+
if pads[i + 1][1] - pads[i][0] == 1 and pads[i][2] == pads[i + 1][2]:
|
|
60
|
+
pads[i] = (pads[i + 1][0], pads[i][1], pads[i][2])
|
|
61
|
+
pads.pop(i + 1)
|
|
62
|
+
|
|
63
|
+
return regular, hetero, pads, (width_regular, width_packed)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
|
|
67
|
+
reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
|
|
68
|
+
reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
|
|
69
|
+
|
|
70
|
+
w_reg_in, w_het_in = shape_in
|
|
71
|
+
w_reg_out, w_het_out = shape_out
|
|
72
|
+
|
|
73
|
+
inp_assignment = [f'packed_inp{_loc(ih, jh)} <= model_inp{_loc(ir, jr)};' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
|
|
74
|
+
_out_assignment: list[tuple[int, str]] = []
|
|
75
|
+
|
|
76
|
+
for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
|
|
77
|
+
if ih == jh - 1:
|
|
78
|
+
continue
|
|
79
|
+
_out_assignment.append((ih, f'model_out{_loc(ir, jr)} <= packed_out{_loc(ih, jh)};'))
|
|
80
|
+
|
|
81
|
+
for i, (i, j, copy_from) in enumerate(pad_out):
|
|
82
|
+
n_bit = i - j + 1
|
|
83
|
+
value = "'0'" if copy_from == -1 else f'packed_out({copy_from})'
|
|
84
|
+
pad = f'(others => {value})' if n_bit > 1 else value
|
|
85
|
+
_out_assignment.append((i, f'model_out{_loc(i, j)} <= {pad};'))
|
|
86
|
+
_out_assignment.sort(key=lambda x: x[0])
|
|
87
|
+
out_assignment = [v for _, v in _out_assignment]
|
|
88
|
+
|
|
89
|
+
inp_assignment_str = '\n '.join(inp_assignment)
|
|
90
|
+
out_assignment_str = '\n '.join(out_assignment)
|
|
91
|
+
|
|
92
|
+
clk_and_rst_inp, clk_and_rst_bind = '', ''
|
|
93
|
+
if pipelined:
|
|
94
|
+
clk_and_rst_inp = '\n clk:in std_logic;'
|
|
95
|
+
clk_and_rst_bind = '\n clk=>clk,'
|
|
96
|
+
|
|
97
|
+
return f"""library ieee;
|
|
98
|
+
use ieee.std_logic_1164.all;
|
|
99
|
+
entity {module_name}_wrapper is port({clk_and_rst_inp}
|
|
100
|
+
model_inp:in std_logic_vector({w_reg_in - 1} downto {0});
|
|
101
|
+
model_out:out std_logic_vector({w_reg_out - 1} downto {0})
|
|
102
|
+
);
|
|
103
|
+
end entity {module_name}_wrapper;
|
|
104
|
+
|
|
105
|
+
architecture rtl of {module_name}_wrapper is
|
|
106
|
+
signal packed_inp:std_logic_vector({w_het_in - 1} downto {0});
|
|
107
|
+
signal packed_out:std_logic_vector({w_het_out - 1} downto {0});
|
|
108
|
+
|
|
109
|
+
begin
|
|
110
|
+
{inp_assignment_str}
|
|
111
|
+
|
|
112
|
+
op:entity work.{module_name} port map({clk_and_rst_bind}
|
|
113
|
+
model_inp=>packed_inp,
|
|
114
|
+
model_out=>packed_out
|
|
115
|
+
);
|
|
116
|
+
|
|
117
|
+
{out_assignment_str}
|
|
118
|
+
|
|
119
|
+
end architecture rtl;
|
|
120
|
+
"""
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from ....cmvm.types import Pipeline, _minimal_kif
|
|
2
|
+
from .comb import comb_logic_gen
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def pipeline_logic_gen(
|
|
6
|
+
csol: Pipeline,
|
|
7
|
+
name: str,
|
|
8
|
+
print_latency=False,
|
|
9
|
+
timescale: str | None = None,
|
|
10
|
+
register_layers: int = 1,
|
|
11
|
+
):
|
|
12
|
+
N = len(csol.solutions)
|
|
13
|
+
inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
|
|
14
|
+
out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
|
|
15
|
+
|
|
16
|
+
registers = [f'signal stage{i}_inp:std_logic_vector({width - 1} downto 0);' for i, width in enumerate(inp_bits)]
|
|
17
|
+
for i in range(0, register_layers - 1):
|
|
18
|
+
registers += [f'signal stage{j}_inp_copy{i}:std_logic_vector({width - 1} downto 0);' for j, width in enumerate(inp_bits)]
|
|
19
|
+
wires = [f'signal stage{i}_out:std_logic_vector({width - 1} downto 0);' for i, width in enumerate(out_bits)]
|
|
20
|
+
|
|
21
|
+
comb_logic = [
|
|
22
|
+
f'stage{i}:entity work.{name}_stage{i} port map(model_inp=>stage{i}_inp,model_out=>stage{i}_out);' for i in range(N)
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
if register_layers == 1:
|
|
26
|
+
serial_logic = ['stage0_inp <= model_inp;']
|
|
27
|
+
serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
|
|
28
|
+
else:
|
|
29
|
+
serial_logic = ['stage0_inp_copy0 <= model_inp;']
|
|
30
|
+
for j in range(1, register_layers - 1):
|
|
31
|
+
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
|
|
32
|
+
serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
|
|
33
|
+
for i in range(1, N):
|
|
34
|
+
serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
|
|
35
|
+
for j in range(1, register_layers - 1):
|
|
36
|
+
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
|
|
37
|
+
serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
|
|
38
|
+
|
|
39
|
+
serial_logic += [f'model_out <= stage{N - 1}_out;']
|
|
40
|
+
|
|
41
|
+
blk = '\n '
|
|
42
|
+
|
|
43
|
+
module = f"""library ieee;
|
|
44
|
+
use ieee.std_logic_1164.all;
|
|
45
|
+
entity {name} is port(
|
|
46
|
+
clk:in std_logic;
|
|
47
|
+
model_inp:in std_logic_vector({inp_bits[0] - 1} downto 0);
|
|
48
|
+
model_out:out std_logic_vector({out_bits[-1] - 1} downto 0));
|
|
49
|
+
end entity {name};
|
|
50
|
+
|
|
51
|
+
architecture rtl of {name} is
|
|
52
|
+
{blk.join(registers)}
|
|
53
|
+
{blk.join(wires)}
|
|
54
|
+
|
|
55
|
+
begin
|
|
56
|
+
{blk.join(comb_logic)}
|
|
57
|
+
|
|
58
|
+
process(clk) begin
|
|
59
|
+
if rising_edge(clk) then
|
|
60
|
+
{f'{blk} '.join(serial_logic)}
|
|
61
|
+
end if;
|
|
62
|
+
end process;
|
|
63
|
+
end architecture rtl;
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
ret: dict[str, str] = {}
|
|
67
|
+
for i, s in enumerate(csol.solutions):
|
|
68
|
+
stage_name = f'{name}_stage{i}'
|
|
69
|
+
ret[stage_name] = comb_logic_gen(s, stage_name, print_latency=print_latency, timescale=timescale)
|
|
70
|
+
ret[name] = module
|
|
71
|
+
return ret
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
library ieee;
|
|
2
|
+
use ieee.std_logic_1164.all;
|
|
3
|
+
use ieee.numeric_std.all;
|
|
4
|
+
use std.textio.all;
|
|
5
|
+
use ieee.std_logic_textio.all;
|
|
6
|
+
|
|
7
|
+
entity lookup_table is
|
|
8
|
+
generic (
|
|
9
|
+
BW_IN : positive := 8;
|
|
10
|
+
BW_OUT : positive := 8;
|
|
11
|
+
MEM_FILE : string := "whatever.mem"
|
|
12
|
+
);
|
|
13
|
+
port (
|
|
14
|
+
inp : in std_logic_vector(BW_IN - 1 downto 0);
|
|
15
|
+
outp : out std_logic_vector(BW_OUT - 1 downto 0)
|
|
16
|
+
);
|
|
17
|
+
end entity lookup_table;
|
|
18
|
+
|
|
19
|
+
architecture rtl of lookup_table is
|
|
20
|
+
subtype rom_index_t is natural range 0 to (2 ** BW_IN) - 1;
|
|
21
|
+
type rom_array_t is array (rom_index_t) of std_logic_vector(BW_OUT - 1 downto 0);
|
|
22
|
+
|
|
23
|
+
-- Load the ROM contents from an external hex file.
|
|
24
|
+
impure function init_rom return rom_array_t is
|
|
25
|
+
file rom_file : text;
|
|
26
|
+
variable rom_data : rom_array_t := (others => (others => '0'));
|
|
27
|
+
variable line_in : line;
|
|
28
|
+
variable idx : integer := 0;
|
|
29
|
+
variable data_val : std_logic_vector(BW_OUT - 1 downto 0);
|
|
30
|
+
variable temp_val : std_logic_vector(((BW_OUT + 3) / 4) * 4 - 1 downto 0);
|
|
31
|
+
begin
|
|
32
|
+
file_open(rom_file, MEM_FILE, read_mode);
|
|
33
|
+
|
|
34
|
+
while not endfile(rom_file) loop
|
|
35
|
+
exit when idx > rom_index_t'high;
|
|
36
|
+
readline(rom_file, line_in);
|
|
37
|
+
hread(line_in, temp_val);
|
|
38
|
+
rom_data(idx) := temp_val(BW_OUT - 1 downto 0);
|
|
39
|
+
idx := idx + 1;
|
|
40
|
+
end loop;
|
|
41
|
+
|
|
42
|
+
file_close(rom_file);
|
|
43
|
+
return rom_data;
|
|
44
|
+
end function init_rom;
|
|
45
|
+
|
|
46
|
+
signal ROM_CONTENTS : rom_array_t := init_rom;
|
|
47
|
+
|
|
48
|
+
attribute rom_style : string;
|
|
49
|
+
attribute rom_style of ROM_CONTENTS : signal is "distributed";
|
|
50
|
+
begin
|
|
51
|
+
outp <= ROM_CONTENTS(to_integer(unsigned(inp)));
|
|
52
|
+
end architecture rtl;
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
library ieee;
|
|
2
|
+
use ieee.std_logic_1164.all;
|
|
3
|
+
use ieee.numeric_std.all;
|
|
4
|
+
|
|
5
|
+
entity multiplier is
|
|
6
|
+
generic (
|
|
7
|
+
BW_INPUT0 : integer := 32;
|
|
8
|
+
BW_INPUT1 : integer := 32;
|
|
9
|
+
SIGNED0 : integer := 0;
|
|
10
|
+
SIGNED1 : integer := 0;
|
|
11
|
+
BW_OUT : integer := 32
|
|
12
|
+
);
|
|
13
|
+
port (
|
|
14
|
+
in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
|
|
15
|
+
in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
|
|
16
|
+
result : out std_logic_vector(BW_OUT-1 downto 0)
|
|
17
|
+
);
|
|
18
|
+
end entity multiplier;
|
|
19
|
+
|
|
20
|
+
architecture rtl of multiplier is
|
|
21
|
+
constant BW_BUF : integer := BW_INPUT0 + BW_INPUT1;
|
|
22
|
+
signal mult_buffer : std_logic_vector(BW_BUF-1 downto 0);
|
|
23
|
+
begin
|
|
24
|
+
|
|
25
|
+
gen_mult : process(in0, in1)
|
|
26
|
+
begin
|
|
27
|
+
if SIGNED0 = 1 and SIGNED1 = 1 then
|
|
28
|
+
mult_buffer <= std_logic_vector(resize(signed(in0) * signed(in1), BW_BUF));
|
|
29
|
+
elsif SIGNED0 = 1 and SIGNED1 = 0 then
|
|
30
|
+
mult_buffer <= std_logic_vector(resize(signed(in0) * signed('0' & in1), BW_BUF));
|
|
31
|
+
elsif SIGNED0 = 0 and SIGNED1 = 1 then
|
|
32
|
+
mult_buffer <= std_logic_vector(resize(signed('0' & in0) * signed(in1), BW_BUF));
|
|
33
|
+
else
|
|
34
|
+
mult_buffer <= std_logic_vector(resize(unsigned(in0) * unsigned(in1), BW_BUF));
|
|
35
|
+
end if;
|
|
36
|
+
end process;
|
|
37
|
+
|
|
38
|
+
result <= mult_buffer(BW_OUT-1 downto 0);
|
|
39
|
+
|
|
40
|
+
end architecture rtl;
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
library ieee;
|
|
2
|
+
use ieee.std_logic_1164.all;
|
|
3
|
+
use ieee.numeric_std.all;
|
|
4
|
+
|
|
5
|
+
entity mux is
|
|
6
|
+
generic (
|
|
7
|
+
BW_INPUT0 : integer := 32;
|
|
8
|
+
BW_INPUT1 : integer := 32;
|
|
9
|
+
SIGNED0 : integer := 0;
|
|
10
|
+
SIGNED1 : integer := 0;
|
|
11
|
+
BW_OUT : integer := 32;
|
|
12
|
+
SHIFT1 : integer := 0;
|
|
13
|
+
INVERT1 : integer := 0
|
|
14
|
+
);
|
|
15
|
+
port (
|
|
16
|
+
key : in std_logic;
|
|
17
|
+
in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
|
|
18
|
+
in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
|
|
19
|
+
result : out std_logic_vector(BW_OUT-1 downto 0)
|
|
20
|
+
);
|
|
21
|
+
end entity mux;
|
|
22
|
+
|
|
23
|
+
architecture rtl of mux is
|
|
24
|
+
function max(L, R: integer) return integer is
|
|
25
|
+
begin
|
|
26
|
+
if L > R then
|
|
27
|
+
return L;
|
|
28
|
+
else
|
|
29
|
+
return R;
|
|
30
|
+
end if;
|
|
31
|
+
end function;
|
|
32
|
+
|
|
33
|
+
function if_then_else(cond: boolean; val_true: integer; val_false: integer) return integer is
|
|
34
|
+
begin
|
|
35
|
+
if cond then
|
|
36
|
+
return val_true;
|
|
37
|
+
else
|
|
38
|
+
return val_false;
|
|
39
|
+
end if;
|
|
40
|
+
end function;
|
|
41
|
+
|
|
42
|
+
constant IN0_NEED_BITS : integer := if_then_else(SHIFT1 < 0, BW_INPUT0 - SHIFT1, BW_INPUT0);
|
|
43
|
+
constant IN1_NEED_BITS : integer := if_then_else(SHIFT1 > 0, BW_INPUT1 + SHIFT1, BW_INPUT1);
|
|
44
|
+
constant EXTRA_PAD : integer := if_then_else(SIGNED0 /= SIGNED1, INVERT1 + 1, INVERT1);
|
|
45
|
+
constant BW_BUF : integer := max(IN0_NEED_BITS, IN1_NEED_BITS) + EXTRA_PAD;
|
|
46
|
+
|
|
47
|
+
signal in0_ext : std_logic_vector(BW_BUF-1 downto 0);
|
|
48
|
+
signal in1_ext : std_logic_vector(BW_BUF-1 downto 0);
|
|
49
|
+
signal out_buf : std_logic_vector(BW_BUF-1 downto 0);
|
|
50
|
+
|
|
51
|
+
begin
|
|
52
|
+
|
|
53
|
+
-- Extension and shifting for input 0
|
|
54
|
+
gen_in0_shift_neg: if SHIFT1 < 0 generate
|
|
55
|
+
gen_in0_signed: if SIGNED0 = 1 generate
|
|
56
|
+
in0_ext <= std_logic_vector(resize(signed(in0), BW_BUF)) sll (-SHIFT1);
|
|
57
|
+
end generate;
|
|
58
|
+
gen_in0_unsigned: if SIGNED0 = 0 generate
|
|
59
|
+
in0_ext <= std_logic_vector(resize(unsigned(in0), BW_BUF)) sll (-SHIFT1);
|
|
60
|
+
end generate;
|
|
61
|
+
end generate;
|
|
62
|
+
|
|
63
|
+
gen_in0_shift_pos: if SHIFT1 >= 0 generate
|
|
64
|
+
gen_in0_signed: if SIGNED0 = 1 generate
|
|
65
|
+
in0_ext <= std_logic_vector(resize(signed(in0), BW_BUF));
|
|
66
|
+
end generate;
|
|
67
|
+
gen_in0_unsigned: if SIGNED0 = 0 generate
|
|
68
|
+
in0_ext <= std_logic_vector(resize(unsigned(in0), BW_BUF));
|
|
69
|
+
end generate;
|
|
70
|
+
end generate;
|
|
71
|
+
|
|
72
|
+
-- Extension and shifting for input 1
|
|
73
|
+
gen_in1_shift_pos: if SHIFT1 > 0 generate
|
|
74
|
+
gen_in1_signed: if SIGNED1 = 1 generate
|
|
75
|
+
in1_ext <= std_logic_vector(resize(signed(in1), BW_BUF)) sll SHIFT1;
|
|
76
|
+
end generate;
|
|
77
|
+
gen_in1_unsigned: if SIGNED1 = 0 generate
|
|
78
|
+
in1_ext <= std_logic_vector(resize(unsigned(in1), BW_BUF)) sll SHIFT1;
|
|
79
|
+
end generate;
|
|
80
|
+
end generate;
|
|
81
|
+
|
|
82
|
+
gen_in1_shift_neg: if SHIFT1 <= 0 generate
|
|
83
|
+
gen_in1_signed: if SIGNED1 = 1 generate
|
|
84
|
+
in1_ext <= std_logic_vector(resize(signed(in1), BW_BUF));
|
|
85
|
+
end generate;
|
|
86
|
+
gen_in1_unsigned: if SIGNED1 = 0 generate
|
|
87
|
+
in1_ext <= std_logic_vector(resize(unsigned(in1), BW_BUF));
|
|
88
|
+
end generate;
|
|
89
|
+
end generate;
|
|
90
|
+
|
|
91
|
+
-- Mux logic
|
|
92
|
+
gen_invert: if INVERT1 = 1 generate
|
|
93
|
+
out_buf <= in0_ext when key = '1' else std_logic_vector(-signed(in1_ext));
|
|
94
|
+
end generate;
|
|
95
|
+
|
|
96
|
+
gen_no_invert: if INVERT1 = 0 generate
|
|
97
|
+
out_buf <= in0_ext when key = '1' else in1_ext;
|
|
98
|
+
end generate;
|
|
99
|
+
|
|
100
|
+
result <= out_buf(BW_OUT-1 downto 0);
|
|
101
|
+
|
|
102
|
+
end architecture rtl;
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
library ieee;
|
|
2
|
+
use ieee.std_logic_1164.all;
|
|
3
|
+
use ieee.numeric_std.all;
|
|
4
|
+
|
|
5
|
+
entity negative is
|
|
6
|
+
generic (
|
|
7
|
+
BW_IN : integer := 32;
|
|
8
|
+
BW_OUT : integer := 32;
|
|
9
|
+
IN_SIGNED : integer := 0
|
|
10
|
+
);
|
|
11
|
+
port (
|
|
12
|
+
neg_in : in std_logic_vector(BW_IN-1 downto 0);
|
|
13
|
+
neg_out : out std_logic_vector(BW_OUT-1 downto 0)
|
|
14
|
+
);
|
|
15
|
+
end entity negative;
|
|
16
|
+
|
|
17
|
+
architecture rtl of negative is
|
|
18
|
+
signal in_ext : std_logic_vector(BW_OUT-1 downto 0);
|
|
19
|
+
begin
|
|
20
|
+
|
|
21
|
+
gen_lt : if BW_IN < BW_OUT generate
|
|
22
|
+
gen_signed : if IN_SIGNED = 1 generate
|
|
23
|
+
in_ext <= std_logic_vector(resize(signed(neg_in), BW_OUT));
|
|
24
|
+
end generate;
|
|
25
|
+
gen_unsigned : if IN_SIGNED = 0 generate
|
|
26
|
+
in_ext <= std_logic_vector(resize(unsigned(neg_in), BW_OUT));
|
|
27
|
+
end generate;
|
|
28
|
+
neg_out <= std_logic_vector(-signed(in_ext));
|
|
29
|
+
end generate;
|
|
30
|
+
|
|
31
|
+
gen_ge : if BW_IN >= BW_OUT generate
|
|
32
|
+
neg_out <= std_logic_vector(-signed(neg_in(BW_OUT-1 downto 0)));
|
|
33
|
+
end generate;
|
|
34
|
+
|
|
35
|
+
end architecture rtl;
|