da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +204 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +246 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.1.post1.dist-info/METADATA +85 -0
- da4ml-0.5.1.post1.dist-info/RECORD +96 -0
- da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
- da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
- da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
da4ml/cmvm/types.py
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
from functools import reduce, singledispatch
|
|
5
|
+
from math import ceil, floor, log2
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, NamedTuple, TypeVar
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from numba import jit
|
|
11
|
+
from numpy import float32, int8
|
|
12
|
+
from numpy.typing import NDArray
|
|
13
|
+
|
|
14
|
+
from .._binary import dais_interp_run
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..trace.fixed_variable import FixedVariable, LookupTable
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class QInterval(NamedTuple):
|
|
21
|
+
"""A class representing a quantized interval: [min, max] with a step size."""
|
|
22
|
+
|
|
23
|
+
min: float
|
|
24
|
+
max: float
|
|
25
|
+
step: float
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_kif(cls, k: int | bool, i: int, f: int):
|
|
29
|
+
_high = 2.0**i
|
|
30
|
+
step = 2.0**-f
|
|
31
|
+
low, high = -k * step, _high - step
|
|
32
|
+
return cls(low, high, step)
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def from_precision(cls, prec: 'Precision'):
|
|
36
|
+
return cls.from_kif(*prec)
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def precision(self):
|
|
40
|
+
return Precision.from_qint(self)
|
|
41
|
+
|
|
42
|
+
def __repr__(self):
|
|
43
|
+
return f'[{self.min}, {self.max}, {self.step}]'
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Precision(NamedTuple):
|
|
47
|
+
"""A class representing the precision of a quantized interval."""
|
|
48
|
+
|
|
49
|
+
keep_negative: bool
|
|
50
|
+
integers: int
|
|
51
|
+
fractional: int
|
|
52
|
+
|
|
53
|
+
def __str__(self):
|
|
54
|
+
k, i, f = self.keep_negative, self.integers, self.fractional
|
|
55
|
+
return f'fixed({k=}, {i=}, {f=})'
|
|
56
|
+
|
|
57
|
+
def __repr__(self):
|
|
58
|
+
return str(self)
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def from_qint(cls, qint: QInterval, symmetric: bool = False):
|
|
62
|
+
return _minimal_kif(qint, symmetric=symmetric)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def qint(self):
|
|
66
|
+
return QInterval.from_kif(*self)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Op(NamedTuple):
|
|
70
|
+
"""One single operation on the data buffer.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
id0: int
|
|
75
|
+
index of the first operand
|
|
76
|
+
id1: int
|
|
77
|
+
index of the second operand, or special opcode if negative
|
|
78
|
+
opcode: int
|
|
79
|
+
0: addition, 1: subtraction, 2: relu, 3: quantize, 4: const addition
|
|
80
|
+
data: int
|
|
81
|
+
Data to be used in the operation
|
|
82
|
+
qint: QInterval
|
|
83
|
+
Quantization interval of the resultant buffer
|
|
84
|
+
latency: float
|
|
85
|
+
Latency of the data generated by this operation (t_available)
|
|
86
|
+
cost: float
|
|
87
|
+
Cost of the operation
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
id0: int
|
|
91
|
+
id1: int
|
|
92
|
+
opcode: int
|
|
93
|
+
data: int
|
|
94
|
+
qint: QInterval
|
|
95
|
+
latency: float
|
|
96
|
+
cost: float
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Pair(NamedTuple):
|
|
100
|
+
"""An operation representing data[id0] +/- data[id1] * 2**shift."""
|
|
101
|
+
|
|
102
|
+
id0: int
|
|
103
|
+
id1: int
|
|
104
|
+
sub: bool
|
|
105
|
+
shift: int
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class DAState(NamedTuple):
|
|
109
|
+
"""Internal state of the DA algorithm."""
|
|
110
|
+
|
|
111
|
+
shifts: tuple[NDArray[int8], NDArray[int8]]
|
|
112
|
+
expr: list[NDArray[int8]]
|
|
113
|
+
ops: list[Op]
|
|
114
|
+
freq_stat: dict[Pair, int]
|
|
115
|
+
kernel: NDArray[float32]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision:
|
|
119
|
+
"""Calculate the minimal KIF for a given QInterval.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
qi : QInterval
|
|
124
|
+
The QInterval to calculate the KIF for.
|
|
125
|
+
symmetric : bool
|
|
126
|
+
Only relevant if qi may be negative. If True, -2**i will be regarded as forbidden.
|
|
127
|
+
May be useful in special cases only.
|
|
128
|
+
Default is False.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
Precision
|
|
133
|
+
A named tuple with the KIF values.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
if qi.min == qi.max == 0:
|
|
137
|
+
return Precision(keep_negative=False, integers=0, fractional=0)
|
|
138
|
+
keep_negative = qi.min < 0
|
|
139
|
+
fractional = int(-log2(qi.step))
|
|
140
|
+
int_min, int_max = round(qi.min / qi.step), round(qi.max / qi.step)
|
|
141
|
+
if symmetric:
|
|
142
|
+
bits = int(ceil(log2(max(abs(int_min), int_max) + 1)))
|
|
143
|
+
else:
|
|
144
|
+
bits = int(ceil(log2(max(abs(int_min), int_max + 1))))
|
|
145
|
+
integers = bits - fractional
|
|
146
|
+
return Precision(keep_negative=keep_negative, integers=integers, fractional=fractional)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
if TYPE_CHECKING:
|
|
150
|
+
|
|
151
|
+
def minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision: ...
|
|
152
|
+
else:
|
|
153
|
+
minimal_kif = jit(_minimal_kif)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
T = TypeVar('T', 'FixedVariable', float, int, np.float32, np.float64, Decimal)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@singledispatch
|
|
160
|
+
def _relu(v: 'T', i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN') -> 'T':
|
|
161
|
+
from ..trace.fixed_variable import FixedVariable
|
|
162
|
+
|
|
163
|
+
assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic relu'
|
|
164
|
+
if inv:
|
|
165
|
+
v = -v
|
|
166
|
+
return v.relu(i, f, round_mode=round_mode)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@_relu.register(float)
|
|
170
|
+
@_relu.register(int)
|
|
171
|
+
@_relu.register(np.float32)
|
|
172
|
+
@_relu.register(np.float64)
|
|
173
|
+
def _(v, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
|
|
174
|
+
if inv:
|
|
175
|
+
v = -v
|
|
176
|
+
v = max(0, v)
|
|
177
|
+
if f is not None:
|
|
178
|
+
if round_mode.upper() == 'RND':
|
|
179
|
+
v += 2.0 ** (-f - 1)
|
|
180
|
+
sf = 2.0**f
|
|
181
|
+
v = floor(v * sf) / sf
|
|
182
|
+
if i is not None:
|
|
183
|
+
v = v % 2.0**i
|
|
184
|
+
return v
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@_relu.register
|
|
188
|
+
def _(v: Decimal, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
|
|
189
|
+
if inv:
|
|
190
|
+
v = -v
|
|
191
|
+
v = max(Decimal(0), v)
|
|
192
|
+
if f is not None:
|
|
193
|
+
if round_mode.upper() == 'RND':
|
|
194
|
+
v += Decimal(2) ** (-f - 1)
|
|
195
|
+
sf = Decimal(2) ** f
|
|
196
|
+
v = floor(v * sf) / sf
|
|
197
|
+
if i is not None:
|
|
198
|
+
v = v % Decimal(2) ** i
|
|
199
|
+
return v
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@singledispatch
|
|
203
|
+
def _quantize(v: 'T', k: int | bool, i: int, f: int, round_mode: str = 'TRN') -> 'T':
|
|
204
|
+
from ..trace.fixed_variable import FixedVariable
|
|
205
|
+
|
|
206
|
+
assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic quantization'
|
|
207
|
+
return v.quantize(k, i, f, round_mode=round_mode)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@_quantize.register(float)
|
|
211
|
+
@_quantize.register(int)
|
|
212
|
+
@_quantize.register(np.float32)
|
|
213
|
+
@_quantize.register(np.float64)
|
|
214
|
+
def _(v, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
|
|
215
|
+
if round_mode.upper() == 'RND':
|
|
216
|
+
v += 2.0 ** (-f - 1)
|
|
217
|
+
b = k + i + f
|
|
218
|
+
bias = 2.0 ** (b - 1) * k
|
|
219
|
+
eps = 2.0**-f
|
|
220
|
+
return eps * ((np.floor(v / eps) + bias) % 2**b - bias)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@_quantize.register
|
|
224
|
+
def _(v: Decimal, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
|
|
225
|
+
if round_mode.upper() == 'RND':
|
|
226
|
+
v += Decimal(2) ** (-f - 1)
|
|
227
|
+
b = k + i + f
|
|
228
|
+
bias = Decimal(2) ** (b - 1) * k
|
|
229
|
+
eps = Decimal(2) ** -f
|
|
230
|
+
return eps * ((floor(v / eps) + bias) % Decimal(2) ** b - bias)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class JSONEncoder(json.JSONEncoder):
|
|
234
|
+
def default(self, o):
|
|
235
|
+
if hasattr(o, 'to_dict'):
|
|
236
|
+
return o.to_dict()
|
|
237
|
+
super().default(o)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class CombLogic(NamedTuple):
|
|
241
|
+
"""A combinational logic that describes a series of operations on input data to produce output data.
|
|
242
|
+
|
|
243
|
+
Attributes
|
|
244
|
+
----------
|
|
245
|
+
shape: tuple[int, int]
|
|
246
|
+
#input, #output
|
|
247
|
+
inp_shifts: list[int]
|
|
248
|
+
The shifts that should be applied to the input data.
|
|
249
|
+
out_idxs: list[int]
|
|
250
|
+
The indices of the output data in the buffer.
|
|
251
|
+
out_shifts: list[int]
|
|
252
|
+
The shifts that should be applied to the output data.
|
|
253
|
+
out_negs: list[bool]
|
|
254
|
+
The signs of the output data.
|
|
255
|
+
ops: list[Op]
|
|
256
|
+
Core list of operations for generating each buffer element.
|
|
257
|
+
carry_size: int
|
|
258
|
+
Size of the carrier for the adder, used for cost and latency estimation.
|
|
259
|
+
adder_size: int
|
|
260
|
+
Elementary size of the adder, used for cost and latency estimation.
|
|
261
|
+
lookup_tables: tuple[LookupTable, ...] | None
|
|
262
|
+
Lookup table arrays for lookup operations, if any.
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
The core part of the comb logic is the operations in the ops list.
|
|
266
|
+
For the exact operations executed with Op, refer to the Op class.
|
|
267
|
+
After all operations are executed, the output data is read from data[op.out_idx] and multiplied by 2**out_shift.
|
|
268
|
+
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
shape: tuple[int, int]
|
|
272
|
+
inp_shifts: list[int]
|
|
273
|
+
out_idxs: list[int]
|
|
274
|
+
out_shifts: list[int]
|
|
275
|
+
out_negs: list[bool]
|
|
276
|
+
ops: list[Op]
|
|
277
|
+
carry_size: int
|
|
278
|
+
adder_size: int
|
|
279
|
+
lookup_tables: 'tuple[LookupTable, ...] | None' = None
|
|
280
|
+
|
|
281
|
+
def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False, dump=False):
|
|
282
|
+
"""Executes the solution on the input data.
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
inp : list | np.ndarray | tuple
|
|
287
|
+
Input data to be processed. The input data should be a list or numpy array of objects.
|
|
288
|
+
quantize : bool
|
|
289
|
+
If True, the input data will be quantized to the output quantization intervals.
|
|
290
|
+
Only floating point data types are supported when quantize is True.
|
|
291
|
+
Default is False.
|
|
292
|
+
debug : bool
|
|
293
|
+
If True, the function will print debug information about the operations being performed.
|
|
294
|
+
Default is False.
|
|
295
|
+
dump : bool
|
|
296
|
+
If True, the return the whole buffer, without applying the output shifts and signs.
|
|
297
|
+
Default is False.
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
np.ndarray
|
|
302
|
+
The output data after applying the operations defined in the solution.
|
|
303
|
+
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
from ..trace.fixed_variable import FixedVariable
|
|
307
|
+
|
|
308
|
+
buf = np.empty(len(self.ops), dtype=object)
|
|
309
|
+
inp = np.asarray(inp)
|
|
310
|
+
|
|
311
|
+
inp_qint = [op.qint for op in self.ops if op.opcode == -1]
|
|
312
|
+
if quantize: # TRN and WRAP
|
|
313
|
+
k, i, f = map(np.array, zip(*map(minimal_kif, inp_qint)))
|
|
314
|
+
inp = [_quantize(*x, round_mode='TRN') for x in zip(inp, k, i, f)]
|
|
315
|
+
|
|
316
|
+
inp = inp * (2.0 ** np.array(self.inp_shifts))
|
|
317
|
+
for i, op in enumerate(self.ops):
|
|
318
|
+
match op.opcode:
|
|
319
|
+
case -1: # copy form external buffer
|
|
320
|
+
buf[i] = inp[op.id0]
|
|
321
|
+
case 0 | 1: # addition
|
|
322
|
+
v0, v1 = buf[op.id0], 2.0**op.data * buf[op.id1]
|
|
323
|
+
buf[i] = v0 + v1 if op.opcode == 0 else v0 - v1
|
|
324
|
+
case 2 | -2: # relu(+/-x)
|
|
325
|
+
v = buf[op.id0]
|
|
326
|
+
_, _i, _f = _minimal_kif(op.qint)
|
|
327
|
+
buf[i] = _relu(v, _i, _f, inv=op.opcode == -2, round_mode='TRN')
|
|
328
|
+
case 3 | -3: # quantize(+/-x)
|
|
329
|
+
v = buf[op.id0] if op.opcode == 3 else -buf[op.id0]
|
|
330
|
+
_k, _i, _f = _minimal_kif(op.qint)
|
|
331
|
+
buf[i] = _quantize(v, _k, _i, _f, round_mode='TRN')
|
|
332
|
+
case 4: # const addition
|
|
333
|
+
bias = op.data * op.qint.step
|
|
334
|
+
buf[i] = buf[op.id0] + bias
|
|
335
|
+
case 5: # const definition
|
|
336
|
+
buf[i] = op.data * op.qint.step # const definition
|
|
337
|
+
case 6 | -6: # MSB Mux
|
|
338
|
+
id_c = op.data & 0xFFFFFFFF
|
|
339
|
+
k, v0, v1 = buf[id_c], buf[op.id0], buf[op.id1]
|
|
340
|
+
shift = (op.data >> 32) & 0xFFFFFFFF
|
|
341
|
+
shift = shift if shift < 0x80000000 else shift - 0x100000000
|
|
342
|
+
if op.opcode == -6:
|
|
343
|
+
v1 = -v1
|
|
344
|
+
|
|
345
|
+
if isinstance(k, FixedVariable):
|
|
346
|
+
buf[i] = k.msb_mux(v0, v1 * 2**shift, op.qint)
|
|
347
|
+
else:
|
|
348
|
+
qint_k = self.ops[id_c].qint
|
|
349
|
+
if qint_k.min < 0:
|
|
350
|
+
buf[i] = v0 if k < 0 else v1 * 2.0**shift
|
|
351
|
+
else:
|
|
352
|
+
_k, _i, _f = _minimal_kif(qint_k)
|
|
353
|
+
buf[i] = v0 if k >= 2.0 ** (_i - 1) else v1 * 2.0**shift
|
|
354
|
+
case 7:
|
|
355
|
+
v0, v1 = buf[op.id0], buf[op.id1]
|
|
356
|
+
buf[i] = v0 * v1
|
|
357
|
+
case 8:
|
|
358
|
+
v0 = buf[op.id0]
|
|
359
|
+
tables = self.lookup_tables
|
|
360
|
+
assert tables is not None, 'No lookup table provided for lookup operation'
|
|
361
|
+
table = tables[op.data]
|
|
362
|
+
buf[i] = table.lookup(v0, self.ops[op.id0].qint)
|
|
363
|
+
case _:
|
|
364
|
+
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
365
|
+
|
|
366
|
+
sf = 2.0 ** np.array(self.out_shifts, dtype=np.float64)
|
|
367
|
+
sign = np.where(self.out_negs, -1, 1)
|
|
368
|
+
out_idx = np.array(self.out_idxs, dtype=np.int32)
|
|
369
|
+
mask = np.where(out_idx < 0, 0, 1)
|
|
370
|
+
if debug:
|
|
371
|
+
operands = []
|
|
372
|
+
for i, v in enumerate(buf):
|
|
373
|
+
op = self.ops[i]
|
|
374
|
+
match op.opcode:
|
|
375
|
+
case -1:
|
|
376
|
+
op_str = 'inp'
|
|
377
|
+
case 0 | 1:
|
|
378
|
+
_sign = '-' if op.opcode == 1 else '+'
|
|
379
|
+
op_str = f'buf[{op.id0}] {_sign} buf[{op.id1}]<<{op.data}'
|
|
380
|
+
case 2 | -2:
|
|
381
|
+
_sign = '' if op.opcode == 2 else '-'
|
|
382
|
+
op_str = f'relu({_sign}buf[{op.id0}])'
|
|
383
|
+
case 3 | -3:
|
|
384
|
+
_sign = '' if op.opcode == 3 else '-'
|
|
385
|
+
op_str = f'quantize({_sign}buf[{op.id0}])'
|
|
386
|
+
case 4:
|
|
387
|
+
op_str = f'buf[{op.id0}] + {op.data * op.qint.step}'
|
|
388
|
+
case 5:
|
|
389
|
+
op_str = f'const {op.data * op.qint.step}'
|
|
390
|
+
case 6 | -6:
|
|
391
|
+
_sign = '-' if op.opcode == -6 else ''
|
|
392
|
+
op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
|
|
393
|
+
case 7:
|
|
394
|
+
op_str = f'buf[{op.id0}] * buf[{op.id1}]'
|
|
395
|
+
case 8:
|
|
396
|
+
op_str = f'tables[{int(op.data)}].lookup(buf[{op.id0}])'
|
|
397
|
+
case _:
|
|
398
|
+
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
399
|
+
|
|
400
|
+
result = f'|-> buf[{i}] = {v}'
|
|
401
|
+
operands.append((op_str, result))
|
|
402
|
+
max_len = max(len(op[0]) for op in operands)
|
|
403
|
+
for op_str, result in operands:
|
|
404
|
+
print(f'{op_str:<{max_len}} {result}')
|
|
405
|
+
|
|
406
|
+
if dump:
|
|
407
|
+
return buf
|
|
408
|
+
return buf[out_idx] * sf * sign * mask
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
def kernel(self):
|
|
412
|
+
"""the kernel represented by the solution, when applicable."""
|
|
413
|
+
kernel = np.empty(self.shape, dtype=np.float32)
|
|
414
|
+
for i, one_hot in enumerate(np.identity(self.shape[0])):
|
|
415
|
+
kernel[i] = self(one_hot)
|
|
416
|
+
return kernel
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def cost(self):
|
|
420
|
+
"""Total cost of the solution."""
|
|
421
|
+
return float(sum(op.cost for op in self.ops))
|
|
422
|
+
|
|
423
|
+
@property
|
|
424
|
+
def latency(self):
|
|
425
|
+
"""Minimum and maximum latency of the solution."""
|
|
426
|
+
latency = [self.ops[i].latency for i in self.out_idxs]
|
|
427
|
+
if len(latency) == 0:
|
|
428
|
+
return 0.0, 0.0
|
|
429
|
+
return min(latency), max(latency)
|
|
430
|
+
|
|
431
|
+
def __repr__(self):
|
|
432
|
+
n_in, n_out = self.shape
|
|
433
|
+
cost = self.cost
|
|
434
|
+
lat_min, lat_max = self.latency
|
|
435
|
+
return f'Solution([{n_in} -> {n_out}], cost={cost}, latency={lat_min}-{lat_max})'
|
|
436
|
+
|
|
437
|
+
@property
|
|
438
|
+
def out_latency(self):
|
|
439
|
+
"""Latencies of all output elements of the solution."""
|
|
440
|
+
return [self.ops[i].latency if i >= 0 else 0.0 for i in self.out_idxs]
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def out_qint(self):
|
|
444
|
+
"""Quantization intervals of the output elements."""
|
|
445
|
+
buf = []
|
|
446
|
+
for i, idx in enumerate(self.out_idxs):
|
|
447
|
+
_min, _max, _step = self.ops[idx].qint
|
|
448
|
+
sf = 2.0 ** self.out_shifts[i]
|
|
449
|
+
_min, _max, _step = _min * sf, _max * sf, _step * sf
|
|
450
|
+
if self.out_negs[i]:
|
|
451
|
+
_min, _max = -_max, -_min
|
|
452
|
+
buf.append(QInterval(_min, _max, _step))
|
|
453
|
+
return buf
|
|
454
|
+
|
|
455
|
+
@property
|
|
456
|
+
def out_kifs(self):
|
|
457
|
+
"""KIFs of all output elements of the solution."""
|
|
458
|
+
return np.array([_minimal_kif(qi) for qi in self.out_qint]).T
|
|
459
|
+
|
|
460
|
+
@property
|
|
461
|
+
def inp_latency(self):
|
|
462
|
+
"""Latencies of all input elements of the solution."""
|
|
463
|
+
return [op.latency for op in self.ops if op.opcode == -1]
|
|
464
|
+
|
|
465
|
+
@property
|
|
466
|
+
def inp_qint(self):
|
|
467
|
+
"""Quantization intervals of the input elements."""
|
|
468
|
+
qints = [QInterval(0.0, 0.0, 1.0) for _ in range(self.shape[0])]
|
|
469
|
+
for op in self.ops:
|
|
470
|
+
if op.opcode != -1:
|
|
471
|
+
continue
|
|
472
|
+
qints[op.id0] = op.qint
|
|
473
|
+
return qints
|
|
474
|
+
|
|
475
|
+
@property
|
|
476
|
+
def inp_kifs(self):
|
|
477
|
+
"""KIFs of all input elements of the solution."""
|
|
478
|
+
return np.array([_minimal_kif(qi) for qi in self.inp_qint]).T
|
|
479
|
+
|
|
480
|
+
def save(self, path: str | Path):
|
|
481
|
+
"""Save the solution to a file."""
|
|
482
|
+
with open(path, 'w') as f:
|
|
483
|
+
json.dump(self, f, cls=JSONEncoder)
|
|
484
|
+
|
|
485
|
+
@classmethod
|
|
486
|
+
def deserialize(cls, data: list):
|
|
487
|
+
"""Load the solution from a file."""
|
|
488
|
+
ops = []
|
|
489
|
+
for _op in data[5]:
|
|
490
|
+
op = Op(*_op[:4], QInterval(*_op[4]), *_op[5:]) # type: ignore
|
|
491
|
+
ops.append(op)
|
|
492
|
+
assert len(data) in (8, 9), f'{len(data)}'
|
|
493
|
+
lookup_tables = data[8] if len(data) > 8 else None
|
|
494
|
+
if lookup_tables is not None:
|
|
495
|
+
from ..trace.fixed_variable import LookupTable
|
|
496
|
+
|
|
497
|
+
lookup_tables = tuple(LookupTable.from_dict(tab) for tab in lookup_tables)
|
|
498
|
+
return cls(
|
|
499
|
+
shape=tuple(data[0]),
|
|
500
|
+
inp_shifts=data[1],
|
|
501
|
+
out_idxs=data[2],
|
|
502
|
+
out_shifts=data[3],
|
|
503
|
+
out_negs=data[4],
|
|
504
|
+
ops=ops,
|
|
505
|
+
carry_size=data[6],
|
|
506
|
+
adder_size=data[7],
|
|
507
|
+
lookup_tables=lookup_tables,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
@classmethod
|
|
511
|
+
def load(cls, path: str | Path):
|
|
512
|
+
"""Load the solution from a file."""
|
|
513
|
+
with open(path) as f:
|
|
514
|
+
data = json.load(f)
|
|
515
|
+
return cls.deserialize(data)
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def ref_count(self) -> np.ndarray:
|
|
519
|
+
"""The number of references to the output elements in the solution."""
|
|
520
|
+
ref_count = np.zeros(len(self.ops), dtype=np.uint64)
|
|
521
|
+
for op in self.ops:
|
|
522
|
+
if op.opcode == -1:
|
|
523
|
+
continue
|
|
524
|
+
id0, id1 = op.id0, op.id1
|
|
525
|
+
if id0 != -1:
|
|
526
|
+
ref_count[id0] += 1
|
|
527
|
+
if id1 != -1:
|
|
528
|
+
ref_count[id1] += 1
|
|
529
|
+
if op.opcode in (6, -6):
|
|
530
|
+
# msb_mux operation
|
|
531
|
+
ref_count[op.data & 0xFFFFFFFF] += 1
|
|
532
|
+
for i in self.out_idxs:
|
|
533
|
+
if i < 0:
|
|
534
|
+
continue
|
|
535
|
+
ref_count[i] += 1
|
|
536
|
+
return ref_count
|
|
537
|
+
|
|
538
|
+
def to_binary(self, version: int = 0) -> NDArray[np.int32]:
|
|
539
|
+
n_in, n_out = self.shape
|
|
540
|
+
header_size_i32 = 6 + n_in + n_out * 3
|
|
541
|
+
n_tables = len(self.lookup_tables) if self.lookup_tables is not None else 0
|
|
542
|
+
|
|
543
|
+
header = np.concatenate(
|
|
544
|
+
[
|
|
545
|
+
[0, version, n_in, n_out, len(self.ops), n_tables],
|
|
546
|
+
self.inp_shifts,
|
|
547
|
+
self.out_idxs,
|
|
548
|
+
self.out_shifts,
|
|
549
|
+
self.out_negs,
|
|
550
|
+
],
|
|
551
|
+
axis=0,
|
|
552
|
+
dtype=np.int32,
|
|
553
|
+
)
|
|
554
|
+
assert len(header) == header_size_i32, f'Header size mismatch: {len(header)} != {header_size_i32}'
|
|
555
|
+
code = np.empty((len(self.ops), 8), dtype=np.int32)
|
|
556
|
+
for i, op in enumerate(self.ops):
|
|
557
|
+
buf = code[i]
|
|
558
|
+
buf[0] = op.opcode
|
|
559
|
+
buf[1] = op.id0
|
|
560
|
+
buf[2] = op.id1
|
|
561
|
+
buf[5:] = _minimal_kif(op.qint)
|
|
562
|
+
buf_i64 = buf[3:5].view(np.int64)
|
|
563
|
+
if op.opcode != 8:
|
|
564
|
+
buf_i64[0] = op.data
|
|
565
|
+
else:
|
|
566
|
+
assert self.lookup_tables is not None
|
|
567
|
+
pad_left = self.lookup_tables[op.data]._get_pads(self.ops[op.id0].qint)[0]
|
|
568
|
+
buf_i64[0] = (pad_left << 32) | op.data
|
|
569
|
+
data = np.concatenate([header, code.flatten()])
|
|
570
|
+
|
|
571
|
+
if self.lookup_tables is None:
|
|
572
|
+
return data
|
|
573
|
+
|
|
574
|
+
tables = [table.table for table in self.lookup_tables]
|
|
575
|
+
table_sizes = [len(tab) for tab in tables]
|
|
576
|
+
table_data = np.concatenate([table_sizes] + tables, axis=0, dtype=np.int32)
|
|
577
|
+
data = np.concatenate([data, table_data])
|
|
578
|
+
return data
|
|
579
|
+
|
|
580
|
+
def save_binary(self, path: str | Path, version: int = 0):
|
|
581
|
+
"""Dump the solution to a binary file."""
|
|
582
|
+
data = self.to_binary(version=version)
|
|
583
|
+
with open(path, 'wb') as f:
|
|
584
|
+
data.tofile(f)
|
|
585
|
+
|
|
586
|
+
def predict(
|
|
587
|
+
self,
|
|
588
|
+
data: NDArray | Sequence[NDArray],
|
|
589
|
+
n_threads: int = -1,
|
|
590
|
+
) -> NDArray[np.float64]:
|
|
591
|
+
"""Predict the output of the solution for a batch of input data with cpp backed DAIS interpreter.
|
|
592
|
+
Cannot be used if the binary interpreter is not installed.
|
|
593
|
+
|
|
594
|
+
Parameters
|
|
595
|
+
----------
|
|
596
|
+
data : NDArray|Sequence[NDArray]
|
|
597
|
+
Input data to the model. The shape is ignored, and the number of samples is
|
|
598
|
+
determined by the size of the data.
|
|
599
|
+
n_threads: int
|
|
600
|
+
Number of threads to use for prediction.
|
|
601
|
+
Negative or zero values will use maximum available threads. Default is -1.
|
|
602
|
+
If OpenMP is not supported, this parameter is ignored.
|
|
603
|
+
|
|
604
|
+
Returns
|
|
605
|
+
-------
|
|
606
|
+
NDArray[np.float64]
|
|
607
|
+
Output of the model in shape (n_samples, output_size).
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
if isinstance(data, Sequence):
|
|
611
|
+
data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
|
|
612
|
+
|
|
613
|
+
if n_threads == 0:
|
|
614
|
+
n_threads = -1
|
|
615
|
+
|
|
616
|
+
bin_logic = self.to_binary()
|
|
617
|
+
return dais_interp_run(bin_logic, data, n_threads)
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
class Pipeline(NamedTuple):
|
|
621
|
+
"""A pipeline with II=1,with each stage represented by a CombLogic
|
|
622
|
+
Attributes
|
|
623
|
+
----------
|
|
624
|
+
solutions: tuple[Solution, ...]
|
|
625
|
+
A tuple containing the individual Solution objects for each stage of the cascade.
|
|
626
|
+
|
|
627
|
+
Properties
|
|
628
|
+
----------
|
|
629
|
+
kernel: NDArray[float32]
|
|
630
|
+
Only useful when the pipeline describes a linear operation.
|
|
631
|
+
The overall kernel matrix which the cascaded solution implements: vec @ kernel = solution(vec).
|
|
632
|
+
This is calculated as the matrix product of all individual solution kernels.
|
|
633
|
+
cost: float
|
|
634
|
+
The total cost of the cascaded solution, computed as the sum of the costs of all stages.
|
|
635
|
+
latency: tuple[float, float]
|
|
636
|
+
The minimum and maximum latency of the pipeline, determined by the last stage.
|
|
637
|
+
inp_qint: list[QInterval]
|
|
638
|
+
Input quantization intervals
|
|
639
|
+
inp_lat: list[float]
|
|
640
|
+
Input latencies
|
|
641
|
+
in_shift: list[int]
|
|
642
|
+
Input shifts
|
|
643
|
+
out_qint: list[QInterval]
|
|
644
|
+
Output quantization intervals
|
|
645
|
+
out_lat: list[float]
|
|
646
|
+
Output latencies
|
|
647
|
+
out_shift: list[int]
|
|
648
|
+
Output shifts
|
|
649
|
+
out_neg: list[bool]
|
|
650
|
+
Output signs
|
|
651
|
+
shape: tuple[int, int]
|
|
652
|
+
The shape of the corresponding kernel matrix.
|
|
653
|
+
"""
|
|
654
|
+
|
|
655
|
+
solutions: tuple[CombLogic, ...]
|
|
656
|
+
|
|
657
|
+
def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False):
|
|
658
|
+
out = np.asarray(inp)
|
|
659
|
+
for sol in self.solutions:
|
|
660
|
+
out = sol(out, quantize=quantize, debug=debug)
|
|
661
|
+
return out
|
|
662
|
+
|
|
663
|
+
@property
|
|
664
|
+
def kernel(self):
|
|
665
|
+
return reduce(lambda x, y: x @ y, [sol.kernel for sol in self.solutions])
|
|
666
|
+
|
|
667
|
+
@property
|
|
668
|
+
def cost(self):
|
|
669
|
+
return sum(sol.cost for sol in self.solutions)
|
|
670
|
+
|
|
671
|
+
@property
|
|
672
|
+
def latency(self):
|
|
673
|
+
return self.solutions[-1].latency
|
|
674
|
+
|
|
675
|
+
@property
|
|
676
|
+
def inp_qint(self):
|
|
677
|
+
return self.solutions[0].inp_qint
|
|
678
|
+
|
|
679
|
+
@property
|
|
680
|
+
def inp_latency(self):
|
|
681
|
+
return self.solutions[0].inp_latency
|
|
682
|
+
|
|
683
|
+
@property
|
|
684
|
+
def out_qint(self):
|
|
685
|
+
return self.solutions[-1].out_qint
|
|
686
|
+
|
|
687
|
+
@property
|
|
688
|
+
def out_latencies(self):
|
|
689
|
+
return self.solutions[-1].out_latency
|
|
690
|
+
|
|
691
|
+
@property
|
|
692
|
+
def shape(self):
|
|
693
|
+
return self.solutions[0].shape[0], self.solutions[-1].shape[1]
|
|
694
|
+
|
|
695
|
+
@property
|
|
696
|
+
def inp_shifts(self):
|
|
697
|
+
return self.solutions[0].inp_shifts
|
|
698
|
+
|
|
699
|
+
@property
|
|
700
|
+
def out_shift(self):
|
|
701
|
+
return self.solutions[-1].out_shifts
|
|
702
|
+
|
|
703
|
+
@property
|
|
704
|
+
def out_neg(self):
|
|
705
|
+
return self.solutions[-1].out_negs
|
|
706
|
+
|
|
707
|
+
def __repr__(self) -> str:
|
|
708
|
+
n_ins = [sol.shape[0] for sol in self.solutions] + [self.shape[1]]
|
|
709
|
+
shape_str = ' -> '.join(map(str, n_ins))
|
|
710
|
+
_cost = self.cost
|
|
711
|
+
lat_min, lat_max = self.latency
|
|
712
|
+
return f'CascatedSolution([{shape_str}], cost={_cost}, latency={lat_min}-{lat_max})'
|
|
713
|
+
|
|
714
|
+
def save(self, path: str | Path):
|
|
715
|
+
"""Save the solution to a file."""
|
|
716
|
+
with open(path, 'w') as f:
|
|
717
|
+
json.dump(self, f, cls=JSONEncoder)
|
|
718
|
+
|
|
719
|
+
@classmethod
|
|
720
|
+
def deserialize(cls, data: dict):
|
|
721
|
+
"""Load the solution from a file."""
|
|
722
|
+
return cls(solutions=tuple(CombLogic.deserialize(sol) for sol in data[0]))
|
|
723
|
+
|
|
724
|
+
@classmethod
|
|
725
|
+
def load(cls, path: str):
|
|
726
|
+
"""Load the solution from a file."""
|
|
727
|
+
with open(path) as f:
|
|
728
|
+
data = json.load(f)
|
|
729
|
+
return cls.deserialize(data)
|
|
730
|
+
|
|
731
|
+
@property
|
|
732
|
+
def reg_bits(self):
|
|
733
|
+
"""The number of bits used for the register in the solution."""
|
|
734
|
+
bits = sum(map(sum, (_minimal_kif(qint) for qint in self.inp_qint)))
|
|
735
|
+
for _sol in self.solutions:
|
|
736
|
+
kifs = [_minimal_kif(qint) for qint in _sol.out_qint]
|
|
737
|
+
_bits = sum(map(sum, kifs))
|
|
738
|
+
bits += _bits
|
|
739
|
+
return bits
|