da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +239 -73
  5. da4ml/cmvm/core/__init__.py +222 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +49 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +265 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +177 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +120 -0
  35. da4ml-0.2.0.dist-info/METADATA +65 -0
  36. da4ml-0.2.0.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
1
+ import ctypes
2
+ import os
3
+ import re
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ from pathlib import Path
8
+ from uuid import uuid4
9
+
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+
13
+ from ... import codegen
14
+ from ...cmvm.types import CascadedSolution, Solution, _minimal_kif
15
+ from ...trace.pipeline import to_pipeline
16
+ from . import comb_binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_binder_gen, pipeline_logic_gen
17
+
18
+
19
+ def get_io_kifs(sol: Solution | CascadedSolution):
20
+ inp_kifs = tuple(zip(*map(_minimal_kif, sol.inp_qint)))
21
+ out_kifs = tuple(zip(*map(_minimal_kif, sol.out_qint)))
22
+ return np.array(inp_kifs, np.int8), np.array(out_kifs, np.int8)
23
+
24
+
25
+ class VerilogModel:
26
+ def __init__(
27
+ self,
28
+ solution: Solution | CascadedSolution,
29
+ prj_name: str,
30
+ path: str | Path,
31
+ latency_cutoff: int = -1,
32
+ print_latency: bool = True,
33
+ part_name: str = 'xcvu13p-flga2577-2-e',
34
+ clock_period: int = 5,
35
+ clock_uncertainty: float = 0.1,
36
+ io_delay_minmax: tuple[float, float] = (0.2, 0.4),
37
+ ):
38
+ self._solution = solution
39
+ self._path = Path(path)
40
+ self._prj_name = prj_name
41
+ self._latency_cutoff = latency_cutoff
42
+ self._print_latency = print_latency
43
+ self.__src_root = Path(codegen.__file__).parent
44
+ self._part_name = part_name
45
+ self._clock_period = clock_period
46
+ self._clock_uncertainty = clock_uncertainty
47
+ self._io_delay_minmax = io_delay_minmax
48
+
49
+ self._pipe = solution if isinstance(solution, CascadedSolution) else None
50
+ if latency_cutoff > 0 and self._pipe is None:
51
+ assert isinstance(solution, Solution)
52
+ self._pipe = to_pipeline(solution, latency_cutoff, verbose=False)
53
+
54
+ if self._pipe is not None:
55
+ # get actual latency cutoff
56
+ latency_cutoff = int(max(max(st.latency) / (i + 1) for i, st in enumerate(self._pipe.solutions)))
57
+ self._latency_cutoff = latency_cutoff
58
+
59
+ self._lib = None
60
+
61
+ def write(self):
62
+ self._path.mkdir(parents=True, exist_ok=True)
63
+ if self._pipe is not None: # Pipeline
64
+ # Main logic
65
+ codes = pipeline_logic_gen(self._pipe, self._prj_name, self._print_latency)
66
+ for k, v in codes.items():
67
+ with open(self._path / f'{k}.v', 'w') as f:
68
+ f.write(v)
69
+
70
+ # Build script
71
+ with open(self.__src_root / 'verilog/source/build_prj.tcl') as f:
72
+ tcl = f.read()
73
+ tcl = tcl.replace('${DEVICE}', self._part_name)
74
+ tcl = tcl.replace('${PROJECT_NAME}', self._prj_name)
75
+ with open(self._path / 'build_prj.tcl', 'w') as f:
76
+ f.write(tcl)
77
+
78
+ # XDC
79
+ with open(self.__src_root / 'verilog/source/template.xdc') as f:
80
+ xdc = f.read()
81
+ xdc = xdc.replace('${CLOCK_PERIOD}', str(self._clock_period))
82
+ xdc = xdc.replace('${UNCERTAINITY_SETUP}', str(self._clock_uncertainty))
83
+ xdc = xdc.replace('${UNCERTAINITY_HOLD}', str(self._clock_uncertainty))
84
+ xdc = xdc.replace('${DELAY_MAX}', str(self._io_delay_minmax[0]))
85
+ xdc = xdc.replace('${DELAY_MIN}', str(self._io_delay_minmax[1]))
86
+ with open(self._path / f'{self._prj_name}.xdc', 'w') as f:
87
+ f.write(xdc)
88
+
89
+ # C++ binder w/
90
+ binder = pipeline_binder_gen(self._pipe, f'{self._prj_name}_wrapper', 1)
91
+
92
+ # Verilog IO wrapper (non-uniform bw to uniform one, clk passthrough)
93
+ io_wrapper = generate_io_wrapper(self._pipe, self._prj_name, True)
94
+
95
+ self._pipe.save(self._path / 'pipeline.json')
96
+ else: # Comb
97
+ assert isinstance(self._solution, Solution)
98
+
99
+ # Main logic
100
+ code = comb_logic_gen(self._solution, self._prj_name, self._print_latency, '`timescale 1ns/1ps')
101
+ with open(self._path / f'{self._prj_name}.v', 'w') as f:
102
+ f.write(code)
103
+
104
+ # Verilog IO wrapper (non-uniform bw to uniform one, no clk)
105
+ io_wrapper = generate_io_wrapper(self._solution, self._prj_name, False)
106
+ binder = comb_binder_gen(self._solution, f'{self._prj_name}_wrapper')
107
+
108
+ with open(self._path / f'{self._prj_name}_wrapper.v', 'w') as f:
109
+ f.write(io_wrapper)
110
+ with open(self._path / f'{self._prj_name}_wrapper_binder.cc', 'w') as f:
111
+ f.write(binder)
112
+
113
+ # Common resource copy
114
+ shutil.copy(self.__src_root / 'verilog/source/shift_adder.v', self._path)
115
+ shutil.copy(self.__src_root / 'verilog/source/build_binder.mk', self._path)
116
+ shutil.copy(self.__src_root / 'verilog/source/ioutils.hh', self._path)
117
+ self._solution.save(self._path / 'model.json')
118
+ with open(self._path / 'misc.json', 'w') as f:
119
+ f.write(f'{{"cost": {self._solution.cost}}}')
120
+
121
+ def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
122
+ """Same as compile, but will not write to the library
123
+
124
+ Parameters
125
+ ----------
126
+ verbose : bool, optional
127
+ Verbose output, by default False
128
+ openmp : bool, optional
129
+ Enable openmp, by default True
130
+ o3 : bool | None, optional
131
+ Turn on -O3 flag, by default False
132
+ clean : bool, optional
133
+ Remove obsolete shared object files, by default True
134
+
135
+ Raises
136
+ ------
137
+ RuntimeError
138
+ If compilation fails
139
+ """
140
+
141
+ self._uuid = str(uuid4())
142
+ args = ['make', '-f', 'build_binder.mk']
143
+ env = os.environ.copy()
144
+ env['VM_PREFIX'] = f'{self._prj_name}_wrapper'
145
+ env['STAMP'] = self._uuid
146
+ env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
147
+ if o3:
148
+ args.append('fast')
149
+
150
+ if clean:
151
+ m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
152
+ for p in self._path.iterdir():
153
+ if not p.is_dir() and m.match(p.name):
154
+ p.unlink()
155
+
156
+ try:
157
+ r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
158
+ except subprocess.CalledProcessError as e:
159
+ print(e.stderr.decode(), file=sys.stderr)
160
+ print(e.stdout.decode(), file=sys.stdout)
161
+ raise RuntimeError('Compilation failed!!') from e
162
+ if r.returncode != 0:
163
+ print(r.stderr.decode(), file=sys.stderr)
164
+ print(r.stdout.decode(), file=sys.stderr)
165
+ raise RuntimeError('Compilation failed!!')
166
+
167
+ self._load_lib(self._uuid)
168
+
169
+ def _load_lib(self, uuid: str | None = None):
170
+ uuid = uuid if uuid is not None else self._uuid
171
+ self._uuid = uuid
172
+ lib_path = self._path / f'lib{self._prj_name}_wrapper_{uuid}.so'
173
+ if not lib_path.exists():
174
+ raise RuntimeError(f'Library {lib_path} does not exist')
175
+ self._lib = ctypes.CDLL(str(lib_path))
176
+
177
+ def compile(self, verbose=False, openmp=True, o3: bool = False):
178
+ """Compile the generated code to a emulator for logic simulation.
179
+
180
+ Parameters
181
+ ----------
182
+ verbose : bool, optional
183
+ Verbose output, by default False
184
+ openmp : bool, optional
185
+ Enable openmp, by default True
186
+ o3 : bool | None, optional
187
+ Turn on -O3 flag, by default False
188
+
189
+ Raises
190
+ ------
191
+ RuntimeError
192
+ If compilation fails
193
+ """
194
+ self.write()
195
+ self._compile(verbose=verbose, openmp=openmp, o3=o3)
196
+ self._load_lib()
197
+
198
+ def predict(self, data: NDArray[np.floating]):
199
+ """Run the model on the input data.
200
+
201
+ Parameters
202
+ ----------
203
+ data : NDArray[np.floating]
204
+ Input data to the model. The shape is ignored, and the number of samples is
205
+ determined by the size of the data.
206
+
207
+ Returns
208
+ -------
209
+ NDArray[np.float64]
210
+ Output of the model in shape (n_samples, output_size).
211
+ """
212
+ assert self._lib is not None, 'Library not loaded, call .compile() first.'
213
+ inp_size, out_size = self._solution.shape
214
+
215
+ assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
216
+ n_sample = data.size // inp_size
217
+
218
+ kifs_in, kifs_out = get_io_kifs(self._solution)
219
+ k_in, i_in, f_in = map(np.max, kifs_in)
220
+ k_out, i_out, f_out = map(np.max, kifs_out)
221
+ assert k_in + i_in + f_in <= 32, "Padded inp bw doesn't fit in int32. Emulation not supported"
222
+ assert k_out + i_out + f_out <= 32, "Padded out bw doesn't fit in int32. Emulation not supported"
223
+
224
+ inp_data = np.empty(n_sample * inp_size, dtype=np.int32)
225
+ out_data = np.empty(n_sample * out_size, dtype=np.int32)
226
+
227
+ # Convert to int32 matching the LSB position
228
+ inp_data[:] = data.ravel() * 2.0 ** np.max(f_in)
229
+
230
+ inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
231
+ out_buf = out_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
232
+ self._lib.inference(inp_buf, out_buf, n_sample)
233
+
234
+ # Unscale the output int32 to recover fp values
235
+ k, i, f = np.max(k_out), np.max(i_out), np.max(f_out)
236
+ a, b, c = 2.0 ** (k + i + f), 2.0 ** (i + f), 2.0**-f
237
+ return ((out_data.reshape(n_sample, out_size) + b) % a - b) * c
238
+
239
+ def __repr__(self):
240
+ inp_size, out_size = self._solution.shape
241
+ cost = round(self._solution.cost)
242
+ kifs_in, kifs_out = get_io_kifs(self._solution)
243
+ in_bits, out_bits = np.sum(kifs_in), np.sum(kifs_out)
244
+ if self._pipe is not None:
245
+ n_stage = len(self._pipe[0])
246
+ lat_cutoff = self._latency_cutoff
247
+ reg_bits = self._pipe.reg_bits
248
+ spec = f"""Top Module: {self._prj_name}\n====================
249
+ {inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
250
+ {n_stage} stages @ max_delay={lat_cutoff}
251
+ Estimated cost: {cost} LUTs, {reg_bits} FFs"""
252
+
253
+ else:
254
+ spec = f"""Top Module: {self._prj_name}\n====================
255
+ {inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
256
+ combinational @ delay={self._solution.latency}
257
+ Estimated cost: {cost} LUTs"""
258
+
259
+ is_compiled = self._lib is not None
260
+ if is_compiled:
261
+ openmp = 'with OpenMP' if self._lib.openmp_enabled else '' # type: ignore
262
+ spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
263
+ else:
264
+ spec += '\nEmulator is **not compiled**'
265
+ return spec
@@ -0,0 +1,6 @@
1
+ from .fixed_variable import HWConfig
2
+ from .fixed_variable_array import FixedVariableArray
3
+ from .pipeline import to_pipeline
4
+ from .tracer import comb_trace
5
+
6
+ __all__ = ['to_pipeline', 'comb_trace', 'FixedVariableArray', 'HWConfig']
@@ -0,0 +1,358 @@
1
+ from decimal import Decimal
2
+ from math import ceil, floor, log2
3
+ from typing import NamedTuple
4
+ from uuid import UUID, uuid4
5
+
6
+ from ..cmvm.core import cost_add
7
+ from ..cmvm.types import QInterval
8
+
9
+
10
+ class HWConfig(NamedTuple):
11
+ adder_size: int
12
+ carry_size: int
13
+ latency_cutoff: float
14
+
15
+
16
+ def _const_f(const: float | Decimal):
17
+ const = float(const)
18
+ _low, _high = -32, 32
19
+ while _high - _low > 1:
20
+ _mid = (_high + _low) // 2
21
+ _value = const * (2.0**_mid)
22
+ if _value == int(_value):
23
+ _high = _mid
24
+ else:
25
+ _low = _mid
26
+ return _high
27
+
28
+
29
+ class FixedVariable:
30
+ def __init__(
31
+ self,
32
+ low: float | Decimal,
33
+ high: float | Decimal,
34
+ step: float | Decimal,
35
+ latency: float | None = None,
36
+ hwconf=HWConfig(-1, -1, -1),
37
+ opr: str = 'new',
38
+ cost: float | None = None,
39
+ _from: tuple['FixedVariable', ...] = (),
40
+ _factor: float | Decimal = 1.0,
41
+ _data: Decimal | None = None,
42
+ _id: UUID | None = None,
43
+ ) -> None:
44
+ assert low <= high, f'low {low} must be less than high {high}'
45
+
46
+ if low == high:
47
+ opr = 'const'
48
+ _factor = 1.0
49
+ _from = ()
50
+
51
+ low, high, step = Decimal(low), Decimal(high), Decimal(step)
52
+ low, high = floor(low / step) * step, ceil(high / step) * step
53
+ self.low = low
54
+ self.high = high
55
+ self.step = step
56
+ self._factor = Decimal(_factor)
57
+ self._from: tuple[FixedVariable, ...] = _from
58
+ opr = opr
59
+ self.opr = opr
60
+ self._data = _data
61
+ self.id = _id or uuid4()
62
+ self.hwconf = hwconf
63
+
64
+ if opr == 'cadd':
65
+ assert _data is not None, 'cadd must have data'
66
+
67
+ if cost is None or latency is None:
68
+ _cost, _latency = self.get_cost_and_latency()
69
+ else:
70
+ _cost, _latency = cost, latency
71
+
72
+ self.latency = _latency
73
+ self.cost = _cost
74
+
75
+ def get_cost_and_latency(self):
76
+ if self.opr == 'const':
77
+ return 0.0, 0.0
78
+ if self.opr in ('vadd', 'cadd'):
79
+ adder_size = self.hwconf.adder_size
80
+ carry_size = self.hwconf.carry_size
81
+ latency_cutoff = self.hwconf.latency_cutoff
82
+
83
+ if self.opr == 'vadd':
84
+ assert len(self._from) == 2
85
+ v0, v1 = self._from
86
+ int0, int1 = v0.qint, v1.qint
87
+ base_latency = max(v0.latency, v1.latency)
88
+ dlat, _cost = cost_add(int0, int1, 0, False, adder_size, carry_size)
89
+ else:
90
+ assert len(self._from) == 1
91
+ assert self._data is not None, 'cadd must have data'
92
+ # int0 = self._from[0].qint
93
+ # int1 = QInterval(float(self._data), float(self._data), float(self.step))
94
+ _f = _const_f(self._data)
95
+ _cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
96
+ base_latency = self._from[0].latency
97
+ dlat = 0.0
98
+
99
+ _latency = dlat + base_latency
100
+ if latency_cutoff > 0 and ceil(_latency / latency_cutoff) > ceil(base_latency / latency_cutoff):
101
+ # Crossed the latency cutoff boundry
102
+ assert (
103
+ dlat <= latency_cutoff
104
+ ), f'Latency of an atomic operation {dlat} is larger than the pipelining latency cutoff {latency_cutoff}'
105
+ _latency = ceil(base_latency / latency_cutoff) * latency_cutoff + dlat
106
+ elif self.opr in ('relu', 'wrap'):
107
+ assert len(self._from) == 1
108
+ _latency = self._from[0].latency
109
+ _cost = 0.0
110
+ # Assume LUT5 used here (2 fan-out per LUT6, thus *1/2)
111
+ if self._from[0]._factor < 0:
112
+ _cost += sum(self.kif) / 2
113
+ if self.opr == 'relu':
114
+ _cost += sum(self.kif) / 2
115
+
116
+ elif self.opr == 'new':
117
+ # new variable, no cost
118
+ _latency = 0.0
119
+ _cost = 0.0
120
+ else:
121
+ raise NotImplementedError(f'Operation {self.opr} is unknown')
122
+ return _cost, _latency
123
+
124
+ @property
125
+ def unscaled(self):
126
+ return self * (1 / self._factor)
127
+
128
+ @property
129
+ def qint(self) -> QInterval:
130
+ return QInterval(float(self.low), float(self.high), float(self.step))
131
+
132
+ @property
133
+ def kif(self) -> tuple[bool, int, int]:
134
+ if self.step == 0:
135
+ return False, 0, 0
136
+ f = -int(log2(self.step))
137
+ i = ceil(log2(max(-self.low, self.high + self.step)))
138
+ k = self.low < 0
139
+ return k, i, f
140
+
141
+ def __repr__(self) -> str:
142
+ if self._factor == 1:
143
+ return f'FixedVariable({self.low}, {self.high}, {self.step})'
144
+ return f'({self._factor}) FixedVariable({self.low}, {self.high}, {self.step})'
145
+
146
+ def __neg__(self):
147
+ return FixedVariable(
148
+ -self.high,
149
+ -self.low,
150
+ self.step,
151
+ _from=self._from,
152
+ _factor=-self._factor,
153
+ latency=self.latency,
154
+ cost=self.cost,
155
+ opr=self.opr,
156
+ _id=self.id,
157
+ _data=self._data,
158
+ hwconf=self.hwconf,
159
+ )
160
+
161
+ def __add__(self, other: 'FixedVariable|float|Decimal|int'):
162
+ if not isinstance(other, FixedVariable):
163
+ return self._const_add(other)
164
+ if other.high == other.low:
165
+ return self._const_add(other.low)
166
+ if self.high == self.low:
167
+ return other._const_add(self.low)
168
+
169
+ assert self.hwconf == other.hwconf, 'FixedVariable must have the same hwconf'
170
+
171
+ f0, f1 = self._factor, other._factor
172
+ if f0 < 0:
173
+ if f1 > 0:
174
+ return other + self
175
+ else:
176
+ return -((-self) + (-other))
177
+
178
+ return FixedVariable(
179
+ self.low + other.low,
180
+ self.high + other.high,
181
+ min(self.step, other.step),
182
+ _from=(self, other),
183
+ _factor=f0,
184
+ opr='vadd',
185
+ hwconf=self.hwconf,
186
+ )
187
+
188
+ def _const_add(self, other: float | Decimal):
189
+ if not isinstance(other, (int, float, Decimal)):
190
+ other = float(other) # direct numpy to decimal raises error
191
+ other = Decimal(other)
192
+ if other == 0:
193
+ return self
194
+
195
+ if self.opr != 'cadd':
196
+ cstep = Decimal(2.0 ** -_const_f(other))
197
+
198
+ return FixedVariable(
199
+ self.low + other,
200
+ self.high + other,
201
+ min(self.step, cstep),
202
+ _from=(self,),
203
+ _factor=self._factor,
204
+ _data=other / self._factor,
205
+ opr='cadd',
206
+ hwconf=self.hwconf,
207
+ )
208
+
209
+ # cadd, combine the constant
210
+ assert len(self._from) == 1
211
+ parent = self._from[0]
212
+ assert self._data is not None, 'cadd must have data'
213
+ sf = self._factor / parent._factor
214
+ other1 = (self._data * parent._factor) + other / sf
215
+ return (parent + other1) * sf
216
+
217
+ def __sub__(self, other: 'FixedVariable|int|float|Decimal'):
218
+ return self + (-other)
219
+
220
+ def __mul__(
221
+ self,
222
+ other: 'float|Decimal',
223
+ ):
224
+ if other == 0:
225
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf)
226
+
227
+ assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
228
+
229
+ other = Decimal(other)
230
+
231
+ low = min(self.low * other, self.high * other)
232
+ high = max(self.low * other, self.high * other)
233
+ step = abs(self.step * other)
234
+ _factor = self._factor * other
235
+
236
+ return FixedVariable(
237
+ low,
238
+ high,
239
+ step,
240
+ _from=self._from,
241
+ _factor=_factor,
242
+ opr=self.opr,
243
+ latency=self.latency,
244
+ cost=self.cost,
245
+ _id=self.id,
246
+ _data=self._data,
247
+ hwconf=self.hwconf,
248
+ )
249
+
250
+ def __radd__(self, other: 'float|Decimal|int|FixedVariable'):
251
+ return self + other
252
+
253
+ def __rsub__(self, other: 'float|Decimal|int|FixedVariable'):
254
+ return (-self) + other
255
+
256
+ def __rmul__(self, other: 'float|Decimal|int|FixedVariable'):
257
+ return self * other
258
+
259
+ def relu(self, i: int | None = None, f: int | None = None, round_mode: str = 'TRN'):
260
+ round_mode = round_mode.upper()
261
+ assert round_mode in ('TRN', 'RND')
262
+
263
+ if self.opr == 'const':
264
+ val = self.low * (self.low > 0)
265
+ f = _const_f(val) if not f else f
266
+ step = Decimal(2) ** -f
267
+ i = ceil(log2(val + step)) if not i else i
268
+ eps = step / 2 if round_mode == 'RND' else 0
269
+ val = floor(val / step + eps) % Decimal(2) ** i * step
270
+ return FixedVariable(val, val, step, hwconf=self.hwconf)
271
+
272
+ step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
273
+ if step > self.step and round_mode == 'RND':
274
+ return (self + step / 2).relu(i, f, 'TRN')
275
+ low = max(Decimal(0), self.low)
276
+ high = max(Decimal(0), self.high)
277
+ if i is not None:
278
+ _high = Decimal(2) ** i - step
279
+ if _high < high:
280
+ # overflows
281
+ low = Decimal(0)
282
+ high = _high
283
+ _factor = self._factor
284
+ return FixedVariable(
285
+ low,
286
+ high,
287
+ step,
288
+ _from=(self,),
289
+ _factor=abs(_factor),
290
+ opr='relu',
291
+ hwconf=self.hwconf,
292
+ cost=sum(self.kif) * (1 if _factor > 0 else 2),
293
+ )
294
+
295
+ def quantize(
296
+ self,
297
+ k: int | bool,
298
+ i: int,
299
+ f: int,
300
+ overflow_mode: str = 'WRAP',
301
+ round_mode: str = 'TRN',
302
+ ):
303
+ overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
304
+ assert overflow_mode in ('WRAP', 'SAT')
305
+ assert round_mode in ('TRN', 'RND')
306
+
307
+ _k, _i, _f = self.kif
308
+
309
+ if k >= _k and i >= _i and f >= _f:
310
+ return self
311
+
312
+ if f < _f and round_mode == 'RND':
313
+ return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
314
+
315
+ if self.low == self.high:
316
+ val = self.low
317
+ step = Decimal(2) ** -f
318
+ _high = Decimal(2) ** i
319
+ high, low = _high - step, -_high * k
320
+ val = (floor(val / step) * step - low) % (2 * _high) + low
321
+ return FixedVariable(val, val, step, hwconf=self.hwconf)
322
+
323
+ # TODO: corner cases exists (e.g., overflow to negative, or negative overflow to high value)
324
+ # bit-exactness will be lost in these cases, but they should never happen (quantizers are used in a weird way)
325
+ # Keeping this for now; change if absolutely necessary
326
+ f = min(f, _f)
327
+ k = min(k, _k)
328
+ i = min(i, _i)
329
+
330
+ step = max(Decimal(2) ** -f, self.step)
331
+
332
+ low = -k * Decimal(2) ** i
333
+ high = Decimal(2) ** i - step
334
+ _low, _high = self.low, self.high
335
+
336
+ if _low >= low and _high <= high:
337
+ low, high = _low, _high
338
+
339
+ if low > high:
340
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf)
341
+
342
+ return FixedVariable(
343
+ low,
344
+ high,
345
+ step,
346
+ _from=(self,),
347
+ _factor=abs(self._factor),
348
+ opr='wrap' if overflow_mode == 'WRAP' else 'sat',
349
+ latency=self.latency,
350
+ hwconf=self.hwconf,
351
+ )
352
+
353
+ @classmethod
354
+ def from_kif(cls, k: int | bool, i: int, f: int, **kwargs):
355
+ step = Decimal(2) ** -f
356
+ _high = Decimal(2) ** i
357
+ low, high = k * _high, _high - step
358
+ return cls(low, high, step, **kwargs)