da4ml 0.4.1__py3-none-any.whl → 0.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (40) hide show
  1. da4ml/__init__.py +2 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +2 -2
  4. da4ml/cmvm/api.py +15 -4
  5. da4ml/cmvm/core/__init__.py +2 -2
  6. da4ml/cmvm/types.py +32 -18
  7. da4ml/cmvm/util/bit_decompose.py +2 -2
  8. da4ml/codegen/hls/hls_codegen.py +10 -5
  9. da4ml/codegen/hls/hls_model.py +7 -4
  10. da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
  11. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  12. da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
  13. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  14. da4ml/codegen/rtl/common_source/template.xdc +11 -13
  15. da4ml/codegen/rtl/rtl_model.py +105 -54
  16. da4ml/codegen/rtl/verilog/__init__.py +2 -1
  17. da4ml/codegen/rtl/verilog/comb.py +47 -7
  18. da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
  19. da4ml/codegen/rtl/verilog/pipeline.py +12 -12
  20. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  21. da4ml/codegen/rtl/vhdl/comb.py +27 -21
  22. da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
  23. da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
  24. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  25. da4ml/converter/__init__.py +57 -1
  26. da4ml/converter/hgq2/parser.py +4 -25
  27. da4ml/converter/hgq2/replica.py +208 -22
  28. da4ml/trace/fixed_variable.py +239 -29
  29. da4ml/trace/fixed_variable_array.py +276 -48
  30. da4ml/trace/ops/__init__.py +31 -15
  31. da4ml/trace/ops/reduce_utils.py +3 -3
  32. da4ml/trace/pipeline.py +40 -18
  33. da4ml/trace/tracer.py +33 -8
  34. da4ml/typing/__init__.py +3 -0
  35. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
  36. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
  37. da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
  38. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
  39. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
  40. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,8 @@ from math import ceil, log2
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ....cmvm.types import Op, QInterval, Solution, _minimal_kif
5
+ from ....cmvm.types import CombLogic, Op, QInterval, _minimal_kif
6
+ from ..verilog.comb import get_table_name
6
7
 
7
8
 
8
9
  def make_neg(
@@ -16,7 +17,7 @@ def make_neg(
16
17
  _min, _max, step = ops[op.id0].qint
17
18
  bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
18
19
  was_signed = int(_min < 0)
19
- signals.append(f'signal v{op.id0}_neg : std_logic_vector({bw_neg-1} downto {0});')
20
+ signals.append(f'signal v{op.id0}_neg : std_logic_vector({bw_neg - 1} downto {0});')
20
21
  assigns.append(
21
22
  f'op_neg_{op.id0} : entity work.negative generic map (BW_IN => {bw0}, BW_OUT => {bw_neg}, IN_SIGNED => {was_signed}) port map (neg_in => {v0_name}, neg_out => v{op.id0}_neg);'
22
23
  )
@@ -25,7 +26,7 @@ def make_neg(
25
26
  return bw0, v0_name
26
27
 
27
28
 
28
- def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
29
+ def ssa_gen(sol: CombLogic, neg_defined: set[int], print_latency: bool = False):
29
30
  ops = sol.ops
30
31
  kifs = list(map(_minimal_kif, (op.qint for op in ops)))
31
32
  widths = list(map(sum, kifs))
@@ -46,10 +47,11 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
46
47
  if bw == 0:
47
48
  continue
48
49
 
50
+ signals.append(f'signal v{i}:std_logic_vector({bw - 1} downto {0});')
51
+
49
52
  match op.opcode:
50
53
  case -1: # Input marker
51
54
  i0, i1 = inp_idxs[op.id0]
52
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
53
55
  line = f'v{i} <= model_inp({i0} downto {i1});'
54
56
 
55
57
  case 0 | 1: # Common a+/-b<<shift oprs
@@ -57,7 +59,6 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
57
59
  bw0, bw1 = widths[op.id0], widths[op.id1]
58
60
  s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
59
61
  shift = op.data + f0 - f1
60
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
61
62
  line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{op.opcode}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
62
63
 
63
64
  case 2 | -2: # ReLU
@@ -68,12 +69,11 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
68
69
  if op.opcode == -2 and op.id0 not in neg_defined:
69
70
  neg_defined.add(op.id0)
70
71
  bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
71
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
72
72
  if ops[op.id0].qint.min < 0:
73
73
  if bw > 1:
74
- line = f'v{i} <= {v0_name}({i0} downto {i1}) and ({bw - 1} downto 0 => not {v0_name}({bw0-1}));'
74
+ line = f'v{i} <= {v0_name}({i0} downto {i1}) and ({bw - 1} downto 0 => not {v0_name}({bw0 - 1}));'
75
75
  else:
76
- line = f'v{i}(0) <= {v0_name}(0) and (not {v0_name}({bw0-1}));'
76
+ line = f'v{i}(0) <= {v0_name}(0) and (not {v0_name}({bw0 - 1}));'
77
77
  else:
78
78
  line = f'v{i} <= {v0_name}({i0} downto {i1});'
79
79
 
@@ -85,7 +85,6 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
85
85
  if op.opcode == -3 and op.id0 not in neg_defined:
86
86
  neg_defined.add(op.id0)
87
87
  bw0, v0_name = make_neg(signals, assigns, op, ops, bw0, v0_name)
88
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
89
88
  line = f'v{i} <= {v0_name}({i0} downto {i1});'
90
89
 
91
90
  case 4: # constant addition
@@ -95,14 +94,12 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
95
94
  bw0 = widths[op.id0]
96
95
  s0 = int(kifs[op.id0][0])
97
96
  shift = kifs[op.id0][2] - kifs[i][2]
98
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
99
97
  bin_val = format(mag, f'0{bw1}b')
100
98
  line = f'op_{i}:entity work.shift_adder generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>0,BW_OUT=>{bw},SHIFT1=>{shift},IS_SUB=>{sign}) port map(in0=>v{op.id0},in1=>"{bin_val}",result=>v{i});'
101
99
  case 5: # constant
102
100
  num = op.data
103
101
  if num < 0:
104
102
  num = 2**bw + num
105
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
106
103
  bin_val = format(num, f'0{bw}b')
107
104
  line = f'v{i} <= "{bin_val}";'
108
105
 
@@ -115,14 +112,23 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
115
112
  _shift = (op.data >> 32) & 0xFFFFFFFF
116
113
  _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
117
114
  shift = f0 - f1 + _shift
118
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
119
- line = f'op_{i}:entity work.mux generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},INVERT1=>{inv}) port map(key=>v{k}({bwk-1}),in0=>v{a},in1=>v{b},result=>v{i});'
115
+ v0, v1 = f'v{a}', f'v{b}'
116
+ if bw0 == 0:
117
+ v0, bw0 = 'B"0"', 1
118
+ if bw1 == 0:
119
+ v1, bw1 = 'B"0"', 1
120
+ line = f'op_{i}:entity work.mux generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw},SHIFT1=>{shift},INVERT1=>{inv}) port map(key=>v{k}({bwk - 1}),in0=>{v0},in1=>{v1},result=>v{i});'
121
+
120
122
  case 7: # Multiplication
121
123
  bw0, bw1 = widths[op.id0], widths[op.id1]
122
124
  s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
123
- signals.append(f'signal v{i}:std_logic_vector({bw-1} downto {0});')
124
125
  line = f'op_{i}:entity work.multiplier generic map(BW_INPUT0=>{bw0},BW_INPUT1=>{bw1},SIGNED0=>{s0},SIGNED1=>{s1},BW_OUT=>{bw}) port map(in0=>v{op.id0},in1=>v{op.id1},result=>v{i});'
125
126
 
127
+ case 8: # Lookup Table
128
+ name = get_table_name(sol, op)
129
+ bw0 = widths[op.id0]
130
+ line = f'op_{i}:entity work.lookup_table generic map(BW_IN=>{bw0},BW_OUT=>{bw},MEM_FILE=>"{name}") port map(inp=>v{op.id0},outp=>v{i});'
131
+
126
132
  case _:
127
133
  raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
128
134
 
@@ -132,7 +138,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
132
138
  return signals, assigns
133
139
 
134
140
 
135
- def output_gen(sol: Solution, neg_defined: set[int]):
141
+ def output_gen(sol: CombLogic, neg_defined: set[int]):
136
142
  assigns = []
137
143
  signals = []
138
144
  widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
@@ -150,17 +156,17 @@ def output_gen(sol: Solution, neg_defined: set[int]):
150
156
  neg_defined.add(idx)
151
157
  bw0 = sum(_minimal_kif(sol.ops[idx].qint))
152
158
  was_signed = int(_minimal_kif(sol.ops[idx].qint)[0])
153
- signals.append(f'signal v{idx}_neg:std_logic_vector({bw-1} downto {0});')
159
+ signals.append(f'signal v{idx}_neg:std_logic_vector({bw - 1} downto {0});')
154
160
  assigns.append(
155
161
  f'op_neg_{idx}:entity work.negative generic map(BW_IN=>{bw0},BW_OUT=>{bw},IN_SIGNED=>{was_signed}) port map(neg_in=>v{idx},neg_out=>v{idx}_neg);'
156
162
  )
157
- assigns.append(f'model_out({i0} downto {i1}) <= v{idx}_neg({bw-1} downto {0});')
163
+ assigns.append(f'model_out({i0} downto {i1}) <= v{idx}_neg({bw - 1} downto {0});')
158
164
  else:
159
- assigns.append(f'model_out({i0} downto {i1}) <= v{idx}({bw-1} downto {0});')
165
+ assigns.append(f'model_out({i0} downto {i1}) <= v{idx}({bw - 1} downto {0});')
160
166
  return signals, assigns
161
167
 
162
168
 
163
- def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
169
+ def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
164
170
  inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
165
171
  out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
166
172
 
@@ -174,8 +180,8 @@ use ieee.std_logic_1164.all;
174
180
  use ieee.numeric_std.all;
175
181
 
176
182
  entity {fn_name} is port(
177
- model_inp:in std_logic_vector({inp_bits-1} downto {0});
178
- model_out:out std_logic_vector({out_bits-1} downto {0})
183
+ model_inp:in std_logic_vector({inp_bits - 1} downto {0});
184
+ model_out:out std_logic_vector({out_bits - 1} downto {0})
179
185
  );
180
186
  end entity {fn_name};
181
187
 
@@ -1,6 +1,6 @@
1
1
  from itertools import accumulate
2
2
 
3
- from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
3
+ from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
4
4
 
5
5
 
6
6
  def _loc(i: int, j: int):
@@ -63,26 +63,26 @@ def hetero_io_map(qints: list[QInterval], merge: bool = False):
63
63
  return regular, hetero, pads, (width_regular, width_packed)
64
64
 
65
65
 
66
- def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipelined: bool = False):
66
+ def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
67
67
  reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
68
68
  reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
69
69
 
70
70
  w_reg_in, w_het_in = shape_in
71
71
  w_reg_out, w_het_out = shape_out
72
72
 
73
- inp_assignment = [f'packed_inp{_loc(ih,jh)} <= model_inp{_loc(ir,jr)};' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
73
+ inp_assignment = [f'packed_inp{_loc(ih, jh)} <= model_inp{_loc(ir, jr)};' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
74
74
  _out_assignment: list[tuple[int, str]] = []
75
75
 
76
76
  for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
77
77
  if ih == jh - 1:
78
78
  continue
79
- _out_assignment.append((ih, f'model_out{_loc(ir,jr)} <= packed_out{_loc(ih,jh)};'))
79
+ _out_assignment.append((ih, f'model_out{_loc(ir, jr)} <= packed_out{_loc(ih, jh)};'))
80
80
 
81
81
  for i, (i, j, copy_from) in enumerate(pad_out):
82
82
  n_bit = i - j + 1
83
83
  value = "'0'" if copy_from == -1 else f'packed_out({copy_from})'
84
84
  pad = f'(others => {value})' if n_bit > 1 else value
85
- _out_assignment.append((i, f'model_out{_loc(i,j)} <= {pad};'))
85
+ _out_assignment.append((i, f'model_out{_loc(i, j)} <= {pad};'))
86
86
  _out_assignment.sort(key=lambda x: x[0])
87
87
  out_assignment = [v for _, v in _out_assignment]
88
88
 
@@ -97,14 +97,14 @@ def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipe
97
97
  return f"""library ieee;
98
98
  use ieee.std_logic_1164.all;
99
99
  entity {module_name}_wrapper is port({clk_and_rst_inp}
100
- model_inp:in std_logic_vector({w_reg_in-1} downto {0});
101
- model_out:out std_logic_vector({w_reg_out-1} downto {0})
100
+ model_inp:in std_logic_vector({w_reg_in - 1} downto {0});
101
+ model_out:out std_logic_vector({w_reg_out - 1} downto {0})
102
102
  );
103
103
  end entity {module_name}_wrapper;
104
104
 
105
105
  architecture rtl of {module_name}_wrapper is
106
- signal packed_inp:std_logic_vector({w_het_in-1} downto {0});
107
- signal packed_out:std_logic_vector({w_het_out-1} downto {0});
106
+ signal packed_inp:std_logic_vector({w_het_in - 1} downto {0});
107
+ signal packed_out:std_logic_vector({w_het_out - 1} downto {0});
108
108
 
109
109
  begin
110
110
  {inp_assignment_str}
@@ -120,12 +120,12 @@ end architecture rtl;
120
120
  """
121
121
 
122
122
 
123
- def binder_gen(csol: CascadedSolution | Solution, module_name: str, II: int = 1, latency_multiplier: int = 1):
123
+ def binder_gen(csol: Pipeline | CombLogic, module_name: str, II: int = 1, latency_multiplier: int = 1):
124
124
  k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
125
125
  k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
126
126
  max_inp_bw = max(k_in) + max(i_in) + max(f_in)
127
127
  max_out_bw = max(k_out) + max(i_out) + max(f_out)
128
- if isinstance(csol, Solution):
128
+ if isinstance(csol, CombLogic):
129
129
  II = latency = 0
130
130
  else:
131
131
  latency = len(csol.solutions) * latency_multiplier
@@ -1,9 +1,9 @@
1
- from ....cmvm.types import CascadedSolution, _minimal_kif
1
+ from ....cmvm.types import Pipeline, _minimal_kif
2
2
  from .comb import comb_logic_gen
3
3
 
4
4
 
5
5
  def pipeline_logic_gen(
6
- csol: CascadedSolution,
6
+ csol: Pipeline,
7
7
  name: str,
8
8
  print_latency=False,
9
9
  timescale: str | None = None,
@@ -13,10 +13,10 @@ def pipeline_logic_gen(
13
13
  inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
14
  out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
15
 
16
- registers = [f'signal stage{i}_inp:std_logic_vector({width-1} downto 0);' for i, width in enumerate(inp_bits)]
16
+ registers = [f'signal stage{i}_inp:std_logic_vector({width - 1} downto 0);' for i, width in enumerate(inp_bits)]
17
17
  for i in range(0, register_layers - 1):
18
- registers += [f'signal stage{j}_inp_copy{i}:std_logic_vector({width-1} downto 0);' for j, width in enumerate(inp_bits)]
19
- wires = [f'signal stage{i}_out:std_logic_vector({width-1} downto 0);' for i, width in enumerate(out_bits)]
18
+ registers += [f'signal stage{j}_inp_copy{i}:std_logic_vector({width - 1} downto 0);' for j, width in enumerate(inp_bits)]
19
+ wires = [f'signal stage{i}_out:std_logic_vector({width - 1} downto 0);' for i, width in enumerate(out_bits)]
20
20
 
21
21
  comb_logic = [
22
22
  f'stage{i}:entity work.{name}_stage{i} port map(model_inp=>stage{i}_inp,model_out=>stage{i}_out);' for i in range(N)
@@ -24,19 +24,19 @@ def pipeline_logic_gen(
24
24
 
25
25
  if register_layers == 1:
26
26
  serial_logic = ['stage0_inp <= model_inp;']
27
- serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
27
+ serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
28
28
  else:
29
29
  serial_logic = ['stage0_inp_copy0 <= model_inp;']
30
30
  for j in range(1, register_layers - 1):
31
- serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
31
+ serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
32
32
  serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
33
33
  for i in range(1, N):
34
- serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
34
+ serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
35
35
  for j in range(1, register_layers - 1):
36
- serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
36
+ serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
37
37
  serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
38
38
 
39
- serial_logic += [f'model_out <= stage{N-1}_out;']
39
+ serial_logic += [f'model_out <= stage{N - 1}_out;']
40
40
 
41
41
  blk = '\n '
42
42
 
@@ -44,8 +44,8 @@ def pipeline_logic_gen(
44
44
  use ieee.std_logic_1164.all;
45
45
  entity {name} is port(
46
46
  clk:in std_logic;
47
- model_inp:in std_logic_vector({inp_bits[0]-1} downto 0);
48
- model_out:out std_logic_vector({out_bits[-1]-1} downto 0));
47
+ model_inp:in std_logic_vector({inp_bits[0] - 1} downto 0);
48
+ model_out:out std_logic_vector({out_bits[-1] - 1} downto 0));
49
49
  end entity {name};
50
50
 
51
51
  architecture rtl of {name} is
@@ -0,0 +1,52 @@
1
+ library ieee;
2
+ use ieee.std_logic_1164.all;
3
+ use ieee.numeric_std.all;
4
+ use std.textio.all;
5
+ use ieee.std_logic_textio.all;
6
+
7
+ entity lookup_table is
8
+ generic (
9
+ BW_IN : positive := 8;
10
+ BW_OUT : positive := 8;
11
+ MEM_FILE : string := "whatever.mem"
12
+ );
13
+ port (
14
+ inp : in std_logic_vector(BW_IN - 1 downto 0);
15
+ outp : out std_logic_vector(BW_OUT - 1 downto 0)
16
+ );
17
+ end entity lookup_table;
18
+
19
+ architecture rtl of lookup_table is
20
+ subtype rom_index_t is natural range 0 to (2 ** BW_IN) - 1;
21
+ type rom_array_t is array (rom_index_t) of std_logic_vector(BW_OUT - 1 downto 0);
22
+
23
+ -- Load the ROM contents from an external hex file.
24
+ impure function init_rom return rom_array_t is
25
+ file rom_file : text;
26
+ variable rom_data : rom_array_t := (others => (others => '0'));
27
+ variable line_in : line;
28
+ variable idx : integer := 0;
29
+ variable data_val : std_logic_vector(BW_OUT - 1 downto 0);
30
+ variable temp_val : std_logic_vector(((BW_OUT + 3) / 4) * 4 - 1 downto 0);
31
+ begin
32
+ file_open(rom_file, MEM_FILE, read_mode);
33
+
34
+ while not endfile(rom_file) loop
35
+ exit when idx > rom_index_t'high;
36
+ readline(rom_file, line_in);
37
+ hread(line_in, temp_val);
38
+ rom_data(idx) := temp_val(BW_OUT - 1 downto 0);
39
+ idx := idx + 1;
40
+ end loop;
41
+
42
+ file_close(rom_file);
43
+ return rom_data;
44
+ end function init_rom;
45
+
46
+ signal ROM_CONTENTS : rom_array_t := init_rom;
47
+
48
+ attribute rom_style : string;
49
+ attribute rom_style of ROM_CONTENTS : signal is "distributed";
50
+ begin
51
+ outp <= ROM_CONTENTS(to_integer(unsigned(inp)));
52
+ end architecture rtl;
@@ -1,3 +1,59 @@
1
- from .hgq2 import trace_model
1
+ from collections.abc import Callable
2
+ from typing import Any, Literal, overload
3
+
4
+ from ..cmvm.api import solver_options_t
5
+ from ..trace import FixedVariableArray, HWConfig
2
6
 
3
7
  __all__ = ['trace_model']
8
+
9
+
10
+ @overload
11
+ def trace_model( # type: ignore
12
+ model: Callable,
13
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
14
+ solver_options: solver_options_t | None = None,
15
+ verbose: bool = False,
16
+ inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
17
+ dump: Literal[False] = False,
18
+ ) -> tuple[FixedVariableArray, FixedVariableArray]: ...
19
+
20
+
21
+ @overload
22
+ def trace_model( # type: ignore
23
+ model: Callable,
24
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
25
+ solver_options: solver_options_t | None = None,
26
+ verbose: bool = False,
27
+ inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
28
+ dump: Literal[True] = False, # type: ignore
29
+ ) -> dict[str, FixedVariableArray]: ...
30
+
31
+
32
+ def trace_model( # type: ignore
33
+ model: Callable,
34
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
35
+ solver_options: dict[str, Any] | None = None,
36
+ verbose: bool = False,
37
+ inputs: tuple[FixedVariableArray, ...] | None = None,
38
+ dump=False,
39
+ ):
40
+ hwconf = HWConfig(*hwconf) if isinstance(hwconf, tuple) else hwconf
41
+
42
+ module = type(model).__module__
43
+ if module.startswith('keras.'):
44
+ import keras
45
+
46
+ from .hgq2 import trace_model as keras_trace_model
47
+
48
+ assert isinstance(model, keras.Model)
49
+
50
+ return keras_trace_model(
51
+ model,
52
+ hwconf,
53
+ solver_options=solver_options,
54
+ verbose=verbose,
55
+ inputs=inputs,
56
+ dump=dump,
57
+ )
58
+ else:
59
+ raise ValueError(f'Unsupported model type: {type(model)}')
@@ -1,11 +1,12 @@
1
1
  from collections.abc import Sequence
2
2
  from dataclasses import dataclass
3
- from typing import Any, Literal, overload
3
+ from typing import Any
4
4
 
5
5
  import keras
6
6
  import numpy as np
7
7
  from keras import KerasTensor, Operation
8
8
 
9
+ from ...cmvm.api import solver_options_t
9
10
  from ...trace import FixedVariableArray, FixedVariableArrayInput, HWConfig, comb_trace
10
11
  from ...trace.fixed_variable import FixedVariable
11
12
  from .replica import _registry
@@ -116,32 +117,10 @@ def _apply_nn(
116
117
  return {k.name: v for k, v in tensor_map.items()}
117
118
 
118
119
 
119
- @overload
120
120
  def trace_model( # type: ignore
121
121
  model: keras.Model,
122
- hwconf: HWConfig = HWConfig(1, -1, -1),
123
- solver_options: dict[str, Any] | None = None,
124
- verbose: bool = False,
125
- inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
126
- dump: Literal[False] = False,
127
- ) -> tuple[FixedVariableArray, FixedVariableArray]: ...
128
-
129
-
130
- @overload
131
- def trace_model( # type: ignore
132
- model: keras.Model,
133
- hwconf: HWConfig = HWConfig(1, -1, -1),
134
- solver_options: dict[str, Any] | None = None,
135
- verbose: bool = False,
136
- inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
137
- dump: Literal[True] = False, # type: ignore
138
- ) -> dict[str, FixedVariableArray]: ...
139
-
140
-
141
- def trace_model( # type: ignore
142
- model: keras.Model,
143
- hwconf: HWConfig = HWConfig(1, -1, -1),
144
- solver_options: dict[str, Any] | None = None,
122
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
123
+ solver_options: solver_options_t | None = None,
145
124
  verbose: bool = False,
146
125
  inputs: tuple[FixedVariableArray, ...] | None = None,
147
126
  dump=False,