da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +204 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +246 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.1.post1.dist-info/METADATA +85 -0
- da4ml-0.5.1.post1.dist-info/RECORD +96 -0
- da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
- da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
- da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from inspect import signature
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numba.typed import List as NumbaList
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from ..cmvm.api import solve, solver_options_t
|
|
10
|
+
from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, LookupTable, QInterval
|
|
11
|
+
from .ops import _quantize, einsum, reduce
|
|
12
|
+
|
|
13
|
+
T = TypeVar('T')
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def to_raw_arr(obj: T) -> T:
|
|
17
|
+
if isinstance(obj, tuple):
|
|
18
|
+
return tuple(to_raw_arr(x) for x in obj) # type: ignore
|
|
19
|
+
elif isinstance(obj, list):
|
|
20
|
+
return [to_raw_arr(x) for x in obj] # type: ignore
|
|
21
|
+
elif isinstance(obj, dict):
|
|
22
|
+
return {k: to_raw_arr(v) for k, v in obj.items()} # type: ignore
|
|
23
|
+
if isinstance(obj, FixedVariableArray):
|
|
24
|
+
return obj._vars # type: ignore
|
|
25
|
+
return obj
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _max_of(a, b):
|
|
29
|
+
if isinstance(a, FixedVariable):
|
|
30
|
+
return a.max_of(b)
|
|
31
|
+
elif isinstance(b, FixedVariable):
|
|
32
|
+
return b.max_of(a)
|
|
33
|
+
else:
|
|
34
|
+
return max(a, b)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _min_of(a, b):
|
|
38
|
+
if isinstance(a, FixedVariable):
|
|
39
|
+
return a.min_of(b)
|
|
40
|
+
elif isinstance(b, FixedVariable):
|
|
41
|
+
return b.min_of(a)
|
|
42
|
+
else:
|
|
43
|
+
return min(a, b)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def mmm(mat0: np.ndarray, mat1: np.ndarray):
|
|
47
|
+
shape = mat0.shape[:-1] + mat1.shape[1:]
|
|
48
|
+
mat0, mat1 = mat0.reshape((-1, mat0.shape[-1])), mat1.reshape((mat1.shape[0], -1))
|
|
49
|
+
_shape = (mat0.shape[0], mat1.shape[1])
|
|
50
|
+
_vars = np.empty(_shape, dtype=object)
|
|
51
|
+
for i in range(mat0.shape[0]):
|
|
52
|
+
for j in range(mat1.shape[1]):
|
|
53
|
+
vec0 = mat0[i]
|
|
54
|
+
vec1 = mat1[:, j]
|
|
55
|
+
_vars[i, j] = reduce(lambda x, y: x + y, vec0 * vec1)
|
|
56
|
+
return _vars.reshape(shape)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def cmvm(cm: np.ndarray, v: 'FixedVariableArray', solver_options: solver_options_t) -> np.ndarray:
|
|
60
|
+
mask = offload_mask(cm, v)
|
|
61
|
+
if np.any(mask):
|
|
62
|
+
offload_cm = cm * mask.astype(cm.dtype)
|
|
63
|
+
cm = cm * (~mask).astype(cm.dtype)
|
|
64
|
+
else:
|
|
65
|
+
offload_cm = None
|
|
66
|
+
_qintervals = [QInterval(float(_v.low), float(_v.high), float(_v.step)) for _v in v._vars]
|
|
67
|
+
_latencies = [float(_v.latency) for _v in v._vars]
|
|
68
|
+
qintervals = NumbaList(_qintervals) # type: ignore
|
|
69
|
+
latencies = NumbaList(_latencies) # type: ignore
|
|
70
|
+
hwconf = v._vars.ravel()[0].hwconf
|
|
71
|
+
solver_options.setdefault('adder_size', hwconf.adder_size)
|
|
72
|
+
solver_options.setdefault('carry_size', hwconf.carry_size)
|
|
73
|
+
_mat = np.ascontiguousarray(cm.astype(np.float32))
|
|
74
|
+
sol = solve(_mat, qintervals=qintervals, latencies=latencies, **solver_options)
|
|
75
|
+
_r: np.ndarray = sol(v._vars)
|
|
76
|
+
if offload_cm is not None:
|
|
77
|
+
_r = _r + mmm(v._vars, offload_cm)
|
|
78
|
+
return _r
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def offload_mask(cm: NDArray, v: 'FixedVariableArray') -> NDArray[np.bool_]:
|
|
82
|
+
assert v.ndim == 1
|
|
83
|
+
assert cm.ndim == 2
|
|
84
|
+
assert cm.shape[0] == v.shape[0]
|
|
85
|
+
bits = np.sum(v.kif, axis=0)[:, None]
|
|
86
|
+
return (bits == 0) & (cm != 0)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
_unary_functions = (
|
|
90
|
+
np.sin,
|
|
91
|
+
np.cos,
|
|
92
|
+
np.tan,
|
|
93
|
+
np.exp,
|
|
94
|
+
np.log,
|
|
95
|
+
np.invert,
|
|
96
|
+
np.sqrt,
|
|
97
|
+
np.tanh,
|
|
98
|
+
np.sinh,
|
|
99
|
+
np.cosh,
|
|
100
|
+
np.arccos,
|
|
101
|
+
np.arcsin,
|
|
102
|
+
np.arctan,
|
|
103
|
+
np.arcsinh,
|
|
104
|
+
np.arccosh,
|
|
105
|
+
np.arctanh,
|
|
106
|
+
np.exp2,
|
|
107
|
+
np.expm1,
|
|
108
|
+
np.log2,
|
|
109
|
+
np.log10,
|
|
110
|
+
np.log1p,
|
|
111
|
+
np.cbrt,
|
|
112
|
+
np.reciprocal,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class FixedVariableArray:
|
|
117
|
+
"""Symbolic array of FixedVariable for tracing operations. Supports numpy ufuncs and array functions."""
|
|
118
|
+
|
|
119
|
+
__array_priority__ = 100
|
|
120
|
+
|
|
121
|
+
def __array_function__(self, func, types, args, kwargs):
|
|
122
|
+
if func in (np.mean, np.sum, np.amax, np.amin, np.prod, np.max, np.min):
|
|
123
|
+
match func:
|
|
124
|
+
case np.mean:
|
|
125
|
+
_x = reduce(lambda x, y: x + y, *args, **kwargs)
|
|
126
|
+
return _x * (_x.size / self._vars.size)
|
|
127
|
+
case np.sum:
|
|
128
|
+
return reduce(lambda x, y: x + y, *args, **kwargs)
|
|
129
|
+
case np.max | np.amax:
|
|
130
|
+
return reduce(_max_of, *args, **kwargs)
|
|
131
|
+
case np.min | np.amin:
|
|
132
|
+
return reduce(_min_of, *args, **kwargs)
|
|
133
|
+
case np.prod:
|
|
134
|
+
return reduce(lambda x, y: x * y, *args, **kwargs)
|
|
135
|
+
case _:
|
|
136
|
+
raise NotImplementedError(f'Unsupported function: {func}')
|
|
137
|
+
|
|
138
|
+
if func is np.clip:
|
|
139
|
+
assert len(args) == 3, 'Clip function requires exactly three arguments'
|
|
140
|
+
x, low, high = args
|
|
141
|
+
_x, low, high = np.broadcast_arrays(x, low, high)
|
|
142
|
+
x = FixedVariableArray(_x, self.solver_options)
|
|
143
|
+
x = np.amax(np.stack((x, low), axis=-1), axis=-1) # type: ignore
|
|
144
|
+
return np.amin(np.stack((x, high), axis=-1), axis=-1)
|
|
145
|
+
|
|
146
|
+
if func is np.einsum:
|
|
147
|
+
# assert len(args) == 2
|
|
148
|
+
sig = signature(np.einsum)
|
|
149
|
+
bind = sig.bind(*args, **kwargs)
|
|
150
|
+
eq = args[0]
|
|
151
|
+
operands = bind.arguments['operands']
|
|
152
|
+
if isinstance(operands[0], str):
|
|
153
|
+
operands = operands[1:]
|
|
154
|
+
assert len(operands) == 2, 'Einsum on FixedVariableArray requires exactly two operands'
|
|
155
|
+
assert bind.arguments.get('out', None) is None, 'Output argument is not supported'
|
|
156
|
+
return einsum(eq, *operands)
|
|
157
|
+
|
|
158
|
+
if func is np.dot:
|
|
159
|
+
assert len(args) in (2, 3), 'Dot function requires exactly two or three arguments'
|
|
160
|
+
|
|
161
|
+
assert len(args) == 2
|
|
162
|
+
a, b = args
|
|
163
|
+
if not isinstance(a, FixedVariableArray):
|
|
164
|
+
a = np.array(a)
|
|
165
|
+
if not isinstance(b, FixedVariableArray):
|
|
166
|
+
b = np.array(b)
|
|
167
|
+
if a.shape[-1] == b.shape[0]:
|
|
168
|
+
return a @ b
|
|
169
|
+
|
|
170
|
+
assert a.size == 1 or b.size == 1, f'Error in dot product: {a.shape} @ {b.shape}'
|
|
171
|
+
return a * b
|
|
172
|
+
|
|
173
|
+
args, kwargs = to_raw_arr(args), to_raw_arr(kwargs)
|
|
174
|
+
return FixedVariableArray(
|
|
175
|
+
func(*args, **kwargs),
|
|
176
|
+
self.solver_options,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
|
180
|
+
assert method == '__call__', f'Only __call__ method is supported for ufuncs, got {method}'
|
|
181
|
+
|
|
182
|
+
match ufunc:
|
|
183
|
+
case np.add | np.subtract | np.multiply | np.true_divide | np.negative:
|
|
184
|
+
inputs = [to_raw_arr(x) for x in inputs]
|
|
185
|
+
return FixedVariableArray(ufunc(*inputs, **kwargs), self.solver_options)
|
|
186
|
+
|
|
187
|
+
case np.negative:
|
|
188
|
+
assert len(inputs) == 1
|
|
189
|
+
return FixedVariableArray(ufunc(to_raw_arr(inputs[0]), **kwargs), self.solver_options)
|
|
190
|
+
|
|
191
|
+
case np.maximum | np.minimum:
|
|
192
|
+
op = _max_of if ufunc is np.maximum else _min_of
|
|
193
|
+
a, b = np.broadcast_arrays(inputs[0], inputs[1])
|
|
194
|
+
shape = a.shape
|
|
195
|
+
a, b = a.ravel(), b.ravel()
|
|
196
|
+
r = np.empty(a.size, dtype=object)
|
|
197
|
+
for i in range(a.size):
|
|
198
|
+
r[i] = op(a[i], b[i])
|
|
199
|
+
return FixedVariableArray(r.reshape(shape), self.solver_options)
|
|
200
|
+
|
|
201
|
+
case np.matmul:
|
|
202
|
+
assert len(inputs) == 2
|
|
203
|
+
assert isinstance(inputs[0], FixedVariableArray) or isinstance(inputs[1], FixedVariableArray)
|
|
204
|
+
if isinstance(inputs[0], FixedVariableArray):
|
|
205
|
+
return inputs[0].matmul(inputs[1])
|
|
206
|
+
else:
|
|
207
|
+
return inputs[1].rmatmul(inputs[0])
|
|
208
|
+
|
|
209
|
+
case np.power:
|
|
210
|
+
assert len(inputs) == 2
|
|
211
|
+
base, exp = inputs
|
|
212
|
+
return base**exp
|
|
213
|
+
|
|
214
|
+
case np.abs | np.absolute:
|
|
215
|
+
assert len(inputs) == 1
|
|
216
|
+
assert inputs[0] is self
|
|
217
|
+
arr = self._vars.ravel()
|
|
218
|
+
r = np.array([v.__abs__() for v in arr])
|
|
219
|
+
return FixedVariableArray(r.reshape(self.shape), self.solver_options)
|
|
220
|
+
|
|
221
|
+
case np.square:
|
|
222
|
+
assert len(inputs) == 1
|
|
223
|
+
assert inputs[0] is self
|
|
224
|
+
return self**2
|
|
225
|
+
|
|
226
|
+
if ufunc in _unary_functions:
|
|
227
|
+
assert len(inputs) == 1
|
|
228
|
+
assert inputs[0] is self
|
|
229
|
+
return self.apply(ufunc)
|
|
230
|
+
|
|
231
|
+
raise NotImplementedError(f'Unsupported ufunc: {ufunc}')
|
|
232
|
+
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
vars: NDArray,
|
|
236
|
+
solver_options: solver_options_t | None = None,
|
|
237
|
+
):
|
|
238
|
+
_vars = np.array(vars)
|
|
239
|
+
_vars_f = _vars.ravel()
|
|
240
|
+
hwconf = next(iter(v for v in _vars_f if isinstance(v, FixedVariable))).hwconf
|
|
241
|
+
for i, v in enumerate(_vars_f):
|
|
242
|
+
if not isinstance(v, FixedVariable):
|
|
243
|
+
_vars_f[i] = FixedVariable(float(v), float(v), 1.0, hwconf=hwconf)
|
|
244
|
+
self._vars = _vars
|
|
245
|
+
_solver_options = signature(solve).parameters
|
|
246
|
+
_solver_options = {k: v.default for k, v in _solver_options.items() if v.default is not v.empty}
|
|
247
|
+
if solver_options is not None:
|
|
248
|
+
_solver_options.update(solver_options)
|
|
249
|
+
_solver_options.pop('qintervals', None)
|
|
250
|
+
_solver_options.pop('latencies', None)
|
|
251
|
+
self.solver_options: solver_options_t = _solver_options # type: ignore
|
|
252
|
+
|
|
253
|
+
@classmethod
|
|
254
|
+
def from_lhs(
|
|
255
|
+
cls,
|
|
256
|
+
low: NDArray[np.floating],
|
|
257
|
+
high: NDArray[np.floating],
|
|
258
|
+
step: NDArray[np.floating],
|
|
259
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
260
|
+
latency: np.ndarray | float = 0.0,
|
|
261
|
+
solver_options: solver_options_t | None = None,
|
|
262
|
+
):
|
|
263
|
+
low, high, step = np.array(low), np.array(high), np.array(step)
|
|
264
|
+
shape = low.shape
|
|
265
|
+
assert shape == high.shape == step.shape
|
|
266
|
+
|
|
267
|
+
low, high, step = low.ravel(), high.ravel(), step.ravel()
|
|
268
|
+
latency = np.full_like(low, latency) if isinstance(latency, (int, float)) else latency.ravel()
|
|
269
|
+
|
|
270
|
+
vars = []
|
|
271
|
+
for l, h, s, lat in zip(low, high, step, latency):
|
|
272
|
+
var = FixedVariable(
|
|
273
|
+
low=float(l),
|
|
274
|
+
high=float(h),
|
|
275
|
+
step=float(s),
|
|
276
|
+
hwconf=hwconf,
|
|
277
|
+
latency=float(
|
|
278
|
+
lat,
|
|
279
|
+
),
|
|
280
|
+
)
|
|
281
|
+
vars.append(var)
|
|
282
|
+
vars = np.array(vars).reshape(shape)
|
|
283
|
+
return cls(vars, solver_options)
|
|
284
|
+
|
|
285
|
+
__array_priority__ = 100
|
|
286
|
+
|
|
287
|
+
@classmethod
|
|
288
|
+
def from_kif(
|
|
289
|
+
cls,
|
|
290
|
+
k: NDArray[np.bool_ | np.integer],
|
|
291
|
+
i: NDArray[np.integer],
|
|
292
|
+
f: NDArray[np.integer],
|
|
293
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
294
|
+
latency: NDArray[np.floating] | float = 0.0,
|
|
295
|
+
solver_options: solver_options_t | None = None,
|
|
296
|
+
):
|
|
297
|
+
mask = k + i + f <= 0
|
|
298
|
+
k = np.where(mask, 0, k)
|
|
299
|
+
i = np.where(mask, 0, i)
|
|
300
|
+
f = np.where(mask, 0, f)
|
|
301
|
+
step = 2.0**-f
|
|
302
|
+
_high = 2.0**i
|
|
303
|
+
high, low = _high - step, -_high * k
|
|
304
|
+
return cls.from_lhs(low, high, step, hwconf, latency, solver_options)
|
|
305
|
+
|
|
306
|
+
def matmul(self, other) -> 'FixedVariableArray':
|
|
307
|
+
if self.collapsed:
|
|
308
|
+
self_mat = np.array([v.low for v in self._vars.ravel()], dtype=np.float64).reshape(self._vars.shape)
|
|
309
|
+
if isinstance(other, FixedVariableArray):
|
|
310
|
+
if not other.collapsed:
|
|
311
|
+
return self_mat @ other # type: ignore
|
|
312
|
+
other_mat = np.array([v.low for v in other._vars.ravel()], dtype=np.float64).reshape(other._vars.shape)
|
|
313
|
+
else:
|
|
314
|
+
other_mat = np.array(other, dtype=np.float64)
|
|
315
|
+
|
|
316
|
+
r = self_mat @ other_mat
|
|
317
|
+
return FixedVariableArray.from_lhs(
|
|
318
|
+
low=r,
|
|
319
|
+
high=r,
|
|
320
|
+
step=np.ones_like(r),
|
|
321
|
+
hwconf=self._vars.ravel()[0].hwconf,
|
|
322
|
+
solver_options=self.solver_options,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if isinstance(other, FixedVariableArray):
|
|
326
|
+
other = other._vars
|
|
327
|
+
if not isinstance(other, np.ndarray):
|
|
328
|
+
other = np.array(other)
|
|
329
|
+
if any(isinstance(x, FixedVariable) for x in other.ravel()):
|
|
330
|
+
mat0, mat1 = self._vars, other
|
|
331
|
+
_vars = mmm(mat0, mat1)
|
|
332
|
+
return FixedVariableArray(_vars, self.solver_options)
|
|
333
|
+
|
|
334
|
+
solver_options = (self.solver_options or {}).copy()
|
|
335
|
+
shape0, shape1 = self.shape, other.shape
|
|
336
|
+
assert shape0[-1] == shape1[0], f'Matrix shapes do not match: {shape0} @ {shape1}'
|
|
337
|
+
contract_len = shape1[0]
|
|
338
|
+
out_shape = shape0[:-1] + shape1[1:]
|
|
339
|
+
mat0, mat1 = self.reshape((-1, contract_len)), other.reshape((contract_len, -1))
|
|
340
|
+
r = []
|
|
341
|
+
for i in range(mat0.shape[0]):
|
|
342
|
+
vec = mat0[i]
|
|
343
|
+
_r = cmvm(mat1, vec, solver_options)
|
|
344
|
+
r.append(_r)
|
|
345
|
+
r = np.array(r).reshape(out_shape)
|
|
346
|
+
return FixedVariableArray(r, self.solver_options)
|
|
347
|
+
|
|
348
|
+
def __matmul__(self, other):
|
|
349
|
+
return self.matmul(other)
|
|
350
|
+
|
|
351
|
+
def rmatmul(self, other):
|
|
352
|
+
mat1 = np.moveaxis(other, -1, 0)
|
|
353
|
+
mat0 = np.moveaxis(self, 0, -1) # type: ignore
|
|
354
|
+
ndim0, ndim1 = mat0.ndim, mat1.ndim
|
|
355
|
+
r = mat0 @ mat1
|
|
356
|
+
|
|
357
|
+
_axes = tuple(range(0, ndim0 + ndim1 - 2))
|
|
358
|
+
axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
|
|
359
|
+
return r.transpose(axes)
|
|
360
|
+
|
|
361
|
+
def __rmatmul__(self, other):
|
|
362
|
+
return self.rmatmul(other)
|
|
363
|
+
|
|
364
|
+
def __getitem__(self, item):
|
|
365
|
+
vars = self._vars[item]
|
|
366
|
+
if isinstance(vars, np.ndarray):
|
|
367
|
+
return FixedVariableArray(vars, self.solver_options)
|
|
368
|
+
else:
|
|
369
|
+
return vars
|
|
370
|
+
|
|
371
|
+
def __len__(self):
|
|
372
|
+
return len(self._vars)
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def shape(self):
|
|
376
|
+
return self._vars.shape
|
|
377
|
+
|
|
378
|
+
def __add__(self, other):
|
|
379
|
+
if isinstance(other, FixedVariableArray):
|
|
380
|
+
return FixedVariableArray(self._vars + other._vars, self.solver_options)
|
|
381
|
+
return FixedVariableArray(self._vars + other, self.solver_options)
|
|
382
|
+
|
|
383
|
+
def __sub__(self, other):
|
|
384
|
+
if isinstance(other, FixedVariableArray):
|
|
385
|
+
return FixedVariableArray(self._vars - other._vars, self.solver_options)
|
|
386
|
+
return FixedVariableArray(self._vars - other, self.solver_options)
|
|
387
|
+
|
|
388
|
+
def __mul__(self, other):
|
|
389
|
+
if isinstance(other, FixedVariableArray):
|
|
390
|
+
return FixedVariableArray(self._vars * other._vars, self.solver_options)
|
|
391
|
+
return FixedVariableArray(self._vars * other, self.solver_options)
|
|
392
|
+
|
|
393
|
+
def __truediv__(self, other):
|
|
394
|
+
return FixedVariableArray(self._vars * (1 / other), self.solver_options)
|
|
395
|
+
|
|
396
|
+
def __radd__(self, other):
|
|
397
|
+
return self + other
|
|
398
|
+
|
|
399
|
+
def __neg__(self):
|
|
400
|
+
return FixedVariableArray(-self._vars, self.solver_options)
|
|
401
|
+
|
|
402
|
+
def __repr__(self):
|
|
403
|
+
shape = self._vars.shape
|
|
404
|
+
hwconf_str = str(self._vars.ravel()[0].hwconf)[8:]
|
|
405
|
+
max_lat = max(v.latency for v in self._vars.ravel())
|
|
406
|
+
return f'FixedVariableArray(shape={shape}, hwconf={hwconf_str}, latency={max_lat})'
|
|
407
|
+
|
|
408
|
+
def __pow__(self, power: int | float):
|
|
409
|
+
_power = int(power)
|
|
410
|
+
if _power == power and _power >= 0:
|
|
411
|
+
return FixedVariableArray(self._vars**_power, self.solver_options)
|
|
412
|
+
else:
|
|
413
|
+
return self.apply(lambda x: x**power)
|
|
414
|
+
|
|
415
|
+
def relu(
|
|
416
|
+
self,
|
|
417
|
+
i: NDArray[np.integer] | None = None,
|
|
418
|
+
f: NDArray[np.integer] | None = None,
|
|
419
|
+
round_mode: str = 'TRN',
|
|
420
|
+
):
|
|
421
|
+
shape = self._vars.shape
|
|
422
|
+
i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
|
|
423
|
+
f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
|
|
424
|
+
ret = []
|
|
425
|
+
for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()): # type: ignore
|
|
426
|
+
ret.append(v.relu(i=i, f=f, round_mode=round_mode))
|
|
427
|
+
return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
|
|
428
|
+
|
|
429
|
+
def quantize(
|
|
430
|
+
self,
|
|
431
|
+
k: NDArray[np.integer] | np.integer | int | None = None,
|
|
432
|
+
i: NDArray[np.integer] | np.integer | int | None = None,
|
|
433
|
+
f: NDArray[np.integer] | np.integer | int | None = None,
|
|
434
|
+
overflow_mode: str = 'WRAP',
|
|
435
|
+
round_mode: str = 'TRN',
|
|
436
|
+
):
|
|
437
|
+
shape = self._vars.shape
|
|
438
|
+
if any(x is None for x in (k, i, f)):
|
|
439
|
+
kif = self.kif
|
|
440
|
+
k = np.broadcast_to(k, shape) if k is not None else kif[0] # type: ignore
|
|
441
|
+
i = np.broadcast_to(i, shape) if i is not None else kif[1] # type: ignore
|
|
442
|
+
f = np.broadcast_to(f, shape) if f is not None else kif[2] # type: ignore
|
|
443
|
+
ret = []
|
|
444
|
+
for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()): # type: ignore
|
|
445
|
+
ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
|
|
446
|
+
return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
|
|
447
|
+
|
|
448
|
+
def flatten(self):
|
|
449
|
+
return FixedVariableArray(self._vars.flatten(), self.solver_options)
|
|
450
|
+
|
|
451
|
+
def reshape(self, *shape):
|
|
452
|
+
return FixedVariableArray(self._vars.reshape(*shape), self.solver_options)
|
|
453
|
+
|
|
454
|
+
def transpose(self, axes=None):
|
|
455
|
+
return FixedVariableArray(self._vars.transpose(axes), self.solver_options)
|
|
456
|
+
|
|
457
|
+
def ravel(self):
|
|
458
|
+
return FixedVariableArray(self._vars.ravel(), self.solver_options)
|
|
459
|
+
|
|
460
|
+
@property
|
|
461
|
+
def dtype(self):
|
|
462
|
+
return self._vars.dtype
|
|
463
|
+
|
|
464
|
+
@property
|
|
465
|
+
def size(self):
|
|
466
|
+
return self._vars.size
|
|
467
|
+
|
|
468
|
+
@property
|
|
469
|
+
def ndim(self):
|
|
470
|
+
return self._vars.ndim
|
|
471
|
+
|
|
472
|
+
@property
|
|
473
|
+
def kif(self):
|
|
474
|
+
"""[k, i, f] array"""
|
|
475
|
+
shape = self._vars.shape
|
|
476
|
+
kif = np.array([v.kif for v in self._vars.ravel()]).reshape(*shape, 3)
|
|
477
|
+
return np.moveaxis(kif, -1, 0)
|
|
478
|
+
|
|
479
|
+
@property
|
|
480
|
+
def lhs(self):
|
|
481
|
+
"""[low, high, step] array"""
|
|
482
|
+
shape = self._vars.shape
|
|
483
|
+
lhs = np.array([(v.low, v.high, v.step) for v in self._vars.ravel()], dtype=np.float32).reshape(*shape, 3)
|
|
484
|
+
return np.moveaxis(lhs, -1, 0)
|
|
485
|
+
|
|
486
|
+
@property
|
|
487
|
+
def latency(self):
|
|
488
|
+
"""Maximum latency among all elements."""
|
|
489
|
+
return np.array([v.latency for v in self._vars.ravel()]).reshape(self._vars.shape)
|
|
490
|
+
|
|
491
|
+
@property
|
|
492
|
+
def collapsed(self):
|
|
493
|
+
return all(v.low == v.high for v in self._vars.ravel())
|
|
494
|
+
|
|
495
|
+
def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
|
|
496
|
+
"""Apply a unary operator to all elements, returning a RetardedFixedVariableArray."""
|
|
497
|
+
return RetardedFixedVariableArray(
|
|
498
|
+
self._vars,
|
|
499
|
+
self.solver_options,
|
|
500
|
+
operator=fn,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
@property
|
|
504
|
+
def T(self):
|
|
505
|
+
return self.transpose()
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
class FixedVariableArrayInput(FixedVariableArray):
|
|
509
|
+
"""Similar to FixedVariableArray, but initializes all elements as FixedVariableInput - the precisions are unspecified when initialized, and the highest precision requested (i.e., quantized to) will be recorded for generation of the logic."""
|
|
510
|
+
|
|
511
|
+
def __init__(
|
|
512
|
+
self,
|
|
513
|
+
shape: tuple[int, ...] | int,
|
|
514
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
515
|
+
solver_options: solver_options_t | None = None,
|
|
516
|
+
latency=0.0,
|
|
517
|
+
):
|
|
518
|
+
_vars = np.empty(shape, dtype=object)
|
|
519
|
+
_vars_f = _vars.ravel()
|
|
520
|
+
for i in range(_vars.size):
|
|
521
|
+
_vars_f[i] = FixedVariableInput(latency, hwconf)
|
|
522
|
+
super().__init__(_vars, solver_options)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def make_table(fn: Callable[[NDArray], NDArray], qint: QInterval) -> LookupTable:
|
|
526
|
+
low, high, step = qint
|
|
527
|
+
n = round(abs(high - low) / step) + 1
|
|
528
|
+
return LookupTable(fn(np.linspace(low, high, n)))
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class RetardedFixedVariableArray(FixedVariableArray):
|
|
532
|
+
"""Ephemeral FixedVariableArray generated from operations of unspecified output precision.
|
|
533
|
+
This object translates to normal FixedVariableArray upon quantization.
|
|
534
|
+
Does not inherit the maximum precision like FixedVariableArrayInput.
|
|
535
|
+
|
|
536
|
+
This object can be used in two ways:
|
|
537
|
+
1. Quantization with specified precision, which converts to FixedVariableArray.
|
|
538
|
+
2. Apply an further unary operation, which returns another RetardedFixedVariableArray. (e.g., composite functions)
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
def __init__(self, vars: NDArray, solver_options: solver_options_t | None, operator: Callable[[NDArray], NDArray]):
|
|
542
|
+
self._operator = operator
|
|
543
|
+
super().__init__(vars, solver_options)
|
|
544
|
+
|
|
545
|
+
def __array_function__(self, ufunc, method, *inputs, **kwargs):
|
|
546
|
+
raise RuntimeError('RetardedFixedVariableArray only supports quantization or further unary operations.')
|
|
547
|
+
|
|
548
|
+
def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
|
|
549
|
+
return RetardedFixedVariableArray(
|
|
550
|
+
self._vars,
|
|
551
|
+
self.solver_options,
|
|
552
|
+
operator=lambda x: fn(self._operator(x)),
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def quantize(
|
|
556
|
+
self,
|
|
557
|
+
k: NDArray[np.integer] | np.integer | int | None = None,
|
|
558
|
+
i: NDArray[np.integer] | np.integer | int | None = None,
|
|
559
|
+
f: NDArray[np.integer] | np.integer | int | None = None,
|
|
560
|
+
overflow_mode: str = 'WRAP',
|
|
561
|
+
round_mode: str = 'TRN',
|
|
562
|
+
):
|
|
563
|
+
if any(x is None for x in (k, i, f)):
|
|
564
|
+
assert all(x is not None for x in (k, i, f)), 'Either all or none of k, i, f must be specified'
|
|
565
|
+
_k = _i = _f = [None] * self.size
|
|
566
|
+
else:
|
|
567
|
+
_k = np.broadcast_to(k, self.shape).ravel() # type: ignore
|
|
568
|
+
_i = np.broadcast_to(i, self.shape).ravel() # type: ignore
|
|
569
|
+
_f = np.broadcast_to(f, self.shape).ravel() # type: ignore
|
|
570
|
+
|
|
571
|
+
op = lambda x: _quantize(self._operator(x), k, i, f, overflow_mode, round_mode) # type: ignore
|
|
572
|
+
|
|
573
|
+
local_tables: dict[tuple[QInterval, tuple[int, int, int]] | QInterval, LookupTable] = {}
|
|
574
|
+
variables = []
|
|
575
|
+
for v, _kk, _ii, _ff in zip(self._vars.ravel(), _k, _i, _f):
|
|
576
|
+
v: FixedVariable
|
|
577
|
+
qint = v.qint if v._factor >= 0 else QInterval(v.qint.max, v.qint.min, v.qint.step)
|
|
578
|
+
if (_kk is None) or (_ii is None) or (_ff is None):
|
|
579
|
+
op = self._operator
|
|
580
|
+
_key = qint
|
|
581
|
+
else:
|
|
582
|
+
op = lambda x: _quantize(self._operator(x), _kk, _ii, _ff, overflow_mode, round_mode) # type: ignore
|
|
583
|
+
_key = (qint, (int(_kk), int(_ii), int(_ff)))
|
|
584
|
+
|
|
585
|
+
if _key in local_tables:
|
|
586
|
+
table = local_tables[_key]
|
|
587
|
+
else:
|
|
588
|
+
table = make_table(op, qint)
|
|
589
|
+
local_tables[_key] = table
|
|
590
|
+
variables.append(v.lookup(table))
|
|
591
|
+
|
|
592
|
+
variables = np.array(variables).reshape(self._vars.shape)
|
|
593
|
+
return FixedVariableArray(variables, self.solver_options)
|
|
594
|
+
|
|
595
|
+
def __repr__(self):
|
|
596
|
+
return 'Retarded' + super().__repr__()
|
|
597
|
+
|
|
598
|
+
@property
|
|
599
|
+
def kif(self):
|
|
600
|
+
raise RuntimeError('RetardedFixedVariableArray does not have defined kif until quantized.')
|