da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +204 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +246 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.1.post1.dist-info/METADATA +85 -0
  92. da4ml-0.5.1.post1.dist-info/RECORD +96 -0
  93. da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
  94. da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,239 @@
1
+ from hashlib import sha256
2
+ from math import ceil, log2
3
+ from uuid import UUID
4
+
5
+ import numpy as np
6
+
7
+ from ....cmvm.types import CombLogic, Op, QInterval, _minimal_kif
8
+
9
+
10
+ def make_neg(lines: list[str], idx: int, qint: QInterval, v0_name: str, neg_repo: dict[int, tuple[int, str]]):
11
+ if idx == 21568:
12
+ pass
13
+ if idx in neg_repo:
14
+ return neg_repo[idx]
15
+ _min, _max, step = qint
16
+ bw0 = sum(_minimal_kif(qint))
17
+ bw_neg = sum(_minimal_kif(QInterval(-_max, -_min, step)))
18
+ was_signed = int(_min < 0)
19
+ lines.append(
20
+ f'wire [{bw_neg - 1}:0] v{idx}_neg; negative #({bw0}, {bw_neg}, {was_signed}) op_neg_{idx} ({v0_name}, v{idx}_neg);'
21
+ )
22
+ bw0 = bw_neg
23
+ v0_name = f'v{idx}_neg'
24
+ neg_repo[idx] = (bw0, v0_name)
25
+ return bw0, v0_name
26
+
27
+
28
+ def gen_mem_file(sol: CombLogic, op: Op) -> str:
29
+ assert op.opcode == 8
30
+ assert sol.lookup_tables is not None
31
+ table = sol.lookup_tables[op.data]
32
+ width = sum(table.spec.out_kif)
33
+ ndigits = ceil(width / 4)
34
+ data = table.padded_table(sol.ops[op.id0].qint)
35
+ mem_lines = [f'{hex(value)[2:].upper().zfill(ndigits)}' for value in data & ((1 << width) - 1)]
36
+ return '\n'.join(mem_lines)
37
+
38
+
39
+ def get_table_name(sol: CombLogic, op: Op) -> str:
40
+ memfile = gen_mem_file(sol, op)
41
+ hash_obj = sha256(memfile.encode('utf-8'))
42
+ _int = int(hash_obj.hexdigest()[:32], 16)
43
+ uuid = UUID(int=_int, version=4)
44
+ return f'table_{str(uuid)}.mem'
45
+
46
+
47
+ def ssa_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]], print_latency: bool = False) -> list[str]:
48
+ ops = sol.ops
49
+ kifs = list(map(_minimal_kif, (op.qint for op in ops)))
50
+ widths: list[int] = list(map(sum, kifs))
51
+ inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
52
+ inp_widths = list(map(sum, inp_kifs))
53
+ _inp_widths = np.cumsum([0] + inp_widths)
54
+ inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
55
+
56
+ lines: list[str] = []
57
+ ref_count = sol.ref_count
58
+
59
+ for i, op in enumerate(ops):
60
+ if ref_count[i] == 0:
61
+ continue
62
+
63
+ bw = widths[i]
64
+ v = f'v{i}[{bw - 1}:0]'
65
+ _def = f'wire [{bw - 1}:0] v{i};'
66
+ if bw == 0:
67
+ continue
68
+
69
+ match op.opcode:
70
+ case -1: # Input marker
71
+ i0, i1 = inp_idxs[op.id0]
72
+ line = f'{_def} assign {v} = model_inp[{i0}:{i1}];'
73
+
74
+ case 0 | 1: # Common a+/-b<<shift oprs
75
+ p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
76
+
77
+ bw0, bw1 = widths[op.id0], widths[op.id1] # width
78
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
79
+ shift = op.data + f0 - f1
80
+ v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
81
+
82
+ line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
83
+
84
+ case 2 | -2: # ReLU
85
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
86
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
87
+
88
+ v0_name = f'v{op.id0}'
89
+ bw0 = widths[op.id0]
90
+
91
+ if op.opcode == -2:
92
+ bw0, v0_name = make_neg(lines, op.id0, ops[op.id0].qint, v0_name, neg_repo)
93
+ if ops[op.id0].qint.min < 0:
94
+ line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}] & {{{bw}{{~{v0_name}[{bw0 - 1}]}}}};'
95
+ else:
96
+ line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
97
+
98
+ case 3 | -3: # Explicit quantization
99
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
100
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
101
+ v0_name = f'v{op.id0}'
102
+ bw0 = widths[op.id0]
103
+
104
+ if op.opcode == -3:
105
+ bw0, v0_name = make_neg(lines, op.id0, ops[op.id0].qint, v0_name, neg_repo)
106
+
107
+ if i0 >= bw0:
108
+ if op.opcode == 3:
109
+ assert ops[op.id0].qint.min < 0, f'{i}, {op.id0}'
110
+ else:
111
+ assert ops[op.id0].qint.max > 0, f'{i}, {op.id0}'
112
+
113
+ if i1 >= bw0:
114
+ v0_name = f'{{{i0 - i1 + 1}{{{v0_name}[{bw0 - 1}]}}}}'
115
+ else:
116
+ v0_name = f'{{{{{i0 - bw0 + 1}{{{v0_name}[{bw0 - 1}]}}}}, {v0_name}[{bw0 - 1}:{i1}]}}'
117
+ line = f'{_def} assign {v} = {v0_name};'
118
+ else:
119
+ line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
120
+
121
+ case 4: # constant addition
122
+ num = op.data
123
+ sign, mag = int(num < 0), abs(num)
124
+ bw1 = ceil(log2(mag + 1))
125
+ bw0 = widths[op.id0]
126
+ s0 = int(kifs[op.id0][0])
127
+ v0 = f'v{op.id0}[{bw0 - 1}:0]'
128
+ v1 = f"{bw1}'{bin(mag)[1:]}"
129
+ shift = kifs[op.id0][2] - kifs[i][2]
130
+
131
+ line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
132
+
133
+ case 5: # constant
134
+ num = op.data
135
+ if num < 0:
136
+ num = 2**bw + num
137
+ line = f"{_def} assign {v} = '{bin(num)[1:]};"
138
+
139
+ case 6 | -6: # MSB Muxing
140
+ k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
141
+ p0, p1 = kifs[a], kifs[b]
142
+ inv = '1' if op.opcode == -6 else '0'
143
+ bwk, bw0, bw1 = widths[k], widths[a], widths[b]
144
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
145
+ _shift = (op.data >> 32) & 0xFFFFFFFF
146
+ _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
147
+ shift = f0 - f1 + _shift
148
+ vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
149
+ if bw0 == 0:
150
+ v0, bw0 = "1'b0", 1
151
+ if bw1 == 0:
152
+ v1, bw1 = "1'b0", 1
153
+
154
+ line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
155
+
156
+ case 7: # Multiplication
157
+ bw0, bw1 = widths[op.id0], widths[op.id1] # width
158
+ s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
159
+ v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
160
+
161
+ line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
162
+
163
+ case 8: # Lookup Table
164
+ name = get_table_name(sol, op)
165
+ bw0 = widths[op.id0]
166
+
167
+ line = f'{_def} lookup_table #({bw0}, {bw}, "{name}") op_{i} (v{op.id0}, {v});'
168
+
169
+ case _:
170
+ raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
171
+
172
+ if print_latency:
173
+ line += f' // {op.latency}'
174
+ lines.append(line)
175
+ return lines
176
+
177
+
178
+ def output_gen(sol: CombLogic, neg_repo: dict[int, tuple[int, str]]) -> list[str]:
179
+ lines = []
180
+ widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
181
+ _widths = np.cumsum([0] + widths)
182
+ out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
183
+ for i, idx in enumerate(sol.out_idxs):
184
+ if idx < 0:
185
+ continue
186
+ i0, i1 = out_idxs[i]
187
+ if i0 == i1 - 1:
188
+ continue
189
+ bw = widths[i]
190
+ if sol.out_negs[i]:
191
+ _, name = make_neg(lines, idx, sol.ops[idx].qint, f'v{idx}', neg_repo)
192
+ lines.append(f'assign model_out[{i0}:{i1}] = {name}[{bw - 1}:0];')
193
+
194
+ else:
195
+ lines.append(f'assign model_out[{i0}:{i1}] = v{idx}[{bw - 1}:0];')
196
+ return lines
197
+
198
+
199
+ def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
200
+ inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
201
+ out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
202
+
203
+ fn_signature = [
204
+ f'module {fn_name} (',
205
+ f' input [{inp_bits - 1}:0] model_inp,',
206
+ f' output [{out_bits - 1}:0] model_out',
207
+ ');',
208
+ ]
209
+
210
+ neg_repo: dict[int, tuple[int, str]] = {}
211
+ ssa_lines = ssa_gen(sol, neg_repo=neg_repo, print_latency=print_latency)
212
+ output_lines = output_gen(sol, neg_repo)
213
+
214
+ indent = ' '
215
+ base_indent = '\n'
216
+ body_indent = base_indent + indent
217
+ code = f"""{base_indent[1:]}{base_indent.join(fn_signature)}
218
+
219
+ // verilator lint_off UNUSEDSIGNAL
220
+ // Explicit quantization operation will drop bits if exists
221
+
222
+ {body_indent.join(ssa_lines)}
223
+
224
+ // verilator lint_on UNUSEDSIGNAL
225
+
226
+ {body_indent.join(output_lines)}
227
+
228
+ endmodule
229
+ """
230
+ if timescale is not None:
231
+ code = f'{timescale}\n\n{code}'
232
+ return code
233
+
234
+
235
+ def table_mem_gen(sol: CombLogic) -> dict[str, str]:
236
+ if not sol.lookup_tables:
237
+ return {}
238
+ mem_files = {get_table_name(sol, op): gen_mem_file(sol, op) for op in sol.ops if op.opcode == 8}
239
+ return mem_files
@@ -0,0 +1,113 @@
1
+ from itertools import accumulate
2
+
3
+ from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
4
+
5
+
6
+ def hetero_io_map(qints: list[QInterval], merge: bool = False):
7
+ N = len(qints)
8
+ ks, _is, fs = zip(*map(_minimal_kif, qints))
9
+ Is = [_i + _k for _i, _k in zip(_is, ks)]
10
+ max_I, max_f = max(_is) + max(ks), max(fs)
11
+ max_bw = max_I + max_f
12
+ width_regular, width_packed = max_bw * N, sum(Is) + sum(fs)
13
+
14
+ regular: list[tuple[int, int]] = []
15
+ pads: list[tuple[int, int, int]] = []
16
+
17
+ bws = [I + f for I, f in zip(Is, fs)]
18
+ _bw = list(accumulate([0] + bws))
19
+ hetero = [(i - 1, j) for i, j in zip(_bw[1:], _bw[:-1])]
20
+
21
+ for i in range(N):
22
+ base = max_bw * i
23
+ bias_low = max_f - fs[i]
24
+ bias_high = max_I - Is[i]
25
+ low = base + bias_low
26
+ high = (base + max_bw - 1) - bias_high
27
+ regular.append((high, low))
28
+
29
+ if bias_low != 0:
30
+ pads.append((base + bias_low - 1, base, -1))
31
+ if bias_high != 0:
32
+ copy_from = hetero[i][0] if ks[i] else -1
33
+ pads.append((base + max_bw - 1, base + max_bw - bias_high, copy_from))
34
+
35
+ mask = list(high < low for high, low in hetero)
36
+ regular = [r for r, m in zip(regular, mask) if not m]
37
+ hetero = [h for h, m in zip(hetero, mask) if not m]
38
+
39
+ if not merge:
40
+ return regular, hetero, pads, (width_regular, width_packed)
41
+
42
+ # Merging consecutive intervals when possible
43
+ NN = len(regular) - 2
44
+ for i in range(NN, -1, -1):
45
+ this_high = regular[i][0]
46
+ next_low = regular[i + 1][1]
47
+ if next_low - this_high != 1:
48
+ continue
49
+ regular[i] = (regular[i + 1][0], regular[i][1])
50
+ regular.pop(i + 1)
51
+ hetero[i] = (hetero[i + 1][0], hetero[i][1])
52
+ hetero.pop(i + 1)
53
+
54
+ for i in range(len(pads) - 2, -1, -1):
55
+ if pads[i + 1][1] - pads[i][0] == 1 and pads[i][2] == pads[i + 1][2]:
56
+ pads[i] = (pads[i + 1][0], pads[i][1], pads[i][2])
57
+ pads.pop(i + 1)
58
+
59
+ return regular, hetero, pads, (width_regular, width_packed)
60
+
61
+
62
+ def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
63
+ reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
64
+ reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
65
+
66
+ w_reg_in, w_het_in = shape_in
67
+ w_reg_out, w_het_out = shape_out
68
+
69
+ inp_assignment = [f'assign packed_inp[{ih}:{jh}] = model_inp[{ir}:{jr}];' for (ih, jh), (ir, jr) in zip(het_in, reg_in)]
70
+ _out_assignment: list[tuple[int, str]] = []
71
+
72
+ for i, ((ih, jh), (ir, jr)) in enumerate(zip(het_out, reg_out)):
73
+ if ih == jh - 1:
74
+ continue
75
+ _out_assignment.append((ih, f'assign model_out[{ir}:{jr}] = packed_out[{ih}:{jh}];'))
76
+
77
+ for i, (i, j, copy_from) in enumerate(pad_out):
78
+ n_bit = i - j + 1
79
+ pad = f"{n_bit}'b0" if copy_from == -1 else f'{{{n_bit}{{packed_out[{copy_from}]}}}}'
80
+ _out_assignment.append((i, f'assign model_out[{i}:{j}] = {pad};'))
81
+ _out_assignment.sort(key=lambda x: x[0])
82
+ out_assignment = [v for _, v in _out_assignment]
83
+
84
+ inp_assignment_str = '\n '.join(inp_assignment)
85
+ out_assignment_str = '\n '.join(out_assignment)
86
+
87
+ clk_and_rst_inp, clk_and_rst_bind = '', ''
88
+ if pipelined:
89
+ clk_and_rst_inp = '\n input clk,'
90
+ clk_and_rst_bind = '\n .clk(clk),'
91
+
92
+ return f"""`timescale 1 ns / 1 ps
93
+
94
+ module {module_name}_wrapper ({clk_and_rst_inp}
95
+ // verilator lint_off UNUSEDSIGNAL
96
+ input [{w_reg_in - 1}:0] model_inp,
97
+ // verilator lint_on UNUSEDSIGNAL
98
+ output [{w_reg_out - 1}:0] model_out
99
+ );
100
+ wire [{w_het_in - 1}:0] packed_inp;
101
+ wire [{w_het_out - 1}:0] packed_out;
102
+
103
+ {inp_assignment_str}
104
+
105
+ {module_name} op ({clk_and_rst_bind}
106
+ .model_inp(packed_inp),
107
+ .model_out(packed_out)
108
+ );
109
+
110
+ {out_assignment_str}
111
+
112
+ endmodule
113
+ """
@@ -0,0 +1,67 @@
1
+ from ....cmvm.types import Pipeline, _minimal_kif
2
+ from .comb import comb_logic_gen
3
+
4
+
5
+ def pipeline_logic_gen(
6
+ csol: Pipeline,
7
+ name: str,
8
+ print_latency=False,
9
+ timescale: str | None = '`timescale 1 ns / 1 ps',
10
+ register_layers: int = 1,
11
+ ):
12
+ N = len(csol.solutions)
13
+ inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
+ out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
+
16
+ registers = [f'reg [{width - 1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
17
+ for i in range(0, register_layers - 1):
18
+ registers += [f'reg [{width - 1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
+ wires = [f'wire [{width - 1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
20
+
21
+ comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
22
+
23
+ if register_layers == 1:
24
+ serial_logic = ['stage0_inp <= model_inp;']
25
+ serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
26
+ else:
27
+ serial_logic = ['stage0_inp_copy0 <= model_inp;']
28
+ for j in range(1, register_layers - 1):
29
+ serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
30
+ serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
31
+ for i in range(1, N):
32
+ serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
33
+ for j in range(1, register_layers - 1):
34
+ serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
35
+ serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
36
+
37
+ serial_logic += [f'model_out <= stage{N - 1}_out;']
38
+
39
+ sep0 = '\n '
40
+ sep1 = '\n '
41
+
42
+ module = f"""module {name} (
43
+ input clk,
44
+ input [{inp_bits[0] - 1}:0] model_inp,
45
+ output reg [{out_bits[-1] - 1}:0] model_out
46
+ );
47
+
48
+ {sep0.join(registers)}
49
+ {sep0.join(wires)}
50
+
51
+ {sep0.join(comb_logic)}
52
+
53
+ always @(posedge clk) begin
54
+ {sep1.join(serial_logic)}
55
+ end
56
+ endmodule
57
+ """
58
+
59
+ if timescale:
60
+ module = f'{timescale}\n\n{module}'
61
+
62
+ ret: dict[str, str] = {}
63
+ for i, s in enumerate(csol.solutions):
64
+ stage_name = f'{name}_stage{i}'
65
+ ret[stage_name] = comb_logic_gen(s, stage_name, print_latency=print_latency, timescale=timescale)
66
+ ret[name] = module
67
+ return ret
@@ -0,0 +1,27 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module lookup_table #(
5
+ parameter BW_IN = 8,
6
+ parameter BW_OUT = 8,
7
+ parameter MEM_FILE = "whatever.mem"
8
+ ) (
9
+ input [BW_IN-1:0] in,
10
+ output [BW_OUT-1:0] out
11
+ );
12
+
13
+ (* rom_style = (BW_IN <= 999) ? "distributed" : "block" *)
14
+ reg [BW_OUT-1:0] lut_rom [0:(1<<BW_IN)-1];
15
+ reg [BW_OUT-1:0] readout;
16
+
17
+ initial begin
18
+ $readmemh(MEM_FILE, lut_rom);
19
+ end
20
+
21
+ assign out[BW_OUT-1:0] = readout[BW_OUT-1:0];
22
+
23
+ always @(*) begin
24
+ readout = lut_rom[in];
25
+ end
26
+
27
+ endmodule
@@ -0,0 +1,37 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module multiplier #(
5
+ parameter BW_INPUT0 = 32,
6
+ parameter BW_INPUT1 = 32,
7
+ parameter SIGNED0 = 0,
8
+ parameter SIGNED1 = 0,
9
+ parameter BW_OUT = 32
10
+ ) (
11
+ input [BW_INPUT0-1:0] in0,
12
+ input [BW_INPUT1-1:0] in1,
13
+ output [BW_OUT-1:0] out
14
+ );
15
+
16
+ localparam BW_BUF = BW_INPUT0 + BW_INPUT1;
17
+
18
+ // verilator lint_off UNUSEDSIGNAL
19
+ wire [BW_BUF - 1:0] buffer;
20
+ // verilator lint_on UNUSEDSIGNAL
21
+
22
+ generate
23
+ if (SIGNED0 == 1 && SIGNED1 == 1) begin : signed_signed
24
+ assign buffer[BW_BUF-1:0] = $signed(in0) * $signed(in1);
25
+ end else if (SIGNED0 == 1 && SIGNED1 == 0) begin : signed_unsigned
26
+ assign buffer[BW_BUF-1:0] = $signed(in0) * $signed({{1'b0,in1}});
27
+ // assign buffer[BW_BUF-1] = in0[BW_INPUT0-1];
28
+ end else if (SIGNED0 == 0 && SIGNED1 == 1) begin : unsigned_signed
29
+ assign buffer[BW_BUF-1:0] = $signed({{1'b0,in0}}) * $signed(in1);
30
+ // assign buffer[BW_BUF-1] = in1[BW_INPUT1-1];
31
+ end else begin : unsigned_unsigned
32
+ assign buffer[BW_BUF-1:0] = in0 * in1;
33
+ end
34
+ endgenerate
35
+
36
+ assign out[BW_OUT-1:0] = buffer[BW_OUT-1:0];
37
+ endmodule
@@ -0,0 +1,58 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module mux #(
5
+ parameter BW_INPUT0 = 32,
6
+ parameter BW_INPUT1 = 32,
7
+ parameter SIGNED0 = 0,
8
+ parameter SIGNED1 = 0,
9
+ parameter BW_OUT = 32,
10
+ parameter SHIFT1 = 0,
11
+ parameter INVERT1 = 0
12
+ ) (
13
+ input key,
14
+ input [BW_INPUT0-1:0] in0,
15
+ input [BW_INPUT1-1:0] in1,
16
+ output [BW_OUT-1:0] out
17
+ );
18
+
19
+ localparam IN0_NEED_BITS = (SHIFT1 < 0) ? BW_INPUT0 - SHIFT1 : BW_INPUT0;
20
+ localparam IN1_NEED_BITS = (SHIFT1 > 0) ? BW_INPUT1 + SHIFT1 : BW_INPUT1;
21
+ localparam EXTRA_PAD = (SIGNED0 != SIGNED1) ? INVERT1 + 1 : INVERT1 + 0;
22
+ localparam BW_BUF = (IN0_NEED_BITS > IN1_NEED_BITS) ? IN0_NEED_BITS + EXTRA_PAD : IN1_NEED_BITS + EXTRA_PAD;
23
+ localparam IN0_PAD_LEFT = (SHIFT1 < 0) ? BW_BUF - BW_INPUT0 + SHIFT1 : BW_BUF - BW_INPUT0;
24
+ localparam IN0_PAD_RIGHT = (SHIFT1 < 0) ? -SHIFT1 : 0;
25
+ localparam IN1_PAD_LEFT = (SHIFT1 > 0) ? BW_BUF - BW_INPUT1 - SHIFT1 : BW_BUF - BW_INPUT1;
26
+ localparam IN1_PAD_RIGHT = (SHIFT1 > 0) ? SHIFT1 : 0;
27
+
28
+
29
+ // verilator lint_off UNUSEDSIGNAL
30
+ wire [BW_BUF-1:0] in0_ext;
31
+ wire [BW_BUF-1:0] in1_ext;
32
+ // verilator lint_on UNUSEDSIGNAL
33
+
34
+ generate
35
+ if (SIGNED0 == 1) begin : in0_is_signed
36
+ assign in0_ext = {{IN0_PAD_LEFT{in0[BW_INPUT0-1]}}, in0, {IN0_PAD_RIGHT{1'b0}}};
37
+ end else begin : in0_is_unsigned
38
+ assign in0_ext = {{IN0_PAD_LEFT{1'b0}}, in0, {IN0_PAD_RIGHT{1'b0}}};
39
+ end
40
+ endgenerate
41
+
42
+ generate
43
+ if (SIGNED1 == 1) begin : in1_is_signed
44
+ assign in1_ext = {{IN1_PAD_LEFT{in1[BW_INPUT1-1]}}, in1, {IN1_PAD_RIGHT{1'b0}}};
45
+ end else begin : in1_is_unsigned
46
+ assign in1_ext = {{IN1_PAD_LEFT{1'b0}}, in1, {IN1_PAD_RIGHT{1'b0}}};
47
+ end
48
+ endgenerate
49
+
50
+ generate
51
+ if (INVERT1 == 1) begin : is_invert
52
+ assign out = (key) ? in0_ext[BW_OUT-1:0] : -in1_ext[BW_OUT-1:0];
53
+ end else begin : is_not_invert
54
+ assign out = (key) ? in0_ext[BW_OUT-1:0] : in1_ext[BW_OUT-1:0];
55
+ end
56
+ endgenerate
57
+
58
+ endmodule
@@ -0,0 +1,31 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module negative #(
5
+ parameter BW_IN = 32,
6
+ parameter BW_OUT = 32,
7
+ parameter IN_SIGNED = 0
8
+ ) (
9
+ // verilator lint_off UNUSEDSIGNAL
10
+ input [ BW_IN-1:0] in,
11
+ // verilator lint_off UNUSEDSIGNAL
12
+ output [BW_OUT-1:0] out
13
+ );
14
+ /* verilator lint_off WIDTHTRUNC */
15
+ generate
16
+ if (BW_IN < BW_OUT) begin : in_is_smaller
17
+ wire [BW_OUT-1:0] in_ext;
18
+ if (IN_SIGNED == 1) begin : is_signed
19
+ assign in_ext = {{BW_OUT - BW_IN{in[BW_IN-1]}}, in};
20
+ end else begin : is_unsigned
21
+ assign in_ext = {{BW_OUT - BW_IN{1'b0}}, in};
22
+ end
23
+ assign out = -in_ext;
24
+ end else begin : in_is_bigger
25
+ wire [BW_IN-1:0] out_ext;
26
+ assign out_ext = -in;
27
+ assign out = out_ext[BW_OUT-1:0];
28
+ end
29
+ endgenerate
30
+ /* verilator lint_on WIDTHTRUNC */
31
+ endmodule
@@ -0,0 +1,59 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module shift_adder #(
5
+ parameter BW_INPUT0 = 32,
6
+ parameter BW_INPUT1 = 32,
7
+ parameter SIGNED0 = 0,
8
+ parameter SIGNED1 = 0,
9
+ parameter BW_OUT = 32,
10
+ parameter SHIFT1 = 0,
11
+ parameter IS_SUB = 0
12
+ ) (
13
+ input [BW_INPUT0-1:0] in0,
14
+ input [BW_INPUT1-1:0] in1,
15
+ output [BW_OUT-1:0] out
16
+ );
17
+
18
+ localparam IN0_NEED_BITS = (SHIFT1 < 0) ? BW_INPUT0 - SHIFT1 : BW_INPUT0;
19
+ localparam IN1_NEED_BITS = (SHIFT1 > 0) ? BW_INPUT1 + SHIFT1 : BW_INPUT1;
20
+ localparam EXTRA_PAD = (SIGNED0 != SIGNED1) ? IS_SUB + 1 : IS_SUB + 0;
21
+ localparam BW_ADD = (IN0_NEED_BITS > IN1_NEED_BITS) ? IN0_NEED_BITS + EXTRA_PAD + 1 : IN1_NEED_BITS + EXTRA_PAD + 1;
22
+ localparam IN0_PAD_LEFT = (SHIFT1 < 0) ? BW_ADD - BW_INPUT0 + SHIFT1 : BW_ADD - BW_INPUT0;
23
+ localparam IN0_PAD_RIGHT = (SHIFT1 < 0) ? -SHIFT1 : 0;
24
+ localparam IN1_PAD_LEFT = (SHIFT1 > 0) ? BW_ADD - BW_INPUT1 - SHIFT1 : BW_ADD - BW_INPUT1;
25
+ localparam IN1_PAD_RIGHT = (SHIFT1 > 0) ? SHIFT1 : 0;
26
+
27
+ wire [BW_ADD-1:0] in0_ext;
28
+ wire [BW_ADD-1:0] in1_ext;
29
+
30
+ // verilator lint_off UNUSEDSIGNAL
31
+ wire [BW_ADD-1:0] accum;
32
+ // verilator lint_on UNUSEDSIGNAL
33
+
34
+ generate
35
+ if (SIGNED0 == 1) begin : in0_is_signed
36
+ assign in0_ext = {{IN0_PAD_LEFT{in0[BW_INPUT0-1]}}, in0, {IN0_PAD_RIGHT{1'b0}}};
37
+ end else begin : in0_is_unsigned
38
+ assign in0_ext = {{IN0_PAD_LEFT{1'b0}}, in0, {IN0_PAD_RIGHT{1'b0}}};
39
+ end
40
+ endgenerate
41
+
42
+ generate
43
+ if (SIGNED1 == 1) begin : in1_is_signed
44
+ assign in1_ext = {{IN1_PAD_LEFT{in1[BW_INPUT1-1]}}, in1, {IN1_PAD_RIGHT{1'b0}}};
45
+ end else begin : in1_is_unsigned
46
+ assign in1_ext = {{IN1_PAD_LEFT{1'b0}}, in1, {IN1_PAD_RIGHT{1'b0}}};
47
+ end
48
+ endgenerate
49
+
50
+ generate
51
+ if (IS_SUB == 1) begin : is_sub
52
+ assign accum = in0_ext - in1_ext;
53
+ end else begin : is_add
54
+ assign accum = in0_ext + in1_ext;
55
+ end
56
+ endgenerate
57
+ assign out = accum[BW_OUT-1:0];
58
+
59
+ endmodule
@@ -0,0 +1,9 @@
1
+ from .comb import comb_logic_gen
2
+ from .io_wrapper import generate_io_wrapper
3
+ from .pipeline import pipeline_logic_gen
4
+
5
+ __all__ = [
6
+ 'comb_logic_gen',
7
+ 'generate_io_wrapper',
8
+ 'pipeline_logic_gen',
9
+ ]