da4ml 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (55) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/codegen/__init__.py +4 -7
  3. da4ml/codegen/hls/__init__.py +4 -0
  4. da4ml/codegen/{cpp/cpp_codegen.py → hls/hls_codegen.py} +19 -12
  5. da4ml/codegen/{cpp → hls}/hls_model.py +7 -7
  6. da4ml/codegen/rtl/__init__.py +15 -0
  7. da4ml/codegen/{verilog/source → rtl/common_source}/binder_util.hh +4 -4
  8. da4ml/codegen/{verilog/source → rtl/common_source}/build_binder.mk +7 -1
  9. da4ml/codegen/{verilog/source → rtl/common_source}/build_prj.tcl +28 -7
  10. da4ml/codegen/{verilog/verilog_model.py → rtl/rtl_model.py} +87 -16
  11. da4ml/codegen/{verilog → rtl/verilog}/__init__.py +0 -2
  12. da4ml/codegen/{verilog → rtl/verilog}/comb.py +32 -34
  13. da4ml/codegen/{verilog → rtl/verilog}/io_wrapper.py +8 -8
  14. da4ml/codegen/{verilog → rtl/verilog}/pipeline.py +10 -10
  15. da4ml/codegen/{verilog → rtl/verilog}/source/negative.v +2 -1
  16. da4ml/codegen/rtl/vhdl/__init__.py +10 -0
  17. da4ml/codegen/rtl/vhdl/comb.py +192 -0
  18. da4ml/codegen/rtl/vhdl/io_wrapper.py +157 -0
  19. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  20. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  21. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  22. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  23. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  24. da4ml/codegen/rtl/vhdl/source/template.xdc +32 -0
  25. {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/METADATA +2 -2
  26. da4ml-0.4.0.dist-info/RECORD +76 -0
  27. da4ml/codegen/cpp/__init__.py +0 -4
  28. da4ml-0.3.3.dist-info/RECORD +0 -66
  29. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_binary.h +0 -0
  30. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_common.h +0 -0
  31. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_decl.h +0 -0
  32. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed.h +0 -0
  33. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_base.h +0 -0
  34. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_ref.h +0 -0
  35. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_special.h +0 -0
  36. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int.h +0 -0
  37. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_base.h +0 -0
  38. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_ref.h +0 -0
  39. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_special.h +0 -0
  40. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_shift_reg.h +0 -0
  41. /da4ml/codegen/{cpp → hls}/source/ap_types/etc/ap_private.h +0 -0
  42. /da4ml/codegen/{cpp → hls}/source/ap_types/hls_math.h +0 -0
  43. /da4ml/codegen/{cpp → hls}/source/ap_types/hls_stream.h +0 -0
  44. /da4ml/codegen/{cpp → hls}/source/ap_types/utils/x_hls_utils.h +0 -0
  45. /da4ml/codegen/{cpp → hls}/source/binder_util.hh +0 -0
  46. /da4ml/codegen/{cpp → hls}/source/build_binder.mk +0 -0
  47. /da4ml/codegen/{cpp → hls}/source/vitis_bitshift.hh +0 -0
  48. /da4ml/codegen/{verilog/source → rtl/common_source}/ioutil.hh +0 -0
  49. /da4ml/codegen/{verilog/source → rtl/common_source}/template.xdc +0 -0
  50. /da4ml/codegen/{verilog → rtl/verilog}/source/multiplier.v +0 -0
  51. /da4ml/codegen/{verilog → rtl/verilog}/source/mux.v +0 -0
  52. /da4ml/codegen/{verilog → rtl/verilog}/source/shift_adder.v +0 -0
  53. {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/WHEEL +0 -0
  54. {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/licenses/LICENSE +0 -0
  55. {da4ml-0.3.3.dist-info → da4ml-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  from itertools import accumulate
2
2
 
3
- from ...cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
3
+ from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
4
4
 
5
5
 
6
6
  def hetero_io_map(qints: list[QInterval], merge: bool = False):
@@ -66,18 +66,18 @@ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipe
66
66
  w_reg_in, w_het_in = shape_in
67
67
  w_reg_out, w_het_out = shape_out
68
68
 
69
- inp_assignment = [f'assign packed_inp[{ih}:{jh}] = inp[{ir}:{jr}];' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
69
+ inp_assignment = [f'assign packed_inp[{ih}:{jh}] = model_inp[{ir}:{jr}];' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
70
70
  _out_assignment: list[tuple[int, str]] = []
71
71
 
72
72
  for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
73
73
  if ih == jh - 1:
74
74
  continue
75
- _out_assignment.append((ih, f'assign out[{ir}:{jr}] = packed_out[{ih}:{jh}];'))
75
+ _out_assignment.append((ih, f'assign model_out[{ir}:{jr}] = packed_out[{ih}:{jh}];'))
76
76
 
77
77
  for i, (i, j, copy_from) in enumerate(pad_out):
78
78
  n_bit = i - j + 1
79
79
  pad = f"{n_bit}'b0" if copy_from == -1 else f'{{{n_bit}{{packed_out[{copy_from}]}}}}'
80
- _out_assignment.append((i, f'assign out[{i}:{j}] = {pad};'))
80
+ _out_assignment.append((i, f'assign model_out[{i}:{j}] = {pad};'))
81
81
  _out_assignment.sort(key=lambda x: x[0])
82
82
  out_assignment = [v for _, v in _out_assignment]
83
83
 
@@ -93,9 +93,9 @@ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipe
93
93
 
94
94
  module {module_name}_wrapper ({clk_and_rst_inp}
95
95
  // verilator lint_off UNUSEDSIGNAL
96
- input [{w_reg_in - 1}:0] inp,
96
+ input [{w_reg_in - 1}:0] model_inp,
97
97
  // verilator lint_on UNUSEDSIGNAL
98
- output [{w_reg_out - 1}:0] out
98
+ output [{w_reg_out - 1}:0] model_out
99
99
  );
100
100
  wire [{w_het_in - 1}:0] packed_inp;
101
101
  wire [{w_het_out - 1}:0] packed_out;
@@ -103,8 +103,8 @@ module {module_name}_wrapper ({clk_and_rst_inp}
103
103
  {inp_assignment_str}
104
104
 
105
105
  {module_name} op ({clk_and_rst_bind}
106
- .inp(packed_inp),
107
- .out(packed_out)
106
+ .model_inp(packed_inp),
107
+ .model_out(packed_out)
108
108
  );
109
109
 
110
110
  {out_assignment_str}
@@ -1,4 +1,4 @@
1
- from ...cmvm.types import CascadedSolution, _minimal_kif
1
+ from ....cmvm.types import CascadedSolution, _minimal_kif
2
2
  from .comb import comb_logic_gen
3
3
 
4
4
 
@@ -13,18 +13,18 @@ def pipeline_logic_gen(
13
13
  inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
14
  out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
15
 
16
- registers = [f'reg [{width}-1:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
16
+ registers = [f'reg [{width-1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
17
17
  for i in range(0, register_layers - 1):
18
- registers += [f'reg [{width}-1:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
- wires = [f'wire [{width}-1:0] stage{i}_out;' for i, width in enumerate(out_bits)]
18
+ registers += [f'reg [{width-1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
+ wires = [f'wire [{width-1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
20
20
 
21
- comb_logic = [f'{name}_stage{i} stage{i} (.inp(stage{i}_inp), .out(stage{i}_out));' for i in range(N)]
21
+ comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
22
22
 
23
23
  if register_layers == 1:
24
- serial_logic = ['stage0_inp <= inp;']
24
+ serial_logic = ['stage0_inp <= model_inp;']
25
25
  serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
26
26
  else:
27
- serial_logic = ['stage0_inp_copy0 <= inp;']
27
+ serial_logic = ['stage0_inp_copy0 <= model_inp;']
28
28
  for j in range(1, register_layers - 1):
29
29
  serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
30
30
  serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
@@ -34,15 +34,15 @@ def pipeline_logic_gen(
34
34
  serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
35
35
  serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
36
36
 
37
- serial_logic += [f'out <= stage{N-1}_out;']
37
+ serial_logic += [f'model_out <= stage{N-1}_out;']
38
38
 
39
39
  sep0 = '\n '
40
40
  sep1 = '\n '
41
41
 
42
42
  module = f"""module {name} (
43
43
  input clk,
44
- input [{inp_bits[0]-1}:0] inp,
45
- output reg [{out_bits[-1]-1}:0] out
44
+ input [{inp_bits[0]-1}:0] model_inp,
45
+ output reg [{out_bits[-1]-1}:0] model_out
46
46
  );
47
47
 
48
48
  {sep0.join(registers)}
@@ -11,6 +11,7 @@ module negative #(
11
11
  // verilator lint_off UNUSEDSIGNAL
12
12
  output [BW_OUT-1:0] out
13
13
  );
14
+ /* verilator lint_off WIDTHTRUNC */
14
15
  generate
15
16
  if (BW_IN < BW_OUT) begin : in_is_smaller
16
17
  wire [BW_OUT-1:0] in_ext;
@@ -24,5 +25,5 @@ module negative #(
24
25
  assign out = -in[BW_OUT-1:0];
25
26
  end
26
27
  endgenerate
27
-
28
+ /* verilator lint_on WIDTHTRUNC */
28
29
  endmodule
@@ -0,0 +1,10 @@
1
+ from .comb import comb_logic_gen
2
+ from .io_wrapper import binder_gen, generate_io_wrapper
3
+ from .pipeline import pipeline_logic_gen
4
+
5
+ __all__ = [
6
+ 'comb_logic_gen',
7
+ 'generate_io_wrapper',
8
+ 'pipeline_logic_gen',
9
+ 'binder_gen',
10
+ ]
@@ -0,0 +1,192 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+
5
+ from ....cmvm.types import Op, QInterval, Solution, _minimal_kif
6
+
7
+
8
+ def make_neg(
9
+ signals: list[str],
10
+ assigns: list[str],
11
+ op: Op,
12
+ ops: list[Op],
13
+ bw0: int,
14
+ v0_name: str,
15
+ ):
16
+ _min, _max, step = ops[op.id0].qint
17
+ bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
18
+ was_signed = int(_min < 0)
19
+ signals.append(f'signal v{op.id0}_neg : std_logic_vector({bw_neg-1} downto {0});')
20
+ assigns.append(
21
+ f'op_neg_{op.id0} : entity work.negative generic map (BW_IN => {bw0}, BW_OUT => {bw_neg}, IN_SIGNED => {was_signed}) port map (neg_in => {v0_name}, neg_out => v{op.id0}_neg);'
22
+ )
23
+ bw0 = bw_neg
24
+ v0_name = f'v{op.id0}_neg'
25
+ return bw0, v0_name
26
+
27
+
28
+ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
29
+ ops = sol.ops
30
+ kifs = list(map(_minimal_kif, (op.qint for op in ops)))
31
+ widths = list(map(sum, kifs))
32
+ inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
33
+ inp_widths = list(map(sum, inp_kifs))
34
+ _inp_widths = np.cumsum([0] + inp_widths)
35
+ inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
36
+
37
+ signals = []
38
+ assigns = []
39
+ ref_count = sol.ref_count
40
+
41
+ for i, op in enumerate(ops):
42
+ if ref_count[i] == 0:
43
+ continue
44
+
45
+ bw = widths[i]
46
+ if bw == 0:
47
+ continue
48
+
49
+ match op.opcode:
50
+ case -1: # Input marker
51
+ i0, i1 = inp_idxs[op.id0]
52
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
53
+ line = f'v{i} <= model_inp({i0} downto {i1});'
54
+
55
+ case 0 | 1: # Common a+/-b<<shift oprs
56
+ p0, p1 = kifs[op.id0], kifs[op.id1]
57
+ bw0, bw1 = widths[op.id0], widths[op.id1]
58
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
59
+ shift = op.data + f0 - f1
60
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
61
+ line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{op.opcode}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
62
+
63
+ case 2 | -2: # ReLU
64
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
65
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
66
+ v0_name = f'v{op.id0}'
67
+ bw0 = widths[op.id0]
68
+ if op.opcode == -2 and op.id0 not in neg_defined:
69
+ neg_defined.add(op.id0)
70
+ bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
71
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
72
+ if ops[op.id0].qint.min < 0:
73
+ if bw > 1:
74
+ line = f'v{i} <= {v0_name}({i0} downto {i1}) and ({bw - 1} downto 0 => not {v0_name}({bw0-1}));'
75
+ else:
76
+ line = f'v{i}(0) <= {v0_name}(0) and (not {v0_name}({bw0-1}));'
77
+ else:
78
+ line = f'v{i} <= {v0_name}({i0} downto {i1});'
79
+
80
+ case 3 | -3: # Explicit quantization
81
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
82
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
83
+ v0_name = f'v{op.id0}'
84
+ bw0 = widths[op.id0]
85
+ if op.opcode == -3 and op.id0 not in neg_defined:
86
+ neg_defined.add(op.id0)
87
+ bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
88
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
89
+ line = f'v{i} <= {v0_name}({i0} downto {i1});'
90
+
91
+ case 4: # constant addition
92
+ num = op.data
93
+ sign, mag = int(num < 0), abs(num)
94
+ bw1 = ceil(log2(mag + 1)) if mag > 0 else 1
95
+ bw0 = widths[op.id0]
96
+ s0 = int(kifs[op.id0][0])
97
+ shift = kifs[op.id0][2] - kifs[i][2]
98
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
99
+ bin_val = format(mag, f'0{bw1}b')
100
+ line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>0,BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{sign}) port map(in0=>v{op.id0},in1=>"{bin_val}",result=>v{i});'
101
+ case 5: # constant
102
+ num = op.data
103
+ if num < 0:
104
+ num = 2**bw + num
105
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
106
+ bin_val = format(num, f'0{bw}b')
107
+ line = f'v{i} <= "{bin_val}";'
108
+
109
+ case 6 | -6: # MSB Muxing
110
+ k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
111
+ p0, p1 = kifs[a], kifs[b]
112
+ inv = '1' if op.opcode == -6 else '0'
113
+ bwk, bw0, bw1 = widths[k], widths[a], widths[b]
114
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
115
+ _shift = (op.data >> 32) & 0xFFFFFFFF
116
+ _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
117
+ shift = f0 - f1 + _shift
118
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
119
+ line = f'op_{i}:entity work.mux generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},INVERT1=>{inv}) port map(key=>v{k}({bwk-1}),in0=>v{a},in1=>v{b},result=>v{i});'
120
+ case 7: # Multiplication
121
+ bw0, bw1 = widths[op.id0], widths[op.id1]
122
+ s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
123
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
124
+ line = f'op_{i}:entity work.multiplier generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
125
+
126
+ case _:
127
+ raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
128
+
129
+ if print_latency:
130
+ line += f' -- {op.latency}'
131
+ assigns.append(line)
132
+ return signals, assigns
133
+
134
+
135
+ def output_gen(sol: Solution, neg_defined: set[int]):
136
+ assigns = []
137
+ signals = []
138
+ widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
139
+ _widths = np.cumsum([0] + widths)
140
+ out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
141
+ for i, idx in enumerate(sol.out_idxs):
142
+ if idx < 0:
143
+ continue
144
+ i0, i1 = out_idxs[i]
145
+ if i0 == i1 - 1:
146
+ continue
147
+ bw = widths[i]
148
+ if sol.out_negs[i]:
149
+ if idx not in neg_defined:
150
+ neg_defined.add(idx)
151
+ bw0 = sum(_minimal_kif(sol.ops[idx].qint))
152
+ was_signed = int(_minimal_kif(sol.ops[idx].qint)[0])
153
+ signals.append(f'signal v{idx}_neg:std_logic_vector({bw-1} downto {0});')
154
+ assigns.append(
155
+ f'op_neg_{idx}:entity work.negative generic map(BW_IN=>{bw0},BW_OUT=>{bw},IN_SIGNED=>{was_signed}) port map(neg_in=>v{idx},neg_out=>v{idx}_neg);'
156
+ )
157
+ assigns.append(f'model_out({i0} downto {i1}) <= v{idx}_neg({bw-1} downto {0});')
158
+ else:
159
+ assigns.append(f'model_out({i0} downto {i1}) <= v{idx}({bw-1} downto {0});')
160
+ return signals, assigns
161
+
162
+
163
+ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
164
+ inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
165
+ out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
166
+
167
+ neg_defined = set()
168
+ ssa_signals, ssa_assigns = ssa_gen(sol, neg_defined=neg_defined, print_latency=print_latency)
169
+ output_signals, output_assigns = output_gen(sol, neg_defined)
170
+ blk = '\n '
171
+
172
+ code = f"""library ieee;
173
+ use ieee.std_logic_1164.all;
174
+ use ieee.numeric_std.all;
175
+
176
+ entity {fn_name} is port(
177
+ model_inp:in std_logic_vector({inp_bits-1} downto {0});
178
+ model_out:out std_logic_vector({out_bits-1} downto {0})
179
+ );
180
+ end entity {fn_name};
181
+
182
+ architecture rtl of {fn_name} is
183
+ {blk.join(ssa_signals + output_signals)}
184
+
185
+
186
+ begin
187
+ {blk.join(ssa_assigns + output_assigns)}
188
+
189
+ end architecture rtl;
190
+
191
+ """
192
+ return code
@@ -0,0 +1,157 @@
1
+ from itertools import accumulate
2
+
3
+ from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
4
+
5
+
6
+ def _loc(i: int, j: int):
7
+ return f'({i} downto {j})' if i != j else f'({i})'
8
+
9
+
10
+ def hetero_io_map(qints: list[QInterval], merge: bool = False):
11
+ N = len(qints)
12
+ ks, _is, fs = zip(*map(_minimal_kif, qints))
13
+ Is = [_i + _k for _i, _k in zip(_is, ks)]
14
+ max_I, max_f = max(_is) + max(ks), max(fs)
15
+ max_bw = max_I + max_f
16
+ width_regular, width_packed = max_bw * N, sum(Is) + sum(fs)
17
+
18
+ regular: list[tuple[int, int]] = []
19
+ pads: list[tuple[int, int, int]] = []
20
+
21
+ bws = [I + f for I, f in zip(Is, fs)]
22
+ _bw = list(accumulate([0] + bws))
23
+ hetero = [(i - 1, j) for i, j in zip(_bw[1:], _bw[:-1])]
24
+
25
+ for i in range(N):
26
+ base = max_bw * i
27
+ bias_low = max_f - fs[i]
28
+ bias_high = max_I - Is[i]
29
+ low = base + bias_low
30
+ high = (base + max_bw - 1) - bias_high
31
+ regular.append((high, low))
32
+
33
+ if bias_low != 0:
34
+ pads.append((base + bias_low - 1, base, -1))
35
+ if bias_high != 0:
36
+ copy_from = hetero[i][0] if ks[i] else -1
37
+ pads.append((base + max_bw - 1, base + max_bw - bias_high, copy_from))
38
+
39
+ mask = list(high < low for high, low in hetero)
40
+ regular = [r for r, m in zip(regular, mask) if not m]
41
+ hetero = [h for h, m in zip(hetero, mask) if not m]
42
+
43
+ if not merge:
44
+ return regular, hetero, pads, (width_regular, width_packed)
45
+
46
+ # Merging consecutive intervals when possible
47
+ NN = len(regular) - 2
48
+ for i in range(NN, -1, -1):
49
+ this_high = regular[i][0]
50
+ next_low = regular[i + 1][1]
51
+ if next_low - this_high != 1:
52
+ continue
53
+ regular[i] = (regular[i + 1][0], regular[i][1])
54
+ regular.pop(i + 1)
55
+ hetero[i] = (hetero[i + 1][0], hetero[i][1])
56
+ hetero.pop(i + 1)
57
+
58
+ for i in range(len(pads) - 2, -1, -1):
59
+ if pads[i + 1][1] - pads[i][0] == 1 and pads[i][2] == pads[i + 1][2]:
60
+ pads[i] = (pads[i + 1][0], pads[i][1], pads[i][2])
61
+ pads.pop(i + 1)
62
+
63
+ return regular, hetero, pads, (width_regular, width_packed)
64
+
65
+
66
+ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipelined: bool = False):
67
+ reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
68
+ reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
69
+
70
+ w_reg_in, w_het_in = shape_in
71
+ w_reg_out, w_het_out = shape_out
72
+
73
+ inp_assignment = [f'packed_inp{_loc(ih,jh)} <= model_inp{_loc(ir,jr)};' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
74
+ _out_assignment: list[tuple[int, str]] = []
75
+
76
+ for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
77
+ if ih == jh - 1:
78
+ continue
79
+ _out_assignment.append((ih, f'model_out{_loc(ir,jr)} <= packed_out{_loc(ih,jh)};'))
80
+
81
+ for i, (i, j, copy_from) in enumerate(pad_out):
82
+ n_bit = i - j + 1
83
+ value = "'0'" if copy_from == -1 else f'packed_out({copy_from})'
84
+ pad = f'(others => {value})' if n_bit > 1 else value
85
+ _out_assignment.append((i, f'model_out{_loc(i,j)} <= {pad};'))
86
+ _out_assignment.sort(key=lambda x: x[0])
87
+ out_assignment = [v for _, v in _out_assignment]
88
+
89
+ inp_assignment_str = '\n '.join(inp_assignment)
90
+ out_assignment_str = '\n '.join(out_assignment)
91
+
92
+ clk_and_rst_inp, clk_and_rst_bind = '', ''
93
+ if pipelined:
94
+ clk_and_rst_inp = '\n clk:in std_logic;'
95
+ clk_and_rst_bind = '\n clk=>clk,'
96
+
97
+ return f"""library ieee;
98
+ use ieee.std_logic_1164.all;
99
+ entity {module_name}_wrapper is port({clk_and_rst_inp}
100
+ model_inp:in std_logic_vector({w_reg_in-1} downto {0});
101
+ model_out:out std_logic_vector({w_reg_out-1} downto {0})
102
+ );
103
+ end entity {module_name}_wrapper;
104
+
105
+ architecture rtl of {module_name}_wrapper is
106
+ signal packed_inp:std_logic_vector({w_het_in-1} downto {0});
107
+ signal packed_out:std_logic_vector({w_het_out-1} downto {0});
108
+
109
+ begin
110
+ {inp_assignment_str}
111
+
112
+ op:entity work.{module_name} port map({clk_and_rst_bind}
113
+ model_inp=>packed_inp,
114
+ model_out=>packed_out
115
+ );
116
+
117
+ {out_assignment_str}
118
+
119
+ end architecture rtl;
120
+ """
121
+
122
+
123
+ def binder_gen(csol: CascadedSolution | Solution, module_name: str, II: int = 1, latency_multiplier: int = 1):
124
+ k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
125
+ k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
126
+ max_inp_bw = max(k_in) + max(i_in) + max(f_in)
127
+ max_out_bw = max(k_out) + max(i_out) + max(f_out)
128
+ if isinstance(csol, Solution):
129
+ II = latency = 0
130
+ else:
131
+ latency = len(csol.solutions) * latency_multiplier
132
+
133
+ n_in, n_out = csol.shape
134
+ return f"""#include <cstddef>
135
+ #include "binder_util.hh"
136
+ #include "V{module_name}.h"
137
+
138
+ struct {module_name}_config {{
139
+ static const size_t N_inp = {n_in};
140
+ static const size_t N_out = {n_out};
141
+ static const size_t max_inp_bw = {max_inp_bw};
142
+ static const size_t max_out_bw = {max_out_bw};
143
+ static const size_t II = {II};
144
+ static const size_t latency = {latency};
145
+ typedef V{module_name} dut_t;
146
+ }};
147
+
148
+ extern "C" {{
149
+ bool openmp_enabled() {{
150
+ return _openmp;
151
+ }}
152
+
153
+ void inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {{
154
+ batch_inference<{module_name}_config>(c_inp, c_out, n_samples);
155
+ }}
156
+ }}
157
+ """
@@ -0,0 +1,71 @@
1
+ from ....cmvm.types import CascadedSolution, _minimal_kif
2
+ from .comb import comb_logic_gen
3
+
4
+
5
+ def pipeline_logic_gen(
6
+ csol: CascadedSolution,
7
+ name: str,
8
+ print_latency=False,
9
+ timescale: str | None = None,
10
+ register_layers: int = 1,
11
+ ):
12
+ N = len(csol.solutions)
13
+ inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
+ out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
+
16
+ registers = [f'signal stage{i}_inp:std_logic_vector({width-1} downto 0);' for i, width in enumerate(inp_bits)]
17
+ for i in range(0, register_layers - 1):
18
+ registers += [f'signal stage{j}_inp_copy{i}:std_logic_vector({width-1} downto 0);' for j, width in enumerate(inp_bits)]
19
+ wires = [f'signal stage{i}_out:std_logic_vector({width-1} downto 0);' for i, width in enumerate(out_bits)]
20
+
21
+ comb_logic = [
22
+ f'stage{i}:entity work.{name}_stage{i} port map(model_inp=>stage{i}_inp,model_out=>stage{i}_out);' for i in range(N)
23
+ ]
24
+
25
+ if register_layers == 1:
26
+ serial_logic = ['stage0_inp <= model_inp;']
27
+ serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
28
+ else:
29
+ serial_logic = ['stage0_inp_copy0 <= model_inp;']
30
+ for j in range(1, register_layers - 1):
31
+ serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
32
+ serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
33
+ for i in range(1, N):
34
+ serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
35
+ for j in range(1, register_layers - 1):
36
+ serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
37
+ serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
38
+
39
+ serial_logic += [f'model_out <= stage{N-1}_out;']
40
+
41
+ blk = '\n '
42
+
43
+ module = f"""library ieee;
44
+ use ieee.std_logic_1164.all;
45
+ entity {name} is port(
46
+ clk:in std_logic;
47
+ model_inp:in std_logic_vector({inp_bits[0]-1} downto 0);
48
+ model_out:out std_logic_vector({out_bits[-1]-1} downto 0));
49
+ end entity {name};
50
+
51
+ architecture rtl of {name} is
52
+ {blk.join(registers)}
53
+ {blk.join(wires)}
54
+
55
+ begin
56
+ {blk.join(comb_logic)}
57
+
58
+ process(clk) begin
59
+ if rising_edge(clk) then
60
+ {f'{blk} '.join(serial_logic)}
61
+ end if;
62
+ end process;
63
+ end architecture rtl;
64
+ """
65
+
66
+ ret: dict[str, str] = {}
67
+ for i, s in enumerate(csol.solutions):
68
+ stage_name = f'{name}_stage{i}'
69
+ ret[stage_name] = comb_logic_gen(s, stage_name, print_latency=print_latency, timescale=timescale)
70
+ ret[name] = module
71
+ return ret
@@ -0,0 +1,40 @@
1
+ library ieee;
2
+ use ieee.std_logic_1164.all;
3
+ use ieee.numeric_std.all;
4
+
5
+ entity multiplier is
6
+ generic (
7
+ BW_INPUT0 : integer := 32;
8
+ BW_INPUT1 : integer := 32;
9
+ SIGNED0 : integer := 0;
10
+ SIGNED1 : integer := 0;
11
+ BW_OUT : integer := 32
12
+ );
13
+ port (
14
+ in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
15
+ in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
16
+ result : out std_logic_vector(BW_OUT-1 downto 0)
17
+ );
18
+ end entity multiplier;
19
+
20
+ architecture rtl of multiplier is
21
+ constant BW_BUF : integer := BW_INPUT0 + BW_INPUT1;
22
+ signal mult_buffer : std_logic_vector(BW_BUF-1 downto 0);
23
+ begin
24
+
25
+ gen_mult : process(in0, in1)
26
+ begin
27
+ if SIGNED0 = 1 and SIGNED1 = 1 then
28
+ mult_buffer <= std_logic_vector(resize(signed(in0) * signed(in1), BW_BUF));
29
+ elsif SIGNED0 = 1 and SIGNED1 = 0 then
30
+ mult_buffer <= std_logic_vector(resize(signed(in0) * signed('0' & in1), BW_BUF));
31
+ elsif SIGNED0 = 0 and SIGNED1 = 1 then
32
+ mult_buffer <= std_logic_vector(resize(signed('0' & in0) * signed(in1), BW_BUF));
33
+ else
34
+ mult_buffer <= std_logic_vector(resize(unsigned(in0) * unsigned(in1), BW_BUF));
35
+ end if;
36
+ end process;
37
+
38
+ result <= mult_buffer(BW_OUT-1 downto 0);
39
+
40
+ end architecture rtl;