da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +204 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +246 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.1.post1.dist-info/METADATA +85 -0
- da4ml-0.5.1.post1.dist-info/RECORD +96 -0
- da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
- da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
- da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,965 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import typing
|
|
3
|
+
from collections.abc import Callable, Generator
|
|
4
|
+
from copy import copy
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from decimal import Decimal
|
|
7
|
+
from hashlib import sha256
|
|
8
|
+
from math import ceil, floor, log2
|
|
9
|
+
from typing import NamedTuple, overload
|
|
10
|
+
from uuid import UUID
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
|
|
15
|
+
from ..cmvm.core import cost_add
|
|
16
|
+
from ..cmvm.types import QInterval, _minimal_kif
|
|
17
|
+
from ..cmvm.util.bit_decompose import _shift_centering
|
|
18
|
+
|
|
19
|
+
rd = random.Random()
|
|
20
|
+
|
|
21
|
+
if typing.TYPE_CHECKING:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class HWConfig(NamedTuple):
|
|
26
|
+
adder_size: int
|
|
27
|
+
carry_size: int
|
|
28
|
+
latency_cutoff: float
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
ufunc_t = Callable[[NDArray[np.floating]], NDArray[np.floating]]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TraceContext:
|
|
35
|
+
_tables: 'dict[str, tuple[LookupTable, int]]' = {}
|
|
36
|
+
hwconf: HWConfig = HWConfig(1, -1, -1)
|
|
37
|
+
_table_counter = 0
|
|
38
|
+
|
|
39
|
+
def register_table(self, table: 'LookupTable|np.ndarray'):
|
|
40
|
+
if isinstance(table, np.ndarray):
|
|
41
|
+
table = LookupTable(table)
|
|
42
|
+
if table.spec.hash in self._tables:
|
|
43
|
+
return self._tables[table.spec.hash]
|
|
44
|
+
self._tables[table.spec.hash] = (table, self._table_counter)
|
|
45
|
+
|
|
46
|
+
self._table_counter += 1
|
|
47
|
+
return self._tables[table.spec.hash]
|
|
48
|
+
|
|
49
|
+
def index_table(self, hash: str) -> int:
|
|
50
|
+
return self._tables[hash][1]
|
|
51
|
+
|
|
52
|
+
def get_table_from_index(self, index: int) -> 'LookupTable':
|
|
53
|
+
for table, idx in self._tables.values():
|
|
54
|
+
if idx == index:
|
|
55
|
+
return table
|
|
56
|
+
raise KeyError(f'No table found with index {index}')
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
table_context = TraceContext()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class TableSpec:
|
|
64
|
+
hash: str
|
|
65
|
+
out_qint: QInterval
|
|
66
|
+
inp_width: int
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def out_kif(self) -> tuple[bool, int, int]:
|
|
70
|
+
return _minimal_kif(self.out_qint)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def to_spec(table: NDArray[np.floating]) -> tuple[TableSpec, NDArray[np.int32]]:
|
|
74
|
+
f_out = -_shift_centering(np.array(table))
|
|
75
|
+
int_table = (table * 2**f_out).astype(np.int32)
|
|
76
|
+
h = sha256(int_table.data)
|
|
77
|
+
h.update(f'{f_out}'.encode())
|
|
78
|
+
inp_width = ceil(log2(table.size))
|
|
79
|
+
out_qint = QInterval(float(np.min(table)), float(np.max(table)), float(2**-f_out))
|
|
80
|
+
return TableSpec(hash=h.hexdigest(), inp_width=inp_width, out_qint=out_qint), int_table
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def interpret_as(
|
|
84
|
+
x: int | NDArray[np.integer],
|
|
85
|
+
k: int,
|
|
86
|
+
i: int,
|
|
87
|
+
f: int,
|
|
88
|
+
) -> float | NDArray[np.floating]:
|
|
89
|
+
b = k + i + f
|
|
90
|
+
bias = 2.0 ** (b - 1) * k
|
|
91
|
+
eps = 2.0**-f
|
|
92
|
+
floor_fn = np.floor if isinstance(x, np.ndarray) else floor
|
|
93
|
+
return eps * (floor_fn(x + bias) % 2.0**b - bias)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class LookupTable:
|
|
97
|
+
def __init__(self, values: NDArray, spec: TableSpec | None = None):
|
|
98
|
+
assert values.ndim == 1, 'Lookup table values must be 1-dimensional'
|
|
99
|
+
if spec is not None:
|
|
100
|
+
assert values.dtype == np.int32, f'{values.dtype}'
|
|
101
|
+
self.spec = spec
|
|
102
|
+
self.table = values
|
|
103
|
+
else:
|
|
104
|
+
self.spec, self.table = to_spec(values)
|
|
105
|
+
|
|
106
|
+
@overload
|
|
107
|
+
def lookup(self, var: 'FixedVariable', qint_in: QInterval) -> 'FixedVariable': ...
|
|
108
|
+
|
|
109
|
+
@overload
|
|
110
|
+
def lookup(self, var: np.floating | float, qint_in: QInterval | tuple[float, float, float]) -> float: ...
|
|
111
|
+
|
|
112
|
+
def lookup(self, var, qint_in: QInterval | tuple[float, float, float]):
|
|
113
|
+
if isinstance(var, FixedVariable):
|
|
114
|
+
return var.lookup(self)
|
|
115
|
+
else:
|
|
116
|
+
_min, _max, _step = qint_in
|
|
117
|
+
assert _min <= var <= _max, f'Value {var} out of range [{_min}, {_max}]'
|
|
118
|
+
index = round((var - _min) / _step)
|
|
119
|
+
return interpret_as(int(self.table[index]), *self.spec.out_kif)
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def float_table(self) -> NDArray[np.floating]:
|
|
123
|
+
k, i, f = self.spec.out_kif
|
|
124
|
+
return interpret_as(self.table, k, i, f) # type: ignore
|
|
125
|
+
|
|
126
|
+
def to_dict(self) -> dict:
|
|
127
|
+
return {
|
|
128
|
+
'spec': {
|
|
129
|
+
'hash': self.spec.hash,
|
|
130
|
+
'out_qint': {
|
|
131
|
+
'min': self.spec.out_qint.min,
|
|
132
|
+
'max': self.spec.out_qint.max,
|
|
133
|
+
'step': self.spec.out_qint.step,
|
|
134
|
+
},
|
|
135
|
+
'inp_width': self.spec.inp_width,
|
|
136
|
+
},
|
|
137
|
+
'table': self.table.tolist(),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def from_dict(cls, data: dict) -> 'LookupTable':
|
|
142
|
+
spec_data = data['spec']
|
|
143
|
+
out_qint_data = spec_data['out_qint']
|
|
144
|
+
spec = TableSpec(
|
|
145
|
+
hash=spec_data['hash'],
|
|
146
|
+
out_qint=QInterval(out_qint_data['min'], out_qint_data['max'], out_qint_data['step']),
|
|
147
|
+
inp_width=spec_data['inp_width'],
|
|
148
|
+
)
|
|
149
|
+
table = np.array(data['table'], dtype=np.int32)
|
|
150
|
+
return cls(table, spec=spec)
|
|
151
|
+
|
|
152
|
+
def _get_pads(self, qint: QInterval) -> tuple[int, int]:
|
|
153
|
+
k, i, f = _minimal_kif(qint)
|
|
154
|
+
if k:
|
|
155
|
+
pad_left = round((qint.min + 2**i) / qint.step)
|
|
156
|
+
else:
|
|
157
|
+
pad_left = round(qint.min / qint.step)
|
|
158
|
+
size = 2 ** (k + i + f)
|
|
159
|
+
pad_right = size - len(self.table) - pad_left
|
|
160
|
+
return pad_left, pad_right
|
|
161
|
+
|
|
162
|
+
def padded_table(self, qint: QInterval) -> NDArray[np.int32]:
|
|
163
|
+
pad_left, pad_right = self._get_pads(qint)
|
|
164
|
+
data = np.pad(self.table, (pad_left, pad_right), mode='constant', constant_values=0)
|
|
165
|
+
if qint.min < 0:
|
|
166
|
+
size = len(data)
|
|
167
|
+
# data = np.concatenate((data[size // 2 :], data[: size // 2]))
|
|
168
|
+
data = np.roll(data, size // 2)
|
|
169
|
+
return data
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _const_f(const: float | Decimal):
|
|
173
|
+
"""Get the minimum f such that const * 2^f is an integer."""
|
|
174
|
+
const = float(const)
|
|
175
|
+
if const == 0:
|
|
176
|
+
return 0
|
|
177
|
+
_low, _high = -32, 32
|
|
178
|
+
while _high - _low > 1:
|
|
179
|
+
_mid = (_high + _low) // 2
|
|
180
|
+
_value = const * (2.0**_mid)
|
|
181
|
+
if _value == int(_value):
|
|
182
|
+
_high = _mid
|
|
183
|
+
else:
|
|
184
|
+
_low = _mid
|
|
185
|
+
return _high
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def to_csd_powers(x: float) -> Generator[float, None, None]:
|
|
189
|
+
"""Convert a float to a list of +/- powers of two in CSD representation."""
|
|
190
|
+
if x == 0:
|
|
191
|
+
return
|
|
192
|
+
f = _const_f(abs(x))
|
|
193
|
+
x = x * 2**f
|
|
194
|
+
s = 2**-f
|
|
195
|
+
N = ceil(log2(abs(x) * 1.5 + 1e-19))
|
|
196
|
+
for n in range(N - 1, -1, -1):
|
|
197
|
+
_2pn = 2**n
|
|
198
|
+
thres = _2pn / 1.5
|
|
199
|
+
bit = int(x > thres) - int(x < -thres)
|
|
200
|
+
v = _2pn * bit
|
|
201
|
+
x -= v
|
|
202
|
+
if v != 0:
|
|
203
|
+
yield v * s
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class FixedVariable:
|
|
207
|
+
__normal__variable__ = True
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
low: float | Decimal,
|
|
212
|
+
high: float | Decimal,
|
|
213
|
+
step: float | Decimal,
|
|
214
|
+
latency: float | None = None,
|
|
215
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(-1, -1, -1),
|
|
216
|
+
opr: str = 'new',
|
|
217
|
+
cost: float | None = None,
|
|
218
|
+
_from: tuple['FixedVariable', ...] = (),
|
|
219
|
+
_factor: float | Decimal = 1.0,
|
|
220
|
+
_data: Decimal | None = None,
|
|
221
|
+
_id: UUID | None = None,
|
|
222
|
+
) -> None:
|
|
223
|
+
if self.__normal__variable__:
|
|
224
|
+
assert low <= high, f'low {low} must be less than high {high}'
|
|
225
|
+
|
|
226
|
+
if low != high and opr == 'const':
|
|
227
|
+
raise ValueError('Constant variable must have low == high')
|
|
228
|
+
|
|
229
|
+
if low == high:
|
|
230
|
+
opr = 'const'
|
|
231
|
+
_from = ()
|
|
232
|
+
step = 2.0 ** -_const_f(low)
|
|
233
|
+
|
|
234
|
+
low, high, step = Decimal(low), Decimal(high), Decimal(step)
|
|
235
|
+
self.low = low
|
|
236
|
+
self.high = high
|
|
237
|
+
self.step = step
|
|
238
|
+
self._factor = Decimal(_factor)
|
|
239
|
+
self._from: tuple[FixedVariable, ...] = _from
|
|
240
|
+
opr = opr
|
|
241
|
+
self.opr = opr
|
|
242
|
+
self._data = _data
|
|
243
|
+
self.id = _id or UUID(int=rd.getrandbits(128), version=4)
|
|
244
|
+
self.hwconf = HWConfig(*hwconf)
|
|
245
|
+
|
|
246
|
+
if opr == 'cadd':
|
|
247
|
+
assert _data is not None, 'cadd must have data'
|
|
248
|
+
|
|
249
|
+
if cost is None or latency is None:
|
|
250
|
+
_cost, _latency = self.get_cost_and_latency()
|
|
251
|
+
else:
|
|
252
|
+
_cost, _latency = cost, latency
|
|
253
|
+
|
|
254
|
+
self.latency = _latency
|
|
255
|
+
self.cost = _cost
|
|
256
|
+
|
|
257
|
+
self._from = tuple(v if v.opr != 'const' else v._with(latency=self.latency) for v in self._from)
|
|
258
|
+
|
|
259
|
+
def _with(self, renew_id=True, **kwargs):
|
|
260
|
+
if not kwargs:
|
|
261
|
+
return self
|
|
262
|
+
_var = copy(self)
|
|
263
|
+
for k, v in kwargs.items():
|
|
264
|
+
setattr(_var, k, v)
|
|
265
|
+
if renew_id:
|
|
266
|
+
_var.id = UUID(int=rd.getrandbits(128), version=4)
|
|
267
|
+
return _var
|
|
268
|
+
|
|
269
|
+
def get_cost_and_latency(self) -> tuple[float, float]:
|
|
270
|
+
if self.opr == 'const':
|
|
271
|
+
return 0.0, 0.0
|
|
272
|
+
|
|
273
|
+
if self.opr == 'lookup':
|
|
274
|
+
assert len(self._from) == 1
|
|
275
|
+
b_in = sum(self._from[0].kif)
|
|
276
|
+
b_out = sum(self.kif)
|
|
277
|
+
_latency = max(b_in - 6, 1) + self._from[0].latency
|
|
278
|
+
_cost = 2 ** max(b_in - 5, 0) * ceil(b_out / 2)
|
|
279
|
+
if b_in < 5:
|
|
280
|
+
_cost *= b_in / 5
|
|
281
|
+
# Assume LUT6 with extra o5 output
|
|
282
|
+
return _cost, _latency
|
|
283
|
+
|
|
284
|
+
if self.opr in ('vadd', 'cadd', 'min', 'max', 'vmul'):
|
|
285
|
+
adder_size = self.hwconf.adder_size
|
|
286
|
+
carry_size = self.hwconf.carry_size
|
|
287
|
+
latency_cutoff = self.hwconf.latency_cutoff
|
|
288
|
+
|
|
289
|
+
if self.opr in ('min', 'max', 'vadd'):
|
|
290
|
+
assert len(self._from) == 2
|
|
291
|
+
v0, v1 = self._from
|
|
292
|
+
int0, int1 = v0.qint, v1.qint
|
|
293
|
+
base_latency = max(v0.latency, v1.latency)
|
|
294
|
+
dlat, _cost = cost_add(int0, int1, 0, False, adder_size, carry_size)
|
|
295
|
+
elif self.opr == 'cadd':
|
|
296
|
+
assert len(self._from) == 1
|
|
297
|
+
assert self._data is not None, 'cadd must have data'
|
|
298
|
+
_f = _const_f(self._data)
|
|
299
|
+
_cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
|
|
300
|
+
base_latency = self._from[0].latency
|
|
301
|
+
dlat = 0.0
|
|
302
|
+
elif self.opr == 'vmul':
|
|
303
|
+
assert len(self._from) == 2
|
|
304
|
+
v0, v1 = self._from
|
|
305
|
+
b0, b1 = sum(v0.kif), sum(v1.kif)
|
|
306
|
+
int0, int1 = v0.qint, v1.qint
|
|
307
|
+
dlat0, _cost0 = cost_add(int0, int0, 0, False, adder_size, carry_size)
|
|
308
|
+
dlat1, _cost1 = cost_add(int1, int1, 0, False, adder_size, carry_size)
|
|
309
|
+
dlat = max(dlat0 * b1, dlat1 * b0)
|
|
310
|
+
_cost = min(_cost0 * b1, _cost1 * b0)
|
|
311
|
+
base_latency = max(v0.latency, v1.latency)
|
|
312
|
+
else:
|
|
313
|
+
raise NotImplementedError(f'Operation {self.opr} is unknown')
|
|
314
|
+
|
|
315
|
+
_latency = dlat + base_latency
|
|
316
|
+
if latency_cutoff > 0 and ceil(_latency / latency_cutoff) > ceil(base_latency / latency_cutoff):
|
|
317
|
+
# Crossed the latency cutoff boundry
|
|
318
|
+
assert dlat <= latency_cutoff, (
|
|
319
|
+
f'Latency of an atomic operation {dlat} is larger than the pipelining latency cutoff {latency_cutoff}'
|
|
320
|
+
)
|
|
321
|
+
_latency = ceil(base_latency / latency_cutoff) * latency_cutoff + dlat
|
|
322
|
+
|
|
323
|
+
elif self.opr in ('relu', 'wrap'):
|
|
324
|
+
assert len(self._from) == 1
|
|
325
|
+
_latency = self._from[0].latency
|
|
326
|
+
_cost = 0.0
|
|
327
|
+
# Assume LUT5 used here (2 fan-out per LUT6, thus *1/2)
|
|
328
|
+
if self._from[0]._factor < 0:
|
|
329
|
+
_cost += sum(self.kif) / 2
|
|
330
|
+
if self.opr == 'relu':
|
|
331
|
+
_cost += sum(self.kif) / 2
|
|
332
|
+
|
|
333
|
+
elif self.opr == 'new':
|
|
334
|
+
# new variable, no cost
|
|
335
|
+
_latency = 0.0
|
|
336
|
+
_cost = 0.0
|
|
337
|
+
else:
|
|
338
|
+
raise NotImplementedError(f'Operation {self.opr} is unknown')
|
|
339
|
+
return _cost, _latency
|
|
340
|
+
|
|
341
|
+
@property
|
|
342
|
+
def unscaled(self):
|
|
343
|
+
return self * (1 / self._factor)
|
|
344
|
+
|
|
345
|
+
@property
|
|
346
|
+
def qint(self) -> QInterval:
|
|
347
|
+
return QInterval(float(self.low), float(self.high), float(self.step))
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def kif(self) -> tuple[bool, int, int]:
|
|
351
|
+
if self.step == 0:
|
|
352
|
+
return False, 0, 0
|
|
353
|
+
f = -int(log2(self.step))
|
|
354
|
+
i = ceil(log2(max(-self.low, self.high + self.step)))
|
|
355
|
+
k = self.low < 0
|
|
356
|
+
return k, i, f
|
|
357
|
+
|
|
358
|
+
@classmethod
|
|
359
|
+
def from_const(cls, const: float | Decimal, hwconf: HWConfig, _factor: float | Decimal = 1):
|
|
360
|
+
return cls(const, const, -1, hwconf=hwconf, opr='const', _factor=_factor)
|
|
361
|
+
|
|
362
|
+
def __repr__(self) -> str:
|
|
363
|
+
if self._factor == 1:
|
|
364
|
+
return f'FixedVariable({self.low}, {self.high}, {self.step})'
|
|
365
|
+
return f'({self._factor}) FixedVariable({self.low}, {self.high}, {self.step})'
|
|
366
|
+
|
|
367
|
+
def __neg__(self):
|
|
368
|
+
opr = self.opr if self.low != self.high else 'const'
|
|
369
|
+
return FixedVariable(
|
|
370
|
+
-self.high,
|
|
371
|
+
-self.low,
|
|
372
|
+
self.step,
|
|
373
|
+
_from=self._from,
|
|
374
|
+
_factor=-self._factor,
|
|
375
|
+
latency=self.latency,
|
|
376
|
+
cost=self.cost,
|
|
377
|
+
opr=opr,
|
|
378
|
+
_id=self.id,
|
|
379
|
+
_data=self._data,
|
|
380
|
+
hwconf=self.hwconf,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
def __add__(self, other: 'FixedVariable|float|Decimal|int'):
|
|
384
|
+
if not isinstance(other, FixedVariable):
|
|
385
|
+
return self._const_add(other)
|
|
386
|
+
if other.high == other.low:
|
|
387
|
+
return self._const_add(other.low)
|
|
388
|
+
if self.high == self.low:
|
|
389
|
+
return other._const_add(self.low)
|
|
390
|
+
|
|
391
|
+
assert self.hwconf == other.hwconf, f'FixedVariable must have the same hwconf, got {self.hwconf} and {other.hwconf}'
|
|
392
|
+
|
|
393
|
+
f0, f1 = self._factor, other._factor
|
|
394
|
+
if f0 < 0:
|
|
395
|
+
if f1 > 0:
|
|
396
|
+
return other + self
|
|
397
|
+
else:
|
|
398
|
+
return -((-self) + (-other))
|
|
399
|
+
|
|
400
|
+
return FixedVariable(
|
|
401
|
+
self.low + other.low,
|
|
402
|
+
self.high + other.high,
|
|
403
|
+
min(self.step, other.step),
|
|
404
|
+
_from=(self, other),
|
|
405
|
+
_factor=f0,
|
|
406
|
+
opr='vadd',
|
|
407
|
+
hwconf=self.hwconf,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def _const_add(self, other: float | Decimal | None) -> 'FixedVariable':
|
|
411
|
+
if other is None:
|
|
412
|
+
return self
|
|
413
|
+
if not isinstance(other, (int, float, Decimal)):
|
|
414
|
+
other = float(other) # direct numpy to decimal raises error
|
|
415
|
+
other = Decimal(other)
|
|
416
|
+
if other == 0:
|
|
417
|
+
return self
|
|
418
|
+
|
|
419
|
+
if self.opr != 'cadd':
|
|
420
|
+
cstep = Decimal(2.0 ** -_const_f(other))
|
|
421
|
+
|
|
422
|
+
return FixedVariable(
|
|
423
|
+
self.low + other,
|
|
424
|
+
self.high + other,
|
|
425
|
+
min(self.step, cstep),
|
|
426
|
+
_from=(self,),
|
|
427
|
+
_factor=self._factor,
|
|
428
|
+
_data=other / self._factor,
|
|
429
|
+
opr='cadd',
|
|
430
|
+
hwconf=self.hwconf,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# cadd, combine the constant
|
|
434
|
+
assert len(self._from) == 1
|
|
435
|
+
parent = self._from[0]
|
|
436
|
+
assert self._data is not None, 'cadd must have data'
|
|
437
|
+
sf = self._factor / parent._factor
|
|
438
|
+
other1 = (self._data * parent._factor) + other / sf
|
|
439
|
+
return (parent + other1) * sf
|
|
440
|
+
|
|
441
|
+
def __sub__(self, other: 'FixedVariable|int|float|Decimal'):
|
|
442
|
+
return self + (-other)
|
|
443
|
+
|
|
444
|
+
def __truediv__(self, other: 'int|float|Decimal'):
|
|
445
|
+
assert not isinstance(other, FixedVariable), 'Division by variable is not supported'
|
|
446
|
+
return self * (1 / other)
|
|
447
|
+
|
|
448
|
+
def __mul__(self, other: 'FixedVariable|int|float|Decimal') -> 'FixedVariable':
|
|
449
|
+
if isinstance(other, FixedVariable):
|
|
450
|
+
if self.high == self.low:
|
|
451
|
+
return other * self.low
|
|
452
|
+
if other.high > other.low:
|
|
453
|
+
return self._var_mul(other)
|
|
454
|
+
assert other.high == other.low
|
|
455
|
+
other = float(other.low)
|
|
456
|
+
|
|
457
|
+
if other == 0:
|
|
458
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
459
|
+
|
|
460
|
+
if log2(abs(other)) % 1 == 0:
|
|
461
|
+
return self._pow2_mul(other)
|
|
462
|
+
|
|
463
|
+
variables = [(self._pow2_mul(v), Decimal(v)) for v in to_csd_powers(float(other))]
|
|
464
|
+
while len(variables) > 1:
|
|
465
|
+
v1, p1 = variables.pop()
|
|
466
|
+
v2, p2 = variables.pop()
|
|
467
|
+
v, p = v1 + v2, p1 + p2
|
|
468
|
+
if p > 0:
|
|
469
|
+
high, low = self.high * p, self.low * p
|
|
470
|
+
else:
|
|
471
|
+
high, low = self.low * p, self.high * p
|
|
472
|
+
v.high, v.low = high, low
|
|
473
|
+
variables.append((v, p))
|
|
474
|
+
return variables[0][0]
|
|
475
|
+
|
|
476
|
+
def _var_mul(self, other: 'FixedVariable') -> 'FixedVariable':
|
|
477
|
+
if other is not self:
|
|
478
|
+
a, b, c, d = self.high * other.low, self.low * other.high, self.high * other.high, self.low * other.low
|
|
479
|
+
low = min(a, b, c, d)
|
|
480
|
+
high = max(a, b, c, d)
|
|
481
|
+
else:
|
|
482
|
+
a, b = self.low * other.low, self.high * other.high
|
|
483
|
+
if self.low < 0 and self.high > 0:
|
|
484
|
+
low = min(a, b, 0)
|
|
485
|
+
high = max(a, b, 0)
|
|
486
|
+
else:
|
|
487
|
+
low = min(a, b)
|
|
488
|
+
high = max(a, b)
|
|
489
|
+
|
|
490
|
+
step = self.step * other.step
|
|
491
|
+
_factor = self._factor * other._factor
|
|
492
|
+
opr = 'vmul'
|
|
493
|
+
return FixedVariable(
|
|
494
|
+
low,
|
|
495
|
+
high,
|
|
496
|
+
step,
|
|
497
|
+
_from=(self, other),
|
|
498
|
+
hwconf=self.hwconf,
|
|
499
|
+
_factor=_factor,
|
|
500
|
+
opr=opr,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
def _pow2_mul(
|
|
504
|
+
self,
|
|
505
|
+
other: float | Decimal,
|
|
506
|
+
):
|
|
507
|
+
other = Decimal(other)
|
|
508
|
+
|
|
509
|
+
low = min(self.low * other, self.high * other)
|
|
510
|
+
high = max(self.low * other, self.high * other)
|
|
511
|
+
step = abs(self.step * other)
|
|
512
|
+
_factor = self._factor * other
|
|
513
|
+
opr = self.opr
|
|
514
|
+
return FixedVariable(
|
|
515
|
+
low,
|
|
516
|
+
high,
|
|
517
|
+
step,
|
|
518
|
+
_from=self._from,
|
|
519
|
+
_factor=_factor,
|
|
520
|
+
opr=opr,
|
|
521
|
+
latency=self.latency,
|
|
522
|
+
cost=self.cost,
|
|
523
|
+
_id=self.id,
|
|
524
|
+
_data=self._data,
|
|
525
|
+
hwconf=self.hwconf,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def __lshift__(self, other: int):
|
|
529
|
+
assert isinstance(other, int), 'Shift amount must be an integer'
|
|
530
|
+
shift_amount = 2.0**other
|
|
531
|
+
return self * shift_amount
|
|
532
|
+
|
|
533
|
+
def __rshift__(self, other: int):
|
|
534
|
+
assert isinstance(other, int), 'Shift amount must be an integer'
|
|
535
|
+
shift_amount = 2.0**-other
|
|
536
|
+
return self * shift_amount
|
|
537
|
+
|
|
538
|
+
def __radd__(self, other: 'float|Decimal|int|FixedVariable'):
|
|
539
|
+
return self + other
|
|
540
|
+
|
|
541
|
+
def __rsub__(self, other: 'float|Decimal|int|FixedVariable'):
|
|
542
|
+
return (-self) + other
|
|
543
|
+
|
|
544
|
+
def __rmul__(self, other: 'float|Decimal|int|FixedVariable'):
|
|
545
|
+
return self * other
|
|
546
|
+
|
|
547
|
+
def __pow__(self, other):
|
|
548
|
+
_power = int(other)
|
|
549
|
+
assert _power == other, 'Power must be an integer'
|
|
550
|
+
assert _power >= 0, 'Power must be non-negative'
|
|
551
|
+
if _power == 0:
|
|
552
|
+
return FixedVariable(1, 1, 1, hwconf=self.hwconf, opr='const')
|
|
553
|
+
if _power == 1:
|
|
554
|
+
return self
|
|
555
|
+
|
|
556
|
+
pow0 = _power // 2
|
|
557
|
+
ret = (self**pow0) * (self ** (_power - pow0))
|
|
558
|
+
if other % 2 == 0:
|
|
559
|
+
ret.low = max(ret.low, 0)
|
|
560
|
+
return ret
|
|
561
|
+
|
|
562
|
+
def relu(self, i: int | None = None, f: int | None = None, round_mode: str = 'TRN'):
|
|
563
|
+
round_mode = round_mode.upper()
|
|
564
|
+
assert round_mode in ('TRN', 'RND')
|
|
565
|
+
|
|
566
|
+
if self.opr == 'const':
|
|
567
|
+
val = self.low * (self.low > 0)
|
|
568
|
+
f = _const_f(val) if not f else f
|
|
569
|
+
step = Decimal(2) ** -f
|
|
570
|
+
i = ceil(log2(val + step)) if not i else i
|
|
571
|
+
eps = step / 2 if round_mode == 'RND' else 0
|
|
572
|
+
val = (floor(val / step + eps) * step) % (Decimal(2) ** i)
|
|
573
|
+
return self.from_const(val, hwconf=self.hwconf)
|
|
574
|
+
|
|
575
|
+
step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
|
|
576
|
+
if step > self.step and round_mode == 'RND':
|
|
577
|
+
return (self + step / 2).relu(i, f, 'TRN')
|
|
578
|
+
low = max(Decimal(0), self.low)
|
|
579
|
+
high = max(Decimal(0), self.high)
|
|
580
|
+
if i is not None:
|
|
581
|
+
_high = Decimal(2) ** i - step
|
|
582
|
+
if _high < high:
|
|
583
|
+
# overflows
|
|
584
|
+
low = Decimal(0)
|
|
585
|
+
high = _high
|
|
586
|
+
_factor = self._factor
|
|
587
|
+
|
|
588
|
+
if self.low == low and self.high == high and self.step == step:
|
|
589
|
+
return self
|
|
590
|
+
|
|
591
|
+
return FixedVariable(
|
|
592
|
+
low,
|
|
593
|
+
high,
|
|
594
|
+
step,
|
|
595
|
+
_from=(self,),
|
|
596
|
+
_factor=abs(_factor),
|
|
597
|
+
opr='relu',
|
|
598
|
+
hwconf=self.hwconf,
|
|
599
|
+
cost=sum(self.kif) * (1 if _factor > 0 else 2),
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
def quantize(
|
|
603
|
+
self,
|
|
604
|
+
k: int | bool,
|
|
605
|
+
i: int,
|
|
606
|
+
f: int,
|
|
607
|
+
overflow_mode: str = 'WRAP',
|
|
608
|
+
round_mode: str = 'TRN',
|
|
609
|
+
) -> 'FixedVariable':
|
|
610
|
+
"""Quantize the variable to the specified fixed-point format.
|
|
611
|
+
|
|
612
|
+
Parameters
|
|
613
|
+
----------
|
|
614
|
+
k : int | bool
|
|
615
|
+
Sign bit (True for signed, False for unsigned)
|
|
616
|
+
i : int
|
|
617
|
+
Integer bits, excluding sign bit
|
|
618
|
+
f : int
|
|
619
|
+
Fraction bits
|
|
620
|
+
overflow_mode : str, optional
|
|
621
|
+
Overflow mode, one of 'WRAP', 'SAT', 'SAT_SYM', by default 'WRAP'
|
|
622
|
+
round_mode : str, optional
|
|
623
|
+
Rounding mode, one of 'TRN' (truncate), 'RND' (round to nearest, half up), by default 'TRN'
|
|
624
|
+
"""
|
|
625
|
+
|
|
626
|
+
overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
|
|
627
|
+
assert overflow_mode in ('WRAP', 'SAT', 'SAT_SYM')
|
|
628
|
+
assert round_mode in ('TRN', 'RND')
|
|
629
|
+
|
|
630
|
+
if k + i + f <= 0:
|
|
631
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
632
|
+
_k, _i, _f = self.kif
|
|
633
|
+
|
|
634
|
+
if k >= _k and i >= _i and f >= _f:
|
|
635
|
+
if overflow_mode != 'SAT_SYM' or i > _i:
|
|
636
|
+
return self
|
|
637
|
+
|
|
638
|
+
if f < _f and round_mode == 'RND':
|
|
639
|
+
return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
640
|
+
|
|
641
|
+
if overflow_mode in ('SAT', 'SAT_SYM'):
|
|
642
|
+
step = Decimal(2) ** -f
|
|
643
|
+
_high = Decimal(2) ** i
|
|
644
|
+
high = _high - step
|
|
645
|
+
low = -_high * k if overflow_mode == 'SAT' else -high * k
|
|
646
|
+
ff = f + 1 if round_mode == 'RND' else f
|
|
647
|
+
v = self.quantize(_k, _i, ff, 'WRAP', 'TRN') if _k + _i + ff > 0 else self
|
|
648
|
+
return v.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
|
|
649
|
+
|
|
650
|
+
if self.low == self.high:
|
|
651
|
+
val = self.low
|
|
652
|
+
step = Decimal(2) ** -f
|
|
653
|
+
_high = Decimal(2) ** i
|
|
654
|
+
high, low = _high - step, -_high * k
|
|
655
|
+
val = (floor(val / step) * step - low) % (2 * _high) + low
|
|
656
|
+
return FixedVariable.from_const(val, hwconf=self.hwconf, _factor=1)
|
|
657
|
+
|
|
658
|
+
f = min(f, _f)
|
|
659
|
+
k = min(k, _k) if i >= _i else k
|
|
660
|
+
|
|
661
|
+
step = Decimal(2) ** -f
|
|
662
|
+
|
|
663
|
+
if self.low < 0:
|
|
664
|
+
_low = floor(self.low / step) * step
|
|
665
|
+
_i = max(_i, ceil(log2(-_low)))
|
|
666
|
+
|
|
667
|
+
i = min(i, _i + (k == 0 and _k == 1))
|
|
668
|
+
|
|
669
|
+
if i + k + f <= 0:
|
|
670
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
671
|
+
|
|
672
|
+
low = -k * Decimal(2) ** i
|
|
673
|
+
|
|
674
|
+
high = Decimal(2) ** i - step
|
|
675
|
+
_low, _high = self.low, self.high
|
|
676
|
+
|
|
677
|
+
if _low >= low and _high <= high:
|
|
678
|
+
low = floor(_low / step) * step
|
|
679
|
+
high = floor(_high / step) * step
|
|
680
|
+
|
|
681
|
+
return FixedVariable(
|
|
682
|
+
low,
|
|
683
|
+
high,
|
|
684
|
+
step,
|
|
685
|
+
_from=(self,),
|
|
686
|
+
_factor=abs(self._factor),
|
|
687
|
+
opr='wrap',
|
|
688
|
+
latency=self.latency,
|
|
689
|
+
hwconf=self.hwconf,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
@classmethod
|
|
693
|
+
def from_kif(cls, k: int | bool, i: int, f: int, **kwargs):
|
|
694
|
+
step = Decimal(2) ** -f
|
|
695
|
+
_high = Decimal(2) ** i
|
|
696
|
+
low, high = -k * _high, _high - step
|
|
697
|
+
return cls(low, high, step, **kwargs)
|
|
698
|
+
|
|
699
|
+
def msb_mux(
|
|
700
|
+
self,
|
|
701
|
+
a: 'FixedVariable|float|Decimal',
|
|
702
|
+
b: 'FixedVariable|float|Decimal',
|
|
703
|
+
qint: tuple[Decimal, Decimal, Decimal] | None = None,
|
|
704
|
+
):
|
|
705
|
+
"""If the MSB of this variable is 1, return a, else return b.
|
|
706
|
+
When the variable is signed, the MSB is determined by the sign bit (1 for <0, 0 for >=0)
|
|
707
|
+
"""
|
|
708
|
+
if not isinstance(a, FixedVariable):
|
|
709
|
+
a = FixedVariable.from_const(a, hwconf=self.hwconf, _factor=1)
|
|
710
|
+
if not isinstance(b, FixedVariable):
|
|
711
|
+
b = FixedVariable.from_const(b, hwconf=self.hwconf, _factor=1)
|
|
712
|
+
if self._factor < 0:
|
|
713
|
+
return (-self).msb_mux(b, a, qint)
|
|
714
|
+
|
|
715
|
+
if self.opr == 'const':
|
|
716
|
+
if self.low >= 0:
|
|
717
|
+
return b
|
|
718
|
+
else:
|
|
719
|
+
return b if log2(abs(self.low)) % 1 == 0 else a
|
|
720
|
+
elif self.opr == 'quantize':
|
|
721
|
+
k, i, _ = self.kif
|
|
722
|
+
pk, pi, _ = self._from[0].kif
|
|
723
|
+
if k + i == pk + pi:
|
|
724
|
+
return self._from[0].msb_mux(a, b, qint=qint)
|
|
725
|
+
|
|
726
|
+
if a._factor < 0:
|
|
727
|
+
qint = (-qint[1], -qint[0], qint[2]) if qint else None
|
|
728
|
+
return -(self.msb_mux(-a, -b, qint=qint))
|
|
729
|
+
|
|
730
|
+
_factor = a._factor
|
|
731
|
+
|
|
732
|
+
if qint is None:
|
|
733
|
+
qint = (min(a.low, b.low), max(a.high, b.high), min(a.step, b.step))
|
|
734
|
+
|
|
735
|
+
dlat, dcost = cost_add(a.qint, b.qint, 0, False, self.hwconf.adder_size, self.hwconf.carry_size)
|
|
736
|
+
return FixedVariable(
|
|
737
|
+
*qint,
|
|
738
|
+
_from=(self, a, b),
|
|
739
|
+
_factor=_factor,
|
|
740
|
+
opr='msb_mux',
|
|
741
|
+
latency=max(a.latency, b.latency, self.latency) + dlat,
|
|
742
|
+
hwconf=self.hwconf,
|
|
743
|
+
cost=dcost,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
def is_negative(self) -> 'FixedVariable|bool':
|
|
747
|
+
if self.low >= 0:
|
|
748
|
+
return False
|
|
749
|
+
if self.high < 0:
|
|
750
|
+
return True
|
|
751
|
+
_, i, _ = self.kif
|
|
752
|
+
sign_bit = self.quantize(0, i + 1, -i) >> i
|
|
753
|
+
return sign_bit
|
|
754
|
+
|
|
755
|
+
def is_positive(self) -> 'FixedVariable|bool':
|
|
756
|
+
return (-self).is_negative()
|
|
757
|
+
|
|
758
|
+
def __abs__(self):
|
|
759
|
+
if self.low >= 0:
|
|
760
|
+
return self
|
|
761
|
+
step = self.step
|
|
762
|
+
high = max(-self.low, self.high)
|
|
763
|
+
return self.msb_mux(-self, self, (Decimal(0), high, step))
|
|
764
|
+
|
|
765
|
+
def abs(self):
|
|
766
|
+
"""Get the absolute value of this variable."""
|
|
767
|
+
return abs(self)
|
|
768
|
+
|
|
769
|
+
def __gt__(self, other: 'FixedVariable|float|Decimal|int'):
|
|
770
|
+
"""Get a variable that is 1 if this variable is greater than other, else 0."""
|
|
771
|
+
return (self - other).is_positive()
|
|
772
|
+
|
|
773
|
+
def __lt__(self, other: 'FixedVariable|float|Decimal|int'):
|
|
774
|
+
"""Get a variable that is 1 if this variable is less than other, else 0."""
|
|
775
|
+
return (other - self).is_positive()
|
|
776
|
+
|
|
777
|
+
# def __ge__(self, other: 'FixedVariable|float|Decimal|int'):
|
|
778
|
+
# """Get a variable that is 1 if this variable is greater than or equal to other, else 0."""
|
|
779
|
+
# r = (other - self).is_negative()
|
|
780
|
+
# if isinstance(r, bool):
|
|
781
|
+
# return not r
|
|
782
|
+
# return ~r
|
|
783
|
+
|
|
784
|
+
# def __le__(self, other: 'FixedVariable|float|Decimal|int'):
|
|
785
|
+
# """Get a variable that is 1 if this variable is less than or equal to other, else 0."""
|
|
786
|
+
# r = (self - other).is_negative()
|
|
787
|
+
# if isinstance(r, bool):
|
|
788
|
+
# return not r
|
|
789
|
+
# return ~r
|
|
790
|
+
|
|
791
|
+
def max_of(self, other):
|
|
792
|
+
"""Get the maximum of this variable and another variable or constant."""
|
|
793
|
+
if other == -float('inf'):
|
|
794
|
+
return self
|
|
795
|
+
if other == float('inf'):
|
|
796
|
+
raise ValueError('Cannot apply max_of with inf')
|
|
797
|
+
if not isinstance(other, FixedVariable):
|
|
798
|
+
other = FixedVariable.from_const(other, hwconf=self.hwconf, _factor=abs(self._factor))
|
|
799
|
+
|
|
800
|
+
if self.low >= other.high:
|
|
801
|
+
return self
|
|
802
|
+
if self.high <= other.low:
|
|
803
|
+
return other
|
|
804
|
+
if other.high == other.low == 0:
|
|
805
|
+
return self.relu()
|
|
806
|
+
|
|
807
|
+
qint = (max(self.low, other.low), max(self.high, other.high), min(self.step, other.step))
|
|
808
|
+
return (self - other).msb_mux(other, self, qint=qint)
|
|
809
|
+
|
|
810
|
+
def min_of(self, other):
|
|
811
|
+
"""Get the minimum of this variable and another variable or constant."""
|
|
812
|
+
|
|
813
|
+
if other == float('inf'):
|
|
814
|
+
return self
|
|
815
|
+
if other == -float('inf'):
|
|
816
|
+
raise ValueError('Cannot apply min_of with -inf')
|
|
817
|
+
if not isinstance(other, FixedVariable):
|
|
818
|
+
other = FixedVariable.from_const(other, hwconf=self.hwconf, _factor=(self._factor))
|
|
819
|
+
|
|
820
|
+
if self.high <= other.low:
|
|
821
|
+
return self
|
|
822
|
+
if self.low >= other.high:
|
|
823
|
+
return other
|
|
824
|
+
if other.high == other.low == 0:
|
|
825
|
+
return -(-self).relu()
|
|
826
|
+
|
|
827
|
+
qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
|
|
828
|
+
return (self - other).msb_mux(self, other, qint=qint)
|
|
829
|
+
|
|
830
|
+
def lookup(self, table: LookupTable | np.ndarray) -> 'FixedVariable':
|
|
831
|
+
"""Use a lookup table to map the variable.
|
|
832
|
+
When the table is a numpy array, the table starts at the lowest possible value of the variable
|
|
833
|
+
When the table is in LookupTable format, the table starts at the normalized lowest value of the variable. (i.e., if the variable has negative _factor, the table is reversed)
|
|
834
|
+
|
|
835
|
+
Parameters
|
|
836
|
+
----------
|
|
837
|
+
table : LookupTable | np.ndarray
|
|
838
|
+
Lookup table to use
|
|
839
|
+
|
|
840
|
+
Returns
|
|
841
|
+
-------
|
|
842
|
+
FixedVariable
|
|
843
|
+
"""
|
|
844
|
+
if isinstance(table, np.ndarray):
|
|
845
|
+
if len(table) == 1:
|
|
846
|
+
return self.from_const(float(table[0]), hwconf=self.hwconf)
|
|
847
|
+
if self._factor < 0:
|
|
848
|
+
table = table[::-1] # Reverse the table for negative factor
|
|
849
|
+
|
|
850
|
+
_table, table_id = table_context.register_table(table)
|
|
851
|
+
size = len(table.table) if isinstance(table, LookupTable) else len(table)
|
|
852
|
+
assert round((self.high - self.low) / self.step) + 1 == size, (
|
|
853
|
+
f'Input variable size does not match lookup table size ({round((self.high - self.low) / self.step) + 1} != {size})'
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
return FixedVariable(
|
|
857
|
+
_table.spec.out_qint.min,
|
|
858
|
+
_table.spec.out_qint.max,
|
|
859
|
+
_table.spec.out_qint.step,
|
|
860
|
+
_from=(self,),
|
|
861
|
+
_factor=Decimal(1),
|
|
862
|
+
opr='lookup',
|
|
863
|
+
hwconf=self.hwconf,
|
|
864
|
+
_data=Decimal(table_id),
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
class FixedVariableInput(FixedVariable):
|
|
869
|
+
__normal__variable__ = False
|
|
870
|
+
|
|
871
|
+
def __init__(
|
|
872
|
+
self,
|
|
873
|
+
latency: float | None = None,
|
|
874
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(-1, -1, -1),
|
|
875
|
+
opr: str = 'new',
|
|
876
|
+
) -> None:
|
|
877
|
+
super().__init__(
|
|
878
|
+
low=Decimal(1e10),
|
|
879
|
+
high=Decimal(-1e10),
|
|
880
|
+
step=Decimal(1e10),
|
|
881
|
+
latency=latency if latency is not None else 0.0,
|
|
882
|
+
hwconf=HWConfig(*hwconf),
|
|
883
|
+
opr=opr,
|
|
884
|
+
cost=0.0,
|
|
885
|
+
_factor=Decimal(1),
|
|
886
|
+
_from=(),
|
|
887
|
+
_data=None,
|
|
888
|
+
_id=None,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
def __add__(self, other):
|
|
892
|
+
if other == 0:
|
|
893
|
+
return self
|
|
894
|
+
raise ValueError('Cannot operate on unquantized input variable')
|
|
895
|
+
|
|
896
|
+
def __sub__(self, other):
|
|
897
|
+
if other == 0:
|
|
898
|
+
return self
|
|
899
|
+
raise ValueError('Cannot operate on unquantized input variable')
|
|
900
|
+
|
|
901
|
+
def __neg__(self):
|
|
902
|
+
raise ValueError('Cannot negate unquantized input variable')
|
|
903
|
+
|
|
904
|
+
def __mul__(self, other):
|
|
905
|
+
if other == 1:
|
|
906
|
+
return self
|
|
907
|
+
raise ValueError('Cannot multiply unquantized input variable')
|
|
908
|
+
|
|
909
|
+
def __rmul__(self, other):
|
|
910
|
+
if other == 1:
|
|
911
|
+
return self
|
|
912
|
+
raise ValueError('Cannot multiply unquantized input variable')
|
|
913
|
+
|
|
914
|
+
def __radd__(self, other):
|
|
915
|
+
if other == 0:
|
|
916
|
+
return self
|
|
917
|
+
raise ValueError('Cannot add unquantized input variable')
|
|
918
|
+
|
|
919
|
+
def __rsub__(self, other):
|
|
920
|
+
raise ValueError('Cannot subtract unquantized input variable')
|
|
921
|
+
|
|
922
|
+
def relu(self, *args, **kwargs):
|
|
923
|
+
raise ValueError('Cannot apply relu on unquantized input variable')
|
|
924
|
+
|
|
925
|
+
def max_of(self, other):
|
|
926
|
+
raise ValueError('Cannot apply max_of on unquantized input variable')
|
|
927
|
+
|
|
928
|
+
def min_of(self, other):
|
|
929
|
+
raise ValueError('Cannot apply min_of on unquantized input variable')
|
|
930
|
+
|
|
931
|
+
def quantize(
|
|
932
|
+
self,
|
|
933
|
+
k: int | bool,
|
|
934
|
+
i: int,
|
|
935
|
+
f: int,
|
|
936
|
+
overflow_mode: str = 'WRAP',
|
|
937
|
+
round_mode: str = 'TRN',
|
|
938
|
+
):
|
|
939
|
+
assert overflow_mode == 'WRAP'
|
|
940
|
+
|
|
941
|
+
if k + i + f <= 0:
|
|
942
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
943
|
+
|
|
944
|
+
if round_mode == 'RND':
|
|
945
|
+
return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
946
|
+
else:
|
|
947
|
+
round_mode = 'TRN'
|
|
948
|
+
|
|
949
|
+
step = Decimal(2) ** -f
|
|
950
|
+
_high = Decimal(2) ** i
|
|
951
|
+
low, high = -_high * k, _high - step
|
|
952
|
+
self.high = max(self.high, high)
|
|
953
|
+
self.low = min(self.low, low)
|
|
954
|
+
self.step = min(self.step, step)
|
|
955
|
+
|
|
956
|
+
return FixedVariable(
|
|
957
|
+
low,
|
|
958
|
+
high,
|
|
959
|
+
step,
|
|
960
|
+
_from=(self,),
|
|
961
|
+
_factor=self._factor,
|
|
962
|
+
opr='wrap',
|
|
963
|
+
latency=self.latency,
|
|
964
|
+
hwconf=self.hwconf,
|
|
965
|
+
)
|