da4ml 0.2.1__py3-none-any.whl → 0.3.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (55) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/cmvm/types.py +95 -15
  3. da4ml/codegen/__init__.py +5 -4
  4. da4ml/codegen/cpp/__init__.py +2 -1
  5. da4ml/codegen/cpp/cpp_codegen.py +56 -23
  6. da4ml/codegen/cpp/hls_model.py +252 -0
  7. da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
  8. da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
  9. da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
  10. da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
  11. da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
  12. da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
  13. da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
  14. da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
  15. da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
  16. da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
  17. da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
  18. da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
  19. da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
  20. da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
  21. da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
  22. da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
  23. da4ml/codegen/cpp/source/binder_util.hh +56 -0
  24. da4ml/codegen/cpp/source/build_binder.mk +24 -0
  25. da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
  26. da4ml/codegen/verilog/__init__.py +2 -3
  27. da4ml/codegen/verilog/comb.py +65 -24
  28. da4ml/codegen/verilog/io_wrapper.py +36 -141
  29. da4ml/codegen/verilog/source/binder_util.hh +72 -0
  30. da4ml/codegen/verilog/source/mux.v +58 -0
  31. da4ml/codegen/verilog/source/negative.v +28 -0
  32. da4ml/codegen/verilog/source/shift_adder.v +4 -1
  33. da4ml/codegen/verilog/source/template.xdc +3 -0
  34. da4ml/codegen/verilog/verilog_model.py +36 -12
  35. da4ml/converter/__init__.py +0 -0
  36. da4ml/converter/hgq2/parser.py +105 -0
  37. da4ml/converter/hgq2/replica.py +383 -0
  38. da4ml/trace/__init__.py +2 -2
  39. da4ml/trace/fixed_variable.py +175 -16
  40. da4ml/trace/fixed_variable_array.py +109 -4
  41. da4ml/trace/ops/__init__.py +22 -6
  42. da4ml/trace/ops/conv_utils.py +147 -15
  43. da4ml/trace/ops/einsum_utils.py +9 -6
  44. da4ml/trace/ops/reduce_utils.py +103 -0
  45. da4ml/trace/pipeline.py +36 -34
  46. da4ml/trace/tracer.py +37 -7
  47. da4ml-0.3.0.post1.dist-info/METADATA +107 -0
  48. da4ml-0.3.0.post1.dist-info/RECORD +64 -0
  49. da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
  50. da4ml-0.2.1.dist-info/METADATA +0 -65
  51. da4ml-0.2.1.dist-info/RECORD +0 -39
  52. /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
  53. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/WHEEL +0 -0
  54. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/licenses/LICENSE +0 -0
  55. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,383 @@
1
+ import typing
2
+ from collections.abc import Sequence
3
+ from math import prod
4
+ from typing import Any
5
+
6
+ import hgq
7
+ import keras
8
+ import numpy as np
9
+ from hgq.layers import (
10
+ QBatchNormalization,
11
+ QBatchNormDense,
12
+ QConv1D,
13
+ QConv2D,
14
+ QConv3D,
15
+ QDense,
16
+ QEinsumDense,
17
+ QEinsumDenseBatchnorm,
18
+ QSum,
19
+ )
20
+ from hgq.layers.core.base import MultipleQuantizers, Quantizer
21
+ from hgq.quantizer.internal import FixedPointQuantizerBase
22
+ from keras.layers import ReLU
23
+ from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling
24
+ from keras.src.layers.pooling.base_pooling import BasePooling
25
+ from keras.src.ops.numpy import (
26
+ Add,
27
+ Concatenate,
28
+ Divide,
29
+ GetItem,
30
+ Moveaxis,
31
+ Multiply,
32
+ Ravel,
33
+ Repeat,
34
+ Reshape,
35
+ Subtract,
36
+ Sum,
37
+ Transpose,
38
+ TrueDivide,
39
+ )
40
+
41
+ from ...trace import FixedVariableArray
42
+ from ...trace.ops import conv, einsum, pool, quantize, relu
43
+
44
+
45
+ def mirror_quantizer(q: Quantizer, v: FixedVariableArray) -> FixedVariableArray:
46
+ q_internal: FixedPointQuantizerBase = q.quantizer
47
+ k, i, f = (np.array(x, dtype=np.int8)[0] for x in q_internal.kif)
48
+ round_mode, overflow_mode = q_internal.round_mode, q_internal.overflow_mode
49
+ return quantize(v, k, i, f, overflow_mode=overflow_mode, round_mode=round_mode)
50
+
51
+
52
+ _registry: dict[type, 'type[MirrorOperationBase]'] = {}
53
+
54
+
55
+ class MirrorOperationMeta(type):
56
+ def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]):
57
+ cls = super().__new__(mcs, name, bases, namespace)
58
+ if name == 'MirrorOperationBase':
59
+ return cls
60
+
61
+ handles: type | tuple[type, ...] = namespace['handles']
62
+ if not isinstance(handles, tuple):
63
+ handles = (handles,)
64
+
65
+ for handle in handles:
66
+ _registry[handle] = cls # type: ignore
67
+ return cls
68
+
69
+
70
+ class MirrorOperationBase(metaclass=MirrorOperationMeta):
71
+ handles: tuple[type, ...] = ()
72
+
73
+ def __init__(self, layer: 'keras.Operation'):
74
+ assert isinstance(layer, self.handles)
75
+ self.op: Any = layer
76
+
77
+ def call(self, *args, **kwargs) -> tuple[FixedVariableArray, ...] | FixedVariableArray: ...
78
+
79
+ def __call__(self, *args, **kwargs) -> tuple[FixedVariableArray, ...]:
80
+ assert all(not isinstance(a, FixedVariableArray) for a in kwargs.values())
81
+ assert all(isinstance(a, FixedVariableArray) or isinstance(a, Sequence) for a in args)
82
+ inputs = args[0] if len(args) == 1 else args
83
+
84
+ if not isinstance(self.op, hgq.layers.QLayerBase):
85
+ r = self.call(*args, **kwargs)
86
+ return r if isinstance(r, tuple) else (r,)
87
+
88
+ layer: hgq.layers.QLayerBase = self.op
89
+ assert kwargs.pop('training', False) is False, 'Training mode is not supported in mirror operation'
90
+ assert kwargs.pop('mask', None) is None, 'Masking is not supported in mirror operation'
91
+
92
+ if layer.enable_iq:
93
+ if isinstance(inputs, Sequence):
94
+ assert isinstance(layer.iq, MultipleQuantizers)
95
+ inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
96
+ else:
97
+ assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
98
+ inputs = mirror_quantizer(layer.iq, inputs)
99
+
100
+ outputs = self.call(inputs, **kwargs)
101
+
102
+ activation = getattr(layer, 'activation', keras.activations.linear)
103
+ if activation is not keras.activations.linear:
104
+ if activation is keras.activations.relu:
105
+ if isinstance(outputs, tuple):
106
+ assert len(outputs) == 1, 'ReLU activation is expected to have a single output'
107
+ outputs = (relu(outputs[0]),)
108
+ else:
109
+ outputs = relu(outputs)
110
+ else:
111
+ raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
112
+
113
+ if layer.enable_oq:
114
+ if isinstance(outputs, tuple):
115
+ assert isinstance(layer.oq, MultipleQuantizers)
116
+ outputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.oq.quantizers, outputs))
117
+ else:
118
+ assert isinstance(layer.oq, Quantizer)
119
+ outputs = mirror_quantizer(layer.oq, outputs)
120
+
121
+ if isinstance(outputs, FixedVariableArray):
122
+ outputs = (outputs,)
123
+
124
+ return outputs
125
+
126
+
127
+ class MirrorQuantizer(MirrorOperationBase):
128
+ handles = (Quantizer,)
129
+
130
+ def __init__(self, op: 'Quantizer'):
131
+ super().__init__(op)
132
+ assert isinstance(op.quantizer, FixedPointQuantizerBase)
133
+
134
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
135
+ return mirror_quantizer(self.op, inputs)
136
+
137
+
138
+ class MirrorQDense(MirrorOperationBase):
139
+ handles = (QDense, QEinsumDense, QEinsumDenseBatchnorm, QBatchNormDense, QBatchNormalization, keras.layers.EinsumDense)
140
+
141
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
142
+ op = self.op
143
+ if isinstance(op, (QDense, QBatchNormDense)):
144
+ qkernel = op.qkernel
145
+ qbias = op.qbias
146
+ eq = '...c,cC->...C'
147
+ elif isinstance(op, (QEinsumDense, QEinsumDenseBatchnorm)):
148
+ qkernel = op.qkernel
149
+ qbias = op.qbias
150
+ eq = op.equation
151
+ elif isinstance(op, keras.layers.EinsumDense):
152
+ qkernel = op.kernel
153
+ qbias = op.bias
154
+ eq = op.equation
155
+ elif isinstance(op, QBatchNormalization):
156
+ qkernel, qbias = op.qscaler_and_qoffset
157
+ dim = inputs._vars.ndim
158
+ axis = op.axis
159
+ assert axis != 0, 'Cannot normalizing on batch axis'
160
+ axis -= 1
161
+ idx = ''.join(chr(ord('a') + i) for i in range(dim))
162
+ eq = f'...{idx},{idx[axis]}->...{idx}'
163
+ else:
164
+ raise TypeError(f'Unsupported layer type: {type(op)}')
165
+
166
+ qkernel = np.array(qkernel)
167
+ qbias = np.array(qbias) if qbias is not None else None
168
+ return (einsum(eq, inputs[None], qkernel) + qbias)[0]
169
+
170
+
171
+ class MirrorQConv(MirrorOperationBase):
172
+ handles = (QConv1D, QConv2D, QConv3D)
173
+
174
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
175
+ layer: QConv1D | QConv2D | QConv3D = self.op
176
+ qkernel = np.array(layer.qkernel)
177
+ qbias = np.array(layer.qbias) if layer.qbias is not None else None
178
+ strides = layer.strides
179
+ padding = layer.padding
180
+ dilation_rate = layer.dilation_rate
181
+ groups = layer.groups
182
+
183
+ assert dilation_rate == 1 or all(d == 1 for d in dilation_rate), 'Dilation rate is not supported in mirror operation'
184
+ if layer.data_format == 'channels_first':
185
+ shape = (0,) + tuple(range(2, len(inputs.shape))) + (1,)
186
+ inputs = inputs.transpose(shape)
187
+
188
+ outputs = conv(inputs, qkernel, qbias, strides=strides, padding=padding, format=layer.data_format, groups=groups)
189
+
190
+ return outputs
191
+
192
+
193
+ class MirrorReLU(MirrorOperationBase):
194
+ handles = (ReLU,)
195
+
196
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
197
+ return relu(inputs)
198
+
199
+
200
+ class MirrorReshape(MirrorOperationBase):
201
+ handles = (keras.layers.Reshape, keras.layers.Flatten, Reshape, Ravel)
202
+
203
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
204
+ if isinstance(self.op, (keras.layers.Flatten, Ravel)):
205
+ return inputs.ravel()
206
+ elif isinstance(self.op, keras.layers.Reshape):
207
+ return inputs.reshape(self.op.target_shape)
208
+ elif isinstance(self.op, Reshape):
209
+ return inputs.reshape(self.op.newshape[1:])
210
+ else:
211
+ raise TypeError(f'Unsupported layer type: {type(self.op)}')
212
+
213
+
214
+ class MirrorMerge(MirrorOperationBase):
215
+ handles = (keras.layers.Add, keras.layers.Concatenate, hgq.layers.QAdd)
216
+
217
+ def call(self, inputs: tuple[FixedVariableArray, FixedVariableArray]) -> FixedVariableArray:
218
+ op: keras.Operation = self.op
219
+ if isinstance(op, (keras.layers.Add, hgq.layers.QAdd)):
220
+ return inputs[0] + inputs[1]
221
+ elif isinstance(op, keras.layers.Concatenate):
222
+ axis = op.axis
223
+ data = np.concatenate([v._vars for v in inputs], axis=axis)
224
+ return FixedVariableArray(data, inputs[0].solver_options)
225
+ else:
226
+ raise TypeError(f'Unsupported layer type: {type(op)}')
227
+
228
+
229
+ class MirrorPool(MirrorOperationBase):
230
+ handles = (
231
+ hgq.layers.QAvgPool1D,
232
+ hgq.layers.QAvgPool2D,
233
+ hgq.layers.QAvgPool3D,
234
+ hgq.layers.QMaxPool1D,
235
+ hgq.layers.QMaxPool2D,
236
+ hgq.layers.QMaxPool3D,
237
+ hgq.layers.QGlobalAveragePooling1D,
238
+ hgq.layers.QGlobalMaxPooling1D,
239
+ hgq.layers.QGlobalAveragePooling2D,
240
+ hgq.layers.QGlobalMaxPooling2D,
241
+ hgq.layers.QGlobalAveragePooling3D,
242
+ hgq.layers.QGlobalMaxPooling3D,
243
+ keras.layers.AveragePooling1D,
244
+ keras.layers.AveragePooling2D,
245
+ keras.layers.AveragePooling3D,
246
+ keras.layers.MaxPooling1D,
247
+ keras.layers.MaxPooling2D,
248
+ keras.layers.MaxPooling3D,
249
+ keras.layers.GlobalAveragePooling1D,
250
+ keras.layers.GlobalMaxPooling1D,
251
+ keras.layers.GlobalAveragePooling2D,
252
+ keras.layers.GlobalMaxPooling2D,
253
+ keras.layers.GlobalAveragePooling3D,
254
+ keras.layers.GlobalMaxPooling3D,
255
+ )
256
+
257
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
258
+ cname = self.op.__class__.__name__
259
+ if 'Max' in cname:
260
+ op = 'max'
261
+ else:
262
+ assert 'Average' in cname, f'Unsupported global pooling layer: {cname}'
263
+ op = 'avg'
264
+
265
+ data_format = self.op.data_format
266
+ if data_format == 'channels_first':
267
+ inputs = np.moveaxis(inputs, 1, -1) # type: ignore
268
+
269
+ if isinstance(self.op, BaseGlobalPooling):
270
+ pool_dim = self.op.input_spec.ndim - 2 # type: ignore
271
+ axis = tuple(range(pool_dim))
272
+ keepdims = self.op.keepdims
273
+
274
+ if op == 'max':
275
+ out = np.amax(inputs, axis=axis, keepdims=keepdims) # type: ignore
276
+ elif op == 'avg':
277
+ pool_size = prod(inputs.shape[:-1])
278
+ out = np.sum(inputs, axis=axis, keepdims=keepdims) / pool_size # type: ignore
279
+ else:
280
+ assert isinstance(self.op, BasePooling), f'Unsupported pooling layer: {type(self.op)}'
281
+ pool_size = self.op.pool_size
282
+ strides = self.op.strides
283
+ padding = self.op.padding
284
+ pool_dim = len(pool_size)
285
+ out = pool(
286
+ inputs,
287
+ pool_size=pool_size,
288
+ strides=strides,
289
+ padding=padding,
290
+ pool_type=op,
291
+ )
292
+ if data_format == 'channels_first':
293
+ out = np.moveaxis(out, -1, 1) # type: ignore
294
+
295
+ return out # type: ignore
296
+
297
+
298
+ class MirrorRepeatVector(MirrorOperationBase):
299
+ handles = (keras.layers.RepeatVector,)
300
+
301
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
302
+ layer: keras.layers.RepeatVector = self.op
303
+ if layer.n == 1:
304
+ return inputs
305
+ # return FixedVariableArray(np.repeat(inputs._vars, layer.n, axis=0), inputs.solver_options)
306
+ return np.repeat(inputs[None], layer.n, axis=0)[0] # type: ignore
307
+
308
+
309
+ class MirrorGetItem(MirrorOperationBase):
310
+ handles = (GetItem,)
311
+
312
+ def call(self, x: FixedVariableArray, key):
313
+ if isinstance(key, list):
314
+ key = tuple(key)
315
+ return x[None][key][0]
316
+
317
+
318
+ class MirrorSum(MirrorOperationBase):
319
+ handles = (Sum,)
320
+
321
+ def call(self, x: FixedVariableArray, axis=None, keepdims=False):
322
+ return np.sum(x[None], axis=axis, keepdims=keepdims)[0] # type: ignore
323
+
324
+
325
+ class MirrorQSum(MirrorOperationBase):
326
+ handles = (QSum,)
327
+
328
+ def call(self, x: FixedVariableArray):
329
+ layer: QSum = self.op
330
+ axes, scale, keepdims = layer.axes, layer.scale, layer.keepdims
331
+ return np.sum(x[None], axis=axes, keepdims=keepdims)[0] * scale # type: ignore
332
+
333
+
334
+ class MirrorArithmetic(MirrorOperationBase):
335
+ handles = (Add, Subtract, Multiply, TrueDivide, Divide)
336
+
337
+ def call(self, x1: FixedVariableArray, x2: FixedVariableArray):
338
+ match self.op.__class__.__name__:
339
+ case 'Add':
340
+ return x1 + x2
341
+ case 'Subtract':
342
+ return x1 - x2
343
+ case 'Multiply':
344
+ return x1 * x2
345
+ case 'TrueDivide' | 'Divide':
346
+ return x1 / x2
347
+ case _:
348
+ raise TypeError(f'Unsupported arithmetic operation: {type(self.op)}')
349
+
350
+
351
+ class MirrorConcatenate(MirrorOperationBase):
352
+ handles = (Concatenate,)
353
+
354
+ def call(self, xs: Sequence[FixedVariableArray]):
355
+ axis = self.op.axis
356
+ # return backend.numpy.concatenate(xs, axis=self.axis)
357
+ # return FixedVariableArray(np.concatenate([x._vars[None] for x in xs], axis=axis)[0], xs[0].solver_options)
358
+ return np.concatenate([x[None] for x in xs], axis=axis)[0] # type: ignore
359
+
360
+
361
+ class MirrorRepeat(MirrorOperationBase):
362
+ handles = (Repeat,)
363
+
364
+ def call(self, x: FixedVariableArray):
365
+ repeats, axis = self.op.repeats, self.op.axis
366
+ # return FixedVariableArray(np.repeat(x._vars[None], repeats, axis=axis)[0], x.solver_options)
367
+ return np.repeat(x[None], repeats, axis=axis)[0] # type: ignore
368
+
369
+
370
+ class MirrorTranspose(MirrorOperationBase):
371
+ handles = (Transpose,)
372
+
373
+ def call(self, x: FixedVariableArray):
374
+ axes = self.op.axes
375
+ return np.transpose(x, axes) # type: ignore
376
+
377
+
378
+ class MirrorMoveaxis(MirrorOperationBase):
379
+ handles = (Moveaxis,)
380
+
381
+ def call(self, x: FixedVariableArray):
382
+ source, destination = self.op.source, self.op.destination
383
+ return np.moveaxis(x[None], source, destination)[0] # type: ignore
da4ml/trace/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from .fixed_variable import HWConfig
2
- from .fixed_variable_array import FixedVariableArray
2
+ from .fixed_variable_array import FixedVariableArray, FixedVariableArrayInput
3
3
  from .pipeline import to_pipeline
4
4
  from .tracer import comb_trace
5
5
 
6
- __all__ = ['to_pipeline', 'comb_trace', 'FixedVariableArray', 'HWConfig']
6
+ __all__ = ['to_pipeline', 'comb_trace', 'FixedVariableArray', 'HWConfig', 'FixedVariableArrayInput']
@@ -43,9 +43,9 @@ class FixedVariable:
43
43
  ) -> None:
44
44
  assert low <= high, f'low {low} must be less than high {high}'
45
45
 
46
- if low == high:
46
+ if low == high and opr != 'new':
47
47
  opr = 'const'
48
- _factor = 1.0
48
+ _factor = _factor
49
49
  _from = ()
50
50
 
51
51
  low, high, step = Decimal(low), Decimal(high), Decimal(step)
@@ -72,15 +72,21 @@ class FixedVariable:
72
72
  self.latency = _latency
73
73
  self.cost = _cost
74
74
 
75
+ # Update latency for constant variables to match the current variable for piplining
76
+
77
+ for v in self._from:
78
+ if v.opr == 'const':
79
+ v.latency = self.latency
80
+
75
81
  def get_cost_and_latency(self):
76
82
  if self.opr == 'const':
77
83
  return 0.0, 0.0
78
- if self.opr in ('vadd', 'cadd'):
84
+ if self.opr in ('vadd', 'cadd', 'min', 'max'):
79
85
  adder_size = self.hwconf.adder_size
80
86
  carry_size = self.hwconf.carry_size
81
87
  latency_cutoff = self.hwconf.latency_cutoff
82
88
 
83
- if self.opr == 'vadd':
89
+ if self.opr in ('min', 'max', 'vadd'):
84
90
  assert len(self._from) == 2
85
91
  v0, v1 = self._from
86
92
  int0, int1 = v0.qint, v1.qint
@@ -89,8 +95,6 @@ class FixedVariable:
89
95
  else:
90
96
  assert len(self._from) == 1
91
97
  assert self._data is not None, 'cadd must have data'
92
- # int0 = self._from[0].qint
93
- # int1 = QInterval(float(self._data), float(self._data), float(self.step))
94
98
  _f = _const_f(self._data)
95
99
  _cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
96
100
  base_latency = self._from[0].latency
@@ -138,6 +142,12 @@ class FixedVariable:
138
142
  k = self.low < 0
139
143
  return k, i, f
140
144
 
145
+ @classmethod
146
+ def from_const(cls, const: float | Decimal, hwconf: HWConfig, latency: float, _factor: float | Decimal):
147
+ f = _const_f(const)
148
+ step = Decimal(2) ** -f
149
+ return cls(const, const, step, hwconf=hwconf, opr='const', _factor=_factor, latency=latency)
150
+
141
151
  def __repr__(self) -> str:
142
152
  if self._factor == 1:
143
153
  return f'FixedVariable({self.low}, {self.high}, {self.step})'
@@ -185,7 +195,9 @@ class FixedVariable:
185
195
  hwconf=self.hwconf,
186
196
  )
187
197
 
188
- def _const_add(self, other: float | Decimal):
198
+ def _const_add(self, other: float | Decimal | None):
199
+ if other is None:
200
+ return self
189
201
  if not isinstance(other, (int, float, Decimal)):
190
202
  other = float(other) # direct numpy to decimal raises error
191
203
  other = Decimal(other)
@@ -222,7 +234,7 @@ class FixedVariable:
222
234
  other: 'float|Decimal',
223
235
  ):
224
236
  if other == 0:
225
- return FixedVariable(0, 0, 1, hwconf=self.hwconf)
237
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
226
238
 
227
239
  assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
228
240
 
@@ -267,7 +279,7 @@ class FixedVariable:
267
279
  i = ceil(log2(val + step)) if not i else i
268
280
  eps = step / 2 if round_mode == 'RND' else 0
269
281
  val = (floor(val / step + eps) * step) % (Decimal(2) ** i)
270
- return FixedVariable(val, val, step, hwconf=self.hwconf)
282
+ return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
271
283
 
272
284
  step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
273
285
  if step > self.step and round_mode == 'RND':
@@ -281,6 +293,10 @@ class FixedVariable:
281
293
  low = Decimal(0)
282
294
  high = _high
283
295
  _factor = self._factor
296
+
297
+ if self.low == low and self.high == high and self.step == step:
298
+ return self
299
+
284
300
  return FixedVariable(
285
301
  low,
286
302
  high,
@@ -301,7 +317,7 @@ class FixedVariable:
301
317
  round_mode: str = 'TRN',
302
318
  ):
303
319
  overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
304
- assert overflow_mode in ('WRAP', 'SAT')
320
+ assert overflow_mode in ('WRAP', 'SAT', 'SAT_SM')
305
321
  assert round_mode in ('TRN', 'RND')
306
322
 
307
323
  _k, _i, _f = self.kif
@@ -312,13 +328,20 @@ class FixedVariable:
312
328
  if f < _f and round_mode == 'RND':
313
329
  return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
314
330
 
331
+ if overflow_mode in ('SAT', 'SAT_SM'):
332
+ step = Decimal(2) ** -f
333
+ _high = Decimal(2) ** i
334
+ high = _high - step
335
+ low = -_high * k if overflow_mode == 'SAT' else -high * k
336
+ return self.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
337
+
315
338
  if self.low == self.high:
316
339
  val = self.low
317
340
  step = Decimal(2) ** -f
318
341
  _high = Decimal(2) ** i
319
342
  high, low = _high - step, -_high * k
320
343
  val = (floor(val / step) * step - low) % (2 * _high) + low
321
- return FixedVariable(val, val, step, hwconf=self.hwconf)
344
+ return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
322
345
 
323
346
  # TODO: corner cases exists (e.g., overflow to negative, or negative overflow to high value)
324
347
  # bit-exactness will be lost in these cases, but they should never happen (quantizers are used in a weird way)
@@ -327,17 +350,20 @@ class FixedVariable:
327
350
  k = min(k, _k) if i >= _i else k
328
351
  i = min(i, _i)
329
352
 
330
- step = max(Decimal(2) ** -f, self.step)
353
+ if i + k + f <= 0:
354
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
355
+
356
+ step = Decimal(2) ** -f
331
357
 
332
358
  low = -k * Decimal(2) ** i
359
+
333
360
  high = Decimal(2) ** i - step
334
361
  _low, _high = self.low, self.high
335
362
 
336
363
  if _low >= low and _high <= high:
337
364
  low, high = _low, _high
338
-
339
- if low > high:
340
- return FixedVariable(0, 0, 1, hwconf=self.hwconf)
365
+ low = floor(low / step) * step
366
+ high = ceil(high / step) * step
341
367
 
342
368
  return FixedVariable(
343
369
  low,
@@ -345,7 +371,7 @@ class FixedVariable:
345
371
  step,
346
372
  _from=(self,),
347
373
  _factor=abs(self._factor),
348
- opr='wrap' if overflow_mode == 'WRAP' else 'sat',
374
+ opr='wrap',
349
375
  latency=self.latency,
350
376
  hwconf=self.hwconf,
351
377
  )
@@ -356,3 +382,136 @@ class FixedVariable:
356
382
  _high = Decimal(2) ** i
357
383
  low, high = k * _high, _high - step
358
384
  return cls(low, high, step, **kwargs)
385
+
386
+ def msb_mux(self, a: 'FixedVariable', b: 'FixedVariable', qint: tuple[Decimal, Decimal, Decimal] | None = None):
387
+ assert isinstance(a, FixedVariable) and isinstance(b, FixedVariable), 'msb_mux requires two FixedVariables'
388
+ if self._factor < 0:
389
+ return (-self).msb_mux(b, a, qint)
390
+
391
+ if a._factor < 0:
392
+ qint = (-qint[1], -qint[0], qint[2]) if qint else None
393
+ return -(self.msb_mux(-a, -b, qint=qint))
394
+
395
+ _factor = a._factor
396
+
397
+ if qint is None:
398
+ qint = (min(a.low, b.low), max(a.high, b.high), min(a.step, b.step))
399
+
400
+ dlat, dcost = cost_add(a.qint, b.qint, 0, False, self.hwconf.adder_size, self.hwconf.carry_size)
401
+ return FixedVariable(
402
+ *qint,
403
+ _from=(self, a, b),
404
+ _factor=_factor,
405
+ opr='msb_mux',
406
+ latency=max(a.latency, b.latency, self.latency) + dlat,
407
+ hwconf=self.hwconf,
408
+ cost=dcost,
409
+ )
410
+
411
+ def max_of(self, other):
412
+ if other == 0:
413
+ return self.relu()
414
+ if other == -float('inf'):
415
+ return self
416
+ if other == float('inf'):
417
+ raise ValueError('Cannot apply max_of with inf')
418
+ if not isinstance(other, FixedVariable):
419
+ other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=abs(self._factor))
420
+
421
+ if self.low >= other.high:
422
+ return self
423
+ if self.high <= other.low:
424
+ return other
425
+
426
+ qint = (max(self.low, other.low), max(self.high, other.high), min(self.step, other.step))
427
+ return (self - other).msb_mux(other, self, qint=qint)
428
+
429
+ def min_of(self, other):
430
+ if other == 0:
431
+ return (-self).relu()
432
+ if other == float('inf'):
433
+ return self
434
+ if other == -float('inf'):
435
+ raise ValueError('Cannot apply min_of with -inf')
436
+ if not isinstance(other, FixedVariable):
437
+ other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=(self._factor))
438
+
439
+ if self.high <= other.low:
440
+ return self
441
+ if self.low >= other.high:
442
+ return other
443
+
444
+ qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
445
+ return (self - other).msb_mux(self, other, qint=qint)
446
+
447
+
448
+ class FixedVariableInput(FixedVariable):
449
+ def __init__(
450
+ self,
451
+ latency: float | None = None,
452
+ hwconf=HWConfig(-1, -1, -1),
453
+ ) -> None:
454
+ self.low = Decimal(1e10)
455
+ self.high = Decimal(-1e10)
456
+ self.step = Decimal(1e10)
457
+ self._factor = Decimal(1)
458
+ self._from: tuple[FixedVariable, ...] = ()
459
+ self.opr = 'new'
460
+ self._data = None
461
+ self.id = uuid4()
462
+ self.hwconf = hwconf
463
+
464
+ self.latency = latency if latency is not None else 0.0
465
+ self.cost = 0.0
466
+
467
+ def __add__(self, other):
468
+ raise ValueError('Cannot operate on unquantized input variable')
469
+
470
+ def __sub__(self, other):
471
+ raise ValueError('Cannot operate on unquantized input variable')
472
+
473
+ def __neg__(self):
474
+ raise ValueError('Cannot negate unquantized input variable')
475
+
476
+ def relu(self, *args, **kwargs):
477
+ raise ValueError('Cannot apply relu on unquantized input variable')
478
+
479
+ def max_of(self, other):
480
+ raise ValueError('Cannot apply max_of on unquantized input variable')
481
+
482
+ def min_of(self, other):
483
+ raise ValueError('Cannot apply min_of on unquantized input variable')
484
+
485
+ def quantize(
486
+ self,
487
+ k: int | bool,
488
+ i: int,
489
+ f: int,
490
+ overflow_mode: str = 'WRAP',
491
+ round_mode: str = 'TRN',
492
+ ):
493
+ assert overflow_mode == 'WRAP'
494
+
495
+ if k + i + f <= 0:
496
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
497
+
498
+ if round_mode == 'RND':
499
+ return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
500
+
501
+ step = Decimal(2) ** -f
502
+ _high = Decimal(2) ** i
503
+ low, high = -_high * k, _high - step
504
+ self.high = max(self.high, high)
505
+ self.low = min(self.low, low)
506
+ self.step = min(self.step, step)
507
+
508
+ return FixedVariable(
509
+ low,
510
+ high,
511
+ step,
512
+ _from=(self,),
513
+ _factor=self._factor,
514
+ opr='wrap',
515
+ latency=self.latency,
516
+ hwconf=self.hwconf,
517
+ )