da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +194 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +240 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.0.dist-info/METADATA +85 -0
  92. da4ml-0.5.0.dist-info/RECORD +96 -0
  93. da4ml-0.5.0.dist-info/WHEEL +6 -0
  94. da4ml-0.5.0.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,7 @@
1
+ from .bit_decompose import csd_decompose
2
+ from .mat_decompose import kernel_decompose
3
+
4
+ __all__ = [
5
+ 'csd_decompose',
6
+ 'kernel_decompose',
7
+ ]
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+ from numba import jit
3
+ from numpy.typing import NDArray
4
+
5
+
6
+ @jit
7
+ def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
8
+ x = x
9
+ N = np.max(np.ceil(np.log2(np.abs(x) * 1.5 + 1e-19)))
10
+ N = int(max(N, 1))
11
+ buf = np.zeros((*np.shape(x), N), dtype=np.int8)
12
+
13
+ for n in range(N - 1, -1, -1):
14
+ _2pn = 2**n
15
+ thres = _2pn / 1.5
16
+ bit = (x > thres).astype(np.int8)
17
+ bit -= (x < -thres).astype(np.int8)
18
+ x -= _2pn * bit.astype(x.dtype)
19
+ buf[..., n] = bit
20
+ return buf
21
+
22
+
23
+ @jit(error_model='numpy')
24
+ def _shift_centering(arr: NDArray):
25
+ low, high = -64, 64
26
+ if np.all(arr == 0):
27
+ high = low = 0
28
+ while high - low > 1:
29
+ mid = (high + low) // 2
30
+ xs = arr * (2.0**mid)
31
+ if np.all(xs == np.floor(xs)):
32
+ high = mid
33
+ else:
34
+ low = mid
35
+ return -high
36
+
37
+
38
+ @jit(error_model='numpy')
39
+ def shift_centering(arr: NDArray, axis: int):
40
+ n = arr.shape[axis]
41
+ shifts = np.empty(n, dtype=np.int8)
42
+ for i in range(n):
43
+ shifts[i] = _shift_centering(arr.take(i, axis=axis))
44
+ return shifts
45
+
46
+
47
+ @jit
48
+ def _center(arr: NDArray):
49
+ shift1 = shift_centering(arr, 1) # d_out
50
+ arr = arr * (2.0**-shift1)
51
+ shift0 = shift_centering(arr, 0) # d_in
52
+ arr = arr * (2.0 ** -shift0[:, None])
53
+ return arr, shift0.astype(np.int8), shift1.astype(np.int8)
54
+
55
+
56
+ @jit
57
+ def csd_decompose(arr: NDArray, center=True):
58
+ """
59
+ Convert an 2D array to CSD representation.
60
+
61
+ Parameters
62
+ ----------
63
+ arr : ndarray
64
+ Input array to be converted.
65
+ center : bool, optional
66
+ If True, the array is centered before conversion. Default is True.
67
+ If False, the function may accept non-2D arrays.
68
+
69
+ Returns
70
+ -------
71
+ csd : ndarray
72
+ CSD representation of the input array after centering, if center is True.
73
+ shift0 : ndarray
74
+ Shift values for the first axis.
75
+ shift1 : ndarray
76
+ Shift values for the second axis.
77
+ """
78
+
79
+ if center:
80
+ arr, shift0, shift1 = _center(arr)
81
+ else:
82
+ shift0 = np.zeros(arr.shape[0], dtype=np.int8)
83
+ shift1 = np.zeros(arr.shape[1], dtype=np.int8)
84
+ arr = arr.copy()
85
+ csd = _volatile_int_arr_to_csd(arr)
86
+ return csd, shift0, shift1
@@ -0,0 +1,121 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+ from numba import jit
5
+
6
+ from .bit_decompose import _center, _volatile_int_arr_to_csd
7
+
8
+
9
+ @jit
10
+ def prim_mst_dc(cost_mat: np.ndarray, dc: int = -1):
11
+ """Minimum Spanning Tree (MST) using Prim's algorithm with a delay constraint. May not be optimal.
12
+ Always start from the root node (0).
13
+
14
+ Parameters
15
+ ----------
16
+ cost_mat : np.ndarray
17
+ The adjacency matrix of the graph, where cost_mat[i, j] is the cost of the edge between i and j.
18
+
19
+ dc : int, optional
20
+ The delay constraint, by default -1
21
+ If -1, no delay constraint is applied.
22
+
23
+ Delay of each edge is ceiling(log2(cost_mat[i, j])).
24
+
25
+ Delay from the root node to any node is the **maximum** latency of each edge connecting in between,
26
+ plus ceiling(log2(#number of connection edges)).
27
+ Latency is **NOT** the sum of the latencies.
28
+
29
+ Returns
30
+ -------
31
+ np.ndarray
32
+ The adjacency list of the MST, where each row is a pair of nodes (parent, child).
33
+ """
34
+
35
+ N = len(cost_mat)
36
+ lat_mat = np.ceil(np.log2(np.maximum(cost_mat, 1)))
37
+ parent = np.full(N, -2, dtype=np.int32) # -2: not visited, -1: root
38
+
39
+ parent[0] = -1
40
+ idxs = np.arange(N)
41
+
42
+ mapping = np.empty((N - 1, 2), dtype=np.int32)
43
+ latency = np.zeros((N,), dtype=np.int32)
44
+
45
+ if dc >= 0:
46
+ _dc = (2**dc - 1) + ceil(log2(np.max(cost_mat[0]) + 1e-32))
47
+ else:
48
+ _dc = -1
49
+
50
+ for n_impl in range(1, N):
51
+ implemented = parent != -2
52
+ _cost = cost_mat[~implemented][:, implemented]
53
+ if dc >= 0:
54
+ _lat = lat_mat[~implemented][:, implemented]
55
+ _cost = np.where(np.maximum(_lat, latency[implemented]) + 1 <= _dc, _cost, np.iinfo(_cost.dtype).max // 2)
56
+ _idx = int(np.argmin(_cost))
57
+ _i, _j = _idx // n_impl, _idx % n_impl
58
+ i, j = idxs[~implemented][_i], idxs[implemented][_j]
59
+ parent[i] = j
60
+ mapping[n_impl - 1, 0] = j
61
+ mapping[n_impl - 1, 1] = i
62
+ latency[i] = max(lat_mat[i, j], latency[j]) + 1 # type: ignore
63
+
64
+ return mapping
65
+
66
+
67
+ @jit
68
+ def kernel_decompose(kernel: np.ndarray, dc: int = -2):
69
+ """Decompose a 2D kernel matrix into two matrices with the delay-constrained approx MST.
70
+
71
+ Parameters
72
+ ----------
73
+ kernel : np.ndarray
74
+ The input kernel matrix to decompose.
75
+
76
+ dc : int, optional
77
+ Delay constraint, by default -1
78
+ If -2, no delay constraint is applied.
79
+ If -1, return trivial decomposition (m0 = kernel, m1 = I).
80
+
81
+ The delay constraint limits the maximum latency (hops) of the decomposed
82
+ multiplication structure.
83
+
84
+ Returns
85
+ -------
86
+ tuple[np.ndarray, np.ndarray]
87
+ The decomposed matrices (m0, m1): kernel = m0 @ m1
88
+ """
89
+ kernel, shift0, shift1 = _center(kernel)
90
+ scale0, scale1 = 2.0**shift0, 2.0**shift1
91
+ m, n = kernel.shape[0], kernel.shape[1] + 1
92
+ mat_aug = np.zeros((m, n), dtype=kernel.dtype)
93
+ mat_aug[:, 1:] = kernel
94
+ diff0 = mat_aug[:, :, None] - mat_aug[:, None, :]
95
+ diff1 = mat_aug[:, :, None] + mat_aug[:, None, :]
96
+ dist0 = np.sum(np.sum(_volatile_int_arr_to_csd(diff0) != 0, axis=3), axis=0)
97
+ dist1 = np.sum(np.sum(_volatile_int_arr_to_csd(diff1) != 0, axis=3), axis=0)
98
+ sign = np.where(dist1 - dist0 < 0, -1, 1)
99
+ dist = np.minimum(dist0, dist1)
100
+ mapping = prim_mst_dc(dist, dc=dc)
101
+ n_in, n_out = kernel.shape
102
+ m0, m1 = np.zeros((n_in, n_out), dtype=kernel.dtype), np.zeros((n_out, n_out), dtype=kernel.dtype)
103
+
104
+ if dc == -1:
105
+ m0[:] = kernel
106
+ m1[:] = np.eye(n_out, dtype=kernel.dtype)
107
+ return m0 * scale0[:, None], m1 * scale1
108
+
109
+ cnt = 0
110
+ for _from, _to in mapping:
111
+ col0 = mat_aug[:, _to] - mat_aug[:, _from] * sign[_to, _from]
112
+ if _from != 0:
113
+ col1 = m1[:, _from - 1].copy() * sign[_to, _from]
114
+ else:
115
+ col1 = np.zeros(n_out, dtype=kernel.dtype)
116
+ if np.any(col0 != 0):
117
+ col1[cnt] = 1
118
+ m0[:, cnt] = col0
119
+ cnt += 1
120
+ m1[:, _to - 1] = col1
121
+ return m0 * scale0[:, None], m1 * scale1
@@ -0,0 +1,9 @@
1
+ from .hls import HLSModel
2
+ from .rtl import RTLModel, VerilogModel, VHDLModel
3
+
4
+ __all__ = [
5
+ 'HLSModel',
6
+ 'VerilogModel',
7
+ 'VHDLModel',
8
+ 'RTLModel',
9
+ ]
@@ -0,0 +1,4 @@
1
+ from .hls_codegen import hls_logic_and_bridge_gen
2
+ from .hls_model import HLSModel
3
+
4
+ __all__ = ['hls_logic_and_bridge_gen', 'HLSModel']
@@ -0,0 +1,196 @@
1
+ from collections.abc import Callable
2
+
3
+ from ...cmvm.types import CombLogic, QInterval, _minimal_kif
4
+ from ...trace.fixed_variable import _const_f
5
+
6
+
7
+ def kif_to_vitis_type(k: bool | int = 1, i: int = 0, f: int = 0):
8
+ if k == i == f == 0:
9
+ f = 1
10
+ return f'ap_{"" if k else "u"}fixed<{k + i + f},{k + i}>'
11
+
12
+
13
+ def kif_to_hlslib_type(k: bool | int = 1, i: int = 0, f: int = 0):
14
+ if k == i == f == 0:
15
+ f = 1
16
+ return f'ac_fixed<{int(k)},{k + i + f},{k + i}>'
17
+
18
+
19
+ def kif_to_oneapi_type(k: bool | int = 1, i: int = 0, f: int = 0):
20
+ # OneAPI requires at least 2 bits for all ac_fixed as of 2025.1
21
+ return f'ac_fixed<{int(k)},{max(k + i + f, 2)},{k + i}>'
22
+
23
+
24
+ def get_typestr_fn(flavor: str):
25
+ match flavor.lower():
26
+ case 'vitis':
27
+ typestr_fn = kif_to_vitis_type
28
+ case 'hlslib':
29
+ typestr_fn = kif_to_hlslib_type
30
+ case 'oneapi':
31
+ typestr_fn = kif_to_oneapi_type
32
+ case _:
33
+ raise ValueError(f'Unsupported flavor: {flavor}')
34
+ return typestr_fn
35
+
36
+
37
+ def ssa_gen(sol: CombLogic, print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
38
+ ops = sol.ops
39
+ all_kifs = list(map(_minimal_kif, (op.qint for op in ops)))
40
+ all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
41
+
42
+ lines = []
43
+ ref_count = sol.ref_count
44
+ for i, op in enumerate(ops):
45
+ if ref_count[i] == 0:
46
+ # Skip unused ops
47
+ continue
48
+
49
+ _type = all_types[i]
50
+
51
+ ref0 = f'v{op.id0}'
52
+
53
+ match op.opcode:
54
+ case -1:
55
+ # Input marker
56
+ val = f'model_inp[{op.id0}]'
57
+ case 0 | 1:
58
+ # Common a+/-b<<shift op
59
+ ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
60
+ val = f'{ref0} {"-" if op.opcode == 1 else "+"} {ref1}'
61
+ case 2 | -2:
62
+ if op.opcode == 2: # relu(model_inp)
63
+ if ops[op.id0].qint.min < 0:
64
+ val = f'{ref0} > 0 ? {_type}({ref0}) : {_type}(0)'
65
+ else:
66
+ val = ref0
67
+ else: # relu(-model_inp)
68
+ if ops[op.id0].qint.max > 0:
69
+ val = f'{ref0} > 0 ? {_type}(0) : {_type}(-{ref0})'
70
+ else:
71
+ val = f'-{ref0}'
72
+ case 3 | -3:
73
+ # Explicit quantization op, done implicitly via assignment
74
+ val = ref0 if op.opcode == 3 else f'-{ref0}'
75
+ case 4:
76
+ # Constant addition
77
+ _number = op.data * op.qint.step
78
+ sign, mag = ('-' if _number < 0 else '+'), abs(_number)
79
+ f = _const_f(mag)
80
+ const_type_str = typestr_fn(*_minimal_kif(QInterval(mag, mag, 2.0**-f)))
81
+ val = f'{ref0} {sign} {const_type_str}({mag})'
82
+ case 5:
83
+ # Define constant
84
+ _number = op.data * op.qint.step
85
+ val = f'{_number}'
86
+ case 6 | -6:
87
+ # MSB Mux
88
+ id_c = op.data & 0xFFFFFFFF
89
+ bw_k = sum(all_kifs[id_c])
90
+ shift = (op.data >> 32) & 0xFFFFFFFF
91
+ shift = shift if shift < 0x80000000 else shift - 0x100000000
92
+ ref_k = f'v{id_c}[{bw_k - 1}]'
93
+ sign = '-' if op.opcode == -6 else ''
94
+ ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
95
+ bw0, bw1 = sum(all_kifs[op.id0]), sum(all_kifs[op.id1])
96
+ if bw0 == 0:
97
+ ref0 = '0'
98
+ if bw1 == 0:
99
+ ref1 = '0'
100
+ val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
101
+ case 7:
102
+ # Multiplication
103
+ ref1 = f'v{op.id1}'
104
+ val = f'{ref0} * {ref1}'
105
+ case _:
106
+ raise ValueError(f'Unsupported opcode: {op.opcode}')
107
+
108
+ line = f'{_type} v{i} = {val};'
109
+
110
+ if print_latency:
111
+ line += f' // {op.latency}'
112
+ lines.append(line)
113
+ return lines
114
+
115
+
116
+ def output_gen(sol: CombLogic, typestr_fn: Callable[[bool | int, int, int], str]):
117
+ lines = []
118
+ for i, idx in enumerate(sol.out_idxs):
119
+ if idx < 0:
120
+ lines.append(f'model_out[{i}] = 0;')
121
+ continue
122
+ _type = typestr_fn(*_minimal_kif(sol.out_qint[i]))
123
+ shift = sol.out_shifts[i]
124
+ neg_str = '-' if sol.out_negs[i] else ''
125
+ if shift == 0:
126
+ lines.append(f'model_out[{i}] = {_type}({neg_str}v{idx});')
127
+ else:
128
+ lines.append(f'model_out[{i}] = {_type}({neg_str}bit_shift<{shift}>(v{idx}));')
129
+ return lines
130
+
131
+
132
+ def get_io_types(sol: CombLogic, flavor: str):
133
+ typestr_fn = get_typestr_fn(flavor)
134
+ in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
135
+ inp_type = typestr_fn(*in_kif)
136
+ out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
137
+ out_type = typestr_fn(*out_kif)
138
+ return inp_type, out_type
139
+
140
+
141
+ def hls_logic_and_bridge_gen(
142
+ sol: CombLogic,
143
+ fn_name: str,
144
+ flavor: str,
145
+ pragmas: list[str] | None = None,
146
+ n_indent: int = 4,
147
+ n_base_indent: int = 0,
148
+ print_latency: bool = False,
149
+ ):
150
+ typestr_fn = get_typestr_fn(flavor)
151
+ inp_t, out_t = get_io_types(sol, flavor)
152
+
153
+ n_in, n_out = sol.shape
154
+ template_def = 'template <typename inp_t, typename out_t>'
155
+ fn_signature = f'void {fn_name}(inp_t model_inp[{n_in}], out_t model_out[{n_out}])'
156
+ pragmas = pragmas or []
157
+
158
+ ssa_lines = ssa_gen(sol, print_latency=print_latency, typestr_fn=typestr_fn)
159
+ output_lines = output_gen(sol, typestr_fn=typestr_fn)
160
+
161
+ indent = ' ' * n_indent
162
+ base_indent = indent * n_base_indent
163
+ body_indent = '\n' + base_indent + indent
164
+ code = f"""{base_indent}{template_def}
165
+ {base_indent}{fn_signature} {{ // {inp_t} -> {out_t}
166
+ {base_indent + indent}{body_indent.join(pragmas)}
167
+ {body_indent}{body_indent.join(ssa_lines)}
168
+ {body_indent}{body_indent.join(output_lines)}
169
+ {base_indent}}}
170
+ """
171
+ bridge = f"""#include "binder_util.hh"
172
+ #include "{fn_name}.hh"
173
+
174
+ struct {fn_name}_config {{
175
+ static const size_t N_inp = {n_in};
176
+ static const size_t N_out = {n_out};
177
+ typedef {inp_t} inp_t;
178
+ typedef {out_t} out_t;
179
+ constexpr static auto f = {fn_name}<inp_t, out_t>;
180
+ }};
181
+
182
+ extern "C" {{
183
+
184
+ bool openmp_enabled() {{
185
+ return _openmp;
186
+ }}
187
+
188
+ void inference_f64(double *model_inp, double *model_out, size_t size, size_t n_threads) {{
189
+ batch_inference<{fn_name}_config, double>(model_inp, model_out, size, n_threads);
190
+ }}
191
+
192
+ void inference_f32(float *model_inp, float *model_out, size_t size, size_t n_threads) {{
193
+ batch_inference<{fn_name}_config, float>(model_inp, model_out, size, n_threads);
194
+ }}
195
+ }}"""
196
+ return code, bridge
@@ -0,0 +1,255 @@
1
+ import ctypes
2
+ import os
3
+ import re
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ from collections.abc import Sequence
8
+ from pathlib import Path
9
+ from typing import TypeVar
10
+ from uuid import uuid4
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+ from da4ml.cmvm.types import CombLogic
16
+ from da4ml.codegen.hls.hls_codegen import get_io_types, hls_logic_and_bridge_gen
17
+
18
+ from ... import codegen
19
+ from ...cmvm.types import _minimal_kif
20
+
21
+ T = TypeVar('T', bound=np.floating)
22
+
23
+
24
+ class HLSModel:
25
+ def __init__(
26
+ self,
27
+ solution: CombLogic,
28
+ prj_name: str,
29
+ path: str | Path,
30
+ flavor: str = 'vitis',
31
+ print_latency: bool = True,
32
+ part_name: str = 'xcvu13p-flga2577-2-e',
33
+ pragma: Sequence[str] | None = None,
34
+ clock_period: int = 5,
35
+ clock_uncertainty: float = 0.1,
36
+ io_delay_minmax: tuple[float, float] = (0.2, 0.4),
37
+ ):
38
+ self._solution = solution
39
+ self._prj_name = prj_name
40
+ self._path = Path(path).resolve()
41
+ self._flavor = flavor.lower()
42
+ assert self._flavor in ('vitis', 'hlslib', 'oneapi'), f'Unsupported HLS flavor: {self._flavor}'
43
+ self._print_latency = print_latency
44
+ self._part_name = part_name
45
+ self._clock_period = clock_period
46
+ self._clock_uncertainty = clock_uncertainty
47
+ self._io_delay_minmax = io_delay_minmax
48
+ self.__src_root = Path(codegen.__file__).parent
49
+ self._lib = None
50
+ self._uuid = None
51
+
52
+ if pragma is None:
53
+ if self._flavor == 'vitis':
54
+ self._pragma = (
55
+ '#pragma HLS ARRAY_PARTITION variable=inp complete',
56
+ '#pragma HLS ARRAY_PARTITION variable=out complete',
57
+ '#pragma HLS PIPELINE II=1',
58
+ )
59
+ else:
60
+ self._pragma = ()
61
+ else:
62
+ self._pragma = tuple(pragma)
63
+
64
+ def write(self):
65
+ if not self._path.exists():
66
+ self._path.mkdir(parents=True, exist_ok=True)
67
+ template_def, bridge = hls_logic_and_bridge_gen(
68
+ self._solution,
69
+ self._prj_name,
70
+ self._flavor,
71
+ ['#pragma HLS INLINE'],
72
+ 4,
73
+ 0,
74
+ self._print_latency,
75
+ )
76
+
77
+ headers = ['#pragma once', '#include "bitshift.hh"']
78
+
79
+ inp_type, out_type = get_io_types(self._solution, self._flavor)
80
+ n_in, n_out = len(self._solution.inp_qint), len(self._solution.out_qint)
81
+ template_signature = (
82
+ f'template <typename inp_t, typename out_t>\nvoid {self._prj_name}(inp_t inp[{n_in}], out_t out[{n_out}]);'
83
+ )
84
+ fn_signature = f'void {self._prj_name}_fn({inp_type} inp[{n_in}], {out_type} out[{n_out}])'
85
+
86
+ with open(self._path / f'{self._prj_name}.hh', 'w') as f:
87
+ f.write('\n'.join(headers) + '\n\n')
88
+ f.write(f'{template_signature}\n\n{fn_signature};\n')
89
+
90
+ pragma_str = '\n'.join(self._pragma)
91
+ cpp_def = f"""
92
+ #include "{self._prj_name}.hh"
93
+
94
+ {template_def}
95
+
96
+ {fn_signature} {{
97
+ {pragma_str}
98
+ {self._prj_name}<{inp_type}, {out_type}>(inp, out);
99
+ }}
100
+ """
101
+ with open(self._path / f'{self._prj_name}.cc', 'w') as f:
102
+ f.write(cpp_def)
103
+
104
+ with open(self._path / f'{self._prj_name}_bridge.cc', 'w') as f:
105
+ f.write(bridge)
106
+
107
+ shutil.copy(self.__src_root / 'hls/source/binder_util.hh', self._path)
108
+ shutil.copy(self.__src_root / f'hls/source/{self._flavor}_bitshift.hh', self._path / 'bitshift.hh')
109
+ shutil.copy(self.__src_root / 'hls/source/build_binder.mk', self._path)
110
+ if self._flavor == 'vitis':
111
+ shutil.copytree(self.__src_root / 'hls/source/ap_types', self._path / 'ap_types', dirs_exist_ok=True)
112
+ else:
113
+ pass
114
+
115
+ self._solution.save(self._path / 'project.json')
116
+
117
+ def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
118
+ """Same as compile, but will not write to the library
119
+
120
+ Parameters
121
+ ----------
122
+ verbose : bool, optional
123
+ Verbose output, by default False
124
+ openmp : bool, optional
125
+ Enable openmp, by default True
126
+ o3 : bool | None, optional
127
+ Turn on -O3 flag, by default False
128
+ clean : bool, optional
129
+ Remove obsolete shared object files, by default True
130
+
131
+ Raises
132
+ ------
133
+ RuntimeError
134
+ If compilation fails
135
+ """
136
+
137
+ self._uuid = str(uuid4())
138
+ args = ['make', '-f', 'build_binder.mk']
139
+ env = os.environ.copy()
140
+ env['PRJ_NAME'] = self._prj_name
141
+ env['STAMP'] = self._uuid
142
+ env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
143
+ if o3:
144
+ args.append('fast')
145
+
146
+ if clean:
147
+ m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
148
+ for p in self._path.iterdir():
149
+ if not p.is_dir() and m.match(p.name):
150
+ p.unlink()
151
+
152
+ try:
153
+ r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
154
+ except subprocess.CalledProcessError as e:
155
+ print(e.stderr.decode(), file=sys.stderr)
156
+ print(e.stdout.decode(), file=sys.stdout)
157
+ raise RuntimeError('Compilation failed!!') from e
158
+ if r.returncode != 0:
159
+ print(r.stderr.decode(), file=sys.stderr)
160
+ print(r.stdout.decode(), file=sys.stderr)
161
+ raise RuntimeError('Compilation failed!!')
162
+
163
+ self._load_lib(self._uuid)
164
+
165
+ def _load_lib(self, uuid: str | None = None):
166
+ uuid = uuid if uuid is not None else self._uuid
167
+ self._uuid = uuid
168
+ lib_path = self._path / f'lib{self._prj_name}_{uuid}.so'
169
+ if not lib_path.exists():
170
+ raise RuntimeError(f'Library {lib_path} does not exist')
171
+ self._lib = ctypes.CDLL(str(lib_path))
172
+
173
+ def compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
174
+ """Compile the model to a shared object file
175
+
176
+ Parameters
177
+ ----------
178
+ verbose : bool, optional
179
+ Verbose output, by default False
180
+ openmp : bool, optional
181
+ Enable openmp, by default True
182
+ o3 : bool | None, optional
183
+ Turn on -O3 flag, by default False
184
+ clean : bool, optional
185
+ Remove obsolete shared object files, by default True
186
+
187
+ Raises
188
+ ------
189
+ RuntimeError
190
+ If compilation fails
191
+ """
192
+ self.write()
193
+ self._compile(verbose, openmp, o3, clean)
194
+
195
+ def predict(self, data: NDArray[T] | Sequence[NDArray[T]], n_threads: int = 0) -> NDArray[T]:
196
+ """Run the model on the input data.
197
+
198
+ Parameters
199
+ ----------
200
+ data: NDArray[np.floating] | Sequence[NDArray[np.floating]]
201
+ Input data to the model. The shape is ignored, and the number of samples is
202
+ determined by the size of the data.
203
+
204
+ Returns
205
+ -------
206
+ NDArray[np.floating]
207
+ Output of the model in shape (n_samples, output_size).
208
+ """
209
+ assert self._lib is not None, 'Library not loaded, call .compile() first.'
210
+ inp_size, out_size = self._solution.shape
211
+
212
+ if isinstance(data, Sequence):
213
+ data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
214
+
215
+ dtype = data.dtype
216
+ if dtype not in (np.float32, np.float64):
217
+ raise TypeError(f'Unsupported input data type: {dtype}. Expected float32 or float64.')
218
+ c_dtype = ctypes.c_float if dtype == np.float32 else ctypes.c_double
219
+
220
+ assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
221
+ n_sample = data.size // inp_size
222
+
223
+ inp_data = np.ascontiguousarray(data)
224
+ out_data = np.empty(n_sample * out_size, dtype=dtype)
225
+
226
+ inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(c_dtype))
227
+ out_buf = out_data.ctypes.data_as(ctypes.POINTER(c_dtype))
228
+ if dtype == np.float32:
229
+ self._lib.inference_f32(inp_buf, out_buf, n_sample, n_threads)
230
+ else:
231
+ self._lib.inference_f64(inp_buf, out_buf, n_sample, n_threads)
232
+
233
+ return out_data.reshape(n_sample, out_size) # type: ignore
234
+
235
+ def __repr__(self):
236
+ inp_size, out_size = self._solution.shape
237
+ inp_size, out_size = self._solution.shape
238
+ cost = round(self._solution.cost)
239
+ inp_kifs = tuple(zip(*map(_minimal_kif, self._solution.inp_qint)))
240
+ out_kifs = tuple(zip(*map(_minimal_kif, self._solution.out_qint)))
241
+ in_bits, out_bits = np.sum(inp_kifs), np.sum(out_kifs)
242
+
243
+ spec = f"""Top Function: {self._prj_name}\n====================
244
+ {inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
245
+ combinational @ delay={self._solution.latency}
246
+ Estimated cost: {cost} LUTs"""
247
+
248
+ is_compiled = self._lib is not None
249
+ if is_compiled:
250
+ assert self._uuid is not None
251
+ openmp = 'with OpenMP' if self._lib.openmp_enabled() else '' # type: ignore
252
+ spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
253
+ else:
254
+ spec += '\nEmulator is **not compiled**'
255
+ return spec