da4ml 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (60) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/codegen/__init__.py +4 -7
  3. da4ml/codegen/hls/__init__.py +4 -0
  4. da4ml/codegen/{cpp/cpp_codegen.py → hls/hls_codegen.py} +19 -12
  5. da4ml/codegen/{cpp → hls}/hls_model.py +7 -7
  6. da4ml/codegen/hls/source/binder_util.hh +50 -0
  7. da4ml/codegen/{cpp → hls}/source/vitis_bitshift.hh +5 -3
  8. da4ml/codegen/rtl/__init__.py +15 -0
  9. da4ml/codegen/{verilog/source → rtl/common_source}/binder_util.hh +4 -4
  10. da4ml/codegen/{verilog/source → rtl/common_source}/build_binder.mk +7 -1
  11. da4ml/codegen/{verilog/source → rtl/common_source}/build_prj.tcl +28 -7
  12. da4ml/codegen/{verilog/verilog_model.py → rtl/rtl_model.py} +87 -16
  13. da4ml/codegen/{verilog → rtl/verilog}/__init__.py +0 -2
  14. da4ml/codegen/{verilog → rtl/verilog}/comb.py +32 -34
  15. da4ml/codegen/{verilog → rtl/verilog}/io_wrapper.py +8 -8
  16. da4ml/codegen/{verilog → rtl/verilog}/pipeline.py +10 -10
  17. da4ml/codegen/{verilog → rtl/verilog}/source/negative.v +2 -1
  18. da4ml/codegen/rtl/vhdl/__init__.py +10 -0
  19. da4ml/codegen/rtl/vhdl/comb.py +192 -0
  20. da4ml/codegen/rtl/vhdl/io_wrapper.py +157 -0
  21. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  22. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  23. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  24. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  25. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  26. da4ml/codegen/rtl/vhdl/source/template.xdc +32 -0
  27. da4ml/converter/hgq2/parser.py +4 -2
  28. da4ml/trace/fixed_variable.py +4 -0
  29. da4ml/trace/fixed_variable_array.py +4 -0
  30. da4ml/trace/ops/reduce_utils.py +3 -3
  31. {da4ml-0.3.2.dist-info → da4ml-0.4.0.dist-info}/METADATA +2 -2
  32. da4ml-0.4.0.dist-info/RECORD +76 -0
  33. da4ml/codegen/cpp/__init__.py +0 -4
  34. da4ml/codegen/cpp/source/binder_util.hh +0 -56
  35. da4ml-0.3.2.dist-info/RECORD +0 -66
  36. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_binary.h +0 -0
  37. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_common.h +0 -0
  38. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_decl.h +0 -0
  39. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed.h +0 -0
  40. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_base.h +0 -0
  41. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_ref.h +0 -0
  42. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_fixed_special.h +0 -0
  43. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int.h +0 -0
  44. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_base.h +0 -0
  45. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_ref.h +0 -0
  46. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_int_special.h +0 -0
  47. /da4ml/codegen/{cpp → hls}/source/ap_types/ap_shift_reg.h +0 -0
  48. /da4ml/codegen/{cpp → hls}/source/ap_types/etc/ap_private.h +0 -0
  49. /da4ml/codegen/{cpp → hls}/source/ap_types/hls_math.h +0 -0
  50. /da4ml/codegen/{cpp → hls}/source/ap_types/hls_stream.h +0 -0
  51. /da4ml/codegen/{cpp → hls}/source/ap_types/utils/x_hls_utils.h +0 -0
  52. /da4ml/codegen/{cpp → hls}/source/build_binder.mk +0 -0
  53. /da4ml/codegen/{verilog/source → rtl/common_source}/ioutil.hh +0 -0
  54. /da4ml/codegen/{verilog/source → rtl/common_source}/template.xdc +0 -0
  55. /da4ml/codegen/{verilog → rtl/verilog}/source/multiplier.v +0 -0
  56. /da4ml/codegen/{verilog → rtl/verilog}/source/mux.v +0 -0
  57. /da4ml/codegen/{verilog → rtl/verilog}/source/shift_adder.v +0 -0
  58. {da4ml-0.3.2.dist-info → da4ml-0.4.0.dist-info}/WHEEL +0 -0
  59. {da4ml-0.3.2.dist-info → da4ml-0.4.0.dist-info}/licenses/LICENSE +0 -0
  60. {da4ml-0.3.2.dist-info → da4ml-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,25 @@ from math import ceil, log2
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ...cmvm.types import QInterval, Solution, _minimal_kif
5
+ from ....cmvm.types import Op, QInterval, Solution, _minimal_kif
6
+
7
+
8
+ def make_neg(
9
+ lines: list[str],
10
+ op: Op,
11
+ ops: list[Op],
12
+ bw0: int,
13
+ v0_name: str,
14
+ ):
15
+ _min, _max, step = ops[op.id0].qint
16
+ bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
17
+ was_signed = int(_min < 0)
18
+ lines.append(
19
+ f'wire [{bw_neg - 1}:0] v{op.id0}_neg; negative #({bw0}, {bw_neg}, {was_signed}) op_neg_{op.id0} ({v0_name}, v{op.id0}_neg);'
20
+ )
21
+ bw0 = bw_neg
22
+ v0_name = f'v{op.id0}_neg'
23
+ return bw0, v0_name
6
24
 
7
25
 
8
26
  def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
@@ -30,7 +48,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
30
48
  match op.opcode:
31
49
  case -1: # Input marker
32
50
  i0, i1 = inp_idxs[op.id0]
33
- line = f'{_def} assign {v} = inp[{i0}:{i1}];'
51
+ line = f'{_def} assign {v} = model_inp[{i0}:{i1}];'
34
52
 
35
53
  case 0 | 1: # Common a+/-b<<shift oprs
36
54
  p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
@@ -49,45 +67,25 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
49
67
  v0_name = f'v{op.id0}'
50
68
  bw0 = widths[op.id0]
51
69
 
52
- if op.opcode == -2:
53
- _min, _max, step = ops[op.id0].qint
54
- bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
55
- if op.id0 not in neg_defined:
56
- neg_defined.add(op.id0)
57
- was_signed = int(kifs[op.id0][0])
58
- lines.append(
59
- f'wire [{bw_neg - 1}:0] v{op.id0}_neg; negative #({bw0}, {bw_neg}, {was_signed}) op_neg_{op.id0} ({v0_name}, v{op.id0}_neg);'
60
- )
61
- bw0 = bw_neg
62
- v0_name = f'v{op.id0}_neg'
70
+ if op.opcode == -2 and op.id0 not in neg_defined:
71
+ neg_defined.add(op.id0)
72
+ bw0, v0_name = make_neg(lines, op, ops, bw0, v0_name)
63
73
  if ops[op.id0].qint.min < 0:
64
74
  line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}] & {{{bw}{{~{v0_name}[{bw0 - 1}]}}}};'
65
75
  else:
66
76
  line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
77
+
67
78
  case 3 | -3: # Explicit quantization
68
79
  lsb_bias = kifs[op.id0][2] - kifs[i][2]
69
80
  i0, i1 = bw + lsb_bias - 1, lsb_bias
70
81
  v0_name = f'v{op.id0}'
71
82
  bw0 = widths[op.id0]
72
83
 
73
- if op.opcode == -3:
74
- _min, _max, step = ops[op.id0].qint
75
- lines.append('/* verilator lint_off WIDTHTRUNC */')
76
- bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
77
- if op.id0 not in neg_defined:
78
- neg_defined.add(op.id0)
79
- # lines.append('/* verilator lint_off WIDTHTRUNC */')
80
- # lines.append(
81
- # f'wire [{bw_neg - 1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg - 1}:0] = -{v0_name}[{bw0 - 1}:0];'
82
- # )
83
- # lines.append('/* verilator lint_on WIDTHTRUNC */')
84
- was_signed = int(kifs[op.id0][0])
85
- lines.append(
86
- f'wire [{bw_neg - 1}:0] v{op.id0}_neg; negative #({bw0}, {bw_neg}, {was_signed}) op_neg_{op.id0} ({v0_name}, v{op.id0}_neg);'
87
- )
88
- v0_name = f'v{op.id0}_neg'
89
-
84
+ if op.opcode == -3 and op.id0 not in neg_defined:
85
+ neg_defined.add(op.id0)
86
+ bw0, v0_name = make_neg(lines, op, ops, bw0, v0_name)
90
87
  line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
88
+
91
89
  case 4: # constant addition
92
90
  num = op.data
93
91
  sign, mag = int(num < 0), abs(num)
@@ -152,10 +150,10 @@ def output_gen(sol: Solution, neg_defined: set[int]):
152
150
  lines.append(
153
151
  f'wire [{bw - 1}:0] v{idx}_neg; negative #({bw0}, {bw}, {was_signed}) op_neg_{idx} (v{idx}, v{idx}_neg);'
154
152
  )
155
- lines.append(f'assign out[{i0}:{i1}] = v{idx}_neg[{bw - 1}:0];')
153
+ lines.append(f'assign model_out[{i0}:{i1}] = v{idx}_neg[{bw - 1}:0];')
156
154
 
157
155
  else:
158
- lines.append(f'assign out[{i0}:{i1}] = v{idx}[{bw - 1}:0];')
156
+ lines.append(f'assign model_out[{i0}:{i1}] = v{idx}[{bw - 1}:0];')
159
157
  return lines
160
158
 
161
159
 
@@ -165,8 +163,8 @@ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, tim
165
163
 
166
164
  fn_signature = [
167
165
  f'module {fn_name} (',
168
- f' input [{inp_bits - 1}:0] inp,',
169
- f' output [{out_bits - 1}:0] out',
166
+ f' input [{inp_bits - 1}:0] model_inp,',
167
+ f' output [{out_bits - 1}:0] model_out',
170
168
  ');',
171
169
  ]
172
170
 
@@ -1,6 +1,6 @@
1
1
  from itertools import accumulate
2
2
 
3
- from ...cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
3
+ from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
4
4
 
5
5
 
6
6
  def hetero_io_map(qints: list[QInterval], merge: bool = False):
@@ -66,18 +66,18 @@ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipe
66
66
  w_reg_in, w_het_in = shape_in
67
67
  w_reg_out, w_het_out = shape_out
68
68
 
69
- inp_assignment = [f'assign packed_inp[{ih}:{jh}] = inp[{ir}:{jr}];' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
69
+ inp_assignment = [f'assign packed_inp[{ih}:{jh}] = model_inp[{ir}:{jr}];' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
70
70
  _out_assignment: list[tuple[int, str]] = []
71
71
 
72
72
  for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
73
73
  if ih == jh - 1:
74
74
  continue
75
- _out_assignment.append((ih, f'assign out[{ir}:{jr}] = packed_out[{ih}:{jh}];'))
75
+ _out_assignment.append((ih, f'assign model_out[{ir}:{jr}] = packed_out[{ih}:{jh}];'))
76
76
 
77
77
  for i, (i, j, copy_from) in enumerate(pad_out):
78
78
  n_bit = i - j + 1
79
79
  pad = f"{n_bit}'b0" if copy_from == -1 else f'{{{n_bit}{{packed_out[{copy_from}]}}}}'
80
- _out_assignment.append((i, f'assign out[{i}:{j}] = {pad};'))
80
+ _out_assignment.append((i, f'assign model_out[{i}:{j}] = {pad};'))
81
81
  _out_assignment.sort(key=lambda x: x[0])
82
82
  out_assignment = [v for _, v in _out_assignment]
83
83
 
@@ -93,9 +93,9 @@ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipe
93
93
 
94
94
  module {module_name}_wrapper ({clk_and_rst_inp}
95
95
  // verilator lint_off UNUSEDSIGNAL
96
- input [{w_reg_in - 1}:0] inp,
96
+ input [{w_reg_in - 1}:0] model_inp,
97
97
  // verilator lint_on UNUSEDSIGNAL
98
- output [{w_reg_out - 1}:0] out
98
+ output [{w_reg_out - 1}:0] model_out
99
99
  );
100
100
  wire [{w_het_in - 1}:0] packed_inp;
101
101
  wire [{w_het_out - 1}:0] packed_out;
@@ -103,8 +103,8 @@ module {module_name}_wrapper ({clk_and_rst_inp}
103
103
  {inp_assignment_str}
104
104
 
105
105
  {module_name} op ({clk_and_rst_bind}
106
- .inp(packed_inp),
107
- .out(packed_out)
106
+ .model_inp(packed_inp),
107
+ .model_out(packed_out)
108
108
  );
109
109
 
110
110
  {out_assignment_str}
@@ -1,4 +1,4 @@
1
- from ...cmvm.types import CascadedSolution, _minimal_kif
1
+ from ....cmvm.types import CascadedSolution, _minimal_kif
2
2
  from .comb import comb_logic_gen
3
3
 
4
4
 
@@ -13,18 +13,18 @@ def pipeline_logic_gen(
13
13
  inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
14
  out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
15
 
16
- registers = [f'reg [{width}-1:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
16
+ registers = [f'reg [{width-1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
17
17
  for i in range(0, register_layers - 1):
18
- registers += [f'reg [{width}-1:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
- wires = [f'wire [{width}-1:0] stage{i}_out;' for i, width in enumerate(out_bits)]
18
+ registers += [f'reg [{width-1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
+ wires = [f'wire [{width-1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
20
20
 
21
- comb_logic = [f'{name}_stage{i} stage{i} (.inp(stage{i}_inp), .out(stage{i}_out));' for i in range(N)]
21
+ comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
22
22
 
23
23
  if register_layers == 1:
24
- serial_logic = ['stage0_inp <= inp;']
24
+ serial_logic = ['stage0_inp <= model_inp;']
25
25
  serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
26
26
  else:
27
- serial_logic = ['stage0_inp_copy0 <= inp;']
27
+ serial_logic = ['stage0_inp_copy0 <= model_inp;']
28
28
  for j in range(1, register_layers - 1):
29
29
  serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
30
30
  serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
@@ -34,15 +34,15 @@ def pipeline_logic_gen(
34
34
  serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
35
35
  serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
36
36
 
37
- serial_logic += [f'out <= stage{N-1}_out;']
37
+ serial_logic += [f'model_out <= stage{N-1}_out;']
38
38
 
39
39
  sep0 = '\n '
40
40
  sep1 = '\n '
41
41
 
42
42
  module = f"""module {name} (
43
43
  input clk,
44
- input [{inp_bits[0]-1}:0] inp,
45
- output reg [{out_bits[-1]-1}:0] out
44
+ input [{inp_bits[0]-1}:0] model_inp,
45
+ output reg [{out_bits[-1]-1}:0] model_out
46
46
  );
47
47
 
48
48
  {sep0.join(registers)}
@@ -11,6 +11,7 @@ module negative #(
11
11
  // verilator lint_off UNUSEDSIGNAL
12
12
  output [BW_OUT-1:0] out
13
13
  );
14
+ /* verilator lint_off WIDTHTRUNC */
14
15
  generate
15
16
  if (BW_IN < BW_OUT) begin : in_is_smaller
16
17
  wire [BW_OUT-1:0] in_ext;
@@ -24,5 +25,5 @@ module negative #(
24
25
  assign out = -in[BW_OUT-1:0];
25
26
  end
26
27
  endgenerate
27
-
28
+ /* verilator lint_on WIDTHTRUNC */
28
29
  endmodule
@@ -0,0 +1,10 @@
1
+ from .comb import comb_logic_gen
2
+ from .io_wrapper import binder_gen, generate_io_wrapper
3
+ from .pipeline import pipeline_logic_gen
4
+
5
+ __all__ = [
6
+ 'comb_logic_gen',
7
+ 'generate_io_wrapper',
8
+ 'pipeline_logic_gen',
9
+ 'binder_gen',
10
+ ]
@@ -0,0 +1,192 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+
5
+ from ....cmvm.types import Op, QInterval, Solution, _minimal_kif
6
+
7
+
8
+ def make_neg(
9
+ signals: list[str],
10
+ assigns: list[str],
11
+ op: Op,
12
+ ops: list[Op],
13
+ bw0: int,
14
+ v0_name: str,
15
+ ):
16
+ _min, _max, step = ops[op.id0].qint
17
+ bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
18
+ was_signed = int(_min < 0)
19
+ signals.append(f'signal v{op.id0}_neg : std_logic_vector({bw_neg-1} downto {0});')
20
+ assigns.append(
21
+ f'op_neg_{op.id0} : entity work.negative generic map (BW_IN => {bw0}, BW_OUT => {bw_neg}, IN_SIGNED => {was_signed}) port map (neg_in => {v0_name}, neg_out => v{op.id0}_neg);'
22
+ )
23
+ bw0 = bw_neg
24
+ v0_name = f'v{op.id0}_neg'
25
+ return bw0, v0_name
26
+
27
+
28
+ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
29
+ ops = sol.ops
30
+ kifs = list(map(_minimal_kif, (op.qint for op in ops)))
31
+ widths = list(map(sum, kifs))
32
+ inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
33
+ inp_widths = list(map(sum, inp_kifs))
34
+ _inp_widths = np.cumsum([0] + inp_widths)
35
+ inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
36
+
37
+ signals = []
38
+ assigns = []
39
+ ref_count = sol.ref_count
40
+
41
+ for i, op in enumerate(ops):
42
+ if ref_count[i] == 0:
43
+ continue
44
+
45
+ bw = widths[i]
46
+ if bw == 0:
47
+ continue
48
+
49
+ match op.opcode:
50
+ case -1: # Input marker
51
+ i0, i1 = inp_idxs[op.id0]
52
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
53
+ line = f'v{i} <= model_inp({i0} downto {i1});'
54
+
55
+ case 0 | 1: # Common a+/-b<<shift oprs
56
+ p0, p1 = kifs[op.id0], kifs[op.id1]
57
+ bw0, bw1 = widths[op.id0], widths[op.id1]
58
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
59
+ shift = op.data + f0 - f1
60
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
61
+ line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{op.opcode}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
62
+
63
+ case 2 | -2: # ReLU
64
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
65
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
66
+ v0_name = f'v{op.id0}'
67
+ bw0 = widths[op.id0]
68
+ if op.opcode == -2 and op.id0 not in neg_defined:
69
+ neg_defined.add(op.id0)
70
+ bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
71
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
72
+ if ops[op.id0].qint.min < 0:
73
+ if bw > 1:
74
+ line = f'v{i} <= {v0_name}({i0} downto {i1}) and ({bw - 1} downto 0 => not {v0_name}({bw0-1}));'
75
+ else:
76
+ line = f'v{i}(0) <= {v0_name}(0) and (not {v0_name}({bw0-1}));'
77
+ else:
78
+ line = f'v{i} <= {v0_name}({i0} downto {i1});'
79
+
80
+ case 3 | -3: # Explicit quantization
81
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
82
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
83
+ v0_name = f'v{op.id0}'
84
+ bw0 = widths[op.id0]
85
+ if op.opcode == -3 and op.id0 not in neg_defined:
86
+ neg_defined.add(op.id0)
87
+ bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
88
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
89
+ line = f'v{i} <= {v0_name}({i0} downto {i1});'
90
+
91
+ case 4: # constant addition
92
+ num = op.data
93
+ sign, mag = int(num < 0), abs(num)
94
+ bw1 = ceil(log2(mag + 1)) if mag > 0 else 1
95
+ bw0 = widths[op.id0]
96
+ s0 = int(kifs[op.id0][0])
97
+ shift = kifs[op.id0][2] - kifs[i][2]
98
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
99
+ bin_val = format(mag, f'0{bw1}b')
100
+ line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>0,BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{sign}) port map(in0=>v{op.id0},in1=>"{bin_val}",result=>v{i});'
101
+ case 5: # constant
102
+ num = op.data
103
+ if num < 0:
104
+ num = 2**bw + num
105
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
106
+ bin_val = format(num, f'0{bw}b')
107
+ line = f'v{i} <= "{bin_val}";'
108
+
109
+ case 6 | -6: # MSB Muxing
110
+ k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
111
+ p0, p1 = kifs[a], kifs[b]
112
+ inv = '1' if op.opcode == -6 else '0'
113
+ bwk, bw0, bw1 = widths[k], widths[a], widths[b]
114
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
115
+ _shift = (op.data >> 32) & 0xFFFFFFFF
116
+ _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
117
+ shift = f0 - f1 + _shift
118
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
119
+ line = f'op_{i}:entity work.mux generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},INVERT1=>{inv}) port map(key=>v{k}({bwk-1}),in0=>v{a},in1=>v{b},result=>v{i});'
120
+ case 7: # Multiplication
121
+ bw0, bw1 = widths[op.id0], widths[op.id1]
122
+ s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
123
+ signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
124
+ line = f'op_{i}:entity work.multiplier generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
125
+
126
+ case _:
127
+ raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
128
+
129
+ if print_latency:
130
+ line += f' -- {op.latency}'
131
+ assigns.append(line)
132
+ return signals, assigns
133
+
134
+
135
+ def output_gen(sol: Solution, neg_defined: set[int]):
136
+ assigns = []
137
+ signals = []
138
+ widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
139
+ _widths = np.cumsum([0] + widths)
140
+ out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
141
+ for i, idx in enumerate(sol.out_idxs):
142
+ if idx < 0:
143
+ continue
144
+ i0, i1 = out_idxs[i]
145
+ if i0 == i1 - 1:
146
+ continue
147
+ bw = widths[i]
148
+ if sol.out_negs[i]:
149
+ if idx not in neg_defined:
150
+ neg_defined.add(idx)
151
+ bw0 = sum(_minimal_kif(sol.ops[idx].qint))
152
+ was_signed = int(_minimal_kif(sol.ops[idx].qint)[0])
153
+ signals.append(f'signal v{idx}_neg:std_logic_vector({bw-1} downto {0});')
154
+ assigns.append(
155
+ f'op_neg_{idx}:entity work.negative generic map(BW_IN=>{bw0},BW_OUT=>{bw},IN_SIGNED=>{was_signed}) port map(neg_in=>v{idx},neg_out=>v{idx}_neg);'
156
+ )
157
+ assigns.append(f'model_out({i0} downto {i1}) <= v{idx}_neg({bw-1} downto {0});')
158
+ else:
159
+ assigns.append(f'model_out({i0} downto {i1}) <= v{idx}({bw-1} downto {0});')
160
+ return signals, assigns
161
+
162
+
163
+ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
164
+ inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
165
+ out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
166
+
167
+ neg_defined = set()
168
+ ssa_signals, ssa_assigns = ssa_gen(sol, neg_defined=neg_defined, print_latency=print_latency)
169
+ output_signals, output_assigns = output_gen(sol, neg_defined)
170
+ blk = '\n '
171
+
172
+ code = f"""library ieee;
173
+ use ieee.std_logic_1164.all;
174
+ use ieee.numeric_std.all;
175
+
176
+ entity {fn_name} is port(
177
+ model_inp:in std_logic_vector({inp_bits-1} downto {0});
178
+ model_out:out std_logic_vector({out_bits-1} downto {0})
179
+ );
180
+ end entity {fn_name};
181
+
182
+ architecture rtl of {fn_name} is
183
+ {blk.join(ssa_signals + output_signals)}
184
+
185
+
186
+ begin
187
+ {blk.join(ssa_assigns + output_assigns)}
188
+
189
+ end architecture rtl;
190
+
191
+ """
192
+ return code
@@ -0,0 +1,157 @@
1
+ from itertools import accumulate
2
+
3
+ from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
4
+
5
+
6
+ def _loc(i: int, j: int):
7
+ return f'({i} downto {j})' if i != j else f'({i})'
8
+
9
+
10
+ def hetero_io_map(qints: list[QInterval], merge: bool = False):
11
+ N = len(qints)
12
+ ks, _is, fs = zip(*map(_minimal_kif, qints))
13
+ Is = [_i + _k for _i, _k in zip(_is, ks)]
14
+ max_I, max_f = max(_is) + max(ks), max(fs)
15
+ max_bw = max_I + max_f
16
+ width_regular, width_packed = max_bw * N, sum(Is) + sum(fs)
17
+
18
+ regular: list[tuple[int, int]] = []
19
+ pads: list[tuple[int, int, int]] = []
20
+
21
+ bws = [I + f for I, f in zip(Is, fs)]
22
+ _bw = list(accumulate([0] + bws))
23
+ hetero = [(i - 1, j) for i, j in zip(_bw[1:], _bw[:-1])]
24
+
25
+ for i in range(N):
26
+ base = max_bw * i
27
+ bias_low = max_f - fs[i]
28
+ bias_high = max_I - Is[i]
29
+ low = base + bias_low
30
+ high = (base + max_bw - 1) - bias_high
31
+ regular.append((high, low))
32
+
33
+ if bias_low != 0:
34
+ pads.append((base + bias_low - 1, base, -1))
35
+ if bias_high != 0:
36
+ copy_from = hetero[i][0] if ks[i] else -1
37
+ pads.append((base + max_bw - 1, base + max_bw - bias_high, copy_from))
38
+
39
+ mask = list(high < low for high, low in hetero)
40
+ regular = [r for r, m in zip(regular, mask) if not m]
41
+ hetero = [h for h, m in zip(hetero, mask) if not m]
42
+
43
+ if not merge:
44
+ return regular, hetero, pads, (width_regular, width_packed)
45
+
46
+ # Merging consecutive intervals when possible
47
+ NN = len(regular) - 2
48
+ for i in range(NN, -1, -1):
49
+ this_high = regular[i][0]
50
+ next_low = regular[i + 1][1]
51
+ if next_low - this_high != 1:
52
+ continue
53
+ regular[i] = (regular[i + 1][0], regular[i][1])
54
+ regular.pop(i + 1)
55
+ hetero[i] = (hetero[i + 1][0], hetero[i][1])
56
+ hetero.pop(i + 1)
57
+
58
+ for i in range(len(pads) - 2, -1, -1):
59
+ if pads[i + 1][1] - pads[i][0] == 1 and pads[i][2] == pads[i + 1][2]:
60
+ pads[i] = (pads[i + 1][0], pads[i][1], pads[i][2])
61
+ pads.pop(i + 1)
62
+
63
+ return regular, hetero, pads, (width_regular, width_packed)
64
+
65
+
66
+ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipelined: bool = False):
67
+ reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
68
+ reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
69
+
70
+ w_reg_in, w_het_in = shape_in
71
+ w_reg_out, w_het_out = shape_out
72
+
73
+ inp_assignment = [f'packed_inp{_loc(ih,jh)} <= model_inp{_loc(ir,jr)};' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
74
+ _out_assignment: list[tuple[int, str]] = []
75
+
76
+ for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
77
+ if ih == jh - 1:
78
+ continue
79
+ _out_assignment.append((ih, f'model_out{_loc(ir,jr)} <= packed_out{_loc(ih,jh)};'))
80
+
81
+ for i, (i, j, copy_from) in enumerate(pad_out):
82
+ n_bit = i - j + 1
83
+ value = "'0'" if copy_from == -1 else f'packed_out({copy_from})'
84
+ pad = f'(others => {value})' if n_bit > 1 else value
85
+ _out_assignment.append((i, f'model_out{_loc(i,j)} <= {pad};'))
86
+ _out_assignment.sort(key=lambda x: x[0])
87
+ out_assignment = [v for _, v in _out_assignment]
88
+
89
+ inp_assignment_str = '\n '.join(inp_assignment)
90
+ out_assignment_str = '\n '.join(out_assignment)
91
+
92
+ clk_and_rst_inp, clk_and_rst_bind = '', ''
93
+ if pipelined:
94
+ clk_and_rst_inp = '\n clk:in std_logic;'
95
+ clk_and_rst_bind = '\n clk=>clk,'
96
+
97
+ return f"""library ieee;
98
+ use ieee.std_logic_1164.all;
99
+ entity {module_name}_wrapper is port({clk_and_rst_inp}
100
+ model_inp:in std_logic_vector({w_reg_in-1} downto {0});
101
+ model_out:out std_logic_vector({w_reg_out-1} downto {0})
102
+ );
103
+ end entity {module_name}_wrapper;
104
+
105
+ architecture rtl of {module_name}_wrapper is
106
+ signal packed_inp:std_logic_vector({w_het_in-1} downto {0});
107
+ signal packed_out:std_logic_vector({w_het_out-1} downto {0});
108
+
109
+ begin
110
+ {inp_assignment_str}
111
+
112
+ op:entity work.{module_name} port map({clk_and_rst_bind}
113
+ model_inp=>packed_inp,
114
+ model_out=>packed_out
115
+ );
116
+
117
+ {out_assignment_str}
118
+
119
+ end architecture rtl;
120
+ """
121
+
122
+
123
+ def binder_gen(csol: CascadedSolution | Solution, module_name: str, II: int = 1, latency_multiplier: int = 1):
124
+ k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
125
+ k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
126
+ max_inp_bw = max(k_in) + max(i_in) + max(f_in)
127
+ max_out_bw = max(k_out) + max(i_out) + max(f_out)
128
+ if isinstance(csol, Solution):
129
+ II = latency = 0
130
+ else:
131
+ latency = len(csol.solutions) * latency_multiplier
132
+
133
+ n_in, n_out = csol.shape
134
+ return f"""#include <cstddef>
135
+ #include "binder_util.hh"
136
+ #include "V{module_name}.h"
137
+
138
+ struct {module_name}_config {{
139
+ static const size_t N_inp = {n_in};
140
+ static const size_t N_out = {n_out};
141
+ static const size_t max_inp_bw = {max_inp_bw};
142
+ static const size_t max_out_bw = {max_out_bw};
143
+ static const size_t II = {II};
144
+ static const size_t latency = {latency};
145
+ typedef V{module_name} dut_t;
146
+ }};
147
+
148
+ extern "C" {{
149
+ bool openmp_enabled() {{
150
+ return _openmp;
151
+ }}
152
+
153
+ void inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {{
154
+ batch_inference<{module_name}_config>(c_inp, c_out, n_samples);
155
+ }}
156
+ }}
157
+ """
@@ -0,0 +1,71 @@
1
+ from ....cmvm.types import CascadedSolution, _minimal_kif
2
+ from .comb import comb_logic_gen
3
+
4
+
5
+ def pipeline_logic_gen(
6
+ csol: CascadedSolution,
7
+ name: str,
8
+ print_latency=False,
9
+ timescale: str | None = None,
10
+ register_layers: int = 1,
11
+ ):
12
+ N = len(csol.solutions)
13
+ inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
+ out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
+
16
+ registers = [f'signal stage{i}_inp:std_logic_vector({width-1} downto 0);' for i, width in enumerate(inp_bits)]
17
+ for i in range(0, register_layers - 1):
18
+ registers += [f'signal stage{j}_inp_copy{i}:std_logic_vector({width-1} downto 0);' for j, width in enumerate(inp_bits)]
19
+ wires = [f'signal stage{i}_out:std_logic_vector({width-1} downto 0);' for i, width in enumerate(out_bits)]
20
+
21
+ comb_logic = [
22
+ f'stage{i}:entity work.{name}_stage{i} port map(model_inp=>stage{i}_inp,model_out=>stage{i}_out);' for i in range(N)
23
+ ]
24
+
25
+ if register_layers == 1:
26
+ serial_logic = ['stage0_inp <= model_inp;']
27
+ serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
28
+ else:
29
+ serial_logic = ['stage0_inp_copy0 <= model_inp;']
30
+ for j in range(1, register_layers - 1):
31
+ serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
32
+ serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
33
+ for i in range(1, N):
34
+ serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
35
+ for j in range(1, register_layers - 1):
36
+ serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
37
+ serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
38
+
39
+ serial_logic += [f'model_out <= stage{N-1}_out;']
40
+
41
+ blk = '\n '
42
+
43
+ module = f"""library ieee;
44
+ use ieee.std_logic_1164.all;
45
+ entity {name} is port(
46
+ clk:in std_logic;
47
+ model_inp:in std_logic_vector({inp_bits[0]-1} downto 0);
48
+ model_out:out std_logic_vector({out_bits[-1]-1} downto 0));
49
+ end entity {name};
50
+
51
+ architecture rtl of {name} is
52
+ {blk.join(registers)}
53
+ {blk.join(wires)}
54
+
55
+ begin
56
+ {blk.join(comb_logic)}
57
+
58
+ process(clk) begin
59
+ if rising_edge(clk) then
60
+ {f'{blk} '.join(serial_logic)}
61
+ end if;
62
+ end process;
63
+ end architecture rtl;
64
+ """
65
+
66
+ ret: dict[str, str] = {}
67
+ for i, s in enumerate(csol.solutions):
68
+ stage_name = f'{name}_stage{i}'
69
+ ret[stage_name] = comb_logic_gen(s, stage_name, print_latency=print_latency, timescale=timescale)
70
+ ret[name] = module
71
+ return ret