da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +194 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +240 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.0.dist-info/METADATA +85 -0
- da4ml-0.5.0.dist-info/RECORD +96 -0
- da4ml-0.5.0.dist-info/WHEEL +6 -0
- da4ml-0.5.0.dist-info/entry_points.txt +3 -0
- da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
from math import prod
|
|
2
|
+
from typing import TYPE_CHECKING, TypedDict, overload
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EinsumRecipe(TypedDict):
|
|
12
|
+
direct_sum_axis: tuple[tuple[int, ...], tuple[int, ...]]
|
|
13
|
+
in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]]
|
|
14
|
+
L0: int
|
|
15
|
+
L1: int
|
|
16
|
+
I: int
|
|
17
|
+
C: int
|
|
18
|
+
out_interpert_shape: tuple[int, ...]
|
|
19
|
+
out_transpose_idxs: tuple[int, ...]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]):
|
|
23
|
+
"""Validate, resolve broadcasting, and compute output shape for einsum string
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
fn : str
|
|
28
|
+
einsum string, e.g. 'ij,jk->ik'
|
|
29
|
+
shape0 : tuple[int,...]
|
|
30
|
+
shape of input0
|
|
31
|
+
shape1 : tuple[int,...]
|
|
32
|
+
shape of input1
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
tuple[str, tuple[int,...]]
|
|
37
|
+
einsum string w/o broadcasting, and output shape
|
|
38
|
+
|
|
39
|
+
Raises
|
|
40
|
+
------
|
|
41
|
+
ValueError
|
|
42
|
+
If the einsum string is invalid, or if it is incompatible with the input shapes
|
|
43
|
+
"""
|
|
44
|
+
inp, out = map(str.strip, fn.split('->'))
|
|
45
|
+
in0, in1 = map(str.strip, inp.split(','))
|
|
46
|
+
alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|
47
|
+
s_alphabets = set(alphabets)
|
|
48
|
+
|
|
49
|
+
# Invalid characters
|
|
50
|
+
if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))):
|
|
51
|
+
raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only")
|
|
52
|
+
|
|
53
|
+
in0 = in0.replace('...', '0')
|
|
54
|
+
in1 = in1.replace('...', '0')
|
|
55
|
+
out = out.replace('...', '0')
|
|
56
|
+
ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
|
|
57
|
+
sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out)
|
|
58
|
+
free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out))
|
|
59
|
+
|
|
60
|
+
# Repeated indices
|
|
61
|
+
if len(sax_in0) != len(ax_in0):
|
|
62
|
+
for a in in0:
|
|
63
|
+
if in0.count(a) == 1:
|
|
64
|
+
continue
|
|
65
|
+
a = a if a != '0' else '...'
|
|
66
|
+
raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times")
|
|
67
|
+
if len(sax_in1) != len(ax_in1):
|
|
68
|
+
for a in in1:
|
|
69
|
+
if in1.count(a) == 1:
|
|
70
|
+
continue
|
|
71
|
+
a = a if a != '0' else '...'
|
|
72
|
+
raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times")
|
|
73
|
+
if len(sax_out) != len(ax_out):
|
|
74
|
+
for a in out:
|
|
75
|
+
if out.count(a) == 1:
|
|
76
|
+
continue
|
|
77
|
+
a = a if a != '0' else '...'
|
|
78
|
+
raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times")
|
|
79
|
+
|
|
80
|
+
# Invalid broadcasting
|
|
81
|
+
if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out:
|
|
82
|
+
if '0' not in sax_out:
|
|
83
|
+
raise ValueError(f'einsum string {fn} is invalid: output does not allow broadcasting, but inputs do')
|
|
84
|
+
if '0' not in sax_in0 and '0' not in sax_in1:
|
|
85
|
+
raise ValueError(f'einsum string {fn} is invalid: output allows broadcasting, but inputs do not')
|
|
86
|
+
|
|
87
|
+
# Output index out of nowhere
|
|
88
|
+
if remaining := sax_out - sax_in0 - sax_in1:
|
|
89
|
+
raise ValueError(f'einsum string {fn} is invalid: output subscripts {remaining} not found in inputs')
|
|
90
|
+
|
|
91
|
+
_common_in = sax_in0 & sax_in1
|
|
92
|
+
|
|
93
|
+
if '0' in sax_in0 and '0' in sax_in1:
|
|
94
|
+
# Simultaneous axes expansion in both inputs
|
|
95
|
+
n_boardcast0 = len(shape0) - len(sax_in0) + 1
|
|
96
|
+
n_boardcast1 = len(shape1) - len(sax_in1) + 1
|
|
97
|
+
assert n_boardcast0 == n_boardcast1, f'... expands to {n_boardcast0} and {n_boardcast1}-axis in input0 and input1.'
|
|
98
|
+
# Replace expansion indices with free indices
|
|
99
|
+
in0 = in0.replace('0', free_indices[:n_boardcast0])
|
|
100
|
+
in1 = in1.replace('0', free_indices[:n_boardcast1])
|
|
101
|
+
out = out.replace('0', free_indices[:n_boardcast0])
|
|
102
|
+
ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
|
|
103
|
+
_common_in = set(ax_in0) & set(ax_in1)
|
|
104
|
+
|
|
105
|
+
else:
|
|
106
|
+
# Axes expansion in input0 or input1 only
|
|
107
|
+
if '0' in sax_in0:
|
|
108
|
+
if len(sax_in0) - 1 > len(shape0):
|
|
109
|
+
raise ValueError(f'Input0 requires at least {len(sax_in0) - 1} dimensions, but only {len(shape0)} given')
|
|
110
|
+
# Replace auto expansion indices with free indices
|
|
111
|
+
n_broadcast = len(shape0) - len(sax_in0) + 1
|
|
112
|
+
in0 = in0.replace('0', free_indices[:n_broadcast])
|
|
113
|
+
out = out.replace('0', free_indices[:n_broadcast])
|
|
114
|
+
ax_in0 = list(in0)
|
|
115
|
+
ax_out = list(out)
|
|
116
|
+
else:
|
|
117
|
+
if len(sax_in0) != len(shape0):
|
|
118
|
+
raise ValueError(f'Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given')
|
|
119
|
+
|
|
120
|
+
if '0' in sax_in1:
|
|
121
|
+
if len(sax_in1) - 1 > len(shape1):
|
|
122
|
+
raise ValueError(f'Input1 requires at least {len(sax_in1) - 1} dimensions, but only {len(shape1)} given')
|
|
123
|
+
# Replace expansion indices with free indices
|
|
124
|
+
n_broadcast = len(shape1) - len(sax_in1) + 1
|
|
125
|
+
in1 = in1.replace('0', free_indices[:n_broadcast])
|
|
126
|
+
out = out.replace('0', free_indices[:n_broadcast])
|
|
127
|
+
ax_in1 = list(in1)
|
|
128
|
+
ax_out = list(out)
|
|
129
|
+
else:
|
|
130
|
+
if len(sax_in1) != len(shape1):
|
|
131
|
+
raise ValueError(f'Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given')
|
|
132
|
+
|
|
133
|
+
# Input dimension mismatch
|
|
134
|
+
for a in _common_in:
|
|
135
|
+
ax_0 = ax_in0.index(a)
|
|
136
|
+
ax_1 = ax_in1.index(a)
|
|
137
|
+
if shape0[ax_0] != shape1[ax_1]:
|
|
138
|
+
raise ValueError(f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}")
|
|
139
|
+
|
|
140
|
+
out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out)
|
|
141
|
+
return f'{in0},{in1}->{out}', out_shape
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe:
|
|
145
|
+
"""Parse einsum operation on two input arrays, return a recipe for execution
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
fn : str
|
|
150
|
+
einsum string, e.g. 'ij,jk->ik'
|
|
151
|
+
input : np.ndarray
|
|
152
|
+
input0, the first input array
|
|
153
|
+
input1 : np.ndarray
|
|
154
|
+
input1, the second input array
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
EinsumRecipe
|
|
159
|
+
einsum recipe; executed by _exec_einsum
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1)
|
|
163
|
+
|
|
164
|
+
_in, _out = fn.split('->')
|
|
165
|
+
_in0, _in1 = _in.split(',')
|
|
166
|
+
|
|
167
|
+
in0, in1, out = list(_in0), list(_in1), list(_out)
|
|
168
|
+
s_in0, s_in1, s_out = set(in0), set(in1), set(out)
|
|
169
|
+
_common = s_in0 & s_in1
|
|
170
|
+
_contract = _common - s_out
|
|
171
|
+
_inplace = _common & s_out
|
|
172
|
+
contract = sorted(_contract, key=lambda x: in1.index(x))
|
|
173
|
+
inplace = sorted(_inplace, key=lambda x: in1.index(x))
|
|
174
|
+
invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x))
|
|
175
|
+
invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x))
|
|
176
|
+
direct_sum0 = s_in0 - s_out - _common
|
|
177
|
+
direct_sum1 = s_in1 - s_out - _common
|
|
178
|
+
direct_sum_axis = (
|
|
179
|
+
tuple(sorted(in0.index(x) for x in direct_sum0)),
|
|
180
|
+
tuple(sorted(in1.index(x) for x in direct_sum1)),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract))
|
|
184
|
+
inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace))
|
|
185
|
+
invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1))
|
|
186
|
+
|
|
187
|
+
inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0])
|
|
188
|
+
inplace_size = prod(inplace_shape)
|
|
189
|
+
contract_size = prod(input_shape0[i] for i in contract_idxs[0])
|
|
190
|
+
invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0])
|
|
191
|
+
invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1])
|
|
192
|
+
invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1)
|
|
193
|
+
|
|
194
|
+
transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0]
|
|
195
|
+
transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1]
|
|
196
|
+
|
|
197
|
+
out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1
|
|
198
|
+
_out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1)))
|
|
199
|
+
out_transpose_idx = tuple(int(i) for i in _out_transpose_idx)
|
|
200
|
+
|
|
201
|
+
return EinsumRecipe(
|
|
202
|
+
direct_sum_axis=direct_sum_axis,
|
|
203
|
+
in_transpose_idxs=(transpose_idx0, transpose_idx1),
|
|
204
|
+
out_interpert_shape=out_shape_pretranspose,
|
|
205
|
+
out_transpose_idxs=out_transpose_idx,
|
|
206
|
+
L0=invariant_size0,
|
|
207
|
+
L1=invariant_size1,
|
|
208
|
+
I=inplace_size,
|
|
209
|
+
C=contract_size,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray:
|
|
214
|
+
"""Execute einsum operation on two input arrays
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
recipe : EinsumRecipe
|
|
219
|
+
einsum recipe
|
|
220
|
+
input0 : np.ndarray
|
|
221
|
+
input0, the first input array
|
|
222
|
+
input1 : np.ndarray
|
|
223
|
+
input1, the second input array
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
np.ndarray
|
|
228
|
+
output array
|
|
229
|
+
"""
|
|
230
|
+
sum_axis0, sum_axis1 = recipe['direct_sum_axis']
|
|
231
|
+
if sum_axis0:
|
|
232
|
+
input0 = np.sum(input0, axis=sum_axis0)
|
|
233
|
+
if sum_axis1:
|
|
234
|
+
input1 = np.sum(input1, axis=sum_axis1)
|
|
235
|
+
input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel()
|
|
236
|
+
input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel()
|
|
237
|
+
out_dtype = object if input0.dtype == object or input1.dtype == object else np.float64
|
|
238
|
+
output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=out_dtype)
|
|
239
|
+
|
|
240
|
+
L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C']
|
|
241
|
+
|
|
242
|
+
for l0 in range(L0):
|
|
243
|
+
for i in range(I):
|
|
244
|
+
A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C))
|
|
245
|
+
B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C]
|
|
246
|
+
output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = A @ B
|
|
247
|
+
|
|
248
|
+
return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs'])
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _einsum(fn: str, input0, input1) -> np.ndarray:
|
|
252
|
+
"""Execute einsum operation on two input arrays.
|
|
253
|
+
|
|
254
|
+
WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
fn : str
|
|
259
|
+
einsum string, e.g. 'ij,jk->ik'
|
|
260
|
+
input : np.ndarray
|
|
261
|
+
input0, the first input array
|
|
262
|
+
input1 : np.ndarray
|
|
263
|
+
input1, the second input array
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
np.ndarray
|
|
268
|
+
output array
|
|
269
|
+
"""
|
|
270
|
+
recipe = parse_einsum(fn, input0.shape, input1.shape)
|
|
271
|
+
return _exec_einsum(recipe, input0, input1)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@overload
|
|
275
|
+
def einsum(fn: str, input0: 'FixedVariableArray', input1: 'FixedVariableArray') -> 'FixedVariableArray': ...
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@overload
|
|
279
|
+
def einsum(fn: str, input0: 'FixedVariableArray', input1: NDArray[np.integer | np.floating]) -> 'FixedVariableArray': ...
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@overload
|
|
283
|
+
def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: 'FixedVariableArray') -> 'FixedVariableArray': ...
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@overload
|
|
287
|
+
def einsum(
|
|
288
|
+
fn: str, input0: NDArray[np.integer | np.floating], input1: NDArray[np.integer | np.floating]
|
|
289
|
+
) -> NDArray[np.integer | np.floating]: ...
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def einsum(fn: str, input0, input1):
|
|
293
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
294
|
+
|
|
295
|
+
fg0 = isinstance(input0, FixedVariableArray)
|
|
296
|
+
fg1 = isinstance(input1, FixedVariableArray)
|
|
297
|
+
|
|
298
|
+
r = _einsum(fn, input0, input1)
|
|
299
|
+
|
|
300
|
+
if fg0:
|
|
301
|
+
return FixedVariableArray(r, input0.solver_options)
|
|
302
|
+
elif fg1:
|
|
303
|
+
return FixedVariableArray(r, input1.solver_options)
|
|
304
|
+
else:
|
|
305
|
+
return r
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
from quantizers.fixed_point.fixed_point_ops_np import get_fixed_quantizer_np
|
|
6
|
+
|
|
7
|
+
from ..fixed_variable_array import FixedVariable
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
11
|
+
|
|
12
|
+
T = TypeVar('T', 'FixedVariableArray', NDArray[np.floating], list[FixedVariable])
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN') -> T:
|
|
16
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
17
|
+
|
|
18
|
+
if isinstance(x, FixedVariableArray):
|
|
19
|
+
return x.relu(i=i, f=f, round_mode=round_mode)
|
|
20
|
+
elif isinstance(x, list):
|
|
21
|
+
return [xx.relu(i=ii, f=ff, round_mode=round_mode) for xx, ii, ff in zip(x, i, f)] # type: ignore
|
|
22
|
+
else:
|
|
23
|
+
round_mode = round_mode.upper()
|
|
24
|
+
assert round_mode in ('TRN', 'RND')
|
|
25
|
+
x = np.maximum(x, 0)
|
|
26
|
+
if f is not None:
|
|
27
|
+
if round_mode == 'RND':
|
|
28
|
+
x += 2.0 ** (-f - 1)
|
|
29
|
+
sf = 2.0**f
|
|
30
|
+
x = np.floor(x * sf) / sf
|
|
31
|
+
if i is not None:
|
|
32
|
+
x = x % 2.0**i
|
|
33
|
+
return x
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _quantize(
|
|
37
|
+
x: NDArray[np.floating],
|
|
38
|
+
k: NDArray[np.integer] | np.integer | int,
|
|
39
|
+
i: NDArray[np.integer] | np.integer | int,
|
|
40
|
+
f: NDArray[np.integer] | np.integer | int,
|
|
41
|
+
overflow_mode: str = 'WRAP',
|
|
42
|
+
round_mode: str = 'TRN',
|
|
43
|
+
) -> NDArray[np.floating]:
|
|
44
|
+
q = get_fixed_quantizer_np(round_mode=round_mode, overflow_mode=overflow_mode)
|
|
45
|
+
return np.where(k + i + f <= 0, 0, q(x, k=k, i=i, f=f)) # type: ignore
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def quantize(
|
|
49
|
+
x: T,
|
|
50
|
+
k: NDArray[np.integer] | np.integer | int,
|
|
51
|
+
i: NDArray[np.integer] | np.integer | int,
|
|
52
|
+
f: NDArray[np.integer] | np.integer | int,
|
|
53
|
+
overflow_mode: str = 'WRAP',
|
|
54
|
+
round_mode: str = 'TRN',
|
|
55
|
+
) -> T:
|
|
56
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
57
|
+
|
|
58
|
+
if isinstance(x, (FixedVariableArray, FixedVariable)):
|
|
59
|
+
return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
|
|
60
|
+
elif isinstance(x, list):
|
|
61
|
+
ret: list[FixedVariable] = []
|
|
62
|
+
for i in range(len(x)):
|
|
63
|
+
ret.append(
|
|
64
|
+
x[i].quantize(
|
|
65
|
+
k=int(k[i] if isinstance(k, (list, np.ndarray)) else k),
|
|
66
|
+
i=int(i[i] if isinstance(i, (list, np.ndarray)) else i),
|
|
67
|
+
f=int(f[i] if isinstance(f, (list, np.ndarray)) else f),
|
|
68
|
+
overflow_mode=overflow_mode,
|
|
69
|
+
round_mode=round_mode,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
return ret # type: ignore
|
|
73
|
+
else:
|
|
74
|
+
return _quantize(x, k, i, f, overflow_mode, round_mode)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import heapq
|
|
2
|
+
import typing
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from math import prod
|
|
5
|
+
from typing import TypeVar
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
if typing.TYPE_CHECKING:
|
|
11
|
+
from ..fixed_variable import FixedVariable
|
|
12
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
T = typing.TypeVar('T', 'FixedVariable', float, np.floating)
|
|
16
|
+
TA = TypeVar('TA', 'FixedVariableArray', NDArray[np.integer | np.floating])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Packet:
|
|
20
|
+
def __init__(self, v):
|
|
21
|
+
self.value = v
|
|
22
|
+
|
|
23
|
+
def __gt__(self, other: 'Packet') -> bool: # type: ignore
|
|
24
|
+
from ..fixed_variable_array import FixedVariable
|
|
25
|
+
|
|
26
|
+
a, b = self.value, other.value
|
|
27
|
+
|
|
28
|
+
if isinstance(a, FixedVariable):
|
|
29
|
+
if isinstance(b, FixedVariable):
|
|
30
|
+
if b.latency > a.latency:
|
|
31
|
+
return False
|
|
32
|
+
if b.latency < a.latency:
|
|
33
|
+
return True
|
|
34
|
+
if b._factor > 0 and a._factor < 0:
|
|
35
|
+
return False
|
|
36
|
+
if b._factor < 0 and a._factor > 0:
|
|
37
|
+
return True
|
|
38
|
+
return sum(a.kif[:2]) > sum(b.kif[:2])
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
def __lt__(self, other: 'Packet') -> bool: # type: ignore
|
|
44
|
+
return not self.__gt__(other)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _reduce(operator: Callable[[T, T], T], arr: Sequence[T]) -> T:
|
|
48
|
+
from ..fixed_variable_array import FixedVariable
|
|
49
|
+
|
|
50
|
+
if isinstance(arr, np.ndarray):
|
|
51
|
+
arr = list(arr.ravel())
|
|
52
|
+
assert len(arr) > 0, 'Array must not be empty'
|
|
53
|
+
if len(arr) == 1:
|
|
54
|
+
return arr[0]
|
|
55
|
+
dtype = arr[0].__class__
|
|
56
|
+
if not issubclass(dtype, FixedVariable):
|
|
57
|
+
r = operator(arr[0], arr[1])
|
|
58
|
+
for i in range(2, len(arr)):
|
|
59
|
+
r = operator(r, arr[i])
|
|
60
|
+
return r
|
|
61
|
+
|
|
62
|
+
heap = [Packet(v) for v in arr] # type: ignore
|
|
63
|
+
heapq.heapify(heap)
|
|
64
|
+
while len(heap) > 1:
|
|
65
|
+
v1 = heapq.heappop(heap).value
|
|
66
|
+
v2 = heapq.heappop(heap).value
|
|
67
|
+
v = operator(v1, v2)
|
|
68
|
+
heapq.heappush(heap, Packet(v)) # type: ignore
|
|
69
|
+
return heap[0].value
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def reduce(operator: Callable[[T, T], T], x: TA, axis: int | Sequence[int] | None = None, keepdims: bool = False) -> TA:
|
|
73
|
+
"""
|
|
74
|
+
Reduce the array by summing over the specified axis.
|
|
75
|
+
"""
|
|
76
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
77
|
+
|
|
78
|
+
if isinstance(x, FixedVariableArray):
|
|
79
|
+
solver_config = x.solver_options
|
|
80
|
+
arr = x._vars
|
|
81
|
+
else:
|
|
82
|
+
solver_config = None
|
|
83
|
+
arr = x
|
|
84
|
+
all_axis = tuple(range(arr.ndim))
|
|
85
|
+
axis = axis if axis is not None else all_axis
|
|
86
|
+
axis = (axis,) if isinstance(axis, int) else tuple(axis)
|
|
87
|
+
axis = tuple(a if a >= 0 else a + arr.ndim for a in axis)
|
|
88
|
+
|
|
89
|
+
xpose_axis = sorted(all_axis, key=lambda a: (a in axis) * 1000 + a)
|
|
90
|
+
if keepdims:
|
|
91
|
+
target_shape = tuple(d if ax not in axis else 1 for ax, d in enumerate(arr.shape))
|
|
92
|
+
else:
|
|
93
|
+
target_shape = tuple(d for ax, d in enumerate(arr.shape) if ax not in axis)
|
|
94
|
+
|
|
95
|
+
dim_contract = prod(arr.shape[a] for a in axis)
|
|
96
|
+
arr = np.transpose(arr, xpose_axis) # type: ignore
|
|
97
|
+
_arr = arr.reshape(-1, dim_contract)
|
|
98
|
+
_arr = np.array([_reduce(operator, _arr[i]) for i in range(_arr.shape[0])])
|
|
99
|
+
r = _arr.reshape(target_shape) # type: ignore
|
|
100
|
+
|
|
101
|
+
if isinstance(x, FixedVariableArray):
|
|
102
|
+
r = FixedVariableArray(r, solver_config)
|
|
103
|
+
if r.shape == ():
|
|
104
|
+
return r._vars.item() # type: ignore
|
|
105
|
+
return r if r.shape != () or keepdims else r.item() # type: ignore
|
da4ml/trace/pipeline.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from math import ceil, floor
|
|
2
|
+
|
|
3
|
+
from ..cmvm.types import CombLogic, Op, Pipeline
|
|
4
|
+
from .fixed_variable import FixedVariable, HWConfig
|
|
5
|
+
from .tracer import comb_trace
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def retime_pipeline(csol: Pipeline, verbose=True):
|
|
9
|
+
n_stages = len(csol[0])
|
|
10
|
+
cutoff_high = ceil(max(max(sol.out_latency) / (i + 1) for i, sol in enumerate(csol[0])))
|
|
11
|
+
cutoff_low = 0
|
|
12
|
+
adder_size, carry_size = csol[0][0].adder_size, csol[0][0].carry_size
|
|
13
|
+
best = csol
|
|
14
|
+
while cutoff_high - cutoff_low > 1:
|
|
15
|
+
cutoff = (cutoff_high + cutoff_low) // 2
|
|
16
|
+
_hwconf = HWConfig(adder_size, carry_size, cutoff)
|
|
17
|
+
inp = [FixedVariable(*qint, hwconf=_hwconf) for qint in csol.inp_qint]
|
|
18
|
+
try:
|
|
19
|
+
out = list(csol(inp))
|
|
20
|
+
except AssertionError:
|
|
21
|
+
cutoff_low = cutoff
|
|
22
|
+
continue
|
|
23
|
+
_sol = to_pipeline(comb_trace(inp, out), cutoff, retiming=False)
|
|
24
|
+
if len(_sol[0]) > n_stages:
|
|
25
|
+
cutoff_low = cutoff
|
|
26
|
+
else:
|
|
27
|
+
cutoff_high = cutoff
|
|
28
|
+
best = _sol
|
|
29
|
+
if verbose:
|
|
30
|
+
print(f'actual cutoff: {cutoff_high}')
|
|
31
|
+
return best
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_new_idx(
|
|
35
|
+
idx: int,
|
|
36
|
+
locator: list[dict[int, int]],
|
|
37
|
+
opd: dict[int, list[Op]],
|
|
38
|
+
out_idxd: dict[int, list[int]],
|
|
39
|
+
ops: list[Op],
|
|
40
|
+
stage: int,
|
|
41
|
+
latency_cutoff: float,
|
|
42
|
+
):
|
|
43
|
+
if idx < 0:
|
|
44
|
+
return idx
|
|
45
|
+
p0_stages = locator[idx].keys()
|
|
46
|
+
if stage not in p0_stages:
|
|
47
|
+
# Need to copy parent to later states
|
|
48
|
+
p0_stage = max(p0_stages)
|
|
49
|
+
p0_idx = locator[idx][p0_stage]
|
|
50
|
+
for j in range(p0_stage, stage):
|
|
51
|
+
op0 = ops[idx]
|
|
52
|
+
latency = float(latency_cutoff * (j + 1))
|
|
53
|
+
out_idxd.setdefault(j, []).append(locator[idx][j])
|
|
54
|
+
_copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op0.qint, latency, 0.0)
|
|
55
|
+
opd.setdefault(j + 1, []).append(_copy_op)
|
|
56
|
+
p0_idx = len(opd[j + 1]) - 1
|
|
57
|
+
locator[idx][j + 1] = p0_idx
|
|
58
|
+
else:
|
|
59
|
+
p0_idx = locator[idx][stage]
|
|
60
|
+
return p0_idx
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def to_pipeline(comb: CombLogic, latency_cutoff: float, retiming=True, verbose=True) -> Pipeline:
|
|
64
|
+
"""Split the record into multiple stages based on the latency of the operations.
|
|
65
|
+
Only useful for HDL generation.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
sol : CombLogic
|
|
70
|
+
The combinational logic to be pipelined into multiple stages.
|
|
71
|
+
latency_cutoff : float
|
|
72
|
+
The latency cutoff for splitting the operations.
|
|
73
|
+
retiming : bool
|
|
74
|
+
Whether to retime the solution after splitting. Default is True.
|
|
75
|
+
If False, new stages are created when the propagation latency exceeds the cutoff.
|
|
76
|
+
If True, after the first round of splitting, the solution is retimed balance the delay within each stage.
|
|
77
|
+
verbose : bool
|
|
78
|
+
Whether to print the actual cutoff used for splitting. Only used if rebalance is True.
|
|
79
|
+
Default is True.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
CascadedSolution
|
|
84
|
+
The cascaded solution with multiple stages.
|
|
85
|
+
"""
|
|
86
|
+
assert len(comb.ops) > 0, 'No operations in the record'
|
|
87
|
+
for i, op in enumerate(comb.ops):
|
|
88
|
+
if op.id1 != -1:
|
|
89
|
+
break
|
|
90
|
+
|
|
91
|
+
def get_stage(op: Op):
|
|
92
|
+
return floor(op.latency / (latency_cutoff + 1e-9)) if latency_cutoff > 0 else 0
|
|
93
|
+
|
|
94
|
+
opd: dict[int, list[Op]] = {}
|
|
95
|
+
out_idxd: dict[int, list[int]] = {}
|
|
96
|
+
|
|
97
|
+
locator: list[dict[int, int]] = []
|
|
98
|
+
|
|
99
|
+
ops = comb.ops.copy()
|
|
100
|
+
lat = max(ops[i].latency for i in comb.out_idxs)
|
|
101
|
+
for i in comb.out_idxs:
|
|
102
|
+
op_out = ops[i]
|
|
103
|
+
ops.append(Op(i, -1001, -1001, 0, op_out.qint, lat, 0.0))
|
|
104
|
+
|
|
105
|
+
for i, op in enumerate(ops):
|
|
106
|
+
stage = get_stage(op)
|
|
107
|
+
if op.opcode == -1:
|
|
108
|
+
# Copy from external buffer
|
|
109
|
+
opd.setdefault(stage, []).append(op)
|
|
110
|
+
locator.append({stage: len(opd[stage]) - 1})
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
p0_idx = _get_new_idx(op.id0, locator, opd, out_idxd, ops, stage, latency_cutoff)
|
|
114
|
+
p1_idx = _get_new_idx(op.id1, locator, opd, out_idxd, ops, stage, latency_cutoff)
|
|
115
|
+
if op.opcode in (6, -6):
|
|
116
|
+
k = op.data & 0xFFFFFFFF
|
|
117
|
+
_shift = (op.data >> 32) & 0xFFFFFFFF
|
|
118
|
+
k = _get_new_idx(k, locator, opd, out_idxd, ops, stage, latency_cutoff)
|
|
119
|
+
data = _shift << 32 | k
|
|
120
|
+
else:
|
|
121
|
+
data = op.data
|
|
122
|
+
|
|
123
|
+
if p1_idx == -1001:
|
|
124
|
+
# Output to external buffer
|
|
125
|
+
out_idxd.setdefault(stage, []).append(p0_idx)
|
|
126
|
+
else:
|
|
127
|
+
_Op = Op(p0_idx, p1_idx, op.opcode, data, op.qint, op.latency, op.cost)
|
|
128
|
+
opd.setdefault(stage, []).append(_Op)
|
|
129
|
+
locator.append({stage: len(opd[stage]) - 1})
|
|
130
|
+
sols = []
|
|
131
|
+
max_stage = max(opd.keys())
|
|
132
|
+
n_in = comb.shape[0]
|
|
133
|
+
for i, stage in enumerate(opd.keys()):
|
|
134
|
+
_ops = opd[stage]
|
|
135
|
+
_out_idx = out_idxd[stage]
|
|
136
|
+
n_out = len(_out_idx)
|
|
137
|
+
|
|
138
|
+
if i == max_stage:
|
|
139
|
+
out_shifts = comb.out_shifts
|
|
140
|
+
out_negs = comb.out_negs
|
|
141
|
+
else:
|
|
142
|
+
out_shifts = [0] * len(_out_idx)
|
|
143
|
+
out_negs = [False] * len(_out_idx)
|
|
144
|
+
|
|
145
|
+
if comb.lookup_tables is not None:
|
|
146
|
+
_ops, lookup_tables = remap_table_idxs(comb, _ops)
|
|
147
|
+
else:
|
|
148
|
+
lookup_tables = None
|
|
149
|
+
_sol = CombLogic(
|
|
150
|
+
shape=(n_in, n_out),
|
|
151
|
+
inp_shifts=[0] * n_in,
|
|
152
|
+
out_idxs=_out_idx,
|
|
153
|
+
out_shifts=out_shifts,
|
|
154
|
+
out_negs=out_negs,
|
|
155
|
+
ops=_ops,
|
|
156
|
+
carry_size=comb.carry_size,
|
|
157
|
+
adder_size=comb.adder_size,
|
|
158
|
+
lookup_tables=lookup_tables,
|
|
159
|
+
)
|
|
160
|
+
sols.append(_sol)
|
|
161
|
+
|
|
162
|
+
n_in = n_out
|
|
163
|
+
csol = Pipeline(tuple(sols))
|
|
164
|
+
|
|
165
|
+
if retiming:
|
|
166
|
+
csol = retime_pipeline(csol, verbose=verbose)
|
|
167
|
+
return csol
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def remap_table_idxs(comb: CombLogic, _ops):
|
|
171
|
+
assert comb.lookup_tables is not None
|
|
172
|
+
table_idxs = sorted(list({op.data for op in _ops if op.opcode == 8}))
|
|
173
|
+
remap = {j: i for i, j in enumerate(table_idxs)}
|
|
174
|
+
_ops_remap = []
|
|
175
|
+
for op in _ops:
|
|
176
|
+
if op.opcode == 8:
|
|
177
|
+
op = Op(op.id0, op.id1, op.opcode, remap[op.data], op.qint, op.latency, op.cost)
|
|
178
|
+
_ops_remap.append(op)
|
|
179
|
+
_ops = _ops_remap
|
|
180
|
+
lookup_tables = tuple(comb.lookup_tables[i] for i in table_idxs)
|
|
181
|
+
return _ops, lookup_tables
|