da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +194 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +240 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.0.dist-info/METADATA +85 -0
  92. da4ml-0.5.0.dist-info/RECORD +96 -0
  93. da4ml-0.5.0.dist-info/WHEEL +6 -0
  94. da4ml-0.5.0.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,149 @@
1
+ from typing import TypeVar
2
+
3
+ import numpy as np
4
+ from hgq.layers import (
5
+ QConv1D,
6
+ QConv2D,
7
+ QConv3D,
8
+ )
9
+ from keras import ops
10
+ from keras.src.ops.image import ExtractPatches, extract_patches_3d
11
+
12
+ from ....trace import FixedVariableArray
13
+ from ._base import ReplayOperationBase, to_np_arr
14
+
15
+ T = TypeVar('T', FixedVariableArray, np.ndarray)
16
+
17
+
18
+ def symbolic_extract_patches_3d(
19
+ images: T,
20
+ size: tuple[int, int, int],
21
+ strides: tuple[int, int, int],
22
+ dilation_rate: tuple[int, int, int],
23
+ padding: str,
24
+ data_format: str,
25
+ pad_value: float = 0,
26
+ ) -> T:
27
+ img_tensor = ops.cast(ops.reshape(ops.arange(images.size), images.shape), dtype='float32')
28
+ img_tensor = -img_tensor - 1
29
+ out_tensor = extract_patches_3d(
30
+ img_tensor[None],
31
+ size=size,
32
+ strides=strides,
33
+ dilation_rate=dilation_rate, # type: ignore
34
+ padding=padding,
35
+ data_format=data_format,
36
+ )[0]
37
+ out_index: np.ndarray = ops.convert_to_numpy(out_tensor).round().astype(np.int32) # type: ignore
38
+ mask = out_index == 0
39
+ out_index = np.where(mask, 0, -out_index - 1)
40
+ images = images.ravel()[out_index]
41
+
42
+ if isinstance(images, FixedVariableArray):
43
+ _vars = images._vars
44
+ _vars = np.where(mask, pad_value, _vars)
45
+ images = FixedVariableArray(_vars, images.solver_options) # type: ignore
46
+ else:
47
+ images = np.where(mask, pad_value, images)
48
+
49
+ return images
50
+
51
+
52
+ def symbolic_extract_patches(
53
+ images: T,
54
+ size: tuple[int, ...] | int,
55
+ strides: tuple[int, ...] | int,
56
+ dilation_rate: tuple[int, ...] | int,
57
+ padding: str,
58
+ data_format: str,
59
+ pad_value: float = 0,
60
+ ) -> T:
61
+ rank = images.ndim - 1
62
+ size = (size,) * rank if isinstance(size, int) else size
63
+ strides = (strides,) * rank if isinstance(strides, int) else strides
64
+ dilation_rate = (dilation_rate,) * rank if isinstance(dilation_rate, int) else dilation_rate
65
+
66
+ assert rank == len(size) == len(strides) == len(dilation_rate), (
67
+ f'Invalid rank {rank} for size {size}, strides {strides}, dilation_rate {dilation_rate}'
68
+ )
69
+
70
+ pad_rank = 3 - rank
71
+ _size: tuple[int, int, int] = (1,) * pad_rank + size # type: ignore
72
+ _strides: tuple[int, int, int] = (1,) * pad_rank + strides # type: ignore
73
+ _dilation_rate: tuple[int, int, int] = (1,) * pad_rank + dilation_rate # type: ignore
74
+
75
+ _pad = (1,) * pad_rank
76
+ if data_format == 'channels_first':
77
+ images = np.moveaxis(images, 0, -1) # type: ignore
78
+
79
+ *spa, ch = images.shape
80
+ images = images.reshape(*_pad, *spa, ch)
81
+
82
+ r = symbolic_extract_patches_3d(
83
+ images,
84
+ size=_size,
85
+ strides=_strides,
86
+ dilation_rate=_dilation_rate,
87
+ padding=padding,
88
+ data_format='channels_last',
89
+ pad_value=pad_value,
90
+ )
91
+
92
+ return r.reshape(r.shape[pad_rank:])
93
+
94
+
95
+ class ReplayExtractPatches(ReplayOperationBase):
96
+ handles = (ExtractPatches,)
97
+
98
+ def call(self, images: FixedVariableArray) -> FixedVariableArray:
99
+ op: ExtractPatches = self.op
100
+ pixel_shape = op.size
101
+ strides = op.strides
102
+ dilation_rate: int | tuple[int, int] = op.dilation_rate
103
+ padding = op.padding
104
+ data_format = op.data_format
105
+
106
+ if strides is None:
107
+ strides = 1
108
+
109
+ return symbolic_extract_patches(images, pixel_shape, strides, dilation_rate, padding, data_format)
110
+
111
+
112
+ class ReplayQConv(ReplayOperationBase):
113
+ handles = (QConv1D, QConv2D, QConv3D)
114
+
115
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
116
+ layer: QConv1D | QConv2D | QConv3D = self.op
117
+ qkernel = to_np_arr(layer.qkernel)
118
+ qbias = to_np_arr(layer.qbias) if layer.qbias is not None else None
119
+ strides = layer.strides
120
+ padding = layer.padding
121
+ dilation_rate = layer.dilation_rate
122
+ groups = layer.groups
123
+
124
+ if layer.data_format == 'channels_first':
125
+ inputs = np.moveaxis(inputs, 0, -1) # type: ignore
126
+
127
+ x = symbolic_extract_patches(
128
+ inputs,
129
+ size=layer.kernel_size,
130
+ strides=strides,
131
+ dilation_rate=dilation_rate,
132
+ padding=padding,
133
+ data_format=layer.data_format,
134
+ )
135
+
136
+ ch_out = qkernel.shape[-1]
137
+
138
+ _ch_out = ch_out // groups
139
+
140
+ x = x.reshape(*x.shape[:-1], -1, groups)
141
+ kernel = qkernel.reshape(-1, groups, _ch_out)
142
+
143
+ outputs = np.einsum('...ig,igo->...go', x, kernel) # type: ignore
144
+ outputs = outputs.reshape(*outputs.shape[:-2], -1) + qbias
145
+
146
+ if layer.data_format == 'channels_first':
147
+ outputs: FixedVariableArray = np.moveaxis(outputs, -1, 0) # type: ignore
148
+
149
+ return outputs
@@ -0,0 +1,39 @@
1
+ import keras
2
+ from hgq.layers import (
3
+ QBatchNormDense,
4
+ QDense,
5
+ QEinsumDense,
6
+ QEinsumDenseBatchnorm,
7
+ )
8
+
9
+ from ....trace import FixedVariableArray
10
+ from ....trace.ops import einsum
11
+ from ._base import ReplayOperationBase, to_np_arr
12
+
13
+
14
+ class ReplayQDense(ReplayOperationBase):
15
+ handles = (QDense, QEinsumDense, QEinsumDenseBatchnorm, QBatchNormDense, keras.layers.EinsumDense)
16
+
17
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
18
+ op = self.op
19
+ if isinstance(op, (QDense, QBatchNormDense)):
20
+ qkernel = op.qkernel
21
+ qbias = op.qbias
22
+ eq = '...c,cC->...C'
23
+ elif isinstance(op, (QEinsumDense, QEinsumDenseBatchnorm)):
24
+ qkernel = op.qkernel
25
+ qbias = op.qbias
26
+ eq = op.equation
27
+ elif isinstance(op, keras.layers.EinsumDense):
28
+ qkernel = op.kernel
29
+ qbias = op.bias
30
+ eq = op.equation
31
+ else:
32
+ raise TypeError(f'Unsupported layer type: {type(op)}')
33
+
34
+ qkernel = to_np_arr(qkernel)
35
+ qbias = to_np_arr(qbias) if qbias is not None else None
36
+ return (einsum(eq, inputs[None], qkernel) + qbias)[0]
37
+
38
+
39
+ __all__ = ['ReplayQDense']
@@ -0,0 +1,240 @@
1
+ from collections.abc import Sequence
2
+
3
+ import keras
4
+ import numpy as np
5
+ from hgq.layers import (
6
+ QAdd,
7
+ QAveragePow2,
8
+ QDot,
9
+ QEinsum,
10
+ QMaximum,
11
+ QMeanPow2,
12
+ QMinimum,
13
+ QMultiply,
14
+ QSubtract,
15
+ QSum,
16
+ )
17
+ from keras.src.ops.numpy import (
18
+ Abs,
19
+ Absolute,
20
+ Add,
21
+ Concatenate,
22
+ Divide,
23
+ Dot,
24
+ Einsum,
25
+ GetItem,
26
+ Matmul,
27
+ Max,
28
+ Maximum,
29
+ Min,
30
+ Minimum,
31
+ Moveaxis,
32
+ Multiply,
33
+ Ravel,
34
+ Repeat,
35
+ Reshape,
36
+ Subtract,
37
+ Sum,
38
+ Transpose,
39
+ TrueDivide,
40
+ )
41
+
42
+ from ....trace import FixedVariableArray
43
+ from ....trace.ops import einsum
44
+ from ._base import ReplayOperationBase
45
+
46
+
47
+ class ReplayReshape(ReplayOperationBase):
48
+ handles = (keras.layers.Reshape, keras.layers.Flatten, Reshape, Ravel)
49
+
50
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
51
+ if isinstance(self.op, (keras.layers.Flatten, Ravel)):
52
+ return inputs.ravel()
53
+ elif isinstance(self.op, keras.layers.Reshape):
54
+ return inputs.reshape(self.op.target_shape)
55
+ elif isinstance(self.op, Reshape):
56
+ return inputs.reshape(self.op.newshape[1:])
57
+ else:
58
+ raise TypeError(f'Unsupported layer type: {type(self.op)}')
59
+
60
+
61
+ class ReplayMerge(ReplayOperationBase):
62
+ handles = (keras.layers.Add, keras.layers.Concatenate, QAdd, QMultiply, QSubtract, QMaximum, QMinimum, QAveragePow2)
63
+
64
+ def call(self, inputs: tuple[FixedVariableArray, ...]) -> FixedVariableArray:
65
+ op = self.op
66
+ name = op.__class__.__name__
67
+ if name.startswith('Q'):
68
+ name = name[1:]
69
+ _inputs: FixedVariableArray = np.stack(np.broadcast_arrays(*inputs), axis=0) # type: ignore
70
+ match name:
71
+ case 'Add':
72
+ return np.sum(_inputs, axis=0) # type: ignore
73
+ case 'AveragePow2':
74
+ return np.sum(_inputs, axis=0) * op._scale # type: ignore
75
+ case 'Subtract':
76
+ assert len(_inputs) == 2, 'Subtract operation requires exactly two inputs'
77
+ return _inputs[0] - _inputs[1]
78
+ case 'Multiply':
79
+ return np.prod(_inputs, axis=0) # type: ignore
80
+ case 'Maximum':
81
+ return np.amax(_inputs, axis=0) # type: ignore
82
+ case 'Minimum':
83
+ return np.amin(_inputs, axis=0) # type: ignore
84
+ case 'Concatenate':
85
+ return np.concatenate(_inputs, axis=op.axis) # type: ignore
86
+
87
+ case _:
88
+ raise TypeError(f'Unsupported layer type: {type(op)}')
89
+
90
+
91
+ class ReplayRepeatVector(ReplayOperationBase):
92
+ handles = (keras.layers.RepeatVector,)
93
+
94
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
95
+ layer: keras.layers.RepeatVector = self.op
96
+ if layer.n == 1:
97
+ return inputs
98
+ # return FixedVariableArray(np.repeat(inputs._vars, layer.n, axis=0), inputs.solver_options)
99
+ return np.repeat(inputs[None], layer.n, axis=0)[0] # type: ignore
100
+
101
+
102
+ class ReplayGetItem(ReplayOperationBase):
103
+ handles = (GetItem,)
104
+
105
+ def call(self, x: FixedVariableArray, key):
106
+ if isinstance(key, list):
107
+ key = tuple(key)
108
+ return x[None][key][0]
109
+
110
+
111
+ class ReplayReduction(ReplayOperationBase):
112
+ handles = (Sum, Max, Min)
113
+
114
+ def call(self, x: FixedVariableArray, axis=None, keepdims=False):
115
+ if isinstance(self.op, Sum):
116
+ op = np.sum
117
+ elif isinstance(self.op, Max):
118
+ op = np.amax
119
+ elif isinstance(self.op, Min):
120
+ op = np.amin
121
+ return op(x[None], axis=axis, keepdims=keepdims)[0] # type: ignore
122
+
123
+
124
+ class ReplayQReduction(ReplayOperationBase):
125
+ handles = (QSum, QMeanPow2)
126
+
127
+ def call(self, x: FixedVariableArray):
128
+ layer: QSum = self.op
129
+ axes, scale, keepdims = layer.axes, layer.scale, layer.keepdims
130
+ return np.sum(x[None], axis=axes, keepdims=keepdims)[0] * scale # type: ignore
131
+
132
+
133
+ class ReplayArithmetic(ReplayOperationBase):
134
+ handles = (Add, Subtract, Multiply, TrueDivide, Divide, Maximum, Minimum)
135
+
136
+ def call(self, x1: FixedVariableArray, x2: FixedVariableArray):
137
+ name = self.op.__class__.__name__
138
+ match name:
139
+ case 'Add':
140
+ return x1 + x2
141
+ case 'Subtract':
142
+ return x1 - x2
143
+ case 'Multiply':
144
+ return x1 * x2
145
+ case 'TrueDivide' | 'Divide':
146
+ return x1 / x2
147
+ case 'Maximum':
148
+ return np.maximum(x1, x2) # type: ignore
149
+ case 'Minimum':
150
+ return np.minimum(x1, x2) # type: ignore
151
+ case _:
152
+ raise TypeError(f'Unsupported arithmetic operation: {type(self.op)}')
153
+
154
+
155
+ class ReplayConcatenate(ReplayOperationBase):
156
+ handles = (Concatenate,)
157
+
158
+ def call(self, xs: Sequence[FixedVariableArray]):
159
+ axis = self.op.axis
160
+ # return backend.numpy.concatenate(xs, axis=self.axis)
161
+ # return FixedVariableArray(np.concatenate([x._vars[None] for x in xs], axis=axis)[0], xs[0].solver_options)
162
+ return np.concatenate([x[None] for x in xs], axis=axis)[0] # type: ignore
163
+
164
+
165
+ class ReplayRepeat(ReplayOperationBase):
166
+ handles = (Repeat,)
167
+
168
+ def call(self, x: FixedVariableArray):
169
+ repeats, axis = self.op.repeats, self.op.axis
170
+ # return FixedVariableArray(np.repeat(x._vars[None], repeats, axis=axis)[0], x.solver_options)
171
+ return np.repeat(x[None], repeats, axis=axis)[0] # type: ignore
172
+
173
+
174
+ class ReplayTranspose(ReplayOperationBase):
175
+ handles = (Transpose,)
176
+
177
+ def call(self, x: FixedVariableArray):
178
+ axes = self.op.axes
179
+ return np.transpose(x, axes) # type: ignore
180
+
181
+
182
+ class ReplayMoveaxis(ReplayOperationBase):
183
+ handles = (Moveaxis,)
184
+
185
+ def call(self, x: FixedVariableArray):
186
+ source, destination = self.op.source, self.op.destination
187
+ return np.moveaxis(x[None], source, destination)[0] # type: ignore
188
+
189
+
190
+ class ReplayNoOp(ReplayOperationBase):
191
+ __noop_layers = []
192
+ for k, v in keras.layers.__dict__.items():
193
+ name = k.lower()
194
+ if 'dropout' in name or 'random' in name or 'noise' in name:
195
+ __noop_layers.append(v)
196
+
197
+ handles = tuple(__noop_layers)
198
+
199
+ def call(self, x: FixedVariableArray, training=False) -> FixedVariableArray:
200
+ assert not training, 'Training mode is not supported in mirror operation'
201
+ return x
202
+
203
+
204
+ class ReplayEinsum(ReplayOperationBase):
205
+ handles = (QEinsum, Einsum, QDot, keras.layers.Dot)
206
+
207
+ def call(self, inputs: tuple[FixedVariableArray, FixedVariableArray]) -> FixedVariableArray:
208
+ op = self.op
209
+ if isinstance(op, QEinsum):
210
+ eq = op.equation
211
+ elif isinstance(op, Einsum):
212
+ eq = op.subscripts
213
+ else: # QDot/Dot
214
+ dim0, dim1 = inputs[0].ndim + 1, inputs[1].ndim + 1
215
+ letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'[0 : dim0 + dim1]
216
+ sub0, sub1 = letters[:dim0], letters[dim0 : dim0 + dim1]
217
+ axes = list(op.axes) if not isinstance(op.axes, int) else [op.axes, op.axes]
218
+ idx0, idx1 = axes[0] if axes[0] >= 0 else axes[0] % dim0, axes[1] if axes[1] >= 0 else axes[1] % dim1
219
+ sub1 = sub1[:idx1] + sub0[idx0] + sub1[idx1 + 1 :]
220
+ sub_out = list(sub0 + sub1)
221
+ sub_out.remove(sub0[idx0])
222
+ sub_out.remove(sub0[idx0])
223
+ sub_out = ''.join(sub_out)
224
+ eq = f'{sub0},{sub1}->{sub_out}'
225
+ assert len(inputs) == 2, 'Only (Q)Einsum operations with exactly two inputs are supported'
226
+ return einsum(eq, inputs[0][None], inputs[1][None])[0]
227
+
228
+
229
+ class ReplayMatmul(ReplayOperationBase):
230
+ handles = (Matmul, Dot)
231
+
232
+ def call(self, x1: FixedVariableArray, x2: FixedVariableArray) -> FixedVariableArray:
233
+ return x1 @ x2
234
+
235
+
236
+ class ReplayAbs(ReplayOperationBase):
237
+ handles = (Absolute, Abs)
238
+
239
+ def call(self, x: FixedVariableArray) -> FixedVariableArray:
240
+ return np.abs(x) # type: ignore
@@ -0,0 +1,107 @@
1
+ from math import prod
2
+
3
+ import hgq
4
+ import keras
5
+ import numpy as np
6
+ from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling
7
+ from keras.src.layers.pooling.base_pooling import BasePooling
8
+
9
+ from ....trace import FixedVariableArray
10
+ from ._base import ReplayOperationBase
11
+ from .conv import symbolic_extract_patches
12
+
13
+
14
+ class ReplayPool(ReplayOperationBase):
15
+ handles = (
16
+ hgq.layers.QAvgPool1D,
17
+ hgq.layers.QAvgPool2D,
18
+ hgq.layers.QAvgPool3D,
19
+ hgq.layers.QMaxPool1D,
20
+ hgq.layers.QMaxPool2D,
21
+ hgq.layers.QMaxPool3D,
22
+ hgq.layers.QGlobalAveragePooling1D,
23
+ hgq.layers.QGlobalMaxPooling1D,
24
+ hgq.layers.QGlobalAveragePooling2D,
25
+ hgq.layers.QGlobalMaxPooling2D,
26
+ hgq.layers.QGlobalAveragePooling3D,
27
+ hgq.layers.QGlobalMaxPooling3D,
28
+ keras.layers.AveragePooling1D,
29
+ keras.layers.AveragePooling2D,
30
+ keras.layers.AveragePooling3D,
31
+ keras.layers.MaxPooling1D,
32
+ keras.layers.MaxPooling2D,
33
+ keras.layers.MaxPooling3D,
34
+ keras.layers.GlobalAveragePooling1D,
35
+ keras.layers.GlobalMaxPooling1D,
36
+ keras.layers.GlobalAveragePooling2D,
37
+ keras.layers.GlobalMaxPooling2D,
38
+ keras.layers.GlobalAveragePooling3D,
39
+ keras.layers.GlobalMaxPooling3D,
40
+ )
41
+
42
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
43
+ cname = self.op.__class__.__name__
44
+ if 'Max' in cname:
45
+ op = 'max'
46
+ else:
47
+ assert 'Average' in cname, f'Unsupported global pooling layer: {cname}'
48
+ op = 'avg'
49
+
50
+ data_format = self.op.data_format
51
+ if data_format == 'channels_first':
52
+ inputs = np.moveaxis(inputs, 1, -1) # type: ignore
53
+
54
+ if isinstance(self.op, BaseGlobalPooling):
55
+ pool_dim = self.op.input_spec.ndim - 2 # type: ignore
56
+ axis = tuple(range(pool_dim))
57
+ keepdims = self.op.keepdims
58
+
59
+ if op == 'max':
60
+ out = np.amax(inputs, axis=axis, keepdims=keepdims) # type: ignore
61
+ elif op == 'avg':
62
+ pool_size = prod(inputs.shape[:-1])
63
+ out = np.sum(inputs, axis=axis, keepdims=keepdims) / pool_size # type: ignore
64
+ else:
65
+ assert isinstance(self.op, BasePooling), f'Unknown pooling layer: {type(self.op)}'
66
+ pool_size = self.op.pool_size
67
+ strides = self.op.strides
68
+ padding = self.op.padding
69
+ pool_dim = len(pool_size)
70
+ ch = inputs.shape[-1]
71
+ x = symbolic_extract_patches(
72
+ inputs,
73
+ pool_size,
74
+ strides,
75
+ dilation_rate=1,
76
+ padding=padding,
77
+ data_format='channels_last',
78
+ )
79
+ x = x.reshape(x.shape[:-1] + (-1, ch))
80
+
81
+ if padding == 'same':
82
+ mask = symbolic_extract_patches(
83
+ np.ones(inputs.shape, dtype=np.int32),
84
+ pool_size,
85
+ strides,
86
+ dilation_rate=1,
87
+ padding=padding,
88
+ data_format='channels_last',
89
+ ).reshape(x.shape)
90
+ elif padding == 'valid':
91
+ mask = np.ones(x.shape, dtype=np.int32)
92
+ else:
93
+ raise ValueError(f'Unknown padding type: {padding}')
94
+
95
+ if op == 'max':
96
+ _vars = np.where(mask, x._vars, -(65535**2))
97
+ x = FixedVariableArray(_vars, x.solver_options)
98
+ out = np.max(x, axis=-2) # type: ignore
99
+ elif op == 'avg':
100
+ out = np.sum(x, axis=-2) / np.sum(mask, axis=-2) # type: ignore
101
+ else:
102
+ raise ValueError(f'Unknown pooling operation: {op}')
103
+
104
+ if data_format == 'channels_first':
105
+ out = np.moveaxis(out, -1, 1) # type: ignore
106
+
107
+ return out # type: ignore