da4ml 0.4.1__py3-none-any.whl → 0.5.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +2 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +2 -2
- da4ml/cmvm/api.py +15 -4
- da4ml/cmvm/core/__init__.py +2 -2
- da4ml/cmvm/types.py +32 -18
- da4ml/cmvm/util/bit_decompose.py +2 -2
- da4ml/codegen/hls/hls_codegen.py +10 -5
- da4ml/codegen/hls/hls_model.py +7 -4
- da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +11 -13
- da4ml/codegen/rtl/rtl_model.py +105 -54
- da4ml/codegen/rtl/verilog/__init__.py +2 -1
- da4ml/codegen/rtl/verilog/comb.py +47 -7
- da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
- da4ml/codegen/rtl/verilog/pipeline.py +12 -12
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/vhdl/comb.py +27 -21
- da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
- da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/converter/__init__.py +57 -1
- da4ml/converter/hgq2/parser.py +4 -25
- da4ml/converter/hgq2/replica.py +208 -22
- da4ml/trace/fixed_variable.py +239 -29
- da4ml/trace/fixed_variable_array.py +276 -48
- da4ml/trace/ops/__init__.py +31 -15
- da4ml/trace/ops/reduce_utils.py +3 -3
- da4ml/trace/pipeline.py +40 -18
- da4ml/trace/tracer.py +33 -8
- da4ml/typing/__init__.py +3 -0
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
- da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,15 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from decimal import Decimal
|
|
1
3
|
from inspect import signature
|
|
2
|
-
from typing import
|
|
4
|
+
from typing import TypeVar
|
|
3
5
|
|
|
4
6
|
import numpy as np
|
|
5
7
|
from numba.typed import List as NumbaList
|
|
6
8
|
from numpy.typing import NDArray
|
|
7
9
|
|
|
8
|
-
from ..cmvm import solve
|
|
9
|
-
from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, QInterval
|
|
10
|
-
from .ops import einsum, reduce
|
|
10
|
+
from ..cmvm.api import solve, solver_options_t
|
|
11
|
+
from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, LookupTable, QInterval
|
|
12
|
+
from .ops import _quantize, einsum, reduce
|
|
11
13
|
|
|
12
14
|
T = TypeVar('T')
|
|
13
15
|
|
|
@@ -42,7 +44,79 @@ def _min_of(a, b):
|
|
|
42
44
|
return min(a, b)
|
|
43
45
|
|
|
44
46
|
|
|
47
|
+
def mmm(mat0: np.ndarray, mat1: np.ndarray):
|
|
48
|
+
shape = mat0.shape[:-1] + mat1.shape[1:]
|
|
49
|
+
mat0, mat1 = mat0.reshape((-1, mat0.shape[-1])), mat1.reshape((mat1.shape[0], -1))
|
|
50
|
+
_shape = (mat0.shape[0], mat1.shape[1])
|
|
51
|
+
_vars = np.empty(_shape, dtype=object)
|
|
52
|
+
for i in range(mat0.shape[0]):
|
|
53
|
+
for j in range(mat1.shape[1]):
|
|
54
|
+
vec0 = mat0[i]
|
|
55
|
+
vec1 = mat1[:, j]
|
|
56
|
+
_vars[i, j] = reduce(lambda x, y: x + y, vec0 * vec1)
|
|
57
|
+
return _vars.reshape(shape)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def cmvm(cm: np.ndarray, v: 'FixedVariableArray', solver_options: solver_options_t) -> np.ndarray:
|
|
61
|
+
mask = offload_mask(cm, v)
|
|
62
|
+
if np.any(mask):
|
|
63
|
+
offload_cm = cm * mask.astype(cm.dtype)
|
|
64
|
+
cm = cm * (~mask).astype(cm.dtype)
|
|
65
|
+
else:
|
|
66
|
+
offload_cm = None
|
|
67
|
+
_qintervals = [QInterval(float(_v.low), float(_v.high), float(_v.step)) for _v in v._vars]
|
|
68
|
+
_latencies = [float(_v.latency) for _v in v._vars]
|
|
69
|
+
qintervals = NumbaList(_qintervals) # type: ignore
|
|
70
|
+
latencies = NumbaList(_latencies) # type: ignore
|
|
71
|
+
hwconf = v._vars.ravel()[0].hwconf
|
|
72
|
+
solver_options.setdefault('adder_size', hwconf.adder_size)
|
|
73
|
+
solver_options.setdefault('carry_size', hwconf.carry_size)
|
|
74
|
+
_mat = np.ascontiguousarray(cm.astype(np.float32))
|
|
75
|
+
sol = solve(_mat, qintervals=qintervals, latencies=latencies, **solver_options)
|
|
76
|
+
_r: np.ndarray = sol(v._vars)
|
|
77
|
+
if offload_cm is not None:
|
|
78
|
+
_r = _r + mmm(v._vars, offload_cm)
|
|
79
|
+
return _r
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def offload_mask(cm: NDArray, v: 'FixedVariableArray') -> NDArray[np.bool_]:
|
|
83
|
+
assert v.ndim == 1
|
|
84
|
+
assert cm.ndim == 2
|
|
85
|
+
assert cm.shape[0] == v.shape[0]
|
|
86
|
+
bits = np.sum(v.kif, axis=0)[:, None]
|
|
87
|
+
return (bits == 0) & (cm != 0)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
_unary_functions = (
|
|
91
|
+
np.sin,
|
|
92
|
+
np.cos,
|
|
93
|
+
np.tan,
|
|
94
|
+
np.exp,
|
|
95
|
+
np.log,
|
|
96
|
+
np.invert,
|
|
97
|
+
np.sqrt,
|
|
98
|
+
np.tanh,
|
|
99
|
+
np.sinh,
|
|
100
|
+
np.cosh,
|
|
101
|
+
np.arccos,
|
|
102
|
+
np.arcsin,
|
|
103
|
+
np.arctan,
|
|
104
|
+
np.arcsinh,
|
|
105
|
+
np.arccosh,
|
|
106
|
+
np.arctanh,
|
|
107
|
+
np.exp2,
|
|
108
|
+
np.expm1,
|
|
109
|
+
np.log2,
|
|
110
|
+
np.log10,
|
|
111
|
+
np.log1p,
|
|
112
|
+
np.cbrt,
|
|
113
|
+
np.reciprocal,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
45
117
|
class FixedVariableArray:
|
|
118
|
+
"""Symbolic array of FixedVariable for tracing operations. Supports numpy ufuncs and array functions."""
|
|
119
|
+
|
|
46
120
|
__array_priority__ = 100
|
|
47
121
|
|
|
48
122
|
def __array_function__(self, func, types, args, kwargs):
|
|
@@ -52,17 +126,19 @@ class FixedVariableArray:
|
|
|
52
126
|
elif len(args) == 2 and isinstance(args[0], np.ndarray) and isinstance(args[1], np.ndarray):
|
|
53
127
|
return self.__rmatmul__(args[1])
|
|
54
128
|
|
|
55
|
-
if func in (np.mean, np.sum, np.amax, np.amin, np.max, np.min):
|
|
129
|
+
if func in (np.mean, np.sum, np.amax, np.amin, np.prod, np.max, np.min):
|
|
56
130
|
match func:
|
|
57
131
|
case np.mean:
|
|
58
|
-
_x = reduce(lambda x, y: x + y,
|
|
132
|
+
_x = reduce(lambda x, y: x + y, *args, **kwargs)
|
|
59
133
|
return _x * (_x.size / self._vars.size)
|
|
60
134
|
case np.sum:
|
|
61
|
-
return reduce(lambda x, y: x + y,
|
|
135
|
+
return reduce(lambda x, y: x + y, *args, **kwargs)
|
|
62
136
|
case np.max | np.amax:
|
|
63
|
-
return reduce(_max_of,
|
|
137
|
+
return reduce(_max_of, *args, **kwargs)
|
|
64
138
|
case np.min | np.amin:
|
|
65
|
-
return reduce(_min_of,
|
|
139
|
+
return reduce(_min_of, *args, **kwargs)
|
|
140
|
+
case np.prod:
|
|
141
|
+
return reduce(lambda x, y: x * y, *args, **kwargs)
|
|
66
142
|
case _:
|
|
67
143
|
raise NotImplementedError(f'Unsupported function: {func}')
|
|
68
144
|
|
|
@@ -86,7 +162,7 @@ class FixedVariableArray:
|
|
|
86
162
|
assert bind.arguments.get('out', None) is None, 'Output argument is not supported'
|
|
87
163
|
return einsum(eq, *operands)
|
|
88
164
|
|
|
89
|
-
if func
|
|
165
|
+
if func is np.dot:
|
|
90
166
|
assert len(args) in (2, 3), 'Dot function requires exactly two or three arguments'
|
|
91
167
|
|
|
92
168
|
assert len(args) == 2
|
|
@@ -107,19 +183,85 @@ class FixedVariableArray:
|
|
|
107
183
|
self.solver_options,
|
|
108
184
|
)
|
|
109
185
|
|
|
186
|
+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
|
187
|
+
assert method == '__call__', f'Only __call__ method is supported for ufuncs, got {method}'
|
|
188
|
+
|
|
189
|
+
match ufunc:
|
|
190
|
+
case np.add | np.subtract | np.multiply | np.true_divide | np.negative:
|
|
191
|
+
inputs = [to_raw_arr(x) for x in inputs]
|
|
192
|
+
return FixedVariableArray(ufunc(*inputs, **kwargs), self.solver_options)
|
|
193
|
+
case np.negative:
|
|
194
|
+
assert len(inputs) == 1
|
|
195
|
+
return FixedVariableArray(ufunc(to_raw_arr(inputs[0]), **kwargs), self.solver_options)
|
|
196
|
+
case np.maximum | np.minimum:
|
|
197
|
+
op = _max_of if ufunc is np.maximum else _min_of
|
|
198
|
+
a, b = np.broadcast_arrays(inputs[0], inputs[1])
|
|
199
|
+
shape = a.shape
|
|
200
|
+
a, b = a.ravel(), b.ravel()
|
|
201
|
+
r = np.empty(a.size, dtype=object)
|
|
202
|
+
for i in range(a.size):
|
|
203
|
+
r[i] = op(a[i], b[i])
|
|
204
|
+
return FixedVariableArray(r.reshape(shape), self.solver_options)
|
|
205
|
+
case np.matmul:
|
|
206
|
+
assert len(inputs) == 2
|
|
207
|
+
assert isinstance(inputs[0], FixedVariableArray) or isinstance(inputs[1], FixedVariableArray)
|
|
208
|
+
if isinstance(inputs[0], FixedVariableArray):
|
|
209
|
+
return inputs[0].matmul(inputs[1])
|
|
210
|
+
else:
|
|
211
|
+
return inputs[1].rmatmul(inputs[0])
|
|
212
|
+
case np.power:
|
|
213
|
+
assert len(inputs) == 2
|
|
214
|
+
base, exp = inputs
|
|
215
|
+
return base**exp
|
|
216
|
+
|
|
217
|
+
case np.abs | np.absolute:
|
|
218
|
+
assert len(inputs) == 1
|
|
219
|
+
assert inputs[0] is self
|
|
220
|
+
mask: np.ndarray = (self.kif[0] == 0).ravel()
|
|
221
|
+
arr = self._vars.ravel()
|
|
222
|
+
|
|
223
|
+
r = np.empty(arr.size, dtype=object)
|
|
224
|
+
for i in range(arr.size):
|
|
225
|
+
if mask[i]:
|
|
226
|
+
r[i] = arr[i]
|
|
227
|
+
continue
|
|
228
|
+
v = arr[i]
|
|
229
|
+
v = v.msb_mux(-v, v)
|
|
230
|
+
v.low = Decimal(0)
|
|
231
|
+
r[i] = v
|
|
232
|
+
return FixedVariableArray(r.reshape(self.shape), self.solver_options)
|
|
233
|
+
|
|
234
|
+
case np.square:
|
|
235
|
+
assert len(inputs) == 1
|
|
236
|
+
assert inputs[0] is self
|
|
237
|
+
return self**2
|
|
238
|
+
|
|
239
|
+
if ufunc in _unary_functions:
|
|
240
|
+
assert len(inputs) == 1
|
|
241
|
+
assert inputs[0] is self
|
|
242
|
+
return self.apply(ufunc)
|
|
243
|
+
|
|
244
|
+
raise NotImplementedError(f'Unsupported ufunc: {ufunc}')
|
|
245
|
+
|
|
110
246
|
def __init__(
|
|
111
247
|
self,
|
|
112
248
|
vars: NDArray,
|
|
113
|
-
solver_options:
|
|
249
|
+
solver_options: solver_options_t | None = None,
|
|
114
250
|
):
|
|
115
|
-
|
|
251
|
+
_vars = np.array(vars)
|
|
252
|
+
_vars_f = _vars.ravel()
|
|
253
|
+
hwconf = next(iter(v for v in _vars_f if isinstance(v, FixedVariable))).hwconf
|
|
254
|
+
for i, v in enumerate(_vars_f):
|
|
255
|
+
if not isinstance(v, FixedVariable):
|
|
256
|
+
_vars_f[i] = FixedVariable(float(v), float(v), 1.0, hwconf=hwconf)
|
|
257
|
+
self._vars = _vars
|
|
116
258
|
_solver_options = signature(solve).parameters
|
|
117
259
|
_solver_options = {k: v.default for k, v in _solver_options.items() if v.default is not v.empty}
|
|
118
260
|
if solver_options is not None:
|
|
119
261
|
_solver_options.update(solver_options)
|
|
120
262
|
_solver_options.pop('qintervals', None)
|
|
121
263
|
_solver_options.pop('latencies', None)
|
|
122
|
-
self.solver_options = _solver_options
|
|
264
|
+
self.solver_options: solver_options_t = _solver_options # type: ignore
|
|
123
265
|
|
|
124
266
|
@classmethod
|
|
125
267
|
def from_lhs(
|
|
@@ -129,7 +271,7 @@ class FixedVariableArray:
|
|
|
129
271
|
step: NDArray[np.floating],
|
|
130
272
|
hwconf: HWConfig,
|
|
131
273
|
latency: np.ndarray | float = 0.0,
|
|
132
|
-
solver_options:
|
|
274
|
+
solver_options: solver_options_t | None = None,
|
|
133
275
|
):
|
|
134
276
|
shape = low.shape
|
|
135
277
|
assert shape == high.shape == step.shape
|
|
@@ -162,7 +304,7 @@ class FixedVariableArray:
|
|
|
162
304
|
f: NDArray[np.integer],
|
|
163
305
|
hwconf: HWConfig,
|
|
164
306
|
latency: NDArray[np.floating] | float = 0.0,
|
|
165
|
-
solver_options:
|
|
307
|
+
solver_options: solver_options_t | None = None,
|
|
166
308
|
):
|
|
167
309
|
mask = k + i + f <= 0
|
|
168
310
|
k = np.where(mask, 0, k)
|
|
@@ -173,47 +315,34 @@ class FixedVariableArray:
|
|
|
173
315
|
high, low = _high - step, -_high * k
|
|
174
316
|
return cls.from_lhs(low, high, step, hwconf, latency, solver_options)
|
|
175
317
|
|
|
176
|
-
def
|
|
318
|
+
def matmul(self, other):
|
|
177
319
|
if isinstance(other, FixedVariableArray):
|
|
178
320
|
other = other._vars
|
|
179
321
|
if not isinstance(other, np.ndarray):
|
|
180
322
|
other = np.array(other)
|
|
181
323
|
if any(isinstance(x, FixedVariable) for x in other.ravel()):
|
|
182
324
|
mat0, mat1 = self._vars, other
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
for i in range(mat0.shape[0]):
|
|
188
|
-
for j in range(mat1.shape[1]):
|
|
189
|
-
vec0 = mat0[i]
|
|
190
|
-
vec1 = mat1[:, j]
|
|
191
|
-
_vars[i, j] = reduce(lambda x, y: x + y, vec0 * vec1)
|
|
192
|
-
return FixedVariableArray(_vars.reshape(shape), self.solver_options)
|
|
193
|
-
|
|
194
|
-
kwargs = (self.solver_options or {}).copy()
|
|
325
|
+
_vars = mmm(mat0, mat1)
|
|
326
|
+
return FixedVariableArray(_vars, self.solver_options)
|
|
327
|
+
|
|
328
|
+
solver_options = (self.solver_options or {}).copy()
|
|
195
329
|
shape0, shape1 = self.shape, other.shape
|
|
196
330
|
assert shape0[-1] == shape1[0], f'Matrix shapes do not match: {shape0} @ {shape1}'
|
|
197
|
-
|
|
331
|
+
contract_len = shape1[0]
|
|
198
332
|
out_shape = shape0[:-1] + shape1[1:]
|
|
199
|
-
mat0, mat1 = self.reshape((-1,
|
|
333
|
+
mat0, mat1 = self.reshape((-1, contract_len)), other.reshape((contract_len, -1))
|
|
200
334
|
r = []
|
|
201
335
|
for i in range(mat0.shape[0]):
|
|
202
336
|
vec = mat0[i]
|
|
203
|
-
|
|
204
|
-
_latencies = [float(v.latency) for v in vec._vars]
|
|
205
|
-
qintervals = NumbaList(_qintervals) # type: ignore
|
|
206
|
-
latencies = NumbaList(_latencies) # type: ignore
|
|
207
|
-
hwconf = self._vars.ravel()[0].hwconf
|
|
208
|
-
kwargs.update(adder_size=hwconf.adder_size, carry_size=hwconf.carry_size)
|
|
209
|
-
_mat = np.ascontiguousarray(mat1.astype(np.float32))
|
|
210
|
-
sol = solve(_mat, qintervals=qintervals, latencies=latencies, **kwargs)
|
|
211
|
-
_r = sol(vec._vars)
|
|
337
|
+
_r = cmvm(mat1, vec, solver_options)
|
|
212
338
|
r.append(_r)
|
|
213
339
|
r = np.array(r).reshape(out_shape)
|
|
214
340
|
return FixedVariableArray(r, self.solver_options)
|
|
215
341
|
|
|
216
|
-
def
|
|
342
|
+
def __matmul__(self, other):
|
|
343
|
+
return self.matmul(other)
|
|
344
|
+
|
|
345
|
+
def rmatmul(self, other):
|
|
217
346
|
mat1 = np.moveaxis(other, -1, 0)
|
|
218
347
|
mat0 = np.moveaxis(self, 0, -1) # type: ignore
|
|
219
348
|
ndim0, ndim1 = mat0.ndim, mat1.ndim
|
|
@@ -223,6 +352,9 @@ class FixedVariableArray:
|
|
|
223
352
|
axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
|
|
224
353
|
return r.transpose(axes)
|
|
225
354
|
|
|
355
|
+
def __rmatmul__(self, other):
|
|
356
|
+
return self.rmatmul(other)
|
|
357
|
+
|
|
226
358
|
def __getitem__(self, item):
|
|
227
359
|
vars = self._vars[item]
|
|
228
360
|
if isinstance(vars, np.ndarray):
|
|
@@ -269,10 +401,17 @@ class FixedVariableArray:
|
|
|
269
401
|
|
|
270
402
|
def __pow__(self, power: int | float):
|
|
271
403
|
_power = int(power)
|
|
272
|
-
|
|
273
|
-
|
|
404
|
+
if _power == power and _power >= 0:
|
|
405
|
+
return FixedVariableArray(self._vars**_power, self.solver_options)
|
|
406
|
+
else:
|
|
407
|
+
return self.apply(lambda x: x**power)
|
|
274
408
|
|
|
275
|
-
def relu(
|
|
409
|
+
def relu(
|
|
410
|
+
self,
|
|
411
|
+
i: NDArray[np.integer] | None = None,
|
|
412
|
+
f: NDArray[np.integer] | None = None,
|
|
413
|
+
round_mode: str = 'TRN',
|
|
414
|
+
):
|
|
276
415
|
shape = self._vars.shape
|
|
277
416
|
i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
|
|
278
417
|
f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
|
|
@@ -290,9 +429,11 @@ class FixedVariableArray:
|
|
|
290
429
|
round_mode: str = 'TRN',
|
|
291
430
|
):
|
|
292
431
|
shape = self._vars.shape
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
432
|
+
if any(x is None for x in (k, i, f)):
|
|
433
|
+
kif = self.kif
|
|
434
|
+
k = np.broadcast_to(k, shape) if k is not None else kif[0] # type: ignore
|
|
435
|
+
i = np.broadcast_to(i, shape) if i is not None else kif[1] # type: ignore
|
|
436
|
+
f = np.broadcast_to(f, shape) if f is not None else kif[2] # type: ignore
|
|
296
437
|
ret = []
|
|
297
438
|
for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()): # type: ignore
|
|
298
439
|
ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
|
|
@@ -324,17 +465,28 @@ class FixedVariableArray:
|
|
|
324
465
|
|
|
325
466
|
@property
|
|
326
467
|
def kif(self):
|
|
468
|
+
"""[k, i, f] array"""
|
|
327
469
|
shape = self._vars.shape
|
|
328
470
|
kif = np.array([v.kif for v in self._vars.ravel()]).reshape(*shape, 3)
|
|
329
471
|
return np.moveaxis(kif, -1, 0)
|
|
330
472
|
|
|
473
|
+
def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
|
|
474
|
+
"""Apply a unary operator to all elements, returning a RetardedFixedVariableArray."""
|
|
475
|
+
return RetardedFixedVariableArray(
|
|
476
|
+
self._vars,
|
|
477
|
+
self.solver_options,
|
|
478
|
+
operator=fn,
|
|
479
|
+
)
|
|
480
|
+
|
|
331
481
|
|
|
332
482
|
class FixedVariableArrayInput(FixedVariableArray):
|
|
483
|
+
"""Similar to FixedVariableArray, but initializes all elements as FixedVariableInput - the precisions are unspecified when initialized, and the highest precision requested (i.e., quantized to) will be recorded for generation of the logic."""
|
|
484
|
+
|
|
333
485
|
def __init__(
|
|
334
486
|
self,
|
|
335
487
|
shape: tuple[int, ...] | int,
|
|
336
|
-
hwconf: HWConfig = HWConfig(1, -1, -1),
|
|
337
|
-
solver_options:
|
|
488
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
489
|
+
solver_options: solver_options_t | None = None,
|
|
338
490
|
latency=0.0,
|
|
339
491
|
):
|
|
340
492
|
_vars = np.empty(shape, dtype=object)
|
|
@@ -342,3 +494,79 @@ class FixedVariableArrayInput(FixedVariableArray):
|
|
|
342
494
|
for i in range(_vars.size):
|
|
343
495
|
_vars_f[i] = FixedVariableInput(latency, hwconf)
|
|
344
496
|
super().__init__(_vars, solver_options)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def make_table(fn: Callable[[NDArray], NDArray], qint: QInterval) -> LookupTable:
|
|
500
|
+
low, high, step = qint
|
|
501
|
+
n = round((high - low) / step) + 1
|
|
502
|
+
return LookupTable(fn(np.linspace(low, high, n)))
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class RetardedFixedVariableArray(FixedVariableArray):
|
|
506
|
+
"""Ephemeral FixedVariableArray generated from operations of unspecified output precision.
|
|
507
|
+
This object translates to normal FixedVariableArray upon quantization.
|
|
508
|
+
Does not inherit the maximum precision like FixedVariableArrayInput.
|
|
509
|
+
|
|
510
|
+
This object can be used in two ways:
|
|
511
|
+
1. Quantization with specified precision, which converts to FixedVariableArray.
|
|
512
|
+
2. Apply an further unary operation, which returns another RetardedFixedVariableArray. (e.g., composite functions)
|
|
513
|
+
"""
|
|
514
|
+
|
|
515
|
+
def __init__(self, vars: NDArray, solver_options: solver_options_t | None, operator: Callable[[NDArray], NDArray]):
|
|
516
|
+
self._operator = operator
|
|
517
|
+
super().__init__(vars, solver_options)
|
|
518
|
+
|
|
519
|
+
def __array_function__(self, ufunc, method, *inputs, **kwargs):
|
|
520
|
+
raise RuntimeError('RetardedFixedVariableArray only supports quantization or further unary operations.')
|
|
521
|
+
|
|
522
|
+
def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
|
|
523
|
+
return RetardedFixedVariableArray(
|
|
524
|
+
self._vars,
|
|
525
|
+
self.solver_options,
|
|
526
|
+
operator=lambda x: fn(self._operator(x)),
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
def quantize(
|
|
530
|
+
self,
|
|
531
|
+
k: NDArray[np.integer] | np.integer | int | None = None,
|
|
532
|
+
i: NDArray[np.integer] | np.integer | int | None = None,
|
|
533
|
+
f: NDArray[np.integer] | np.integer | int | None = None,
|
|
534
|
+
overflow_mode: str = 'WRAP',
|
|
535
|
+
round_mode: str = 'TRN',
|
|
536
|
+
):
|
|
537
|
+
if any(x is None for x in (k, i, f)):
|
|
538
|
+
assert all(x is not None for x in (k, i, f)), 'Either all or none of k, i, f must be specified'
|
|
539
|
+
_k = _i = _f = [None] * self.size
|
|
540
|
+
else:
|
|
541
|
+
_k = np.broadcast_to(k, self.shape).ravel() # type: ignore
|
|
542
|
+
_i = np.broadcast_to(i, self.shape).ravel() # type: ignore
|
|
543
|
+
_f = np.broadcast_to(f, self.shape).ravel() # type: ignore
|
|
544
|
+
|
|
545
|
+
op = lambda x: _quantize(self._operator(x), k, i, f, overflow_mode, round_mode) # type: ignore
|
|
546
|
+
|
|
547
|
+
local_tables: dict[tuple[QInterval, tuple[int, int, int] | None], LookupTable] = {}
|
|
548
|
+
variables = []
|
|
549
|
+
for v, _kk, _ii, _ff in zip(self._vars.ravel(), _k, _i, _f):
|
|
550
|
+
if (_kk is None) or (_ii is None) or (_ff is None):
|
|
551
|
+
op = self._operator
|
|
552
|
+
_key = v.qint
|
|
553
|
+
else:
|
|
554
|
+
op = lambda x: _quantize(self._operator(x), _kk, _ii, _ff, overflow_mode, round_mode) # type: ignore
|
|
555
|
+
_key = (v.qint, (int(_kk), int(_ii), int(_ff)))
|
|
556
|
+
|
|
557
|
+
if _key in local_tables:
|
|
558
|
+
table = local_tables[_key]
|
|
559
|
+
else:
|
|
560
|
+
table = make_table(op, v.qint)
|
|
561
|
+
local_tables[_key] = table
|
|
562
|
+
variables.append(v.lookup(table))
|
|
563
|
+
|
|
564
|
+
variables = np.array(variables).reshape(self._vars.shape)
|
|
565
|
+
return FixedVariableArray(variables, self.solver_options)
|
|
566
|
+
|
|
567
|
+
def __repr__(self):
|
|
568
|
+
return 'Retarded' + super().__repr__()
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def kif(self):
|
|
572
|
+
raise RuntimeError('RetardedFixedVariableArray does not have defined kif until quantized.')
|
da4ml/trace/ops/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, TypeVar
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from numpy.typing import NDArray
|
|
5
|
+
from quantizers.fixed_point.fixed_point_ops_np import get_fixed_quantizer_np
|
|
5
6
|
|
|
6
7
|
from ..fixed_variable_array import FixedVariable
|
|
7
8
|
from .conv_utils import conv, pool
|
|
@@ -22,9 +23,11 @@ def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | No
|
|
|
22
23
|
elif isinstance(x, list):
|
|
23
24
|
return [xx.relu(i=ii, f=ff, round_mode=round_mode) for xx, ii, ff in zip(x, i, f)] # type: ignore
|
|
24
25
|
else:
|
|
26
|
+
round_mode = round_mode.upper()
|
|
27
|
+
assert round_mode in ('TRN', 'RND')
|
|
25
28
|
x = np.maximum(x, 0)
|
|
26
29
|
if f is not None:
|
|
27
|
-
if round_mode
|
|
30
|
+
if round_mode == 'RND':
|
|
28
31
|
x += 2.0 ** (-f - 1)
|
|
29
32
|
sf = 2.0**f
|
|
30
33
|
x = np.floor(x * sf) / sf
|
|
@@ -33,6 +36,18 @@ def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | No
|
|
|
33
36
|
return x
|
|
34
37
|
|
|
35
38
|
|
|
39
|
+
def _quantize(
|
|
40
|
+
x: NDArray[np.floating],
|
|
41
|
+
k: NDArray[np.integer] | np.integer | int,
|
|
42
|
+
i: NDArray[np.integer] | np.integer | int,
|
|
43
|
+
f: NDArray[np.integer] | np.integer | int,
|
|
44
|
+
overflow_mode: str = 'WRAP',
|
|
45
|
+
round_mode: str = 'TRN',
|
|
46
|
+
) -> NDArray[np.floating]:
|
|
47
|
+
q = get_fixed_quantizer_np(round_mode=round_mode, overflow_mode=overflow_mode)
|
|
48
|
+
return q(x, k=k, i=i, f=f) # type: ignore
|
|
49
|
+
|
|
50
|
+
|
|
36
51
|
def quantize(
|
|
37
52
|
x: T,
|
|
38
53
|
k: NDArray[np.integer] | np.integer | int,
|
|
@@ -43,22 +58,23 @@ def quantize(
|
|
|
43
58
|
) -> T:
|
|
44
59
|
from ..fixed_variable_array import FixedVariableArray
|
|
45
60
|
|
|
46
|
-
if isinstance(x, FixedVariableArray):
|
|
61
|
+
if isinstance(x, (FixedVariableArray, FixedVariable)):
|
|
47
62
|
return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
|
|
63
|
+
elif isinstance(x, list):
|
|
64
|
+
ret: list[FixedVariable] = []
|
|
65
|
+
for i in range(len(x)):
|
|
66
|
+
ret.append(
|
|
67
|
+
x[i].quantize(
|
|
68
|
+
k=int(k[i] if isinstance(k, (list, np.ndarray)) else k),
|
|
69
|
+
i=int(i[i] if isinstance(i, (list, np.ndarray)) else i),
|
|
70
|
+
f=int(f[i] if isinstance(f, (list, np.ndarray)) else f),
|
|
71
|
+
overflow_mode=overflow_mode,
|
|
72
|
+
round_mode=round_mode,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
return ret # type: ignore
|
|
48
76
|
else:
|
|
49
|
-
x
|
|
50
|
-
if overflow_mode in ('SAT', 'SAT_SYM'):
|
|
51
|
-
step = 2.0**-f
|
|
52
|
-
_high = 2.0**i
|
|
53
|
-
high = _high - step
|
|
54
|
-
low = -_high * k if overflow_mode == 'SAT' else -high * k
|
|
55
|
-
x = np.clip(x, low, high) # type: ignore
|
|
56
|
-
if round_mode.upper() == 'RND':
|
|
57
|
-
x += 2.0 ** (-f - 1) # type: ignore
|
|
58
|
-
b = k + i + f
|
|
59
|
-
bias = 2.0 ** (b - 1) * k
|
|
60
|
-
eps = 2.0**-f
|
|
61
|
-
return eps * ((np.floor(x / eps) + bias) % 2.0**b - bias) # type: ignore
|
|
77
|
+
return _quantize(x, k, i, f, overflow_mode, round_mode)
|
|
62
78
|
|
|
63
79
|
|
|
64
80
|
__all__ = [
|
da4ml/trace/ops/reduce_utils.py
CHANGED
|
@@ -100,6 +100,6 @@ def reduce(operator: Callable[[T, T], T], x: TA, axis: int | Sequence[int] | Non
|
|
|
100
100
|
|
|
101
101
|
if isinstance(x, FixedVariableArray):
|
|
102
102
|
r = FixedVariableArray(r, solver_config)
|
|
103
|
-
if r.
|
|
104
|
-
return r.
|
|
105
|
-
return r if r.
|
|
103
|
+
if r.shape == ():
|
|
104
|
+
return r._vars.item() # type: ignore
|
|
105
|
+
return r if r.shape != () or keepdims else r.item() # type: ignore
|
da4ml/trace/pipeline.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from math import ceil, floor
|
|
2
2
|
|
|
3
|
-
from ..cmvm.types import
|
|
3
|
+
from ..cmvm.types import CombLogic, Op, Pipeline
|
|
4
4
|
from .fixed_variable import FixedVariable, HWConfig
|
|
5
5
|
from .tracer import comb_trace
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def retime_pipeline(csol:
|
|
8
|
+
def retime_pipeline(csol: Pipeline, verbose=True):
|
|
9
9
|
n_stages = len(csol[0])
|
|
10
10
|
cutoff_high = ceil(max(max(sol.out_latency) / (i + 1) for i, sol in enumerate(csol[0])))
|
|
11
11
|
cutoff_low = 0
|
|
@@ -60,14 +60,14 @@ def _get_new_idx(
|
|
|
60
60
|
return p0_idx
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def to_pipeline(
|
|
63
|
+
def to_pipeline(comb: CombLogic, latency_cutoff: float, retiming=True, verbose=True) -> Pipeline:
|
|
64
64
|
"""Split the record into multiple stages based on the latency of the operations.
|
|
65
65
|
Only useful for HDL generation.
|
|
66
66
|
|
|
67
67
|
Parameters
|
|
68
68
|
----------
|
|
69
|
-
sol :
|
|
70
|
-
The
|
|
69
|
+
sol : CombLogic
|
|
70
|
+
The combinational logic to be pipelined into multiple stages.
|
|
71
71
|
latency_cutoff : float
|
|
72
72
|
The latency cutoff for splitting the operations.
|
|
73
73
|
retiming : bool
|
|
@@ -83,8 +83,8 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
|
|
|
83
83
|
CascadedSolution
|
|
84
84
|
The cascaded solution with multiple stages.
|
|
85
85
|
"""
|
|
86
|
-
assert len(
|
|
87
|
-
for i, op in enumerate(
|
|
86
|
+
assert len(comb.ops) > 0, 'No operations in the record'
|
|
87
|
+
for i, op in enumerate(comb.ops):
|
|
88
88
|
if op.id1 != -1:
|
|
89
89
|
break
|
|
90
90
|
|
|
@@ -96,9 +96,9 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
|
|
|
96
96
|
|
|
97
97
|
locator: list[dict[int, int]] = []
|
|
98
98
|
|
|
99
|
-
ops =
|
|
100
|
-
lat = max(ops[i].latency for i in
|
|
101
|
-
for i in
|
|
99
|
+
ops = comb.ops.copy()
|
|
100
|
+
lat = max(ops[i].latency for i in comb.out_idxs)
|
|
101
|
+
for i in comb.out_idxs:
|
|
102
102
|
op_out = ops[i]
|
|
103
103
|
ops.append(Op(i, -1001, -1001, 0, op_out.qint, lat, 0.0))
|
|
104
104
|
|
|
@@ -113,7 +113,10 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
|
|
|
113
113
|
p0_idx = _get_new_idx(op.id0, locator, opd, out_idxd, ops, stage, latency_cutoff)
|
|
114
114
|
p1_idx = _get_new_idx(op.id1, locator, opd, out_idxd, ops, stage, latency_cutoff)
|
|
115
115
|
if op.opcode in (6, -6):
|
|
116
|
-
|
|
116
|
+
k = op.data & 0xFFFFFFFF
|
|
117
|
+
_shift = (op.data >> 32) & 0xFFFFFFFF
|
|
118
|
+
k = _get_new_idx(k, locator, opd, out_idxd, ops, stage, latency_cutoff)
|
|
119
|
+
data = _shift << 32 | k
|
|
117
120
|
else:
|
|
118
121
|
data = op.data
|
|
119
122
|
|
|
@@ -126,34 +129,53 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
|
|
|
126
129
|
locator.append({stage: len(opd[stage]) - 1})
|
|
127
130
|
sols = []
|
|
128
131
|
max_stage = max(opd.keys())
|
|
129
|
-
n_in =
|
|
132
|
+
n_in = comb.shape[0]
|
|
130
133
|
for i, stage in enumerate(opd.keys()):
|
|
131
134
|
_ops = opd[stage]
|
|
132
135
|
_out_idx = out_idxd[stage]
|
|
133
136
|
n_out = len(_out_idx)
|
|
134
137
|
|
|
135
138
|
if i == max_stage:
|
|
136
|
-
out_shifts =
|
|
137
|
-
out_negs =
|
|
139
|
+
out_shifts = comb.out_shifts
|
|
140
|
+
out_negs = comb.out_negs
|
|
138
141
|
else:
|
|
139
142
|
out_shifts = [0] * len(_out_idx)
|
|
140
143
|
out_negs = [False] * len(_out_idx)
|
|
141
144
|
|
|
142
|
-
|
|
145
|
+
if comb.lookup_tables is not None:
|
|
146
|
+
_ops, lookup_tables = remap_table_idxs(comb, _ops)
|
|
147
|
+
else:
|
|
148
|
+
lookup_tables = None
|
|
149
|
+
_sol = CombLogic(
|
|
143
150
|
shape=(n_in, n_out),
|
|
144
151
|
inp_shift=[0] * n_in,
|
|
145
152
|
out_idxs=_out_idx,
|
|
146
153
|
out_shifts=out_shifts,
|
|
147
154
|
out_negs=out_negs,
|
|
148
155
|
ops=_ops,
|
|
149
|
-
carry_size=
|
|
150
|
-
adder_size=
|
|
156
|
+
carry_size=comb.carry_size,
|
|
157
|
+
adder_size=comb.adder_size,
|
|
158
|
+
lookup_tables=lookup_tables,
|
|
151
159
|
)
|
|
152
160
|
sols.append(_sol)
|
|
153
161
|
|
|
154
162
|
n_in = n_out
|
|
155
|
-
csol =
|
|
163
|
+
csol = Pipeline(tuple(sols))
|
|
156
164
|
|
|
157
165
|
if retiming:
|
|
158
166
|
csol = retime_pipeline(csol, verbose=verbose)
|
|
159
167
|
return csol
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def remap_table_idxs(comb: CombLogic, _ops):
|
|
171
|
+
assert comb.lookup_tables is not None
|
|
172
|
+
table_idxs = sorted(list({op.data for op in _ops if op.opcode == 8}))
|
|
173
|
+
remap = {j: i for i, j in enumerate(table_idxs)}
|
|
174
|
+
_ops_remap = []
|
|
175
|
+
for op in _ops:
|
|
176
|
+
if op.opcode == 8:
|
|
177
|
+
op = Op(op.id0, op.id1, op.opcode, remap[op.data], op.qint, op.latency, op.cost)
|
|
178
|
+
_ops_remap.append(op)
|
|
179
|
+
_ops = _ops_remap
|
|
180
|
+
lookup_tables = tuple(comb.lookup_tables[i] for i in table_idxs)
|
|
181
|
+
return _ops, lookup_tables
|