da4ml 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +16 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +3 -34
- da4ml/cmvm/api.py +239 -73
- da4ml/cmvm/core/__init__.py +222 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +569 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +11 -0
- da4ml/codegen/cpp/__init__.py +3 -0
- da4ml/codegen/cpp/cpp_codegen.py +148 -0
- da4ml/codegen/cpp/source/vitis.h +30 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
- da4ml/codegen/verilog/__init__.py +13 -0
- da4ml/codegen/verilog/comb.py +146 -0
- da4ml/codegen/verilog/io_wrapper.py +255 -0
- da4ml/codegen/verilog/pipeline.py +49 -0
- da4ml/codegen/verilog/source/build_binder.mk +27 -0
- da4ml/codegen/verilog/source/build_prj.tcl +75 -0
- da4ml/codegen/verilog/source/ioutils.hh +117 -0
- da4ml/codegen/verilog/source/shift_adder.v +56 -0
- da4ml/codegen/verilog/source/template.xdc +29 -0
- da4ml/codegen/verilog/verilog_model.py +265 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +358 -0
- da4ml/trace/fixed_variable_array.py +177 -0
- da4ml/trace/ops/__init__.py +55 -0
- da4ml/trace/ops/conv_utils.py +104 -0
- da4ml/trace/ops/einsum_utils.py +299 -0
- da4ml/trace/pipeline.py +155 -0
- da4ml/trace/tracer.py +120 -0
- da4ml-0.2.0.dist-info/METADATA +65 -0
- da4ml-0.2.0.dist-info/RECORD +39 -0
- {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
- da4ml/cmvm/balanced_reduction.py +0 -46
- da4ml/cmvm/cmvm.py +0 -328
- da4ml/cmvm/codegen.py +0 -159
- da4ml/cmvm/csd.py +0 -73
- da4ml/cmvm/fixed_variable.py +0 -205
- da4ml/cmvm/graph_compile.py +0 -85
- da4ml/cmvm/nb_fixed_precision.py +0 -98
- da4ml/cmvm/scoring.py +0 -55
- da4ml/cmvm/utils.py +0 -5
- da4ml-0.1.1.dist-info/METADATA +0 -121
- da4ml-0.1.1.dist-info/RECORD +0 -18
- {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info/licenses}/LICENSE +0 -0
- {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import shutil
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from numpy.typing import NDArray
|
|
12
|
+
|
|
13
|
+
from ... import codegen
|
|
14
|
+
from ...cmvm.types import CascadedSolution, Solution, _minimal_kif
|
|
15
|
+
from ...trace.pipeline import to_pipeline
|
|
16
|
+
from . import comb_binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_binder_gen, pipeline_logic_gen
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_io_kifs(sol: Solution | CascadedSolution):
|
|
20
|
+
inp_kifs = tuple(zip(*map(_minimal_kif, sol.inp_qint)))
|
|
21
|
+
out_kifs = tuple(zip(*map(_minimal_kif, sol.out_qint)))
|
|
22
|
+
return np.array(inp_kifs, np.int8), np.array(out_kifs, np.int8)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class VerilogModel:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
solution: Solution | CascadedSolution,
|
|
29
|
+
prj_name: str,
|
|
30
|
+
path: str | Path,
|
|
31
|
+
latency_cutoff: int = -1,
|
|
32
|
+
print_latency: bool = True,
|
|
33
|
+
part_name: str = 'xcvu13p-flga2577-2-e',
|
|
34
|
+
clock_period: int = 5,
|
|
35
|
+
clock_uncertainty: float = 0.1,
|
|
36
|
+
io_delay_minmax: tuple[float, float] = (0.2, 0.4),
|
|
37
|
+
):
|
|
38
|
+
self._solution = solution
|
|
39
|
+
self._path = Path(path)
|
|
40
|
+
self._prj_name = prj_name
|
|
41
|
+
self._latency_cutoff = latency_cutoff
|
|
42
|
+
self._print_latency = print_latency
|
|
43
|
+
self.__src_root = Path(codegen.__file__).parent
|
|
44
|
+
self._part_name = part_name
|
|
45
|
+
self._clock_period = clock_period
|
|
46
|
+
self._clock_uncertainty = clock_uncertainty
|
|
47
|
+
self._io_delay_minmax = io_delay_minmax
|
|
48
|
+
|
|
49
|
+
self._pipe = solution if isinstance(solution, CascadedSolution) else None
|
|
50
|
+
if latency_cutoff > 0 and self._pipe is None:
|
|
51
|
+
assert isinstance(solution, Solution)
|
|
52
|
+
self._pipe = to_pipeline(solution, latency_cutoff, verbose=False)
|
|
53
|
+
|
|
54
|
+
if self._pipe is not None:
|
|
55
|
+
# get actual latency cutoff
|
|
56
|
+
latency_cutoff = int(max(max(st.latency) / (i + 1) for i, st in enumerate(self._pipe.solutions)))
|
|
57
|
+
self._latency_cutoff = latency_cutoff
|
|
58
|
+
|
|
59
|
+
self._lib = None
|
|
60
|
+
|
|
61
|
+
def write(self):
|
|
62
|
+
self._path.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
if self._pipe is not None: # Pipeline
|
|
64
|
+
# Main logic
|
|
65
|
+
codes = pipeline_logic_gen(self._pipe, self._prj_name, self._print_latency)
|
|
66
|
+
for k, v in codes.items():
|
|
67
|
+
with open(self._path / f'{k}.v', 'w') as f:
|
|
68
|
+
f.write(v)
|
|
69
|
+
|
|
70
|
+
# Build script
|
|
71
|
+
with open(self.__src_root / 'verilog/source/build_prj.tcl') as f:
|
|
72
|
+
tcl = f.read()
|
|
73
|
+
tcl = tcl.replace('${DEVICE}', self._part_name)
|
|
74
|
+
tcl = tcl.replace('${PROJECT_NAME}', self._prj_name)
|
|
75
|
+
with open(self._path / 'build_prj.tcl', 'w') as f:
|
|
76
|
+
f.write(tcl)
|
|
77
|
+
|
|
78
|
+
# XDC
|
|
79
|
+
with open(self.__src_root / 'verilog/source/template.xdc') as f:
|
|
80
|
+
xdc = f.read()
|
|
81
|
+
xdc = xdc.replace('${CLOCK_PERIOD}', str(self._clock_period))
|
|
82
|
+
xdc = xdc.replace('${UNCERTAINITY_SETUP}', str(self._clock_uncertainty))
|
|
83
|
+
xdc = xdc.replace('${UNCERTAINITY_HOLD}', str(self._clock_uncertainty))
|
|
84
|
+
xdc = xdc.replace('${DELAY_MAX}', str(self._io_delay_minmax[0]))
|
|
85
|
+
xdc = xdc.replace('${DELAY_MIN}', str(self._io_delay_minmax[1]))
|
|
86
|
+
with open(self._path / f'{self._prj_name}.xdc', 'w') as f:
|
|
87
|
+
f.write(xdc)
|
|
88
|
+
|
|
89
|
+
# C++ binder w/
|
|
90
|
+
binder = pipeline_binder_gen(self._pipe, f'{self._prj_name}_wrapper', 1)
|
|
91
|
+
|
|
92
|
+
# Verilog IO wrapper (non-uniform bw to uniform one, clk passthrough)
|
|
93
|
+
io_wrapper = generate_io_wrapper(self._pipe, self._prj_name, True)
|
|
94
|
+
|
|
95
|
+
self._pipe.save(self._path / 'pipeline.json')
|
|
96
|
+
else: # Comb
|
|
97
|
+
assert isinstance(self._solution, Solution)
|
|
98
|
+
|
|
99
|
+
# Main logic
|
|
100
|
+
code = comb_logic_gen(self._solution, self._prj_name, self._print_latency, '`timescale 1ns/1ps')
|
|
101
|
+
with open(self._path / f'{self._prj_name}.v', 'w') as f:
|
|
102
|
+
f.write(code)
|
|
103
|
+
|
|
104
|
+
# Verilog IO wrapper (non-uniform bw to uniform one, no clk)
|
|
105
|
+
io_wrapper = generate_io_wrapper(self._solution, self._prj_name, False)
|
|
106
|
+
binder = comb_binder_gen(self._solution, f'{self._prj_name}_wrapper')
|
|
107
|
+
|
|
108
|
+
with open(self._path / f'{self._prj_name}_wrapper.v', 'w') as f:
|
|
109
|
+
f.write(io_wrapper)
|
|
110
|
+
with open(self._path / f'{self._prj_name}_wrapper_binder.cc', 'w') as f:
|
|
111
|
+
f.write(binder)
|
|
112
|
+
|
|
113
|
+
# Common resource copy
|
|
114
|
+
shutil.copy(self.__src_root / 'verilog/source/shift_adder.v', self._path)
|
|
115
|
+
shutil.copy(self.__src_root / 'verilog/source/build_binder.mk', self._path)
|
|
116
|
+
shutil.copy(self.__src_root / 'verilog/source/ioutils.hh', self._path)
|
|
117
|
+
self._solution.save(self._path / 'model.json')
|
|
118
|
+
with open(self._path / 'misc.json', 'w') as f:
|
|
119
|
+
f.write(f'{{"cost": {self._solution.cost}}}')
|
|
120
|
+
|
|
121
|
+
def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
|
|
122
|
+
"""Same as compile, but will not write to the library
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
verbose : bool, optional
|
|
127
|
+
Verbose output, by default False
|
|
128
|
+
openmp : bool, optional
|
|
129
|
+
Enable openmp, by default True
|
|
130
|
+
o3 : bool | None, optional
|
|
131
|
+
Turn on -O3 flag, by default False
|
|
132
|
+
clean : bool, optional
|
|
133
|
+
Remove obsolete shared object files, by default True
|
|
134
|
+
|
|
135
|
+
Raises
|
|
136
|
+
------
|
|
137
|
+
RuntimeError
|
|
138
|
+
If compilation fails
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
self._uuid = str(uuid4())
|
|
142
|
+
args = ['make', '-f', 'build_binder.mk']
|
|
143
|
+
env = os.environ.copy()
|
|
144
|
+
env['VM_PREFIX'] = f'{self._prj_name}_wrapper'
|
|
145
|
+
env['STAMP'] = self._uuid
|
|
146
|
+
env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
|
|
147
|
+
if o3:
|
|
148
|
+
args.append('fast')
|
|
149
|
+
|
|
150
|
+
if clean:
|
|
151
|
+
m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
|
|
152
|
+
for p in self._path.iterdir():
|
|
153
|
+
if not p.is_dir() and m.match(p.name):
|
|
154
|
+
p.unlink()
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
|
|
158
|
+
except subprocess.CalledProcessError as e:
|
|
159
|
+
print(e.stderr.decode(), file=sys.stderr)
|
|
160
|
+
print(e.stdout.decode(), file=sys.stdout)
|
|
161
|
+
raise RuntimeError('Compilation failed!!') from e
|
|
162
|
+
if r.returncode != 0:
|
|
163
|
+
print(r.stderr.decode(), file=sys.stderr)
|
|
164
|
+
print(r.stdout.decode(), file=sys.stderr)
|
|
165
|
+
raise RuntimeError('Compilation failed!!')
|
|
166
|
+
|
|
167
|
+
self._load_lib(self._uuid)
|
|
168
|
+
|
|
169
|
+
def _load_lib(self, uuid: str | None = None):
|
|
170
|
+
uuid = uuid if uuid is not None else self._uuid
|
|
171
|
+
self._uuid = uuid
|
|
172
|
+
lib_path = self._path / f'lib{self._prj_name}_wrapper_{uuid}.so'
|
|
173
|
+
if not lib_path.exists():
|
|
174
|
+
raise RuntimeError(f'Library {lib_path} does not exist')
|
|
175
|
+
self._lib = ctypes.CDLL(str(lib_path))
|
|
176
|
+
|
|
177
|
+
def compile(self, verbose=False, openmp=True, o3: bool = False):
|
|
178
|
+
"""Compile the generated code to a emulator for logic simulation.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
verbose : bool, optional
|
|
183
|
+
Verbose output, by default False
|
|
184
|
+
openmp : bool, optional
|
|
185
|
+
Enable openmp, by default True
|
|
186
|
+
o3 : bool | None, optional
|
|
187
|
+
Turn on -O3 flag, by default False
|
|
188
|
+
|
|
189
|
+
Raises
|
|
190
|
+
------
|
|
191
|
+
RuntimeError
|
|
192
|
+
If compilation fails
|
|
193
|
+
"""
|
|
194
|
+
self.write()
|
|
195
|
+
self._compile(verbose=verbose, openmp=openmp, o3=o3)
|
|
196
|
+
self._load_lib()
|
|
197
|
+
|
|
198
|
+
def predict(self, data: NDArray[np.floating]):
|
|
199
|
+
"""Run the model on the input data.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
data : NDArray[np.floating]
|
|
204
|
+
Input data to the model. The shape is ignored, and the number of samples is
|
|
205
|
+
determined by the size of the data.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
NDArray[np.float64]
|
|
210
|
+
Output of the model in shape (n_samples, output_size).
|
|
211
|
+
"""
|
|
212
|
+
assert self._lib is not None, 'Library not loaded, call .compile() first.'
|
|
213
|
+
inp_size, out_size = self._solution.shape
|
|
214
|
+
|
|
215
|
+
assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
|
|
216
|
+
n_sample = data.size // inp_size
|
|
217
|
+
|
|
218
|
+
kifs_in, kifs_out = get_io_kifs(self._solution)
|
|
219
|
+
k_in, i_in, f_in = map(np.max, kifs_in)
|
|
220
|
+
k_out, i_out, f_out = map(np.max, kifs_out)
|
|
221
|
+
assert k_in + i_in + f_in <= 32, "Padded inp bw doesn't fit in int32. Emulation not supported"
|
|
222
|
+
assert k_out + i_out + f_out <= 32, "Padded out bw doesn't fit in int32. Emulation not supported"
|
|
223
|
+
|
|
224
|
+
inp_data = np.empty(n_sample * inp_size, dtype=np.int32)
|
|
225
|
+
out_data = np.empty(n_sample * out_size, dtype=np.int32)
|
|
226
|
+
|
|
227
|
+
# Convert to int32 matching the LSB position
|
|
228
|
+
inp_data[:] = data.ravel() * 2.0 ** np.max(f_in)
|
|
229
|
+
|
|
230
|
+
inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
|
|
231
|
+
out_buf = out_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
|
|
232
|
+
self._lib.inference(inp_buf, out_buf, n_sample)
|
|
233
|
+
|
|
234
|
+
# Unscale the output int32 to recover fp values
|
|
235
|
+
k, i, f = np.max(k_out), np.max(i_out), np.max(f_out)
|
|
236
|
+
a, b, c = 2.0 ** (k + i + f), 2.0 ** (i + f), 2.0**-f
|
|
237
|
+
return ((out_data.reshape(n_sample, out_size) + b) % a - b) * c
|
|
238
|
+
|
|
239
|
+
def __repr__(self):
|
|
240
|
+
inp_size, out_size = self._solution.shape
|
|
241
|
+
cost = round(self._solution.cost)
|
|
242
|
+
kifs_in, kifs_out = get_io_kifs(self._solution)
|
|
243
|
+
in_bits, out_bits = np.sum(kifs_in), np.sum(kifs_out)
|
|
244
|
+
if self._pipe is not None:
|
|
245
|
+
n_stage = len(self._pipe[0])
|
|
246
|
+
lat_cutoff = self._latency_cutoff
|
|
247
|
+
reg_bits = self._pipe.reg_bits
|
|
248
|
+
spec = f"""Top Module: {self._prj_name}\n====================
|
|
249
|
+
{inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
|
|
250
|
+
{n_stage} stages @ max_delay={lat_cutoff}
|
|
251
|
+
Estimated cost: {cost} LUTs, {reg_bits} FFs"""
|
|
252
|
+
|
|
253
|
+
else:
|
|
254
|
+
spec = f"""Top Module: {self._prj_name}\n====================
|
|
255
|
+
{inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
|
|
256
|
+
combinational @ delay={self._solution.latency}
|
|
257
|
+
Estimated cost: {cost} LUTs"""
|
|
258
|
+
|
|
259
|
+
is_compiled = self._lib is not None
|
|
260
|
+
if is_compiled:
|
|
261
|
+
openmp = 'with OpenMP' if self._lib.openmp_enabled else '' # type: ignore
|
|
262
|
+
spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
|
|
263
|
+
else:
|
|
264
|
+
spec += '\nEmulator is **not compiled**'
|
|
265
|
+
return spec
|
da4ml/trace/__init__.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
from decimal import Decimal
|
|
2
|
+
from math import ceil, floor, log2
|
|
3
|
+
from typing import NamedTuple
|
|
4
|
+
from uuid import UUID, uuid4
|
|
5
|
+
|
|
6
|
+
from ..cmvm.core import cost_add
|
|
7
|
+
from ..cmvm.types import QInterval
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class HWConfig(NamedTuple):
|
|
11
|
+
adder_size: int
|
|
12
|
+
carry_size: int
|
|
13
|
+
latency_cutoff: float
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _const_f(const: float | Decimal):
|
|
17
|
+
const = float(const)
|
|
18
|
+
_low, _high = -32, 32
|
|
19
|
+
while _high - _low > 1:
|
|
20
|
+
_mid = (_high + _low) // 2
|
|
21
|
+
_value = const * (2.0**_mid)
|
|
22
|
+
if _value == int(_value):
|
|
23
|
+
_high = _mid
|
|
24
|
+
else:
|
|
25
|
+
_low = _mid
|
|
26
|
+
return _high
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FixedVariable:
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
low: float | Decimal,
|
|
33
|
+
high: float | Decimal,
|
|
34
|
+
step: float | Decimal,
|
|
35
|
+
latency: float | None = None,
|
|
36
|
+
hwconf=HWConfig(-1, -1, -1),
|
|
37
|
+
opr: str = 'new',
|
|
38
|
+
cost: float | None = None,
|
|
39
|
+
_from: tuple['FixedVariable', ...] = (),
|
|
40
|
+
_factor: float | Decimal = 1.0,
|
|
41
|
+
_data: Decimal | None = None,
|
|
42
|
+
_id: UUID | None = None,
|
|
43
|
+
) -> None:
|
|
44
|
+
assert low <= high, f'low {low} must be less than high {high}'
|
|
45
|
+
|
|
46
|
+
if low == high:
|
|
47
|
+
opr = 'const'
|
|
48
|
+
_factor = 1.0
|
|
49
|
+
_from = ()
|
|
50
|
+
|
|
51
|
+
low, high, step = Decimal(low), Decimal(high), Decimal(step)
|
|
52
|
+
low, high = floor(low / step) * step, ceil(high / step) * step
|
|
53
|
+
self.low = low
|
|
54
|
+
self.high = high
|
|
55
|
+
self.step = step
|
|
56
|
+
self._factor = Decimal(_factor)
|
|
57
|
+
self._from: tuple[FixedVariable, ...] = _from
|
|
58
|
+
opr = opr
|
|
59
|
+
self.opr = opr
|
|
60
|
+
self._data = _data
|
|
61
|
+
self.id = _id or uuid4()
|
|
62
|
+
self.hwconf = hwconf
|
|
63
|
+
|
|
64
|
+
if opr == 'cadd':
|
|
65
|
+
assert _data is not None, 'cadd must have data'
|
|
66
|
+
|
|
67
|
+
if cost is None or latency is None:
|
|
68
|
+
_cost, _latency = self.get_cost_and_latency()
|
|
69
|
+
else:
|
|
70
|
+
_cost, _latency = cost, latency
|
|
71
|
+
|
|
72
|
+
self.latency = _latency
|
|
73
|
+
self.cost = _cost
|
|
74
|
+
|
|
75
|
+
def get_cost_and_latency(self):
|
|
76
|
+
if self.opr == 'const':
|
|
77
|
+
return 0.0, 0.0
|
|
78
|
+
if self.opr in ('vadd', 'cadd'):
|
|
79
|
+
adder_size = self.hwconf.adder_size
|
|
80
|
+
carry_size = self.hwconf.carry_size
|
|
81
|
+
latency_cutoff = self.hwconf.latency_cutoff
|
|
82
|
+
|
|
83
|
+
if self.opr == 'vadd':
|
|
84
|
+
assert len(self._from) == 2
|
|
85
|
+
v0, v1 = self._from
|
|
86
|
+
int0, int1 = v0.qint, v1.qint
|
|
87
|
+
base_latency = max(v0.latency, v1.latency)
|
|
88
|
+
dlat, _cost = cost_add(int0, int1, 0, False, adder_size, carry_size)
|
|
89
|
+
else:
|
|
90
|
+
assert len(self._from) == 1
|
|
91
|
+
assert self._data is not None, 'cadd must have data'
|
|
92
|
+
# int0 = self._from[0].qint
|
|
93
|
+
# int1 = QInterval(float(self._data), float(self._data), float(self.step))
|
|
94
|
+
_f = _const_f(self._data)
|
|
95
|
+
_cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
|
|
96
|
+
base_latency = self._from[0].latency
|
|
97
|
+
dlat = 0.0
|
|
98
|
+
|
|
99
|
+
_latency = dlat + base_latency
|
|
100
|
+
if latency_cutoff > 0 and ceil(_latency / latency_cutoff) > ceil(base_latency / latency_cutoff):
|
|
101
|
+
# Crossed the latency cutoff boundry
|
|
102
|
+
assert (
|
|
103
|
+
dlat <= latency_cutoff
|
|
104
|
+
), f'Latency of an atomic operation {dlat} is larger than the pipelining latency cutoff {latency_cutoff}'
|
|
105
|
+
_latency = ceil(base_latency / latency_cutoff) * latency_cutoff + dlat
|
|
106
|
+
elif self.opr in ('relu', 'wrap'):
|
|
107
|
+
assert len(self._from) == 1
|
|
108
|
+
_latency = self._from[0].latency
|
|
109
|
+
_cost = 0.0
|
|
110
|
+
# Assume LUT5 used here (2 fan-out per LUT6, thus *1/2)
|
|
111
|
+
if self._from[0]._factor < 0:
|
|
112
|
+
_cost += sum(self.kif) / 2
|
|
113
|
+
if self.opr == 'relu':
|
|
114
|
+
_cost += sum(self.kif) / 2
|
|
115
|
+
|
|
116
|
+
elif self.opr == 'new':
|
|
117
|
+
# new variable, no cost
|
|
118
|
+
_latency = 0.0
|
|
119
|
+
_cost = 0.0
|
|
120
|
+
else:
|
|
121
|
+
raise NotImplementedError(f'Operation {self.opr} is unknown')
|
|
122
|
+
return _cost, _latency
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def unscaled(self):
|
|
126
|
+
return self * (1 / self._factor)
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def qint(self) -> QInterval:
|
|
130
|
+
return QInterval(float(self.low), float(self.high), float(self.step))
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def kif(self) -> tuple[bool, int, int]:
|
|
134
|
+
if self.step == 0:
|
|
135
|
+
return False, 0, 0
|
|
136
|
+
f = -int(log2(self.step))
|
|
137
|
+
i = ceil(log2(max(-self.low, self.high + self.step)))
|
|
138
|
+
k = self.low < 0
|
|
139
|
+
return k, i, f
|
|
140
|
+
|
|
141
|
+
def __repr__(self) -> str:
|
|
142
|
+
if self._factor == 1:
|
|
143
|
+
return f'FixedVariable({self.low}, {self.high}, {self.step})'
|
|
144
|
+
return f'({self._factor}) FixedVariable({self.low}, {self.high}, {self.step})'
|
|
145
|
+
|
|
146
|
+
def __neg__(self):
|
|
147
|
+
return FixedVariable(
|
|
148
|
+
-self.high,
|
|
149
|
+
-self.low,
|
|
150
|
+
self.step,
|
|
151
|
+
_from=self._from,
|
|
152
|
+
_factor=-self._factor,
|
|
153
|
+
latency=self.latency,
|
|
154
|
+
cost=self.cost,
|
|
155
|
+
opr=self.opr,
|
|
156
|
+
_id=self.id,
|
|
157
|
+
_data=self._data,
|
|
158
|
+
hwconf=self.hwconf,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def __add__(self, other: 'FixedVariable|float|Decimal|int'):
|
|
162
|
+
if not isinstance(other, FixedVariable):
|
|
163
|
+
return self._const_add(other)
|
|
164
|
+
if other.high == other.low:
|
|
165
|
+
return self._const_add(other.low)
|
|
166
|
+
if self.high == self.low:
|
|
167
|
+
return other._const_add(self.low)
|
|
168
|
+
|
|
169
|
+
assert self.hwconf == other.hwconf, 'FixedVariable must have the same hwconf'
|
|
170
|
+
|
|
171
|
+
f0, f1 = self._factor, other._factor
|
|
172
|
+
if f0 < 0:
|
|
173
|
+
if f1 > 0:
|
|
174
|
+
return other + self
|
|
175
|
+
else:
|
|
176
|
+
return -((-self) + (-other))
|
|
177
|
+
|
|
178
|
+
return FixedVariable(
|
|
179
|
+
self.low + other.low,
|
|
180
|
+
self.high + other.high,
|
|
181
|
+
min(self.step, other.step),
|
|
182
|
+
_from=(self, other),
|
|
183
|
+
_factor=f0,
|
|
184
|
+
opr='vadd',
|
|
185
|
+
hwconf=self.hwconf,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def _const_add(self, other: float | Decimal):
|
|
189
|
+
if not isinstance(other, (int, float, Decimal)):
|
|
190
|
+
other = float(other) # direct numpy to decimal raises error
|
|
191
|
+
other = Decimal(other)
|
|
192
|
+
if other == 0:
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
if self.opr != 'cadd':
|
|
196
|
+
cstep = Decimal(2.0 ** -_const_f(other))
|
|
197
|
+
|
|
198
|
+
return FixedVariable(
|
|
199
|
+
self.low + other,
|
|
200
|
+
self.high + other,
|
|
201
|
+
min(self.step, cstep),
|
|
202
|
+
_from=(self,),
|
|
203
|
+
_factor=self._factor,
|
|
204
|
+
_data=other / self._factor,
|
|
205
|
+
opr='cadd',
|
|
206
|
+
hwconf=self.hwconf,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# cadd, combine the constant
|
|
210
|
+
assert len(self._from) == 1
|
|
211
|
+
parent = self._from[0]
|
|
212
|
+
assert self._data is not None, 'cadd must have data'
|
|
213
|
+
sf = self._factor / parent._factor
|
|
214
|
+
other1 = (self._data * parent._factor) + other / sf
|
|
215
|
+
return (parent + other1) * sf
|
|
216
|
+
|
|
217
|
+
def __sub__(self, other: 'FixedVariable|int|float|Decimal'):
|
|
218
|
+
return self + (-other)
|
|
219
|
+
|
|
220
|
+
def __mul__(
|
|
221
|
+
self,
|
|
222
|
+
other: 'float|Decimal',
|
|
223
|
+
):
|
|
224
|
+
if other == 0:
|
|
225
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf)
|
|
226
|
+
|
|
227
|
+
assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
|
|
228
|
+
|
|
229
|
+
other = Decimal(other)
|
|
230
|
+
|
|
231
|
+
low = min(self.low * other, self.high * other)
|
|
232
|
+
high = max(self.low * other, self.high * other)
|
|
233
|
+
step = abs(self.step * other)
|
|
234
|
+
_factor = self._factor * other
|
|
235
|
+
|
|
236
|
+
return FixedVariable(
|
|
237
|
+
low,
|
|
238
|
+
high,
|
|
239
|
+
step,
|
|
240
|
+
_from=self._from,
|
|
241
|
+
_factor=_factor,
|
|
242
|
+
opr=self.opr,
|
|
243
|
+
latency=self.latency,
|
|
244
|
+
cost=self.cost,
|
|
245
|
+
_id=self.id,
|
|
246
|
+
_data=self._data,
|
|
247
|
+
hwconf=self.hwconf,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def __radd__(self, other: 'float|Decimal|int|FixedVariable'):
|
|
251
|
+
return self + other
|
|
252
|
+
|
|
253
|
+
def __rsub__(self, other: 'float|Decimal|int|FixedVariable'):
|
|
254
|
+
return (-self) + other
|
|
255
|
+
|
|
256
|
+
def __rmul__(self, other: 'float|Decimal|int|FixedVariable'):
|
|
257
|
+
return self * other
|
|
258
|
+
|
|
259
|
+
def relu(self, i: int | None = None, f: int | None = None, round_mode: str = 'TRN'):
|
|
260
|
+
round_mode = round_mode.upper()
|
|
261
|
+
assert round_mode in ('TRN', 'RND')
|
|
262
|
+
|
|
263
|
+
if self.opr == 'const':
|
|
264
|
+
val = self.low * (self.low > 0)
|
|
265
|
+
f = _const_f(val) if not f else f
|
|
266
|
+
step = Decimal(2) ** -f
|
|
267
|
+
i = ceil(log2(val + step)) if not i else i
|
|
268
|
+
eps = step / 2 if round_mode == 'RND' else 0
|
|
269
|
+
val = floor(val / step + eps) % Decimal(2) ** i * step
|
|
270
|
+
return FixedVariable(val, val, step, hwconf=self.hwconf)
|
|
271
|
+
|
|
272
|
+
step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
|
|
273
|
+
if step > self.step and round_mode == 'RND':
|
|
274
|
+
return (self + step / 2).relu(i, f, 'TRN')
|
|
275
|
+
low = max(Decimal(0), self.low)
|
|
276
|
+
high = max(Decimal(0), self.high)
|
|
277
|
+
if i is not None:
|
|
278
|
+
_high = Decimal(2) ** i - step
|
|
279
|
+
if _high < high:
|
|
280
|
+
# overflows
|
|
281
|
+
low = Decimal(0)
|
|
282
|
+
high = _high
|
|
283
|
+
_factor = self._factor
|
|
284
|
+
return FixedVariable(
|
|
285
|
+
low,
|
|
286
|
+
high,
|
|
287
|
+
step,
|
|
288
|
+
_from=(self,),
|
|
289
|
+
_factor=abs(_factor),
|
|
290
|
+
opr='relu',
|
|
291
|
+
hwconf=self.hwconf,
|
|
292
|
+
cost=sum(self.kif) * (1 if _factor > 0 else 2),
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def quantize(
|
|
296
|
+
self,
|
|
297
|
+
k: int | bool,
|
|
298
|
+
i: int,
|
|
299
|
+
f: int,
|
|
300
|
+
overflow_mode: str = 'WRAP',
|
|
301
|
+
round_mode: str = 'TRN',
|
|
302
|
+
):
|
|
303
|
+
overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
|
|
304
|
+
assert overflow_mode in ('WRAP', 'SAT')
|
|
305
|
+
assert round_mode in ('TRN', 'RND')
|
|
306
|
+
|
|
307
|
+
_k, _i, _f = self.kif
|
|
308
|
+
|
|
309
|
+
if k >= _k and i >= _i and f >= _f:
|
|
310
|
+
return self
|
|
311
|
+
|
|
312
|
+
if f < _f and round_mode == 'RND':
|
|
313
|
+
return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
314
|
+
|
|
315
|
+
if self.low == self.high:
|
|
316
|
+
val = self.low
|
|
317
|
+
step = Decimal(2) ** -f
|
|
318
|
+
_high = Decimal(2) ** i
|
|
319
|
+
high, low = _high - step, -_high * k
|
|
320
|
+
val = (floor(val / step) * step - low) % (2 * _high) + low
|
|
321
|
+
return FixedVariable(val, val, step, hwconf=self.hwconf)
|
|
322
|
+
|
|
323
|
+
# TODO: corner cases exists (e.g., overflow to negative, or negative overflow to high value)
|
|
324
|
+
# bit-exactness will be lost in these cases, but they should never happen (quantizers are used in a weird way)
|
|
325
|
+
# Keeping this for now; change if absolutely necessary
|
|
326
|
+
f = min(f, _f)
|
|
327
|
+
k = min(k, _k)
|
|
328
|
+
i = min(i, _i)
|
|
329
|
+
|
|
330
|
+
step = max(Decimal(2) ** -f, self.step)
|
|
331
|
+
|
|
332
|
+
low = -k * Decimal(2) ** i
|
|
333
|
+
high = Decimal(2) ** i - step
|
|
334
|
+
_low, _high = self.low, self.high
|
|
335
|
+
|
|
336
|
+
if _low >= low and _high <= high:
|
|
337
|
+
low, high = _low, _high
|
|
338
|
+
|
|
339
|
+
if low > high:
|
|
340
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf)
|
|
341
|
+
|
|
342
|
+
return FixedVariable(
|
|
343
|
+
low,
|
|
344
|
+
high,
|
|
345
|
+
step,
|
|
346
|
+
_from=(self,),
|
|
347
|
+
_factor=abs(self._factor),
|
|
348
|
+
opr='wrap' if overflow_mode == 'WRAP' else 'sat',
|
|
349
|
+
latency=self.latency,
|
|
350
|
+
hwconf=self.hwconf,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
def from_kif(cls, k: int | bool, i: int, f: int, **kwargs):
|
|
355
|
+
step = Decimal(2) ** -f
|
|
356
|
+
_high = Decimal(2) ** i
|
|
357
|
+
low, high = k * _high, _high - step
|
|
358
|
+
return cls(low, high, step, **kwargs)
|