da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +194 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +240 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.0.dist-info/METADATA +85 -0
  92. da4ml-0.5.0.dist-info/RECORD +96 -0
  93. da4ml-0.5.0.dist-info/WHEEL +6 -0
  94. da4ml-0.5.0.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,101 @@
1
+ library ieee;
2
+ use ieee.std_logic_1164.all;
3
+ use ieee.numeric_std.all;
4
+
5
+ entity shift_adder is
6
+ generic (
7
+ BW_INPUT0 : integer := 32;
8
+ BW_INPUT1 : integer := 32;
9
+ SIGNED0 : integer := 0;
10
+ SIGNED1 : integer := 0;
11
+ BW_OUT : integer := 32;
12
+ SHIFT1 : integer := 0;
13
+ IS_SUB : integer := 0
14
+ );
15
+ port (
16
+ in0 : in std_logic_vector(BW_INPUT0-1 downto 0);
17
+ in1 : in std_logic_vector(BW_INPUT1-1 downto 0);
18
+ result : out std_logic_vector(BW_OUT-1 downto 0)
19
+ );
20
+ end entity shift_adder;
21
+
22
+ architecture rtl of shift_adder is
23
+ function max(L, R: integer) return integer is
24
+ begin
25
+ if L > R then
26
+ return L;
27
+ else
28
+ return R;
29
+ end if;
30
+ end function;
31
+
32
+ function if_then_else(cond: boolean; val_true: integer; val_false: integer) return integer is
33
+ begin
34
+ if cond then
35
+ return val_true;
36
+ else
37
+ return val_false;
38
+ end if;
39
+ end function;
40
+
41
+ constant IN0_NEED_BITS : integer := if_then_else(SHIFT1 < 0, BW_INPUT0 - SHIFT1, BW_INPUT0);
42
+ constant IN1_NEED_BITS : integer := if_then_else(SHIFT1 > 0, BW_INPUT1 + SHIFT1, BW_INPUT1);
43
+ constant EXTRA_PAD : integer := if_then_else(SIGNED0 /= SIGNED1, IS_SUB + 1, IS_SUB);
44
+ constant BW_ADD : integer := max(IN0_NEED_BITS, IN1_NEED_BITS) + EXTRA_PAD + 1;
45
+
46
+ signal in0_ext : std_logic_vector(BW_ADD-1 downto 0);
47
+ signal in1_ext : std_logic_vector(BW_ADD-1 downto 0);
48
+ signal accum : std_logic_vector(BW_ADD-1 downto 0);
49
+
50
+ begin
51
+
52
+ -- Extension and shifting for input 0
53
+ gen_in0_shift_neg: if SHIFT1 < 0 generate
54
+ gen_in0_signed: if SIGNED0 = 1 generate
55
+ in0_ext <= std_logic_vector(resize(signed(in0), BW_ADD)) sll (-SHIFT1);
56
+ end generate;
57
+ gen_in0_unsigned: if SIGNED0 = 0 generate
58
+ in0_ext <= std_logic_vector(resize(unsigned(in0), BW_ADD)) sll (-SHIFT1);
59
+ end generate;
60
+ end generate;
61
+
62
+ gen_in0_shift_pos: if SHIFT1 >= 0 generate
63
+ gen_in0_signed: if SIGNED0 = 1 generate
64
+ in0_ext <= std_logic_vector(resize(signed(in0), BW_ADD));
65
+ end generate;
66
+ gen_in0_unsigned: if SIGNED0 = 0 generate
67
+ in0_ext <= std_logic_vector(resize(unsigned(in0), BW_ADD));
68
+ end generate;
69
+ end generate;
70
+
71
+ -- Extension and shifting for input 1
72
+ gen_in1_shift_pos: if SHIFT1 > 0 generate
73
+ gen_in1_signed: if SIGNED1 = 1 generate
74
+ in1_ext <= std_logic_vector(resize(signed(in1), BW_ADD)) sll SHIFT1;
75
+ end generate;
76
+ gen_in1_unsigned: if SIGNED1 = 0 generate
77
+ in1_ext <= std_logic_vector(resize(unsigned(in1), BW_ADD)) sll SHIFT1;
78
+ end generate;
79
+ end generate;
80
+
81
+ gen_in1_shift_neg: if SHIFT1 <= 0 generate
82
+ gen_in1_signed: if SIGNED1 = 1 generate
83
+ in1_ext <= std_logic_vector(resize(signed(in1), BW_ADD));
84
+ end generate;
85
+ gen_in1_unsigned: if SIGNED1 = 0 generate
86
+ in1_ext <= std_logic_vector(resize(unsigned(in1), BW_ADD));
87
+ end generate;
88
+ end generate;
89
+
90
+ -- Addition/subtraction logic
91
+ gen_sub: if IS_SUB = 1 generate
92
+ accum <= std_logic_vector(signed(in0_ext) - signed(in1_ext));
93
+ end generate;
94
+
95
+ gen_add: if IS_SUB = 0 generate
96
+ accum <= std_logic_vector(signed(in0_ext) + signed(in1_ext));
97
+ end generate;
98
+
99
+ result <= accum(BW_OUT-1 downto 0);
100
+
101
+ end architecture rtl;
@@ -0,0 +1,63 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal, overload
3
+
4
+ from ..cmvm.api import solver_options_t
5
+ from ..trace import FixedVariableArray, HWConfig
6
+
7
+ __all__ = ['trace_model']
8
+
9
+
10
+ @overload
11
+ def trace_model( # type: ignore
12
+ model: Callable,
13
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
14
+ solver_options: solver_options_t | None = None,
15
+ verbose: bool = False,
16
+ inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
17
+ inputs_kif: tuple[int, int, int] | None = None,
18
+ dump: Literal[False] = False,
19
+ ) -> tuple[FixedVariableArray, FixedVariableArray]: ...
20
+
21
+
22
+ @overload
23
+ def trace_model( # type: ignore
24
+ model: Callable,
25
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
26
+ solver_options: solver_options_t | None = None,
27
+ verbose: bool = False,
28
+ inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
29
+ inputs_kif: tuple[int, int, int] | None = None,
30
+ dump: Literal[True] = False, # type: ignore
31
+ ) -> dict[str, FixedVariableArray]: ...
32
+
33
+
34
+ def trace_model( # type: ignore
35
+ model: Callable,
36
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
37
+ solver_options: solver_options_t | None = None,
38
+ verbose: bool = False,
39
+ inputs: tuple[FixedVariableArray, ...] | None = None,
40
+ inputs_kif: tuple[int, int, int] | None = None,
41
+ dump=False,
42
+ ):
43
+ hwconf = HWConfig(*hwconf) if isinstance(hwconf, tuple) else hwconf
44
+
45
+ module = type(model).__module__
46
+ if module.startswith('keras.'):
47
+ import keras
48
+
49
+ from .hgq2 import trace_model as keras_trace_model
50
+
51
+ assert isinstance(model, keras.Model)
52
+
53
+ return keras_trace_model(
54
+ model,
55
+ hwconf,
56
+ solver_options=solver_options,
57
+ verbose=verbose,
58
+ inputs=inputs,
59
+ inputs_kif=inputs_kif,
60
+ dump=dump,
61
+ )
62
+ else:
63
+ raise ValueError(f'Unsupported model type: {type(model)}')
@@ -0,0 +1,3 @@
1
+ from .parser import trace_model
2
+
3
+ __all__ = ['trace_model']
@@ -0,0 +1,11 @@
1
+ from ._base import _registry
2
+ from .activation import *
3
+ from .attn import *
4
+ from .batchnorm import *
5
+ from .conv import *
6
+ from .dense import *
7
+ from .ops import *
8
+ from .pool import *
9
+ from .table import *
10
+
11
+ __all__ = ['_registry']
@@ -0,0 +1,132 @@
1
+ import typing
2
+ from collections.abc import Sequence
3
+ from typing import Any
4
+
5
+ import hgq
6
+ import keras
7
+ import numpy as np
8
+ from hgq.layers.core.base import MultipleQuantizers, Quantizer
9
+ from hgq.quantizer.internal import FixedPointQuantizerBase
10
+ from keras.ops import convert_to_numpy
11
+
12
+ from ....trace import FixedVariable, FixedVariableArray
13
+ from ....trace.ops import quantize, relu
14
+
15
+
16
+ def to_np_arr(x: Any) -> np.ndarray:
17
+ return np.asarray(convert_to_numpy(x))
18
+
19
+
20
+ def mirror_quantizer(q: Quantizer, v: FixedVariableArray) -> FixedVariableArray:
21
+ if q.scaler is not None:
22
+ v = v * (1.0 / q.scaler)
23
+ q_internal: FixedPointQuantizerBase = q.quantizer
24
+ kk, ki, kf = q_internal.kif
25
+ shape = (1,) + v.shape
26
+ kk = q_internal.bw_mapper.bw_to_x(kk, shape)
27
+ ki = q_internal.bw_mapper.bw_to_x(ki, shape)
28
+ kf = q_internal.bw_mapper.bw_to_x(kf, shape)
29
+ k, i, f = (to_np_arr(x).astype(np.int8)[0] for x in (kk, ki, kf))
30
+ round_mode, overflow_mode = q_internal.round_mode, q_internal.overflow_mode
31
+ rq = quantize(v, k, i, f, overflow_mode=overflow_mode, round_mode=round_mode)
32
+ if q.affine:
33
+ rq = rq * q.affine[0] + q.affine[1]
34
+ return rq
35
+
36
+
37
+ _registry: dict[type, 'type[ReplayOperationBase]'] = {}
38
+
39
+
40
+ class HandlerRegMeta(type):
41
+ """Metaclass for automatic registration of handler classes."""
42
+
43
+ def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]):
44
+ cls = super().__new__(mcs, name, bases, namespace)
45
+ if name == 'ReplayOperationBase':
46
+ return cls
47
+
48
+ handles: type | tuple[type, ...] = namespace['handles']
49
+ if not isinstance(handles, tuple):
50
+ handles = (handles,)
51
+
52
+ for handle in handles:
53
+ _registry[handle] = cls # type: ignore
54
+ return cls
55
+
56
+
57
+ class ReplayOperationBase(metaclass=HandlerRegMeta):
58
+ handles: tuple[type, ...] = ()
59
+ __activation_handled__ = False
60
+ __input_quantizer_handled__ = False
61
+ __output_quantizer_handled__ = False
62
+
63
+ def __init__(self, layer: 'keras.Operation'):
64
+ assert isinstance(layer, self.handles)
65
+ self.op: Any = layer
66
+
67
+ def call(self, *args, **kwargs) -> tuple[FixedVariableArray, ...] | FixedVariableArray: ...
68
+
69
+ def __call__(self, *args, **kwargs) -> tuple[FixedVariableArray, ...]:
70
+ assert all(not isinstance(a, FixedVariableArray) for a in kwargs.values())
71
+
72
+ if not isinstance(self.op, hgq.layers.QLayerBase):
73
+ r = self.call(*args, **kwargs)
74
+ return r if isinstance(r, tuple) else (r,)
75
+
76
+ layer: hgq.layers.QLayerBase = self.op
77
+ assert kwargs.pop('training', False) is False, 'Training mode is not supported in mirror operation'
78
+ assert kwargs.pop('mask', None) is None, 'Masking is not supported in mirror operation'
79
+
80
+ if not self.__input_quantizer_handled__:
81
+ assert len(args) == 1
82
+ inputs = args[0]
83
+
84
+ if layer.enable_iq:
85
+ if isinstance(inputs, Sequence):
86
+ assert isinstance(layer.iq, MultipleQuantizers)
87
+ inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
88
+ else:
89
+ assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
90
+ inputs = mirror_quantizer(layer.iq, inputs)
91
+
92
+ outputs = self.call(inputs, **kwargs)
93
+ else:
94
+ outputs = self.call(*args, **kwargs)
95
+ if isinstance(outputs, FixedVariable):
96
+ outputs = FixedVariableArray(np.array([outputs]))
97
+
98
+ if not self.__activation_handled__:
99
+ activation = getattr(layer, 'activation', keras.activations.linear)
100
+ if activation is not keras.activations.linear:
101
+ if activation is keras.activations.relu:
102
+ if isinstance(outputs, tuple):
103
+ assert len(outputs) == 1, 'ReLU activation is expected to have a single output'
104
+ outputs = (relu(outputs[0]),)
105
+ else:
106
+ outputs = relu(outputs)
107
+ else:
108
+ raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
109
+
110
+ if layer.enable_oq and not self.__output_quantizer_handled__:
111
+ if isinstance(outputs, tuple):
112
+ assert isinstance(layer.oq, MultipleQuantizers)
113
+ outputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.oq.quantizers, outputs))
114
+ else:
115
+ assert isinstance(layer.oq, Quantizer)
116
+ outputs = mirror_quantizer(layer.oq, outputs)
117
+
118
+ if isinstance(outputs, (FixedVariableArray, np.ndarray)):
119
+ outputs = (outputs,)
120
+
121
+ return outputs
122
+
123
+
124
+ class ReplayQuantizer(ReplayOperationBase):
125
+ handles = (Quantizer,)
126
+
127
+ def __init__(self, op: 'Quantizer'):
128
+ super().__init__(op)
129
+ assert isinstance(op.quantizer, FixedPointQuantizerBase)
130
+
131
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
132
+ return mirror_quantizer(self.op, inputs)
@@ -0,0 +1,81 @@
1
+ import keras
2
+ import numpy as np
3
+ from hgq.layers import (
4
+ QSoftmax,
5
+ QUnaryFunctionLUT,
6
+ )
7
+ from keras.layers import LeakyReLU, PReLU, ReLU
8
+
9
+ from ....trace import FixedVariableArray
10
+ from ....trace.ops import relu
11
+ from ._base import ReplayOperationBase, to_np_arr
12
+
13
+
14
+ class ReplayReLU(ReplayOperationBase):
15
+ handles = (ReLU, LeakyReLU, PReLU)
16
+
17
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
18
+ op = self.op
19
+ if isinstance(op, ReLU):
20
+ th, neg, maxv = op.threshold, op.negative_slope, op.max_value
21
+ elif isinstance(op, LeakyReLU):
22
+ th, neg, maxv = 0, op.negative_slope, None
23
+ elif isinstance(op, PReLU):
24
+ th, neg, maxv = 0, to_np_arr(op.alpha), None
25
+ else:
26
+ raise TypeError(f'Unsupported activation layer: {type(op)}')
27
+
28
+ if th == 0 and np.all(neg == 0) and maxv is None:
29
+ return relu(inputs)
30
+
31
+ pos_part = inputs if maxv is None else np.minimum(inputs, maxv) # type: ignore
32
+ pos_part = pos_part._vars.ravel()
33
+
34
+ if th != 0:
35
+ z_cond = (inputs - (th + 2.0 ** (-inputs.kif[2] - 1)))._vars.ravel()
36
+ else:
37
+ z_cond = inputs._vars.ravel()
38
+
39
+ neg_part = ((inputs[None] - th) * neg)._vars.ravel()
40
+ out = np.array([c.msb_mux(n, p) if c.low < 0 else p for c, n, p in zip(z_cond, neg_part, pos_part)])
41
+
42
+ return FixedVariableArray(out.reshape(inputs.shape), inputs.solver_options)
43
+
44
+
45
+ class ReplayQFunctionLUT(ReplayOperationBase):
46
+ __activation_handled__ = True
47
+ handles = (QUnaryFunctionLUT,)
48
+
49
+ def call(self, x: FixedVariableArray) -> FixedVariableArray:
50
+ op: QUnaryFunctionLUT = self.op
51
+
52
+ def activation(x) -> np.ndarray:
53
+ kx = keras.ops.convert_to_tensor(x[None])
54
+ kx = op.activation(kx)
55
+ return keras.ops.convert_to_numpy(kx[0]) # type: ignore
56
+
57
+ return x.apply(activation)
58
+
59
+
60
+ class ReplayQSoftmax(ReplayOperationBase):
61
+ handles = (QSoftmax,)
62
+
63
+ def call(self, inputs: FixedVariableArray, mask: None | FixedVariableArray = None) -> FixedVariableArray:
64
+ op: QSoftmax = self.op
65
+ inputs = inputs[None]
66
+
67
+ if op.stable:
68
+ inputs = np.amax(inputs, axis=op.axes, keepdims=True) - inputs # type: ignore
69
+
70
+ exp_inp = ReplayQFunctionLUT(op.exp_table)(inputs[0])[0]
71
+
72
+ if mask is not None:
73
+ exp_inp = mask[0] * exp_inp
74
+
75
+ sums = np.sum(exp_inp[None], axis=op.axes, keepdims=True)[0] # type: ignore
76
+ divisor = ReplayQFunctionLUT(op.inv_table)(sums)[0]
77
+
78
+ return exp_inp * divisor
79
+
80
+
81
+ __all__ = ['ReplayReLU', 'ReplayQFunctionLUT', 'ReplayQSoftmax']
@@ -0,0 +1,148 @@
1
+ import numpy as np
2
+ from hgq.layers import (
3
+ QLinformerAttention,
4
+ QMultiHeadAttention,
5
+ )
6
+
7
+ from ....trace import FixedVariableArray
8
+ from ....trace.ops import einsum
9
+ from ._base import ReplayOperationBase, mirror_quantizer
10
+ from .activation import ReplayQSoftmax
11
+ from .dense import ReplayQDense
12
+
13
+
14
+ def _compute_attention_mask(
15
+ query,
16
+ value,
17
+ query_mask=None,
18
+ value_mask=None,
19
+ key_mask=None,
20
+ attention_mask=None,
21
+ use_causal_mask=False,
22
+ ):
23
+ masks = []
24
+ if query_mask is not None:
25
+ masks.append(np.expand_dims(query_mask, -1)) # [Q, 1]
26
+ if value_mask is not None:
27
+ masks.append(np.expand_dims(value_mask, -2)) # [1, V]
28
+ if key_mask is not None:
29
+ masks.append(np.expand_dims(key_mask, -2)) # [1, V]
30
+ if use_causal_mask:
31
+ q = query.shape[0]
32
+ v = q if value is None else value.shape[0]
33
+ masks.append(np.tril(np.ones((q, v), dtype='uint8'))) # [Q, V]
34
+ masks.append(attention_mask)
35
+ if not masks:
36
+ return None
37
+
38
+ if any(isinstance(m, FixedVariableArray) for m in masks):
39
+ return np.prod(np.stack(masks, axis=0), axis=0)
40
+ else:
41
+ return None
42
+
43
+
44
+ def _masked_softmax(op, attention_scores, attention_mask=None):
45
+ # Normalize the attention scores to probabilities.
46
+ # attention_scores = [B, N, T, S]
47
+ if attention_mask is not None:
48
+ # The expand dim happens starting from the `num_heads` dimension,
49
+ # (<batch_dims>, num_heads, <query_attention_dims,
50
+ # key_attention_dims>)
51
+ mask_expansion_axis = -len(op._attention_axes) * 2 - 1
52
+ for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
53
+ attention_mask = np.expand_dims(attention_mask, axis=mask_expansion_axis)
54
+ return ReplayQSoftmax(op._softmax)(attention_scores[0], mask=attention_mask)[0][None]
55
+
56
+
57
+ def _compute_attention(op: QMultiHeadAttention, query, key, value, attention_mask=None, training=None):
58
+ # Take the dot product between "query" and "key" to get the raw
59
+ # attention scores.
60
+ attention_scores = einsum(op._dot_product_equation, key, query)
61
+
62
+ attention_scores = _masked_softmax(op, attention_scores, attention_mask)
63
+
64
+ # `context_layer` = [B, T, N, H]
65
+ attention_output = einsum(op._combine_equation, attention_scores, value)
66
+ return attention_output, attention_scores
67
+
68
+
69
+ class ReplayMHA(ReplayOperationBase):
70
+ handles = (QMultiHeadAttention,)
71
+ __input_quantizer_handled__ = True
72
+ __output_quantizer_handled__ = True
73
+
74
+ def call(
75
+ self,
76
+ query: FixedVariableArray,
77
+ value: FixedVariableArray,
78
+ key=None,
79
+ query_mask=None,
80
+ value_mask=None,
81
+ key_mask=None,
82
+ attention_mask=None,
83
+ return_attention_scores=False,
84
+ use_causal_mask=False,
85
+ ):
86
+ op: QMultiHeadAttention = self.op
87
+
88
+ if key is None:
89
+ key = value
90
+
91
+ _attention_mask = _compute_attention_mask(
92
+ query,
93
+ value,
94
+ query_mask=query_mask,
95
+ value_mask=value_mask,
96
+ key_mask=key_mask,
97
+ attention_mask=attention_mask,
98
+ use_causal_mask=use_causal_mask,
99
+ )
100
+
101
+ query = ReplayQDense(op._query_dense)(query)[0][None]
102
+ key = ReplayQDense(op._key_dense)(key)[0][None]
103
+ value = ReplayQDense(op._value_dense)(value)[0][None]
104
+
105
+ attention_output, attention_scores = _compute_attention(op, query, key, value, _attention_mask)
106
+ attention_output = ReplayQDense(op._output_dense)(attention_output[0])[0]
107
+
108
+ if op.enable_oq:
109
+ attention_output = mirror_quantizer(op.oq, attention_output)
110
+
111
+ if return_attention_scores:
112
+ return attention_output, attention_scores[0]
113
+ return attention_output
114
+
115
+
116
+ class ReplayQLinformerAttention(ReplayMHA):
117
+ handles = (QLinformerAttention,)
118
+
119
+ def call(
120
+ self,
121
+ query,
122
+ value,
123
+ key=None,
124
+ query_mask=None,
125
+ value_mask=None,
126
+ key_mask=None,
127
+ attention_mask=None,
128
+ return_attention_scores=False,
129
+ use_causal_mask=False,
130
+ ):
131
+ assert use_causal_mask is False, 'Causal mask is not supported in QLinformerAttention.'
132
+ key = key if key is not None else value
133
+ op: QLinformerAttention = self.op
134
+ key = ReplayQDense(op._lin_k_proj)(key)[0]
135
+ value = ReplayQDense(op._lin_v_proj)(value)[0]
136
+ return super().call(
137
+ query,
138
+ value,
139
+ key,
140
+ query_mask=query_mask,
141
+ value_mask=value_mask,
142
+ key_mask=key_mask,
143
+ attention_mask=attention_mask,
144
+ return_attention_scores=return_attention_scores,
145
+ )
146
+
147
+
148
+ __all__ = ['ReplayMHA', 'ReplayQLinformerAttention']
@@ -0,0 +1,15 @@
1
+ import numpy as np
2
+ from hgq.layers import QBatchNormalization
3
+
4
+ from ....trace import FixedVariableArray
5
+ from ._base import ReplayOperationBase
6
+
7
+
8
+ class ReplayQBatchNormalization(ReplayOperationBase):
9
+ handles = (QBatchNormalization,)
10
+
11
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
12
+ layer: QBatchNormalization = self.op
13
+ scale, bias = map(np.array, layer.qscaler_and_qoffset)
14
+ shape = layer._shape[1:]
15
+ return inputs * scale.reshape(shape) + bias.reshape(shape)