da4ml 0.4.1__py3-none-any.whl → 0.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (40) hide show
  1. da4ml/__init__.py +2 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +2 -2
  4. da4ml/cmvm/api.py +15 -4
  5. da4ml/cmvm/core/__init__.py +2 -2
  6. da4ml/cmvm/types.py +32 -18
  7. da4ml/cmvm/util/bit_decompose.py +2 -2
  8. da4ml/codegen/hls/hls_codegen.py +10 -5
  9. da4ml/codegen/hls/hls_model.py +7 -4
  10. da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
  11. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  12. da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
  13. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  14. da4ml/codegen/rtl/common_source/template.xdc +11 -13
  15. da4ml/codegen/rtl/rtl_model.py +105 -54
  16. da4ml/codegen/rtl/verilog/__init__.py +2 -1
  17. da4ml/codegen/rtl/verilog/comb.py +47 -7
  18. da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
  19. da4ml/codegen/rtl/verilog/pipeline.py +12 -12
  20. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  21. da4ml/codegen/rtl/vhdl/comb.py +27 -21
  22. da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
  23. da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
  24. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  25. da4ml/converter/__init__.py +57 -1
  26. da4ml/converter/hgq2/parser.py +4 -25
  27. da4ml/converter/hgq2/replica.py +208 -22
  28. da4ml/trace/fixed_variable.py +239 -29
  29. da4ml/trace/fixed_variable_array.py +276 -48
  30. da4ml/trace/ops/__init__.py +31 -15
  31. da4ml/trace/ops/reduce_utils.py +3 -3
  32. da4ml/trace/pipeline.py +40 -18
  33. da4ml/trace/tracer.py +33 -8
  34. da4ml/typing/__init__.py +3 -0
  35. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
  36. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
  37. da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
  38. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
  39. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
  40. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
@@ -18,11 +18,16 @@ from hgq.layers import (
18
18
  QEinsum,
19
19
  QEinsumDense,
20
20
  QEinsumDenseBatchnorm,
21
+ QLinformerAttention,
21
22
  QMaximum,
22
23
  QMeanPow2,
23
24
  QMinimum,
25
+ QMultiHeadAttention,
26
+ QMultiply,
27
+ QSoftmax,
24
28
  QSubtract,
25
29
  QSum,
30
+ QUnaryFunctionLUT,
26
31
  )
27
32
  from hgq.layers.core.base import MultipleQuantizers, Quantizer
28
33
  from hgq.quantizer.internal import FixedPointQuantizerBase
@@ -68,7 +73,9 @@ def mirror_quantizer(q: Quantizer, v: FixedVariableArray) -> FixedVariableArray:
68
73
  _registry: dict[type, 'type[ReplayOperationBase]'] = {}
69
74
 
70
75
 
71
- class ReplayOperationMeta(type):
76
+ class HandlerRegMeta(type):
77
+ """Metaclass for automatic registration of handler classes."""
78
+
72
79
  def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]):
73
80
  cls = super().__new__(mcs, name, bases, namespace)
74
81
  if name == 'ReplayOperationBase':
@@ -83,8 +90,11 @@ class ReplayOperationMeta(type):
83
90
  return cls
84
91
 
85
92
 
86
- class ReplayOperationBase(metaclass=ReplayOperationMeta):
93
+ class ReplayOperationBase(metaclass=HandlerRegMeta):
87
94
  handles: tuple[type, ...] = ()
95
+ __activation_handled__ = False
96
+ __input_quantizer_handled__ = False
97
+ __output_quantizer_handled__ = False
88
98
 
89
99
  def __init__(self, layer: 'keras.Operation'):
90
100
  assert isinstance(layer, self.handles)
@@ -94,7 +104,6 @@ class ReplayOperationBase(metaclass=ReplayOperationMeta):
94
104
 
95
105
  def __call__(self, *args, **kwargs) -> tuple[FixedVariableArray, ...]:
96
106
  assert all(not isinstance(a, FixedVariableArray) for a in kwargs.values())
97
- inputs = args[0] if len(args) == 1 else args
98
107
 
99
108
  if not isinstance(self.op, hgq.layers.QLayerBase):
100
109
  r = self.call(*args, **kwargs)
@@ -104,28 +113,35 @@ class ReplayOperationBase(metaclass=ReplayOperationMeta):
104
113
  assert kwargs.pop('training', False) is False, 'Training mode is not supported in mirror operation'
105
114
  assert kwargs.pop('mask', None) is None, 'Masking is not supported in mirror operation'
106
115
 
107
- if layer.enable_iq:
108
- if isinstance(inputs, Sequence):
109
- assert isinstance(layer.iq, MultipleQuantizers)
110
- inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
111
- else:
112
- assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
113
- inputs = mirror_quantizer(layer.iq, inputs)
116
+ if not self.__input_quantizer_handled__:
117
+ assert len(args) == 1
118
+ inputs = args[0]
114
119
 
115
- outputs = self.call(inputs, **kwargs)
120
+ if layer.enable_iq:
121
+ if isinstance(inputs, Sequence):
122
+ assert isinstance(layer.iq, MultipleQuantizers)
123
+ inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
124
+ else:
125
+ assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
126
+ inputs = mirror_quantizer(layer.iq, inputs)
116
127
 
117
- activation = getattr(layer, 'activation', keras.activations.linear)
118
- if activation is not keras.activations.linear:
119
- if activation is keras.activations.relu:
120
- if isinstance(outputs, tuple):
121
- assert len(outputs) == 1, 'ReLU activation is expected to have a single output'
122
- outputs = (relu(outputs[0]),)
128
+ outputs = self.call(inputs, **kwargs)
129
+ else:
130
+ outputs = self.call(*args, **kwargs)
131
+
132
+ if not self.__activation_handled__:
133
+ activation = getattr(layer, 'activation', keras.activations.linear)
134
+ if activation is not keras.activations.linear:
135
+ if activation is keras.activations.relu:
136
+ if isinstance(outputs, tuple):
137
+ assert len(outputs) == 1, 'ReLU activation is expected to have a single output'
138
+ outputs = (relu(outputs[0]),)
139
+ else:
140
+ outputs = relu(outputs)
123
141
  else:
124
- outputs = relu(outputs)
125
- else:
126
- raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
142
+ raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
127
143
 
128
- if layer.enable_oq:
144
+ if layer.enable_oq and not self.__output_quantizer_handled__:
129
145
  if isinstance(outputs, tuple):
130
146
  assert isinstance(layer.oq, MultipleQuantizers)
131
147
  outputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.oq.quantizers, outputs))
@@ -366,7 +382,7 @@ class ReplayQReduction(ReplayOperationBase):
366
382
 
367
383
 
368
384
  class ReplayArithmetic(ReplayOperationBase):
369
- handles = (Add, Subtract, Multiply, TrueDivide, Divide, QSubtract, QMaximum, QMinimum, Maximum, Minimum)
385
+ handles = (Add, Subtract, Multiply, QMultiply, TrueDivide, Divide, QSubtract, QMaximum, QMinimum, Maximum, Minimum)
370
386
 
371
387
  def call(self, x1: FixedVariableArray, x2: FixedVariableArray):
372
388
  name = self.op.__class__.__name__
@@ -470,3 +486,173 @@ class ReplayAbs(ReplayOperationBase):
470
486
 
471
487
  def call(self, x: FixedVariableArray) -> FixedVariableArray:
472
488
  return np.abs(x) # type: ignore
489
+
490
+
491
+ class ReplayQFunctionLUT(ReplayOperationBase):
492
+ __activation_handled__ = True
493
+ handles = (QUnaryFunctionLUT,)
494
+
495
+ def call(self, x: FixedVariableArray) -> FixedVariableArray:
496
+ op: QUnaryFunctionLUT = self.op
497
+
498
+ def activation(x) -> np.ndarray:
499
+ kx = keras.ops.convert_to_tensor(x[None])
500
+ kx = op.activation(kx)
501
+ return keras.ops.convert_to_numpy(kx[0]) # type: ignore
502
+
503
+ return x.apply(activation)
504
+
505
+
506
+ class ReplayQSoftmax(ReplayOperationBase):
507
+ handles = (QSoftmax,)
508
+
509
+ def call(self, inputs: FixedVariableArray, mask: None | FixedVariableArray = None) -> FixedVariableArray:
510
+ op: QSoftmax = self.op
511
+ inputs = inputs[None]
512
+
513
+ if op.stable:
514
+ inputs = np.amax(inputs, axis=op.axes, keepdims=True) - inputs # type: ignore
515
+
516
+ exp_inp = ReplayQFunctionLUT(op.exp_table)(inputs[0])[0]
517
+
518
+ if mask is not None:
519
+ exp_inp = mask[0] * exp_inp
520
+
521
+ sums = np.sum(exp_inp[None], axis=op.axes, keepdims=True)[0] # type: ignore
522
+ divisor = ReplayQFunctionLUT(op.inv_table)(sums)[0]
523
+
524
+ return exp_inp * divisor
525
+
526
+
527
+ def _compute_attention_mask(
528
+ query,
529
+ value,
530
+ query_mask=None,
531
+ value_mask=None,
532
+ key_mask=None,
533
+ attention_mask=None,
534
+ use_causal_mask=False,
535
+ ):
536
+ masks = []
537
+ if query_mask is not None:
538
+ masks.append(np.expand_dims(query_mask, -1)) # [Q, 1]
539
+ if value_mask is not None:
540
+ masks.append(np.expand_dims(value_mask, -2)) # [1, V]
541
+ if key_mask is not None:
542
+ masks.append(np.expand_dims(key_mask, -2)) # [1, V]
543
+ if use_causal_mask:
544
+ q = query.shape[0]
545
+ v = q if value is None else value.shape[0]
546
+ masks.append(np.tril(np.ones((q, v), dtype='uint8'))) # [Q, V]
547
+ masks.append(attention_mask)
548
+ if not masks:
549
+ return None
550
+
551
+ if any(isinstance(m, FixedVariableArray) for m in masks):
552
+ return np.prod(np.stack(masks, axis=0), axis=0)
553
+ else:
554
+ return None
555
+
556
+
557
+ def _masked_softmax(op, attention_scores, attention_mask=None):
558
+ # Normalize the attention scores to probabilities.
559
+ # attention_scores = [B, N, T, S]
560
+ if attention_mask is not None:
561
+ # The expand dim happens starting from the `num_heads` dimension,
562
+ # (<batch_dims>, num_heads, <query_attention_dims,
563
+ # key_attention_dims>)
564
+ mask_expansion_axis = -len(op._attention_axes) * 2 - 1
565
+ for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
566
+ attention_mask = np.expand_dims(attention_mask, axis=mask_expansion_axis)
567
+ return ReplayQSoftmax(op._softmax)(attention_scores[0], mask=attention_mask)[0][None]
568
+
569
+
570
+ def _compute_attention(op: QMultiHeadAttention, query, key, value, attention_mask=None, training=None):
571
+ # Take the dot product between "query" and "key" to get the raw
572
+ # attention scores.
573
+ attention_scores = einsum(op._dot_product_equation, key, query)
574
+
575
+ attention_scores = _masked_softmax(op, attention_scores, attention_mask)
576
+
577
+ # `context_layer` = [B, T, N, H]
578
+ attention_output = einsum(op._combine_equation, attention_scores, value)
579
+ return attention_output, attention_scores
580
+
581
+
582
+ class ReplayMHA(ReplayOperationBase):
583
+ handles = (QMultiHeadAttention,)
584
+ __input_quantizer_handled__ = True
585
+ __output_quantizer_handled__ = True
586
+
587
+ def call(
588
+ self,
589
+ query: FixedVariableArray,
590
+ value: FixedVariableArray,
591
+ key=None,
592
+ query_mask=None,
593
+ value_mask=None,
594
+ key_mask=None,
595
+ attention_mask=None,
596
+ return_attention_scores=False,
597
+ use_causal_mask=False,
598
+ ):
599
+ op: QMultiHeadAttention = self.op
600
+
601
+ if key is None:
602
+ key = value
603
+
604
+ _attention_mask = _compute_attention_mask(
605
+ query,
606
+ value,
607
+ query_mask=query_mask,
608
+ value_mask=value_mask,
609
+ key_mask=key_mask,
610
+ attention_mask=attention_mask,
611
+ use_causal_mask=use_causal_mask,
612
+ )
613
+
614
+ query = ReplayQDense(op._query_dense)(query)[0][None]
615
+ key = ReplayQDense(op._key_dense)(key)[0][None]
616
+ value = ReplayQDense(op._value_dense)(value)[0][None]
617
+
618
+ attention_output, attention_scores = _compute_attention(op, query, key, value, _attention_mask)
619
+ attention_output = ReplayQDense(op._output_dense)(attention_output[0])[0]
620
+
621
+ if op.enable_oq:
622
+ attention_output = mirror_quantizer(op.oq, attention_output)
623
+
624
+ if return_attention_scores:
625
+ return attention_output, attention_scores[0]
626
+ return attention_output
627
+
628
+
629
+ class ReplayQLinformerAttention(ReplayMHA):
630
+ handles = (QLinformerAttention,)
631
+
632
+ def call(
633
+ self,
634
+ query,
635
+ value,
636
+ key=None,
637
+ query_mask=None,
638
+ value_mask=None,
639
+ key_mask=None,
640
+ attention_mask=None,
641
+ return_attention_scores=False,
642
+ use_causal_mask=False,
643
+ ):
644
+ assert use_causal_mask is False, 'Causal mask is not supported in QLinformerAttention.'
645
+ key = key if key is not None else value
646
+ op: QLinformerAttention = self.op
647
+ key = ReplayQDense(op._lin_k_proj)(key)[0]
648
+ value = ReplayQDense(op._lin_v_proj)(value)[0]
649
+ return super().call(
650
+ query,
651
+ value,
652
+ key,
653
+ query_mask=query_mask,
654
+ value_mask=value_mask,
655
+ key_mask=key_mask,
656
+ attention_mask=attention_mask,
657
+ return_attention_scores=return_attention_scores,
658
+ )
@@ -1,14 +1,24 @@
1
1
  import random
2
- from collections.abc import Generator
2
+ import typing
3
+ from collections.abc import Callable, Generator
4
+ from dataclasses import dataclass
3
5
  from decimal import Decimal
6
+ from hashlib import sha256
4
7
  from math import ceil, floor, log2
5
- from typing import NamedTuple
8
+ from typing import NamedTuple, overload
6
9
  from uuid import UUID
7
10
 
11
+ import numpy as np
12
+ from numpy.typing import NDArray
13
+
8
14
  from ..cmvm.core import cost_add
9
- from ..cmvm.types import QInterval
15
+ from ..cmvm.types import QInterval, _minimal_kif
16
+ from ..cmvm.util.bit_decompose import _shift_centering
17
+
18
+ rd = random.Random()
10
19
 
11
- rd = random.SystemRandom()
20
+ if typing.TYPE_CHECKING:
21
+ pass
12
22
 
13
23
 
14
24
  class HWConfig(NamedTuple):
@@ -17,7 +27,154 @@ class HWConfig(NamedTuple):
17
27
  latency_cutoff: float
18
28
 
19
29
 
30
+ ufunc_t = Callable[[NDArray[np.floating]], NDArray[np.floating]]
31
+
32
+
33
+ class TraceContext:
34
+ _tables: 'dict[str, tuple[LookupTable, int]]' = {}
35
+ hwconf: HWConfig = HWConfig(1, -1, -1)
36
+ _table_counter = 0
37
+
38
+ def register_table(self, table: 'LookupTable|np.ndarray'):
39
+ if isinstance(table, np.ndarray):
40
+ table = LookupTable(table)
41
+ if table.spec.hash in self._tables:
42
+ return self._tables[table.spec.hash]
43
+ self._tables[table.spec.hash] = (table, self._table_counter)
44
+
45
+ self._table_counter += 1
46
+ return self._tables[table.spec.hash]
47
+
48
+ def index_table(self, hash: str) -> int:
49
+ return self._tables[hash][1]
50
+
51
+ def get_table_from_index(self, index: int) -> 'LookupTable':
52
+ for table, idx in self._tables.values():
53
+ if idx == index:
54
+ return table
55
+ raise KeyError(f'No table found with index {index}')
56
+
57
+
58
+ table_context = TraceContext()
59
+
60
+
61
+ @dataclass
62
+ class TableSpec:
63
+ hash: str
64
+ out_qint: QInterval
65
+ inp_width: int
66
+
67
+ @property
68
+ def out_kif(self) -> tuple[bool, int, int]:
69
+ return _minimal_kif(self.out_qint)
70
+
71
+
72
+ def to_spec(table: NDArray[np.floating]) -> tuple[TableSpec, NDArray[np.int32]]:
73
+ f_out = -_shift_centering(np.array(table))
74
+ int_table = (table * 2**f_out).astype(np.int32)
75
+ h = sha256(int_table.data)
76
+ h.update(f'{f_out}'.encode())
77
+ inp_width = ceil(log2(table.size))
78
+ out_qint = QInterval(float(np.min(table)), float(np.max(table)), float(2**-f_out))
79
+ return TableSpec(hash=h.hexdigest(), inp_width=inp_width, out_qint=out_qint), int_table
80
+
81
+
82
+ def interpret_as(
83
+ x: int | NDArray[np.integer],
84
+ k: int,
85
+ i: int,
86
+ f: int,
87
+ ) -> float | NDArray[np.floating]:
88
+ b = k + i + f
89
+ bias = 2.0 ** (b - 1) * k
90
+ eps = 2.0**-f
91
+ floor_fn = np.floor if isinstance(x, np.ndarray) else floor
92
+ return eps * (floor_fn(x + bias) % 2.0**b - bias)
93
+
94
+
95
+ class LookupTable:
96
+ def __init__(self, values: NDArray, spec: TableSpec | None = None):
97
+ assert values.ndim == 1, 'Lookup table values must be 1-dimensional'
98
+ if spec is not None:
99
+ assert values.dtype is np.int32
100
+ self.spec = spec
101
+ self.table = values
102
+ else:
103
+ self.spec, self.table = to_spec(values)
104
+
105
+ @overload
106
+ def lookup(self, var: 'FixedVariable', qint_in: QInterval) -> 'FixedVariable': ...
107
+
108
+ @overload
109
+ def lookup(self, var: np.floating | float, qint_in: QInterval | tuple[float, float, float]) -> float: ...
110
+
111
+ def lookup(self, var, qint_in: QInterval | tuple[float, float, float]):
112
+ if isinstance(var, FixedVariable):
113
+ return var.lookup(self)
114
+ else:
115
+ _min, _max, _step = qint_in
116
+ assert _min <= var <= _max, f'Value {var} out of range [{_min}, {_max}]'
117
+ index = round((var - _min) / _step)
118
+ return interpret_as(int(self.table[index]), *self.spec.out_kif)
119
+
120
+ @property
121
+ def float_table(self) -> NDArray[np.floating]:
122
+ k, i, f = self.spec.out_kif
123
+ return interpret_as(self.table, k, i, f) # type: ignore
124
+
125
+ def to_dict(self) -> dict:
126
+ return {
127
+ 'spec': {
128
+ 'hash': self.spec.hash,
129
+ 'out_qint': {
130
+ 'min': self.spec.out_qint.min,
131
+ 'max': self.spec.out_qint.max,
132
+ 'step': self.spec.out_qint.step,
133
+ },
134
+ 'inp_width': self.spec.inp_width,
135
+ },
136
+ 'table': self.table.tolist(),
137
+ }
138
+
139
+ @classmethod
140
+ def from_dict(cls, data: dict) -> 'LookupTable':
141
+ spec_data = data['spec']
142
+ out_qint_data = spec_data['out_qint']
143
+ spec = TableSpec(
144
+ hash=spec_data['hash'],
145
+ out_qint=QInterval(out_qint_data['min'], out_qint_data['max'], out_qint_data['step']),
146
+ inp_width=spec_data['inp_width'],
147
+ )
148
+ table = np.array(data['table'], dtype=np.int32)
149
+ return cls(table, spec=spec)
150
+
151
+ def _get_pads(self, qint: QInterval) -> tuple[int, int]:
152
+ k, i, f = _minimal_kif(qint)
153
+ if k:
154
+ pad_left = round((qint.min + 2**i) / qint.step)
155
+ else:
156
+ pad_left = round(qint.min / qint.step)
157
+ size = 2 ** (k + i + f)
158
+ pad_right = size - len(self.table) - pad_left
159
+ return pad_left, pad_right
160
+
161
+ def padded_table(self, qint: QInterval) -> NDArray[np.int32]:
162
+ pad_left, pad_right = self._get_pads(qint)
163
+ data = np.pad(self.table, (pad_left, pad_right), mode='constant', constant_values=0)
164
+ if qint.min < 0:
165
+ size = len(data)
166
+ # data = np.concatenate((data[size // 2 :], data[: size // 2]))
167
+ data = np.roll(data, size // 2)
168
+ return data
169
+
170
+ def get_uuid(self, qint: QInterval) -> UUID:
171
+ pad_left, _ = self._get_pads(qint)
172
+ _int = int(self.spec.hash[:32], 16) ^ pad_left
173
+ return UUID(int=_int, version=4)
174
+
175
+
20
176
  def _const_f(const: float | Decimal):
177
+ """Get the minimum f such that const * 2^f is an integer."""
21
178
  const = float(const)
22
179
  _low, _high = -32, 32
23
180
  while _high - _low > 1:
@@ -31,6 +188,7 @@ def _const_f(const: float | Decimal):
31
188
 
32
189
 
33
190
  def to_csd_powers(x: float) -> Generator[float, None, None]:
191
+ """Convert a float to a list of +/- powers of two in CSD representation."""
34
192
  if x == 0:
35
193
  return
36
194
  f = _const_f(abs(x))
@@ -48,6 +206,8 @@ def to_csd_powers(x: float) -> Generator[float, None, None]:
48
206
 
49
207
 
50
208
  class FixedVariable:
209
+ __normal__variable__ = True
210
+
51
211
  def __init__(
52
212
  self,
53
213
  low: float | Decimal,
@@ -62,7 +222,8 @@ class FixedVariable:
62
222
  _data: Decimal | None = None,
63
223
  _id: UUID | None = None,
64
224
  ) -> None:
65
- assert low <= high, f'low {low} must be less than high {high}'
225
+ if self.__normal__variable__:
226
+ assert low <= high, f'low {low} must be less than high {high}'
66
227
 
67
228
  if low != high and opr == 'const':
68
229
  raise ValueError('Constant variable must have low == high')
@@ -100,9 +261,19 @@ class FixedVariable:
100
261
  if v.opr == 'const':
101
262
  v.latency = self.latency
102
263
 
103
- def get_cost_and_latency(self):
264
+ def get_cost_and_latency(self) -> tuple[float, float]:
104
265
  if self.opr == 'const':
105
266
  return 0.0, 0.0
267
+
268
+ if self.opr == 'lookup':
269
+ assert len(self._from) == 1
270
+ b_in = sum(self._from[0].kif)
271
+ b_out = sum(self.kif)
272
+ _latency = max(b_in - 6, 1) + self._from[0].latency
273
+ _cost = 2 ** max(b_in - 5, 0) * ceil(b_out / 2)
274
+ # Assume LUT6 with extra o5 output
275
+ return _cost, _latency
276
+
106
277
  if self.opr in ('vadd', 'cadd', 'min', 'max', 'vmul'):
107
278
  adder_size = self.hwconf.adder_size
108
279
  carry_size = self.hwconf.carry_size
@@ -212,7 +383,7 @@ class FixedVariable:
212
383
  if self.high == self.low:
213
384
  return other._const_add(self.low)
214
385
 
215
- assert self.hwconf == other.hwconf, 'FixedVariable must have the same hwconf'
386
+ assert self.hwconf == other.hwconf, f'FixedVariable must have the same hwconf, got {self.hwconf} and {other.hwconf}'
216
387
 
217
388
  f0, f1 = self._factor, other._factor
218
389
  if f0 < 0:
@@ -270,20 +441,32 @@ class FixedVariable:
270
441
  return self * (1 / other)
271
442
 
272
443
  def __mul__(self, other: 'FixedVariable|int|float|Decimal') -> 'FixedVariable':
444
+ if isinstance(other, FixedVariable):
445
+ if self.high == self.low:
446
+ return other * self.low
447
+ if other.high > other.low:
448
+ return self._var_mul(other)
449
+ assert other.high == other.low
450
+ other = float(other.low)
451
+
273
452
  if other == 0:
274
453
  return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
275
454
 
276
- if isinstance(other, FixedVariable):
277
- return self._var_mul(other)
278
-
279
455
  if log2(abs(other)) % 1 == 0:
280
456
  return self._pow2_mul(other)
281
457
 
282
- variables = [self._pow2_mul(v) for v in to_csd_powers(float(other))]
458
+ variables = [(self._pow2_mul(v), Decimal(v)) for v in to_csd_powers(float(other))]
283
459
  while len(variables) > 1:
284
- v = variables.pop() + variables.pop()
285
- variables.append(v)
286
- return variables[0]
460
+ v1, p1 = variables.pop()
461
+ v2, p2 = variables.pop()
462
+ v, p = v1 + v2, p1 + p2
463
+ if p > 0:
464
+ high, low = self.high * p, self.low * p
465
+ else:
466
+ high, low = self.low * p, self.high * p
467
+ v.high, v.low = high, low
468
+ variables.append((v, p))
469
+ return variables[0][0]
287
470
 
288
471
  def _var_mul(self, other: 'FixedVariable') -> 'FixedVariable':
289
472
  if other is not self:
@@ -307,6 +490,7 @@ class FixedVariable:
307
490
  high,
308
491
  step,
309
492
  _from=(self, other),
493
+ hwconf=self.hwconf,
310
494
  _factor=_factor,
311
495
  opr=opr,
312
496
  )
@@ -407,7 +591,7 @@ class FixedVariable:
407
591
  f: int,
408
592
  overflow_mode: str = 'WRAP',
409
593
  round_mode: str = 'TRN',
410
- ):
594
+ ) -> 'FixedVariable':
411
595
  overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
412
596
  assert overflow_mode in ('WRAP', 'SAT', 'SAT_SYM')
413
597
  assert round_mode in ('TRN', 'RND')
@@ -428,7 +612,9 @@ class FixedVariable:
428
612
  _high = Decimal(2) ** i
429
613
  high = _high - step
430
614
  low = -_high * k if overflow_mode == 'SAT' else -high * k
431
- return self.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
615
+ ff = f + 1 if round_mode == 'RND' else f
616
+ v = self.quantize(_k, _i, ff, 'WRAP', 'TRN')
617
+ return v.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
432
618
 
433
619
  if self.low == self.high:
434
620
  val = self.low
@@ -539,25 +725,47 @@ class FixedVariable:
539
725
  qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
540
726
  return (self - other).msb_mux(self, other, qint=qint)
541
727
 
728
+ def lookup(self, table: LookupTable | np.ndarray) -> 'FixedVariable':
729
+ _table, table_id = table_context.register_table(table)
730
+ size = len(table.table) if isinstance(table, LookupTable) else len(table)
731
+ assert (
732
+ round((self.high - self.low) / self.step) + 1 == size
733
+ ), f'Input variable size does not match lookup table size ({round((self.high - self.low) / self.step) + 1} != {size})'
734
+
735
+ return FixedVariable(
736
+ _table.spec.out_qint.min,
737
+ _table.spec.out_qint.max,
738
+ _table.spec.out_qint.step,
739
+ _from=(self,),
740
+ _factor=Decimal(1),
741
+ opr='lookup',
742
+ hwconf=self.hwconf,
743
+ _data=Decimal(table_id),
744
+ )
745
+
542
746
 
543
747
  class FixedVariableInput(FixedVariable):
748
+ __normal__variable__ = False
749
+
544
750
  def __init__(
545
751
  self,
546
752
  latency: float | None = None,
547
- hwconf=HWConfig(-1, -1, -1),
753
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(-1, -1, -1),
754
+ opr: str = 'new',
548
755
  ) -> None:
549
- self.low = Decimal(1e10)
550
- self.high = Decimal(-1e10)
551
- self.step = Decimal(1e10)
552
- self._factor = Decimal(1)
553
- self._from: tuple[FixedVariable, ...] = ()
554
- self.opr = 'new'
555
- self._data = None
556
- self.id = UUID(int=rd.getrandbits(128), version=4)
557
- self.hwconf = hwconf
558
-
559
- self.latency = latency if latency is not None else 0.0
560
- self.cost = 0.0
756
+ super().__init__(
757
+ low=Decimal(1e10),
758
+ high=Decimal(-1e10),
759
+ step=Decimal(1e10),
760
+ latency=latency if latency is not None else 0.0,
761
+ hwconf=HWConfig(*hwconf),
762
+ opr=opr,
763
+ cost=0.0,
764
+ _factor=Decimal(1),
765
+ _from=(),
766
+ _data=None,
767
+ _id=None,
768
+ )
561
769
 
562
770
  def __add__(self, other):
563
771
  if other == 0:
@@ -614,6 +822,8 @@ class FixedVariableInput(FixedVariable):
614
822
 
615
823
  if round_mode == 'RND':
616
824
  return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
825
+ else:
826
+ round_mode = 'TRN'
617
827
 
618
828
  step = Decimal(2) ** -f
619
829
  _high = Decimal(2) ** i