da4ml 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +235 -73
  5. da4ml/cmvm/core/__init__.py +221 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +67 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +74 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +268 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +187 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +122 -0
  35. da4ml-0.2.1.dist-info/METADATA +65 -0
  36. da4ml-0.2.1.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,121 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+ from numba import jit
5
+
6
+ from .bit_decompose import _center, _volatile_int_arr_to_csd
7
+
8
+
9
+ @jit
10
+ def prim_mst_dc(cost_mat: np.ndarray, dc: int = -1):
11
+ """Minimum Spanning Tree (MST) using Prim's algorithm with a delay constraint. May not be optimal.
12
+ Always start from the root node (0).
13
+
14
+ Parameters
15
+ ----------
16
+ cost_mat : np.ndarray
17
+ The adjacency matrix of the graph, where cost_mat[i, j] is the cost of the edge between i and j.
18
+
19
+ dc : int, optional
20
+ The delay constraint, by default -1
21
+ If -1, no delay constraint is applied.
22
+
23
+ Delay of each edge is ceiling(log2(cost_mat[i, j])).
24
+
25
+ Delay from the root node to any node is the **maximum** latency of each edge connecting in between,
26
+ plus ceiling(log2(#number of connection edges)).
27
+ Latency is **NOT** the sum of the latencies.
28
+
29
+ Returns
30
+ -------
31
+ np.ndarray
32
+ The adjacency list of the MST, where each row is a pair of nodes (parent, child).
33
+ """
34
+
35
+ N = len(cost_mat)
36
+ lat_mat = np.ceil(np.log2(np.maximum(cost_mat, 1)))
37
+ parent = np.full(N, -2, dtype=np.int32) # -2: not visited, -1: root
38
+
39
+ parent[0] = -1
40
+ idxs = np.arange(N)
41
+
42
+ mapping = np.empty((N - 1, 2), dtype=np.int32)
43
+ latency = np.zeros((N,), dtype=np.int32)
44
+
45
+ if dc >= 0:
46
+ _dc = (2**dc - 1) + ceil(log2(np.max(cost_mat[0]) + 1e-32))
47
+ else:
48
+ _dc = -1
49
+
50
+ for n_impl in range(1, N):
51
+ implemented = parent != -2
52
+ _cost = cost_mat[~implemented][:, implemented]
53
+ if dc >= 0:
54
+ _lat = lat_mat[~implemented][:, implemented]
55
+ _cost = np.where(np.maximum(_lat, latency[implemented]) + 1 <= _dc, _cost, np.iinfo(_cost.dtype).max // 2)
56
+ _idx = int(np.argmin(_cost))
57
+ _i, _j = _idx // n_impl, _idx % n_impl
58
+ i, j = idxs[~implemented][_i], idxs[implemented][_j]
59
+ parent[i] = j
60
+ mapping[n_impl - 1, 0] = j
61
+ mapping[n_impl - 1, 1] = i
62
+ latency[i] = max(lat_mat[i, j], latency[j]) + 1 # type: ignore
63
+
64
+ return mapping
65
+
66
+
67
+ @jit
68
+ def kernel_decompose(kernel: np.ndarray, dc: int = -2):
69
+ """Decompose a 2D kernel matrix into two matrices with the delay-constrained approx MST.
70
+
71
+ Parameters
72
+ ----------
73
+ kernel : np.ndarray
74
+ The input kernel matrix to decompose.
75
+
76
+ dc : int, optional
77
+ Delay constraint, by default -1
78
+ If -2, no delay constraint is applied.
79
+ If -1, return trivial decomposition (m0 = kernel, m1 = I).
80
+
81
+ The delay constraint limits the maximum latency (hops) of the decomposed
82
+ multiplication structure.
83
+
84
+ Returns
85
+ -------
86
+ tuple[np.ndarray, np.ndarray]
87
+ The decomposed matrices (m0, m1): kernel = m0 @ m1
88
+ """
89
+ kernel, shift0, shift1 = _center(kernel)
90
+ scale0, scale1 = 2.0**shift0, 2.0**shift1
91
+ m, n = kernel.shape[0], kernel.shape[1] + 1
92
+ mat_aug = np.zeros((m, n), dtype=kernel.dtype)
93
+ mat_aug[:, 1:] = kernel
94
+ diff0 = mat_aug[:, :, None] - mat_aug[:, None, :]
95
+ diff1 = mat_aug[:, :, None] + mat_aug[:, None, :]
96
+ dist0 = np.sum(np.sum(_volatile_int_arr_to_csd(diff0) != 0, axis=3), axis=0)
97
+ dist1 = np.sum(np.sum(_volatile_int_arr_to_csd(diff1) != 0, axis=3), axis=0)
98
+ sign = np.where(dist1 - dist0 < 0, -1, 1)
99
+ dist = np.minimum(dist0, dist1)
100
+ mapping = prim_mst_dc(dist, dc=dc)
101
+ n_in, n_out = kernel.shape
102
+ m0, m1 = np.zeros((n_in, n_out), dtype=kernel.dtype), np.zeros((n_out, n_out), dtype=kernel.dtype)
103
+
104
+ if dc == -1:
105
+ m0[:] = kernel
106
+ m1[:] = np.eye(n_out, dtype=kernel.dtype)
107
+ return m0 * scale0[:, None], m1 * scale1
108
+
109
+ cnt = 0
110
+ for _from, _to in mapping:
111
+ col0 = mat_aug[:, _to] - mat_aug[:, _from] * sign[_to, _from]
112
+ if _from != 0:
113
+ col1 = m1[:, _from - 1].copy() * sign[_to, _from]
114
+ else:
115
+ col1 = np.zeros(n_out, dtype=kernel.dtype)
116
+ if np.any(col0 != 0):
117
+ col1[cnt] = 1
118
+ m0[:, cnt] = col0
119
+ cnt += 1
120
+ m1[:, _to - 1] = col1
121
+ return m0 * scale0[:, None], m1 * scale1
@@ -0,0 +1,11 @@
1
+ from .cpp import cpp_logic_and_bridge_gen
2
+ from .verilog import comb_binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_binder_gen, pipeline_logic_gen
3
+
4
+ __all__ = [
5
+ 'cpp_logic_and_bridge_gen',
6
+ 'comb_logic_gen',
7
+ 'generate_io_wrapper',
8
+ 'comb_binder_gen',
9
+ 'pipeline_logic_gen',
10
+ 'pipeline_binder_gen',
11
+ ]
@@ -0,0 +1,3 @@
1
+ from .cpp_codegen import cpp_logic_and_bridge_gen
2
+
3
+ __all__ = ['cpp_logic_and_bridge_gen']
@@ -0,0 +1,148 @@
1
+ from collections.abc import Callable
2
+
3
+ from ...cmvm.types import Op, QInterval, Solution, _minimal_kif
4
+ from ...trace.fixed_variable import _const_f
5
+
6
+
7
+ def kif_to_vitis_type(k: bool | int = 1, i: int = 0, f: int = 0):
8
+ if k == i == f == 0:
9
+ f = 1
10
+ return f'ap_{"" if k else "u"}fixed<{k+i+f},{k+i}>'
11
+
12
+
13
+ def kif_to_hlslib_type(k: bool | int = 1, i: int = 0, f: int = 0):
14
+ if k == i == f == 0:
15
+ f = 1
16
+ return f'ac_fixed<{int(k)},{k+i+f},{k+i}>'
17
+
18
+
19
+ def get_typestr_fn(flavor: str):
20
+ match flavor.lower():
21
+ case 'vitis':
22
+ typestr_fn = kif_to_vitis_type
23
+ case 'hlslib':
24
+ typestr_fn = kif_to_hlslib_type
25
+ case _:
26
+ raise ValueError(f'Unsupported flavor: {flavor}')
27
+ return typestr_fn
28
+
29
+
30
+ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
31
+ all_kifs = map(_minimal_kif, (op.qint for op in ops))
32
+ all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
33
+
34
+ lines = []
35
+
36
+ for i, op in enumerate(ops):
37
+ _type = all_types[i]
38
+
39
+ ref0 = f'v{op.id0}'
40
+
41
+ match op.opcode:
42
+ case -1:
43
+ # Input marker
44
+ val = f'inp[{ops[op.id0].id0}]'
45
+
46
+ case 0 | 1:
47
+ # Common a+/-b<<shift op
48
+ ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
49
+ val = f'{ref0} {"-" if op.opcode == 1 else "+"} {ref1}'
50
+
51
+ case 2 | -2:
52
+ if op.opcode == 2: # relu(inp)
53
+ if ops[op.id0].qint.min < 0:
54
+ val = f'{ref0} > 0 ? {_type}({ref0}) : {_type}(0)'
55
+ else:
56
+ val = ref0
57
+ else: # relu(-inp)
58
+ if ops[op.id0].qint.max > 0:
59
+ val = f'{ref0} > 0 ? {_type}(0) : {_type}(-{ref0})'
60
+ else:
61
+ val = f'-{ref0}'
62
+
63
+ case 3 | -3:
64
+ # Explicit quantization op, done implicitly via assignment
65
+ val = ref0 if op.opcode == 3 else f'-{ref0}'
66
+
67
+ case 4:
68
+ # Constant addition
69
+ _number = op.data * op.qint.step
70
+ sign, mag = ('-' if _number < 0 else '+'), abs(_number)
71
+ f = _const_f(mag)
72
+ const_type_str = typestr_fn(*_minimal_kif(QInterval(mag, mag, 2.0**-f)))
73
+ val = f'{ref0} {sign} {const_type_str}({mag})'
74
+
75
+ case 5:
76
+ _number = op.data * op.qint.step
77
+ val = f'{_number}'
78
+
79
+ case _:
80
+ raise ValueError(f'Unsupported opcode: {op.opcode}')
81
+
82
+ line = f'{_type} v{i} = {val};'
83
+
84
+ if print_latency:
85
+ line += f' // {op.latency}'
86
+ lines.append(line)
87
+ return lines
88
+
89
+
90
+ def output_gen(sol: Solution, typestr_fn: Callable[[bool | int, int, int], str]):
91
+ lines = []
92
+ for i, idx in enumerate(sol.out_idxs):
93
+ if idx < 0:
94
+ lines.append(f'out[{i}] = 0;')
95
+ continue
96
+ _type = typestr_fn(*_minimal_kif(sol.out_qint[i]))
97
+ shift = sol.out_shifts[i]
98
+ neg_str = '-' if sol.out_negs[i] else ''
99
+ if shift == 0:
100
+ lines.append(f'out[{i}] = {_type}({neg_str}v{idx});')
101
+ else:
102
+ lines.append(f'out[{i}] = {_type}({neg_str}bit_shift<{shift}>(v{idx}));')
103
+ return lines
104
+
105
+
106
+ def cpp_logic_and_bridge_gen(
107
+ sol: Solution,
108
+ fn_name: str,
109
+ flavor: str,
110
+ pragmas: list[str] | None = None,
111
+ n_indent: int = 4,
112
+ n_base_indent: int = 0,
113
+ print_latency: bool = False,
114
+ ):
115
+ typestr_fn = get_typestr_fn(flavor)
116
+ in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
117
+ inp_type = typestr_fn(*in_kif)
118
+ out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
119
+ out_type = typestr_fn(*out_kif)
120
+
121
+ n_in, n_out = sol.shape
122
+ template_def = 'template <typename inp_t, typename out_t>'
123
+ fn_signature = f'void {fn_name}(inp_t inp[{n_in}], out_t out[{n_out}])'
124
+ pragmas = pragmas or []
125
+
126
+ ssa_lines = ssa_gen(sol.ops, print_latency=print_latency, typestr_fn=typestr_fn)
127
+ output_lines = output_gen(sol, typestr_fn=typestr_fn)
128
+
129
+ indent = ' ' * n_indent
130
+ base_indent = indent * n_base_indent
131
+ body_indent = '\n' + base_indent + indent
132
+ code = f"""{base_indent}{template_def}
133
+ {base_indent}{fn_signature} {{ // {inp_type} -> {out_type}
134
+ {body_indent}{body_indent.join(pragmas)}
135
+ {body_indent}{body_indent.join(ssa_lines)}
136
+ {body_indent}{body_indent.join(output_lines)}
137
+ {base_indent}}}
138
+ """
139
+ bridge = f"""#include "bridge.h"
140
+ #include "fn.h"
141
+
142
+ extern "C" {{
143
+ void bridge(double *inp, double *out, int size) {{
144
+ auto fn = {fn_name}<{inp_type}, {out_type}>;
145
+ vitis_bridge<{inp_type}, {out_type}, {n_in}, {n_out}>(fn, inp, out, size);
146
+ }}
147
+ }}"""
148
+ return code, bridge
@@ -0,0 +1,30 @@
1
+ #pragma once
2
+ #include "ap_fixed.h"
3
+
4
+ template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_fixed<b, i + s> bit_shift(ap_fixed<b, i, Q, O, N> x) {
5
+ #pragma HLS INLINE
6
+ ap_fixed<b, i + s> r;
7
+ r.range() = x.range();
8
+ return r;
9
+ };
10
+
11
+ template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_ufixed<b, i + s> bit_shift(ap_ufixed<b, i, Q, O, N> x) {
12
+ #pragma HLS INLINE
13
+ ap_ufixed<b, i + s> r;
14
+ r.range() = x.range();
15
+ return r;
16
+ };
17
+
18
+ template <int s, int b> ap_fixed<b, s> bit_shift(ap_int<b> x) {
19
+ #pragma HLS INLINE
20
+ ap_fixed<b, s> r;
21
+ r.range() = x.range();
22
+ return r;
23
+ };
24
+
25
+ template <int s, int b> ap_ufixed<b, s> bit_shift(ap_uint<b> x) {
26
+ #pragma HLS INLINE
27
+ ap_ufixed<b, s> r;
28
+ r.range() = x.range();
29
+ return r;
30
+ };
@@ -0,0 +1,17 @@
1
+ #pragma once
2
+ #include "ap_fixed.h"
3
+
4
+ template <typename inp_t, typename out_t, size_t SIZE_IN, size_t SIZE_OUT, typename F>
5
+ void vitis_bridge(F f, double *inp, double *out, int size) {
6
+ inp_t in_fixed_buf[SIZE_IN];
7
+ out_t out_fixed_buf[SIZE_OUT];
8
+ for (int i = 0; i < size; i++) {
9
+ for (int j = 0; j < SIZE_IN; j++) {
10
+ in_fixed_buf[j] = inp_t(inp[i * SIZE_IN + j]);
11
+ }
12
+ f(in_fixed_buf, out_fixed_buf);
13
+ for (int j = 0; j < SIZE_OUT; j++) {
14
+ out[i * SIZE_OUT + j] = double(out_fixed_buf[j]);
15
+ }
16
+ }
17
+ }
@@ -0,0 +1,13 @@
1
+ from .comb import comb_logic_gen
2
+ from .io_wrapper import comb_binder_gen, generate_io_wrapper, pipeline_binder_gen
3
+ from .pipeline import pipeline_logic_gen
4
+ from .verilog_model import VerilogModel
5
+
6
+ __all__ = [
7
+ 'comb_logic_gen',
8
+ 'generate_io_wrapper',
9
+ 'comb_binder_gen',
10
+ 'pipeline_logic_gen',
11
+ 'pipeline_binder_gen',
12
+ 'VerilogModel',
13
+ ]
@@ -0,0 +1,146 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+
5
+ from da4ml.cmvm.types import Op, QInterval, Solution, _minimal_kif
6
+
7
+
8
+ def ssa_gen(ops: list[Op], print_latency: bool = False):
9
+ kifs = list(map(_minimal_kif, (op.qint for op in ops)))
10
+ widths = list(map(sum, kifs))
11
+ inp_kifs = [_minimal_kif(op.qint) for op in ops if op.opcode == -1]
12
+ inp_widths = list(map(sum, inp_kifs))
13
+ _inp_widths = np.cumsum([0] + inp_widths)
14
+ inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
15
+
16
+ lines = []
17
+
18
+ for i, op in enumerate(ops):
19
+ bw = widths[i]
20
+ v = f'v{i}[{bw-1}:0]'
21
+ _def = f'wire [{bw-1}:0] v{i};'
22
+
23
+ match op.opcode:
24
+ case -1: # Input marker
25
+ i0, i1 = inp_idxs[op.id0]
26
+ line = f'{_def} assign {v} = inp[{i0}:{i1}];'
27
+ case 2 | -2: # ReLU
28
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
29
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
30
+
31
+ v0_name = f'v{op.id0}'
32
+ bw0 = widths[op.id0]
33
+
34
+ if op.opcode == -2:
35
+ _min, _max, step = ops[op.id0].qint
36
+ bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
37
+ lines.append(
38
+ f'wire [{bw_neg-1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg-1}:0] = -{v0_name}[{bw0-1}:0];'
39
+ )
40
+ v0_name = f'v{op.id0}_neg'
41
+ if ops[op.id0].qint.min < 0:
42
+ line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}] & {{{bw}{{~{v0_name}[{bw0-1}]}}}};'
43
+ else:
44
+ line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
45
+ case 3 | -3: # Explicit quantization
46
+ lsb_bias = kifs[op.id0][2] - kifs[i][2]
47
+ i0, i1 = bw + lsb_bias - 1, lsb_bias
48
+ v0_name = f'v{op.id0}'
49
+ bw0 = widths[op.id0]
50
+
51
+ if op.opcode == -3:
52
+ _min, _max, step = ops[op.id0].qint
53
+ bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
54
+ lines.append(
55
+ f'wire [{bw_neg-1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg-1}:0] = -{v0_name}[{bw0-1}:0];'
56
+ )
57
+ v0_name = f'v{op.id0}_neg'
58
+
59
+ line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
60
+ case 4: # constant addition
61
+ num = op.data
62
+ sign, mag = int(num < 0), abs(num)
63
+ line = f"{_def} assign {v} = '{bin(mag)[1:]};"
64
+ bw1 = ceil(log2(mag + 1))
65
+ bw0 = widths[op.id0]
66
+ s0 = int(kifs[op.id0][0])
67
+ v0 = f'v{op.id0}[{bw0-1}:0]'
68
+ v1 = f"'{bin(mag)[1:]}"
69
+ shift = int(log2(op.qint.step / ops[op.id0].qint.step))
70
+ line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
71
+ case 5: # constant
72
+ num = op.data
73
+ if num < 0:
74
+ num = 2**bw + num
75
+ line = f"{_def} assign {v} = '{bin(num)[1:]};"
76
+
77
+ case 0 | 1: # Common a+/-b<<shift oprs
78
+ p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
79
+
80
+ bw0, bw1 = widths[op.id0], widths[op.id1] # width
81
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
82
+ shift = op.data + f0 - f1
83
+ v0, v1 = f'v{op.id0}[{bw0-1}:0]', f'v{op.id1}[{bw1-1}:0]'
84
+
85
+ line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
86
+ case _:
87
+ raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
88
+
89
+ if print_latency:
90
+ line += f' // {op.latency}'
91
+ lines.append(line)
92
+ return lines
93
+
94
+
95
+ def output_gen(sol: Solution):
96
+ lines = []
97
+ widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
98
+ _widths = np.cumsum([0] + widths)
99
+ out_idxs = np.stack([_widths[1:] - 1, _widths[:-1]], axis=1)
100
+ for i, idx in enumerate(sol.out_idxs):
101
+ if idx < 0:
102
+ continue
103
+ i0, i1 = out_idxs[i]
104
+ bw = widths[i]
105
+ bw0 = sum(_minimal_kif(sol.ops[idx].qint))
106
+ if sol.out_negs[i]:
107
+ lines.append(f'wire [{bw-1}:0] out_neg{i}; assign out_neg{i} = -v{idx}[{bw0-1}:0];')
108
+ lines.append(f'assign out[{i0}:{i1}] = out_neg{i}[{bw-1}:0];')
109
+ else:
110
+ lines.append(f'assign out[{i0}:{i1}] = v{idx}[{bw-1}:0];')
111
+ return lines
112
+
113
+
114
+ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
115
+ inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
116
+ out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
117
+
118
+ fn_signature = [
119
+ f'module {fn_name} (',
120
+ f' input [{inp_bits-1}:0] inp,',
121
+ f' output [{out_bits-1}:0] out',
122
+ ');',
123
+ ]
124
+
125
+ ssa_lines = ssa_gen(sol.ops, print_latency=print_latency)
126
+ output_lines = output_gen(sol)
127
+
128
+ indent = ' '
129
+ base_indent = '\n'
130
+ body_indent = base_indent + indent
131
+ code = f"""{base_indent[1:]}{base_indent.join(fn_signature)}
132
+
133
+ // verilator lint_off UNUSEDSIGNAL
134
+ // Explicit quantization operation will drop bits if exists
135
+
136
+ {body_indent.join(ssa_lines)}
137
+
138
+ // verilator lint_on UNUSEDSIGNAL
139
+
140
+ {body_indent.join(output_lines)}
141
+
142
+ endmodule
143
+ """
144
+ if timescale is not None:
145
+ code = f'{timescale}\n\n{code}'
146
+ return code