da4ml 0.2.1__py3-none-any.whl → 0.3.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/_version.py +2 -2
- da4ml/cmvm/types.py +95 -15
- da4ml/codegen/__init__.py +5 -4
- da4ml/codegen/cpp/__init__.py +2 -1
- da4ml/codegen/cpp/cpp_codegen.py +56 -23
- da4ml/codegen/cpp/hls_model.py +252 -0
- da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/cpp/source/binder_util.hh +56 -0
- da4ml/codegen/cpp/source/build_binder.mk +24 -0
- da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
- da4ml/codegen/verilog/__init__.py +2 -3
- da4ml/codegen/verilog/comb.py +65 -24
- da4ml/codegen/verilog/io_wrapper.py +36 -141
- da4ml/codegen/verilog/source/binder_util.hh +72 -0
- da4ml/codegen/verilog/source/mux.v +58 -0
- da4ml/codegen/verilog/source/negative.v +28 -0
- da4ml/codegen/verilog/source/shift_adder.v +4 -1
- da4ml/codegen/verilog/source/template.xdc +3 -0
- da4ml/codegen/verilog/verilog_model.py +36 -12
- da4ml/converter/__init__.py +0 -0
- da4ml/converter/hgq2/parser.py +105 -0
- da4ml/converter/hgq2/replica.py +383 -0
- da4ml/trace/__init__.py +2 -2
- da4ml/trace/fixed_variable.py +175 -16
- da4ml/trace/fixed_variable_array.py +109 -4
- da4ml/trace/ops/__init__.py +22 -6
- da4ml/trace/ops/conv_utils.py +147 -15
- da4ml/trace/ops/einsum_utils.py +9 -6
- da4ml/trace/ops/reduce_utils.py +103 -0
- da4ml/trace/pipeline.py +36 -34
- da4ml/trace/tracer.py +37 -7
- da4ml-0.3.0.post1.dist-info/METADATA +107 -0
- da4ml-0.3.0.post1.dist-info/RECORD +64 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
- da4ml-0.2.1.dist-info/METADATA +0 -65
- da4ml-0.2.1.dist-info/RECORD +0 -39
- /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
- {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/WHEEL +0 -0
- {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from math import prod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import hgq
|
|
7
|
+
import keras
|
|
8
|
+
import numpy as np
|
|
9
|
+
from hgq.layers import (
|
|
10
|
+
QBatchNormalization,
|
|
11
|
+
QBatchNormDense,
|
|
12
|
+
QConv1D,
|
|
13
|
+
QConv2D,
|
|
14
|
+
QConv3D,
|
|
15
|
+
QDense,
|
|
16
|
+
QEinsumDense,
|
|
17
|
+
QEinsumDenseBatchnorm,
|
|
18
|
+
QSum,
|
|
19
|
+
)
|
|
20
|
+
from hgq.layers.core.base import MultipleQuantizers, Quantizer
|
|
21
|
+
from hgq.quantizer.internal import FixedPointQuantizerBase
|
|
22
|
+
from keras.layers import ReLU
|
|
23
|
+
from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling
|
|
24
|
+
from keras.src.layers.pooling.base_pooling import BasePooling
|
|
25
|
+
from keras.src.ops.numpy import (
|
|
26
|
+
Add,
|
|
27
|
+
Concatenate,
|
|
28
|
+
Divide,
|
|
29
|
+
GetItem,
|
|
30
|
+
Moveaxis,
|
|
31
|
+
Multiply,
|
|
32
|
+
Ravel,
|
|
33
|
+
Repeat,
|
|
34
|
+
Reshape,
|
|
35
|
+
Subtract,
|
|
36
|
+
Sum,
|
|
37
|
+
Transpose,
|
|
38
|
+
TrueDivide,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
from ...trace import FixedVariableArray
|
|
42
|
+
from ...trace.ops import conv, einsum, pool, quantize, relu
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def mirror_quantizer(q: Quantizer, v: FixedVariableArray) -> FixedVariableArray:
|
|
46
|
+
q_internal: FixedPointQuantizerBase = q.quantizer
|
|
47
|
+
k, i, f = (np.array(x, dtype=np.int8)[0] for x in q_internal.kif)
|
|
48
|
+
round_mode, overflow_mode = q_internal.round_mode, q_internal.overflow_mode
|
|
49
|
+
return quantize(v, k, i, f, overflow_mode=overflow_mode, round_mode=round_mode)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_registry: dict[type, 'type[MirrorOperationBase]'] = {}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MirrorOperationMeta(type):
|
|
56
|
+
def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]):
|
|
57
|
+
cls = super().__new__(mcs, name, bases, namespace)
|
|
58
|
+
if name == 'MirrorOperationBase':
|
|
59
|
+
return cls
|
|
60
|
+
|
|
61
|
+
handles: type | tuple[type, ...] = namespace['handles']
|
|
62
|
+
if not isinstance(handles, tuple):
|
|
63
|
+
handles = (handles,)
|
|
64
|
+
|
|
65
|
+
for handle in handles:
|
|
66
|
+
_registry[handle] = cls # type: ignore
|
|
67
|
+
return cls
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class MirrorOperationBase(metaclass=MirrorOperationMeta):
|
|
71
|
+
handles: tuple[type, ...] = ()
|
|
72
|
+
|
|
73
|
+
def __init__(self, layer: 'keras.Operation'):
|
|
74
|
+
assert isinstance(layer, self.handles)
|
|
75
|
+
self.op: Any = layer
|
|
76
|
+
|
|
77
|
+
def call(self, *args, **kwargs) -> tuple[FixedVariableArray, ...] | FixedVariableArray: ...
|
|
78
|
+
|
|
79
|
+
def __call__(self, *args, **kwargs) -> tuple[FixedVariableArray, ...]:
|
|
80
|
+
assert all(not isinstance(a, FixedVariableArray) for a in kwargs.values())
|
|
81
|
+
assert all(isinstance(a, FixedVariableArray) or isinstance(a, Sequence) for a in args)
|
|
82
|
+
inputs = args[0] if len(args) == 1 else args
|
|
83
|
+
|
|
84
|
+
if not isinstance(self.op, hgq.layers.QLayerBase):
|
|
85
|
+
r = self.call(*args, **kwargs)
|
|
86
|
+
return r if isinstance(r, tuple) else (r,)
|
|
87
|
+
|
|
88
|
+
layer: hgq.layers.QLayerBase = self.op
|
|
89
|
+
assert kwargs.pop('training', False) is False, 'Training mode is not supported in mirror operation'
|
|
90
|
+
assert kwargs.pop('mask', None) is None, 'Masking is not supported in mirror operation'
|
|
91
|
+
|
|
92
|
+
if layer.enable_iq:
|
|
93
|
+
if isinstance(inputs, Sequence):
|
|
94
|
+
assert isinstance(layer.iq, MultipleQuantizers)
|
|
95
|
+
inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
|
|
96
|
+
else:
|
|
97
|
+
assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
|
|
98
|
+
inputs = mirror_quantizer(layer.iq, inputs)
|
|
99
|
+
|
|
100
|
+
outputs = self.call(inputs, **kwargs)
|
|
101
|
+
|
|
102
|
+
activation = getattr(layer, 'activation', keras.activations.linear)
|
|
103
|
+
if activation is not keras.activations.linear:
|
|
104
|
+
if activation is keras.activations.relu:
|
|
105
|
+
if isinstance(outputs, tuple):
|
|
106
|
+
assert len(outputs) == 1, 'ReLU activation is expected to have a single output'
|
|
107
|
+
outputs = (relu(outputs[0]),)
|
|
108
|
+
else:
|
|
109
|
+
outputs = relu(outputs)
|
|
110
|
+
else:
|
|
111
|
+
raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
|
|
112
|
+
|
|
113
|
+
if layer.enable_oq:
|
|
114
|
+
if isinstance(outputs, tuple):
|
|
115
|
+
assert isinstance(layer.oq, MultipleQuantizers)
|
|
116
|
+
outputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.oq.quantizers, outputs))
|
|
117
|
+
else:
|
|
118
|
+
assert isinstance(layer.oq, Quantizer)
|
|
119
|
+
outputs = mirror_quantizer(layer.oq, outputs)
|
|
120
|
+
|
|
121
|
+
if isinstance(outputs, FixedVariableArray):
|
|
122
|
+
outputs = (outputs,)
|
|
123
|
+
|
|
124
|
+
return outputs
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class MirrorQuantizer(MirrorOperationBase):
|
|
128
|
+
handles = (Quantizer,)
|
|
129
|
+
|
|
130
|
+
def __init__(self, op: 'Quantizer'):
|
|
131
|
+
super().__init__(op)
|
|
132
|
+
assert isinstance(op.quantizer, FixedPointQuantizerBase)
|
|
133
|
+
|
|
134
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
135
|
+
return mirror_quantizer(self.op, inputs)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class MirrorQDense(MirrorOperationBase):
|
|
139
|
+
handles = (QDense, QEinsumDense, QEinsumDenseBatchnorm, QBatchNormDense, QBatchNormalization, keras.layers.EinsumDense)
|
|
140
|
+
|
|
141
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
142
|
+
op = self.op
|
|
143
|
+
if isinstance(op, (QDense, QBatchNormDense)):
|
|
144
|
+
qkernel = op.qkernel
|
|
145
|
+
qbias = op.qbias
|
|
146
|
+
eq = '...c,cC->...C'
|
|
147
|
+
elif isinstance(op, (QEinsumDense, QEinsumDenseBatchnorm)):
|
|
148
|
+
qkernel = op.qkernel
|
|
149
|
+
qbias = op.qbias
|
|
150
|
+
eq = op.equation
|
|
151
|
+
elif isinstance(op, keras.layers.EinsumDense):
|
|
152
|
+
qkernel = op.kernel
|
|
153
|
+
qbias = op.bias
|
|
154
|
+
eq = op.equation
|
|
155
|
+
elif isinstance(op, QBatchNormalization):
|
|
156
|
+
qkernel, qbias = op.qscaler_and_qoffset
|
|
157
|
+
dim = inputs._vars.ndim
|
|
158
|
+
axis = op.axis
|
|
159
|
+
assert axis != 0, 'Cannot normalizing on batch axis'
|
|
160
|
+
axis -= 1
|
|
161
|
+
idx = ''.join(chr(ord('a') + i) for i in range(dim))
|
|
162
|
+
eq = f'...{idx},{idx[axis]}->...{idx}'
|
|
163
|
+
else:
|
|
164
|
+
raise TypeError(f'Unsupported layer type: {type(op)}')
|
|
165
|
+
|
|
166
|
+
qkernel = np.array(qkernel)
|
|
167
|
+
qbias = np.array(qbias) if qbias is not None else None
|
|
168
|
+
return (einsum(eq, inputs[None], qkernel) + qbias)[0]
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class MirrorQConv(MirrorOperationBase):
|
|
172
|
+
handles = (QConv1D, QConv2D, QConv3D)
|
|
173
|
+
|
|
174
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
175
|
+
layer: QConv1D | QConv2D | QConv3D = self.op
|
|
176
|
+
qkernel = np.array(layer.qkernel)
|
|
177
|
+
qbias = np.array(layer.qbias) if layer.qbias is not None else None
|
|
178
|
+
strides = layer.strides
|
|
179
|
+
padding = layer.padding
|
|
180
|
+
dilation_rate = layer.dilation_rate
|
|
181
|
+
groups = layer.groups
|
|
182
|
+
|
|
183
|
+
assert dilation_rate == 1 or all(d == 1 for d in dilation_rate), 'Dilation rate is not supported in mirror operation'
|
|
184
|
+
if layer.data_format == 'channels_first':
|
|
185
|
+
shape = (0,) + tuple(range(2, len(inputs.shape))) + (1,)
|
|
186
|
+
inputs = inputs.transpose(shape)
|
|
187
|
+
|
|
188
|
+
outputs = conv(inputs, qkernel, qbias, strides=strides, padding=padding, format=layer.data_format, groups=groups)
|
|
189
|
+
|
|
190
|
+
return outputs
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class MirrorReLU(MirrorOperationBase):
|
|
194
|
+
handles = (ReLU,)
|
|
195
|
+
|
|
196
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
197
|
+
return relu(inputs)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class MirrorReshape(MirrorOperationBase):
|
|
201
|
+
handles = (keras.layers.Reshape, keras.layers.Flatten, Reshape, Ravel)
|
|
202
|
+
|
|
203
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
204
|
+
if isinstance(self.op, (keras.layers.Flatten, Ravel)):
|
|
205
|
+
return inputs.ravel()
|
|
206
|
+
elif isinstance(self.op, keras.layers.Reshape):
|
|
207
|
+
return inputs.reshape(self.op.target_shape)
|
|
208
|
+
elif isinstance(self.op, Reshape):
|
|
209
|
+
return inputs.reshape(self.op.newshape[1:])
|
|
210
|
+
else:
|
|
211
|
+
raise TypeError(f'Unsupported layer type: {type(self.op)}')
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class MirrorMerge(MirrorOperationBase):
|
|
215
|
+
handles = (keras.layers.Add, keras.layers.Concatenate, hgq.layers.QAdd)
|
|
216
|
+
|
|
217
|
+
def call(self, inputs: tuple[FixedVariableArray, FixedVariableArray]) -> FixedVariableArray:
|
|
218
|
+
op: keras.Operation = self.op
|
|
219
|
+
if isinstance(op, (keras.layers.Add, hgq.layers.QAdd)):
|
|
220
|
+
return inputs[0] + inputs[1]
|
|
221
|
+
elif isinstance(op, keras.layers.Concatenate):
|
|
222
|
+
axis = op.axis
|
|
223
|
+
data = np.concatenate([v._vars for v in inputs], axis=axis)
|
|
224
|
+
return FixedVariableArray(data, inputs[0].solver_options)
|
|
225
|
+
else:
|
|
226
|
+
raise TypeError(f'Unsupported layer type: {type(op)}')
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class MirrorPool(MirrorOperationBase):
|
|
230
|
+
handles = (
|
|
231
|
+
hgq.layers.QAvgPool1D,
|
|
232
|
+
hgq.layers.QAvgPool2D,
|
|
233
|
+
hgq.layers.QAvgPool3D,
|
|
234
|
+
hgq.layers.QMaxPool1D,
|
|
235
|
+
hgq.layers.QMaxPool2D,
|
|
236
|
+
hgq.layers.QMaxPool3D,
|
|
237
|
+
hgq.layers.QGlobalAveragePooling1D,
|
|
238
|
+
hgq.layers.QGlobalMaxPooling1D,
|
|
239
|
+
hgq.layers.QGlobalAveragePooling2D,
|
|
240
|
+
hgq.layers.QGlobalMaxPooling2D,
|
|
241
|
+
hgq.layers.QGlobalAveragePooling3D,
|
|
242
|
+
hgq.layers.QGlobalMaxPooling3D,
|
|
243
|
+
keras.layers.AveragePooling1D,
|
|
244
|
+
keras.layers.AveragePooling2D,
|
|
245
|
+
keras.layers.AveragePooling3D,
|
|
246
|
+
keras.layers.MaxPooling1D,
|
|
247
|
+
keras.layers.MaxPooling2D,
|
|
248
|
+
keras.layers.MaxPooling3D,
|
|
249
|
+
keras.layers.GlobalAveragePooling1D,
|
|
250
|
+
keras.layers.GlobalMaxPooling1D,
|
|
251
|
+
keras.layers.GlobalAveragePooling2D,
|
|
252
|
+
keras.layers.GlobalMaxPooling2D,
|
|
253
|
+
keras.layers.GlobalAveragePooling3D,
|
|
254
|
+
keras.layers.GlobalMaxPooling3D,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
258
|
+
cname = self.op.__class__.__name__
|
|
259
|
+
if 'Max' in cname:
|
|
260
|
+
op = 'max'
|
|
261
|
+
else:
|
|
262
|
+
assert 'Average' in cname, f'Unsupported global pooling layer: {cname}'
|
|
263
|
+
op = 'avg'
|
|
264
|
+
|
|
265
|
+
data_format = self.op.data_format
|
|
266
|
+
if data_format == 'channels_first':
|
|
267
|
+
inputs = np.moveaxis(inputs, 1, -1) # type: ignore
|
|
268
|
+
|
|
269
|
+
if isinstance(self.op, BaseGlobalPooling):
|
|
270
|
+
pool_dim = self.op.input_spec.ndim - 2 # type: ignore
|
|
271
|
+
axis = tuple(range(pool_dim))
|
|
272
|
+
keepdims = self.op.keepdims
|
|
273
|
+
|
|
274
|
+
if op == 'max':
|
|
275
|
+
out = np.amax(inputs, axis=axis, keepdims=keepdims) # type: ignore
|
|
276
|
+
elif op == 'avg':
|
|
277
|
+
pool_size = prod(inputs.shape[:-1])
|
|
278
|
+
out = np.sum(inputs, axis=axis, keepdims=keepdims) / pool_size # type: ignore
|
|
279
|
+
else:
|
|
280
|
+
assert isinstance(self.op, BasePooling), f'Unsupported pooling layer: {type(self.op)}'
|
|
281
|
+
pool_size = self.op.pool_size
|
|
282
|
+
strides = self.op.strides
|
|
283
|
+
padding = self.op.padding
|
|
284
|
+
pool_dim = len(pool_size)
|
|
285
|
+
out = pool(
|
|
286
|
+
inputs,
|
|
287
|
+
pool_size=pool_size,
|
|
288
|
+
strides=strides,
|
|
289
|
+
padding=padding,
|
|
290
|
+
pool_type=op,
|
|
291
|
+
)
|
|
292
|
+
if data_format == 'channels_first':
|
|
293
|
+
out = np.moveaxis(out, -1, 1) # type: ignore
|
|
294
|
+
|
|
295
|
+
return out # type: ignore
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class MirrorRepeatVector(MirrorOperationBase):
|
|
299
|
+
handles = (keras.layers.RepeatVector,)
|
|
300
|
+
|
|
301
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
302
|
+
layer: keras.layers.RepeatVector = self.op
|
|
303
|
+
if layer.n == 1:
|
|
304
|
+
return inputs
|
|
305
|
+
# return FixedVariableArray(np.repeat(inputs._vars, layer.n, axis=0), inputs.solver_options)
|
|
306
|
+
return np.repeat(inputs[None], layer.n, axis=0)[0] # type: ignore
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class MirrorGetItem(MirrorOperationBase):
|
|
310
|
+
handles = (GetItem,)
|
|
311
|
+
|
|
312
|
+
def call(self, x: FixedVariableArray, key):
|
|
313
|
+
if isinstance(key, list):
|
|
314
|
+
key = tuple(key)
|
|
315
|
+
return x[None][key][0]
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class MirrorSum(MirrorOperationBase):
|
|
319
|
+
handles = (Sum,)
|
|
320
|
+
|
|
321
|
+
def call(self, x: FixedVariableArray, axis=None, keepdims=False):
|
|
322
|
+
return np.sum(x[None], axis=axis, keepdims=keepdims)[0] # type: ignore
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class MirrorQSum(MirrorOperationBase):
|
|
326
|
+
handles = (QSum,)
|
|
327
|
+
|
|
328
|
+
def call(self, x: FixedVariableArray):
|
|
329
|
+
layer: QSum = self.op
|
|
330
|
+
axes, scale, keepdims = layer.axes, layer.scale, layer.keepdims
|
|
331
|
+
return np.sum(x[None], axis=axes, keepdims=keepdims)[0] * scale # type: ignore
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class MirrorArithmetic(MirrorOperationBase):
|
|
335
|
+
handles = (Add, Subtract, Multiply, TrueDivide, Divide)
|
|
336
|
+
|
|
337
|
+
def call(self, x1: FixedVariableArray, x2: FixedVariableArray):
|
|
338
|
+
match self.op.__class__.__name__:
|
|
339
|
+
case 'Add':
|
|
340
|
+
return x1 + x2
|
|
341
|
+
case 'Subtract':
|
|
342
|
+
return x1 - x2
|
|
343
|
+
case 'Multiply':
|
|
344
|
+
return x1 * x2
|
|
345
|
+
case 'TrueDivide' | 'Divide':
|
|
346
|
+
return x1 / x2
|
|
347
|
+
case _:
|
|
348
|
+
raise TypeError(f'Unsupported arithmetic operation: {type(self.op)}')
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class MirrorConcatenate(MirrorOperationBase):
|
|
352
|
+
handles = (Concatenate,)
|
|
353
|
+
|
|
354
|
+
def call(self, xs: Sequence[FixedVariableArray]):
|
|
355
|
+
axis = self.op.axis
|
|
356
|
+
# return backend.numpy.concatenate(xs, axis=self.axis)
|
|
357
|
+
# return FixedVariableArray(np.concatenate([x._vars[None] for x in xs], axis=axis)[0], xs[0].solver_options)
|
|
358
|
+
return np.concatenate([x[None] for x in xs], axis=axis)[0] # type: ignore
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class MirrorRepeat(MirrorOperationBase):
|
|
362
|
+
handles = (Repeat,)
|
|
363
|
+
|
|
364
|
+
def call(self, x: FixedVariableArray):
|
|
365
|
+
repeats, axis = self.op.repeats, self.op.axis
|
|
366
|
+
# return FixedVariableArray(np.repeat(x._vars[None], repeats, axis=axis)[0], x.solver_options)
|
|
367
|
+
return np.repeat(x[None], repeats, axis=axis)[0] # type: ignore
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class MirrorTranspose(MirrorOperationBase):
|
|
371
|
+
handles = (Transpose,)
|
|
372
|
+
|
|
373
|
+
def call(self, x: FixedVariableArray):
|
|
374
|
+
axes = self.op.axes
|
|
375
|
+
return np.transpose(x, axes) # type: ignore
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class MirrorMoveaxis(MirrorOperationBase):
|
|
379
|
+
handles = (Moveaxis,)
|
|
380
|
+
|
|
381
|
+
def call(self, x: FixedVariableArray):
|
|
382
|
+
source, destination = self.op.source, self.op.destination
|
|
383
|
+
return np.moveaxis(x[None], source, destination)[0] # type: ignore
|
da4ml/trace/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from .fixed_variable import HWConfig
|
|
2
|
-
from .fixed_variable_array import FixedVariableArray
|
|
2
|
+
from .fixed_variable_array import FixedVariableArray, FixedVariableArrayInput
|
|
3
3
|
from .pipeline import to_pipeline
|
|
4
4
|
from .tracer import comb_trace
|
|
5
5
|
|
|
6
|
-
__all__ = ['to_pipeline', 'comb_trace', 'FixedVariableArray', 'HWConfig']
|
|
6
|
+
__all__ = ['to_pipeline', 'comb_trace', 'FixedVariableArray', 'HWConfig', 'FixedVariableArrayInput']
|
da4ml/trace/fixed_variable.py
CHANGED
|
@@ -43,9 +43,9 @@ class FixedVariable:
|
|
|
43
43
|
) -> None:
|
|
44
44
|
assert low <= high, f'low {low} must be less than high {high}'
|
|
45
45
|
|
|
46
|
-
if low == high:
|
|
46
|
+
if low == high and opr != 'new':
|
|
47
47
|
opr = 'const'
|
|
48
|
-
_factor =
|
|
48
|
+
_factor = _factor
|
|
49
49
|
_from = ()
|
|
50
50
|
|
|
51
51
|
low, high, step = Decimal(low), Decimal(high), Decimal(step)
|
|
@@ -72,15 +72,21 @@ class FixedVariable:
|
|
|
72
72
|
self.latency = _latency
|
|
73
73
|
self.cost = _cost
|
|
74
74
|
|
|
75
|
+
# Update latency for constant variables to match the current variable for piplining
|
|
76
|
+
|
|
77
|
+
for v in self._from:
|
|
78
|
+
if v.opr == 'const':
|
|
79
|
+
v.latency = self.latency
|
|
80
|
+
|
|
75
81
|
def get_cost_and_latency(self):
|
|
76
82
|
if self.opr == 'const':
|
|
77
83
|
return 0.0, 0.0
|
|
78
|
-
if self.opr in ('vadd', 'cadd'):
|
|
84
|
+
if self.opr in ('vadd', 'cadd', 'min', 'max'):
|
|
79
85
|
adder_size = self.hwconf.adder_size
|
|
80
86
|
carry_size = self.hwconf.carry_size
|
|
81
87
|
latency_cutoff = self.hwconf.latency_cutoff
|
|
82
88
|
|
|
83
|
-
if self.opr
|
|
89
|
+
if self.opr in ('min', 'max', 'vadd'):
|
|
84
90
|
assert len(self._from) == 2
|
|
85
91
|
v0, v1 = self._from
|
|
86
92
|
int0, int1 = v0.qint, v1.qint
|
|
@@ -89,8 +95,6 @@ class FixedVariable:
|
|
|
89
95
|
else:
|
|
90
96
|
assert len(self._from) == 1
|
|
91
97
|
assert self._data is not None, 'cadd must have data'
|
|
92
|
-
# int0 = self._from[0].qint
|
|
93
|
-
# int1 = QInterval(float(self._data), float(self._data), float(self.step))
|
|
94
98
|
_f = _const_f(self._data)
|
|
95
99
|
_cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
|
|
96
100
|
base_latency = self._from[0].latency
|
|
@@ -138,6 +142,12 @@ class FixedVariable:
|
|
|
138
142
|
k = self.low < 0
|
|
139
143
|
return k, i, f
|
|
140
144
|
|
|
145
|
+
@classmethod
|
|
146
|
+
def from_const(cls, const: float | Decimal, hwconf: HWConfig, latency: float, _factor: float | Decimal):
|
|
147
|
+
f = _const_f(const)
|
|
148
|
+
step = Decimal(2) ** -f
|
|
149
|
+
return cls(const, const, step, hwconf=hwconf, opr='const', _factor=_factor, latency=latency)
|
|
150
|
+
|
|
141
151
|
def __repr__(self) -> str:
|
|
142
152
|
if self._factor == 1:
|
|
143
153
|
return f'FixedVariable({self.low}, {self.high}, {self.step})'
|
|
@@ -185,7 +195,9 @@ class FixedVariable:
|
|
|
185
195
|
hwconf=self.hwconf,
|
|
186
196
|
)
|
|
187
197
|
|
|
188
|
-
def _const_add(self, other: float | Decimal):
|
|
198
|
+
def _const_add(self, other: float | Decimal | None):
|
|
199
|
+
if other is None:
|
|
200
|
+
return self
|
|
189
201
|
if not isinstance(other, (int, float, Decimal)):
|
|
190
202
|
other = float(other) # direct numpy to decimal raises error
|
|
191
203
|
other = Decimal(other)
|
|
@@ -222,7 +234,7 @@ class FixedVariable:
|
|
|
222
234
|
other: 'float|Decimal',
|
|
223
235
|
):
|
|
224
236
|
if other == 0:
|
|
225
|
-
return FixedVariable(0, 0, 1, hwconf=self.hwconf)
|
|
237
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
226
238
|
|
|
227
239
|
assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
|
|
228
240
|
|
|
@@ -267,7 +279,7 @@ class FixedVariable:
|
|
|
267
279
|
i = ceil(log2(val + step)) if not i else i
|
|
268
280
|
eps = step / 2 if round_mode == 'RND' else 0
|
|
269
281
|
val = (floor(val / step + eps) * step) % (Decimal(2) ** i)
|
|
270
|
-
return FixedVariable(val, val, step, hwconf=self.hwconf)
|
|
282
|
+
return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
|
|
271
283
|
|
|
272
284
|
step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
|
|
273
285
|
if step > self.step and round_mode == 'RND':
|
|
@@ -281,6 +293,10 @@ class FixedVariable:
|
|
|
281
293
|
low = Decimal(0)
|
|
282
294
|
high = _high
|
|
283
295
|
_factor = self._factor
|
|
296
|
+
|
|
297
|
+
if self.low == low and self.high == high and self.step == step:
|
|
298
|
+
return self
|
|
299
|
+
|
|
284
300
|
return FixedVariable(
|
|
285
301
|
low,
|
|
286
302
|
high,
|
|
@@ -301,7 +317,7 @@ class FixedVariable:
|
|
|
301
317
|
round_mode: str = 'TRN',
|
|
302
318
|
):
|
|
303
319
|
overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
|
|
304
|
-
assert overflow_mode in ('WRAP', 'SAT')
|
|
320
|
+
assert overflow_mode in ('WRAP', 'SAT', 'SAT_SM')
|
|
305
321
|
assert round_mode in ('TRN', 'RND')
|
|
306
322
|
|
|
307
323
|
_k, _i, _f = self.kif
|
|
@@ -312,13 +328,20 @@ class FixedVariable:
|
|
|
312
328
|
if f < _f and round_mode == 'RND':
|
|
313
329
|
return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
314
330
|
|
|
331
|
+
if overflow_mode in ('SAT', 'SAT_SM'):
|
|
332
|
+
step = Decimal(2) ** -f
|
|
333
|
+
_high = Decimal(2) ** i
|
|
334
|
+
high = _high - step
|
|
335
|
+
low = -_high * k if overflow_mode == 'SAT' else -high * k
|
|
336
|
+
return self.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
|
|
337
|
+
|
|
315
338
|
if self.low == self.high:
|
|
316
339
|
val = self.low
|
|
317
340
|
step = Decimal(2) ** -f
|
|
318
341
|
_high = Decimal(2) ** i
|
|
319
342
|
high, low = _high - step, -_high * k
|
|
320
343
|
val = (floor(val / step) * step - low) % (2 * _high) + low
|
|
321
|
-
return FixedVariable(val, val, step, hwconf=self.hwconf)
|
|
344
|
+
return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
|
|
322
345
|
|
|
323
346
|
# TODO: corner cases exists (e.g., overflow to negative, or negative overflow to high value)
|
|
324
347
|
# bit-exactness will be lost in these cases, but they should never happen (quantizers are used in a weird way)
|
|
@@ -327,17 +350,20 @@ class FixedVariable:
|
|
|
327
350
|
k = min(k, _k) if i >= _i else k
|
|
328
351
|
i = min(i, _i)
|
|
329
352
|
|
|
330
|
-
|
|
353
|
+
if i + k + f <= 0:
|
|
354
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
355
|
+
|
|
356
|
+
step = Decimal(2) ** -f
|
|
331
357
|
|
|
332
358
|
low = -k * Decimal(2) ** i
|
|
359
|
+
|
|
333
360
|
high = Decimal(2) ** i - step
|
|
334
361
|
_low, _high = self.low, self.high
|
|
335
362
|
|
|
336
363
|
if _low >= low and _high <= high:
|
|
337
364
|
low, high = _low, _high
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
return FixedVariable(0, 0, 1, hwconf=self.hwconf)
|
|
365
|
+
low = floor(low / step) * step
|
|
366
|
+
high = ceil(high / step) * step
|
|
341
367
|
|
|
342
368
|
return FixedVariable(
|
|
343
369
|
low,
|
|
@@ -345,7 +371,7 @@ class FixedVariable:
|
|
|
345
371
|
step,
|
|
346
372
|
_from=(self,),
|
|
347
373
|
_factor=abs(self._factor),
|
|
348
|
-
opr='wrap'
|
|
374
|
+
opr='wrap',
|
|
349
375
|
latency=self.latency,
|
|
350
376
|
hwconf=self.hwconf,
|
|
351
377
|
)
|
|
@@ -356,3 +382,136 @@ class FixedVariable:
|
|
|
356
382
|
_high = Decimal(2) ** i
|
|
357
383
|
low, high = k * _high, _high - step
|
|
358
384
|
return cls(low, high, step, **kwargs)
|
|
385
|
+
|
|
386
|
+
def msb_mux(self, a: 'FixedVariable', b: 'FixedVariable', qint: tuple[Decimal, Decimal, Decimal] | None = None):
|
|
387
|
+
assert isinstance(a, FixedVariable) and isinstance(b, FixedVariable), 'msb_mux requires two FixedVariables'
|
|
388
|
+
if self._factor < 0:
|
|
389
|
+
return (-self).msb_mux(b, a, qint)
|
|
390
|
+
|
|
391
|
+
if a._factor < 0:
|
|
392
|
+
qint = (-qint[1], -qint[0], qint[2]) if qint else None
|
|
393
|
+
return -(self.msb_mux(-a, -b, qint=qint))
|
|
394
|
+
|
|
395
|
+
_factor = a._factor
|
|
396
|
+
|
|
397
|
+
if qint is None:
|
|
398
|
+
qint = (min(a.low, b.low), max(a.high, b.high), min(a.step, b.step))
|
|
399
|
+
|
|
400
|
+
dlat, dcost = cost_add(a.qint, b.qint, 0, False, self.hwconf.adder_size, self.hwconf.carry_size)
|
|
401
|
+
return FixedVariable(
|
|
402
|
+
*qint,
|
|
403
|
+
_from=(self, a, b),
|
|
404
|
+
_factor=_factor,
|
|
405
|
+
opr='msb_mux',
|
|
406
|
+
latency=max(a.latency, b.latency, self.latency) + dlat,
|
|
407
|
+
hwconf=self.hwconf,
|
|
408
|
+
cost=dcost,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
def max_of(self, other):
|
|
412
|
+
if other == 0:
|
|
413
|
+
return self.relu()
|
|
414
|
+
if other == -float('inf'):
|
|
415
|
+
return self
|
|
416
|
+
if other == float('inf'):
|
|
417
|
+
raise ValueError('Cannot apply max_of with inf')
|
|
418
|
+
if not isinstance(other, FixedVariable):
|
|
419
|
+
other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=abs(self._factor))
|
|
420
|
+
|
|
421
|
+
if self.low >= other.high:
|
|
422
|
+
return self
|
|
423
|
+
if self.high <= other.low:
|
|
424
|
+
return other
|
|
425
|
+
|
|
426
|
+
qint = (max(self.low, other.low), max(self.high, other.high), min(self.step, other.step))
|
|
427
|
+
return (self - other).msb_mux(other, self, qint=qint)
|
|
428
|
+
|
|
429
|
+
def min_of(self, other):
|
|
430
|
+
if other == 0:
|
|
431
|
+
return (-self).relu()
|
|
432
|
+
if other == float('inf'):
|
|
433
|
+
return self
|
|
434
|
+
if other == -float('inf'):
|
|
435
|
+
raise ValueError('Cannot apply min_of with -inf')
|
|
436
|
+
if not isinstance(other, FixedVariable):
|
|
437
|
+
other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=(self._factor))
|
|
438
|
+
|
|
439
|
+
if self.high <= other.low:
|
|
440
|
+
return self
|
|
441
|
+
if self.low >= other.high:
|
|
442
|
+
return other
|
|
443
|
+
|
|
444
|
+
qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
|
|
445
|
+
return (self - other).msb_mux(self, other, qint=qint)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class FixedVariableInput(FixedVariable):
|
|
449
|
+
def __init__(
|
|
450
|
+
self,
|
|
451
|
+
latency: float | None = None,
|
|
452
|
+
hwconf=HWConfig(-1, -1, -1),
|
|
453
|
+
) -> None:
|
|
454
|
+
self.low = Decimal(1e10)
|
|
455
|
+
self.high = Decimal(-1e10)
|
|
456
|
+
self.step = Decimal(1e10)
|
|
457
|
+
self._factor = Decimal(1)
|
|
458
|
+
self._from: tuple[FixedVariable, ...] = ()
|
|
459
|
+
self.opr = 'new'
|
|
460
|
+
self._data = None
|
|
461
|
+
self.id = uuid4()
|
|
462
|
+
self.hwconf = hwconf
|
|
463
|
+
|
|
464
|
+
self.latency = latency if latency is not None else 0.0
|
|
465
|
+
self.cost = 0.0
|
|
466
|
+
|
|
467
|
+
def __add__(self, other):
|
|
468
|
+
raise ValueError('Cannot operate on unquantized input variable')
|
|
469
|
+
|
|
470
|
+
def __sub__(self, other):
|
|
471
|
+
raise ValueError('Cannot operate on unquantized input variable')
|
|
472
|
+
|
|
473
|
+
def __neg__(self):
|
|
474
|
+
raise ValueError('Cannot negate unquantized input variable')
|
|
475
|
+
|
|
476
|
+
def relu(self, *args, **kwargs):
|
|
477
|
+
raise ValueError('Cannot apply relu on unquantized input variable')
|
|
478
|
+
|
|
479
|
+
def max_of(self, other):
|
|
480
|
+
raise ValueError('Cannot apply max_of on unquantized input variable')
|
|
481
|
+
|
|
482
|
+
def min_of(self, other):
|
|
483
|
+
raise ValueError('Cannot apply min_of on unquantized input variable')
|
|
484
|
+
|
|
485
|
+
def quantize(
|
|
486
|
+
self,
|
|
487
|
+
k: int | bool,
|
|
488
|
+
i: int,
|
|
489
|
+
f: int,
|
|
490
|
+
overflow_mode: str = 'WRAP',
|
|
491
|
+
round_mode: str = 'TRN',
|
|
492
|
+
):
|
|
493
|
+
assert overflow_mode == 'WRAP'
|
|
494
|
+
|
|
495
|
+
if k + i + f <= 0:
|
|
496
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
497
|
+
|
|
498
|
+
if round_mode == 'RND':
|
|
499
|
+
return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
500
|
+
|
|
501
|
+
step = Decimal(2) ** -f
|
|
502
|
+
_high = Decimal(2) ** i
|
|
503
|
+
low, high = -_high * k, _high - step
|
|
504
|
+
self.high = max(self.high, high)
|
|
505
|
+
self.low = min(self.low, low)
|
|
506
|
+
self.step = min(self.step, step)
|
|
507
|
+
|
|
508
|
+
return FixedVariable(
|
|
509
|
+
low,
|
|
510
|
+
high,
|
|
511
|
+
step,
|
|
512
|
+
_from=(self,),
|
|
513
|
+
_factor=self._factor,
|
|
514
|
+
opr='wrap',
|
|
515
|
+
latency=self.latency,
|
|
516
|
+
hwconf=self.hwconf,
|
|
517
|
+
)
|