da4ml 0.2.1__py3-none-any.whl → 0.3.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (55) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/cmvm/types.py +95 -15
  3. da4ml/codegen/__init__.py +5 -4
  4. da4ml/codegen/cpp/__init__.py +2 -1
  5. da4ml/codegen/cpp/cpp_codegen.py +56 -23
  6. da4ml/codegen/cpp/hls_model.py +252 -0
  7. da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
  8. da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
  9. da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
  10. da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
  11. da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
  12. da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
  13. da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
  14. da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
  15. da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
  16. da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
  17. da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
  18. da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
  19. da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
  20. da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
  21. da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
  22. da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
  23. da4ml/codegen/cpp/source/binder_util.hh +56 -0
  24. da4ml/codegen/cpp/source/build_binder.mk +24 -0
  25. da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
  26. da4ml/codegen/verilog/__init__.py +2 -3
  27. da4ml/codegen/verilog/comb.py +65 -24
  28. da4ml/codegen/verilog/io_wrapper.py +36 -141
  29. da4ml/codegen/verilog/source/binder_util.hh +72 -0
  30. da4ml/codegen/verilog/source/mux.v +58 -0
  31. da4ml/codegen/verilog/source/negative.v +28 -0
  32. da4ml/codegen/verilog/source/shift_adder.v +4 -1
  33. da4ml/codegen/verilog/source/template.xdc +3 -0
  34. da4ml/codegen/verilog/verilog_model.py +36 -12
  35. da4ml/converter/__init__.py +0 -0
  36. da4ml/converter/hgq2/parser.py +105 -0
  37. da4ml/converter/hgq2/replica.py +383 -0
  38. da4ml/trace/__init__.py +2 -2
  39. da4ml/trace/fixed_variable.py +175 -16
  40. da4ml/trace/fixed_variable_array.py +109 -4
  41. da4ml/trace/ops/__init__.py +22 -6
  42. da4ml/trace/ops/conv_utils.py +147 -15
  43. da4ml/trace/ops/einsum_utils.py +9 -6
  44. da4ml/trace/ops/reduce_utils.py +103 -0
  45. da4ml/trace/pipeline.py +36 -34
  46. da4ml/trace/tracer.py +37 -7
  47. da4ml-0.3.0.post1.dist-info/METADATA +107 -0
  48. da4ml-0.3.0.post1.dist-info/RECORD +64 -0
  49. da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
  50. da4ml-0.2.1.dist-info/METADATA +0 -65
  51. da4ml-0.2.1.dist-info/RECORD +0 -39
  52. /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
  53. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/WHEEL +0 -0
  54. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/licenses/LICENSE +0 -0
  55. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,97 @@
1
1
  from inspect import signature
2
- from typing import Any
2
+ from typing import Any, TypeVar
3
3
 
4
4
  import numpy as np
5
5
  from numba.typed import List as NumbaList
6
6
  from numpy.typing import NDArray
7
7
 
8
8
  from ..cmvm import solve
9
- from .fixed_variable import FixedVariable, HWConfig, QInterval
9
+ from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, QInterval
10
+ from .ops import einsum, reduce
11
+
12
+ T = TypeVar('T')
13
+
14
+
15
+ def to_raw_arr(obj: T) -> T:
16
+ if isinstance(obj, tuple):
17
+ return tuple(to_raw_arr(x) for x in obj) # type: ignore
18
+ elif isinstance(obj, list):
19
+ return [to_raw_arr(x) for x in obj] # type: ignore
20
+ elif isinstance(obj, dict):
21
+ return {k: to_raw_arr(v) for k, v in obj.items()} # type: ignore
22
+ if isinstance(obj, FixedVariableArray):
23
+ return obj._vars # type: ignore
24
+ return obj
25
+
26
+
27
+ def _max_of(a, b):
28
+ if isinstance(a, FixedVariable):
29
+ return a.max_of(b)
30
+ elif isinstance(b, FixedVariable):
31
+ return b.max_of(a)
32
+ else:
33
+ return max(a, b)
34
+
35
+
36
+ def _min_of(a, b):
37
+ if isinstance(a, FixedVariable):
38
+ return a.min_of(b)
39
+ elif isinstance(b, FixedVariable):
40
+ return b.min_of(a)
41
+ else:
42
+ return min(a, b)
10
43
 
11
44
 
12
45
  class FixedVariableArray:
46
+ __array_priority__ = 100
47
+
48
+ def __array_function__(self, func, types, args, kwargs):
49
+ if func is np.matmul:
50
+ if len(args) == 1 and isinstance(args[0], np.ndarray):
51
+ return self.__matmul__(args[0])
52
+ elif len(args) == 2 and isinstance(args[0], np.ndarray) and isinstance(args[1], np.ndarray):
53
+ return self.__rmatmul__(args[1])
54
+
55
+ if func in (np.mean, np.sum, np.amax, np.amin, np.max, np.min):
56
+ match func:
57
+ case np.mean:
58
+ _x = reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
59
+ return _x * (_x.size / self._vars.size)
60
+ case np.sum:
61
+ return reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
62
+ case np.max | np.amax:
63
+ return reduce(_max_of, self, *args[1:], **kwargs)
64
+ case np.min | np.amin:
65
+ return reduce(_min_of, self, *args[1:], **kwargs)
66
+ case _:
67
+ raise NotImplementedError(f'Unsupported function: {func}')
68
+
69
+ if func is np.clip:
70
+ assert len(args) == 3, 'Clip function requires exactly three arguments'
71
+ x, low, high = args
72
+ _x, low, high = np.broadcast_arrays(x, low, high)
73
+ x = FixedVariableArray(_x, self.solver_options)
74
+ x = np.amax(np.stack((x, low), axis=-1), axis=-1) # type: ignore
75
+ return np.amin(np.stack((x, high), axis=-1), axis=-1)
76
+
77
+ if func is np.einsum:
78
+ # assert len(args) == 2
79
+ sig = signature(np.einsum)
80
+ bind = sig.bind(*args, **kwargs)
81
+ eq = args[0]
82
+ operands = bind.arguments['operands']
83
+ if isinstance(operands[0], str):
84
+ operands = operands[1:]
85
+ assert len(operands) == 2, 'Einsum on FixedVariableArray requires exactly two operands'
86
+ assert bind.arguments.get('out', None) is None, 'Output argument is not supported'
87
+ return einsum(eq, *operands)
88
+
89
+ args, kwargs = to_raw_arr(args), to_raw_arr(kwargs)
90
+ return FixedVariableArray(
91
+ func(*args, **kwargs),
92
+ self.solver_options,
93
+ )
94
+
13
95
  def __init__(
14
96
  self,
15
97
  vars: NDArray,
@@ -121,9 +203,13 @@ class FixedVariableArray:
121
203
  return self._vars.shape
122
204
 
123
205
  def __add__(self, other):
206
+ if isinstance(other, FixedVariableArray):
207
+ return FixedVariableArray(self._vars + other._vars, self.solver_options)
124
208
  return FixedVariableArray(self._vars + other, self.solver_options)
125
209
 
126
210
  def __sub__(self, other):
211
+ if isinstance(other, FixedVariableArray):
212
+ return FixedVariableArray(self._vars - other._vars, self.solver_options)
127
213
  return FixedVariableArray(self._vars - other, self.solver_options)
128
214
 
129
215
  def __mul__(self, other):
@@ -149,7 +235,7 @@ class FixedVariableArray:
149
235
  i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
150
236
  f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
151
237
  ret = []
152
- for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()):
238
+ for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()): # type: ignore
153
239
  ret.append(v.relu(i=i, f=f, round_mode=round_mode))
154
240
  return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
155
241
 
@@ -166,7 +252,7 @@ class FixedVariableArray:
166
252
  i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
167
253
  f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
168
254
  ret = []
169
- for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()):
255
+ for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()): # type: ignore
170
256
  ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
171
257
  return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
172
258
 
@@ -185,3 +271,22 @@ class FixedVariableArray:
185
271
  @property
186
272
  def dtype(self):
187
273
  return self._vars.dtype
274
+
275
+ @property
276
+ def size(self):
277
+ return self._vars.size
278
+
279
+ @property
280
+ def kif(self):
281
+ shape = self._vars.shape
282
+ kif = np.array([v.kif for v in self._vars.ravel()]).reshape(*shape, 3)
283
+ return np.moveaxis(kif, -1, 0)
284
+
285
+
286
+ class FixedVariableArrayInput(FixedVariableArray):
287
+ def __init__(self, shape: tuple[int, ...] | int, hwconf: HWConfig, solver_options: dict[str, Any] | None = None, latency=0.0):
288
+ _vars = np.empty(shape, dtype=object)
289
+ _vars_f = _vars.ravel()
290
+ for i in range(_vars.size):
291
+ _vars_f[i] = FixedVariableInput(latency, hwconf)
292
+ super().__init__(_vars, solver_options)
@@ -1,16 +1,22 @@
1
- from typing import TypeVar
1
+ from typing import TYPE_CHECKING, TypeVar
2
2
 
3
3
  import numpy as np
4
4
  from numpy.typing import NDArray
5
5
 
6
- from ..fixed_variable_array import FixedVariable, FixedVariableArray
7
- from .conv_utils import conv
6
+ from ..fixed_variable_array import FixedVariable
7
+ from .conv_utils import conv, pool
8
8
  from .einsum_utils import einsum
9
+ from .reduce_utils import reduce
9
10
 
10
- T = TypeVar('T', FixedVariableArray, NDArray[np.floating], list[FixedVariable])
11
+ if TYPE_CHECKING:
12
+ from ..fixed_variable_array import FixedVariableArray
13
+
14
+ T = TypeVar('T', 'FixedVariableArray', NDArray[np.floating], list[FixedVariable])
11
15
 
12
16
 
13
17
  def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN') -> T:
18
+ from ..fixed_variable_array import FixedVariableArray
19
+
14
20
  if isinstance(x, FixedVariableArray):
15
21
  return x.relu(i=i, f=f, round_mode=round_mode)
16
22
  elif isinstance(x, list):
@@ -35,12 +41,20 @@ def quantize(
35
41
  overflow_mode: str = 'WRAP',
36
42
  round_mode: str = 'TRN',
37
43
  ) -> T:
38
- assert overflow_mode.upper() == 'WRAP', 'Only WRAP overflow mode is supported'
44
+ from ..fixed_variable_array import FixedVariableArray
45
+
39
46
  if isinstance(x, FixedVariableArray):
40
47
  return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
41
48
  else:
49
+ x = x.copy()
50
+ if overflow_mode in ('SAT', 'SAT_SM'):
51
+ step = 2.0**-f
52
+ _high = 2.0**i
53
+ high = _high - step
54
+ low = -_high * k if overflow_mode == 'SAT' else -high * k
55
+ x = np.clip(x, low, high) # type: ignore
42
56
  if round_mode.upper() == 'RND':
43
- x += 2.0 ** (-f - 1)
57
+ x += 2.0 ** (-f - 1) # type: ignore
44
58
  b = k + i + f
45
59
  bias = 2.0 ** (b - 1) * k
46
60
  eps = 2.0**-f
@@ -52,4 +66,6 @@ __all__ = [
52
66
  'einsum',
53
67
  'relu',
54
68
  'quantize',
69
+ 'pool',
70
+ 'reduce',
55
71
  ]
@@ -1,10 +1,15 @@
1
+ import typing
1
2
  from collections.abc import Sequence
3
+ from math import ceil, prod
2
4
  from typing import TypeVar
3
5
 
4
6
  import numpy as np
5
7
  from numpy.typing import NDArray
6
8
 
7
- from ..fixed_variable_array import FixedVariableArray
9
+ from .reduce_utils import reduce
10
+
11
+ if typing.TYPE_CHECKING:
12
+ from ..fixed_variable_array import FixedVariableArray
8
13
 
9
14
 
10
15
  def r_im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int):
@@ -33,23 +38,23 @@ def stride_arr(stride: int | tuple[int, ...], arr: np.ndarray):
33
38
  ndim = arr.ndim
34
39
  if isinstance(stride, int):
35
40
  stride = (stride,) * (ndim - 1)
36
- assert len(stride) == ndim - 1, f'Invalid stride {stride} for array with {ndim} dimensions'
37
41
 
38
42
  _idx = tuple(slice(None, None, st) for st in stride)
39
- return arr[*_idx]
43
+ return arr[_idx]
40
44
 
41
45
 
42
- T = TypeVar('T', FixedVariableArray, NDArray[np.integer | np.floating])
46
+ TA = TypeVar('TA', 'FixedVariableArray', NDArray[np.integer | np.floating])
43
47
 
44
48
 
45
- def conv(
46
- x: T,
49
+ def _conv(
50
+ x: TA,
47
51
  kernel: NDArray[np.integer | np.floating],
48
52
  bias: NDArray[np.integer | np.floating] | None = None,
49
53
  strides: int | tuple[int, ...] = 1,
50
54
  padding: tuple[tuple[int, int], ...] | str = 'VALID',
51
- format: str = 'channels_last',
52
- ):
55
+ ) -> TA:
56
+ from ..fixed_variable_array import FixedVariableArray
57
+
53
58
  if isinstance(x, FixedVariableArray):
54
59
  solver_options = x.solver_options
55
60
  data = x._vars
@@ -63,10 +68,10 @@ def conv(
63
68
  ch_in, ch_out = kernel.shape[-2:]
64
69
  _ch_in = data.shape[-1]
65
70
  assert ch_in == _ch_in, f'Invalid input shape {data.shape} for kernel {kernel.shape}'
66
- assert kernel.ndim == ndim + 1
67
-
68
- assert format in ('channels_last', 'channels_first'), f'Invalid format {format}'
69
-
71
+ if kernel.ndim != ndim + 1:
72
+ if kernel.ndim == ndim:
73
+ raise ValueError('Inputs should not contain batch dimension')
74
+ raise ValueError(f'Invalid kernel shape {kernel.shape} for input with {ndim} dimensions')
70
75
  if isinstance(strides, int):
71
76
  strides = (strides,) * (ndim - 1)
72
77
  assert len(strides) == ndim - 1, f'Invalid stride {strides} for array with {ndim} dimensions'
@@ -89,16 +94,143 @@ def conv(
89
94
 
90
95
  data = np.pad(data, padding + ((0, 0),), mode='constant', constant_values=0.0)
91
96
  data = _im2col(kernel.shape, data)
97
+ data = stride_arr(strides, data)
92
98
  if is_symbolic:
93
99
  _data = FixedVariableArray(data, solver_options) @ kernel.reshape(-1, ch_out)
94
100
  data = _data._vars
95
101
  else:
96
102
  data = data @ kernel.reshape(-1, ch_out)
97
- data = stride_arr(strides, data)
98
103
  if bias is not None:
99
104
  data = data + bias
105
+ if isinstance(x, FixedVariableArray):
106
+ return FixedVariableArray(data, solver_options)
107
+ return data
108
+
109
+
110
+ def conv(
111
+ x: TA,
112
+ kernel: NDArray[np.integer | np.floating],
113
+ bias: NDArray[np.integer | np.floating] | None = None,
114
+ strides: int | tuple[int, ...] = 1,
115
+ padding: tuple[tuple[int, int], ...] | str = 'VALID',
116
+ format: str = 'channels_last',
117
+ groups: int | None = None,
118
+ ) -> TA:
119
+ from ..fixed_variable_array import FixedVariableArray
120
+
121
+ assert format in ('channels_last', 'channels_first'), f'Invalid format {format}'
122
+ if format == 'channels_first':
123
+ x = np.moveaxis(x, 0, -1) # type: ignore
124
+
125
+ *_, _ch_in, ch_out = kernel.shape
126
+ ch_in = x.shape[-1]
127
+ assert ch_in % _ch_in == 0, f'groups is not integer (total_ch_in={ch_in}, kernel_ch_in={_ch_in})'
128
+ if groups is None:
129
+ groups = ch_in // _ch_in
130
+ else:
131
+ assert (
132
+ groups == ch_in // _ch_in
133
+ ), f'groups {groups} does not match input channels {ch_in} and kernel input channels {_ch_in}'
134
+ assert ch_out % groups == 0, f'groups is not integer (total_ch_out={ch_out}, groups={groups})'
135
+ _ch_out = ch_out // groups
136
+
137
+ buf: list[TA] = []
138
+ for gp in range(groups):
139
+ _kernel = kernel[..., gp * _ch_out : (gp + 1) * _ch_out]
140
+ _x = x[..., gp * _ch_in : (gp + 1) * _ch_in]
141
+ _buf = _conv(
142
+ _x,
143
+ _kernel,
144
+ strides=strides,
145
+ padding=padding,
146
+ )
147
+ buf.append(_buf) # type: ignore
148
+
149
+ if isinstance(x, FixedVariableArray):
150
+ data = np.concatenate([b._vars for b in buf], axis=-1) # type: ignore
151
+ else:
152
+ data = np.concatenate(buf, axis=-1) # type: ignore
153
+
154
+ data = data + bias if bias is not None else data
155
+
100
156
  if format == 'channels_first':
101
- data = np.moveaxis(data, -1, 1)
102
- if solver_options is not None:
157
+ return np.moveaxis(data, -1, 0) # type: ignore
158
+
159
+ if isinstance(x, FixedVariableArray):
160
+ return FixedVariableArray(data, x.solver_options)
161
+ return data
162
+
163
+
164
+ def pool(
165
+ x: TA,
166
+ pool_size: Sequence[int],
167
+ strides: int | Sequence[int] | None = None,
168
+ padding: tuple[tuple[int, int], ...] | str = 'VALID',
169
+ pool_type: str = 'avg',
170
+ format: str = 'channels_last',
171
+ ) -> TA:
172
+ from ..fixed_variable import FixedVariable
173
+ from ..fixed_variable_array import FixedVariableArray
174
+
175
+ if isinstance(x, FixedVariableArray):
176
+ solver_options = x.solver_options
177
+ data = x._vars
178
+ else:
179
+ solver_options = None
180
+ data = x
181
+
182
+ if format == 'channels_first':
183
+ data = np.moveaxis(data, 0, -1)
184
+
185
+ strides = strides or pool_size
186
+
187
+ assert pool_type in ('avg', 'max'), f'Invalid pool type {pool_type}'
188
+ ndim = data.ndim
189
+ if isinstance(strides, int):
190
+ strides = (strides,) * (ndim - 1)
191
+ assert len(strides) == ndim - 1, f'Invalid stride {strides} for array with {ndim} dimensions'
192
+
193
+ if isinstance(padding, str):
194
+ padding = padding.upper()
195
+ if padding == 'VALID':
196
+ padding = ((0, 0),) * (ndim - 1)
197
+ elif padding == 'SAME':
198
+ _padding = []
199
+ for i in range(ndim - 1):
200
+ n_pad = ceil(data.shape[i] / strides[i]) * strides[i] + (pool_size[i] - strides[i]) - data.shape[i]
201
+ pad0 = n_pad // 2
202
+ pad1 = n_pad - pad0
203
+ _padding.append((pad0, pad1))
204
+ padding = tuple(_padding)
205
+ else:
206
+ raise ValueError(f'Invalid padding {padding}')
207
+ assert len(padding) == ndim - 1, f'Invalid padding {padding} for array with {ndim} dimensions'
208
+ assert all(len(p) == 2 for p in padding), f'Invalid padding {padding} for array with {ndim} dimensions'
209
+
210
+ data = np.pad(data, padding + ((0, 0),), mode='constant', constant_values=-np.inf)
211
+ ch_in = data.shape[-1]
212
+ fake_kernel_shape = tuple(pool_size) + (ch_in, ch_in)
213
+ data = _im2col(fake_kernel_shape, data)
214
+ data = data.reshape(*data.shape[:-1], prod(pool_size), ch_in)
215
+ data = stride_arr(tuple(strides), data)
216
+ if pool_type == 'avg':
217
+ div = np.sum(data != -np.inf, axis=-2)
218
+ data = np.where(data == -np.inf, 0, data)
219
+ data = reduce(lambda x, y: x + y, data, axis=-2) * (1 / div)
220
+ else:
221
+
222
+ def max_of(a, b):
223
+ if isinstance(a, FixedVariable):
224
+ return a.max_of(b)
225
+ if isinstance(b, FixedVariable):
226
+ return b.max_of(a)
227
+ return max(a, b)
228
+
229
+ data = reduce(lambda x, y: max_of(x, y), data, axis=-2)
230
+
231
+ if format == 'channels_first':
232
+ data = np.moveaxis(data, -1, 0)
233
+
234
+ if isinstance(x, FixedVariableArray):
103
235
  return FixedVariableArray(data, solver_options)
104
236
  return data
@@ -1,10 +1,11 @@
1
1
  from math import prod
2
- from typing import TypedDict, overload
2
+ from typing import TYPE_CHECKING, TypedDict, overload
3
3
 
4
4
  import numpy as np
5
5
  from numpy.typing import NDArray
6
6
 
7
- from ..fixed_variable_array import FixedVariableArray
7
+ if TYPE_CHECKING:
8
+ from ..fixed_variable_array import FixedVariableArray
8
9
 
9
10
 
10
11
  class EinsumRecipe(TypedDict):
@@ -105,7 +106,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, .
105
106
  # Axes expansion in input0 or input1 only
106
107
  if '0' in sax_in0:
107
108
  if len(sax_in0) - 1 > len(shape0):
108
- raise ValueError(f'Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given')
109
+ raise ValueError(f'Input0 requires at least {len(sax_in0) - 1} dimensions, but only {len(shape0)} given')
109
110
  # Replace auto expansion indices with free indices
110
111
  n_broadcast = len(shape0) - len(sax_in0) + 1
111
112
  in0 = in0.replace('0', free_indices[:n_broadcast])
@@ -118,7 +119,7 @@ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, .
118
119
 
119
120
  if '0' in sax_in1:
120
121
  if len(sax_in1) - 1 > len(shape1):
121
- raise ValueError(f'Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given')
122
+ raise ValueError(f'Input1 requires at least {len(sax_in1) - 1} dimensions, but only {len(shape1)} given')
122
123
  # Replace expansion indices with free indices
123
124
  n_broadcast = len(shape1) - len(sax_in1) + 1
124
125
  in1 = in1.replace('0', free_indices[:n_broadcast])
@@ -271,11 +272,11 @@ def _einsum(fn: str, input0, input1) -> np.ndarray:
271
272
 
272
273
 
273
274
  @overload
274
- def einsum(fn: str, input0: FixedVariableArray, input1: NDArray[np.integer | np.floating]) -> FixedVariableArray: ...
275
+ def einsum(fn: str, input0: 'FixedVariableArray', input1: NDArray[np.integer | np.floating]) -> 'FixedVariableArray': ...
275
276
 
276
277
 
277
278
  @overload
278
- def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: FixedVariableArray) -> FixedVariableArray: ...
279
+ def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: 'FixedVariableArray') -> 'FixedVariableArray': ...
279
280
 
280
281
 
281
282
  @overload
@@ -285,6 +286,8 @@ def einsum(
285
286
 
286
287
 
287
288
  def einsum(fn: str, input0, input1):
289
+ from ..fixed_variable_array import FixedVariableArray
290
+
288
291
  fg0 = isinstance(input0, FixedVariableArray)
289
292
  fg1 = isinstance(input1, FixedVariableArray)
290
293
  if fg0 and fg1:
@@ -0,0 +1,103 @@
1
+ import heapq
2
+ import typing
3
+ from collections.abc import Callable, Sequence
4
+ from math import prod
5
+ from typing import TypeVar
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ if typing.TYPE_CHECKING:
11
+ from ..fixed_variable import FixedVariable
12
+ from ..fixed_variable_array import FixedVariableArray
13
+
14
+
15
+ T = typing.TypeVar('T', 'FixedVariable', float, np.floating)
16
+ TA = TypeVar('TA', 'FixedVariableArray', NDArray[np.integer | np.floating])
17
+
18
+
19
+ class Packet:
20
+ def __init__(self, v):
21
+ self.value = v
22
+
23
+ def __gt__(self, other: 'Packet') -> bool: # type: ignore
24
+ from ..fixed_variable_array import FixedVariable
25
+
26
+ a, b = self.value, other.value
27
+
28
+ if isinstance(a, FixedVariable):
29
+ if isinstance(b, FixedVariable):
30
+ if b.latency > a.latency:
31
+ return False
32
+ if b.latency < a.latency:
33
+ return True
34
+ if b._factor > 0 and a._factor < 0:
35
+ return False
36
+ if b._factor < 0 and a._factor > 0:
37
+ return True
38
+ return sum(a.kif[:2]) > sum(b.kif[:2])
39
+ return True
40
+
41
+ return False
42
+
43
+ def __lt__(self, other: 'Packet') -> bool: # type: ignore
44
+ return not self.__gt__(other)
45
+
46
+
47
+ def _reduce(operator: Callable[[T, T], T], arr: Sequence[T]) -> T:
48
+ from ..fixed_variable_array import FixedVariable
49
+
50
+ if isinstance(arr, np.ndarray):
51
+ arr = list(arr.ravel())
52
+ assert len(arr) > 0, 'Array must not be empty'
53
+ if len(arr) == 1:
54
+ return arr[0]
55
+ dtype = arr[0].__class__
56
+ if not issubclass(dtype, FixedVariable):
57
+ r = operator(arr[0], arr[1])
58
+ for i in range(2, len(arr)):
59
+ r = operator(r, arr[i])
60
+ return r
61
+
62
+ heap = [Packet(v) for v in arr] # type: ignore
63
+ heapq.heapify(heap)
64
+ while len(heap) > 1:
65
+ v1 = heapq.heappop(heap).value
66
+ v2 = heapq.heappop(heap).value
67
+ v = operator(v1, v2)
68
+ heapq.heappush(heap, Packet(v)) # type: ignore
69
+ return heap[0].value
70
+
71
+
72
+ def reduce(operator: Callable[[T, T], T], x: TA, axis: int | Sequence[int] | None = None, keepdims: bool = False) -> TA:
73
+ """
74
+ Reduce the array by summing over the specified axis.
75
+ """
76
+ from ..fixed_variable_array import FixedVariableArray
77
+
78
+ if isinstance(x, FixedVariableArray):
79
+ solver_config = x.solver_options
80
+ arr = x._vars
81
+ else:
82
+ solver_config = None
83
+ arr = x
84
+ all_axis = tuple(range(arr.ndim))
85
+ axis = axis if axis is not None else all_axis
86
+ axis = (axis,) if isinstance(axis, int) else tuple(axis)
87
+ axis = tuple(a if a >= 0 else a + arr.ndim for a in axis)
88
+
89
+ xpose_axis = sorted(all_axis, key=lambda a: (a in axis) * 1000 + a)
90
+ if keepdims:
91
+ target_shape = tuple(d if ax not in axis else 1 for ax, d in enumerate(arr.shape))
92
+ else:
93
+ target_shape = tuple(d for ax, d in enumerate(arr.shape) if ax not in axis)
94
+
95
+ dim_contract = prod(arr.shape[a] for a in axis)
96
+ arr = np.transpose(arr, xpose_axis) # type: ignore
97
+ _arr = arr.reshape(-1, dim_contract)
98
+ _arr = np.array([_reduce(operator, _arr[i]) for i in range(_arr.shape[0])])
99
+ r = _arr.reshape(target_shape) # type: ignore
100
+
101
+ if isinstance(x, FixedVariableArray):
102
+ return FixedVariableArray(r, solver_config)
103
+ return r
da4ml/trace/pipeline.py CHANGED
@@ -31,6 +31,35 @@ def retime_pipeline(csol: CascadedSolution, verbose=True):
31
31
  return best
32
32
 
33
33
 
34
+ def _get_new_idx(
35
+ idx: int,
36
+ locator: list[dict[int, int]],
37
+ opd: dict[int, list[Op]],
38
+ out_idxd: dict[int, list[int]],
39
+ ops: list[Op],
40
+ stage: int,
41
+ latency_cutoff: int,
42
+ ):
43
+ if idx < 0:
44
+ return idx
45
+ p0_stages = locator[idx].keys()
46
+ if stage not in p0_stages:
47
+ # Need to copy parent to later states
48
+ p0_stage = max(p0_stages)
49
+ p0_idx = locator[idx][p0_stage]
50
+ for j in range(p0_stage, stage):
51
+ op0 = ops[idx]
52
+ latency = float(latency_cutoff * (j + 1))
53
+ out_idxd.setdefault(j, []).append(locator[idx][j])
54
+ _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op0.qint, latency, 0.0)
55
+ opd.setdefault(j + 1, []).append(_copy_op)
56
+ p0_idx = len(opd[j + 1]) - 1
57
+ locator[idx][j + 1] = p0_idx
58
+ else:
59
+ p0_idx = locator[idx][stage]
60
+ return p0_idx
61
+
62
+
34
63
  def to_pipeline(sol: Solution, latency_cutoff: int, retiming=True, verbose=True) -> CascadedSolution:
35
64
  """Split the record into multiple stages based on the latency of the operations.
36
65
  Only useful for HDL generation.
@@ -80,46 +109,19 @@ def to_pipeline(sol: Solution, latency_cutoff: int, retiming=True, verbose=True)
80
109
  opd.setdefault(stage, []).append(op)
81
110
  locator.append({stage: len(opd[stage]) - 1})
82
111
  continue
83
- p0_stages = locator[op.id0].keys()
84
- if stage not in p0_stages:
85
- # Need to copy parent to later states
86
- p0_stage = max(p0_stages)
87
- p0_idx = locator[op.id0][p0_stage]
88
- for j in range(p0_stage, stage):
89
- op0 = ops[op.id0]
90
- latency = float(latency_cutoff * (j + 1))
91
- out_idxd.setdefault(j, []).append(locator[op.id0][j])
92
- _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op0.qint, latency, 0.0)
93
- opd.setdefault(j + 1, []).append(_copy_op)
94
- p0_idx = len(opd[j + 1]) - 1
95
- locator[op.id0][j + 1] = p0_idx
96
- else:
97
- p0_idx = locator[op.id0][stage]
98
-
99
- if op.opcode in (0, 1):
100
- p1_stages = locator[op.id1].keys()
101
- if stage not in p1_stages:
102
- # Need to copy parent to later states
103
- p1_stage = max(p1_stages)
104
- p1_idx = locator[op.id1][p1_stage]
105
- for j in range(p1_stage, stage):
106
- op1 = ops[op.id1]
107
- latency = float(latency_cutoff * (j + 1))
108
- out_idxd.setdefault(j, []).append(locator[op.id1][j])
109
- _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op1.qint, latency, 0.0)
110
- opd.setdefault(j + 1, []).append(_copy_op)
111
- p1_idx = len(opd[j + 1]) - 1
112
- locator[op.id1][j + 1] = p1_idx
113
- else:
114
- p1_idx = locator[op.id1][stage]
112
+
113
+ p0_idx = _get_new_idx(op.id0, locator, opd, out_idxd, ops, stage, latency_cutoff)
114
+ p1_idx = _get_new_idx(op.id1, locator, opd, out_idxd, ops, stage, latency_cutoff)
115
+ if op.opcode in (6, -6):
116
+ data = _get_new_idx(op.data, locator, opd, out_idxd, ops, stage, latency_cutoff)
115
117
  else:
116
- p1_idx = op.id1
118
+ data = op.data
117
119
 
118
120
  if p1_idx == -1001:
119
121
  # Output to external buffer
120
122
  out_idxd.setdefault(stage, []).append(p0_idx)
121
123
  else:
122
- _Op = Op(p0_idx, p1_idx, op.opcode, op.data, op.qint, op.latency, op.cost)
124
+ _Op = Op(p0_idx, p1_idx, op.opcode, data, op.qint, op.latency, op.cost)
123
125
  opd.setdefault(stage, []).append(_Op)
124
126
  locator.append({stage: len(opd[stage]) - 1})
125
127
  sols = []