da4ml 0.4.0__py3-none-any.whl → 0.5.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +2 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +2 -2
- da4ml/cmvm/api.py +15 -4
- da4ml/cmvm/core/__init__.py +2 -2
- da4ml/cmvm/types.py +32 -18
- da4ml/cmvm/util/bit_decompose.py +2 -2
- da4ml/codegen/hls/hls_codegen.py +10 -5
- da4ml/codegen/hls/hls_model.py +7 -4
- da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +11 -13
- da4ml/codegen/rtl/rtl_model.py +105 -53
- da4ml/codegen/rtl/verilog/__init__.py +2 -1
- da4ml/codegen/rtl/verilog/comb.py +47 -7
- da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
- da4ml/codegen/rtl/verilog/pipeline.py +12 -12
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/vhdl/comb.py +27 -21
- da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
- da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/converter/__init__.py +57 -1
- da4ml/converter/hgq2/parser.py +4 -25
- da4ml/converter/hgq2/replica.py +210 -25
- da4ml/trace/fixed_variable.py +239 -29
- da4ml/trace/fixed_variable_array.py +276 -48
- da4ml/trace/ops/__init__.py +31 -15
- da4ml/trace/ops/reduce_utils.py +3 -3
- da4ml/trace/pipeline.py +40 -18
- da4ml/trace/tracer.py +33 -8
- da4ml/typing/__init__.py +3 -0
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
- da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
da4ml/converter/hgq2/replica.py
CHANGED
|
@@ -18,11 +18,16 @@ from hgq.layers import (
|
|
|
18
18
|
QEinsum,
|
|
19
19
|
QEinsumDense,
|
|
20
20
|
QEinsumDenseBatchnorm,
|
|
21
|
+
QLinformerAttention,
|
|
21
22
|
QMaximum,
|
|
22
23
|
QMeanPow2,
|
|
23
24
|
QMinimum,
|
|
25
|
+
QMultiHeadAttention,
|
|
26
|
+
QMultiply,
|
|
27
|
+
QSoftmax,
|
|
24
28
|
QSubtract,
|
|
25
29
|
QSum,
|
|
30
|
+
QUnaryFunctionLUT,
|
|
26
31
|
)
|
|
27
32
|
from hgq.layers.core.base import MultipleQuantizers, Quantizer
|
|
28
33
|
from hgq.quantizer.internal import FixedPointQuantizerBase
|
|
@@ -68,7 +73,9 @@ def mirror_quantizer(q: Quantizer, v: FixedVariableArray) -> FixedVariableArray:
|
|
|
68
73
|
_registry: dict[type, 'type[ReplayOperationBase]'] = {}
|
|
69
74
|
|
|
70
75
|
|
|
71
|
-
class
|
|
76
|
+
class HandlerRegMeta(type):
|
|
77
|
+
"""Metaclass for automatic registration of handler classes."""
|
|
78
|
+
|
|
72
79
|
def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]):
|
|
73
80
|
cls = super().__new__(mcs, name, bases, namespace)
|
|
74
81
|
if name == 'ReplayOperationBase':
|
|
@@ -83,8 +90,11 @@ class ReplayOperationMeta(type):
|
|
|
83
90
|
return cls
|
|
84
91
|
|
|
85
92
|
|
|
86
|
-
class ReplayOperationBase(metaclass=
|
|
93
|
+
class ReplayOperationBase(metaclass=HandlerRegMeta):
|
|
87
94
|
handles: tuple[type, ...] = ()
|
|
95
|
+
__activation_handled__ = False
|
|
96
|
+
__input_quantizer_handled__ = False
|
|
97
|
+
__output_quantizer_handled__ = False
|
|
88
98
|
|
|
89
99
|
def __init__(self, layer: 'keras.Operation'):
|
|
90
100
|
assert isinstance(layer, self.handles)
|
|
@@ -94,8 +104,6 @@ class ReplayOperationBase(metaclass=ReplayOperationMeta):
|
|
|
94
104
|
|
|
95
105
|
def __call__(self, *args, **kwargs) -> tuple[FixedVariableArray, ...]:
|
|
96
106
|
assert all(not isinstance(a, FixedVariableArray) for a in kwargs.values())
|
|
97
|
-
assert all(isinstance(a, FixedVariableArray) or isinstance(a, Sequence) for a in args)
|
|
98
|
-
inputs = args[0] if len(args) == 1 else args
|
|
99
107
|
|
|
100
108
|
if not isinstance(self.op, hgq.layers.QLayerBase):
|
|
101
109
|
r = self.call(*args, **kwargs)
|
|
@@ -105,28 +113,35 @@ class ReplayOperationBase(metaclass=ReplayOperationMeta):
|
|
|
105
113
|
assert kwargs.pop('training', False) is False, 'Training mode is not supported in mirror operation'
|
|
106
114
|
assert kwargs.pop('mask', None) is None, 'Masking is not supported in mirror operation'
|
|
107
115
|
|
|
108
|
-
if
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
|
|
112
|
-
else:
|
|
113
|
-
assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
|
|
114
|
-
inputs = mirror_quantizer(layer.iq, inputs)
|
|
116
|
+
if not self.__input_quantizer_handled__:
|
|
117
|
+
assert len(args) == 1
|
|
118
|
+
inputs = args[0]
|
|
115
119
|
|
|
116
|
-
|
|
120
|
+
if layer.enable_iq:
|
|
121
|
+
if isinstance(inputs, Sequence):
|
|
122
|
+
assert isinstance(layer.iq, MultipleQuantizers)
|
|
123
|
+
inputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.iq.quantizers, inputs))
|
|
124
|
+
else:
|
|
125
|
+
assert isinstance(layer.iq, Quantizer), f'Expected iq to be a Quantizer, got {type(layer.iq)}'
|
|
126
|
+
inputs = mirror_quantizer(layer.iq, inputs)
|
|
117
127
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
128
|
+
outputs = self.call(inputs, **kwargs)
|
|
129
|
+
else:
|
|
130
|
+
outputs = self.call(*args, **kwargs)
|
|
131
|
+
|
|
132
|
+
if not self.__activation_handled__:
|
|
133
|
+
activation = getattr(layer, 'activation', keras.activations.linear)
|
|
134
|
+
if activation is not keras.activations.linear:
|
|
135
|
+
if activation is keras.activations.relu:
|
|
136
|
+
if isinstance(outputs, tuple):
|
|
137
|
+
assert len(outputs) == 1, 'ReLU activation is expected to have a single output'
|
|
138
|
+
outputs = (relu(outputs[0]),)
|
|
139
|
+
else:
|
|
140
|
+
outputs = relu(outputs)
|
|
124
141
|
else:
|
|
125
|
-
|
|
126
|
-
else:
|
|
127
|
-
raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
|
|
142
|
+
raise NotImplementedError(f'Activation {activation} is not supported in mirror operation')
|
|
128
143
|
|
|
129
|
-
if layer.enable_oq:
|
|
144
|
+
if layer.enable_oq and not self.__output_quantizer_handled__:
|
|
130
145
|
if isinstance(outputs, tuple):
|
|
131
146
|
assert isinstance(layer.oq, MultipleQuantizers)
|
|
132
147
|
outputs = tuple(mirror_quantizer(q, v) for q, v in zip(layer.oq.quantizers, outputs))
|
|
@@ -134,7 +149,7 @@ class ReplayOperationBase(metaclass=ReplayOperationMeta):
|
|
|
134
149
|
assert isinstance(layer.oq, Quantizer)
|
|
135
150
|
outputs = mirror_quantizer(layer.oq, outputs)
|
|
136
151
|
|
|
137
|
-
if isinstance(outputs, FixedVariableArray):
|
|
152
|
+
if isinstance(outputs, (FixedVariableArray, np.ndarray)):
|
|
138
153
|
outputs = (outputs,)
|
|
139
154
|
|
|
140
155
|
return outputs
|
|
@@ -193,7 +208,7 @@ class ReplayQBatchNormalization(ReplayOperationBase):
|
|
|
193
208
|
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
194
209
|
layer: QBatchNormalization = self.op
|
|
195
210
|
scale, bias = map(np.array, layer.qscaler_and_qoffset)
|
|
196
|
-
shape = layer._shape
|
|
211
|
+
shape = layer._shape[1:]
|
|
197
212
|
return inputs * scale.reshape(shape) + bias.reshape(shape)
|
|
198
213
|
|
|
199
214
|
|
|
@@ -367,7 +382,7 @@ class ReplayQReduction(ReplayOperationBase):
|
|
|
367
382
|
|
|
368
383
|
|
|
369
384
|
class ReplayArithmetic(ReplayOperationBase):
|
|
370
|
-
handles = (Add, Subtract, Multiply, TrueDivide, Divide, QSubtract, QMaximum, QMinimum, Maximum, Minimum)
|
|
385
|
+
handles = (Add, Subtract, Multiply, QMultiply, TrueDivide, Divide, QSubtract, QMaximum, QMinimum, Maximum, Minimum)
|
|
371
386
|
|
|
372
387
|
def call(self, x1: FixedVariableArray, x2: FixedVariableArray):
|
|
373
388
|
name = self.op.__class__.__name__
|
|
@@ -471,3 +486,173 @@ class ReplayAbs(ReplayOperationBase):
|
|
|
471
486
|
|
|
472
487
|
def call(self, x: FixedVariableArray) -> FixedVariableArray:
|
|
473
488
|
return np.abs(x) # type: ignore
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
class ReplayQFunctionLUT(ReplayOperationBase):
|
|
492
|
+
__activation_handled__ = True
|
|
493
|
+
handles = (QUnaryFunctionLUT,)
|
|
494
|
+
|
|
495
|
+
def call(self, x: FixedVariableArray) -> FixedVariableArray:
|
|
496
|
+
op: QUnaryFunctionLUT = self.op
|
|
497
|
+
|
|
498
|
+
def activation(x) -> np.ndarray:
|
|
499
|
+
kx = keras.ops.convert_to_tensor(x[None])
|
|
500
|
+
kx = op.activation(kx)
|
|
501
|
+
return keras.ops.convert_to_numpy(kx[0]) # type: ignore
|
|
502
|
+
|
|
503
|
+
return x.apply(activation)
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
class ReplayQSoftmax(ReplayOperationBase):
|
|
507
|
+
handles = (QSoftmax,)
|
|
508
|
+
|
|
509
|
+
def call(self, inputs: FixedVariableArray, mask: None | FixedVariableArray = None) -> FixedVariableArray:
|
|
510
|
+
op: QSoftmax = self.op
|
|
511
|
+
inputs = inputs[None]
|
|
512
|
+
|
|
513
|
+
if op.stable:
|
|
514
|
+
inputs = np.amax(inputs, axis=op.axes, keepdims=True) - inputs # type: ignore
|
|
515
|
+
|
|
516
|
+
exp_inp = ReplayQFunctionLUT(op.exp_table)(inputs[0])[0]
|
|
517
|
+
|
|
518
|
+
if mask is not None:
|
|
519
|
+
exp_inp = mask[0] * exp_inp
|
|
520
|
+
|
|
521
|
+
sums = np.sum(exp_inp[None], axis=op.axes, keepdims=True)[0] # type: ignore
|
|
522
|
+
divisor = ReplayQFunctionLUT(op.inv_table)(sums)[0]
|
|
523
|
+
|
|
524
|
+
return exp_inp * divisor
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _compute_attention_mask(
|
|
528
|
+
query,
|
|
529
|
+
value,
|
|
530
|
+
query_mask=None,
|
|
531
|
+
value_mask=None,
|
|
532
|
+
key_mask=None,
|
|
533
|
+
attention_mask=None,
|
|
534
|
+
use_causal_mask=False,
|
|
535
|
+
):
|
|
536
|
+
masks = []
|
|
537
|
+
if query_mask is not None:
|
|
538
|
+
masks.append(np.expand_dims(query_mask, -1)) # [Q, 1]
|
|
539
|
+
if value_mask is not None:
|
|
540
|
+
masks.append(np.expand_dims(value_mask, -2)) # [1, V]
|
|
541
|
+
if key_mask is not None:
|
|
542
|
+
masks.append(np.expand_dims(key_mask, -2)) # [1, V]
|
|
543
|
+
if use_causal_mask:
|
|
544
|
+
q = query.shape[0]
|
|
545
|
+
v = q if value is None else value.shape[0]
|
|
546
|
+
masks.append(np.tril(np.ones((q, v), dtype='uint8'))) # [Q, V]
|
|
547
|
+
masks.append(attention_mask)
|
|
548
|
+
if not masks:
|
|
549
|
+
return None
|
|
550
|
+
|
|
551
|
+
if any(isinstance(m, FixedVariableArray) for m in masks):
|
|
552
|
+
return np.prod(np.stack(masks, axis=0), axis=0)
|
|
553
|
+
else:
|
|
554
|
+
return None
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def _masked_softmax(op, attention_scores, attention_mask=None):
|
|
558
|
+
# Normalize the attention scores to probabilities.
|
|
559
|
+
# attention_scores = [B, N, T, S]
|
|
560
|
+
if attention_mask is not None:
|
|
561
|
+
# The expand dim happens starting from the `num_heads` dimension,
|
|
562
|
+
# (<batch_dims>, num_heads, <query_attention_dims,
|
|
563
|
+
# key_attention_dims>)
|
|
564
|
+
mask_expansion_axis = -len(op._attention_axes) * 2 - 1
|
|
565
|
+
for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
|
|
566
|
+
attention_mask = np.expand_dims(attention_mask, axis=mask_expansion_axis)
|
|
567
|
+
return ReplayQSoftmax(op._softmax)(attention_scores[0], mask=attention_mask)[0][None]
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def _compute_attention(op: QMultiHeadAttention, query, key, value, attention_mask=None, training=None):
|
|
571
|
+
# Take the dot product between "query" and "key" to get the raw
|
|
572
|
+
# attention scores.
|
|
573
|
+
attention_scores = einsum(op._dot_product_equation, key, query)
|
|
574
|
+
|
|
575
|
+
attention_scores = _masked_softmax(op, attention_scores, attention_mask)
|
|
576
|
+
|
|
577
|
+
# `context_layer` = [B, T, N, H]
|
|
578
|
+
attention_output = einsum(op._combine_equation, attention_scores, value)
|
|
579
|
+
return attention_output, attention_scores
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
class ReplayMHA(ReplayOperationBase):
|
|
583
|
+
handles = (QMultiHeadAttention,)
|
|
584
|
+
__input_quantizer_handled__ = True
|
|
585
|
+
__output_quantizer_handled__ = True
|
|
586
|
+
|
|
587
|
+
def call(
|
|
588
|
+
self,
|
|
589
|
+
query: FixedVariableArray,
|
|
590
|
+
value: FixedVariableArray,
|
|
591
|
+
key=None,
|
|
592
|
+
query_mask=None,
|
|
593
|
+
value_mask=None,
|
|
594
|
+
key_mask=None,
|
|
595
|
+
attention_mask=None,
|
|
596
|
+
return_attention_scores=False,
|
|
597
|
+
use_causal_mask=False,
|
|
598
|
+
):
|
|
599
|
+
op: QMultiHeadAttention = self.op
|
|
600
|
+
|
|
601
|
+
if key is None:
|
|
602
|
+
key = value
|
|
603
|
+
|
|
604
|
+
_attention_mask = _compute_attention_mask(
|
|
605
|
+
query,
|
|
606
|
+
value,
|
|
607
|
+
query_mask=query_mask,
|
|
608
|
+
value_mask=value_mask,
|
|
609
|
+
key_mask=key_mask,
|
|
610
|
+
attention_mask=attention_mask,
|
|
611
|
+
use_causal_mask=use_causal_mask,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
query = ReplayQDense(op._query_dense)(query)[0][None]
|
|
615
|
+
key = ReplayQDense(op._key_dense)(key)[0][None]
|
|
616
|
+
value = ReplayQDense(op._value_dense)(value)[0][None]
|
|
617
|
+
|
|
618
|
+
attention_output, attention_scores = _compute_attention(op, query, key, value, _attention_mask)
|
|
619
|
+
attention_output = ReplayQDense(op._output_dense)(attention_output[0])[0]
|
|
620
|
+
|
|
621
|
+
if op.enable_oq:
|
|
622
|
+
attention_output = mirror_quantizer(op.oq, attention_output)
|
|
623
|
+
|
|
624
|
+
if return_attention_scores:
|
|
625
|
+
return attention_output, attention_scores[0]
|
|
626
|
+
return attention_output
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
class ReplayQLinformerAttention(ReplayMHA):
|
|
630
|
+
handles = (QLinformerAttention,)
|
|
631
|
+
|
|
632
|
+
def call(
|
|
633
|
+
self,
|
|
634
|
+
query,
|
|
635
|
+
value,
|
|
636
|
+
key=None,
|
|
637
|
+
query_mask=None,
|
|
638
|
+
value_mask=None,
|
|
639
|
+
key_mask=None,
|
|
640
|
+
attention_mask=None,
|
|
641
|
+
return_attention_scores=False,
|
|
642
|
+
use_causal_mask=False,
|
|
643
|
+
):
|
|
644
|
+
assert use_causal_mask is False, 'Causal mask is not supported in QLinformerAttention.'
|
|
645
|
+
key = key if key is not None else value
|
|
646
|
+
op: QLinformerAttention = self.op
|
|
647
|
+
key = ReplayQDense(op._lin_k_proj)(key)[0]
|
|
648
|
+
value = ReplayQDense(op._lin_v_proj)(value)[0]
|
|
649
|
+
return super().call(
|
|
650
|
+
query,
|
|
651
|
+
value,
|
|
652
|
+
key,
|
|
653
|
+
query_mask=query_mask,
|
|
654
|
+
value_mask=value_mask,
|
|
655
|
+
key_mask=key_mask,
|
|
656
|
+
attention_mask=attention_mask,
|
|
657
|
+
return_attention_scores=return_attention_scores,
|
|
658
|
+
)
|
da4ml/trace/fixed_variable.py
CHANGED
|
@@ -1,14 +1,24 @@
|
|
|
1
1
|
import random
|
|
2
|
-
|
|
2
|
+
import typing
|
|
3
|
+
from collections.abc import Callable, Generator
|
|
4
|
+
from dataclasses import dataclass
|
|
3
5
|
from decimal import Decimal
|
|
6
|
+
from hashlib import sha256
|
|
4
7
|
from math import ceil, floor, log2
|
|
5
|
-
from typing import NamedTuple
|
|
8
|
+
from typing import NamedTuple, overload
|
|
6
9
|
from uuid import UUID
|
|
7
10
|
|
|
11
|
+
import numpy as np
|
|
12
|
+
from numpy.typing import NDArray
|
|
13
|
+
|
|
8
14
|
from ..cmvm.core import cost_add
|
|
9
|
-
from ..cmvm.types import QInterval
|
|
15
|
+
from ..cmvm.types import QInterval, _minimal_kif
|
|
16
|
+
from ..cmvm.util.bit_decompose import _shift_centering
|
|
17
|
+
|
|
18
|
+
rd = random.Random()
|
|
10
19
|
|
|
11
|
-
|
|
20
|
+
if typing.TYPE_CHECKING:
|
|
21
|
+
pass
|
|
12
22
|
|
|
13
23
|
|
|
14
24
|
class HWConfig(NamedTuple):
|
|
@@ -17,7 +27,154 @@ class HWConfig(NamedTuple):
|
|
|
17
27
|
latency_cutoff: float
|
|
18
28
|
|
|
19
29
|
|
|
30
|
+
ufunc_t = Callable[[NDArray[np.floating]], NDArray[np.floating]]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TraceContext:
|
|
34
|
+
_tables: 'dict[str, tuple[LookupTable, int]]' = {}
|
|
35
|
+
hwconf: HWConfig = HWConfig(1, -1, -1)
|
|
36
|
+
_table_counter = 0
|
|
37
|
+
|
|
38
|
+
def register_table(self, table: 'LookupTable|np.ndarray'):
|
|
39
|
+
if isinstance(table, np.ndarray):
|
|
40
|
+
table = LookupTable(table)
|
|
41
|
+
if table.spec.hash in self._tables:
|
|
42
|
+
return self._tables[table.spec.hash]
|
|
43
|
+
self._tables[table.spec.hash] = (table, self._table_counter)
|
|
44
|
+
|
|
45
|
+
self._table_counter += 1
|
|
46
|
+
return self._tables[table.spec.hash]
|
|
47
|
+
|
|
48
|
+
def index_table(self, hash: str) -> int:
|
|
49
|
+
return self._tables[hash][1]
|
|
50
|
+
|
|
51
|
+
def get_table_from_index(self, index: int) -> 'LookupTable':
|
|
52
|
+
for table, idx in self._tables.values():
|
|
53
|
+
if idx == index:
|
|
54
|
+
return table
|
|
55
|
+
raise KeyError(f'No table found with index {index}')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
table_context = TraceContext()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class TableSpec:
|
|
63
|
+
hash: str
|
|
64
|
+
out_qint: QInterval
|
|
65
|
+
inp_width: int
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def out_kif(self) -> tuple[bool, int, int]:
|
|
69
|
+
return _minimal_kif(self.out_qint)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def to_spec(table: NDArray[np.floating]) -> tuple[TableSpec, NDArray[np.int32]]:
|
|
73
|
+
f_out = -_shift_centering(np.array(table))
|
|
74
|
+
int_table = (table * 2**f_out).astype(np.int32)
|
|
75
|
+
h = sha256(int_table.data)
|
|
76
|
+
h.update(f'{f_out}'.encode())
|
|
77
|
+
inp_width = ceil(log2(table.size))
|
|
78
|
+
out_qint = QInterval(float(np.min(table)), float(np.max(table)), float(2**-f_out))
|
|
79
|
+
return TableSpec(hash=h.hexdigest(), inp_width=inp_width, out_qint=out_qint), int_table
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def interpret_as(
|
|
83
|
+
x: int | NDArray[np.integer],
|
|
84
|
+
k: int,
|
|
85
|
+
i: int,
|
|
86
|
+
f: int,
|
|
87
|
+
) -> float | NDArray[np.floating]:
|
|
88
|
+
b = k + i + f
|
|
89
|
+
bias = 2.0 ** (b - 1) * k
|
|
90
|
+
eps = 2.0**-f
|
|
91
|
+
floor_fn = np.floor if isinstance(x, np.ndarray) else floor
|
|
92
|
+
return eps * (floor_fn(x + bias) % 2.0**b - bias)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class LookupTable:
|
|
96
|
+
def __init__(self, values: NDArray, spec: TableSpec | None = None):
|
|
97
|
+
assert values.ndim == 1, 'Lookup table values must be 1-dimensional'
|
|
98
|
+
if spec is not None:
|
|
99
|
+
assert values.dtype is np.int32
|
|
100
|
+
self.spec = spec
|
|
101
|
+
self.table = values
|
|
102
|
+
else:
|
|
103
|
+
self.spec, self.table = to_spec(values)
|
|
104
|
+
|
|
105
|
+
@overload
|
|
106
|
+
def lookup(self, var: 'FixedVariable', qint_in: QInterval) -> 'FixedVariable': ...
|
|
107
|
+
|
|
108
|
+
@overload
|
|
109
|
+
def lookup(self, var: np.floating | float, qint_in: QInterval | tuple[float, float, float]) -> float: ...
|
|
110
|
+
|
|
111
|
+
def lookup(self, var, qint_in: QInterval | tuple[float, float, float]):
|
|
112
|
+
if isinstance(var, FixedVariable):
|
|
113
|
+
return var.lookup(self)
|
|
114
|
+
else:
|
|
115
|
+
_min, _max, _step = qint_in
|
|
116
|
+
assert _min <= var <= _max, f'Value {var} out of range [{_min}, {_max}]'
|
|
117
|
+
index = round((var - _min) / _step)
|
|
118
|
+
return interpret_as(int(self.table[index]), *self.spec.out_kif)
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def float_table(self) -> NDArray[np.floating]:
|
|
122
|
+
k, i, f = self.spec.out_kif
|
|
123
|
+
return interpret_as(self.table, k, i, f) # type: ignore
|
|
124
|
+
|
|
125
|
+
def to_dict(self) -> dict:
|
|
126
|
+
return {
|
|
127
|
+
'spec': {
|
|
128
|
+
'hash': self.spec.hash,
|
|
129
|
+
'out_qint': {
|
|
130
|
+
'min': self.spec.out_qint.min,
|
|
131
|
+
'max': self.spec.out_qint.max,
|
|
132
|
+
'step': self.spec.out_qint.step,
|
|
133
|
+
},
|
|
134
|
+
'inp_width': self.spec.inp_width,
|
|
135
|
+
},
|
|
136
|
+
'table': self.table.tolist(),
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def from_dict(cls, data: dict) -> 'LookupTable':
|
|
141
|
+
spec_data = data['spec']
|
|
142
|
+
out_qint_data = spec_data['out_qint']
|
|
143
|
+
spec = TableSpec(
|
|
144
|
+
hash=spec_data['hash'],
|
|
145
|
+
out_qint=QInterval(out_qint_data['min'], out_qint_data['max'], out_qint_data['step']),
|
|
146
|
+
inp_width=spec_data['inp_width'],
|
|
147
|
+
)
|
|
148
|
+
table = np.array(data['table'], dtype=np.int32)
|
|
149
|
+
return cls(table, spec=spec)
|
|
150
|
+
|
|
151
|
+
def _get_pads(self, qint: QInterval) -> tuple[int, int]:
|
|
152
|
+
k, i, f = _minimal_kif(qint)
|
|
153
|
+
if k:
|
|
154
|
+
pad_left = round((qint.min + 2**i) / qint.step)
|
|
155
|
+
else:
|
|
156
|
+
pad_left = round(qint.min / qint.step)
|
|
157
|
+
size = 2 ** (k + i + f)
|
|
158
|
+
pad_right = size - len(self.table) - pad_left
|
|
159
|
+
return pad_left, pad_right
|
|
160
|
+
|
|
161
|
+
def padded_table(self, qint: QInterval) -> NDArray[np.int32]:
|
|
162
|
+
pad_left, pad_right = self._get_pads(qint)
|
|
163
|
+
data = np.pad(self.table, (pad_left, pad_right), mode='constant', constant_values=0)
|
|
164
|
+
if qint.min < 0:
|
|
165
|
+
size = len(data)
|
|
166
|
+
# data = np.concatenate((data[size // 2 :], data[: size // 2]))
|
|
167
|
+
data = np.roll(data, size // 2)
|
|
168
|
+
return data
|
|
169
|
+
|
|
170
|
+
def get_uuid(self, qint: QInterval) -> UUID:
|
|
171
|
+
pad_left, _ = self._get_pads(qint)
|
|
172
|
+
_int = int(self.spec.hash[:32], 16) ^ pad_left
|
|
173
|
+
return UUID(int=_int, version=4)
|
|
174
|
+
|
|
175
|
+
|
|
20
176
|
def _const_f(const: float | Decimal):
|
|
177
|
+
"""Get the minimum f such that const * 2^f is an integer."""
|
|
21
178
|
const = float(const)
|
|
22
179
|
_low, _high = -32, 32
|
|
23
180
|
while _high - _low > 1:
|
|
@@ -31,6 +188,7 @@ def _const_f(const: float | Decimal):
|
|
|
31
188
|
|
|
32
189
|
|
|
33
190
|
def to_csd_powers(x: float) -> Generator[float, None, None]:
|
|
191
|
+
"""Convert a float to a list of +/- powers of two in CSD representation."""
|
|
34
192
|
if x == 0:
|
|
35
193
|
return
|
|
36
194
|
f = _const_f(abs(x))
|
|
@@ -48,6 +206,8 @@ def to_csd_powers(x: float) -> Generator[float, None, None]:
|
|
|
48
206
|
|
|
49
207
|
|
|
50
208
|
class FixedVariable:
|
|
209
|
+
__normal__variable__ = True
|
|
210
|
+
|
|
51
211
|
def __init__(
|
|
52
212
|
self,
|
|
53
213
|
low: float | Decimal,
|
|
@@ -62,7 +222,8 @@ class FixedVariable:
|
|
|
62
222
|
_data: Decimal | None = None,
|
|
63
223
|
_id: UUID | None = None,
|
|
64
224
|
) -> None:
|
|
65
|
-
|
|
225
|
+
if self.__normal__variable__:
|
|
226
|
+
assert low <= high, f'low {low} must be less than high {high}'
|
|
66
227
|
|
|
67
228
|
if low != high and opr == 'const':
|
|
68
229
|
raise ValueError('Constant variable must have low == high')
|
|
@@ -100,9 +261,19 @@ class FixedVariable:
|
|
|
100
261
|
if v.opr == 'const':
|
|
101
262
|
v.latency = self.latency
|
|
102
263
|
|
|
103
|
-
def get_cost_and_latency(self):
|
|
264
|
+
def get_cost_and_latency(self) -> tuple[float, float]:
|
|
104
265
|
if self.opr == 'const':
|
|
105
266
|
return 0.0, 0.0
|
|
267
|
+
|
|
268
|
+
if self.opr == 'lookup':
|
|
269
|
+
assert len(self._from) == 1
|
|
270
|
+
b_in = sum(self._from[0].kif)
|
|
271
|
+
b_out = sum(self.kif)
|
|
272
|
+
_latency = max(b_in - 6, 1) + self._from[0].latency
|
|
273
|
+
_cost = 2 ** max(b_in - 5, 0) * ceil(b_out / 2)
|
|
274
|
+
# Assume LUT6 with extra o5 output
|
|
275
|
+
return _cost, _latency
|
|
276
|
+
|
|
106
277
|
if self.opr in ('vadd', 'cadd', 'min', 'max', 'vmul'):
|
|
107
278
|
adder_size = self.hwconf.adder_size
|
|
108
279
|
carry_size = self.hwconf.carry_size
|
|
@@ -212,7 +383,7 @@ class FixedVariable:
|
|
|
212
383
|
if self.high == self.low:
|
|
213
384
|
return other._const_add(self.low)
|
|
214
385
|
|
|
215
|
-
assert self.hwconf == other.hwconf, 'FixedVariable must have the same hwconf'
|
|
386
|
+
assert self.hwconf == other.hwconf, f'FixedVariable must have the same hwconf, got {self.hwconf} and {other.hwconf}'
|
|
216
387
|
|
|
217
388
|
f0, f1 = self._factor, other._factor
|
|
218
389
|
if f0 < 0:
|
|
@@ -270,20 +441,32 @@ class FixedVariable:
|
|
|
270
441
|
return self * (1 / other)
|
|
271
442
|
|
|
272
443
|
def __mul__(self, other: 'FixedVariable|int|float|Decimal') -> 'FixedVariable':
|
|
444
|
+
if isinstance(other, FixedVariable):
|
|
445
|
+
if self.high == self.low:
|
|
446
|
+
return other * self.low
|
|
447
|
+
if other.high > other.low:
|
|
448
|
+
return self._var_mul(other)
|
|
449
|
+
assert other.high == other.low
|
|
450
|
+
other = float(other.low)
|
|
451
|
+
|
|
273
452
|
if other == 0:
|
|
274
453
|
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
275
454
|
|
|
276
|
-
if isinstance(other, FixedVariable):
|
|
277
|
-
return self._var_mul(other)
|
|
278
|
-
|
|
279
455
|
if log2(abs(other)) % 1 == 0:
|
|
280
456
|
return self._pow2_mul(other)
|
|
281
457
|
|
|
282
|
-
variables = [self._pow2_mul(v) for v in to_csd_powers(float(other))]
|
|
458
|
+
variables = [(self._pow2_mul(v), Decimal(v)) for v in to_csd_powers(float(other))]
|
|
283
459
|
while len(variables) > 1:
|
|
284
|
-
|
|
285
|
-
variables.
|
|
286
|
-
|
|
460
|
+
v1, p1 = variables.pop()
|
|
461
|
+
v2, p2 = variables.pop()
|
|
462
|
+
v, p = v1 + v2, p1 + p2
|
|
463
|
+
if p > 0:
|
|
464
|
+
high, low = self.high * p, self.low * p
|
|
465
|
+
else:
|
|
466
|
+
high, low = self.low * p, self.high * p
|
|
467
|
+
v.high, v.low = high, low
|
|
468
|
+
variables.append((v, p))
|
|
469
|
+
return variables[0][0]
|
|
287
470
|
|
|
288
471
|
def _var_mul(self, other: 'FixedVariable') -> 'FixedVariable':
|
|
289
472
|
if other is not self:
|
|
@@ -307,6 +490,7 @@ class FixedVariable:
|
|
|
307
490
|
high,
|
|
308
491
|
step,
|
|
309
492
|
_from=(self, other),
|
|
493
|
+
hwconf=self.hwconf,
|
|
310
494
|
_factor=_factor,
|
|
311
495
|
opr=opr,
|
|
312
496
|
)
|
|
@@ -407,7 +591,7 @@ class FixedVariable:
|
|
|
407
591
|
f: int,
|
|
408
592
|
overflow_mode: str = 'WRAP',
|
|
409
593
|
round_mode: str = 'TRN',
|
|
410
|
-
):
|
|
594
|
+
) -> 'FixedVariable':
|
|
411
595
|
overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
|
|
412
596
|
assert overflow_mode in ('WRAP', 'SAT', 'SAT_SYM')
|
|
413
597
|
assert round_mode in ('TRN', 'RND')
|
|
@@ -428,7 +612,9 @@ class FixedVariable:
|
|
|
428
612
|
_high = Decimal(2) ** i
|
|
429
613
|
high = _high - step
|
|
430
614
|
low = -_high * k if overflow_mode == 'SAT' else -high * k
|
|
431
|
-
|
|
615
|
+
ff = f + 1 if round_mode == 'RND' else f
|
|
616
|
+
v = self.quantize(_k, _i, ff, 'WRAP', 'TRN')
|
|
617
|
+
return v.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
|
|
432
618
|
|
|
433
619
|
if self.low == self.high:
|
|
434
620
|
val = self.low
|
|
@@ -539,25 +725,47 @@ class FixedVariable:
|
|
|
539
725
|
qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
|
|
540
726
|
return (self - other).msb_mux(self, other, qint=qint)
|
|
541
727
|
|
|
728
|
+
def lookup(self, table: LookupTable | np.ndarray) -> 'FixedVariable':
|
|
729
|
+
_table, table_id = table_context.register_table(table)
|
|
730
|
+
size = len(table.table) if isinstance(table, LookupTable) else len(table)
|
|
731
|
+
assert (
|
|
732
|
+
round((self.high - self.low) / self.step) + 1 == size
|
|
733
|
+
), f'Input variable size does not match lookup table size ({round((self.high - self.low) / self.step) + 1} != {size})'
|
|
734
|
+
|
|
735
|
+
return FixedVariable(
|
|
736
|
+
_table.spec.out_qint.min,
|
|
737
|
+
_table.spec.out_qint.max,
|
|
738
|
+
_table.spec.out_qint.step,
|
|
739
|
+
_from=(self,),
|
|
740
|
+
_factor=Decimal(1),
|
|
741
|
+
opr='lookup',
|
|
742
|
+
hwconf=self.hwconf,
|
|
743
|
+
_data=Decimal(table_id),
|
|
744
|
+
)
|
|
745
|
+
|
|
542
746
|
|
|
543
747
|
class FixedVariableInput(FixedVariable):
|
|
748
|
+
__normal__variable__ = False
|
|
749
|
+
|
|
544
750
|
def __init__(
|
|
545
751
|
self,
|
|
546
752
|
latency: float | None = None,
|
|
547
|
-
hwconf=HWConfig(-1, -1, -1),
|
|
753
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(-1, -1, -1),
|
|
754
|
+
opr: str = 'new',
|
|
548
755
|
) -> None:
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
756
|
+
super().__init__(
|
|
757
|
+
low=Decimal(1e10),
|
|
758
|
+
high=Decimal(-1e10),
|
|
759
|
+
step=Decimal(1e10),
|
|
760
|
+
latency=latency if latency is not None else 0.0,
|
|
761
|
+
hwconf=HWConfig(*hwconf),
|
|
762
|
+
opr=opr,
|
|
763
|
+
cost=0.0,
|
|
764
|
+
_factor=Decimal(1),
|
|
765
|
+
_from=(),
|
|
766
|
+
_data=None,
|
|
767
|
+
_id=None,
|
|
768
|
+
)
|
|
561
769
|
|
|
562
770
|
def __add__(self, other):
|
|
563
771
|
if other == 0:
|
|
@@ -614,6 +822,8 @@ class FixedVariableInput(FixedVariable):
|
|
|
614
822
|
|
|
615
823
|
if round_mode == 'RND':
|
|
616
824
|
return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
825
|
+
else:
|
|
826
|
+
round_mode = 'TRN'
|
|
617
827
|
|
|
618
828
|
step = Decimal(2) ** -f
|
|
619
829
|
_high = Decimal(2) ** i
|