da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +194 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +240 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.0.dist-info/METADATA +85 -0
  92. da4ml-0.5.0.dist-info/RECORD +96 -0
  93. da4ml-0.5.0.dist-info/WHEEL +6 -0
  94. da4ml-0.5.0.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,206 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+
5
+ from ....cmvm.types import CombLogic, QInterval, _minimal_kif
6
+ from ..verilog.comb import get_table_name
7
+
8
+
9
+ def make_neg(
10
+ signals: list[str],
11
+ assigns: list[str],
12
+ idx: int,
13
+ qint: QInterval,
14
+ v0_name: str,
15
+ neg_repo: dict[int, tuple[int, str]],
16
+ ):
17
+ if idx in neg_repo:
18
+ return neg_repo[idx]
19
+ _min, _max, step = qint
20
+ was_signed = int(_min < 0)
21
+ bw0 = sum(_minimal_kif(qint))
22
+ bw_neg = sum(_minimal_kif(QInterval(-_max, -_min, step)))
23
+ signals.append(f'signal v{idx}_neg : std_logic_vector({bw_neg - 1} downto {0});')
24
+ assigns.append(
25
+ f'op_neg_{idx} : entity work.negative generic map (BW_IN => {bw0}, BW_OUT => {bw_neg}, IN_SIGNED => {was_signed}) port map (neg_in => {v0_name}, neg_out => v{idx}_neg);'
26
+ )
27
+ bw0 = bw_neg
28
+ v0_name = f'v{idx}_neg'
29
+ neg_repo[idx] = (bw0, v0_name)
30
+ return bw0, v0_name
31
+
32
+
33
+ def ssa_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]], print_latency: bool = False):
34
+ ops = sol.ops
35
+ kifs = list(map(_minimal_kif, (op.qint for op in ops)))
36
+ widths = list(map(sum, kifs))
37
+ inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
38
+ inp_widths = list(map(sum, inp_kifs))
39
+ _inp_widths = np.cumsum([0] + inp_widths)
40
+ inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
41
+
42
+ signals = []
43
+ assigns = []
44
+ ref_count = sol.ref_count
45
+
46
+ for i, op in enumerate(ops):
47
+ if ref_count[i] == 0:
48
+ continue
49
+
50
+ bw = widths[i]
51
+ if bw == 0:
52
+ continue
53
+
54
+ signals.append(f'signal v{i}:std_logic_vector({bw - 1} downto {0});')
55
+
56
+ match op.opcode:
57
+ case -1: # Input marker
58
+ i0, i1 = inp_idxs[op.id0]
59
+ line = f'v{i} <= model_inp({i0} downto {i1});'
60
+
61
+ case 0 | 1: # Common a+/-b<<shift oprs
62
+ p0, p1 = kifs[op.id0], kifs[op.id1]
63
+ bw0, bw1 = widths[op.id0], widths[op.id1]
64
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
65
+ shift = op.data + f0 - f1
66
+ line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{op.opcode}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
67
+
68
+ case 2 | -2: # ReLU
69
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
70
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
71
+ v0_name = f'v{op.id0}'
72
+ bw0 = widths[op.id0]
73
+ if op.opcode == -2:
74
+ bw0, v0_name = make_neg(signals, assigns, op.id0, ops[op.id0].qint, v0_name, neg_repo)
75
+ if ops[op.id0].qint.min < 0:
76
+ if bw > 1:
77
+ line = f'v{i} <= {v0_name}({i0} downto {i1}) and ({bw - 1} downto 0 => not {v0_name}({bw0 - 1}));'
78
+ else:
79
+ line = f'v{i}(0) <= {v0_name}(0) and (not {v0_name}({bw0 - 1}));'
80
+ else:
81
+ line = f'v{i} <= {v0_name}({i0} downto {i1});'
82
+
83
+ case 3 | -3: # Explicit quantization
84
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
85
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
86
+ v0_name = f'v{op.id0}'
87
+ bw0 = widths[op.id0]
88
+ if op.opcode == -3:
89
+ bw0, v0_name = make_neg(signals, assigns, op.id0, ops[op.id0].qint, v0_name, neg_repo)
90
+
91
+ if i0 >= bw0:
92
+ if op.opcode == 3:
93
+ assert ops[op.id0].qint.min < 0, f'{i}, {op.id0}'
94
+ else:
95
+ assert ops[op.id0].qint.max > 0, f'{i}, {op.id0}'
96
+
97
+ if i1 >= bw0:
98
+ v0_name = f'({i0 - i1} downto 0 => {v0_name}({bw0 - 1}))'
99
+ else:
100
+ v0_name = f'({i0 - bw0} downto 0 => {v0_name}({bw0 - 1})) & {v0_name}({bw0 - 1} downto {i1})'
101
+ line = f'v{i} <= {v0_name};'
102
+ else:
103
+ line = f'v{i} <= {v0_name}({i0} downto {i1});'
104
+
105
+ case 4: # constant addition
106
+ num = op.data
107
+ sign, mag = int(num < 0), abs(num)
108
+ bw1 = ceil(log2(mag + 1)) if mag > 0 else 1
109
+ bw0 = widths[op.id0]
110
+ s0 = int(kifs[op.id0][0])
111
+ shift = kifs[op.id0][2] - kifs[i][2]
112
+ bin_val = format(mag, f'0{bw1}b')
113
+ line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>0,BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{sign}) port map(in0=>v{op.id0},in1=>"{bin_val}",result=>v{i});'
114
+ case 5: # constant
115
+ num = op.data
116
+ if num < 0:
117
+ num = 2**bw + num
118
+ bin_val = format(num, f'0{bw}b')
119
+ line = f'v{i} <= "{bin_val}";'
120
+
121
+ case 6 | -6: # MSB Muxing
122
+ k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
123
+ p0, p1 = kifs[a], kifs[b]
124
+ inv = '1' if op.opcode == -6 else '0'
125
+ bwk, bw0, bw1 = widths[k], widths[a], widths[b]
126
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
127
+ _shift = (op.data >> 32) & 0xFFFFFFFF
128
+ _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
129
+ shift = f0 - f1 + _shift
130
+ v0, v1 = f'v{a}', f'v{b}'
131
+ if bw0 == 0:
132
+ v0, bw0 = 'B"0"', 1
133
+ if bw1 == 0:
134
+ v1, bw1 = 'B"0"', 1
135
+ line = f'op_{i}:entity work.mux generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},INVERT1=>{inv}) port map(key=>v{k}({bwk - 1}),in0=>{v0},in1=>{v1},result=>v{i});'
136
+
137
+ case 7: # Multiplication
138
+ bw0, bw1 = widths[op.id0], widths[op.id1]
139
+ s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
140
+ line = f'op_{i}:entity work.multiplier generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
141
+
142
+ case 8: # Lookup Table
143
+ name = get_table_name(sol, op)
144
+ bw0 = widths[op.id0]
145
+ line = f'op_{i}:entity work.lookup_table generic map(BW_IN=>{bw0},BW_OUT=>{bw},MEM_FILE=>"{name}") port map(inp=>v{op.id0},outp=>v{i});'
146
+
147
+ case _:
148
+ raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
149
+
150
+ if print_latency:
151
+ line += f' -- {op.latency}'
152
+ assigns.append(line)
153
+ return signals, assigns
154
+
155
+
156
+ def output_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]]):
157
+ assigns = []
158
+ signals = []
159
+ widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
160
+ _widths = np.cumsum([0] + widths)
161
+ out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
162
+ for i, idx in enumerate(sol.out_idxs):
163
+ if idx < 0:
164
+ continue
165
+ i0, i1 = out_idxs[i]
166
+ if i0 == i1 - 1:
167
+ continue
168
+ bw = widths[i]
169
+ if sol.out_negs[i]:
170
+ bw, name = make_neg(signals, assigns, idx, sol.ops[idx].qint, f'v{idx}', neg_repo)
171
+ assigns.append(f'model_out({i0} downto {i1}) <= {name}({bw - 1} downto {0});')
172
+ else:
173
+ assigns.append(f'model_out({i0} downto {i1}) <= v{idx}({bw - 1} downto {0});')
174
+ return signals, assigns
175
+
176
+
177
+ def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
178
+ inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
179
+ out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
180
+
181
+ neg_repo: dict[int, tuple[int, str]] = {}
182
+ ssa_signals, ssa_assigns = ssa_gen(sol, neg_repo=neg_repo, print_latency=print_latency)
183
+ output_signals, output_assigns = output_gen(sol, neg_repo)
184
+ blk = '\n '
185
+
186
+ code = f"""library ieee;
187
+ use ieee.std_logic_1164.all;
188
+ use ieee.numeric_std.all;
189
+
190
+ entity {fn_name} is port(
191
+ model_inp:in std_logic_vector({inp_bits - 1} downto {0});
192
+ model_out:out std_logic_vector({out_bits - 1} downto {0})
193
+ );
194
+ end entity {fn_name};
195
+
196
+ architecture rtl of {fn_name} is
197
+ {blk.join(ssa_signals + output_signals)}
198
+
199
+
200
+ begin
201
+ {blk.join(ssa_assigns + output_assigns)}
202
+
203
+ end architecture rtl;
204
+
205
+ """
206
+ return code
@@ -0,0 +1,120 @@
1
+ from itertools import accumulate
2
+
3
+ from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
4
+
5
+
6
+ def _loc(i: int, j: int):
7
+ return f'({i} downto {j})' if i != j else f'({i})'
8
+
9
+
10
+ def hetero_io_map(qints: list[QInterval], merge: bool = False):
11
+ N = len(qints)
12
+ ks, _is, fs = zip(*map(_minimal_kif, qints))
13
+ Is = [_i + _k for _i, _k in zip(_is, ks)]
14
+ max_I, max_f = max(_is) + max(ks), max(fs)
15
+ max_bw = max_I + max_f
16
+ width_regular, width_packed = max_bw * N, sum(Is) + sum(fs)
17
+
18
+ regular: list[tuple[int, int]] = []
19
+ pads: list[tuple[int, int, int]] = []
20
+
21
+ bws = [I + f for I, f in zip(Is, fs)]
22
+ _bw = list(accumulate([0] + bws))
23
+ hetero = [(i - 1, j) for i, j in zip(_bw[1:], _bw[:-1])]
24
+
25
+ for i in range(N):
26
+ base = max_bw * i
27
+ bias_low = max_f - fs[i]
28
+ bias_high = max_I - Is[i]
29
+ low = base + bias_low
30
+ high = (base + max_bw - 1) - bias_high
31
+ regular.append((high, low))
32
+
33
+ if bias_low != 0:
34
+ pads.append((base + bias_low - 1, base, -1))
35
+ if bias_high != 0:
36
+ copy_from = hetero[i][0] if ks[i] else -1
37
+ pads.append((base + max_bw - 1, base + max_bw - bias_high, copy_from))
38
+
39
+ mask = list(high < low for high, low in hetero)
40
+ regular = [r for r, m in zip(regular, mask) if not m]
41
+ hetero = [h for h, m in zip(hetero, mask) if not m]
42
+
43
+ if not merge:
44
+ return regular, hetero, pads, (width_regular, width_packed)
45
+
46
+ # Merging consecutive intervals when possible
47
+ NN = len(regular) - 2
48
+ for i in range(NN, -1, -1):
49
+ this_high = regular[i][0]
50
+ next_low = regular[i + 1][1]
51
+ if next_low - this_high != 1:
52
+ continue
53
+ regular[i] = (regular[i + 1][0], regular[i][1])
54
+ regular.pop(i + 1)
55
+ hetero[i] = (hetero[i + 1][0], hetero[i][1])
56
+ hetero.pop(i + 1)
57
+
58
+ for i in range(len(pads) - 2, -1, -1):
59
+ if pads[i + 1][1] - pads[i][0] == 1 and pads[i][2] == pads[i + 1][2]:
60
+ pads[i] = (pads[i + 1][0], pads[i][1], pads[i][2])
61
+ pads.pop(i + 1)
62
+
63
+ return regular, hetero, pads, (width_regular, width_packed)
64
+
65
+
66
+ def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
67
+ reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
68
+ reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
69
+
70
+ w_reg_in, w_het_in = shape_in
71
+ w_reg_out, w_het_out = shape_out
72
+
73
+ inp_assignment = [f'packed_inp{_loc(ih, jh)} <= model_inp{_loc(ir, jr)};' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
74
+ _out_assignment: list[tuple[int, str]] = []
75
+
76
+ for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
77
+ if ih == jh - 1:
78
+ continue
79
+ _out_assignment.append((ih, f'model_out{_loc(ir, jr)} <= packed_out{_loc(ih, jh)};'))
80
+
81
+ for i, (i, j, copy_from) in enumerate(pad_out):
82
+ n_bit = i - j + 1
83
+ value = "'0'" if copy_from == -1 else f'packed_out({copy_from})'
84
+ pad = f'(others => {value})' if n_bit > 1 else value
85
+ _out_assignment.append((i, f'model_out{_loc(i, j)} <= {pad};'))
86
+ _out_assignment.sort(key=lambda x: x[0])
87
+ out_assignment = [v for _, v in _out_assignment]
88
+
89
+ inp_assignment_str = '\n '.join(inp_assignment)
90
+ out_assignment_str = '\n '.join(out_assignment)
91
+
92
+ clk_and_rst_inp, clk_and_rst_bind = '', ''
93
+ if pipelined:
94
+ clk_and_rst_inp = '\n clk:in std_logic;'
95
+ clk_and_rst_bind = '\n clk=>clk,'
96
+
97
+ return f"""library ieee;
98
+ use ieee.std_logic_1164.all;
99
+ entity {module_name}_wrapper is port({clk_and_rst_inp}
100
+ model_inp:in std_logic_vector({w_reg_in - 1} downto {0});
101
+ model_out:out std_logic_vector({w_reg_out - 1} downto {0})
102
+ );
103
+ end entity {module_name}_wrapper;
104
+
105
+ architecture rtl of {module_name}_wrapper is
106
+ signal packed_inp:std_logic_vector({w_het_in - 1} downto {0});
107
+ signal packed_out:std_logic_vector({w_het_out - 1} downto {0});
108
+
109
+ begin
110
+ {inp_assignment_str}
111
+
112
+ op:entity work.{module_name} port map({clk_and_rst_bind}
113
+ model_inp=>packed_inp,
114
+ model_out=>packed_out
115
+ );
116
+
117
+ {out_assignment_str}
118
+
119
+ end architecture rtl;
120
+ """
@@ -0,0 +1,71 @@
1
+ from ....cmvm.types import Pipeline, _minimal_kif
2
+ from .comb import comb_logic_gen
3
+
4
+
5
+ def pipeline_logic_gen(
6
+ csol: Pipeline,
7
+ name: str,
8
+ print_latency=False,
9
+ timescale: str | None = None,
10
+ register_layers: int = 1,
11
+ ):
12
+ N = len(csol.solutions)
13
+ inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
+ out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
+
16
+ registers = [f'signal stage{i}_inp:std_logic_vector({width - 1} downto 0);' for i, width in enumerate(inp_bits)]
17
+ for i in range(0, register_layers - 1):
18
+ registers += [f'signal stage{j}_inp_copy{i}:std_logic_vector({width - 1} downto 0);' for j, width in enumerate(inp_bits)]
19
+ wires = [f'signal stage{i}_out:std_logic_vector({width - 1} downto 0);' for i, width in enumerate(out_bits)]
20
+
21
+ comb_logic = [
22
+ f'stage{i}:entity work.{name}_stage{i} port map(model_inp=>stage{i}_inp,model_out=>stage{i}_out);' for i in range(N)
23
+ ]
24
+
25
+ if register_layers == 1:
26
+ serial_logic = ['stage0_inp <= model_inp;']
27
+ serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
28
+ else:
29
+ serial_logic = ['stage0_inp_copy0 <= model_inp;']
30
+ for j in range(1, register_layers - 1):
31
+ serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
32
+ serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
33
+ for i in range(1, N):
34
+ serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
35
+ for j in range(1, register_layers - 1):
36
+ serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
37
+ serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
38
+
39
+ serial_logic += [f'model_out <= stage{N - 1}_out;']
40
+
41
+ blk = '\n '
42
+
43
+ module = f"""library ieee;
44
+ use ieee.std_logic_1164.all;
45
+ entity {name} is port(
46
+ clk:in std_logic;
47
+ model_inp:in std_logic_vector({inp_bits[0] - 1} downto 0);
48
+ model_out:out std_logic_vector({out_bits[-1] - 1} downto 0));
49
+ end entity {name};
50
+
51
+ architecture rtl of {name} is
52
+ {blk.join(registers)}
53
+ {blk.join(wires)}
54
+
55
+ begin
56
+ {blk.join(comb_logic)}
57
+
58
+ process(clk) begin
59
+ if rising_edge(clk) then
60
+ {f'{blk} '.join(serial_logic)}
61
+ end if;
62
+ end process;
63
+ end architecture rtl;
64
+ """
65
+
66
+ ret: dict[str, str] = {}
67
+ for i, s in enumerate(csol.solutions):
68
+ stage_name = f'{name}_stage{i}'
69
+ ret[stage_name] = comb_logic_gen(s, stage_name, print_latency=print_latency, timescale=timescale)
70
+ ret[name] = module
71
+ return ret
@@ -0,0 +1,52 @@
1
+ library ieee;
2
+ use ieee.std_logic_1164.all;
3
+ use ieee.numeric_std.all;
4
+ use std.textio.all;
5
+ use ieee.std_logic_textio.all;
6
+
7
+ entity lookup_table is
8
+ generic (
9
+ BW_IN : positive := 8;
10
+ BW_OUT : positive := 8;
11
+ MEM_FILE : string := "whatever.mem"
12
+ );
13
+ port (
14
+ inp : in std_logic_vector(BW_IN - 1 downto 0);
15
+ outp : out std_logic_vector(BW_OUT - 1 downto 0)
16
+ );
17
+ end entity lookup_table;
18
+
19
+ architecture rtl of lookup_table is
20
+ subtype rom_index_t is natural range 0 to (2 ** BW_IN) - 1;
21
+ type rom_array_t is array (rom_index_t) of std_logic_vector(BW_OUT - 1 downto 0);
22
+
23
+ -- Load the ROM contents from an external hex file.
24
+ impure function init_rom return rom_array_t is
25
+ file rom_file : text;
26
+ variable rom_data : rom_array_t := (others => (others => '0'));
27
+ variable line_in : line;
28
+ variable idx : integer := 0;
29
+ variable data_val : std_logic_vector(BW_OUT - 1 downto 0);
30
+ variable temp_val : std_logic_vector(((BW_OUT + 3) / 4) * 4 - 1 downto 0);
31
+ begin
32
+ file_open(rom_file, MEM_FILE, read_mode);
33
+
34
+ while not endfile(rom_file) loop
35
+ exit when idx > rom_index_t'high;
36
+ readline(rom_file, line_in);
37
+ hread(line_in, temp_val);
38
+ rom_data(idx) := temp_val(BW_OUT - 1 downto 0);
39
+ idx := idx + 1;
40
+ end loop;
41
+
42
+ file_close(rom_file);
43
+ return rom_data;
44
+ end function init_rom;
45
+
46
+ signal ROM_CONTENTS : rom_array_t := init_rom;
47
+
48
+ attribute rom_style : string;
49
+ attribute rom_style of ROM_CONTENTS : signal is "distributed";
50
+ begin
51
+ outp <= ROM_CONTENTS(to_integer(unsigned(inp)));
52
+ end architecture rtl;
@@ -0,0 +1,40 @@
1
+ library ieee;
2
+ use ieee.std_logic_1164.all;
3
+ use ieee.numeric_std.all;
4
+
5
+ entity multiplier is
6
+ generic (
7
+ BW_INPUT0 : integer := 32;
8
+ BW_INPUT1 : integer := 32;
9
+ SIGNED0 : integer := 0;
10
+ SIGNED1 : integer := 0;
11
+ BW_OUT : integer := 32
12
+ );
13
+ port (
14
+ in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
15
+ in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
16
+ result : out std_logic_vector(BW_OUT-1 downto 0)
17
+ );
18
+ end entity multiplier;
19
+
20
+ architecture rtl of multiplier is
21
+ constant BW_BUF : integer := BW_INPUT0 + BW_INPUT1;
22
+ signal mult_buffer : std_logic_vector(BW_BUF-1 downto 0);
23
+ begin
24
+
25
+ gen_mult : process(in0, in1)
26
+ begin
27
+ if SIGNED0 = 1 and SIGNED1 = 1 then
28
+ mult_buffer <= std_logic_vector(resize(signed(in0) * signed(in1), BW_BUF));
29
+ elsif SIGNED0 = 1 and SIGNED1 = 0 then
30
+ mult_buffer <= std_logic_vector(resize(signed(in0) * signed('0' & in1), BW_BUF));
31
+ elsif SIGNED0 = 0 and SIGNED1 = 1 then
32
+ mult_buffer <= std_logic_vector(resize(signed('0' & in0) * signed(in1), BW_BUF));
33
+ else
34
+ mult_buffer <= std_logic_vector(resize(unsigned(in0) * unsigned(in1), BW_BUF));
35
+ end if;
36
+ end process;
37
+
38
+ result <= mult_buffer(BW_OUT-1 downto 0);
39
+
40
+ end architecture rtl;
@@ -0,0 +1,102 @@
1
+ library ieee;
2
+ use ieee.std_logic_1164.all;
3
+ use ieee.numeric_std.all;
4
+
5
+ entity mux is
6
+ generic (
7
+ BW_INPUT0 : integer := 32;
8
+ BW_INPUT1 : integer := 32;
9
+ SIGNED0 : integer := 0;
10
+ SIGNED1 : integer := 0;
11
+ BW_OUT : integer := 32;
12
+ SHIFT1 : integer := 0;
13
+ INVERT1 : integer := 0
14
+ );
15
+ port (
16
+ key : in std_logic;
17
+ in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
18
+ in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
19
+ result : out std_logic_vector(BW_OUT-1 downto 0)
20
+ );
21
+ end entity mux;
22
+
23
+ architecture rtl of mux is
24
+ function max(L, R: integer) return integer is
25
+ begin
26
+ if L > R then
27
+ return L;
28
+ else
29
+ return R;
30
+ end if;
31
+ end function;
32
+
33
+ function if_then_else(cond: boolean; val_true: integer; val_false: integer) return integer is
34
+ begin
35
+ if cond then
36
+ return val_true;
37
+ else
38
+ return val_false;
39
+ end if;
40
+ end function;
41
+
42
+ constant IN0_NEED_BITS : integer := if_then_else(SHIFT1 < 0, BW_INPUT0 - SHIFT1, BW_INPUT0);
43
+ constant IN1_NEED_BITS : integer := if_then_else(SHIFT1 > 0, BW_INPUT1 + SHIFT1, BW_INPUT1);
44
+ constant EXTRA_PAD : integer := if_then_else(SIGNED0 /= SIGNED1, INVERT1 + 1, INVERT1);
45
+ constant BW_BUF : integer := max(IN0_NEED_BITS, IN1_NEED_BITS) + EXTRA_PAD;
46
+
47
+ signal in0_ext : std_logic_vector(BW_BUF-1 downto 0);
48
+ signal in1_ext : std_logic_vector(BW_BUF-1 downto 0);
49
+ signal out_buf : std_logic_vector(BW_BUF-1 downto 0);
50
+
51
+ begin
52
+
53
+ -- Extension and shifting for input 0
54
+ gen_in0_shift_neg: if SHIFT1 < 0 generate
55
+ gen_in0_signed: if SIGNED0 = 1 generate
56
+ in0_ext <= std_logic_vector(resize(signed(in0), BW_BUF)) sll (-SHIFT1);
57
+ end generate;
58
+ gen_in0_unsigned: if SIGNED0 = 0 generate
59
+ in0_ext <= std_logic_vector(resize(unsigned(in0), BW_BUF)) sll (-SHIFT1);
60
+ end generate;
61
+ end generate;
62
+
63
+ gen_in0_shift_pos: if SHIFT1 >= 0 generate
64
+ gen_in0_signed: if SIGNED0 = 1 generate
65
+ in0_ext <= std_logic_vector(resize(signed(in0), BW_BUF));
66
+ end generate;
67
+ gen_in0_unsigned: if SIGNED0 = 0 generate
68
+ in0_ext <= std_logic_vector(resize(unsigned(in0), BW_BUF));
69
+ end generate;
70
+ end generate;
71
+
72
+ -- Extension and shifting for input 1
73
+ gen_in1_shift_pos: if SHIFT1 > 0 generate
74
+ gen_in1_signed: if SIGNED1 = 1 generate
75
+ in1_ext <= std_logic_vector(resize(signed(in1), BW_BUF)) sll SHIFT1;
76
+ end generate;
77
+ gen_in1_unsigned: if SIGNED1 = 0 generate
78
+ in1_ext <= std_logic_vector(resize(unsigned(in1), BW_BUF)) sll SHIFT1;
79
+ end generate;
80
+ end generate;
81
+
82
+ gen_in1_shift_neg: if SHIFT1 <= 0 generate
83
+ gen_in1_signed: if SIGNED1 = 1 generate
84
+ in1_ext <= std_logic_vector(resize(signed(in1), BW_BUF));
85
+ end generate;
86
+ gen_in1_unsigned: if SIGNED1 = 0 generate
87
+ in1_ext <= std_logic_vector(resize(unsigned(in1), BW_BUF));
88
+ end generate;
89
+ end generate;
90
+
91
+ -- Mux logic
92
+ gen_invert: if INVERT1 = 1 generate
93
+ out_buf <= in0_ext when key = '1' else std_logic_vector(-signed(in1_ext));
94
+ end generate;
95
+
96
+ gen_no_invert: if INVERT1 = 0 generate
97
+ out_buf <= in0_ext when key = '1' else in1_ext;
98
+ end generate;
99
+
100
+ result <= out_buf(BW_OUT-1 downto 0);
101
+
102
+ end architecture rtl;
@@ -0,0 +1,35 @@
1
+ library ieee;
2
+ use ieee.std_logic_1164.all;
3
+ use ieee.numeric_std.all;
4
+
5
+ entity negative is
6
+ generic (
7
+ BW_IN : integer := 32;
8
+ BW_OUT : integer := 32;
9
+ IN_SIGNED : integer := 0
10
+ );
11
+ port (
12
+ neg_in : in std_logic_vector(BW_IN-1 downto 0);
13
+ neg_out : out std_logic_vector(BW_OUT-1 downto 0)
14
+ );
15
+ end entity negative;
16
+
17
+ architecture rtl of negative is
18
+ signal in_ext : std_logic_vector(BW_OUT-1 downto 0);
19
+ begin
20
+
21
+ gen_lt : if BW_IN < BW_OUT generate
22
+ gen_signed : if IN_SIGNED = 1 generate
23
+ in_ext <= std_logic_vector(resize(signed(neg_in), BW_OUT));
24
+ end generate;
25
+ gen_unsigned : if IN_SIGNED = 0 generate
26
+ in_ext <= std_logic_vector(resize(unsigned(neg_in), BW_OUT));
27
+ end generate;
28
+ neg_out <= std_logic_vector(-signed(in_ext));
29
+ end generate;
30
+
31
+ gen_ge : if BW_IN >= BW_OUT generate
32
+ neg_out <= std_logic_vector(-signed(neg_in(BW_OUT-1 downto 0)));
33
+ end generate;
34
+
35
+ end architecture rtl;