da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +194 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +240 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.0.dist-info/METADATA +85 -0
- da4ml-0.5.0.dist-info/RECORD +96 -0
- da4ml-0.5.0.dist-info/WHEEL +6 -0
- da4ml-0.5.0.dist-info/entry_points.txt +3 -0
- da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from typing import TypeVar
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from hgq.layers import (
|
|
5
|
+
QConv1D,
|
|
6
|
+
QConv2D,
|
|
7
|
+
QConv3D,
|
|
8
|
+
)
|
|
9
|
+
from keras import ops
|
|
10
|
+
from keras.src.ops.image import ExtractPatches, extract_patches_3d
|
|
11
|
+
|
|
12
|
+
from ....trace import FixedVariableArray
|
|
13
|
+
from ._base import ReplayOperationBase, to_np_arr
|
|
14
|
+
|
|
15
|
+
T = TypeVar('T', FixedVariableArray, np.ndarray)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def symbolic_extract_patches_3d(
|
|
19
|
+
images: T,
|
|
20
|
+
size: tuple[int, int, int],
|
|
21
|
+
strides: tuple[int, int, int],
|
|
22
|
+
dilation_rate: tuple[int, int, int],
|
|
23
|
+
padding: str,
|
|
24
|
+
data_format: str,
|
|
25
|
+
pad_value: float = 0,
|
|
26
|
+
) -> T:
|
|
27
|
+
img_tensor = ops.cast(ops.reshape(ops.arange(images.size), images.shape), dtype='float32')
|
|
28
|
+
img_tensor = -img_tensor - 1
|
|
29
|
+
out_tensor = extract_patches_3d(
|
|
30
|
+
img_tensor[None],
|
|
31
|
+
size=size,
|
|
32
|
+
strides=strides,
|
|
33
|
+
dilation_rate=dilation_rate, # type: ignore
|
|
34
|
+
padding=padding,
|
|
35
|
+
data_format=data_format,
|
|
36
|
+
)[0]
|
|
37
|
+
out_index: np.ndarray = ops.convert_to_numpy(out_tensor).round().astype(np.int32) # type: ignore
|
|
38
|
+
mask = out_index == 0
|
|
39
|
+
out_index = np.where(mask, 0, -out_index - 1)
|
|
40
|
+
images = images.ravel()[out_index]
|
|
41
|
+
|
|
42
|
+
if isinstance(images, FixedVariableArray):
|
|
43
|
+
_vars = images._vars
|
|
44
|
+
_vars = np.where(mask, pad_value, _vars)
|
|
45
|
+
images = FixedVariableArray(_vars, images.solver_options) # type: ignore
|
|
46
|
+
else:
|
|
47
|
+
images = np.where(mask, pad_value, images)
|
|
48
|
+
|
|
49
|
+
return images
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def symbolic_extract_patches(
|
|
53
|
+
images: T,
|
|
54
|
+
size: tuple[int, ...] | int,
|
|
55
|
+
strides: tuple[int, ...] | int,
|
|
56
|
+
dilation_rate: tuple[int, ...] | int,
|
|
57
|
+
padding: str,
|
|
58
|
+
data_format: str,
|
|
59
|
+
pad_value: float = 0,
|
|
60
|
+
) -> T:
|
|
61
|
+
rank = images.ndim - 1
|
|
62
|
+
size = (size,) * rank if isinstance(size, int) else size
|
|
63
|
+
strides = (strides,) * rank if isinstance(strides, int) else strides
|
|
64
|
+
dilation_rate = (dilation_rate,) * rank if isinstance(dilation_rate, int) else dilation_rate
|
|
65
|
+
|
|
66
|
+
assert rank == len(size) == len(strides) == len(dilation_rate), (
|
|
67
|
+
f'Invalid rank {rank} for size {size}, strides {strides}, dilation_rate {dilation_rate}'
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
pad_rank = 3 - rank
|
|
71
|
+
_size: tuple[int, int, int] = (1,) * pad_rank + size # type: ignore
|
|
72
|
+
_strides: tuple[int, int, int] = (1,) * pad_rank + strides # type: ignore
|
|
73
|
+
_dilation_rate: tuple[int, int, int] = (1,) * pad_rank + dilation_rate # type: ignore
|
|
74
|
+
|
|
75
|
+
_pad = (1,) * pad_rank
|
|
76
|
+
if data_format == 'channels_first':
|
|
77
|
+
images = np.moveaxis(images, 0, -1) # type: ignore
|
|
78
|
+
|
|
79
|
+
*spa, ch = images.shape
|
|
80
|
+
images = images.reshape(*_pad, *spa, ch)
|
|
81
|
+
|
|
82
|
+
r = symbolic_extract_patches_3d(
|
|
83
|
+
images,
|
|
84
|
+
size=_size,
|
|
85
|
+
strides=_strides,
|
|
86
|
+
dilation_rate=_dilation_rate,
|
|
87
|
+
padding=padding,
|
|
88
|
+
data_format='channels_last',
|
|
89
|
+
pad_value=pad_value,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return r.reshape(r.shape[pad_rank:])
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ReplayExtractPatches(ReplayOperationBase):
|
|
96
|
+
handles = (ExtractPatches,)
|
|
97
|
+
|
|
98
|
+
def call(self, images: FixedVariableArray) -> FixedVariableArray:
|
|
99
|
+
op: ExtractPatches = self.op
|
|
100
|
+
pixel_shape = op.size
|
|
101
|
+
strides = op.strides
|
|
102
|
+
dilation_rate: int | tuple[int, int] = op.dilation_rate
|
|
103
|
+
padding = op.padding
|
|
104
|
+
data_format = op.data_format
|
|
105
|
+
|
|
106
|
+
if strides is None:
|
|
107
|
+
strides = 1
|
|
108
|
+
|
|
109
|
+
return symbolic_extract_patches(images, pixel_shape, strides, dilation_rate, padding, data_format)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class ReplayQConv(ReplayOperationBase):
|
|
113
|
+
handles = (QConv1D, QConv2D, QConv3D)
|
|
114
|
+
|
|
115
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
116
|
+
layer: QConv1D | QConv2D | QConv3D = self.op
|
|
117
|
+
qkernel = to_np_arr(layer.qkernel)
|
|
118
|
+
qbias = to_np_arr(layer.qbias) if layer.qbias is not None else None
|
|
119
|
+
strides = layer.strides
|
|
120
|
+
padding = layer.padding
|
|
121
|
+
dilation_rate = layer.dilation_rate
|
|
122
|
+
groups = layer.groups
|
|
123
|
+
|
|
124
|
+
if layer.data_format == 'channels_first':
|
|
125
|
+
inputs = np.moveaxis(inputs, 0, -1) # type: ignore
|
|
126
|
+
|
|
127
|
+
x = symbolic_extract_patches(
|
|
128
|
+
inputs,
|
|
129
|
+
size=layer.kernel_size,
|
|
130
|
+
strides=strides,
|
|
131
|
+
dilation_rate=dilation_rate,
|
|
132
|
+
padding=padding,
|
|
133
|
+
data_format=layer.data_format,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
ch_out = qkernel.shape[-1]
|
|
137
|
+
|
|
138
|
+
_ch_out = ch_out // groups
|
|
139
|
+
|
|
140
|
+
x = x.reshape(*x.shape[:-1], -1, groups)
|
|
141
|
+
kernel = qkernel.reshape(-1, groups, _ch_out)
|
|
142
|
+
|
|
143
|
+
outputs = np.einsum('...ig,igo->...go', x, kernel) # type: ignore
|
|
144
|
+
outputs = outputs.reshape(*outputs.shape[:-2], -1) + qbias
|
|
145
|
+
|
|
146
|
+
if layer.data_format == 'channels_first':
|
|
147
|
+
outputs: FixedVariableArray = np.moveaxis(outputs, -1, 0) # type: ignore
|
|
148
|
+
|
|
149
|
+
return outputs
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from hgq.layers import (
|
|
3
|
+
QBatchNormDense,
|
|
4
|
+
QDense,
|
|
5
|
+
QEinsumDense,
|
|
6
|
+
QEinsumDenseBatchnorm,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from ....trace import FixedVariableArray
|
|
10
|
+
from ....trace.ops import einsum
|
|
11
|
+
from ._base import ReplayOperationBase, to_np_arr
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ReplayQDense(ReplayOperationBase):
|
|
15
|
+
handles = (QDense, QEinsumDense, QEinsumDenseBatchnorm, QBatchNormDense, keras.layers.EinsumDense)
|
|
16
|
+
|
|
17
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
18
|
+
op = self.op
|
|
19
|
+
if isinstance(op, (QDense, QBatchNormDense)):
|
|
20
|
+
qkernel = op.qkernel
|
|
21
|
+
qbias = op.qbias
|
|
22
|
+
eq = '...c,cC->...C'
|
|
23
|
+
elif isinstance(op, (QEinsumDense, QEinsumDenseBatchnorm)):
|
|
24
|
+
qkernel = op.qkernel
|
|
25
|
+
qbias = op.qbias
|
|
26
|
+
eq = op.equation
|
|
27
|
+
elif isinstance(op, keras.layers.EinsumDense):
|
|
28
|
+
qkernel = op.kernel
|
|
29
|
+
qbias = op.bias
|
|
30
|
+
eq = op.equation
|
|
31
|
+
else:
|
|
32
|
+
raise TypeError(f'Unsupported layer type: {type(op)}')
|
|
33
|
+
|
|
34
|
+
qkernel = to_np_arr(qkernel)
|
|
35
|
+
qbias = to_np_arr(qbias) if qbias is not None else None
|
|
36
|
+
return (einsum(eq, inputs[None], qkernel) + qbias)[0]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
__all__ = ['ReplayQDense']
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
from hgq.layers import (
|
|
6
|
+
QAdd,
|
|
7
|
+
QAveragePow2,
|
|
8
|
+
QDot,
|
|
9
|
+
QEinsum,
|
|
10
|
+
QMaximum,
|
|
11
|
+
QMeanPow2,
|
|
12
|
+
QMinimum,
|
|
13
|
+
QMultiply,
|
|
14
|
+
QSubtract,
|
|
15
|
+
QSum,
|
|
16
|
+
)
|
|
17
|
+
from keras.src.ops.numpy import (
|
|
18
|
+
Abs,
|
|
19
|
+
Absolute,
|
|
20
|
+
Add,
|
|
21
|
+
Concatenate,
|
|
22
|
+
Divide,
|
|
23
|
+
Dot,
|
|
24
|
+
Einsum,
|
|
25
|
+
GetItem,
|
|
26
|
+
Matmul,
|
|
27
|
+
Max,
|
|
28
|
+
Maximum,
|
|
29
|
+
Min,
|
|
30
|
+
Minimum,
|
|
31
|
+
Moveaxis,
|
|
32
|
+
Multiply,
|
|
33
|
+
Ravel,
|
|
34
|
+
Repeat,
|
|
35
|
+
Reshape,
|
|
36
|
+
Subtract,
|
|
37
|
+
Sum,
|
|
38
|
+
Transpose,
|
|
39
|
+
TrueDivide,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from ....trace import FixedVariableArray
|
|
43
|
+
from ....trace.ops import einsum
|
|
44
|
+
from ._base import ReplayOperationBase
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ReplayReshape(ReplayOperationBase):
|
|
48
|
+
handles = (keras.layers.Reshape, keras.layers.Flatten, Reshape, Ravel)
|
|
49
|
+
|
|
50
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
51
|
+
if isinstance(self.op, (keras.layers.Flatten, Ravel)):
|
|
52
|
+
return inputs.ravel()
|
|
53
|
+
elif isinstance(self.op, keras.layers.Reshape):
|
|
54
|
+
return inputs.reshape(self.op.target_shape)
|
|
55
|
+
elif isinstance(self.op, Reshape):
|
|
56
|
+
return inputs.reshape(self.op.newshape[1:])
|
|
57
|
+
else:
|
|
58
|
+
raise TypeError(f'Unsupported layer type: {type(self.op)}')
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ReplayMerge(ReplayOperationBase):
|
|
62
|
+
handles = (keras.layers.Add, keras.layers.Concatenate, QAdd, QMultiply, QSubtract, QMaximum, QMinimum, QAveragePow2)
|
|
63
|
+
|
|
64
|
+
def call(self, inputs: tuple[FixedVariableArray, ...]) -> FixedVariableArray:
|
|
65
|
+
op = self.op
|
|
66
|
+
name = op.__class__.__name__
|
|
67
|
+
if name.startswith('Q'):
|
|
68
|
+
name = name[1:]
|
|
69
|
+
_inputs: FixedVariableArray = np.stack(np.broadcast_arrays(*inputs), axis=0) # type: ignore
|
|
70
|
+
match name:
|
|
71
|
+
case 'Add':
|
|
72
|
+
return np.sum(_inputs, axis=0) # type: ignore
|
|
73
|
+
case 'AveragePow2':
|
|
74
|
+
return np.sum(_inputs, axis=0) * op._scale # type: ignore
|
|
75
|
+
case 'Subtract':
|
|
76
|
+
assert len(_inputs) == 2, 'Subtract operation requires exactly two inputs'
|
|
77
|
+
return _inputs[0] - _inputs[1]
|
|
78
|
+
case 'Multiply':
|
|
79
|
+
return np.prod(_inputs, axis=0) # type: ignore
|
|
80
|
+
case 'Maximum':
|
|
81
|
+
return np.amax(_inputs, axis=0) # type: ignore
|
|
82
|
+
case 'Minimum':
|
|
83
|
+
return np.amin(_inputs, axis=0) # type: ignore
|
|
84
|
+
case 'Concatenate':
|
|
85
|
+
return np.concatenate(_inputs, axis=op.axis) # type: ignore
|
|
86
|
+
|
|
87
|
+
case _:
|
|
88
|
+
raise TypeError(f'Unsupported layer type: {type(op)}')
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ReplayRepeatVector(ReplayOperationBase):
|
|
92
|
+
handles = (keras.layers.RepeatVector,)
|
|
93
|
+
|
|
94
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
95
|
+
layer: keras.layers.RepeatVector = self.op
|
|
96
|
+
if layer.n == 1:
|
|
97
|
+
return inputs
|
|
98
|
+
# return FixedVariableArray(np.repeat(inputs._vars, layer.n, axis=0), inputs.solver_options)
|
|
99
|
+
return np.repeat(inputs[None], layer.n, axis=0)[0] # type: ignore
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class ReplayGetItem(ReplayOperationBase):
|
|
103
|
+
handles = (GetItem,)
|
|
104
|
+
|
|
105
|
+
def call(self, x: FixedVariableArray, key):
|
|
106
|
+
if isinstance(key, list):
|
|
107
|
+
key = tuple(key)
|
|
108
|
+
return x[None][key][0]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ReplayReduction(ReplayOperationBase):
|
|
112
|
+
handles = (Sum, Max, Min)
|
|
113
|
+
|
|
114
|
+
def call(self, x: FixedVariableArray, axis=None, keepdims=False):
|
|
115
|
+
if isinstance(self.op, Sum):
|
|
116
|
+
op = np.sum
|
|
117
|
+
elif isinstance(self.op, Max):
|
|
118
|
+
op = np.amax
|
|
119
|
+
elif isinstance(self.op, Min):
|
|
120
|
+
op = np.amin
|
|
121
|
+
return op(x[None], axis=axis, keepdims=keepdims)[0] # type: ignore
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ReplayQReduction(ReplayOperationBase):
|
|
125
|
+
handles = (QSum, QMeanPow2)
|
|
126
|
+
|
|
127
|
+
def call(self, x: FixedVariableArray):
|
|
128
|
+
layer: QSum = self.op
|
|
129
|
+
axes, scale, keepdims = layer.axes, layer.scale, layer.keepdims
|
|
130
|
+
return np.sum(x[None], axis=axes, keepdims=keepdims)[0] * scale # type: ignore
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class ReplayArithmetic(ReplayOperationBase):
|
|
134
|
+
handles = (Add, Subtract, Multiply, TrueDivide, Divide, Maximum, Minimum)
|
|
135
|
+
|
|
136
|
+
def call(self, x1: FixedVariableArray, x2: FixedVariableArray):
|
|
137
|
+
name = self.op.__class__.__name__
|
|
138
|
+
match name:
|
|
139
|
+
case 'Add':
|
|
140
|
+
return x1 + x2
|
|
141
|
+
case 'Subtract':
|
|
142
|
+
return x1 - x2
|
|
143
|
+
case 'Multiply':
|
|
144
|
+
return x1 * x2
|
|
145
|
+
case 'TrueDivide' | 'Divide':
|
|
146
|
+
return x1 / x2
|
|
147
|
+
case 'Maximum':
|
|
148
|
+
return np.maximum(x1, x2) # type: ignore
|
|
149
|
+
case 'Minimum':
|
|
150
|
+
return np.minimum(x1, x2) # type: ignore
|
|
151
|
+
case _:
|
|
152
|
+
raise TypeError(f'Unsupported arithmetic operation: {type(self.op)}')
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class ReplayConcatenate(ReplayOperationBase):
|
|
156
|
+
handles = (Concatenate,)
|
|
157
|
+
|
|
158
|
+
def call(self, xs: Sequence[FixedVariableArray]):
|
|
159
|
+
axis = self.op.axis
|
|
160
|
+
# return backend.numpy.concatenate(xs, axis=self.axis)
|
|
161
|
+
# return FixedVariableArray(np.concatenate([x._vars[None] for x in xs], axis=axis)[0], xs[0].solver_options)
|
|
162
|
+
return np.concatenate([x[None] for x in xs], axis=axis)[0] # type: ignore
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ReplayRepeat(ReplayOperationBase):
|
|
166
|
+
handles = (Repeat,)
|
|
167
|
+
|
|
168
|
+
def call(self, x: FixedVariableArray):
|
|
169
|
+
repeats, axis = self.op.repeats, self.op.axis
|
|
170
|
+
# return FixedVariableArray(np.repeat(x._vars[None], repeats, axis=axis)[0], x.solver_options)
|
|
171
|
+
return np.repeat(x[None], repeats, axis=axis)[0] # type: ignore
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class ReplayTranspose(ReplayOperationBase):
|
|
175
|
+
handles = (Transpose,)
|
|
176
|
+
|
|
177
|
+
def call(self, x: FixedVariableArray):
|
|
178
|
+
axes = self.op.axes
|
|
179
|
+
return np.transpose(x, axes) # type: ignore
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ReplayMoveaxis(ReplayOperationBase):
|
|
183
|
+
handles = (Moveaxis,)
|
|
184
|
+
|
|
185
|
+
def call(self, x: FixedVariableArray):
|
|
186
|
+
source, destination = self.op.source, self.op.destination
|
|
187
|
+
return np.moveaxis(x[None], source, destination)[0] # type: ignore
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class ReplayNoOp(ReplayOperationBase):
|
|
191
|
+
__noop_layers = []
|
|
192
|
+
for k, v in keras.layers.__dict__.items():
|
|
193
|
+
name = k.lower()
|
|
194
|
+
if 'dropout' in name or 'random' in name or 'noise' in name:
|
|
195
|
+
__noop_layers.append(v)
|
|
196
|
+
|
|
197
|
+
handles = tuple(__noop_layers)
|
|
198
|
+
|
|
199
|
+
def call(self, x: FixedVariableArray, training=False) -> FixedVariableArray:
|
|
200
|
+
assert not training, 'Training mode is not supported in mirror operation'
|
|
201
|
+
return x
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class ReplayEinsum(ReplayOperationBase):
|
|
205
|
+
handles = (QEinsum, Einsum, QDot, keras.layers.Dot)
|
|
206
|
+
|
|
207
|
+
def call(self, inputs: tuple[FixedVariableArray, FixedVariableArray]) -> FixedVariableArray:
|
|
208
|
+
op = self.op
|
|
209
|
+
if isinstance(op, QEinsum):
|
|
210
|
+
eq = op.equation
|
|
211
|
+
elif isinstance(op, Einsum):
|
|
212
|
+
eq = op.subscripts
|
|
213
|
+
else: # QDot/Dot
|
|
214
|
+
dim0, dim1 = inputs[0].ndim + 1, inputs[1].ndim + 1
|
|
215
|
+
letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'[0 : dim0 + dim1]
|
|
216
|
+
sub0, sub1 = letters[:dim0], letters[dim0 : dim0 + dim1]
|
|
217
|
+
axes = list(op.axes) if not isinstance(op.axes, int) else [op.axes, op.axes]
|
|
218
|
+
idx0, idx1 = axes[0] if axes[0] >= 0 else axes[0] % dim0, axes[1] if axes[1] >= 0 else axes[1] % dim1
|
|
219
|
+
sub1 = sub1[:idx1] + sub0[idx0] + sub1[idx1 + 1 :]
|
|
220
|
+
sub_out = list(sub0 + sub1)
|
|
221
|
+
sub_out.remove(sub0[idx0])
|
|
222
|
+
sub_out.remove(sub0[idx0])
|
|
223
|
+
sub_out = ''.join(sub_out)
|
|
224
|
+
eq = f'{sub0},{sub1}->{sub_out}'
|
|
225
|
+
assert len(inputs) == 2, 'Only (Q)Einsum operations with exactly two inputs are supported'
|
|
226
|
+
return einsum(eq, inputs[0][None], inputs[1][None])[0]
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class ReplayMatmul(ReplayOperationBase):
|
|
230
|
+
handles = (Matmul, Dot)
|
|
231
|
+
|
|
232
|
+
def call(self, x1: FixedVariableArray, x2: FixedVariableArray) -> FixedVariableArray:
|
|
233
|
+
return x1 @ x2
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class ReplayAbs(ReplayOperationBase):
|
|
237
|
+
handles = (Absolute, Abs)
|
|
238
|
+
|
|
239
|
+
def call(self, x: FixedVariableArray) -> FixedVariableArray:
|
|
240
|
+
return np.abs(x) # type: ignore
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from math import prod
|
|
2
|
+
|
|
3
|
+
import hgq
|
|
4
|
+
import keras
|
|
5
|
+
import numpy as np
|
|
6
|
+
from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling
|
|
7
|
+
from keras.src.layers.pooling.base_pooling import BasePooling
|
|
8
|
+
|
|
9
|
+
from ....trace import FixedVariableArray
|
|
10
|
+
from ._base import ReplayOperationBase
|
|
11
|
+
from .conv import symbolic_extract_patches
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ReplayPool(ReplayOperationBase):
|
|
15
|
+
handles = (
|
|
16
|
+
hgq.layers.QAvgPool1D,
|
|
17
|
+
hgq.layers.QAvgPool2D,
|
|
18
|
+
hgq.layers.QAvgPool3D,
|
|
19
|
+
hgq.layers.QMaxPool1D,
|
|
20
|
+
hgq.layers.QMaxPool2D,
|
|
21
|
+
hgq.layers.QMaxPool3D,
|
|
22
|
+
hgq.layers.QGlobalAveragePooling1D,
|
|
23
|
+
hgq.layers.QGlobalMaxPooling1D,
|
|
24
|
+
hgq.layers.QGlobalAveragePooling2D,
|
|
25
|
+
hgq.layers.QGlobalMaxPooling2D,
|
|
26
|
+
hgq.layers.QGlobalAveragePooling3D,
|
|
27
|
+
hgq.layers.QGlobalMaxPooling3D,
|
|
28
|
+
keras.layers.AveragePooling1D,
|
|
29
|
+
keras.layers.AveragePooling2D,
|
|
30
|
+
keras.layers.AveragePooling3D,
|
|
31
|
+
keras.layers.MaxPooling1D,
|
|
32
|
+
keras.layers.MaxPooling2D,
|
|
33
|
+
keras.layers.MaxPooling3D,
|
|
34
|
+
keras.layers.GlobalAveragePooling1D,
|
|
35
|
+
keras.layers.GlobalMaxPooling1D,
|
|
36
|
+
keras.layers.GlobalAveragePooling2D,
|
|
37
|
+
keras.layers.GlobalMaxPooling2D,
|
|
38
|
+
keras.layers.GlobalAveragePooling3D,
|
|
39
|
+
keras.layers.GlobalMaxPooling3D,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
43
|
+
cname = self.op.__class__.__name__
|
|
44
|
+
if 'Max' in cname:
|
|
45
|
+
op = 'max'
|
|
46
|
+
else:
|
|
47
|
+
assert 'Average' in cname, f'Unsupported global pooling layer: {cname}'
|
|
48
|
+
op = 'avg'
|
|
49
|
+
|
|
50
|
+
data_format = self.op.data_format
|
|
51
|
+
if data_format == 'channels_first':
|
|
52
|
+
inputs = np.moveaxis(inputs, 1, -1) # type: ignore
|
|
53
|
+
|
|
54
|
+
if isinstance(self.op, BaseGlobalPooling):
|
|
55
|
+
pool_dim = self.op.input_spec.ndim - 2 # type: ignore
|
|
56
|
+
axis = tuple(range(pool_dim))
|
|
57
|
+
keepdims = self.op.keepdims
|
|
58
|
+
|
|
59
|
+
if op == 'max':
|
|
60
|
+
out = np.amax(inputs, axis=axis, keepdims=keepdims) # type: ignore
|
|
61
|
+
elif op == 'avg':
|
|
62
|
+
pool_size = prod(inputs.shape[:-1])
|
|
63
|
+
out = np.sum(inputs, axis=axis, keepdims=keepdims) / pool_size # type: ignore
|
|
64
|
+
else:
|
|
65
|
+
assert isinstance(self.op, BasePooling), f'Unknown pooling layer: {type(self.op)}'
|
|
66
|
+
pool_size = self.op.pool_size
|
|
67
|
+
strides = self.op.strides
|
|
68
|
+
padding = self.op.padding
|
|
69
|
+
pool_dim = len(pool_size)
|
|
70
|
+
ch = inputs.shape[-1]
|
|
71
|
+
x = symbolic_extract_patches(
|
|
72
|
+
inputs,
|
|
73
|
+
pool_size,
|
|
74
|
+
strides,
|
|
75
|
+
dilation_rate=1,
|
|
76
|
+
padding=padding,
|
|
77
|
+
data_format='channels_last',
|
|
78
|
+
)
|
|
79
|
+
x = x.reshape(x.shape[:-1] + (-1, ch))
|
|
80
|
+
|
|
81
|
+
if padding == 'same':
|
|
82
|
+
mask = symbolic_extract_patches(
|
|
83
|
+
np.ones(inputs.shape, dtype=np.int32),
|
|
84
|
+
pool_size,
|
|
85
|
+
strides,
|
|
86
|
+
dilation_rate=1,
|
|
87
|
+
padding=padding,
|
|
88
|
+
data_format='channels_last',
|
|
89
|
+
).reshape(x.shape)
|
|
90
|
+
elif padding == 'valid':
|
|
91
|
+
mask = np.ones(x.shape, dtype=np.int32)
|
|
92
|
+
else:
|
|
93
|
+
raise ValueError(f'Unknown padding type: {padding}')
|
|
94
|
+
|
|
95
|
+
if op == 'max':
|
|
96
|
+
_vars = np.where(mask, x._vars, -(65535**2))
|
|
97
|
+
x = FixedVariableArray(_vars, x.solver_options)
|
|
98
|
+
out = np.max(x, axis=-2) # type: ignore
|
|
99
|
+
elif op == 'avg':
|
|
100
|
+
out = np.sum(x, axis=-2) / np.sum(mask, axis=-2) # type: ignore
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(f'Unknown pooling operation: {op}')
|
|
103
|
+
|
|
104
|
+
if data_format == 'channels_first':
|
|
105
|
+
out = np.moveaxis(out, -1, 1) # type: ignore
|
|
106
|
+
|
|
107
|
+
return out # type: ignore
|