da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +204 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +246 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.1.post1.dist-info/METADATA +85 -0
- da4ml-0.5.1.post1.dist-info/RECORD +96 -0
- da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
- da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
- da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
from hashlib import sha256
|
|
2
|
+
from math import ceil, log2
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ....cmvm.types import CombLogic, Op, QInterval, _minimal_kif
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def make_neg(lines: list[str], idx: int, qint: QInterval, v0_name: str, neg_repo: dict[int, tuple[int, str]]):
|
|
11
|
+
if idx == 21568:
|
|
12
|
+
pass
|
|
13
|
+
if idx in neg_repo:
|
|
14
|
+
return neg_repo[idx]
|
|
15
|
+
_min, _max, step = qint
|
|
16
|
+
bw0 = sum(_minimal_kif(qint))
|
|
17
|
+
bw_neg = sum(_minimal_kif(QInterval(-_max, -_min, step)))
|
|
18
|
+
was_signed = int(_min < 0)
|
|
19
|
+
lines.append(
|
|
20
|
+
f'wire [{bw_neg - 1}:0] v{idx}_neg; negative #({bw0}, {bw_neg}, {was_signed}) op_neg_{idx} ({v0_name}, v{idx}_neg);'
|
|
21
|
+
)
|
|
22
|
+
bw0 = bw_neg
|
|
23
|
+
v0_name = f'v{idx}_neg'
|
|
24
|
+
neg_repo[idx] = (bw0, v0_name)
|
|
25
|
+
return bw0, v0_name
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def gen_mem_file(sol: CombLogic, op: Op) -> str:
|
|
29
|
+
assert op.opcode == 8
|
|
30
|
+
assert sol.lookup_tables is not None
|
|
31
|
+
table = sol.lookup_tables[op.data]
|
|
32
|
+
width = sum(table.spec.out_kif)
|
|
33
|
+
ndigits = ceil(width / 4)
|
|
34
|
+
data = table.padded_table(sol.ops[op.id0].qint)
|
|
35
|
+
mem_lines = [f'{hex(value)[2:].upper().zfill(ndigits)}' for value in data & ((1 << width) - 1)]
|
|
36
|
+
return '\n'.join(mem_lines)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_table_name(sol: CombLogic, op: Op) -> str:
|
|
40
|
+
memfile = gen_mem_file(sol, op)
|
|
41
|
+
hash_obj = sha256(memfile.encode('utf-8'))
|
|
42
|
+
_int = int(hash_obj.hexdigest()[:32], 16)
|
|
43
|
+
uuid = UUID(int=_int, version=4)
|
|
44
|
+
return f'table_{str(uuid)}.mem'
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def ssa_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]], print_latency: bool = False) -> list[str]:
|
|
48
|
+
ops = sol.ops
|
|
49
|
+
kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
50
|
+
widths: list[int] = list(map(sum, kifs))
|
|
51
|
+
inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
|
|
52
|
+
inp_widths = list(map(sum, inp_kifs))
|
|
53
|
+
_inp_widths = np.cumsum([0] + inp_widths)
|
|
54
|
+
inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
|
|
55
|
+
|
|
56
|
+
lines: list[str] = []
|
|
57
|
+
ref_count = sol.ref_count
|
|
58
|
+
|
|
59
|
+
for i, op in enumerate(ops):
|
|
60
|
+
if ref_count[i] == 0:
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
bw = widths[i]
|
|
64
|
+
v = f'v{i}[{bw - 1}:0]'
|
|
65
|
+
_def = f'wire [{bw - 1}:0] v{i};'
|
|
66
|
+
if bw == 0:
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
match op.opcode:
|
|
70
|
+
case -1: # Input marker
|
|
71
|
+
i0, i1 = inp_idxs[op.id0]
|
|
72
|
+
line = f'{_def} assign {v} = model_inp[{i0}:{i1}];'
|
|
73
|
+
|
|
74
|
+
case 0 | 1: # Common a+/-b<<shift oprs
|
|
75
|
+
p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
|
|
76
|
+
|
|
77
|
+
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
78
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
79
|
+
shift = op.data + f0 - f1
|
|
80
|
+
v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
|
|
81
|
+
|
|
82
|
+
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
|
|
83
|
+
|
|
84
|
+
case 2 | -2: # ReLU
|
|
85
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
86
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
87
|
+
|
|
88
|
+
v0_name = f'v{op.id0}'
|
|
89
|
+
bw0 = widths[op.id0]
|
|
90
|
+
|
|
91
|
+
if op.opcode == -2:
|
|
92
|
+
bw0, v0_name = make_neg(lines, op.id0, ops[op.id0].qint, v0_name, neg_repo)
|
|
93
|
+
if ops[op.id0].qint.min < 0:
|
|
94
|
+
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}] & {{{bw}{{~{v0_name}[{bw0 - 1}]}}}};'
|
|
95
|
+
else:
|
|
96
|
+
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
|
|
97
|
+
|
|
98
|
+
case 3 | -3: # Explicit quantization
|
|
99
|
+
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
100
|
+
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
101
|
+
v0_name = f'v{op.id0}'
|
|
102
|
+
bw0 = widths[op.id0]
|
|
103
|
+
|
|
104
|
+
if op.opcode == -3:
|
|
105
|
+
bw0, v0_name = make_neg(lines, op.id0, ops[op.id0].qint, v0_name, neg_repo)
|
|
106
|
+
|
|
107
|
+
if i0 >= bw0:
|
|
108
|
+
if op.opcode == 3:
|
|
109
|
+
assert ops[op.id0].qint.min < 0, f'{i}, {op.id0}'
|
|
110
|
+
else:
|
|
111
|
+
assert ops[op.id0].qint.max > 0, f'{i}, {op.id0}'
|
|
112
|
+
|
|
113
|
+
if i1 >= bw0:
|
|
114
|
+
v0_name = f'{{{i0 - i1 + 1}{{{v0_name}[{bw0 - 1}]}}}}'
|
|
115
|
+
else:
|
|
116
|
+
v0_name = f'{{{{{i0 - bw0 + 1}{{{v0_name}[{bw0 - 1}]}}}}, {v0_name}[{bw0 - 1}:{i1}]}}'
|
|
117
|
+
line = f'{_def} assign {v} = {v0_name};'
|
|
118
|
+
else:
|
|
119
|
+
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
|
|
120
|
+
|
|
121
|
+
case 4: # constant addition
|
|
122
|
+
num = op.data
|
|
123
|
+
sign, mag = int(num < 0), abs(num)
|
|
124
|
+
bw1 = ceil(log2(mag + 1))
|
|
125
|
+
bw0 = widths[op.id0]
|
|
126
|
+
s0 = int(kifs[op.id0][0])
|
|
127
|
+
v0 = f'v{op.id0}[{bw0 - 1}:0]'
|
|
128
|
+
v1 = f"{bw1}'{bin(mag)[1:]}"
|
|
129
|
+
shift = kifs[op.id0][2] - kifs[i][2]
|
|
130
|
+
|
|
131
|
+
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
|
|
132
|
+
|
|
133
|
+
case 5: # constant
|
|
134
|
+
num = op.data
|
|
135
|
+
if num < 0:
|
|
136
|
+
num = 2**bw + num
|
|
137
|
+
line = f"{_def} assign {v} = '{bin(num)[1:]};"
|
|
138
|
+
|
|
139
|
+
case 6 | -6: # MSB Muxing
|
|
140
|
+
k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
|
|
141
|
+
p0, p1 = kifs[a], kifs[b]
|
|
142
|
+
inv = '1' if op.opcode == -6 else '0'
|
|
143
|
+
bwk, bw0, bw1 = widths[k], widths[a], widths[b]
|
|
144
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
145
|
+
_shift = (op.data >> 32) & 0xFFFFFFFF
|
|
146
|
+
_shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
|
|
147
|
+
shift = f0 - f1 + _shift
|
|
148
|
+
vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
|
|
149
|
+
if bw0 == 0:
|
|
150
|
+
v0, bw0 = "1'b0", 1
|
|
151
|
+
if bw1 == 0:
|
|
152
|
+
v1, bw1 = "1'b0", 1
|
|
153
|
+
|
|
154
|
+
line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
|
|
155
|
+
|
|
156
|
+
case 7: # Multiplication
|
|
157
|
+
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
158
|
+
s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
|
|
159
|
+
v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
|
|
160
|
+
|
|
161
|
+
line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
|
|
162
|
+
|
|
163
|
+
case 8: # Lookup Table
|
|
164
|
+
name = get_table_name(sol, op)
|
|
165
|
+
bw0 = widths[op.id0]
|
|
166
|
+
|
|
167
|
+
line = f'{_def} lookup_table #({bw0}, {bw}, "{name}") op_{i} (v{op.id0}, {v});'
|
|
168
|
+
|
|
169
|
+
case _:
|
|
170
|
+
raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
|
|
171
|
+
|
|
172
|
+
if print_latency:
|
|
173
|
+
line += f' // {op.latency}'
|
|
174
|
+
lines.append(line)
|
|
175
|
+
return lines
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def output_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]]) -> list[str]:
|
|
179
|
+
lines = []
|
|
180
|
+
widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
181
|
+
_widths = np.cumsum([0] + widths)
|
|
182
|
+
out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
|
|
183
|
+
for i, idx in enumerate(sol.out_idxs):
|
|
184
|
+
if idx < 0:
|
|
185
|
+
continue
|
|
186
|
+
i0, i1 = out_idxs[i]
|
|
187
|
+
if i0 == i1 - 1:
|
|
188
|
+
continue
|
|
189
|
+
bw = widths[i]
|
|
190
|
+
if sol.out_negs[i]:
|
|
191
|
+
_, name = make_neg(lines, idx, sol.ops[idx].qint, f'v{idx}', neg_repo)
|
|
192
|
+
lines.append(f'assign model_out[{i0}:{i1}] = {name}[{bw - 1}:0];')
|
|
193
|
+
|
|
194
|
+
else:
|
|
195
|
+
lines.append(f'assign model_out[{i0}:{i1}] = v{idx}[{bw - 1}:0];')
|
|
196
|
+
return lines
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
|
|
200
|
+
inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
|
|
201
|
+
out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
202
|
+
|
|
203
|
+
fn_signature = [
|
|
204
|
+
f'module {fn_name} (',
|
|
205
|
+
f' input [{inp_bits - 1}:0] model_inp,',
|
|
206
|
+
f' output [{out_bits - 1}:0] model_out',
|
|
207
|
+
');',
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
neg_repo: dict[int, tuple[int, str]] = {}
|
|
211
|
+
ssa_lines = ssa_gen(sol, neg_repo=neg_repo, print_latency=print_latency)
|
|
212
|
+
output_lines = output_gen(sol, neg_repo)
|
|
213
|
+
|
|
214
|
+
indent = ' '
|
|
215
|
+
base_indent = '\n'
|
|
216
|
+
body_indent = base_indent + indent
|
|
217
|
+
code = f"""{base_indent[1:]}{base_indent.join(fn_signature)}
|
|
218
|
+
|
|
219
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
220
|
+
// Explicit quantization operation will drop bits if exists
|
|
221
|
+
|
|
222
|
+
{body_indent.join(ssa_lines)}
|
|
223
|
+
|
|
224
|
+
// verilator lint_on UNUSEDSIGNAL
|
|
225
|
+
|
|
226
|
+
{body_indent.join(output_lines)}
|
|
227
|
+
|
|
228
|
+
endmodule
|
|
229
|
+
"""
|
|
230
|
+
if timescale is not None:
|
|
231
|
+
code = f'{timescale}\n\n{code}'
|
|
232
|
+
return code
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def table_mem_gen(sol: CombLogic) -> dict[str, str]:
|
|
236
|
+
if not sol.lookup_tables:
|
|
237
|
+
return {}
|
|
238
|
+
mem_files = {get_table_name(sol, op): gen_mem_file(sol, op) for op in sol.ops if op.opcode == 8}
|
|
239
|
+
return mem_files
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from itertools import accumulate
|
|
2
|
+
|
|
3
|
+
from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
7
|
+
N = len(qints)
|
|
8
|
+
ks, _is, fs = zip(*map(_minimal_kif, qints))
|
|
9
|
+
Is = [_i + _k for _i, _k in zip(_is, ks)]
|
|
10
|
+
max_I, max_f = max(_is) + max(ks), max(fs)
|
|
11
|
+
max_bw = max_I + max_f
|
|
12
|
+
width_regular, width_packed = max_bw * N, sum(Is) + sum(fs)
|
|
13
|
+
|
|
14
|
+
regular: list[tuple[int, int]] = []
|
|
15
|
+
pads: list[tuple[int, int, int]] = []
|
|
16
|
+
|
|
17
|
+
bws = [I + f for I, f in zip(Is, fs)]
|
|
18
|
+
_bw = list(accumulate([0] + bws))
|
|
19
|
+
hetero = [(i - 1, j) for i, j in zip(_bw[1:], _bw[:-1])]
|
|
20
|
+
|
|
21
|
+
for i in range(N):
|
|
22
|
+
base = max_bw * i
|
|
23
|
+
bias_low = max_f - fs[i]
|
|
24
|
+
bias_high = max_I - Is[i]
|
|
25
|
+
low = base + bias_low
|
|
26
|
+
high = (base + max_bw - 1) - bias_high
|
|
27
|
+
regular.append((high, low))
|
|
28
|
+
|
|
29
|
+
if bias_low != 0:
|
|
30
|
+
pads.append((base + bias_low - 1, base, -1))
|
|
31
|
+
if bias_high != 0:
|
|
32
|
+
copy_from = hetero[i][0] if ks[i] else -1
|
|
33
|
+
pads.append((base + max_bw - 1, base + max_bw - bias_high, copy_from))
|
|
34
|
+
|
|
35
|
+
mask = list(high < low for high, low in hetero)
|
|
36
|
+
regular = [r for r, m in zip(regular, mask) if not m]
|
|
37
|
+
hetero = [h for h, m in zip(hetero, mask) if not m]
|
|
38
|
+
|
|
39
|
+
if not merge:
|
|
40
|
+
return regular, hetero, pads, (width_regular, width_packed)
|
|
41
|
+
|
|
42
|
+
# Merging consecutive intervals when possible
|
|
43
|
+
NN = len(regular) - 2
|
|
44
|
+
for i in range(NN, -1, -1):
|
|
45
|
+
this_high = regular[i][0]
|
|
46
|
+
next_low = regular[i + 1][1]
|
|
47
|
+
if next_low - this_high != 1:
|
|
48
|
+
continue
|
|
49
|
+
regular[i] = (regular[i + 1][0], regular[i][1])
|
|
50
|
+
regular.pop(i + 1)
|
|
51
|
+
hetero[i] = (hetero[i + 1][0], hetero[i][1])
|
|
52
|
+
hetero.pop(i + 1)
|
|
53
|
+
|
|
54
|
+
for i in range(len(pads) - 2, -1, -1):
|
|
55
|
+
if pads[i + 1][1] - pads[i][0] == 1 and pads[i][2] == pads[i + 1][2]:
|
|
56
|
+
pads[i] = (pads[i + 1][0], pads[i][1], pads[i][2])
|
|
57
|
+
pads.pop(i + 1)
|
|
58
|
+
|
|
59
|
+
return regular, hetero, pads, (width_regular, width_packed)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
|
|
63
|
+
reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
|
|
64
|
+
reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
|
|
65
|
+
|
|
66
|
+
w_reg_in, w_het_in = shape_in
|
|
67
|
+
w_reg_out, w_het_out = shape_out
|
|
68
|
+
|
|
69
|
+
inp_assignment = [f'assign packed_inp[{ih}:{jh}] = model_inp[{ir}:{jr}];' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
|
|
70
|
+
_out_assignment: list[tuple[int, str]] = []
|
|
71
|
+
|
|
72
|
+
for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
|
|
73
|
+
if ih == jh - 1:
|
|
74
|
+
continue
|
|
75
|
+
_out_assignment.append((ih, f'assign model_out[{ir}:{jr}] = packed_out[{ih}:{jh}];'))
|
|
76
|
+
|
|
77
|
+
for i, (i, j, copy_from) in enumerate(pad_out):
|
|
78
|
+
n_bit = i - j + 1
|
|
79
|
+
pad = f"{n_bit}'b0" if copy_from == -1 else f'{{{n_bit}{{packed_out[{copy_from}]}}}}'
|
|
80
|
+
_out_assignment.append((i, f'assign model_out[{i}:{j}] = {pad};'))
|
|
81
|
+
_out_assignment.sort(key=lambda x: x[0])
|
|
82
|
+
out_assignment = [v for _, v in _out_assignment]
|
|
83
|
+
|
|
84
|
+
inp_assignment_str = '\n '.join(inp_assignment)
|
|
85
|
+
out_assignment_str = '\n '.join(out_assignment)
|
|
86
|
+
|
|
87
|
+
clk_and_rst_inp, clk_and_rst_bind = '', ''
|
|
88
|
+
if pipelined:
|
|
89
|
+
clk_and_rst_inp = '\n input clk,'
|
|
90
|
+
clk_and_rst_bind = '\n .clk(clk),'
|
|
91
|
+
|
|
92
|
+
return f"""`timescale 1 ns / 1 ps
|
|
93
|
+
|
|
94
|
+
module {module_name}_wrapper ({clk_and_rst_inp}
|
|
95
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
96
|
+
input [{w_reg_in - 1}:0] model_inp,
|
|
97
|
+
// verilator lint_on UNUSEDSIGNAL
|
|
98
|
+
output [{w_reg_out - 1}:0] model_out
|
|
99
|
+
);
|
|
100
|
+
wire [{w_het_in - 1}:0] packed_inp;
|
|
101
|
+
wire [{w_het_out - 1}:0] packed_out;
|
|
102
|
+
|
|
103
|
+
{inp_assignment_str}
|
|
104
|
+
|
|
105
|
+
{module_name} op ({clk_and_rst_bind}
|
|
106
|
+
.model_inp(packed_inp),
|
|
107
|
+
.model_out(packed_out)
|
|
108
|
+
);
|
|
109
|
+
|
|
110
|
+
{out_assignment_str}
|
|
111
|
+
|
|
112
|
+
endmodule
|
|
113
|
+
"""
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from ....cmvm.types import Pipeline, _minimal_kif
|
|
2
|
+
from .comb import comb_logic_gen
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def pipeline_logic_gen(
|
|
6
|
+
csol: Pipeline,
|
|
7
|
+
name: str,
|
|
8
|
+
print_latency=False,
|
|
9
|
+
timescale: str | None = '`timescale 1 ns / 1 ps',
|
|
10
|
+
register_layers: int = 1,
|
|
11
|
+
):
|
|
12
|
+
N = len(csol.solutions)
|
|
13
|
+
inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
|
|
14
|
+
out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
|
|
15
|
+
|
|
16
|
+
registers = [f'reg [{width - 1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
|
|
17
|
+
for i in range(0, register_layers - 1):
|
|
18
|
+
registers += [f'reg [{width - 1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
|
|
19
|
+
wires = [f'wire [{width - 1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
|
|
20
|
+
|
|
21
|
+
comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
|
|
22
|
+
|
|
23
|
+
if register_layers == 1:
|
|
24
|
+
serial_logic = ['stage0_inp <= model_inp;']
|
|
25
|
+
serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
|
|
26
|
+
else:
|
|
27
|
+
serial_logic = ['stage0_inp_copy0 <= model_inp;']
|
|
28
|
+
for j in range(1, register_layers - 1):
|
|
29
|
+
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
|
|
30
|
+
serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
|
|
31
|
+
for i in range(1, N):
|
|
32
|
+
serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
|
|
33
|
+
for j in range(1, register_layers - 1):
|
|
34
|
+
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
|
|
35
|
+
serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
|
|
36
|
+
|
|
37
|
+
serial_logic += [f'model_out <= stage{N - 1}_out;']
|
|
38
|
+
|
|
39
|
+
sep0 = '\n '
|
|
40
|
+
sep1 = '\n '
|
|
41
|
+
|
|
42
|
+
module = f"""module {name} (
|
|
43
|
+
input clk,
|
|
44
|
+
input [{inp_bits[0] - 1}:0] model_inp,
|
|
45
|
+
output reg [{out_bits[-1] - 1}:0] model_out
|
|
46
|
+
);
|
|
47
|
+
|
|
48
|
+
{sep0.join(registers)}
|
|
49
|
+
{sep0.join(wires)}
|
|
50
|
+
|
|
51
|
+
{sep0.join(comb_logic)}
|
|
52
|
+
|
|
53
|
+
always @(posedge clk) begin
|
|
54
|
+
{sep1.join(serial_logic)}
|
|
55
|
+
end
|
|
56
|
+
endmodule
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if timescale:
|
|
60
|
+
module = f'{timescale}\n\n{module}'
|
|
61
|
+
|
|
62
|
+
ret: dict[str, str] = {}
|
|
63
|
+
for i, s in enumerate(csol.solutions):
|
|
64
|
+
stage_name = f'{name}_stage{i}'
|
|
65
|
+
ret[stage_name] = comb_logic_gen(s, stage_name, print_latency=print_latency, timescale=timescale)
|
|
66
|
+
ret[name] = module
|
|
67
|
+
return ret
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module lookup_table #(
|
|
5
|
+
parameter BW_IN = 8,
|
|
6
|
+
parameter BW_OUT = 8,
|
|
7
|
+
parameter MEM_FILE = "whatever.mem"
|
|
8
|
+
) (
|
|
9
|
+
input [BW_IN-1:0] in,
|
|
10
|
+
output [BW_OUT-1:0] out
|
|
11
|
+
);
|
|
12
|
+
|
|
13
|
+
(* rom_style = (BW_IN <= 999) ? "distributed" : "block" *)
|
|
14
|
+
reg [BW_OUT-1:0] lut_rom [0:(1<<BW_IN)-1];
|
|
15
|
+
reg [BW_OUT-1:0] readout;
|
|
16
|
+
|
|
17
|
+
initial begin
|
|
18
|
+
$readmemh(MEM_FILE, lut_rom);
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
assign out[BW_OUT-1:0] = readout[BW_OUT-1:0];
|
|
22
|
+
|
|
23
|
+
always @(*) begin
|
|
24
|
+
readout = lut_rom[in];
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
endmodule
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module multiplier #(
|
|
5
|
+
parameter BW_INPUT0 = 32,
|
|
6
|
+
parameter BW_INPUT1 = 32,
|
|
7
|
+
parameter SIGNED0 = 0,
|
|
8
|
+
parameter SIGNED1 = 0,
|
|
9
|
+
parameter BW_OUT = 32
|
|
10
|
+
) (
|
|
11
|
+
input [BW_INPUT0-1:0] in0,
|
|
12
|
+
input [BW_INPUT1-1:0] in1,
|
|
13
|
+
output [BW_OUT-1:0] out
|
|
14
|
+
);
|
|
15
|
+
|
|
16
|
+
localparam BW_BUF = BW_INPUT0 + BW_INPUT1;
|
|
17
|
+
|
|
18
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
19
|
+
wire [BW_BUF - 1:0] buffer;
|
|
20
|
+
// verilator lint_on UNUSEDSIGNAL
|
|
21
|
+
|
|
22
|
+
generate
|
|
23
|
+
if (SIGNED0 == 1 && SIGNED1 == 1) begin : signed_signed
|
|
24
|
+
assign buffer[BW_BUF-1:0] = $signed(in0) * $signed(in1);
|
|
25
|
+
end else if (SIGNED0 == 1 && SIGNED1 == 0) begin : signed_unsigned
|
|
26
|
+
assign buffer[BW_BUF-1:0] = $signed(in0) * $signed({{1'b0,in1}});
|
|
27
|
+
// assign buffer[BW_BUF-1] = in0[BW_INPUT0-1];
|
|
28
|
+
end else if (SIGNED0 == 0 && SIGNED1 == 1) begin : unsigned_signed
|
|
29
|
+
assign buffer[BW_BUF-1:0] = $signed({{1'b0,in0}}) * $signed(in1);
|
|
30
|
+
// assign buffer[BW_BUF-1] = in1[BW_INPUT1-1];
|
|
31
|
+
end else begin : unsigned_unsigned
|
|
32
|
+
assign buffer[BW_BUF-1:0] = in0 * in1;
|
|
33
|
+
end
|
|
34
|
+
endgenerate
|
|
35
|
+
|
|
36
|
+
assign out[BW_OUT-1:0] = buffer[BW_OUT-1:0];
|
|
37
|
+
endmodule
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module mux #(
|
|
5
|
+
parameter BW_INPUT0 = 32,
|
|
6
|
+
parameter BW_INPUT1 = 32,
|
|
7
|
+
parameter SIGNED0 = 0,
|
|
8
|
+
parameter SIGNED1 = 0,
|
|
9
|
+
parameter BW_OUT = 32,
|
|
10
|
+
parameter SHIFT1 = 0,
|
|
11
|
+
parameter INVERT1 = 0
|
|
12
|
+
) (
|
|
13
|
+
input key,
|
|
14
|
+
input [BW_INPUT0-1:0] in0,
|
|
15
|
+
input [BW_INPUT1-1:0] in1,
|
|
16
|
+
output [BW_OUT-1:0] out
|
|
17
|
+
);
|
|
18
|
+
|
|
19
|
+
localparam IN0_NEED_BITS = (SHIFT1 < 0) ? BW_INPUT0 - SHIFT1 : BW_INPUT0;
|
|
20
|
+
localparam IN1_NEED_BITS = (SHIFT1 > 0) ? BW_INPUT1 + SHIFT1 : BW_INPUT1;
|
|
21
|
+
localparam EXTRA_PAD = (SIGNED0 != SIGNED1) ? INVERT1 + 1 : INVERT1 + 0;
|
|
22
|
+
localparam BW_BUF = (IN0_NEED_BITS > IN1_NEED_BITS) ? IN0_NEED_BITS + EXTRA_PAD : IN1_NEED_BITS + EXTRA_PAD;
|
|
23
|
+
localparam IN0_PAD_LEFT = (SHIFT1 < 0) ? BW_BUF - BW_INPUT0 + SHIFT1 : BW_BUF - BW_INPUT0;
|
|
24
|
+
localparam IN0_PAD_RIGHT = (SHIFT1 < 0) ? -SHIFT1 : 0;
|
|
25
|
+
localparam IN1_PAD_LEFT = (SHIFT1 > 0) ? BW_BUF - BW_INPUT1 - SHIFT1 : BW_BUF - BW_INPUT1;
|
|
26
|
+
localparam IN1_PAD_RIGHT = (SHIFT1 > 0) ? SHIFT1 : 0;
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
30
|
+
wire [BW_BUF-1:0] in0_ext;
|
|
31
|
+
wire [BW_BUF-1:0] in1_ext;
|
|
32
|
+
// verilator lint_on UNUSEDSIGNAL
|
|
33
|
+
|
|
34
|
+
generate
|
|
35
|
+
if (SIGNED0 == 1) begin : in0_is_signed
|
|
36
|
+
assign in0_ext = {{IN0_PAD_LEFT{in0[BW_INPUT0-1]}}, in0, {IN0_PAD_RIGHT{1'b0}}};
|
|
37
|
+
end else begin : in0_is_unsigned
|
|
38
|
+
assign in0_ext = {{IN0_PAD_LEFT{1'b0}}, in0, {IN0_PAD_RIGHT{1'b0}}};
|
|
39
|
+
end
|
|
40
|
+
endgenerate
|
|
41
|
+
|
|
42
|
+
generate
|
|
43
|
+
if (SIGNED1 == 1) begin : in1_is_signed
|
|
44
|
+
assign in1_ext = {{IN1_PAD_LEFT{in1[BW_INPUT1-1]}}, in1, {IN1_PAD_RIGHT{1'b0}}};
|
|
45
|
+
end else begin : in1_is_unsigned
|
|
46
|
+
assign in1_ext = {{IN1_PAD_LEFT{1'b0}}, in1, {IN1_PAD_RIGHT{1'b0}}};
|
|
47
|
+
end
|
|
48
|
+
endgenerate
|
|
49
|
+
|
|
50
|
+
generate
|
|
51
|
+
if (INVERT1 == 1) begin : is_invert
|
|
52
|
+
assign out = (key) ? in0_ext[BW_OUT-1:0] : -in1_ext[BW_OUT-1:0];
|
|
53
|
+
end else begin : is_not_invert
|
|
54
|
+
assign out = (key) ? in0_ext[BW_OUT-1:0] : in1_ext[BW_OUT-1:0];
|
|
55
|
+
end
|
|
56
|
+
endgenerate
|
|
57
|
+
|
|
58
|
+
endmodule
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module negative #(
|
|
5
|
+
parameter BW_IN = 32,
|
|
6
|
+
parameter BW_OUT = 32,
|
|
7
|
+
parameter IN_SIGNED = 0
|
|
8
|
+
) (
|
|
9
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
10
|
+
input [ BW_IN-1:0] in,
|
|
11
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
12
|
+
output [BW_OUT-1:0] out
|
|
13
|
+
);
|
|
14
|
+
/* verilator lint_off WIDTHTRUNC */
|
|
15
|
+
generate
|
|
16
|
+
if (BW_IN < BW_OUT) begin : in_is_smaller
|
|
17
|
+
wire [BW_OUT-1:0] in_ext;
|
|
18
|
+
if (IN_SIGNED == 1) begin : is_signed
|
|
19
|
+
assign in_ext = {{BW_OUT - BW_IN{in[BW_IN-1]}}, in};
|
|
20
|
+
end else begin : is_unsigned
|
|
21
|
+
assign in_ext = {{BW_OUT - BW_IN{1'b0}}, in};
|
|
22
|
+
end
|
|
23
|
+
assign out = -in_ext;
|
|
24
|
+
end else begin : in_is_bigger
|
|
25
|
+
wire [BW_IN-1:0] out_ext;
|
|
26
|
+
assign out_ext = -in;
|
|
27
|
+
assign out = out_ext[BW_OUT-1:0];
|
|
28
|
+
end
|
|
29
|
+
endgenerate
|
|
30
|
+
/* verilator lint_on WIDTHTRUNC */
|
|
31
|
+
endmodule
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module shift_adder #(
|
|
5
|
+
parameter BW_INPUT0 = 32,
|
|
6
|
+
parameter BW_INPUT1 = 32,
|
|
7
|
+
parameter SIGNED0 = 0,
|
|
8
|
+
parameter SIGNED1 = 0,
|
|
9
|
+
parameter BW_OUT = 32,
|
|
10
|
+
parameter SHIFT1 = 0,
|
|
11
|
+
parameter IS_SUB = 0
|
|
12
|
+
) (
|
|
13
|
+
input [BW_INPUT0-1:0] in0,
|
|
14
|
+
input [BW_INPUT1-1:0] in1,
|
|
15
|
+
output [BW_OUT-1:0] out
|
|
16
|
+
);
|
|
17
|
+
|
|
18
|
+
localparam IN0_NEED_BITS = (SHIFT1 < 0) ? BW_INPUT0 - SHIFT1 : BW_INPUT0;
|
|
19
|
+
localparam IN1_NEED_BITS = (SHIFT1 > 0) ? BW_INPUT1 + SHIFT1 : BW_INPUT1;
|
|
20
|
+
localparam EXTRA_PAD = (SIGNED0 != SIGNED1) ? IS_SUB + 1 : IS_SUB + 0;
|
|
21
|
+
localparam BW_ADD = (IN0_NEED_BITS > IN1_NEED_BITS) ? IN0_NEED_BITS + EXTRA_PAD + 1 : IN1_NEED_BITS + EXTRA_PAD + 1;
|
|
22
|
+
localparam IN0_PAD_LEFT = (SHIFT1 < 0) ? BW_ADD - BW_INPUT0 + SHIFT1 : BW_ADD - BW_INPUT0;
|
|
23
|
+
localparam IN0_PAD_RIGHT = (SHIFT1 < 0) ? -SHIFT1 : 0;
|
|
24
|
+
localparam IN1_PAD_LEFT = (SHIFT1 > 0) ? BW_ADD - BW_INPUT1 - SHIFT1 : BW_ADD - BW_INPUT1;
|
|
25
|
+
localparam IN1_PAD_RIGHT = (SHIFT1 > 0) ? SHIFT1 : 0;
|
|
26
|
+
|
|
27
|
+
wire [BW_ADD-1:0] in0_ext;
|
|
28
|
+
wire [BW_ADD-1:0] in1_ext;
|
|
29
|
+
|
|
30
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
31
|
+
wire [BW_ADD-1:0] accum;
|
|
32
|
+
// verilator lint_on UNUSEDSIGNAL
|
|
33
|
+
|
|
34
|
+
generate
|
|
35
|
+
if (SIGNED0 == 1) begin : in0_is_signed
|
|
36
|
+
assign in0_ext = {{IN0_PAD_LEFT{in0[BW_INPUT0-1]}}, in0, {IN0_PAD_RIGHT{1'b0}}};
|
|
37
|
+
end else begin : in0_is_unsigned
|
|
38
|
+
assign in0_ext = {{IN0_PAD_LEFT{1'b0}}, in0, {IN0_PAD_RIGHT{1'b0}}};
|
|
39
|
+
end
|
|
40
|
+
endgenerate
|
|
41
|
+
|
|
42
|
+
generate
|
|
43
|
+
if (SIGNED1 == 1) begin : in1_is_signed
|
|
44
|
+
assign in1_ext = {{IN1_PAD_LEFT{in1[BW_INPUT1-1]}}, in1, {IN1_PAD_RIGHT{1'b0}}};
|
|
45
|
+
end else begin : in1_is_unsigned
|
|
46
|
+
assign in1_ext = {{IN1_PAD_LEFT{1'b0}}, in1, {IN1_PAD_RIGHT{1'b0}}};
|
|
47
|
+
end
|
|
48
|
+
endgenerate
|
|
49
|
+
|
|
50
|
+
generate
|
|
51
|
+
if (IS_SUB == 1) begin : is_sub
|
|
52
|
+
assign accum = in0_ext - in1_ext;
|
|
53
|
+
end else begin : is_add
|
|
54
|
+
assign accum = in0_ext + in1_ext;
|
|
55
|
+
end
|
|
56
|
+
endgenerate
|
|
57
|
+
assign out = accum[BW_OUT-1:0];
|
|
58
|
+
|
|
59
|
+
endmodule
|