da4ml 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +16 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +3 -34
- da4ml/cmvm/api.py +235 -73
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +569 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +11 -0
- da4ml/codegen/cpp/__init__.py +3 -0
- da4ml/codegen/cpp/cpp_codegen.py +148 -0
- da4ml/codegen/cpp/source/vitis.h +30 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
- da4ml/codegen/verilog/__init__.py +13 -0
- da4ml/codegen/verilog/comb.py +146 -0
- da4ml/codegen/verilog/io_wrapper.py +255 -0
- da4ml/codegen/verilog/pipeline.py +67 -0
- da4ml/codegen/verilog/source/build_binder.mk +27 -0
- da4ml/codegen/verilog/source/build_prj.tcl +74 -0
- da4ml/codegen/verilog/source/ioutils.hh +117 -0
- da4ml/codegen/verilog/source/shift_adder.v +56 -0
- da4ml/codegen/verilog/source/template.xdc +29 -0
- da4ml/codegen/verilog/verilog_model.py +268 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +358 -0
- da4ml/trace/fixed_variable_array.py +187 -0
- da4ml/trace/ops/__init__.py +55 -0
- da4ml/trace/ops/conv_utils.py +104 -0
- da4ml/trace/ops/einsum_utils.py +299 -0
- da4ml/trace/pipeline.py +155 -0
- da4ml/trace/tracer.py +122 -0
- da4ml-0.2.1.dist-info/METADATA +65 -0
- da4ml-0.2.1.dist-info/RECORD +39 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/WHEEL +1 -1
- da4ml/cmvm/balanced_reduction.py +0 -46
- da4ml/cmvm/cmvm.py +0 -328
- da4ml/cmvm/codegen.py +0 -159
- da4ml/cmvm/csd.py +0 -73
- da4ml/cmvm/fixed_variable.py +0 -205
- da4ml/cmvm/graph_compile.py +0 -85
- da4ml/cmvm/nb_fixed_precision.py +0 -98
- da4ml/cmvm/scoring.py +0 -55
- da4ml/cmvm/utils.py +0 -5
- da4ml-0.1.2.dist-info/METADATA +0 -122
- da4ml-0.1.2.dist-info/RECORD +0 -18
- {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/top_level.txt +0 -0
da4ml/cmvm/types.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from decimal import Decimal
|
|
3
|
+
from functools import reduce, singledispatch
|
|
4
|
+
from math import ceil, floor, log2
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, NamedTuple, TypeVar
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from numba import jit
|
|
10
|
+
from numpy import float32, int8
|
|
11
|
+
from numpy.typing import NDArray
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..trace.tracer import FixedVariable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class QInterval(NamedTuple):
|
|
18
|
+
"""A class representing a quantized interval: [min, max] with a step size."""
|
|
19
|
+
|
|
20
|
+
min: float
|
|
21
|
+
max: float
|
|
22
|
+
step: float
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def from_kif(cls, k: int | bool, i: int, f: int):
|
|
26
|
+
_high = 2.0**i
|
|
27
|
+
step = 2.0**-f
|
|
28
|
+
low, high = -k * step, _high - step
|
|
29
|
+
return cls(low, high, step)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_precision(cls, prec: 'Precision'):
|
|
33
|
+
return cls.from_kif(*prec)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def precision(self):
|
|
37
|
+
return Precision.from_qint(self)
|
|
38
|
+
|
|
39
|
+
def __repr__(self):
|
|
40
|
+
return f'[{self.min}, {self.max}, {self.step}]'
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Precision(NamedTuple):
|
|
44
|
+
"""A class representing the precision of a quantized interval."""
|
|
45
|
+
|
|
46
|
+
keep_negative: bool
|
|
47
|
+
integers: int
|
|
48
|
+
fractional: int
|
|
49
|
+
|
|
50
|
+
def __str__(self):
|
|
51
|
+
k, i, f = self.keep_negative, self.integers, self.fractional
|
|
52
|
+
k, B, I = k, i + f + k, i + k
|
|
53
|
+
return f'fixed({k}, {B}, {I})'
|
|
54
|
+
|
|
55
|
+
def __repr__(self):
|
|
56
|
+
return str(self)
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_qint(cls, qint: QInterval, symmetric: bool = False):
|
|
60
|
+
return _minimal_kif(qint, symmetric=symmetric)
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def qint(self):
|
|
64
|
+
return QInterval.from_kif(*self)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Op(NamedTuple):
|
|
68
|
+
"""One single operation on the data buffer.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
id0: int
|
|
73
|
+
index of the first operand
|
|
74
|
+
id1: int
|
|
75
|
+
index of the second operand, or special opcode if negative
|
|
76
|
+
opcode: int
|
|
77
|
+
0: addition, 1: subtraction, 2: relu, 3: quantize, 4: const addition
|
|
78
|
+
data: int
|
|
79
|
+
Data to be used in the operation
|
|
80
|
+
qint: QInterval
|
|
81
|
+
Quantization interval of the resultant buffer
|
|
82
|
+
latency: float
|
|
83
|
+
Latency of the data generated by this operation (t_available)
|
|
84
|
+
cost: float
|
|
85
|
+
Cost of the operation
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
id0: int
|
|
89
|
+
id1: int
|
|
90
|
+
opcode: int
|
|
91
|
+
data: int
|
|
92
|
+
qint: QInterval
|
|
93
|
+
latency: float
|
|
94
|
+
cost: float
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Pair(NamedTuple):
|
|
98
|
+
"""An operation representing data[id0] +/- data[id1] * 2**shift."""
|
|
99
|
+
|
|
100
|
+
id0: int
|
|
101
|
+
id1: int
|
|
102
|
+
sub: bool
|
|
103
|
+
shift: int
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class DAState(NamedTuple):
|
|
107
|
+
"""Internal state of the DA algorithm."""
|
|
108
|
+
|
|
109
|
+
shifts: tuple[NDArray[int8], NDArray[int8]]
|
|
110
|
+
expr: list[NDArray[int8]]
|
|
111
|
+
ops: list[Op]
|
|
112
|
+
freq_stat: dict[Pair, int]
|
|
113
|
+
kernel: NDArray[float32]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision:
|
|
117
|
+
"""Calculate the minimal KIF for a given QInterval.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
qi : QInterval
|
|
122
|
+
The QInterval to calculate the KIF for.
|
|
123
|
+
symmetric : bool
|
|
124
|
+
Only relevant if qi may be negative. If True, -2**i will be regarded as forbidden.
|
|
125
|
+
May be useful in special cases only.
|
|
126
|
+
Default is False.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
Precision
|
|
131
|
+
A named tuple with the KIF values.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
if qi.min == qi.max == 0:
|
|
135
|
+
return Precision(keep_negative=False, integers=0, fractional=0)
|
|
136
|
+
keep_negative = qi.min < 0
|
|
137
|
+
fractional = int(-log2(qi.step))
|
|
138
|
+
int_min, int_max = round(qi.min / qi.step), round(qi.max / qi.step)
|
|
139
|
+
if symmetric:
|
|
140
|
+
bits = int(ceil(log2(max(abs(int_min), int_max) + 1)))
|
|
141
|
+
else:
|
|
142
|
+
bits = int(ceil(log2(max(abs(int_min), int_max + 1))))
|
|
143
|
+
integers = bits - fractional
|
|
144
|
+
return Precision(keep_negative=keep_negative, integers=integers, fractional=fractional)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
if TYPE_CHECKING:
|
|
148
|
+
|
|
149
|
+
def minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision: ...
|
|
150
|
+
else:
|
|
151
|
+
minimal_kif = jit(_minimal_kif)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
T = TypeVar('T', 'FixedVariable', float, int, np.float32, np.float64, Decimal)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@singledispatch
|
|
158
|
+
def _relu(v: 'T', i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN') -> 'T':
|
|
159
|
+
from ..trace.fixed_variable import FixedVariable
|
|
160
|
+
|
|
161
|
+
assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic relu'
|
|
162
|
+
if inv:
|
|
163
|
+
v = -v
|
|
164
|
+
return v.relu(i, f, round_mode=round_mode)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@_relu.register(float)
|
|
168
|
+
@_relu.register(int)
|
|
169
|
+
@_relu.register(np.float32)
|
|
170
|
+
@_relu.register(np.float64)
|
|
171
|
+
def _(v, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
|
|
172
|
+
if inv:
|
|
173
|
+
v = -v
|
|
174
|
+
v = max(0, v)
|
|
175
|
+
if f is not None:
|
|
176
|
+
if round_mode.upper() == 'RND':
|
|
177
|
+
v += 2.0 ** (-f - 1)
|
|
178
|
+
sf = 2.0**f
|
|
179
|
+
v = floor(v * sf) / sf
|
|
180
|
+
if i is not None:
|
|
181
|
+
v = v % 2.0**i
|
|
182
|
+
return v
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@_relu.register
|
|
186
|
+
def _(v: Decimal, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
|
|
187
|
+
if inv:
|
|
188
|
+
v = -v
|
|
189
|
+
v = max(Decimal(0), v)
|
|
190
|
+
if f is not None:
|
|
191
|
+
if round_mode.upper() == 'RND':
|
|
192
|
+
v += Decimal(2) ** (-f - 1)
|
|
193
|
+
sf = Decimal(2) ** f
|
|
194
|
+
v = floor(v * sf) / sf
|
|
195
|
+
if i is not None:
|
|
196
|
+
v = v % Decimal(2) ** i
|
|
197
|
+
return v
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@singledispatch
|
|
201
|
+
def _quantize(v: 'T', k: int | bool, i: int, f: int, round_mode: str = 'TRN') -> 'T':
|
|
202
|
+
from ..trace.fixed_variable import FixedVariable
|
|
203
|
+
|
|
204
|
+
assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic quantization'
|
|
205
|
+
return v.quantize(k, i, f, round_mode=round_mode)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@_quantize.register(float)
|
|
209
|
+
@_quantize.register(int)
|
|
210
|
+
@_quantize.register(np.float32)
|
|
211
|
+
@_quantize.register(np.float64)
|
|
212
|
+
def _(v, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
|
|
213
|
+
if round_mode.upper() == 'RND':
|
|
214
|
+
v += 2.0 ** (-f - 1)
|
|
215
|
+
b = k + i + f
|
|
216
|
+
bias = 2.0 ** (b - 1) * k
|
|
217
|
+
eps = 2.0**-f
|
|
218
|
+
return eps * ((np.floor(v / eps) + bias) % 2**b - bias)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@_quantize.register
|
|
222
|
+
def _(v: Decimal, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
|
|
223
|
+
if round_mode.upper() == 'RND':
|
|
224
|
+
v += Decimal(2) ** (-f - 1)
|
|
225
|
+
b = k + i + f
|
|
226
|
+
bias = Decimal(2) ** (b - 1) * k
|
|
227
|
+
eps = Decimal(2) ** -f
|
|
228
|
+
return eps * ((floor(v / eps) + bias) % Decimal(2) ** b - bias)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class Solution(NamedTuple):
|
|
232
|
+
"""Represents a series of operations that can be applied to a vector of data.
|
|
233
|
+
May represent a CMVM solution or a general neural network
|
|
234
|
+
|
|
235
|
+
Attributes
|
|
236
|
+
----------
|
|
237
|
+
shape: tuple[int, int]
|
|
238
|
+
#input, #output
|
|
239
|
+
inp_shift: list[int]
|
|
240
|
+
The shifts that should be applied to the input data.
|
|
241
|
+
out_idxs: list[int]
|
|
242
|
+
The indices of the output data in the buffer.
|
|
243
|
+
out_shifts: list[int]
|
|
244
|
+
The shifts that should be applied to the output data.
|
|
245
|
+
out_negs: list[bool]
|
|
246
|
+
The signs of the output data.
|
|
247
|
+
ops: list[Op]
|
|
248
|
+
Core list of operations for generating each buffer element.
|
|
249
|
+
carry_size: int
|
|
250
|
+
Size of the carrier for the adder.
|
|
251
|
+
adder_size: int
|
|
252
|
+
Elementary size of the adder.
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
The core part of the solution is the operations in the ops list.
|
|
256
|
+
For the exact operations executed with Op, refer to the Op class.
|
|
257
|
+
After all operations are executed, the output data is read from data[op.out_idx] and multiplied by 2**out_shift.
|
|
258
|
+
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
shape: tuple[int, int]
|
|
262
|
+
inp_shift: list[int]
|
|
263
|
+
out_idxs: list[int]
|
|
264
|
+
out_shifts: list[int]
|
|
265
|
+
out_negs: list[bool]
|
|
266
|
+
ops: list[Op]
|
|
267
|
+
carry_size: int
|
|
268
|
+
adder_size: int
|
|
269
|
+
|
|
270
|
+
def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False, dump=False):
|
|
271
|
+
"""Executes the solution on the input data.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
inp : list | np.ndarray | tuple
|
|
276
|
+
Input data to be processed. The input data should be a list or numpy array of objects.
|
|
277
|
+
quantize : bool
|
|
278
|
+
If True, the input data will be quantized to the output quantization intervals.
|
|
279
|
+
Only floating point data types are supported when quantize is True.
|
|
280
|
+
Default is False.
|
|
281
|
+
debug : bool
|
|
282
|
+
If True, the function will print debug information about the operations being performed.
|
|
283
|
+
Default is False.
|
|
284
|
+
dump : bool
|
|
285
|
+
If True, the return the whole buffer, without applying the output shifts and signs.
|
|
286
|
+
Default is False.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
np.ndarray
|
|
291
|
+
The output data after applying the operations defined in the solution.
|
|
292
|
+
|
|
293
|
+
"""
|
|
294
|
+
buf = np.empty(len(self.ops), dtype=object)
|
|
295
|
+
inp = np.asarray(inp)
|
|
296
|
+
|
|
297
|
+
inp_qint = [op.qint for op in self.ops if op.opcode == -1]
|
|
298
|
+
if quantize: # TRN and WRAP
|
|
299
|
+
k, i, f = map(np.array, zip(*map(minimal_kif, inp_qint)))
|
|
300
|
+
inp = [_quantize(*x, round_mode='TRN') for x in zip(inp, k, i, f)]
|
|
301
|
+
|
|
302
|
+
inp = inp * (2.0 ** np.array(self.inp_shift))
|
|
303
|
+
for i, op in enumerate(self.ops):
|
|
304
|
+
match op.opcode:
|
|
305
|
+
case -1: # copy form external buffer
|
|
306
|
+
buf[i] = inp[op.id0]
|
|
307
|
+
case 0 | 1: # addition
|
|
308
|
+
v0, v1 = buf[op.id0], 2.0**op.data * buf[op.id1]
|
|
309
|
+
buf[i] = v0 + v1 if op.opcode == 0 else v0 - v1
|
|
310
|
+
case 2 | -2: # relu(+/-x)
|
|
311
|
+
v = buf[op.id0]
|
|
312
|
+
_, _i, _f = _minimal_kif(op.qint)
|
|
313
|
+
buf[i] = _relu(v, _i, _f, inv=op.opcode == -2, round_mode='TRN')
|
|
314
|
+
case 3 | -3: # quantize(+/-x)
|
|
315
|
+
v = buf[op.id0] if op.opcode == 3 else -buf[op.id0]
|
|
316
|
+
_k, _i, _f = _minimal_kif(op.qint)
|
|
317
|
+
buf[i] = _quantize(v, _k, _i, _f, round_mode='TRN')
|
|
318
|
+
case 4: # const addition
|
|
319
|
+
bias = op.data * op.qint.step
|
|
320
|
+
buf[i] = buf[op.id0] + bias
|
|
321
|
+
case 5:
|
|
322
|
+
buf[i] = op.data * op.qint.step # const definition
|
|
323
|
+
case _:
|
|
324
|
+
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
325
|
+
|
|
326
|
+
sf = 2.0 ** np.array(self.out_shifts)
|
|
327
|
+
sign = np.where(self.out_negs, -1, 1)
|
|
328
|
+
out_idx = np.array(self.out_idxs)
|
|
329
|
+
mask = np.where(out_idx < 0, 0, 1)
|
|
330
|
+
if debug:
|
|
331
|
+
for i, v in enumerate(buf):
|
|
332
|
+
op = self.ops[i]
|
|
333
|
+
match op.opcode:
|
|
334
|
+
case -1:
|
|
335
|
+
op_str = 'inp'
|
|
336
|
+
case 0:
|
|
337
|
+
op_str = f'buf[{op.id0}] + buf[{op.id1}]<<{op.data}'
|
|
338
|
+
case 1:
|
|
339
|
+
op_str = f'buf[{op.id0}] - buf[{op.id1}]<<{op.data}'
|
|
340
|
+
case 2:
|
|
341
|
+
op_str = f'relu(buf[{op.id0}])'
|
|
342
|
+
case -2:
|
|
343
|
+
op_str = f'relu(-buf[{op.id0}])'
|
|
344
|
+
case 3:
|
|
345
|
+
op_str = f'quantize(buf[{op.id0}])'
|
|
346
|
+
case -3:
|
|
347
|
+
op_str = f'quantize(-buf[{op.id0}])'
|
|
348
|
+
case 4:
|
|
349
|
+
op_str = f'buf[{op.id0}] + {op.data * op.qint.step}'
|
|
350
|
+
case 5:
|
|
351
|
+
op_str = f'const {op.data * op.qint.step}'
|
|
352
|
+
case _:
|
|
353
|
+
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
354
|
+
|
|
355
|
+
print(f'{op_str:24} |-> buf[{i}] = {v}')
|
|
356
|
+
|
|
357
|
+
if dump:
|
|
358
|
+
return buf
|
|
359
|
+
return buf[out_idx] * sf * sign * mask
|
|
360
|
+
|
|
361
|
+
@property
|
|
362
|
+
def kernel(self):
|
|
363
|
+
"""the kernel represented by the solution, when applicable."""
|
|
364
|
+
kernel = np.empty(self.shape, dtype=np.float32)
|
|
365
|
+
for i, one_hot in enumerate(np.identity(self.shape[0])):
|
|
366
|
+
kernel[i] = self(one_hot)
|
|
367
|
+
return kernel
|
|
368
|
+
|
|
369
|
+
@property
|
|
370
|
+
def cost(self):
|
|
371
|
+
"""Total cost of the solution."""
|
|
372
|
+
return float(sum(op.cost for op in self.ops))
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def latency(self):
|
|
376
|
+
"""Minimum and maximum latency of the solution."""
|
|
377
|
+
latency = [self.ops[i].latency for i in self.out_idxs]
|
|
378
|
+
if len(latency) == 0:
|
|
379
|
+
return 0.0, 0.0
|
|
380
|
+
return min(latency), max(latency)
|
|
381
|
+
|
|
382
|
+
def __repr__(self):
|
|
383
|
+
n_in, n_out = self.shape
|
|
384
|
+
cost = self.cost
|
|
385
|
+
lat_min, lat_max = self.latency
|
|
386
|
+
return f'Solution([{n_in} -> {n_out}], cost={cost}, latency={lat_min}-{lat_max})'
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def out_latency(self):
|
|
390
|
+
"""Latencies of all output elements of the solution."""
|
|
391
|
+
return [self.ops[i].latency if i >= 0 else 0.0 for i in self.out_idxs]
|
|
392
|
+
|
|
393
|
+
@property
|
|
394
|
+
def out_qint(self):
|
|
395
|
+
"""Quantization intervals of the output elements."""
|
|
396
|
+
buf = []
|
|
397
|
+
for i, idx in enumerate(self.out_idxs):
|
|
398
|
+
_min, _max, _step = self.ops[idx].qint
|
|
399
|
+
sf = 2.0 ** self.out_shifts[i]
|
|
400
|
+
_min, _max, _step = _min * sf, _max * sf, _step * sf
|
|
401
|
+
if self.out_negs[i]:
|
|
402
|
+
_min, _max = -_max, -_min
|
|
403
|
+
buf.append(QInterval(_min, _max, _step))
|
|
404
|
+
return buf
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def inp_latency(self):
|
|
408
|
+
"""Latencies of all input elements of the solution."""
|
|
409
|
+
return [op.latency for op in self.ops if op.opcode == -1]
|
|
410
|
+
|
|
411
|
+
@property
|
|
412
|
+
def inp_qint(self):
|
|
413
|
+
"""Quantization intervals of the input elements."""
|
|
414
|
+
return [op.qint for op in self.ops if op.opcode == -1]
|
|
415
|
+
|
|
416
|
+
def save(self, path: str | Path):
|
|
417
|
+
"""Save the solution to a file."""
|
|
418
|
+
with open(path, 'w') as f:
|
|
419
|
+
json.dump(self, f)
|
|
420
|
+
|
|
421
|
+
@classmethod
|
|
422
|
+
def deserialize(cls, data: dict):
|
|
423
|
+
"""Load the solution from a file."""
|
|
424
|
+
ops = []
|
|
425
|
+
for _op in data[5]:
|
|
426
|
+
op = Op(*_op[:4], QInterval(*_op[4]), *_op[5:]) # type: ignore
|
|
427
|
+
ops.append(op)
|
|
428
|
+
return cls(
|
|
429
|
+
shape=tuple(data[0]),
|
|
430
|
+
inp_shift=data[1],
|
|
431
|
+
out_idxs=data[2],
|
|
432
|
+
out_shifts=data[3],
|
|
433
|
+
out_negs=data[4],
|
|
434
|
+
ops=ops,
|
|
435
|
+
carry_size=data[6],
|
|
436
|
+
adder_size=data[7],
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
@classmethod
|
|
440
|
+
def load(cls, path: str | Path):
|
|
441
|
+
"""Load the solution from a file."""
|
|
442
|
+
with open(path) as f:
|
|
443
|
+
data = json.load(f)
|
|
444
|
+
return cls.deserialize(data)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class CascadedSolution(NamedTuple):
|
|
448
|
+
"""A solution that implements cascaded matrix-vector multiplications through multiple CMVM stages.
|
|
449
|
+
|
|
450
|
+
CascadedSolution represents a sequence of Solution objects where the output of each stage
|
|
451
|
+
is fed as input to the next stage.
|
|
452
|
+
|
|
453
|
+
Attributes
|
|
454
|
+
----------
|
|
455
|
+
solutions: tuple[Solution, ...]
|
|
456
|
+
A tuple containing the individual Solution objects for each stage of the cascade.
|
|
457
|
+
|
|
458
|
+
Properties
|
|
459
|
+
----------
|
|
460
|
+
kernel: NDArray[float32]
|
|
461
|
+
The overall kernel matrix which the cascaded solution implements: vec @ kernel = solution(vec).
|
|
462
|
+
This is calculated as the matrix product of all individual solution kernels.
|
|
463
|
+
cost: float
|
|
464
|
+
The total cost of the cascaded solution, computed as the sum of the costs of all stages.
|
|
465
|
+
latency: tuple[float, float]
|
|
466
|
+
The minimum and maximum latency of the cascaded solution.
|
|
467
|
+
inp_qint: list[QInterval]
|
|
468
|
+
Input quantization intervals
|
|
469
|
+
inp_lat: list[float]
|
|
470
|
+
Input latencies
|
|
471
|
+
in_shift: list[int]
|
|
472
|
+
Input shifts
|
|
473
|
+
out_qint: list[QInterval]
|
|
474
|
+
Output quantization intervals
|
|
475
|
+
out_lat: list[float]
|
|
476
|
+
Output latencies
|
|
477
|
+
out_shift: list[int]
|
|
478
|
+
Output shifts
|
|
479
|
+
out_neg: list[bool]
|
|
480
|
+
Output signs
|
|
481
|
+
shape: tuple[int, int]
|
|
482
|
+
The shape of the corresponding kernel matrix.
|
|
483
|
+
"""
|
|
484
|
+
|
|
485
|
+
solutions: tuple[Solution, ...]
|
|
486
|
+
|
|
487
|
+
def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False):
|
|
488
|
+
out = np.asarray(inp)
|
|
489
|
+
for sol in self.solutions:
|
|
490
|
+
out = sol(out, quantize=quantize, debug=debug)
|
|
491
|
+
return out
|
|
492
|
+
|
|
493
|
+
@property
|
|
494
|
+
def kernel(self):
|
|
495
|
+
return reduce(lambda x, y: x @ y, [sol.kernel for sol in self.solutions])
|
|
496
|
+
|
|
497
|
+
@property
|
|
498
|
+
def cost(self):
|
|
499
|
+
return sum(sol.cost for sol in self.solutions)
|
|
500
|
+
|
|
501
|
+
@property
|
|
502
|
+
def latency(self):
|
|
503
|
+
return self.solutions[-1].latency
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def inp_qint(self):
|
|
507
|
+
return self.solutions[0].inp_qint
|
|
508
|
+
|
|
509
|
+
@property
|
|
510
|
+
def inp_latency(self):
|
|
511
|
+
return self.solutions[0].inp_latency
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def out_qint(self):
|
|
515
|
+
return self.solutions[-1].out_qint
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def out_latencies(self):
|
|
519
|
+
return self.solutions[-1].out_latency
|
|
520
|
+
|
|
521
|
+
@property
|
|
522
|
+
def shape(self):
|
|
523
|
+
return self.solutions[0].shape[0], self.solutions[-1].shape[1]
|
|
524
|
+
|
|
525
|
+
@property
|
|
526
|
+
def inp_shift(self):
|
|
527
|
+
return self.solutions[0].inp_shift
|
|
528
|
+
|
|
529
|
+
@property
|
|
530
|
+
def out_shift(self):
|
|
531
|
+
return self.solutions[-1].out_shifts
|
|
532
|
+
|
|
533
|
+
@property
|
|
534
|
+
def out_neg(self):
|
|
535
|
+
return self.solutions[-1].out_negs
|
|
536
|
+
|
|
537
|
+
def __repr__(self) -> str:
|
|
538
|
+
n_ins = [sol.shape[0] for sol in self.solutions] + [self.shape[1]]
|
|
539
|
+
shape_str = ' -> '.join(map(str, n_ins))
|
|
540
|
+
_cost = self.cost
|
|
541
|
+
lat_min, lat_max = self.latency
|
|
542
|
+
return f'CascatedSolution([{shape_str}], cost={_cost}, latency={lat_min}-{lat_max})'
|
|
543
|
+
|
|
544
|
+
def save(self, path: str | Path):
|
|
545
|
+
"""Save the solution to a file."""
|
|
546
|
+
with open(path, 'w') as f:
|
|
547
|
+
json.dump(self, f)
|
|
548
|
+
|
|
549
|
+
@classmethod
|
|
550
|
+
def deserialize(cls, data: dict):
|
|
551
|
+
"""Load the solution from a file."""
|
|
552
|
+
return cls(solutions=tuple(Solution.deserialize(sol) for sol in data[0]))
|
|
553
|
+
|
|
554
|
+
@classmethod
|
|
555
|
+
def load(cls, path: str):
|
|
556
|
+
"""Load the solution from a file."""
|
|
557
|
+
with open(path) as f:
|
|
558
|
+
data = json.load(f)
|
|
559
|
+
return cls.deserialize(data)
|
|
560
|
+
|
|
561
|
+
@property
|
|
562
|
+
def reg_bits(self):
|
|
563
|
+
"""The number of bits used for the register in the solution."""
|
|
564
|
+
bits = sum(map(sum, (_minimal_kif(qint) for qint in self.inp_qint)))
|
|
565
|
+
for _sol in self.solutions:
|
|
566
|
+
kifs = [_minimal_kif(qint) for qint in _sol.out_qint]
|
|
567
|
+
_bits = sum(map(sum, kifs))
|
|
568
|
+
bits += _bits
|
|
569
|
+
return bits
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numba import jit
|
|
3
|
+
from numpy.typing import NDArray
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@jit
|
|
7
|
+
def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
|
|
8
|
+
x = x
|
|
9
|
+
N = np.max(np.ceil(np.log2(np.abs(x) * 1.5 + 1e-19)))
|
|
10
|
+
N = int(max(N, 1))
|
|
11
|
+
buf = np.zeros((*np.shape(x), N), dtype=np.int8)
|
|
12
|
+
|
|
13
|
+
for n in range(N - 1, -1, -1):
|
|
14
|
+
_2pn = 2**n
|
|
15
|
+
thres = _2pn / 1.5
|
|
16
|
+
bit = (x > thres).astype(np.int8)
|
|
17
|
+
bit -= (x < -thres).astype(np.int8)
|
|
18
|
+
x -= _2pn * bit
|
|
19
|
+
buf[..., n] = bit
|
|
20
|
+
return buf
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@jit(error_model='numpy')
|
|
24
|
+
def _shift_centering(arr: NDArray):
|
|
25
|
+
low, high = -64, 64
|
|
26
|
+
if np.all(arr == 0):
|
|
27
|
+
high = low = 0
|
|
28
|
+
while high - low > 1:
|
|
29
|
+
mid = (high + low) // 2
|
|
30
|
+
xs = arr * (2.0**mid)
|
|
31
|
+
if np.all(xs == np.floor(xs)):
|
|
32
|
+
high = mid
|
|
33
|
+
else:
|
|
34
|
+
low = mid
|
|
35
|
+
return -high
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@jit(error_model='numpy')
|
|
39
|
+
def shift_centering(arr: NDArray, axis: int):
|
|
40
|
+
n = arr.shape[axis]
|
|
41
|
+
shifts = np.empty(n, dtype=np.int8)
|
|
42
|
+
for i in range(n):
|
|
43
|
+
shifts[i] = _shift_centering(arr.take(i, axis=axis))
|
|
44
|
+
return shifts
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@jit
|
|
48
|
+
def _center(arr: NDArray):
|
|
49
|
+
shift1 = shift_centering(arr, 1) # d_out
|
|
50
|
+
arr = arr * (2.0**-shift1)
|
|
51
|
+
shift0 = shift_centering(arr, 0) # d_in
|
|
52
|
+
arr = arr * (2.0 ** -shift0[:, None])
|
|
53
|
+
return arr, shift0, shift1
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@jit
|
|
57
|
+
def csd_decompose(arr: NDArray, center=True):
|
|
58
|
+
"""
|
|
59
|
+
Convert an 2D array to CSD representation.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
arr : ndarray
|
|
64
|
+
Input array to be converted.
|
|
65
|
+
center : bool, optional
|
|
66
|
+
If True, the array is centered before conversion. Default is True.
|
|
67
|
+
If False, the function may accept non-2D arrays.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
csd : ndarray
|
|
72
|
+
CSD representation of the input array after centering, if center is True.
|
|
73
|
+
shift0 : ndarray
|
|
74
|
+
Shift values for the first axis.
|
|
75
|
+
shift1 : ndarray
|
|
76
|
+
Shift values for the second axis.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
if center:
|
|
80
|
+
arr, shift0, shift1 = _center(arr)
|
|
81
|
+
else:
|
|
82
|
+
shift0 = np.zeros(arr.shape[0], dtype=np.int8)
|
|
83
|
+
shift1 = np.zeros(arr.shape[1], dtype=np.int8)
|
|
84
|
+
arr = arr.copy()
|
|
85
|
+
csd = _volatile_int_arr_to_csd(arr)
|
|
86
|
+
return csd, shift0, shift1
|