da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +204 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +246 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.1.post1.dist-info/METADATA +85 -0
  92. da4ml-0.5.1.post1.dist-info/RECORD +96 -0
  93. da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
  94. da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,965 @@
1
+ import random
2
+ import typing
3
+ from collections.abc import Callable, Generator
4
+ from copy import copy
5
+ from dataclasses import dataclass
6
+ from decimal import Decimal
7
+ from hashlib import sha256
8
+ from math import ceil, floor, log2
9
+ from typing import NamedTuple, overload
10
+ from uuid import UUID
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+ from ..cmvm.core import cost_add
16
+ from ..cmvm.types import QInterval, _minimal_kif
17
+ from ..cmvm.util.bit_decompose import _shift_centering
18
+
19
+ rd = random.Random()
20
+
21
+ if typing.TYPE_CHECKING:
22
+ pass
23
+
24
+
25
+ class HWConfig(NamedTuple):
26
+ adder_size: int
27
+ carry_size: int
28
+ latency_cutoff: float
29
+
30
+
31
+ ufunc_t = Callable[[NDArray[np.floating]], NDArray[np.floating]]
32
+
33
+
34
+ class TraceContext:
35
+ _tables: 'dict[str, tuple[LookupTable, int]]' = {}
36
+ hwconf: HWConfig = HWConfig(1, -1, -1)
37
+ _table_counter = 0
38
+
39
+ def register_table(self, table: 'LookupTable|np.ndarray'):
40
+ if isinstance(table, np.ndarray):
41
+ table = LookupTable(table)
42
+ if table.spec.hash in self._tables:
43
+ return self._tables[table.spec.hash]
44
+ self._tables[table.spec.hash] = (table, self._table_counter)
45
+
46
+ self._table_counter += 1
47
+ return self._tables[table.spec.hash]
48
+
49
+ def index_table(self, hash: str) -> int:
50
+ return self._tables[hash][1]
51
+
52
+ def get_table_from_index(self, index: int) -> 'LookupTable':
53
+ for table, idx in self._tables.values():
54
+ if idx == index:
55
+ return table
56
+ raise KeyError(f'No table found with index {index}')
57
+
58
+
59
+ table_context = TraceContext()
60
+
61
+
62
+ @dataclass
63
+ class TableSpec:
64
+ hash: str
65
+ out_qint: QInterval
66
+ inp_width: int
67
+
68
+ @property
69
+ def out_kif(self) -> tuple[bool, int, int]:
70
+ return _minimal_kif(self.out_qint)
71
+
72
+
73
+ def to_spec(table: NDArray[np.floating]) -> tuple[TableSpec, NDArray[np.int32]]:
74
+ f_out = -_shift_centering(np.array(table))
75
+ int_table = (table * 2**f_out).astype(np.int32)
76
+ h = sha256(int_table.data)
77
+ h.update(f'{f_out}'.encode())
78
+ inp_width = ceil(log2(table.size))
79
+ out_qint = QInterval(float(np.min(table)), float(np.max(table)), float(2**-f_out))
80
+ return TableSpec(hash=h.hexdigest(), inp_width=inp_width, out_qint=out_qint), int_table
81
+
82
+
83
+ def interpret_as(
84
+ x: int | NDArray[np.integer],
85
+ k: int,
86
+ i: int,
87
+ f: int,
88
+ ) -> float | NDArray[np.floating]:
89
+ b = k + i + f
90
+ bias = 2.0 ** (b - 1) * k
91
+ eps = 2.0**-f
92
+ floor_fn = np.floor if isinstance(x, np.ndarray) else floor
93
+ return eps * (floor_fn(x + bias) % 2.0**b - bias)
94
+
95
+
96
+ class LookupTable:
97
+ def __init__(self, values: NDArray, spec: TableSpec | None = None):
98
+ assert values.ndim == 1, 'Lookup table values must be 1-dimensional'
99
+ if spec is not None:
100
+ assert values.dtype == np.int32, f'{values.dtype}'
101
+ self.spec = spec
102
+ self.table = values
103
+ else:
104
+ self.spec, self.table = to_spec(values)
105
+
106
+ @overload
107
+ def lookup(self, var: 'FixedVariable', qint_in: QInterval) -> 'FixedVariable': ...
108
+
109
+ @overload
110
+ def lookup(self, var: np.floating | float, qint_in: QInterval | tuple[float, float, float]) -> float: ...
111
+
112
+ def lookup(self, var, qint_in: QInterval | tuple[float, float, float]):
113
+ if isinstance(var, FixedVariable):
114
+ return var.lookup(self)
115
+ else:
116
+ _min, _max, _step = qint_in
117
+ assert _min <= var <= _max, f'Value {var} out of range [{_min}, {_max}]'
118
+ index = round((var - _min) / _step)
119
+ return interpret_as(int(self.table[index]), *self.spec.out_kif)
120
+
121
+ @property
122
+ def float_table(self) -> NDArray[np.floating]:
123
+ k, i, f = self.spec.out_kif
124
+ return interpret_as(self.table, k, i, f) # type: ignore
125
+
126
+ def to_dict(self) -> dict:
127
+ return {
128
+ 'spec': {
129
+ 'hash': self.spec.hash,
130
+ 'out_qint': {
131
+ 'min': self.spec.out_qint.min,
132
+ 'max': self.spec.out_qint.max,
133
+ 'step': self.spec.out_qint.step,
134
+ },
135
+ 'inp_width': self.spec.inp_width,
136
+ },
137
+ 'table': self.table.tolist(),
138
+ }
139
+
140
+ @classmethod
141
+ def from_dict(cls, data: dict) -> 'LookupTable':
142
+ spec_data = data['spec']
143
+ out_qint_data = spec_data['out_qint']
144
+ spec = TableSpec(
145
+ hash=spec_data['hash'],
146
+ out_qint=QInterval(out_qint_data['min'], out_qint_data['max'], out_qint_data['step']),
147
+ inp_width=spec_data['inp_width'],
148
+ )
149
+ table = np.array(data['table'], dtype=np.int32)
150
+ return cls(table, spec=spec)
151
+
152
+ def _get_pads(self, qint: QInterval) -> tuple[int, int]:
153
+ k, i, f = _minimal_kif(qint)
154
+ if k:
155
+ pad_left = round((qint.min + 2**i) / qint.step)
156
+ else:
157
+ pad_left = round(qint.min / qint.step)
158
+ size = 2 ** (k + i + f)
159
+ pad_right = size - len(self.table) - pad_left
160
+ return pad_left, pad_right
161
+
162
+ def padded_table(self, qint: QInterval) -> NDArray[np.int32]:
163
+ pad_left, pad_right = self._get_pads(qint)
164
+ data = np.pad(self.table, (pad_left, pad_right), mode='constant', constant_values=0)
165
+ if qint.min < 0:
166
+ size = len(data)
167
+ # data = np.concatenate((data[size // 2 :], data[: size // 2]))
168
+ data = np.roll(data, size // 2)
169
+ return data
170
+
171
+
172
+ def _const_f(const: float | Decimal):
173
+ """Get the minimum f such that const * 2^f is an integer."""
174
+ const = float(const)
175
+ if const == 0:
176
+ return 0
177
+ _low, _high = -32, 32
178
+ while _high - _low > 1:
179
+ _mid = (_high + _low) // 2
180
+ _value = const * (2.0**_mid)
181
+ if _value == int(_value):
182
+ _high = _mid
183
+ else:
184
+ _low = _mid
185
+ return _high
186
+
187
+
188
+ def to_csd_powers(x: float) -> Generator[float, None, None]:
189
+ """Convert a float to a list of +/- powers of two in CSD representation."""
190
+ if x == 0:
191
+ return
192
+ f = _const_f(abs(x))
193
+ x = x * 2**f
194
+ s = 2**-f
195
+ N = ceil(log2(abs(x) * 1.5 + 1e-19))
196
+ for n in range(N - 1, -1, -1):
197
+ _2pn = 2**n
198
+ thres = _2pn / 1.5
199
+ bit = int(x > thres) - int(x < -thres)
200
+ v = _2pn * bit
201
+ x -= v
202
+ if v != 0:
203
+ yield v * s
204
+
205
+
206
+ class FixedVariable:
207
+ __normal__variable__ = True
208
+
209
+ def __init__(
210
+ self,
211
+ low: float | Decimal,
212
+ high: float | Decimal,
213
+ step: float | Decimal,
214
+ latency: float | None = None,
215
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(-1, -1, -1),
216
+ opr: str = 'new',
217
+ cost: float | None = None,
218
+ _from: tuple['FixedVariable', ...] = (),
219
+ _factor: float | Decimal = 1.0,
220
+ _data: Decimal | None = None,
221
+ _id: UUID | None = None,
222
+ ) -> None:
223
+ if self.__normal__variable__:
224
+ assert low <= high, f'low {low} must be less than high {high}'
225
+
226
+ if low != high and opr == 'const':
227
+ raise ValueError('Constant variable must have low == high')
228
+
229
+ if low == high:
230
+ opr = 'const'
231
+ _from = ()
232
+ step = 2.0 ** -_const_f(low)
233
+
234
+ low, high, step = Decimal(low), Decimal(high), Decimal(step)
235
+ self.low = low
236
+ self.high = high
237
+ self.step = step
238
+ self._factor = Decimal(_factor)
239
+ self._from: tuple[FixedVariable, ...] = _from
240
+ opr = opr
241
+ self.opr = opr
242
+ self._data = _data
243
+ self.id = _id or UUID(int=rd.getrandbits(128), version=4)
244
+ self.hwconf = HWConfig(*hwconf)
245
+
246
+ if opr == 'cadd':
247
+ assert _data is not None, 'cadd must have data'
248
+
249
+ if cost is None or latency is None:
250
+ _cost, _latency = self.get_cost_and_latency()
251
+ else:
252
+ _cost, _latency = cost, latency
253
+
254
+ self.latency = _latency
255
+ self.cost = _cost
256
+
257
+ self._from = tuple(v if v.opr != 'const' else v._with(latency=self.latency) for v in self._from)
258
+
259
+ def _with(self, renew_id=True, **kwargs):
260
+ if not kwargs:
261
+ return self
262
+ _var = copy(self)
263
+ for k, v in kwargs.items():
264
+ setattr(_var, k, v)
265
+ if renew_id:
266
+ _var.id = UUID(int=rd.getrandbits(128), version=4)
267
+ return _var
268
+
269
+ def get_cost_and_latency(self) -> tuple[float, float]:
270
+ if self.opr == 'const':
271
+ return 0.0, 0.0
272
+
273
+ if self.opr == 'lookup':
274
+ assert len(self._from) == 1
275
+ b_in = sum(self._from[0].kif)
276
+ b_out = sum(self.kif)
277
+ _latency = max(b_in - 6, 1) + self._from[0].latency
278
+ _cost = 2 ** max(b_in - 5, 0) * ceil(b_out / 2)
279
+ if b_in < 5:
280
+ _cost *= b_in / 5
281
+ # Assume LUT6 with extra o5 output
282
+ return _cost, _latency
283
+
284
+ if self.opr in ('vadd', 'cadd', 'min', 'max', 'vmul'):
285
+ adder_size = self.hwconf.adder_size
286
+ carry_size = self.hwconf.carry_size
287
+ latency_cutoff = self.hwconf.latency_cutoff
288
+
289
+ if self.opr in ('min', 'max', 'vadd'):
290
+ assert len(self._from) == 2
291
+ v0, v1 = self._from
292
+ int0, int1 = v0.qint, v1.qint
293
+ base_latency = max(v0.latency, v1.latency)
294
+ dlat, _cost = cost_add(int0, int1, 0, False, adder_size, carry_size)
295
+ elif self.opr == 'cadd':
296
+ assert len(self._from) == 1
297
+ assert self._data is not None, 'cadd must have data'
298
+ _f = _const_f(self._data)
299
+ _cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
300
+ base_latency = self._from[0].latency
301
+ dlat = 0.0
302
+ elif self.opr == 'vmul':
303
+ assert len(self._from) == 2
304
+ v0, v1 = self._from
305
+ b0, b1 = sum(v0.kif), sum(v1.kif)
306
+ int0, int1 = v0.qint, v1.qint
307
+ dlat0, _cost0 = cost_add(int0, int0, 0, False, adder_size, carry_size)
308
+ dlat1, _cost1 = cost_add(int1, int1, 0, False, adder_size, carry_size)
309
+ dlat = max(dlat0 * b1, dlat1 * b0)
310
+ _cost = min(_cost0 * b1, _cost1 * b0)
311
+ base_latency = max(v0.latency, v1.latency)
312
+ else:
313
+ raise NotImplementedError(f'Operation {self.opr} is unknown')
314
+
315
+ _latency = dlat + base_latency
316
+ if latency_cutoff > 0 and ceil(_latency / latency_cutoff) > ceil(base_latency / latency_cutoff):
317
+ # Crossed the latency cutoff boundry
318
+ assert dlat <= latency_cutoff, (
319
+ f'Latency of an atomic operation {dlat} is larger than the pipelining latency cutoff {latency_cutoff}'
320
+ )
321
+ _latency = ceil(base_latency / latency_cutoff) * latency_cutoff + dlat
322
+
323
+ elif self.opr in ('relu', 'wrap'):
324
+ assert len(self._from) == 1
325
+ _latency = self._from[0].latency
326
+ _cost = 0.0
327
+ # Assume LUT5 used here (2 fan-out per LUT6, thus *1/2)
328
+ if self._from[0]._factor < 0:
329
+ _cost += sum(self.kif) / 2
330
+ if self.opr == 'relu':
331
+ _cost += sum(self.kif) / 2
332
+
333
+ elif self.opr == 'new':
334
+ # new variable, no cost
335
+ _latency = 0.0
336
+ _cost = 0.0
337
+ else:
338
+ raise NotImplementedError(f'Operation {self.opr} is unknown')
339
+ return _cost, _latency
340
+
341
+ @property
342
+ def unscaled(self):
343
+ return self * (1 / self._factor)
344
+
345
+ @property
346
+ def qint(self) -> QInterval:
347
+ return QInterval(float(self.low), float(self.high), float(self.step))
348
+
349
+ @property
350
+ def kif(self) -> tuple[bool, int, int]:
351
+ if self.step == 0:
352
+ return False, 0, 0
353
+ f = -int(log2(self.step))
354
+ i = ceil(log2(max(-self.low, self.high + self.step)))
355
+ k = self.low < 0
356
+ return k, i, f
357
+
358
+ @classmethod
359
+ def from_const(cls, const: float | Decimal, hwconf: HWConfig, _factor: float | Decimal = 1):
360
+ return cls(const, const, -1, hwconf=hwconf, opr='const', _factor=_factor)
361
+
362
+ def __repr__(self) -> str:
363
+ if self._factor == 1:
364
+ return f'FixedVariable({self.low}, {self.high}, {self.step})'
365
+ return f'({self._factor}) FixedVariable({self.low}, {self.high}, {self.step})'
366
+
367
+ def __neg__(self):
368
+ opr = self.opr if self.low != self.high else 'const'
369
+ return FixedVariable(
370
+ -self.high,
371
+ -self.low,
372
+ self.step,
373
+ _from=self._from,
374
+ _factor=-self._factor,
375
+ latency=self.latency,
376
+ cost=self.cost,
377
+ opr=opr,
378
+ _id=self.id,
379
+ _data=self._data,
380
+ hwconf=self.hwconf,
381
+ )
382
+
383
+ def __add__(self, other: 'FixedVariable|float|Decimal|int'):
384
+ if not isinstance(other, FixedVariable):
385
+ return self._const_add(other)
386
+ if other.high == other.low:
387
+ return self._const_add(other.low)
388
+ if self.high == self.low:
389
+ return other._const_add(self.low)
390
+
391
+ assert self.hwconf == other.hwconf, f'FixedVariable must have the same hwconf, got {self.hwconf} and {other.hwconf}'
392
+
393
+ f0, f1 = self._factor, other._factor
394
+ if f0 < 0:
395
+ if f1 > 0:
396
+ return other + self
397
+ else:
398
+ return -((-self) + (-other))
399
+
400
+ return FixedVariable(
401
+ self.low + other.low,
402
+ self.high + other.high,
403
+ min(self.step, other.step),
404
+ _from=(self, other),
405
+ _factor=f0,
406
+ opr='vadd',
407
+ hwconf=self.hwconf,
408
+ )
409
+
410
+ def _const_add(self, other: float | Decimal | None) -> 'FixedVariable':
411
+ if other is None:
412
+ return self
413
+ if not isinstance(other, (int, float, Decimal)):
414
+ other = float(other) # direct numpy to decimal raises error
415
+ other = Decimal(other)
416
+ if other == 0:
417
+ return self
418
+
419
+ if self.opr != 'cadd':
420
+ cstep = Decimal(2.0 ** -_const_f(other))
421
+
422
+ return FixedVariable(
423
+ self.low + other,
424
+ self.high + other,
425
+ min(self.step, cstep),
426
+ _from=(self,),
427
+ _factor=self._factor,
428
+ _data=other / self._factor,
429
+ opr='cadd',
430
+ hwconf=self.hwconf,
431
+ )
432
+
433
+ # cadd, combine the constant
434
+ assert len(self._from) == 1
435
+ parent = self._from[0]
436
+ assert self._data is not None, 'cadd must have data'
437
+ sf = self._factor / parent._factor
438
+ other1 = (self._data * parent._factor) + other / sf
439
+ return (parent + other1) * sf
440
+
441
+ def __sub__(self, other: 'FixedVariable|int|float|Decimal'):
442
+ return self + (-other)
443
+
444
+ def __truediv__(self, other: 'int|float|Decimal'):
445
+ assert not isinstance(other, FixedVariable), 'Division by variable is not supported'
446
+ return self * (1 / other)
447
+
448
+ def __mul__(self, other: 'FixedVariable|int|float|Decimal') -> 'FixedVariable':
449
+ if isinstance(other, FixedVariable):
450
+ if self.high == self.low:
451
+ return other * self.low
452
+ if other.high > other.low:
453
+ return self._var_mul(other)
454
+ assert other.high == other.low
455
+ other = float(other.low)
456
+
457
+ if other == 0:
458
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
459
+
460
+ if log2(abs(other)) % 1 == 0:
461
+ return self._pow2_mul(other)
462
+
463
+ variables = [(self._pow2_mul(v), Decimal(v)) for v in to_csd_powers(float(other))]
464
+ while len(variables) > 1:
465
+ v1, p1 = variables.pop()
466
+ v2, p2 = variables.pop()
467
+ v, p = v1 + v2, p1 + p2
468
+ if p > 0:
469
+ high, low = self.high * p, self.low * p
470
+ else:
471
+ high, low = self.low * p, self.high * p
472
+ v.high, v.low = high, low
473
+ variables.append((v, p))
474
+ return variables[0][0]
475
+
476
+ def _var_mul(self, other: 'FixedVariable') -> 'FixedVariable':
477
+ if other is not self:
478
+ a, b, c, d = self.high * other.low, self.low * other.high, self.high * other.high, self.low * other.low
479
+ low = min(a, b, c, d)
480
+ high = max(a, b, c, d)
481
+ else:
482
+ a, b = self.low * other.low, self.high * other.high
483
+ if self.low < 0 and self.high > 0:
484
+ low = min(a, b, 0)
485
+ high = max(a, b, 0)
486
+ else:
487
+ low = min(a, b)
488
+ high = max(a, b)
489
+
490
+ step = self.step * other.step
491
+ _factor = self._factor * other._factor
492
+ opr = 'vmul'
493
+ return FixedVariable(
494
+ low,
495
+ high,
496
+ step,
497
+ _from=(self, other),
498
+ hwconf=self.hwconf,
499
+ _factor=_factor,
500
+ opr=opr,
501
+ )
502
+
503
+ def _pow2_mul(
504
+ self,
505
+ other: float | Decimal,
506
+ ):
507
+ other = Decimal(other)
508
+
509
+ low = min(self.low * other, self.high * other)
510
+ high = max(self.low * other, self.high * other)
511
+ step = abs(self.step * other)
512
+ _factor = self._factor * other
513
+ opr = self.opr
514
+ return FixedVariable(
515
+ low,
516
+ high,
517
+ step,
518
+ _from=self._from,
519
+ _factor=_factor,
520
+ opr=opr,
521
+ latency=self.latency,
522
+ cost=self.cost,
523
+ _id=self.id,
524
+ _data=self._data,
525
+ hwconf=self.hwconf,
526
+ )
527
+
528
+ def __lshift__(self, other: int):
529
+ assert isinstance(other, int), 'Shift amount must be an integer'
530
+ shift_amount = 2.0**other
531
+ return self * shift_amount
532
+
533
+ def __rshift__(self, other: int):
534
+ assert isinstance(other, int), 'Shift amount must be an integer'
535
+ shift_amount = 2.0**-other
536
+ return self * shift_amount
537
+
538
+ def __radd__(self, other: 'float|Decimal|int|FixedVariable'):
539
+ return self + other
540
+
541
+ def __rsub__(self, other: 'float|Decimal|int|FixedVariable'):
542
+ return (-self) + other
543
+
544
+ def __rmul__(self, other: 'float|Decimal|int|FixedVariable'):
545
+ return self * other
546
+
547
+ def __pow__(self, other):
548
+ _power = int(other)
549
+ assert _power == other, 'Power must be an integer'
550
+ assert _power >= 0, 'Power must be non-negative'
551
+ if _power == 0:
552
+ return FixedVariable(1, 1, 1, hwconf=self.hwconf, opr='const')
553
+ if _power == 1:
554
+ return self
555
+
556
+ pow0 = _power // 2
557
+ ret = (self**pow0) * (self ** (_power - pow0))
558
+ if other % 2 == 0:
559
+ ret.low = max(ret.low, 0)
560
+ return ret
561
+
562
+ def relu(self, i: int | None = None, f: int | None = None, round_mode: str = 'TRN'):
563
+ round_mode = round_mode.upper()
564
+ assert round_mode in ('TRN', 'RND')
565
+
566
+ if self.opr == 'const':
567
+ val = self.low * (self.low > 0)
568
+ f = _const_f(val) if not f else f
569
+ step = Decimal(2) ** -f
570
+ i = ceil(log2(val + step)) if not i else i
571
+ eps = step / 2 if round_mode == 'RND' else 0
572
+ val = (floor(val / step + eps) * step) % (Decimal(2) ** i)
573
+ return self.from_const(val, hwconf=self.hwconf)
574
+
575
+ step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
576
+ if step > self.step and round_mode == 'RND':
577
+ return (self + step / 2).relu(i, f, 'TRN')
578
+ low = max(Decimal(0), self.low)
579
+ high = max(Decimal(0), self.high)
580
+ if i is not None:
581
+ _high = Decimal(2) ** i - step
582
+ if _high < high:
583
+ # overflows
584
+ low = Decimal(0)
585
+ high = _high
586
+ _factor = self._factor
587
+
588
+ if self.low == low and self.high == high and self.step == step:
589
+ return self
590
+
591
+ return FixedVariable(
592
+ low,
593
+ high,
594
+ step,
595
+ _from=(self,),
596
+ _factor=abs(_factor),
597
+ opr='relu',
598
+ hwconf=self.hwconf,
599
+ cost=sum(self.kif) * (1 if _factor > 0 else 2),
600
+ )
601
+
602
+ def quantize(
603
+ self,
604
+ k: int | bool,
605
+ i: int,
606
+ f: int,
607
+ overflow_mode: str = 'WRAP',
608
+ round_mode: str = 'TRN',
609
+ ) -> 'FixedVariable':
610
+ """Quantize the variable to the specified fixed-point format.
611
+
612
+ Parameters
613
+ ----------
614
+ k : int | bool
615
+ Sign bit (True for signed, False for unsigned)
616
+ i : int
617
+ Integer bits, excluding sign bit
618
+ f : int
619
+ Fraction bits
620
+ overflow_mode : str, optional
621
+ Overflow mode, one of 'WRAP', 'SAT', 'SAT_SYM', by default 'WRAP'
622
+ round_mode : str, optional
623
+ Rounding mode, one of 'TRN' (truncate), 'RND' (round to nearest, half up), by default 'TRN'
624
+ """
625
+
626
+ overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
627
+ assert overflow_mode in ('WRAP', 'SAT', 'SAT_SYM')
628
+ assert round_mode in ('TRN', 'RND')
629
+
630
+ if k + i + f <= 0:
631
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
632
+ _k, _i, _f = self.kif
633
+
634
+ if k >= _k and i >= _i and f >= _f:
635
+ if overflow_mode != 'SAT_SYM' or i > _i:
636
+ return self
637
+
638
+ if f < _f and round_mode == 'RND':
639
+ return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
640
+
641
+ if overflow_mode in ('SAT', 'SAT_SYM'):
642
+ step = Decimal(2) ** -f
643
+ _high = Decimal(2) ** i
644
+ high = _high - step
645
+ low = -_high * k if overflow_mode == 'SAT' else -high * k
646
+ ff = f + 1 if round_mode == 'RND' else f
647
+ v = self.quantize(_k, _i, ff, 'WRAP', 'TRN') if _k + _i + ff > 0 else self
648
+ return v.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
649
+
650
+ if self.low == self.high:
651
+ val = self.low
652
+ step = Decimal(2) ** -f
653
+ _high = Decimal(2) ** i
654
+ high, low = _high - step, -_high * k
655
+ val = (floor(val / step) * step - low) % (2 * _high) + low
656
+ return FixedVariable.from_const(val, hwconf=self.hwconf, _factor=1)
657
+
658
+ f = min(f, _f)
659
+ k = min(k, _k) if i >= _i else k
660
+
661
+ step = Decimal(2) ** -f
662
+
663
+ if self.low < 0:
664
+ _low = floor(self.low / step) * step
665
+ _i = max(_i, ceil(log2(-_low)))
666
+
667
+ i = min(i, _i + (k == 0 and _k == 1))
668
+
669
+ if i + k + f <= 0:
670
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
671
+
672
+ low = -k * Decimal(2) ** i
673
+
674
+ high = Decimal(2) ** i - step
675
+ _low, _high = self.low, self.high
676
+
677
+ if _low >= low and _high <= high:
678
+ low = floor(_low / step) * step
679
+ high = floor(_high / step) * step
680
+
681
+ return FixedVariable(
682
+ low,
683
+ high,
684
+ step,
685
+ _from=(self,),
686
+ _factor=abs(self._factor),
687
+ opr='wrap',
688
+ latency=self.latency,
689
+ hwconf=self.hwconf,
690
+ )
691
+
692
+ @classmethod
693
+ def from_kif(cls, k: int | bool, i: int, f: int, **kwargs):
694
+ step = Decimal(2) ** -f
695
+ _high = Decimal(2) ** i
696
+ low, high = -k * _high, _high - step
697
+ return cls(low, high, step, **kwargs)
698
+
699
+ def msb_mux(
700
+ self,
701
+ a: 'FixedVariable|float|Decimal',
702
+ b: 'FixedVariable|float|Decimal',
703
+ qint: tuple[Decimal, Decimal, Decimal] | None = None,
704
+ ):
705
+ """If the MSB of this variable is 1, return a, else return b.
706
+ When the variable is signed, the MSB is determined by the sign bit (1 for <0, 0 for >=0)
707
+ """
708
+ if not isinstance(a, FixedVariable):
709
+ a = FixedVariable.from_const(a, hwconf=self.hwconf, _factor=1)
710
+ if not isinstance(b, FixedVariable):
711
+ b = FixedVariable.from_const(b, hwconf=self.hwconf, _factor=1)
712
+ if self._factor < 0:
713
+ return (-self).msb_mux(b, a, qint)
714
+
715
+ if self.opr == 'const':
716
+ if self.low >= 0:
717
+ return b
718
+ else:
719
+ return b if log2(abs(self.low)) % 1 == 0 else a
720
+ elif self.opr == 'quantize':
721
+ k, i, _ = self.kif
722
+ pk, pi, _ = self._from[0].kif
723
+ if k + i == pk + pi:
724
+ return self._from[0].msb_mux(a, b, qint=qint)
725
+
726
+ if a._factor < 0:
727
+ qint = (-qint[1], -qint[0], qint[2]) if qint else None
728
+ return -(self.msb_mux(-a, -b, qint=qint))
729
+
730
+ _factor = a._factor
731
+
732
+ if qint is None:
733
+ qint = (min(a.low, b.low), max(a.high, b.high), min(a.step, b.step))
734
+
735
+ dlat, dcost = cost_add(a.qint, b.qint, 0, False, self.hwconf.adder_size, self.hwconf.carry_size)
736
+ return FixedVariable(
737
+ *qint,
738
+ _from=(self, a, b),
739
+ _factor=_factor,
740
+ opr='msb_mux',
741
+ latency=max(a.latency, b.latency, self.latency) + dlat,
742
+ hwconf=self.hwconf,
743
+ cost=dcost,
744
+ )
745
+
746
+ def is_negative(self) -> 'FixedVariable|bool':
747
+ if self.low >= 0:
748
+ return False
749
+ if self.high < 0:
750
+ return True
751
+ _, i, _ = self.kif
752
+ sign_bit = self.quantize(0, i + 1, -i) >> i
753
+ return sign_bit
754
+
755
+ def is_positive(self) -> 'FixedVariable|bool':
756
+ return (-self).is_negative()
757
+
758
+ def __abs__(self):
759
+ if self.low >= 0:
760
+ return self
761
+ step = self.step
762
+ high = max(-self.low, self.high)
763
+ return self.msb_mux(-self, self, (Decimal(0), high, step))
764
+
765
+ def abs(self):
766
+ """Get the absolute value of this variable."""
767
+ return abs(self)
768
+
769
+ def __gt__(self, other: 'FixedVariable|float|Decimal|int'):
770
+ """Get a variable that is 1 if this variable is greater than other, else 0."""
771
+ return (self - other).is_positive()
772
+
773
+ def __lt__(self, other: 'FixedVariable|float|Decimal|int'):
774
+ """Get a variable that is 1 if this variable is less than other, else 0."""
775
+ return (other - self).is_positive()
776
+
777
+ # def __ge__(self, other: 'FixedVariable|float|Decimal|int'):
778
+ # """Get a variable that is 1 if this variable is greater than or equal to other, else 0."""
779
+ # r = (other - self).is_negative()
780
+ # if isinstance(r, bool):
781
+ # return not r
782
+ # return ~r
783
+
784
+ # def __le__(self, other: 'FixedVariable|float|Decimal|int'):
785
+ # """Get a variable that is 1 if this variable is less than or equal to other, else 0."""
786
+ # r = (self - other).is_negative()
787
+ # if isinstance(r, bool):
788
+ # return not r
789
+ # return ~r
790
+
791
+ def max_of(self, other):
792
+ """Get the maximum of this variable and another variable or constant."""
793
+ if other == -float('inf'):
794
+ return self
795
+ if other == float('inf'):
796
+ raise ValueError('Cannot apply max_of with inf')
797
+ if not isinstance(other, FixedVariable):
798
+ other = FixedVariable.from_const(other, hwconf=self.hwconf, _factor=abs(self._factor))
799
+
800
+ if self.low >= other.high:
801
+ return self
802
+ if self.high <= other.low:
803
+ return other
804
+ if other.high == other.low == 0:
805
+ return self.relu()
806
+
807
+ qint = (max(self.low, other.low), max(self.high, other.high), min(self.step, other.step))
808
+ return (self - other).msb_mux(other, self, qint=qint)
809
+
810
+ def min_of(self, other):
811
+ """Get the minimum of this variable and another variable or constant."""
812
+
813
+ if other == float('inf'):
814
+ return self
815
+ if other == -float('inf'):
816
+ raise ValueError('Cannot apply min_of with -inf')
817
+ if not isinstance(other, FixedVariable):
818
+ other = FixedVariable.from_const(other, hwconf=self.hwconf, _factor=(self._factor))
819
+
820
+ if self.high <= other.low:
821
+ return self
822
+ if self.low >= other.high:
823
+ return other
824
+ if other.high == other.low == 0:
825
+ return -(-self).relu()
826
+
827
+ qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
828
+ return (self - other).msb_mux(self, other, qint=qint)
829
+
830
+ def lookup(self, table: LookupTable | np.ndarray) -> 'FixedVariable':
831
+ """Use a lookup table to map the variable.
832
+ When the table is a numpy array, the table starts at the lowest possible value of the variable
833
+ When the table is in LookupTable format, the table starts at the normalized lowest value of the variable. (i.e., if the variable has negative _factor, the table is reversed)
834
+
835
+ Parameters
836
+ ----------
837
+ table : LookupTable | np.ndarray
838
+ Lookup table to use
839
+
840
+ Returns
841
+ -------
842
+ FixedVariable
843
+ """
844
+ if isinstance(table, np.ndarray):
845
+ if len(table) == 1:
846
+ return self.from_const(float(table[0]), hwconf=self.hwconf)
847
+ if self._factor < 0:
848
+ table = table[::-1] # Reverse the table for negative factor
849
+
850
+ _table, table_id = table_context.register_table(table)
851
+ size = len(table.table) if isinstance(table, LookupTable) else len(table)
852
+ assert round((self.high - self.low) / self.step) + 1 == size, (
853
+ f'Input variable size does not match lookup table size ({round((self.high - self.low) / self.step) + 1} != {size})'
854
+ )
855
+
856
+ return FixedVariable(
857
+ _table.spec.out_qint.min,
858
+ _table.spec.out_qint.max,
859
+ _table.spec.out_qint.step,
860
+ _from=(self,),
861
+ _factor=Decimal(1),
862
+ opr='lookup',
863
+ hwconf=self.hwconf,
864
+ _data=Decimal(table_id),
865
+ )
866
+
867
+
868
+ class FixedVariableInput(FixedVariable):
869
+ __normal__variable__ = False
870
+
871
+ def __init__(
872
+ self,
873
+ latency: float | None = None,
874
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(-1, -1, -1),
875
+ opr: str = 'new',
876
+ ) -> None:
877
+ super().__init__(
878
+ low=Decimal(1e10),
879
+ high=Decimal(-1e10),
880
+ step=Decimal(1e10),
881
+ latency=latency if latency is not None else 0.0,
882
+ hwconf=HWConfig(*hwconf),
883
+ opr=opr,
884
+ cost=0.0,
885
+ _factor=Decimal(1),
886
+ _from=(),
887
+ _data=None,
888
+ _id=None,
889
+ )
890
+
891
+ def __add__(self, other):
892
+ if other == 0:
893
+ return self
894
+ raise ValueError('Cannot operate on unquantized input variable')
895
+
896
+ def __sub__(self, other):
897
+ if other == 0:
898
+ return self
899
+ raise ValueError('Cannot operate on unquantized input variable')
900
+
901
+ def __neg__(self):
902
+ raise ValueError('Cannot negate unquantized input variable')
903
+
904
+ def __mul__(self, other):
905
+ if other == 1:
906
+ return self
907
+ raise ValueError('Cannot multiply unquantized input variable')
908
+
909
+ def __rmul__(self, other):
910
+ if other == 1:
911
+ return self
912
+ raise ValueError('Cannot multiply unquantized input variable')
913
+
914
+ def __radd__(self, other):
915
+ if other == 0:
916
+ return self
917
+ raise ValueError('Cannot add unquantized input variable')
918
+
919
+ def __rsub__(self, other):
920
+ raise ValueError('Cannot subtract unquantized input variable')
921
+
922
+ def relu(self, *args, **kwargs):
923
+ raise ValueError('Cannot apply relu on unquantized input variable')
924
+
925
+ def max_of(self, other):
926
+ raise ValueError('Cannot apply max_of on unquantized input variable')
927
+
928
+ def min_of(self, other):
929
+ raise ValueError('Cannot apply min_of on unquantized input variable')
930
+
931
+ def quantize(
932
+ self,
933
+ k: int | bool,
934
+ i: int,
935
+ f: int,
936
+ overflow_mode: str = 'WRAP',
937
+ round_mode: str = 'TRN',
938
+ ):
939
+ assert overflow_mode == 'WRAP'
940
+
941
+ if k + i + f <= 0:
942
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
943
+
944
+ if round_mode == 'RND':
945
+ return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
946
+ else:
947
+ round_mode = 'TRN'
948
+
949
+ step = Decimal(2) ** -f
950
+ _high = Decimal(2) ** i
951
+ low, high = -_high * k, _high - step
952
+ self.high = max(self.high, high)
953
+ self.low = min(self.low, low)
954
+ self.step = min(self.step, step)
955
+
956
+ return FixedVariable(
957
+ low,
958
+ high,
959
+ step,
960
+ _from=(self,),
961
+ _factor=self._factor,
962
+ opr='wrap',
963
+ latency=self.latency,
964
+ hwconf=self.hwconf,
965
+ )