da4ml 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (59) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/cmvm/api.py +2 -6
  3. da4ml/cmvm/core/__init__.py +0 -1
  4. da4ml/cmvm/types.py +99 -19
  5. da4ml/codegen/__init__.py +5 -4
  6. da4ml/codegen/cpp/__init__.py +2 -1
  7. da4ml/codegen/cpp/cpp_codegen.py +58 -25
  8. da4ml/codegen/cpp/hls_model.py +252 -0
  9. da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
  10. da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
  11. da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
  12. da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
  13. da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
  14. da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
  15. da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
  16. da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
  17. da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
  18. da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
  19. da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
  20. da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
  21. da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
  22. da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
  23. da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
  24. da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
  25. da4ml/codegen/cpp/source/binder_util.hh +56 -0
  26. da4ml/codegen/cpp/source/build_binder.mk +24 -0
  27. da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
  28. da4ml/codegen/verilog/__init__.py +2 -3
  29. da4ml/codegen/verilog/comb.py +65 -24
  30. da4ml/codegen/verilog/io_wrapper.py +36 -141
  31. da4ml/codegen/verilog/pipeline.py +21 -3
  32. da4ml/codegen/verilog/source/binder_util.hh +72 -0
  33. da4ml/codegen/verilog/source/build_prj.tcl +0 -1
  34. da4ml/codegen/verilog/source/mux.v +58 -0
  35. da4ml/codegen/verilog/source/negative.v +28 -0
  36. da4ml/codegen/verilog/source/shift_adder.v +4 -1
  37. da4ml/codegen/verilog/source/template.xdc +3 -0
  38. da4ml/codegen/verilog/verilog_model.py +42 -15
  39. da4ml/converter/__init__.py +0 -0
  40. da4ml/converter/hgq2/parser.py +105 -0
  41. da4ml/converter/hgq2/replica.py +383 -0
  42. da4ml/trace/__init__.py +2 -2
  43. da4ml/trace/fixed_variable.py +177 -18
  44. da4ml/trace/fixed_variable_array.py +124 -9
  45. da4ml/trace/ops/__init__.py +22 -6
  46. da4ml/trace/ops/conv_utils.py +146 -14
  47. da4ml/trace/ops/einsum_utils.py +9 -6
  48. da4ml/trace/ops/reduce_utils.py +103 -0
  49. da4ml/trace/pipeline.py +36 -34
  50. da4ml/trace/tracer.py +37 -5
  51. da4ml-0.3.0.dist-info/METADATA +107 -0
  52. da4ml-0.3.0.dist-info/RECORD +64 -0
  53. da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
  54. da4ml-0.2.0.dist-info/METADATA +0 -65
  55. da4ml-0.2.0.dist-info/RECORD +0 -39
  56. /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
  57. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/WHEEL +0 -0
  58. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/licenses/LICENSE +0 -0
  59. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,15 @@
1
+ import typing
1
2
  from collections.abc import Sequence
3
+ from math import ceil, prod
2
4
  from typing import TypeVar
3
5
 
4
6
  import numpy as np
5
7
  from numpy.typing import NDArray
6
8
 
7
- from ..fixed_variable_array import FixedVariableArray
9
+ from .reduce_utils import reduce
10
+
11
+ if typing.TYPE_CHECKING:
12
+ from ..fixed_variable_array import FixedVariableArray
8
13
 
9
14
 
10
15
  def r_im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int):
@@ -33,23 +38,23 @@ def stride_arr(stride: int | tuple[int, ...], arr: np.ndarray):
33
38
  ndim = arr.ndim
34
39
  if isinstance(stride, int):
35
40
  stride = (stride,) * (ndim - 1)
36
- assert len(stride) == ndim - 1, f'Invalid stride {stride} for array with {ndim} dimensions'
37
41
 
38
42
  _idx = tuple(slice(None, None, st) for st in stride)
39
43
  return arr[*_idx]
40
44
 
41
45
 
42
- T = TypeVar('T', FixedVariableArray, NDArray[np.integer | np.floating])
46
+ TA = TypeVar('TA', 'FixedVariableArray', NDArray[np.integer | np.floating])
43
47
 
44
48
 
45
- def conv(
46
- x: T,
49
+ def _conv(
50
+ x: TA,
47
51
  kernel: NDArray[np.integer | np.floating],
48
52
  bias: NDArray[np.integer | np.floating] | None = None,
49
53
  strides: int | tuple[int, ...] = 1,
50
54
  padding: tuple[tuple[int, int], ...] | str = 'VALID',
51
- format: str = 'channels_last',
52
- ):
55
+ ) -> TA:
56
+ from ..fixed_variable_array import FixedVariableArray
57
+
53
58
  if isinstance(x, FixedVariableArray):
54
59
  solver_options = x.solver_options
55
60
  data = x._vars
@@ -63,10 +68,10 @@ def conv(
63
68
  ch_in, ch_out = kernel.shape[-2:]
64
69
  _ch_in = data.shape[-1]
65
70
  assert ch_in == _ch_in, f'Invalid input shape {data.shape} for kernel {kernel.shape}'
66
- assert kernel.ndim == ndim + 1
67
-
68
- assert format in ('channels_last', 'channels_first'), f'Invalid format {format}'
69
-
71
+ if kernel.ndim != ndim + 1:
72
+ if kernel.ndim == ndim:
73
+ raise ValueError('Inputs should not contain batch dimension')
74
+ raise ValueError(f'Invalid kernel shape {kernel.shape} for input with {ndim} dimensions')
70
75
  if isinstance(strides, int):
71
76
  strides = (strides,) * (ndim - 1)
72
77
  assert len(strides) == ndim - 1, f'Invalid stride {strides} for array with {ndim} dimensions'
@@ -89,16 +94,143 @@ def conv(
89
94
 
90
95
  data = np.pad(data, padding + ((0, 0),), mode='constant', constant_values=0.0)
91
96
  data = _im2col(kernel.shape, data)
97
+ data = stride_arr(strides, data)
92
98
  if is_symbolic:
93
99
  _data = FixedVariableArray(data, solver_options) @ kernel.reshape(-1, ch_out)
94
100
  data = _data._vars
95
101
  else:
96
102
  data = data @ kernel.reshape(-1, ch_out)
97
- data = stride_arr(strides, data)
98
103
  if bias is not None:
99
104
  data = data + bias
105
+ if isinstance(x, FixedVariableArray):
106
+ return FixedVariableArray(data, solver_options)
107
+ return data
108
+
109
+
110
+ def conv(
111
+ x: TA,
112
+ kernel: NDArray[np.integer | np.floating],
113
+ bias: NDArray[np.integer | np.floating] | None = None,
114
+ strides: int | tuple[int, ...] = 1,
115
+ padding: tuple[tuple[int, int], ...] | str = 'VALID',
116
+ format: str = 'channels_last',
117
+ groups: int | None = None,
118
+ ) -> TA:
119
+ from ..fixed_variable_array import FixedVariableArray
120
+
121
+ assert format in ('channels_last', 'channels_first'), f'Invalid format {format}'
122
+ if format == 'channels_first':
123
+ x = np.moveaxis(x, 0, -1) # type: ignore
124
+
125
+ *_, _ch_in, ch_out = kernel.shape
126
+ ch_in = x.shape[-1]
127
+ assert ch_in % _ch_in == 0, f'groups is not integer (total_ch_in={ch_in}, kernel_ch_in={_ch_in})'
128
+ if groups is None:
129
+ groups = ch_in // _ch_in
130
+ else:
131
+ assert (
132
+ groups == ch_in // _ch_in
133
+ ), f'groups {groups} does not match input channels {ch_in} and kernel input channels {_ch_in}'
134
+ assert ch_out % groups == 0, f'groups is not integer (total_ch_out={ch_out}, groups={groups})'
135
+ _ch_out = ch_out // groups
136
+
137
+ buf: list[TA] = []
138
+ for gp in range(groups):
139
+ _kernel = kernel[..., gp * _ch_out : (gp + 1) * _ch_out]
140
+ _x = x[..., gp * _ch_in : (gp + 1) * _ch_in]
141
+ _buf = _conv(
142
+ _x,
143
+ _kernel,
144
+ strides=strides,
145
+ padding=padding,
146
+ )
147
+ buf.append(_buf) # type: ignore
148
+
149
+ if isinstance(x, FixedVariableArray):
150
+ data = np.concatenate([b._vars for b in buf], axis=-1) # type: ignore
151
+ else:
152
+ data = np.concatenate(buf, axis=-1) # type: ignore
153
+
154
+ data = data + bias if bias is not None else data
155
+
100
156
  if format == 'channels_first':
101
- data = np.moveaxis(data, -1, 1)
102
- if solver_options is not None:
157
+ return np.moveaxis(data, -1, 0) # type: ignore
158
+
159
+ if isinstance(x, FixedVariableArray):
160
+ return FixedVariableArray(data, x.solver_options)
161
+ return data
162
+
163
+
164
+ def pool(
165
+ x: TA,
166
+ pool_size: Sequence[int],
167
+ strides: int | Sequence[int] | None = None,
168
+ padding: tuple[tuple[int, int], ...] | str = 'VALID',
169
+ pool_type: str = 'avg',
170
+ format: str = 'channels_last',
171
+ ) -> TA:
172
+ from ..fixed_variable import FixedVariable
173
+ from ..fixed_variable_array import FixedVariableArray
174
+
175
+ if isinstance(x, FixedVariableArray):
176
+ solver_options = x.solver_options
177
+ data = x._vars
178
+ else:
179
+ solver_options = None
180
+ data = x
181
+
182
+ if format == 'channels_first':
183
+ data = np.moveaxis(data, 0, -1)
184
+
185
+ strides = strides or pool_size
186
+
187
+ assert pool_type in ('avg', 'max'), f'Invalid pool type {pool_type}'
188
+ ndim = data.ndim
189
+ if isinstance(strides, int):
190
+ strides = (strides,) * (ndim - 1)
191
+ assert len(strides) == ndim - 1, f'Invalid stride {strides} for array with {ndim} dimensions'
192
+
193
+ if isinstance(padding, str):
194
+ padding = padding.upper()
195
+ if padding == 'VALID':
196
+ padding = ((0, 0),) * (ndim - 1)
197
+ elif padding == 'SAME':
198
+ _padding = []
199
+ for i in range(ndim - 1):
200
+ n_pad = ceil(data.shape[i] / strides[i]) * strides[i] + (pool_size[i] - strides[i]) - data.shape[i]
201
+ pad0 = n_pad // 2
202
+ pad1 = n_pad - pad0
203
+ _padding.append((pad0, pad1))
204
+ padding = tuple(_padding)
205
+ else:
206
+ raise ValueError(f'Invalid padding {padding}')
207
+ assert len(padding) == ndim - 1, f'Invalid padding {padding} for array with {ndim} dimensions'
208
+ assert all(len(p) == 2 for p in padding), f'Invalid padding {padding} for array with {ndim} dimensions'
209
+
210
+ data = np.pad(data, padding + ((0, 0),), mode='constant', constant_values=-np.inf)
211
+ ch_in = data.shape[-1]
212
+ fake_kernel_shape = tuple(pool_size) + (ch_in, ch_in)
213
+ data = _im2col(fake_kernel_shape, data)
214
+ data = data.reshape(*data.shape[:-1], prod(pool_size), ch_in)
215
+ data = stride_arr(tuple(strides), data)
216
+ if pool_type == 'avg':
217
+ div = np.sum(data != -np.inf, axis=-2)
218
+ data = np.where(data == -np.inf, 0, data)
219
+ data = reduce(lambda x, y: x + y, data, axis=-2) * (1 / div)
220
+ else:
221
+
222
+ def max_of(a, b):
223
+ if isinstance(a, FixedVariable):
224
+ return a.max_of(b)
225
+ if isinstance(b, FixedVariable):
226
+ return b.max_of(a)
227
+ return max(a, b)
228
+
229
+ data = reduce(lambda x, y: max_of(x, y), data, axis=-2)
230
+
231
+ if format == 'channels_first':
232
+ data = np.moveaxis(data, -1, 0)
233
+
234
+ if isinstance(x, FixedVariableArray):
103
235
  return FixedVariableArray(data, solver_options)
104
236
  return data
@@ -1,10 +1,11 @@
1
1
  from math import prod
2
- from typing import TypedDict, overload
2
+ from typing import TYPE_CHECKING, TypedDict, overload
3
3
 
4
4
  import numpy as np
5
5
  from numpy.typing import NDArray
6
6
 
7
- from ..fixed_variable_array import FixedVariableArray
7
+ if TYPE_CHECKING:
8
+ from ..fixed_variable_array import FixedVariableArray
8
9
 
9
10
 
10
11
  class EinsumRecipe(TypedDict):
@@ -105,7 +106,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, .
105
106
  # Axes expansion in input0 or input1 only
106
107
  if '0' in sax_in0:
107
108
  if len(sax_in0) - 1 > len(shape0):
108
- raise ValueError(f'Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given')
109
+ raise ValueError(f'Input0 requires at least {len(sax_in0) - 1} dimensions, but only {len(shape0)} given')
109
110
  # Replace auto expansion indices with free indices
110
111
  n_broadcast = len(shape0) - len(sax_in0) + 1
111
112
  in0 = in0.replace('0', free_indices[:n_broadcast])
@@ -118,7 +119,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, .
118
119
 
119
120
  if '0' in sax_in1:
120
121
  if len(sax_in1) - 1 > len(shape1):
121
- raise ValueError(f'Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given')
122
+ raise ValueError(f'Input1 requires at least {len(sax_in1) - 1} dimensions, but only {len(shape1)} given')
122
123
  # Replace expansion indices with free indices
123
124
  n_broadcast = len(shape1) - len(sax_in1) + 1
124
125
  in1 = in1.replace('0', free_indices[:n_broadcast])
@@ -271,11 +272,11 @@ def _einsum(fn: str, input0, input1) -> np.ndarray:
271
272
 
272
273
 
273
274
  @overload
274
- def einsum(fn: str, input0: FixedVariableArray, input1: NDArray[np.integer | np.floating]) -> FixedVariableArray: ...
275
+ def einsum(fn: str, input0: 'FixedVariableArray', input1: NDArray[np.integer | np.floating]) -> 'FixedVariableArray': ...
275
276
 
276
277
 
277
278
  @overload
278
- def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: FixedVariableArray) -> FixedVariableArray: ...
279
+ def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: 'FixedVariableArray') -> 'FixedVariableArray': ...
279
280
 
280
281
 
281
282
  @overload
@@ -285,6 +286,8 @@ def einsum(
285
286
 
286
287
 
287
288
  def einsum(fn: str, input0, input1):
289
+ from ..fixed_variable_array import FixedVariableArray
290
+
288
291
  fg0 = isinstance(input0, FixedVariableArray)
289
292
  fg1 = isinstance(input1, FixedVariableArray)
290
293
  if fg0 and fg1:
@@ -0,0 +1,103 @@
1
+ import heapq
2
+ import typing
3
+ from collections.abc import Callable, Sequence
4
+ from math import prod
5
+ from typing import TypeVar
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ if typing.TYPE_CHECKING:
11
+ from ..fixed_variable import FixedVariable
12
+ from ..fixed_variable_array import FixedVariableArray
13
+
14
+
15
+ T = typing.TypeVar('T', 'FixedVariable', float, np.floating)
16
+ TA = TypeVar('TA', 'FixedVariableArray', NDArray[np.integer | np.floating])
17
+
18
+
19
+ class Packet:
20
+ def __init__(self, v):
21
+ self.value = v
22
+
23
+ def __gt__(self, other: 'Packet') -> bool: # type: ignore
24
+ from ..fixed_variable_array import FixedVariable
25
+
26
+ a, b = self.value, other.value
27
+
28
+ if isinstance(a, FixedVariable):
29
+ if isinstance(b, FixedVariable):
30
+ if b.latency > a.latency:
31
+ return False
32
+ if b.latency < a.latency:
33
+ return True
34
+ if b._factor > 0 and a._factor < 0:
35
+ return False
36
+ if b._factor < 0 and a._factor > 0:
37
+ return True
38
+ return sum(a.kif[:2]) > sum(b.kif[:2])
39
+ return True
40
+
41
+ return False
42
+
43
+ def __lt__(self, other: 'Packet') -> bool: # type: ignore
44
+ return not self.__gt__(other)
45
+
46
+
47
+ def _reduce(operator: Callable[[T, T], T], arr: Sequence[T]) -> T:
48
+ from ..fixed_variable_array import FixedVariable
49
+
50
+ if isinstance(arr, np.ndarray):
51
+ arr = list(arr.ravel())
52
+ assert len(arr) > 0, 'Array must not be empty'
53
+ if len(arr) == 1:
54
+ return arr[0]
55
+ dtype = arr[0].__class__
56
+ if not issubclass(dtype, FixedVariable):
57
+ r = operator(arr[0], arr[1])
58
+ for i in range(2, len(arr)):
59
+ r = operator(r, arr[i])
60
+ return r
61
+
62
+ heap = [Packet(v) for v in arr] # type: ignore
63
+ heapq.heapify(heap)
64
+ while len(heap) > 1:
65
+ v1 = heapq.heappop(heap).value
66
+ v2 = heapq.heappop(heap).value
67
+ v = operator(v1, v2)
68
+ heapq.heappush(heap, Packet(v)) # type: ignore
69
+ return heap[0].value
70
+
71
+
72
+ def reduce(operator: Callable[[T, T], T], x: TA, axis: int | Sequence[int] | None = None, keepdims: bool = False) -> TA:
73
+ """
74
+ Reduce the array by summing over the specified axis.
75
+ """
76
+ from ..fixed_variable_array import FixedVariableArray
77
+
78
+ if isinstance(x, FixedVariableArray):
79
+ solver_config = x.solver_options
80
+ arr = x._vars
81
+ else:
82
+ solver_config = None
83
+ arr = x
84
+ all_axis = tuple(range(arr.ndim))
85
+ axis = axis if axis is not None else all_axis
86
+ axis = (axis,) if isinstance(axis, int) else tuple(axis)
87
+ axis = tuple(a if a >= 0 else a + arr.ndim for a in axis)
88
+
89
+ xpose_axis = sorted(all_axis, key=lambda a: (a in axis) * 1000 + a)
90
+ if keepdims:
91
+ target_shape = tuple(d if ax not in axis else 1 for ax, d in enumerate(arr.shape))
92
+ else:
93
+ target_shape = tuple(d for ax, d in enumerate(arr.shape) if ax not in axis)
94
+
95
+ dim_contract = prod(arr.shape[a] for a in axis)
96
+ arr = np.transpose(arr, xpose_axis) # type: ignore
97
+ _arr = arr.reshape(-1, dim_contract)
98
+ _arr = np.array([_reduce(operator, _arr[i]) for i in range(_arr.shape[0])])
99
+ r = _arr.reshape(target_shape) # type: ignore
100
+
101
+ if isinstance(x, FixedVariableArray):
102
+ return FixedVariableArray(r, solver_config)
103
+ return r
da4ml/trace/pipeline.py CHANGED
@@ -31,6 +31,35 @@ def retime_pipeline(csol: CascadedSolution, verbose=True):
31
31
  return best
32
32
 
33
33
 
34
+ def _get_new_idx(
35
+ idx: int,
36
+ locator: list[dict[int, int]],
37
+ opd: dict[int, list[Op]],
38
+ out_idxd: dict[int, list[int]],
39
+ ops: list[Op],
40
+ stage: int,
41
+ latency_cutoff: int,
42
+ ):
43
+ if idx < 0:
44
+ return idx
45
+ p0_stages = locator[idx].keys()
46
+ if stage not in p0_stages:
47
+ # Need to copy parent to later states
48
+ p0_stage = max(p0_stages)
49
+ p0_idx = locator[idx][p0_stage]
50
+ for j in range(p0_stage, stage):
51
+ op0 = ops[idx]
52
+ latency = float(latency_cutoff * (j + 1))
53
+ out_idxd.setdefault(j, []).append(locator[idx][j])
54
+ _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op0.qint, latency, 0.0)
55
+ opd.setdefault(j + 1, []).append(_copy_op)
56
+ p0_idx = len(opd[j + 1]) - 1
57
+ locator[idx][j + 1] = p0_idx
58
+ else:
59
+ p0_idx = locator[idx][stage]
60
+ return p0_idx
61
+
62
+
34
63
  def to_pipeline(sol: Solution, latency_cutoff: int, retiming=True, verbose=True) -> CascadedSolution:
35
64
  """Split the record into multiple stages based on the latency of the operations.
36
65
  Only useful for HDL generation.
@@ -80,46 +109,19 @@ def to_pipeline(sol: Solution, latency_cutoff: int, retiming=True, verbose=True)
80
109
  opd.setdefault(stage, []).append(op)
81
110
  locator.append({stage: len(opd[stage]) - 1})
82
111
  continue
83
- p0_stages = locator[op.id0].keys()
84
- if stage not in p0_stages:
85
- # Need to copy parent to later states
86
- p0_stage = max(p0_stages)
87
- p0_idx = locator[op.id0][p0_stage]
88
- for j in range(p0_stage, stage):
89
- op0 = ops[op.id0]
90
- latency = float(latency_cutoff * (j + 1))
91
- out_idxd.setdefault(j, []).append(locator[op.id0][j])
92
- _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op0.qint, latency, 0.0)
93
- opd.setdefault(j + 1, []).append(_copy_op)
94
- p0_idx = len(opd[j + 1]) - 1
95
- locator[op.id0][j + 1] = p0_idx
96
- else:
97
- p0_idx = locator[op.id0][stage]
98
-
99
- if op.opcode in (0, 1):
100
- p1_stages = locator[op.id1].keys()
101
- if stage not in p1_stages:
102
- # Need to copy parent to later states
103
- p1_stage = max(p1_stages)
104
- p1_idx = locator[op.id1][p1_stage]
105
- for j in range(p1_stage, stage):
106
- op1 = ops[op.id1]
107
- latency = float(latency_cutoff * (j + 1))
108
- out_idxd.setdefault(j, []).append(locator[op.id1][j])
109
- _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op1.qint, latency, 0.0)
110
- opd.setdefault(j + 1, []).append(_copy_op)
111
- p1_idx = len(opd[j + 1]) - 1
112
- locator[op.id1][j + 1] = p1_idx
113
- else:
114
- p1_idx = locator[op.id1][stage]
112
+
113
+ p0_idx = _get_new_idx(op.id0, locator, opd, out_idxd, ops, stage, latency_cutoff)
114
+ p1_idx = _get_new_idx(op.id1, locator, opd, out_idxd, ops, stage, latency_cutoff)
115
+ if op.opcode in (6, -6):
116
+ data = _get_new_idx(op.data, locator, opd, out_idxd, ops, stage, latency_cutoff)
115
117
  else:
116
- p1_idx = op.id1
118
+ data = op.data
117
119
 
118
120
  if p1_idx == -1001:
119
121
  # Output to external buffer
120
122
  out_idxd.setdefault(stage, []).append(p0_idx)
121
123
  else:
122
- _Op = Op(p0_idx, p1_idx, op.opcode, op.data, op.qint, op.latency, op.cost)
124
+ _Op = Op(p0_idx, p1_idx, op.opcode, data, op.qint, op.latency, op.cost)
123
125
  opd.setdefault(stage, []).append(_Op)
124
126
  locator.append({stage: len(opd[stage]) - 1})
125
127
  sols = []
da4ml/trace/tracer.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from collections.abc import Sequence
2
2
  from decimal import Decimal
3
+ from itertools import chain
3
4
  from math import log2
4
5
  from typing import overload
5
6
  from uuid import UUID
@@ -11,20 +12,20 @@ from .fixed_variable import FixedVariable, _const_f
11
12
  from .fixed_variable_array import FixedVariableArray
12
13
 
13
14
 
14
- def _recursive_trace(v: FixedVariable, gathered: dict[UUID, FixedVariable]):
15
- if v in gathered:
15
+ def _recursive_gather(v: FixedVariable, gathered: dict[UUID, FixedVariable]):
16
+ if v.id in gathered:
16
17
  return
17
18
  assert v._from is not None
18
19
  for _v in v._from:
19
20
  if _v.id not in gathered:
20
- _recursive_trace(_v, gathered)
21
+ _recursive_gather(_v, gathered)
21
22
  gathered[v.id] = v
22
23
 
23
24
 
24
25
  def gather_variables(inputs: Sequence[FixedVariable], outputs: Sequence[FixedVariable]):
25
26
  gathered = {v.id: v for v in inputs}
26
27
  for o in outputs:
27
- _recursive_trace(o, gathered)
28
+ _recursive_gather(o, gathered)
28
29
 
29
30
  variables = list(gathered.values())
30
31
 
@@ -85,6 +86,19 @@ def _comb_trace(inputs: Sequence[FixedVariable], outputs: Sequence[FixedVariable
85
86
  qint = QInterval(qint.min, qint.min, step)
86
87
  data = qint.min / step
87
88
  ops.append(Op(-1, -1, 5, int(data), qint, v.latency, v.cost))
89
+ case 'msb_mux':
90
+ qint = v.unscaled.qint
91
+ key, in0, in1 = v._from
92
+ opcode = 6 if in1._factor > 0 else -6
93
+ idk, id0, id1 = index[key.id], index[in0.id], index[in1.id]
94
+ f0, f1 = in0._factor, in1._factor
95
+ shift = int(log2(abs(f1 / f0)))
96
+ data = idk + (shift << 32)
97
+ assert idk < i and id0 < i and id1 < i
98
+ assert key._factor > 0, f'Cannot mux on v{key.id} with negative factor {key._factor}'
99
+ op = Op(id0, id1, opcode, data, qint, v.latency, v.cost)
100
+ ops.append(op)
101
+
88
102
  case _:
89
103
  raise NotImplementedError(f'Operation "{v.opr}" is not supported in tracing')
90
104
  out_index = [index[v.id] for v in outputs]
@@ -101,6 +115,15 @@ def comb_trace(inputs: FixedVariableArray, outputs: FixedVariableArray) -> Solut
101
115
 
102
116
  def comb_trace(inputs, outputs):
103
117
  inputs, outputs = list(np.ravel(inputs)), list(np.ravel(outputs))
118
+
119
+ if any(not isinstance(v, FixedVariable) for v in outputs):
120
+ hwconf = inputs[0].hwconf
121
+ latency = max(v.latency for v in chain(inputs, outputs) if isinstance(v, FixedVariable))
122
+ outputs = list(outputs)
123
+ for i, v in enumerate(outputs):
124
+ if not isinstance(v, FixedVariable):
125
+ outputs[i] = FixedVariable.from_const(v, hwconf, latency, 1)
126
+
104
127
  ops, out_index = _comb_trace(inputs, outputs)
105
128
  shape = len(inputs), len(outputs)
106
129
  inp_shift = [0] * shape[0]
@@ -108,7 +131,7 @@ def comb_trace(inputs, outputs):
108
131
  out_shift = [int(log2(abs(sf))) for sf in out_sf]
109
132
  out_neg = [sf < 0 for sf in out_sf]
110
133
 
111
- return Solution(
134
+ sol = Solution(
112
135
  shape,
113
136
  inp_shift,
114
137
  out_index,
@@ -118,3 +141,12 @@ def comb_trace(inputs, outputs):
118
141
  outputs[0].hwconf.carry_size,
119
142
  outputs[0].hwconf.adder_size,
120
143
  )
144
+
145
+ ref_count = sol.ref_count
146
+
147
+ for i in range(len(ops)):
148
+ if ref_count[i] == 0:
149
+ op = ops[i]
150
+ sol.ops[i] = Op(-1, -1, op[2], 0, QInterval(0, 0, 1), op[5], op[6])
151
+
152
+ return sol
@@ -0,0 +1,107 @@
1
+ Metadata-Version: 2.4
2
+ Name: da4ml
3
+ Version: 0.3.0
4
+ Summary: Digital Arithmetic for Machine Learning
5
+ Author-email: Chang Sun <chsun@cern.ch>
6
+ License: GNU Lesser General Public License v3 (LGPLv3)
7
+ Project-URL: repository, https://github.com/calad0i/da4ml
8
+ Keywords: CMVM,distributed arithmetic,hls4ml,MCM,subexpression elimination
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Requires-Python: >=3.10
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: llvmlite>=0.44
21
+ Requires-Dist: numba>=0.61
22
+ Dynamic: license-file
23
+
24
+ # da4ml: Distributed Arithmetic for Machine Learning
25
+
26
+ This project performs Constant Matrix-Vector Multiplication (CMVM) with Distributed Arithmetic (DA) for Machine Learning (ML) on a Field Programmable Gate Arrays (FPGAs).
27
+
28
+ CMVM optimization is done through greedy CSE of two-term subexpressions, with possible Delay Constraints (DC). The optimization is done in jitted Python (Numba), and a list of optimized operations is generated as traced Python code.
29
+
30
+ The project generates Verilog or Vitis HLS code for the optimized CMVM operations. This project can be used in conjunction with [`hls4ml`](https://github.com/fastmachinelearning/hls4ml/) for optimizing the neural networks deployed on FPGAs. For a subset of neural networks, the full design can be generated standalone in Verilog or Vitis HLS.
31
+
32
+
33
+ ## Installation
34
+
35
+ The project is available on PyPI and can be installed with pip:
36
+
37
+ ```bash
38
+ pip install da4ml
39
+ ```
40
+
41
+ Notice that `numba>=6.0.0` is required for the project to work. The project does not work with `python<3.10`. If the project fails to compile, try upgrading `numba` and `llvmlite` to the latest versions.
42
+
43
+ ## `hls4ml`
44
+
45
+ The major use of this project is through the `distributed_arithmetic` strategy in the `hls4ml`:
46
+
47
+ ```python
48
+ model_hls = hls4ml.converters.convert_from_keras_model(
49
+ model,
50
+ hls_config={
51
+ 'Model': {
52
+ ...
53
+ 'Strategy': 'distributed_arithmetic',
54
+ },
55
+ ...
56
+ },
57
+ ...
58
+ )
59
+ ```
60
+
61
+ Currently, `Dense/Conv1D/Conv2D` layers are supported for both `io_parallel` and `io_stream` dataflows. However, notice that distributed arithmetic implies `reuse_factor=1`, as the whole kernel is implemented in combinational logic.
62
+
63
+ ## Standalone usage
64
+
65
+ ### `HGQ2`
66
+
67
+ For some models trained with `HGQ2`, the `da4ml` can be used to generate the whole model in Verilog or Vitis HLS:
68
+
69
+ ```python
70
+ from da4ml.codegen import HLSModel, VerilogModel
71
+ from da4ml.converter.hgq2.parser import trace_model
72
+ from da4ml.trace import comb_trace
73
+
74
+ inp, out = trace_model(hgq2_model)
75
+ comb_logic = comb_trace(inp[0], out[0]) # Currently, only models with 1 input and 1 output are supported
76
+
77
+ # Pipelined Verilog model generation
78
+ # `latency_cutoff` is used to control auto piplining behavior. To disable pipelining, set it to -1.
79
+ verilog_model = VerilogModel(sol, prj_name='barbar', path='/tmp/barbar', latency_cutoff=5)
80
+ verilog_model.compile() # write and verilator binding
81
+ verilog_model.predict(inputs)
82
+
83
+ vitis_hls_model = HLSModel(sol, prj_name='foo', path='/tmp/foo', flavor='vitis') # Only vitis is supported for now
84
+ vitis_hls_model.compile() # write and hls binding
85
+ vitis_hls_model.predict(inputs)
86
+ ```
87
+
88
+ ### Functional Definition
89
+ For generic operations, one can define a combinational logic with the functional API:
90
+
91
+ ```python
92
+ from da4ml.trace import FixedVariableArray, HWConfig, comb_trace
93
+ from da4ml.trace.ops import einsum, relu, quantize, conv, pool
94
+
95
+ # k, i, f are numpy arrays of integers: keep_negative (0/1), integer bits (excl. sign), fractional bits
96
+ inp = FixedVariableArray.from_kif(k, i, f, HWConfig(1, -1, -1), solver_options={'hard_dc':2})
97
+ out = inp @ kernel
98
+ out = relu(out)
99
+ out = einsum(equation, out, weights)
100
+ ...
101
+
102
+ comb = comb_trace(inp, out)
103
+ ```
104
+
105
+ `+`, `-`, `@` are supported as well as `einsum`, `relu`, `quantize` (WRAP, with TRN or RND), `conv`, `pool` (average only). For multiplications, only power-of-two multipliers are supported, otherwise use `einsum` or `@` operators.
106
+
107
+ The `comb_trace` returns a `Solution` objects that contains a list of low-level operations that are used to implement the combinational logic, which in turn can be used to generate Verilog or Vitis HLS code.