da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +204 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +246 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.1.post1.dist-info/METADATA +85 -0
- da4ml-0.5.1.post1.dist-info/RECORD +96 -0
- da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
- da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
- da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
library ieee;
|
|
2
|
+
use ieee.std_logic_1164.all;
|
|
3
|
+
use ieee.numeric_std.all;
|
|
4
|
+
|
|
5
|
+
entity shift_adder is
|
|
6
|
+
generic (
|
|
7
|
+
BW_INPUT0 : integer := 32;
|
|
8
|
+
BW_INPUT1 : integer := 32;
|
|
9
|
+
SIGNED0 : integer := 0;
|
|
10
|
+
SIGNED1 : integer := 0;
|
|
11
|
+
BW_OUT : integer := 32;
|
|
12
|
+
SHIFT1 : integer := 0;
|
|
13
|
+
IS_SUB : integer := 0
|
|
14
|
+
);
|
|
15
|
+
port (
|
|
16
|
+
in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
|
|
17
|
+
in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
|
|
18
|
+
result : out std_logic_vector(BW_OUT-1 downto 0)
|
|
19
|
+
);
|
|
20
|
+
end entity shift_adder;
|
|
21
|
+
|
|
22
|
+
architecture rtl of shift_adder is
|
|
23
|
+
function max(L, R: integer) return integer is
|
|
24
|
+
begin
|
|
25
|
+
if L > R then
|
|
26
|
+
return L;
|
|
27
|
+
else
|
|
28
|
+
return R;
|
|
29
|
+
end if;
|
|
30
|
+
end function;
|
|
31
|
+
|
|
32
|
+
function if_then_else(cond: boolean; val_true: integer; val_false: integer) return integer is
|
|
33
|
+
begin
|
|
34
|
+
if cond then
|
|
35
|
+
return val_true;
|
|
36
|
+
else
|
|
37
|
+
return val_false;
|
|
38
|
+
end if;
|
|
39
|
+
end function;
|
|
40
|
+
|
|
41
|
+
constant IN0_NEED_BITS : integer := if_then_else(SHIFT1 < 0, BW_INPUT0 - SHIFT1, BW_INPUT0);
|
|
42
|
+
constant IN1_NEED_BITS : integer := if_then_else(SHIFT1 > 0, BW_INPUT1 + SHIFT1, BW_INPUT1);
|
|
43
|
+
constant EXTRA_PAD : integer := if_then_else(SIGNED0 /= SIGNED1, IS_SUB + 1, IS_SUB);
|
|
44
|
+
constant BW_ADD : integer := max(IN0_NEED_BITS, IN1_NEED_BITS) + EXTRA_PAD + 1;
|
|
45
|
+
|
|
46
|
+
signal in0_ext : std_logic_vector(BW_ADD-1 downto 0);
|
|
47
|
+
signal in1_ext : std_logic_vector(BW_ADD-1 downto 0);
|
|
48
|
+
signal accum : std_logic_vector(BW_ADD-1 downto 0);
|
|
49
|
+
|
|
50
|
+
begin
|
|
51
|
+
|
|
52
|
+
-- Extension and shifting for input 0
|
|
53
|
+
gen_in0_shift_neg: if SHIFT1 < 0 generate
|
|
54
|
+
gen_in0_signed: if SIGNED0 = 1 generate
|
|
55
|
+
in0_ext <= std_logic_vector(resize(signed(in0), BW_ADD)) sll (-SHIFT1);
|
|
56
|
+
end generate;
|
|
57
|
+
gen_in0_unsigned: if SIGNED0 = 0 generate
|
|
58
|
+
in0_ext <= std_logic_vector(resize(unsigned(in0), BW_ADD)) sll (-SHIFT1);
|
|
59
|
+
end generate;
|
|
60
|
+
end generate;
|
|
61
|
+
|
|
62
|
+
gen_in0_shift_pos: if SHIFT1 >= 0 generate
|
|
63
|
+
gen_in0_signed: if SIGNED0 = 1 generate
|
|
64
|
+
in0_ext <= std_logic_vector(resize(signed(in0), BW_ADD));
|
|
65
|
+
end generate;
|
|
66
|
+
gen_in0_unsigned: if SIGNED0 = 0 generate
|
|
67
|
+
in0_ext <= std_logic_vector(resize(unsigned(in0), BW_ADD));
|
|
68
|
+
end generate;
|
|
69
|
+
end generate;
|
|
70
|
+
|
|
71
|
+
-- Extension and shifting for input 1
|
|
72
|
+
gen_in1_shift_pos: if SHIFT1 > 0 generate
|
|
73
|
+
gen_in1_signed: if SIGNED1 = 1 generate
|
|
74
|
+
in1_ext <= std_logic_vector(resize(signed(in1), BW_ADD)) sll SHIFT1;
|
|
75
|
+
end generate;
|
|
76
|
+
gen_in1_unsigned: if SIGNED1 = 0 generate
|
|
77
|
+
in1_ext <= std_logic_vector(resize(unsigned(in1), BW_ADD)) sll SHIFT1;
|
|
78
|
+
end generate;
|
|
79
|
+
end generate;
|
|
80
|
+
|
|
81
|
+
gen_in1_shift_neg: if SHIFT1 <= 0 generate
|
|
82
|
+
gen_in1_signed: if SIGNED1 = 1 generate
|
|
83
|
+
in1_ext <= std_logic_vector(resize(signed(in1), BW_ADD));
|
|
84
|
+
end generate;
|
|
85
|
+
gen_in1_unsigned: if SIGNED1 = 0 generate
|
|
86
|
+
in1_ext <= std_logic_vector(resize(unsigned(in1), BW_ADD));
|
|
87
|
+
end generate;
|
|
88
|
+
end generate;
|
|
89
|
+
|
|
90
|
+
-- Addition/subtraction logic
|
|
91
|
+
gen_sub: if IS_SUB = 1 generate
|
|
92
|
+
accum <= std_logic_vector(signed(in0_ext) - signed(in1_ext));
|
|
93
|
+
end generate;
|
|
94
|
+
|
|
95
|
+
gen_add: if IS_SUB = 0 generate
|
|
96
|
+
accum <= std_logic_vector(signed(in0_ext) + signed(in1_ext));
|
|
97
|
+
end generate;
|
|
98
|
+
|
|
99
|
+
result <= accum(BW_OUT-1 downto 0);
|
|
100
|
+
|
|
101
|
+
end architecture rtl;
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal, overload
|
|
3
|
+
|
|
4
|
+
from ..cmvm.api import solver_options_t
|
|
5
|
+
from ..trace import FixedVariableArray, HWConfig
|
|
6
|
+
|
|
7
|
+
__all__ = ['trace_model']
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@overload
|
|
11
|
+
def trace_model( # type: ignore
|
|
12
|
+
model: Callable,
|
|
13
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
14
|
+
solver_options: solver_options_t | None = None,
|
|
15
|
+
verbose: bool = False,
|
|
16
|
+
inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
|
|
17
|
+
inputs_kif: tuple[int, int, int] | None = None,
|
|
18
|
+
dump: Literal[False] = False,
|
|
19
|
+
) -> tuple[FixedVariableArray, FixedVariableArray]: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@overload
|
|
23
|
+
def trace_model( # type: ignore
|
|
24
|
+
model: Callable,
|
|
25
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
26
|
+
solver_options: solver_options_t | None = None,
|
|
27
|
+
verbose: bool = False,
|
|
28
|
+
inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
|
|
29
|
+
inputs_kif: tuple[int, int, int] | None = None,
|
|
30
|
+
dump: Literal[True] = False, # type: ignore
|
|
31
|
+
) -> dict[str, FixedVariableArray]: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def trace_model( # type: ignore
|
|
35
|
+
model: Callable,
|
|
36
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
37
|
+
solver_options: solver_options_t | None = None,
|
|
38
|
+
verbose: bool = False,
|
|
39
|
+
inputs: tuple[FixedVariableArray, ...] | None = None,
|
|
40
|
+
inputs_kif: tuple[int, int, int] | None = None,
|
|
41
|
+
dump=False,
|
|
42
|
+
):
|
|
43
|
+
hwconf = HWConfig(*hwconf) if isinstance(hwconf, tuple) else hwconf
|
|
44
|
+
|
|
45
|
+
module = type(model).__module__
|
|
46
|
+
if module.startswith('keras.'):
|
|
47
|
+
import keras
|
|
48
|
+
|
|
49
|
+
from .hgq2 import trace_model as keras_trace_model
|
|
50
|
+
|
|
51
|
+
assert isinstance(model, keras.Model)
|
|
52
|
+
|
|
53
|
+
return keras_trace_model(
|
|
54
|
+
model,
|
|
55
|
+
hwconf,
|
|
56
|
+
solver_options=solver_options,
|
|
57
|
+
verbose=verbose,
|
|
58
|
+
inputs=inputs,
|
|
59
|
+
inputs_kif=inputs_kif,
|
|
60
|
+
dump=dump,
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(f'Unsupported model type: {type(model)}')
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import hgq
|
|
6
|
+
import keras
|
|
7
|
+
import numpy as np
|
|
8
|
+
from hgq.layers.core.base import MultipleQuantizers, Quantizer
|
|
9
|
+
from hgq.quantizer.internal import FixedPointQuantizerBase
|
|
10
|
+
from keras.ops import convert_to_numpy
|
|
11
|
+
|
|
12
|
+
from ....trace import FixedVariable, FixedVariableArray
|
|
13
|
+
from ....trace.ops import quantize, relu
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def to_np_arr(x: Any) -> np.ndarray:
|
|
17
|
+
return np.asarray(convert_to_numpy(x))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def mirror_quantizer(q: Quantizer, v: FixedVariableArray) -> FixedVariableArray:
|
|
21
|
+
if q.scaler is not None:
|
|
22
|
+
v = v * (1.0 / q.scaler)
|
|
23
|
+
q_internal: FixedPointQuantizerBase = q.quantizer
|
|
24
|
+
kk, ki, kf = q_internal.kif
|
|
25
|
+
shape = (1,) + v.shape
|
|
26
|
+
kk = q_internal.bw_mapper.bw_to_x(kk, shape)
|
|
27
|
+
ki = q_internal.bw_mapper.bw_to_x(ki, shape)
|
|
28
|
+
kf = q_internal.bw_mapper.bw_to_x(kf, shape)
|
|
29
|
+
k, i, f = (to_np_arr(x).astype(np.int8)[0] for x in (kk, ki, kf))
|
|
30
|
+
round_mode, overflow_mode = q_internal.round_mode, q_internal.overflow_mode
|
|
31
|
+
rq = quantize(v, k, i, f, overflow_mode=overflow_mode, round_mode=round_mode)
|
|
32
|
+
if q.affine:
|
|
33
|
+
rq = rq * q.affine[0] + q.affine[1]
|
|
34
|
+
return rq
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_registry: dict[type, 'type[ReplayOperationBase]'] = {}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class HandlerRegMeta(type):
|
|
41
|
+
"""Metaclass for automatic registration of handler classes."""
|
|
42
|
+
|
|
43
|
+
def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]):
|
|
44
|
+
cls = super().__new__(mcs, name, bases, namespace)
|
|
45
|
+
if name == 'ReplayOperationBase':
|
|
46
|
+
return cls
|
|
47
|
+
|
|
48
|
+
handles: type | tuple[type, ...] = namespace['handles']
|
|
49
|
+
if not isinstance(handles, tuple):
|
|
50
|
+
handles = (handles,)
|
|
51
|
+
|
|
52
|
+
for handle in handles:
|
|
53
|
+
_registry[handle] = cls # type: ignore
|
|
54
|
+
return cls
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ReplayOperationBase(metaclass=HandlerRegMeta):
|
|
58
|
+
handles: tuple[type, ...] = ()
|
|
59
|
+
__activation_handled__ = False
|
|
60
|
+
__input_quantizer_handled__ = False
|
|
61
|
+
__output_quantizer_handled__ = False
|
|
62
|
+
|
|
63
|
+
def __init__(self, layer: 'keras.Operation'):
|
|
64
|
+
assert isinstance(layer, self.handles)
|
|
65
|
+
self.op: Any = layer
|
|
66
|
+
|
|
67
|
+
def call(self, *args, **kwargs) -> tuple[FixedVariableArray, ...] | FixedVariableArray: ...
|
|
68
|
+
|
|
69
|
+
def __call__(self, *args, **kwargs) -> tuple[FixedVariableArray, ...]:
|
|
70
|
+
assert all(not isinstance(a, FixedVariableArray) for a in kwargs.values())
|
|
71
|
+
|
|
72
|
+
if not isinstance(self.op, hgq.layers.QLayerBase):
|
|
73
|
+
r = self.call(*args, **kwargs)
|
|
74
|
+
return r if isinstance(r, tuple) else (r,)
|
|
75
|
+
|
|
76
|
+
layer: hgq.layers.QLayerBase = self.op
|
|
77
|
+
assert kwargs.pop('training', False) is False, 'Training mode is not supported in mirror operation'
|
|
78
|
+
assert kwargs.pop('mask', None) is None, 'Masking is not supported in mirror operation'
|
|
79
|
+
|
|
80
|
+
if not self.__input_quantizer_handled__:
|
|
81
|
+
assert len(args) == 1
|
|
82
|
+
inputs = args[0]
|
|
83
|
+
|
|
84
|
+
if layer.enable_iq:
|
|
85
|
+
if isinstance(inputs, Sequence):
|
|
86
|
+
assert isinstance(layer.iq, MultipleQuantizers)
|
|
87
|
+
inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
|
|
88
|
+
else:
|
|
89
|
+
assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
|
|
90
|
+
inputs = mirror_quantizer(layer.iq, inputs)
|
|
91
|
+
|
|
92
|
+
outputs = self.call(inputs, **kwargs)
|
|
93
|
+
else:
|
|
94
|
+
outputs = self.call(*args, **kwargs)
|
|
95
|
+
if isinstance(outputs, FixedVariable):
|
|
96
|
+
outputs = FixedVariableArray(np.array([outputs]))
|
|
97
|
+
|
|
98
|
+
if not self.__activation_handled__:
|
|
99
|
+
activation = getattr(layer, 'activation', keras.activations.linear)
|
|
100
|
+
if activation is not keras.activations.linear:
|
|
101
|
+
if activation is keras.activations.relu:
|
|
102
|
+
if isinstance(outputs, tuple):
|
|
103
|
+
assert len(outputs) == 1, 'ReLU activation is expected to have a single output'
|
|
104
|
+
outputs = (relu(outputs[0]),)
|
|
105
|
+
else:
|
|
106
|
+
outputs = relu(outputs)
|
|
107
|
+
else:
|
|
108
|
+
raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
|
|
109
|
+
|
|
110
|
+
if layer.enable_oq and not self.__output_quantizer_handled__:
|
|
111
|
+
if isinstance(outputs, tuple):
|
|
112
|
+
assert isinstance(layer.oq, MultipleQuantizers)
|
|
113
|
+
outputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.oq.quantizers, outputs))
|
|
114
|
+
else:
|
|
115
|
+
assert isinstance(layer.oq, Quantizer)
|
|
116
|
+
outputs = mirror_quantizer(layer.oq, outputs)
|
|
117
|
+
|
|
118
|
+
if isinstance(outputs, (FixedVariableArray, np.ndarray)):
|
|
119
|
+
outputs = (outputs,)
|
|
120
|
+
|
|
121
|
+
return outputs
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ReplayQuantizer(ReplayOperationBase):
|
|
125
|
+
handles = (Quantizer,)
|
|
126
|
+
|
|
127
|
+
def __init__(self, op: 'Quantizer'):
|
|
128
|
+
super().__init__(op)
|
|
129
|
+
assert isinstance(op.quantizer, FixedPointQuantizerBase)
|
|
130
|
+
|
|
131
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
132
|
+
return mirror_quantizer(self.op, inputs)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
import numpy as np
|
|
3
|
+
from hgq.layers import (
|
|
4
|
+
QSoftmax,
|
|
5
|
+
QUnaryFunctionLUT,
|
|
6
|
+
)
|
|
7
|
+
from keras.layers import LeakyReLU, PReLU, ReLU
|
|
8
|
+
|
|
9
|
+
from ....trace import FixedVariableArray
|
|
10
|
+
from ....trace.ops import relu
|
|
11
|
+
from ._base import ReplayOperationBase, to_np_arr
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ReplayReLU(ReplayOperationBase):
|
|
15
|
+
handles = (ReLU, LeakyReLU, PReLU)
|
|
16
|
+
|
|
17
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
18
|
+
op = self.op
|
|
19
|
+
if isinstance(op, ReLU):
|
|
20
|
+
th, neg, maxv = op.threshold, op.negative_slope, op.max_value
|
|
21
|
+
elif isinstance(op, LeakyReLU):
|
|
22
|
+
th, neg, maxv = 0, op.negative_slope, None
|
|
23
|
+
elif isinstance(op, PReLU):
|
|
24
|
+
th, neg, maxv = 0, to_np_arr(op.alpha), None
|
|
25
|
+
else:
|
|
26
|
+
raise TypeError(f'Unsupported activation layer: {type(op)}')
|
|
27
|
+
|
|
28
|
+
if th == 0 and np.all(neg == 0) and maxv is None:
|
|
29
|
+
return relu(inputs)
|
|
30
|
+
|
|
31
|
+
pos_part = inputs if maxv is None else np.minimum(inputs, maxv) # type: ignore
|
|
32
|
+
pos_part = pos_part._vars.ravel()
|
|
33
|
+
|
|
34
|
+
if th != 0:
|
|
35
|
+
z_cond = (inputs - (th + 2.0 ** (-inputs.kif[2] - 1)))._vars.ravel()
|
|
36
|
+
else:
|
|
37
|
+
z_cond = inputs._vars.ravel()
|
|
38
|
+
|
|
39
|
+
neg_part = ((inputs[None] - th) * neg)._vars.ravel()
|
|
40
|
+
out = np.array([c.msb_mux(n, p) if c.low < 0 else p for c, n, p in zip(z_cond, neg_part, pos_part)])
|
|
41
|
+
|
|
42
|
+
return FixedVariableArray(out.reshape(inputs.shape), inputs.solver_options)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ReplayQFunctionLUT(ReplayOperationBase):
|
|
46
|
+
__activation_handled__ = True
|
|
47
|
+
handles = (QUnaryFunctionLUT,)
|
|
48
|
+
|
|
49
|
+
def call(self, x: FixedVariableArray) -> FixedVariableArray:
|
|
50
|
+
op: QUnaryFunctionLUT = self.op
|
|
51
|
+
|
|
52
|
+
def activation(x) -> np.ndarray:
|
|
53
|
+
kx = keras.ops.convert_to_tensor(x[None])
|
|
54
|
+
kx = op.activation(kx)
|
|
55
|
+
return keras.ops.convert_to_numpy(kx[0]) # type: ignore
|
|
56
|
+
|
|
57
|
+
return x.apply(activation)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ReplayQSoftmax(ReplayOperationBase):
|
|
61
|
+
handles = (QSoftmax,)
|
|
62
|
+
|
|
63
|
+
def call(self, inputs: FixedVariableArray, mask: None | FixedVariableArray = None) -> FixedVariableArray:
|
|
64
|
+
op: QSoftmax = self.op
|
|
65
|
+
inputs = inputs[None]
|
|
66
|
+
|
|
67
|
+
if op.stable:
|
|
68
|
+
inputs = np.amax(inputs, axis=op.axes, keepdims=True) - inputs # type: ignore
|
|
69
|
+
|
|
70
|
+
exp_inp = ReplayQFunctionLUT(op.exp_table)(inputs[0])[0]
|
|
71
|
+
|
|
72
|
+
if mask is not None:
|
|
73
|
+
exp_inp = mask[0] * exp_inp
|
|
74
|
+
|
|
75
|
+
sums = np.sum(exp_inp[None], axis=op.axes, keepdims=True)[0] # type: ignore
|
|
76
|
+
divisor = ReplayQFunctionLUT(op.inv_table)(sums)[0]
|
|
77
|
+
|
|
78
|
+
return exp_inp * divisor
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
__all__ = ['ReplayReLU', 'ReplayQFunctionLUT', 'ReplayQSoftmax']
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from hgq.layers import (
|
|
3
|
+
QLinformerAttention,
|
|
4
|
+
QMultiHeadAttention,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from ....trace import FixedVariableArray
|
|
8
|
+
from ....trace.ops import einsum
|
|
9
|
+
from ._base import ReplayOperationBase, mirror_quantizer
|
|
10
|
+
from .activation import ReplayQSoftmax
|
|
11
|
+
from .dense import ReplayQDense
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _compute_attention_mask(
|
|
15
|
+
query,
|
|
16
|
+
value,
|
|
17
|
+
query_mask=None,
|
|
18
|
+
value_mask=None,
|
|
19
|
+
key_mask=None,
|
|
20
|
+
attention_mask=None,
|
|
21
|
+
use_causal_mask=False,
|
|
22
|
+
):
|
|
23
|
+
masks = []
|
|
24
|
+
if query_mask is not None:
|
|
25
|
+
masks.append(np.expand_dims(query_mask, -1)) # [Q, 1]
|
|
26
|
+
if value_mask is not None:
|
|
27
|
+
masks.append(np.expand_dims(value_mask, -2)) # [1, V]
|
|
28
|
+
if key_mask is not None:
|
|
29
|
+
masks.append(np.expand_dims(key_mask, -2)) # [1, V]
|
|
30
|
+
if use_causal_mask:
|
|
31
|
+
q = query.shape[0]
|
|
32
|
+
v = q if value is None else value.shape[0]
|
|
33
|
+
masks.append(np.tril(np.ones((q, v), dtype='uint8'))) # [Q, V]
|
|
34
|
+
masks.append(attention_mask)
|
|
35
|
+
if not masks:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
if any(isinstance(m, FixedVariableArray) for m in masks):
|
|
39
|
+
return np.prod(np.stack(masks, axis=0), axis=0)
|
|
40
|
+
else:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _masked_softmax(op, attention_scores, attention_mask=None):
|
|
45
|
+
# Normalize the attention scores to probabilities.
|
|
46
|
+
# attention_scores = [B, N, T, S]
|
|
47
|
+
if attention_mask is not None:
|
|
48
|
+
# The expand dim happens starting from the `num_heads` dimension,
|
|
49
|
+
# (<batch_dims>, num_heads, <query_attention_dims,
|
|
50
|
+
# key_attention_dims>)
|
|
51
|
+
mask_expansion_axis = -len(op._attention_axes) * 2 - 1
|
|
52
|
+
for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
|
|
53
|
+
attention_mask = np.expand_dims(attention_mask, axis=mask_expansion_axis)
|
|
54
|
+
return ReplayQSoftmax(op._softmax)(attention_scores[0], mask=attention_mask)[0][None]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _compute_attention(op: QMultiHeadAttention, query, key, value, attention_mask=None, training=None):
|
|
58
|
+
# Take the dot product between "query" and "key" to get the raw
|
|
59
|
+
# attention scores.
|
|
60
|
+
attention_scores = einsum(op._dot_product_equation, key, query)
|
|
61
|
+
|
|
62
|
+
attention_scores = _masked_softmax(op, attention_scores, attention_mask)
|
|
63
|
+
|
|
64
|
+
# `context_layer` = [B, T, N, H]
|
|
65
|
+
attention_output = einsum(op._combine_equation, attention_scores, value)
|
|
66
|
+
return attention_output, attention_scores
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ReplayMHA(ReplayOperationBase):
|
|
70
|
+
handles = (QMultiHeadAttention,)
|
|
71
|
+
__input_quantizer_handled__ = True
|
|
72
|
+
__output_quantizer_handled__ = True
|
|
73
|
+
|
|
74
|
+
def call(
|
|
75
|
+
self,
|
|
76
|
+
query: FixedVariableArray,
|
|
77
|
+
value: FixedVariableArray,
|
|
78
|
+
key=None,
|
|
79
|
+
query_mask=None,
|
|
80
|
+
value_mask=None,
|
|
81
|
+
key_mask=None,
|
|
82
|
+
attention_mask=None,
|
|
83
|
+
return_attention_scores=False,
|
|
84
|
+
use_causal_mask=False,
|
|
85
|
+
):
|
|
86
|
+
op: QMultiHeadAttention = self.op
|
|
87
|
+
|
|
88
|
+
if key is None:
|
|
89
|
+
key = value
|
|
90
|
+
|
|
91
|
+
_attention_mask = _compute_attention_mask(
|
|
92
|
+
query,
|
|
93
|
+
value,
|
|
94
|
+
query_mask=query_mask,
|
|
95
|
+
value_mask=value_mask,
|
|
96
|
+
key_mask=key_mask,
|
|
97
|
+
attention_mask=attention_mask,
|
|
98
|
+
use_causal_mask=use_causal_mask,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
query = ReplayQDense(op._query_dense)(query)[0][None]
|
|
102
|
+
key = ReplayQDense(op._key_dense)(key)[0][None]
|
|
103
|
+
value = ReplayQDense(op._value_dense)(value)[0][None]
|
|
104
|
+
|
|
105
|
+
attention_output, attention_scores = _compute_attention(op, query, key, value, _attention_mask)
|
|
106
|
+
attention_output = ReplayQDense(op._output_dense)(attention_output[0])[0]
|
|
107
|
+
|
|
108
|
+
if op.enable_oq:
|
|
109
|
+
attention_output = mirror_quantizer(op.oq, attention_output)
|
|
110
|
+
|
|
111
|
+
if return_attention_scores:
|
|
112
|
+
return attention_output, attention_scores[0]
|
|
113
|
+
return attention_output
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ReplayQLinformerAttention(ReplayMHA):
|
|
117
|
+
handles = (QLinformerAttention,)
|
|
118
|
+
|
|
119
|
+
def call(
|
|
120
|
+
self,
|
|
121
|
+
query,
|
|
122
|
+
value,
|
|
123
|
+
key=None,
|
|
124
|
+
query_mask=None,
|
|
125
|
+
value_mask=None,
|
|
126
|
+
key_mask=None,
|
|
127
|
+
attention_mask=None,
|
|
128
|
+
return_attention_scores=False,
|
|
129
|
+
use_causal_mask=False,
|
|
130
|
+
):
|
|
131
|
+
assert use_causal_mask is False, 'Causal mask is not supported in QLinformerAttention.'
|
|
132
|
+
key = key if key is not None else value
|
|
133
|
+
op: QLinformerAttention = self.op
|
|
134
|
+
key = ReplayQDense(op._lin_k_proj)(key)[0]
|
|
135
|
+
value = ReplayQDense(op._lin_v_proj)(value)[0]
|
|
136
|
+
return super().call(
|
|
137
|
+
query,
|
|
138
|
+
value,
|
|
139
|
+
key,
|
|
140
|
+
query_mask=query_mask,
|
|
141
|
+
value_mask=value_mask,
|
|
142
|
+
key_mask=key_mask,
|
|
143
|
+
attention_mask=attention_mask,
|
|
144
|
+
return_attention_scores=return_attention_scores,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
__all__ = ['ReplayMHA', 'ReplayQLinformerAttention']
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from hgq.layers import QBatchNormalization
|
|
3
|
+
|
|
4
|
+
from ....trace import FixedVariableArray
|
|
5
|
+
from ._base import ReplayOperationBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ReplayQBatchNormalization(ReplayOperationBase):
|
|
9
|
+
handles = (QBatchNormalization,)
|
|
10
|
+
|
|
11
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
12
|
+
layer: QBatchNormalization = self.op
|
|
13
|
+
scale, bias = map(np.array, layer.qscaler_and_qoffset)
|
|
14
|
+
shape = layer._shape[1:]
|
|
15
|
+
return inputs * scale.reshape(shape) + bias.reshape(shape)
|