da4ml 0.3.0.post1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

@@ -7,14 +7,21 @@ import hgq
7
7
  import keras
8
8
  import numpy as np
9
9
  from hgq.layers import (
10
+ QAdd,
10
11
  QBatchNormalization,
11
12
  QBatchNormDense,
12
13
  QConv1D,
13
14
  QConv2D,
14
15
  QConv3D,
15
16
  QDense,
17
+ QDot,
18
+ QEinsum,
16
19
  QEinsumDense,
17
20
  QEinsumDenseBatchnorm,
21
+ QMaximum,
22
+ QMeanPow2,
23
+ QMinimum,
24
+ QSubtract,
18
25
  QSum,
19
26
  )
20
27
  from hgq.layers.core.base import MultipleQuantizers, Quantizer
@@ -23,10 +30,19 @@ from keras.layers import ReLU
23
30
  from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling
24
31
  from keras.src.layers.pooling.base_pooling import BasePooling
25
32
  from keras.src.ops.numpy import (
33
+ Abs,
34
+ Absolute,
26
35
  Add,
27
36
  Concatenate,
28
37
  Divide,
38
+ Dot,
39
+ Einsum,
29
40
  GetItem,
41
+ Matmul,
42
+ Max,
43
+ Maximum,
44
+ Min,
45
+ Minimum,
30
46
  Moveaxis,
31
47
  Multiply,
32
48
  Ravel,
@@ -49,13 +65,13 @@ def mirror_quantizer(q: Quantizer, v: FixedVariableArray) -> FixedVariableArray:
49
65
  return quantize(v, k, i, f, overflow_mode=overflow_mode, round_mode=round_mode)
50
66
 
51
67
 
52
- _registry: dict[type, 'type[MirrorOperationBase]'] = {}
68
+ _registry: dict[type, 'type[ReplayOperationBase]'] = {}
53
69
 
54
70
 
55
- class MirrorOperationMeta(type):
71
+ class ReplayOperationMeta(type):
56
72
  def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]):
57
73
  cls = super().__new__(mcs, name, bases, namespace)
58
- if name == 'MirrorOperationBase':
74
+ if name == 'ReplayOperationBase':
59
75
  return cls
60
76
 
61
77
  handles: type | tuple[type, ...] = namespace['handles']
@@ -67,7 +83,7 @@ class MirrorOperationMeta(type):
67
83
  return cls
68
84
 
69
85
 
70
- class MirrorOperationBase(metaclass=MirrorOperationMeta):
86
+ class ReplayOperationBase(metaclass=ReplayOperationMeta):
71
87
  handles: tuple[type, ...] = ()
72
88
 
73
89
  def __init__(self, layer: 'keras.Operation'):
@@ -124,7 +140,7 @@ class MirrorOperationBase(metaclass=MirrorOperationMeta):
124
140
  return outputs
125
141
 
126
142
 
127
- class MirrorQuantizer(MirrorOperationBase):
143
+ class ReplayQuantizer(ReplayOperationBase):
128
144
  handles = (Quantizer,)
129
145
 
130
146
  def __init__(self, op: 'Quantizer'):
@@ -135,8 +151,8 @@ class MirrorQuantizer(MirrorOperationBase):
135
151
  return mirror_quantizer(self.op, inputs)
136
152
 
137
153
 
138
- class MirrorQDense(MirrorOperationBase):
139
- handles = (QDense, QEinsumDense, QEinsumDenseBatchnorm, QBatchNormDense, QBatchNormalization, keras.layers.EinsumDense)
154
+ class ReplayQDense(ReplayOperationBase):
155
+ handles = (QDense, QEinsumDense, QEinsumDenseBatchnorm, QBatchNormDense, keras.layers.EinsumDense)
140
156
 
141
157
  def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
142
158
  op = self.op
@@ -152,14 +168,6 @@ class MirrorQDense(MirrorOperationBase):
152
168
  qkernel = op.kernel
153
169
  qbias = op.bias
154
170
  eq = op.equation
155
- elif isinstance(op, QBatchNormalization):
156
- qkernel, qbias = op.qscaler_and_qoffset
157
- dim = inputs._vars.ndim
158
- axis = op.axis
159
- assert axis != 0, 'Cannot normalizing on batch axis'
160
- axis -= 1
161
- idx = ''.join(chr(ord('a') + i) for i in range(dim))
162
- eq = f'...{idx},{idx[axis]}->...{idx}'
163
171
  else:
164
172
  raise TypeError(f'Unsupported layer type: {type(op)}')
165
173
 
@@ -168,7 +176,28 @@ class MirrorQDense(MirrorOperationBase):
168
176
  return (einsum(eq, inputs[None], qkernel) + qbias)[0]
169
177
 
170
178
 
171
- class MirrorQConv(MirrorOperationBase):
179
+ class ReplayQDot(ReplayOperationBase):
180
+ handles = (QDot, keras.layers.Dot)
181
+
182
+ def call(self, inputs: tuple[FixedVariableArray, FixedVariableArray]) -> FixedVariableArray:
183
+ layer: QDot | keras.layers.Dot = self.op
184
+ assert not layer.normalize, 'normalize is not supported in mirror operation'
185
+
186
+ axes = layer.axes
187
+ return np.dot(inputs[0][None], inputs[1][None], axes=axes)[0] # type: ignore
188
+
189
+
190
+ class ReplayQBatchNormalization(ReplayOperationBase):
191
+ handles = (QBatchNormalization,)
192
+
193
+ def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
194
+ layer: QBatchNormalization = self.op
195
+ scale, bias = map(np.array, layer.qscaler_and_qoffset)
196
+ shape = layer._shape
197
+ return inputs * scale.reshape(shape) + bias.reshape(shape)
198
+
199
+
200
+ class ReplayQConv(ReplayOperationBase):
172
201
  handles = (QConv1D, QConv2D, QConv3D)
173
202
 
174
203
  def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
@@ -190,14 +219,14 @@ class MirrorQConv(MirrorOperationBase):
190
219
  return outputs
191
220
 
192
221
 
193
- class MirrorReLU(MirrorOperationBase):
222
+ class ReplayReLU(ReplayOperationBase):
194
223
  handles = (ReLU,)
195
224
 
196
225
  def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
197
226
  return relu(inputs)
198
227
 
199
228
 
200
- class MirrorReshape(MirrorOperationBase):
229
+ class ReplayReshape(ReplayOperationBase):
201
230
  handles = (keras.layers.Reshape, keras.layers.Flatten, Reshape, Ravel)
202
231
 
203
232
  def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
@@ -211,8 +240,8 @@ class MirrorReshape(MirrorOperationBase):
211
240
  raise TypeError(f'Unsupported layer type: {type(self.op)}')
212
241
 
213
242
 
214
- class MirrorMerge(MirrorOperationBase):
215
- handles = (keras.layers.Add, keras.layers.Concatenate, hgq.layers.QAdd)
243
+ class ReplayMerge(ReplayOperationBase):
244
+ handles = (keras.layers.Add, keras.layers.Concatenate, QAdd)
216
245
 
217
246
  def call(self, inputs: tuple[FixedVariableArray, FixedVariableArray]) -> FixedVariableArray:
218
247
  op: keras.Operation = self.op
@@ -226,7 +255,7 @@ class MirrorMerge(MirrorOperationBase):
226
255
  raise TypeError(f'Unsupported layer type: {type(op)}')
227
256
 
228
257
 
229
- class MirrorPool(MirrorOperationBase):
258
+ class ReplayPool(ReplayOperationBase):
230
259
  handles = (
231
260
  hgq.layers.QAvgPool1D,
232
261
  hgq.layers.QAvgPool2D,
@@ -295,7 +324,7 @@ class MirrorPool(MirrorOperationBase):
295
324
  return out # type: ignore
296
325
 
297
326
 
298
- class MirrorRepeatVector(MirrorOperationBase):
327
+ class ReplayRepeatVector(ReplayOperationBase):
299
328
  handles = (keras.layers.RepeatVector,)
300
329
 
301
330
  def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
@@ -306,7 +335,7 @@ class MirrorRepeatVector(MirrorOperationBase):
306
335
  return np.repeat(inputs[None], layer.n, axis=0)[0] # type: ignore
307
336
 
308
337
 
309
- class MirrorGetItem(MirrorOperationBase):
338
+ class ReplayGetItem(ReplayOperationBase):
310
339
  handles = (GetItem,)
311
340
 
312
341
  def call(self, x: FixedVariableArray, key):
@@ -315,15 +344,21 @@ class MirrorGetItem(MirrorOperationBase):
315
344
  return x[None][key][0]
316
345
 
317
346
 
318
- class MirrorSum(MirrorOperationBase):
319
- handles = (Sum,)
347
+ class ReplayReduction(ReplayOperationBase):
348
+ handles = (Sum, Max, Min)
320
349
 
321
350
  def call(self, x: FixedVariableArray, axis=None, keepdims=False):
322
- return np.sum(x[None], axis=axis, keepdims=keepdims)[0] # type: ignore
351
+ if isinstance(self.op, Sum):
352
+ op = np.sum
353
+ elif isinstance(self.op, Max):
354
+ op = np.amax
355
+ elif isinstance(self.op, Min):
356
+ op = np.amin
357
+ return op(x[None], axis=axis, keepdims=keepdims)[0] # type: ignore
323
358
 
324
359
 
325
- class MirrorQSum(MirrorOperationBase):
326
- handles = (QSum,)
360
+ class ReplayQReduction(ReplayOperationBase):
361
+ handles = (QSum, QMeanPow2)
327
362
 
328
363
  def call(self, x: FixedVariableArray):
329
364
  layer: QSum = self.op
@@ -331,11 +366,14 @@ class MirrorQSum(MirrorOperationBase):
331
366
  return np.sum(x[None], axis=axes, keepdims=keepdims)[0] * scale # type: ignore
332
367
 
333
368
 
334
- class MirrorArithmetic(MirrorOperationBase):
335
- handles = (Add, Subtract, Multiply, TrueDivide, Divide)
369
+ class ReplayArithmetic(ReplayOperationBase):
370
+ handles = (Add, Subtract, Multiply, TrueDivide, Divide, QSubtract, QMaximum, QMinimum, Maximum, Minimum)
336
371
 
337
372
  def call(self, x1: FixedVariableArray, x2: FixedVariableArray):
338
- match self.op.__class__.__name__:
373
+ name = self.op.__class__.__name__
374
+ if name.startswith('Q'):
375
+ name = name[1:]
376
+ match name:
339
377
  case 'Add':
340
378
  return x1 + x2
341
379
  case 'Subtract':
@@ -344,11 +382,15 @@ class MirrorArithmetic(MirrorOperationBase):
344
382
  return x1 * x2
345
383
  case 'TrueDivide' | 'Divide':
346
384
  return x1 / x2
385
+ case 'Maximum':
386
+ return np.maximum(x1, x2) # type: ignore
387
+ case 'Minimum':
388
+ return np.minimum(x1, x2) # type: ignore
347
389
  case _:
348
390
  raise TypeError(f'Unsupported arithmetic operation: {type(self.op)}')
349
391
 
350
392
 
351
- class MirrorConcatenate(MirrorOperationBase):
393
+ class ReplayConcatenate(ReplayOperationBase):
352
394
  handles = (Concatenate,)
353
395
 
354
396
  def call(self, xs: Sequence[FixedVariableArray]):
@@ -358,7 +400,7 @@ class MirrorConcatenate(MirrorOperationBase):
358
400
  return np.concatenate([x[None] for x in xs], axis=axis)[0] # type: ignore
359
401
 
360
402
 
361
- class MirrorRepeat(MirrorOperationBase):
403
+ class ReplayRepeat(ReplayOperationBase):
362
404
  handles = (Repeat,)
363
405
 
364
406
  def call(self, x: FixedVariableArray):
@@ -367,7 +409,7 @@ class MirrorRepeat(MirrorOperationBase):
367
409
  return np.repeat(x[None], repeats, axis=axis)[0] # type: ignore
368
410
 
369
411
 
370
- class MirrorTranspose(MirrorOperationBase):
412
+ class ReplayTranspose(ReplayOperationBase):
371
413
  handles = (Transpose,)
372
414
 
373
415
  def call(self, x: FixedVariableArray):
@@ -375,9 +417,57 @@ class MirrorTranspose(MirrorOperationBase):
375
417
  return np.transpose(x, axes) # type: ignore
376
418
 
377
419
 
378
- class MirrorMoveaxis(MirrorOperationBase):
420
+ class ReplayMoveaxis(ReplayOperationBase):
379
421
  handles = (Moveaxis,)
380
422
 
381
423
  def call(self, x: FixedVariableArray):
382
424
  source, destination = self.op.source, self.op.destination
383
425
  return np.moveaxis(x[None], source, destination)[0] # type: ignore
426
+
427
+
428
+ noop_layers = []
429
+ for k, v in keras.layers.__dict__.items():
430
+ name = k.lower()
431
+ if 'dropout' in name or 'random' in name or 'noise' in name:
432
+ noop_layers.append(v)
433
+
434
+
435
+ class ReplayNoOp(ReplayOperationBase):
436
+ handles = tuple(noop_layers)
437
+
438
+ def call(self, x: FixedVariableArray, training=False) -> FixedVariableArray:
439
+ assert not training, 'Training mode is not supported in mirror operation'
440
+ return x
441
+
442
+
443
+ class ReplayQEinsum(ReplayOperationBase):
444
+ handles = (QEinsum,)
445
+
446
+ def call(self, inputs: tuple[FixedVariableArray, ...]) -> FixedVariableArray:
447
+ layer: QEinsum = self.op
448
+ eq = layer.equation
449
+ return einsum(eq, *inputs)
450
+
451
+
452
+ class ReplayEinsum(ReplayOperationBase):
453
+ handles = (Einsum,)
454
+
455
+ def call(self, *operands: FixedVariableArray) -> FixedVariableArray:
456
+ layer: Einsum = self.op
457
+ eq = layer.subscripts
458
+ operands = [operand[None] for operand in operands] # type: ignore
459
+ return einsum(eq, *operands)[0]
460
+
461
+
462
+ class ReplayMatmul(ReplayOperationBase):
463
+ handles = (Matmul, Dot)
464
+
465
+ def call(self, x1: FixedVariableArray, x2: FixedVariableArray) -> FixedVariableArray:
466
+ return x1 @ x2
467
+
468
+
469
+ class ReplayAbs(ReplayOperationBase):
470
+ handles = (Absolute, Abs)
471
+
472
+ def call(self, x: FixedVariableArray) -> FixedVariableArray:
473
+ return np.abs(x) # type: ignore
@@ -1,11 +1,15 @@
1
+ import random
2
+ from collections.abc import Generator
1
3
  from decimal import Decimal
2
4
  from math import ceil, floor, log2
3
5
  from typing import NamedTuple
4
- from uuid import UUID, uuid4
6
+ from uuid import UUID
5
7
 
6
8
  from ..cmvm.core import cost_add
7
9
  from ..cmvm.types import QInterval
8
10
 
11
+ rd = random.SystemRandom()
12
+
9
13
 
10
14
  class HWConfig(NamedTuple):
11
15
  adder_size: int
@@ -26,6 +30,23 @@ def _const_f(const: float | Decimal):
26
30
  return _high
27
31
 
28
32
 
33
+ def to_csd_powers(x: float) -> Generator[float, None, None]:
34
+ if x == 0:
35
+ return
36
+ f = _const_f(abs(x))
37
+ x = x * 2**f
38
+ s = 2**-f
39
+ N = ceil(log2(abs(x) * 1.5 + 1e-19))
40
+ for n in range(N - 1, -1, -1):
41
+ _2pn = 2**n
42
+ thres = _2pn / 1.5
43
+ bit = int(x > thres) - int(x < -thres)
44
+ v = _2pn * bit
45
+ x -= v
46
+ if v != 0:
47
+ yield v * s
48
+
49
+
29
50
  class FixedVariable:
30
51
  def __init__(
31
52
  self,
@@ -43,13 +64,14 @@ class FixedVariable:
43
64
  ) -> None:
44
65
  assert low <= high, f'low {low} must be less than high {high}'
45
66
 
46
- if low == high and opr != 'new':
67
+ if low != high and opr == 'const':
68
+ raise ValueError('Constant variable must have low == high')
69
+
70
+ if low == high:
47
71
  opr = 'const'
48
- _factor = _factor
49
72
  _from = ()
50
73
 
51
74
  low, high, step = Decimal(low), Decimal(high), Decimal(step)
52
- low, high = floor(low / step) * step, ceil(high / step) * step
53
75
  self.low = low
54
76
  self.high = high
55
77
  self.step = step
@@ -58,7 +80,7 @@ class FixedVariable:
58
80
  opr = opr
59
81
  self.opr = opr
60
82
  self._data = _data
61
- self.id = _id or uuid4()
83
+ self.id = _id or UUID(int=rd.getrandbits(128), version=4)
62
84
  self.hwconf = hwconf
63
85
 
64
86
  if opr == 'cadd':
@@ -81,7 +103,7 @@ class FixedVariable:
81
103
  def get_cost_and_latency(self):
82
104
  if self.opr == 'const':
83
105
  return 0.0, 0.0
84
- if self.opr in ('vadd', 'cadd', 'min', 'max'):
106
+ if self.opr in ('vadd', 'cadd', 'min', 'max', 'vmul'):
85
107
  adder_size = self.hwconf.adder_size
86
108
  carry_size = self.hwconf.carry_size
87
109
  latency_cutoff = self.hwconf.latency_cutoff
@@ -92,13 +114,25 @@ class FixedVariable:
92
114
  int0, int1 = v0.qint, v1.qint
93
115
  base_latency = max(v0.latency, v1.latency)
94
116
  dlat, _cost = cost_add(int0, int1, 0, False, adder_size, carry_size)
95
- else:
117
+ elif self.opr == 'cadd':
96
118
  assert len(self._from) == 1
97
119
  assert self._data is not None, 'cadd must have data'
98
120
  _f = _const_f(self._data)
99
121
  _cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
100
122
  base_latency = self._from[0].latency
101
123
  dlat = 0.0
124
+ elif self.opr == 'vmul':
125
+ assert len(self._from) == 2
126
+ v0, v1 = self._from
127
+ b0, b1 = sum(v0.kif), sum(v1.kif)
128
+ int0, int1 = v0.qint, v1.qint
129
+ dlat0, _cost0 = cost_add(int0, int0, 0, False, adder_size, carry_size)
130
+ dlat1, _cost1 = cost_add(int1, int1, 0, False, adder_size, carry_size)
131
+ dlat = max(dlat0 * b1, dlat1 * b0)
132
+ _cost = min(_cost0 * b1, _cost1 * b0)
133
+ base_latency = max(v0.latency, v1.latency)
134
+ else:
135
+ raise NotImplementedError(f'Operation {self.opr} is unknown')
102
136
 
103
137
  _latency = dlat + base_latency
104
138
  if latency_cutoff > 0 and ceil(_latency / latency_cutoff) > ceil(base_latency / latency_cutoff):
@@ -107,6 +141,7 @@ class FixedVariable:
107
141
  dlat <= latency_cutoff
108
142
  ), f'Latency of an atomic operation {dlat} is larger than the pipelining latency cutoff {latency_cutoff}'
109
143
  _latency = ceil(base_latency / latency_cutoff) * latency_cutoff + dlat
144
+
110
145
  elif self.opr in ('relu', 'wrap'):
111
146
  assert len(self._from) == 1
112
147
  _latency = self._from[0].latency
@@ -154,6 +189,7 @@ class FixedVariable:
154
189
  return f'({self._factor}) FixedVariable({self.low}, {self.high}, {self.step})'
155
190
 
156
191
  def __neg__(self):
192
+ opr = self.opr if self.low != self.high else 'const'
157
193
  return FixedVariable(
158
194
  -self.high,
159
195
  -self.low,
@@ -162,7 +198,7 @@ class FixedVariable:
162
198
  _factor=-self._factor,
163
199
  latency=self.latency,
164
200
  cost=self.cost,
165
- opr=self.opr,
201
+ opr=opr,
166
202
  _id=self.id,
167
203
  _data=self._data,
168
204
  hwconf=self.hwconf,
@@ -195,7 +231,7 @@ class FixedVariable:
195
231
  hwconf=self.hwconf,
196
232
  )
197
233
 
198
- def _const_add(self, other: float | Decimal | None):
234
+ def _const_add(self, other: float | Decimal | None) -> 'FixedVariable':
199
235
  if other is None:
200
236
  return self
201
237
  if not isinstance(other, (int, float, Decimal)):
@@ -229,29 +265,66 @@ class FixedVariable:
229
265
  def __sub__(self, other: 'FixedVariable|int|float|Decimal'):
230
266
  return self + (-other)
231
267
 
232
- def __mul__(
233
- self,
234
- other: 'float|Decimal',
235
- ):
268
+ def __mul__(self, other: 'FixedVariable|int|float|Decimal') -> 'FixedVariable':
236
269
  if other == 0:
237
270
  return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
238
271
 
239
- assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
272
+ if isinstance(other, FixedVariable):
273
+ return self._var_mul(other)
274
+
275
+ if log2(abs(other)) % 1 == 0:
276
+ return self._pow2_mul(other)
277
+
278
+ variables = [self._pow2_mul(v) for v in to_csd_powers(float(other))]
279
+ while len(variables) > 1:
280
+ v = variables.pop() + variables.pop()
281
+ variables.append(v)
282
+ return variables[0]
283
+
284
+ def _var_mul(self, other: 'FixedVariable') -> 'FixedVariable':
285
+ if other is not self:
286
+ a, b, c, d = self.high * other.low, self.low * other.high, self.high * other.high, self.low * other.low
287
+ low = min(a, b, c, d)
288
+ high = max(a, b, c, d)
289
+ else:
290
+ a, b = self.low * other.low, self.high * other.high
291
+ if self.low < 0 and self.high > 0:
292
+ low = min(a, b, 0)
293
+ high = max(a, b, 0)
294
+ else:
295
+ low = min(a, b)
296
+ high = max(a, b)
240
297
 
298
+ step = self.step * other.step
299
+ _factor = self._factor * other._factor
300
+ opr = 'vmul'
301
+ return FixedVariable(
302
+ low,
303
+ high,
304
+ step,
305
+ _from=(self, other),
306
+ _factor=_factor,
307
+ opr=opr,
308
+ )
309
+
310
+ def _pow2_mul(
311
+ self,
312
+ other: float | Decimal,
313
+ ):
241
314
  other = Decimal(other)
242
315
 
243
316
  low = min(self.low * other, self.high * other)
244
317
  high = max(self.low * other, self.high * other)
245
318
  step = abs(self.step * other)
246
319
  _factor = self._factor * other
247
-
320
+ opr = self.opr
248
321
  return FixedVariable(
249
322
  low,
250
323
  high,
251
324
  step,
252
325
  _from=self._from,
253
326
  _factor=_factor,
254
- opr=self.opr,
327
+ opr=opr,
255
328
  latency=self.latency,
256
329
  cost=self.cost,
257
330
  _id=self.id,
@@ -268,6 +341,21 @@ class FixedVariable:
268
341
  def __rmul__(self, other: 'float|Decimal|int|FixedVariable'):
269
342
  return self * other
270
343
 
344
+ def __pow__(self, other):
345
+ _power = int(other)
346
+ assert _power == other, 'Power must be an integer'
347
+ assert _power >= 0, 'Power must be non-negative'
348
+ if _power == 0:
349
+ return FixedVariable(1, 1, 1, hwconf=self.hwconf, opr='const')
350
+ if _power == 1:
351
+ return self
352
+
353
+ pow0 = _power // 2
354
+ ret = (self**pow0) * (self ** (_power - pow0))
355
+ if other % 2 == 0:
356
+ ret.low = max(ret.low, 0)
357
+ return ret
358
+
271
359
  def relu(self, i: int | None = None, f: int | None = None, round_mode: str = 'TRN'):
272
360
  round_mode = round_mode.upper()
273
361
  assert round_mode in ('TRN', 'RND')
@@ -317,18 +405,21 @@ class FixedVariable:
317
405
  round_mode: str = 'TRN',
318
406
  ):
319
407
  overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
320
- assert overflow_mode in ('WRAP', 'SAT', 'SAT_SM')
408
+ assert overflow_mode in ('WRAP', 'SAT', 'SAT_SYM')
321
409
  assert round_mode in ('TRN', 'RND')
322
410
 
411
+ if k + i + f <= 0:
412
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
323
413
  _k, _i, _f = self.kif
324
414
 
325
415
  if k >= _k and i >= _i and f >= _f:
326
- return self
416
+ if overflow_mode != 'SAT_SYM' or i > _i:
417
+ return self
327
418
 
328
419
  if f < _f and round_mode == 'RND':
329
420
  return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
330
421
 
331
- if overflow_mode in ('SAT', 'SAT_SM'):
422
+ if overflow_mode in ('SAT', 'SAT_SYM'):
332
423
  step = Decimal(2) ** -f
333
424
  _high = Decimal(2) ** i
334
425
  high = _high - step
@@ -458,21 +549,43 @@ class FixedVariableInput(FixedVariable):
458
549
  self._from: tuple[FixedVariable, ...] = ()
459
550
  self.opr = 'new'
460
551
  self._data = None
461
- self.id = uuid4()
552
+ self.id = UUID(int=rd.getrandbits(128), version=4)
462
553
  self.hwconf = hwconf
463
554
 
464
555
  self.latency = latency if latency is not None else 0.0
465
556
  self.cost = 0.0
466
557
 
467
558
  def __add__(self, other):
559
+ if other == 0:
560
+ return self
468
561
  raise ValueError('Cannot operate on unquantized input variable')
469
562
 
470
563
  def __sub__(self, other):
564
+ if other == 0:
565
+ return self
471
566
  raise ValueError('Cannot operate on unquantized input variable')
472
567
 
473
568
  def __neg__(self):
474
569
  raise ValueError('Cannot negate unquantized input variable')
475
570
 
571
+ def __mul__(self, other):
572
+ if other == 1:
573
+ return self
574
+ raise ValueError('Cannot multiply unquantized input variable')
575
+
576
+ def __rmul__(self, other):
577
+ if other == 1:
578
+ return self
579
+ raise ValueError('Cannot multiply unquantized input variable')
580
+
581
+ def __radd__(self, other):
582
+ if other == 0:
583
+ return self
584
+ raise ValueError('Cannot add unquantized input variable')
585
+
586
+ def __rsub__(self, other):
587
+ raise ValueError('Cannot subtract unquantized input variable')
588
+
476
589
  def relu(self, *args, **kwargs):
477
590
  raise ValueError('Cannot apply relu on unquantized input variable')
478
591