da4ml 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +235 -73
  5. da4ml/cmvm/core/__init__.py +221 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +67 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +74 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +268 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +187 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +122 -0
  35. da4ml-0.2.1.dist-info/METADATA +65 -0
  36. da4ml-0.2.1.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ import ctypes
2
+ import os
3
+ import re
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ from pathlib import Path
8
+ from uuid import uuid4
9
+
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+
13
+ from ... import codegen
14
+ from ...cmvm.types import CascadedSolution, Solution, _minimal_kif
15
+ from ...trace.pipeline import to_pipeline
16
+ from . import comb_binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_binder_gen, pipeline_logic_gen
17
+
18
+
19
+ def get_io_kifs(sol: Solution | CascadedSolution):
20
+ inp_kifs = tuple(zip(*map(_minimal_kif, sol.inp_qint)))
21
+ out_kifs = tuple(zip(*map(_minimal_kif, sol.out_qint)))
22
+ return np.array(inp_kifs, np.int8), np.array(out_kifs, np.int8)
23
+
24
+
25
+ class VerilogModel:
26
+ def __init__(
27
+ self,
28
+ solution: Solution | CascadedSolution,
29
+ prj_name: str,
30
+ path: str | Path,
31
+ latency_cutoff: int = -1,
32
+ print_latency: bool = True,
33
+ part_name: str = 'xcvu13p-flga2577-2-e',
34
+ clock_period: int = 5,
35
+ clock_uncertainty: float = 0.1,
36
+ io_delay_minmax: tuple[float, float] = (0.2, 0.4),
37
+ register_layers: int = 1,
38
+ ):
39
+ self._solution = solution
40
+ self._path = Path(path)
41
+ self._prj_name = prj_name
42
+ self._latency_cutoff = latency_cutoff
43
+ self._print_latency = print_latency
44
+ self.__src_root = Path(codegen.__file__).parent
45
+ self._part_name = part_name
46
+ self._clock_period = clock_period
47
+ self._clock_uncertainty = clock_uncertainty
48
+ self._io_delay_minmax = io_delay_minmax
49
+ self._register_layers = register_layers
50
+
51
+ self._pipe = solution if isinstance(solution, CascadedSolution) else None
52
+ if latency_cutoff > 0 and self._pipe is None:
53
+ assert isinstance(solution, Solution)
54
+ self._pipe = to_pipeline(solution, latency_cutoff, verbose=False)
55
+
56
+ if self._pipe is not None:
57
+ # get actual latency cutoff
58
+ latency_cutoff = int(max(max(st.latency) / (i + 1) for i, st in enumerate(self._pipe.solutions)))
59
+ self._latency_cutoff = latency_cutoff
60
+
61
+ self._lib = None
62
+
63
+ def write(self):
64
+ self._path.mkdir(parents=True, exist_ok=True)
65
+ if self._pipe is not None: # Pipeline
66
+ # Main logic
67
+ codes = pipeline_logic_gen(self._pipe, self._prj_name, self._print_latency, register_layers=self._register_layers)
68
+ for k, v in codes.items():
69
+ with open(self._path / f'{k}.v', 'w') as f:
70
+ f.write(v)
71
+
72
+ # Build script
73
+ with open(self.__src_root / 'verilog/source/build_prj.tcl') as f:
74
+ tcl = f.read()
75
+ tcl = tcl.replace('${DEVICE}', self._part_name)
76
+ tcl = tcl.replace('${PROJECT_NAME}', self._prj_name)
77
+ with open(self._path / 'build_prj.tcl', 'w') as f:
78
+ f.write(tcl)
79
+
80
+ # XDC
81
+ with open(self.__src_root / 'verilog/source/template.xdc') as f:
82
+ xdc = f.read()
83
+ xdc = xdc.replace('${CLOCK_PERIOD}', str(self._clock_period))
84
+ xdc = xdc.replace('${UNCERTAINITY_SETUP}', str(self._clock_uncertainty))
85
+ xdc = xdc.replace('${UNCERTAINITY_HOLD}', str(self._clock_uncertainty))
86
+ xdc = xdc.replace('${DELAY_MAX}', str(self._io_delay_minmax[0]))
87
+ xdc = xdc.replace('${DELAY_MIN}', str(self._io_delay_minmax[1]))
88
+ with open(self._path / f'{self._prj_name}.xdc', 'w') as f:
89
+ f.write(xdc)
90
+
91
+ # C++ binder w/ verilog wrapper for uniform bw
92
+ binder = pipeline_binder_gen(self._pipe, f'{self._prj_name}_wrapper', 1, self._register_layers)
93
+
94
+ # Verilog IO wrapper (non-uniform bw to uniform one, clk passthrough)
95
+ io_wrapper = generate_io_wrapper(self._pipe, self._prj_name, True)
96
+
97
+ self._pipe.save(self._path / 'pipeline.json')
98
+ else: # Comb
99
+ assert isinstance(self._solution, Solution)
100
+
101
+ # Main logic
102
+ code = comb_logic_gen(self._solution, self._prj_name, self._print_latency, '`timescale 1ns/1ps')
103
+ with open(self._path / f'{self._prj_name}.v', 'w') as f:
104
+ f.write(code)
105
+
106
+ # Verilog IO wrapper (non-uniform bw to uniform one, no clk)
107
+ io_wrapper = generate_io_wrapper(self._solution, self._prj_name, False)
108
+ binder = comb_binder_gen(self._solution, f'{self._prj_name}_wrapper')
109
+
110
+ with open(self._path / f'{self._prj_name}_wrapper.v', 'w') as f:
111
+ f.write(io_wrapper)
112
+ with open(self._path / f'{self._prj_name}_wrapper_binder.cc', 'w') as f:
113
+ f.write(binder)
114
+
115
+ # Common resource copy
116
+ shutil.copy(self.__src_root / 'verilog/source/shift_adder.v', self._path)
117
+ shutil.copy(self.__src_root / 'verilog/source/build_binder.mk', self._path)
118
+ shutil.copy(self.__src_root / 'verilog/source/ioutils.hh', self._path)
119
+ self._solution.save(self._path / 'model.json')
120
+ with open(self._path / 'misc.json', 'w') as f:
121
+ f.write(f'{{"cost": {self._solution.cost}}}')
122
+
123
+ def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
124
+ """Same as compile, but will not write to the library
125
+
126
+ Parameters
127
+ ----------
128
+ verbose : bool, optional
129
+ Verbose output, by default False
130
+ openmp : bool, optional
131
+ Enable openmp, by default True
132
+ o3 : bool | None, optional
133
+ Turn on -O3 flag, by default False
134
+ clean : bool, optional
135
+ Remove obsolete shared object files, by default True
136
+
137
+ Raises
138
+ ------
139
+ RuntimeError
140
+ If compilation fails
141
+ """
142
+
143
+ self._uuid = str(uuid4())
144
+ args = ['make', '-f', 'build_binder.mk']
145
+ env = os.environ.copy()
146
+ env['VM_PREFIX'] = f'{self._prj_name}_wrapper'
147
+ env['STAMP'] = self._uuid
148
+ env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
149
+ if o3:
150
+ args.append('fast')
151
+
152
+ if clean:
153
+ m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
154
+ for p in self._path.iterdir():
155
+ if not p.is_dir() and m.match(p.name):
156
+ p.unlink()
157
+
158
+ try:
159
+ r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
160
+ except subprocess.CalledProcessError as e:
161
+ print(e.stderr.decode(), file=sys.stderr)
162
+ print(e.stdout.decode(), file=sys.stdout)
163
+ raise RuntimeError('Compilation failed!!') from e
164
+ if r.returncode != 0:
165
+ print(r.stderr.decode(), file=sys.stderr)
166
+ print(r.stdout.decode(), file=sys.stderr)
167
+ raise RuntimeError('Compilation failed!!')
168
+
169
+ self._load_lib(self._uuid)
170
+
171
+ def _load_lib(self, uuid: str | None = None):
172
+ uuid = uuid if uuid is not None else self._uuid
173
+ self._uuid = uuid
174
+ lib_path = self._path / f'lib{self._prj_name}_wrapper_{uuid}.so'
175
+ if not lib_path.exists():
176
+ raise RuntimeError(f'Library {lib_path} does not exist')
177
+ self._lib = ctypes.CDLL(str(lib_path))
178
+
179
+ def compile(self, verbose=False, openmp=True, o3: bool = False):
180
+ """Compile the generated code to a emulator for logic simulation.
181
+
182
+ Parameters
183
+ ----------
184
+ verbose : bool, optional
185
+ Verbose output, by default False
186
+ openmp : bool, optional
187
+ Enable openmp, by default True
188
+ o3 : bool | None, optional
189
+ Turn on -O3 flag, by default False
190
+
191
+ Raises
192
+ ------
193
+ RuntimeError
194
+ If compilation fails
195
+ """
196
+ self.write()
197
+ self._compile(verbose=verbose, openmp=openmp, o3=o3)
198
+ self._load_lib()
199
+
200
+ def predict(self, data: NDArray[np.floating]):
201
+ """Run the model on the input data.
202
+
203
+ Parameters
204
+ ----------
205
+ data : NDArray[np.floating]
206
+ Input data to the model. The shape is ignored, and the number of samples is
207
+ determined by the size of the data.
208
+
209
+ Returns
210
+ -------
211
+ NDArray[np.float64]
212
+ Output of the model in shape (n_samples, output_size).
213
+ """
214
+ assert self._lib is not None, 'Library not loaded, call .compile() first.'
215
+ inp_size, out_size = self._solution.shape
216
+
217
+ assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
218
+ n_sample = data.size // inp_size
219
+
220
+ kifs_in, kifs_out = get_io_kifs(self._solution)
221
+ k_in, i_in, f_in = map(np.max, kifs_in)
222
+ k_out, i_out, f_out = map(np.max, kifs_out)
223
+ assert k_in + i_in + f_in <= 32, "Padded inp bw doesn't fit in int32. Emulation not supported"
224
+ assert k_out + i_out + f_out <= 32, "Padded out bw doesn't fit in int32. Emulation not supported"
225
+
226
+ inp_data = np.empty(n_sample * inp_size, dtype=np.int32)
227
+ out_data = np.empty(n_sample * out_size, dtype=np.int32)
228
+
229
+ # Convert to int32 matching the LSB position
230
+ inp_data[:] = data.ravel() * 2.0 ** np.max(f_in)
231
+
232
+ inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
233
+ out_buf = out_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
234
+ self._lib.inference(inp_buf, out_buf, n_sample)
235
+
236
+ # Unscale the output int32 to recover fp values
237
+ k, i, f = np.max(k_out), np.max(i_out), np.max(f_out)
238
+ a, b, c = 2.0 ** (k + i + f), 2.0 ** (i + f), 2.0**-f
239
+ return ((out_data.reshape(n_sample, out_size) + b) % a - b) * c
240
+
241
+ def __repr__(self):
242
+ inp_size, out_size = self._solution.shape
243
+ cost = round(self._solution.cost)
244
+ kifs_in, kifs_out = get_io_kifs(self._solution)
245
+ in_bits, out_bits = np.sum(kifs_in), np.sum(kifs_out)
246
+ if self._pipe is not None:
247
+ n_stage = len(self._pipe[0])
248
+ delay_suffix = '' if self._register_layers == 1 else f'x {self._register_layers} '
249
+ lat_cutoff = self._latency_cutoff
250
+ reg_bits = self._pipe.reg_bits
251
+ spec = f"""Top Module: {self._prj_name}\n====================
252
+ {inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
253
+ {n_stage} {delay_suffix}stages @ max_delay={lat_cutoff}
254
+ Estimated cost: {cost} LUTs, {reg_bits} FFs"""
255
+
256
+ else:
257
+ spec = f"""Top Module: {self._prj_name}\n====================
258
+ {inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
259
+ combinational @ delay={self._solution.latency}
260
+ Estimated cost: {cost} LUTs"""
261
+
262
+ is_compiled = self._lib is not None
263
+ if is_compiled:
264
+ openmp = 'with OpenMP' if self._lib.openmp_enabled else '' # type: ignore
265
+ spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
266
+ else:
267
+ spec += '\nEmulator is **not compiled**'
268
+ return spec
@@ -0,0 +1,6 @@
1
+ from .fixed_variable import HWConfig
2
+ from .fixed_variable_array import FixedVariableArray
3
+ from .pipeline import to_pipeline
4
+ from .tracer import comb_trace
5
+
6
+ __all__ = ['to_pipeline', 'comb_trace', 'FixedVariableArray', 'HWConfig']
@@ -0,0 +1,358 @@
1
+ from decimal import Decimal
2
+ from math import ceil, floor, log2
3
+ from typing import NamedTuple
4
+ from uuid import UUID, uuid4
5
+
6
+ from ..cmvm.core import cost_add
7
+ from ..cmvm.types import QInterval
8
+
9
+
10
+ class HWConfig(NamedTuple):
11
+ adder_size: int
12
+ carry_size: int
13
+ latency_cutoff: float
14
+
15
+
16
+ def _const_f(const: float | Decimal):
17
+ const = float(const)
18
+ _low, _high = -32, 32
19
+ while _high - _low > 1:
20
+ _mid = (_high + _low) // 2
21
+ _value = const * (2.0**_mid)
22
+ if _value == int(_value):
23
+ _high = _mid
24
+ else:
25
+ _low = _mid
26
+ return _high
27
+
28
+
29
+ class FixedVariable:
30
+ def __init__(
31
+ self,
32
+ low: float | Decimal,
33
+ high: float | Decimal,
34
+ step: float | Decimal,
35
+ latency: float | None = None,
36
+ hwconf=HWConfig(-1, -1, -1),
37
+ opr: str = 'new',
38
+ cost: float | None = None,
39
+ _from: tuple['FixedVariable', ...] = (),
40
+ _factor: float | Decimal = 1.0,
41
+ _data: Decimal | None = None,
42
+ _id: UUID | None = None,
43
+ ) -> None:
44
+ assert low <= high, f'low {low} must be less than high {high}'
45
+
46
+ if low == high:
47
+ opr = 'const'
48
+ _factor = 1.0
49
+ _from = ()
50
+
51
+ low, high, step = Decimal(low), Decimal(high), Decimal(step)
52
+ low, high = floor(low / step) * step, ceil(high / step) * step
53
+ self.low = low
54
+ self.high = high
55
+ self.step = step
56
+ self._factor = Decimal(_factor)
57
+ self._from: tuple[FixedVariable, ...] = _from
58
+ opr = opr
59
+ self.opr = opr
60
+ self._data = _data
61
+ self.id = _id or uuid4()
62
+ self.hwconf = hwconf
63
+
64
+ if opr == 'cadd':
65
+ assert _data is not None, 'cadd must have data'
66
+
67
+ if cost is None or latency is None:
68
+ _cost, _latency = self.get_cost_and_latency()
69
+ else:
70
+ _cost, _latency = cost, latency
71
+
72
+ self.latency = _latency
73
+ self.cost = _cost
74
+
75
+ def get_cost_and_latency(self):
76
+ if self.opr == 'const':
77
+ return 0.0, 0.0
78
+ if self.opr in ('vadd', 'cadd'):
79
+ adder_size = self.hwconf.adder_size
80
+ carry_size = self.hwconf.carry_size
81
+ latency_cutoff = self.hwconf.latency_cutoff
82
+
83
+ if self.opr == 'vadd':
84
+ assert len(self._from) == 2
85
+ v0, v1 = self._from
86
+ int0, int1 = v0.qint, v1.qint
87
+ base_latency = max(v0.latency, v1.latency)
88
+ dlat, _cost = cost_add(int0, int1, 0, False, adder_size, carry_size)
89
+ else:
90
+ assert len(self._from) == 1
91
+ assert self._data is not None, 'cadd must have data'
92
+ # int0 = self._from[0].qint
93
+ # int1 = QInterval(float(self._data), float(self._data), float(self.step))
94
+ _f = _const_f(self._data)
95
+ _cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
96
+ base_latency = self._from[0].latency
97
+ dlat = 0.0
98
+
99
+ _latency = dlat + base_latency
100
+ if latency_cutoff > 0 and ceil(_latency / latency_cutoff) > ceil(base_latency / latency_cutoff):
101
+ # Crossed the latency cutoff boundry
102
+ assert (
103
+ dlat <= latency_cutoff
104
+ ), f'Latency of an atomic operation {dlat} is larger than the pipelining latency cutoff {latency_cutoff}'
105
+ _latency = ceil(base_latency / latency_cutoff) * latency_cutoff + dlat
106
+ elif self.opr in ('relu', 'wrap'):
107
+ assert len(self._from) == 1
108
+ _latency = self._from[0].latency
109
+ _cost = 0.0
110
+ # Assume LUT5 used here (2 fan-out per LUT6, thus *1/2)
111
+ if self._from[0]._factor < 0:
112
+ _cost += sum(self.kif) / 2
113
+ if self.opr == 'relu':
114
+ _cost += sum(self.kif) / 2
115
+
116
+ elif self.opr == 'new':
117
+ # new variable, no cost
118
+ _latency = 0.0
119
+ _cost = 0.0
120
+ else:
121
+ raise NotImplementedError(f'Operation {self.opr} is unknown')
122
+ return _cost, _latency
123
+
124
+ @property
125
+ def unscaled(self):
126
+ return self * (1 / self._factor)
127
+
128
+ @property
129
+ def qint(self) -> QInterval:
130
+ return QInterval(float(self.low), float(self.high), float(self.step))
131
+
132
+ @property
133
+ def kif(self) -> tuple[bool, int, int]:
134
+ if self.step == 0:
135
+ return False, 0, 0
136
+ f = -int(log2(self.step))
137
+ i = ceil(log2(max(-self.low, self.high + self.step)))
138
+ k = self.low < 0
139
+ return k, i, f
140
+
141
+ def __repr__(self) -> str:
142
+ if self._factor == 1:
143
+ return f'FixedVariable({self.low}, {self.high}, {self.step})'
144
+ return f'({self._factor}) FixedVariable({self.low}, {self.high}, {self.step})'
145
+
146
+ def __neg__(self):
147
+ return FixedVariable(
148
+ -self.high,
149
+ -self.low,
150
+ self.step,
151
+ _from=self._from,
152
+ _factor=-self._factor,
153
+ latency=self.latency,
154
+ cost=self.cost,
155
+ opr=self.opr,
156
+ _id=self.id,
157
+ _data=self._data,
158
+ hwconf=self.hwconf,
159
+ )
160
+
161
+ def __add__(self, other: 'FixedVariable|float|Decimal|int'):
162
+ if not isinstance(other, FixedVariable):
163
+ return self._const_add(other)
164
+ if other.high == other.low:
165
+ return self._const_add(other.low)
166
+ if self.high == self.low:
167
+ return other._const_add(self.low)
168
+
169
+ assert self.hwconf == other.hwconf, 'FixedVariable must have the same hwconf'
170
+
171
+ f0, f1 = self._factor, other._factor
172
+ if f0 < 0:
173
+ if f1 > 0:
174
+ return other + self
175
+ else:
176
+ return -((-self) + (-other))
177
+
178
+ return FixedVariable(
179
+ self.low + other.low,
180
+ self.high + other.high,
181
+ min(self.step, other.step),
182
+ _from=(self, other),
183
+ _factor=f0,
184
+ opr='vadd',
185
+ hwconf=self.hwconf,
186
+ )
187
+
188
+ def _const_add(self, other: float | Decimal):
189
+ if not isinstance(other, (int, float, Decimal)):
190
+ other = float(other) # direct numpy to decimal raises error
191
+ other = Decimal(other)
192
+ if other == 0:
193
+ return self
194
+
195
+ if self.opr != 'cadd':
196
+ cstep = Decimal(2.0 ** -_const_f(other))
197
+
198
+ return FixedVariable(
199
+ self.low + other,
200
+ self.high + other,
201
+ min(self.step, cstep),
202
+ _from=(self,),
203
+ _factor=self._factor,
204
+ _data=other / self._factor,
205
+ opr='cadd',
206
+ hwconf=self.hwconf,
207
+ )
208
+
209
+ # cadd, combine the constant
210
+ assert len(self._from) == 1
211
+ parent = self._from[0]
212
+ assert self._data is not None, 'cadd must have data'
213
+ sf = self._factor / parent._factor
214
+ other1 = (self._data * parent._factor) + other / sf
215
+ return (parent + other1) * sf
216
+
217
+ def __sub__(self, other: 'FixedVariable|int|float|Decimal'):
218
+ return self + (-other)
219
+
220
+ def __mul__(
221
+ self,
222
+ other: 'float|Decimal',
223
+ ):
224
+ if other == 0:
225
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf)
226
+
227
+ assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
228
+
229
+ other = Decimal(other)
230
+
231
+ low = min(self.low * other, self.high * other)
232
+ high = max(self.low * other, self.high * other)
233
+ step = abs(self.step * other)
234
+ _factor = self._factor * other
235
+
236
+ return FixedVariable(
237
+ low,
238
+ high,
239
+ step,
240
+ _from=self._from,
241
+ _factor=_factor,
242
+ opr=self.opr,
243
+ latency=self.latency,
244
+ cost=self.cost,
245
+ _id=self.id,
246
+ _data=self._data,
247
+ hwconf=self.hwconf,
248
+ )
249
+
250
+ def __radd__(self, other: 'float|Decimal|int|FixedVariable'):
251
+ return self + other
252
+
253
+ def __rsub__(self, other: 'float|Decimal|int|FixedVariable'):
254
+ return (-self) + other
255
+
256
+ def __rmul__(self, other: 'float|Decimal|int|FixedVariable'):
257
+ return self * other
258
+
259
+ def relu(self, i: int | None = None, f: int | None = None, round_mode: str = 'TRN'):
260
+ round_mode = round_mode.upper()
261
+ assert round_mode in ('TRN', 'RND')
262
+
263
+ if self.opr == 'const':
264
+ val = self.low * (self.low > 0)
265
+ f = _const_f(val) if not f else f
266
+ step = Decimal(2) ** -f
267
+ i = ceil(log2(val + step)) if not i else i
268
+ eps = step / 2 if round_mode == 'RND' else 0
269
+ val = (floor(val / step + eps) * step) % (Decimal(2) ** i)
270
+ return FixedVariable(val, val, step, hwconf=self.hwconf)
271
+
272
+ step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
273
+ if step > self.step and round_mode == 'RND':
274
+ return (self + step / 2).relu(i, f, 'TRN')
275
+ low = max(Decimal(0), self.low)
276
+ high = max(Decimal(0), self.high)
277
+ if i is not None:
278
+ _high = Decimal(2) ** i - step
279
+ if _high < high:
280
+ # overflows
281
+ low = Decimal(0)
282
+ high = _high
283
+ _factor = self._factor
284
+ return FixedVariable(
285
+ low,
286
+ high,
287
+ step,
288
+ _from=(self,),
289
+ _factor=abs(_factor),
290
+ opr='relu',
291
+ hwconf=self.hwconf,
292
+ cost=sum(self.kif) * (1 if _factor > 0 else 2),
293
+ )
294
+
295
+ def quantize(
296
+ self,
297
+ k: int | bool,
298
+ i: int,
299
+ f: int,
300
+ overflow_mode: str = 'WRAP',
301
+ round_mode: str = 'TRN',
302
+ ):
303
+ overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
304
+ assert overflow_mode in ('WRAP', 'SAT')
305
+ assert round_mode in ('TRN', 'RND')
306
+
307
+ _k, _i, _f = self.kif
308
+
309
+ if k >= _k and i >= _i and f >= _f:
310
+ return self
311
+
312
+ if f < _f and round_mode == 'RND':
313
+ return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
314
+
315
+ if self.low == self.high:
316
+ val = self.low
317
+ step = Decimal(2) ** -f
318
+ _high = Decimal(2) ** i
319
+ high, low = _high - step, -_high * k
320
+ val = (floor(val / step) * step - low) % (2 * _high) + low
321
+ return FixedVariable(val, val, step, hwconf=self.hwconf)
322
+
323
+ # TODO: corner cases exists (e.g., overflow to negative, or negative overflow to high value)
324
+ # bit-exactness will be lost in these cases, but they should never happen (quantizers are used in a weird way)
325
+ # Keeping this for now; change if absolutely necessary
326
+ f = min(f, _f)
327
+ k = min(k, _k) if i >= _i else k
328
+ i = min(i, _i)
329
+
330
+ step = max(Decimal(2) ** -f, self.step)
331
+
332
+ low = -k * Decimal(2) ** i
333
+ high = Decimal(2) ** i - step
334
+ _low, _high = self.low, self.high
335
+
336
+ if _low >= low and _high <= high:
337
+ low, high = _low, _high
338
+
339
+ if low > high:
340
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf)
341
+
342
+ return FixedVariable(
343
+ low,
344
+ high,
345
+ step,
346
+ _from=(self,),
347
+ _factor=abs(self._factor),
348
+ opr='wrap' if overflow_mode == 'WRAP' else 'sat',
349
+ latency=self.latency,
350
+ hwconf=self.hwconf,
351
+ )
352
+
353
+ @classmethod
354
+ def from_kif(cls, k: int | bool, i: int, f: int, **kwargs):
355
+ step = Decimal(2) ** -f
356
+ _high = Decimal(2) ** i
357
+ low, high = k * _high, _high - step
358
+ return cls(low, high, step, **kwargs)