da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da4ml/__init__.py +4 -0
- da4ml/_binary/__init__.py +15 -0
- da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
- da4ml/_binary/dais_bin.pyi +5 -0
- da4ml/_cli/__init__.py +30 -0
- da4ml/_cli/convert.py +204 -0
- da4ml/_cli/report.py +295 -0
- da4ml/_version.py +32 -0
- da4ml/cmvm/__init__.py +4 -0
- da4ml/cmvm/api.py +264 -0
- da4ml/cmvm/core/__init__.py +221 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +739 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +9 -0
- da4ml/codegen/hls/__init__.py +4 -0
- da4ml/codegen/hls/hls_codegen.py +196 -0
- da4ml/codegen/hls/hls_model.py +255 -0
- da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/hls/source/binder_util.hh +71 -0
- da4ml/codegen/hls/source/build_binder.mk +22 -0
- da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
- da4ml/codegen/rtl/__init__.py +15 -0
- da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
- da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
- da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +30 -0
- da4ml/codegen/rtl/rtl_model.py +486 -0
- da4ml/codegen/rtl/verilog/__init__.py +10 -0
- da4ml/codegen/rtl/verilog/comb.py +239 -0
- da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
- da4ml/codegen/rtl/verilog/pipeline.py +67 -0
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
- da4ml/codegen/rtl/verilog/source/mux.v +58 -0
- da4ml/codegen/rtl/verilog/source/negative.v +31 -0
- da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
- da4ml/codegen/rtl/vhdl/__init__.py +9 -0
- da4ml/codegen/rtl/vhdl/comb.py +206 -0
- da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
- da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
- da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
- da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
- da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
- da4ml/converter/__init__.py +63 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/layers/__init__.py +11 -0
- da4ml/converter/hgq2/layers/_base.py +132 -0
- da4ml/converter/hgq2/layers/activation.py +81 -0
- da4ml/converter/hgq2/layers/attn.py +148 -0
- da4ml/converter/hgq2/layers/batchnorm.py +15 -0
- da4ml/converter/hgq2/layers/conv.py +149 -0
- da4ml/converter/hgq2/layers/dense.py +39 -0
- da4ml/converter/hgq2/layers/ops.py +246 -0
- da4ml/converter/hgq2/layers/pool.py +107 -0
- da4ml/converter/hgq2/layers/table.py +176 -0
- da4ml/converter/hgq2/parser.py +161 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +965 -0
- da4ml/trace/fixed_variable_array.py +600 -0
- da4ml/trace/ops/__init__.py +13 -0
- da4ml/trace/ops/einsum_utils.py +305 -0
- da4ml/trace/ops/quantization.py +74 -0
- da4ml/trace/ops/reduce_utils.py +105 -0
- da4ml/trace/pipeline.py +181 -0
- da4ml/trace/tracer.py +186 -0
- da4ml/typing/__init__.py +3 -0
- da4ml-0.5.1.post1.dist-info/METADATA +85 -0
- da4ml-0.5.1.post1.dist-info/RECORD +96 -0
- da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
- da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
- da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
- da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from math import prod, sqrt
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
import numpy as np
|
|
6
|
+
from hgq.layers.table import QConvT1D, QConvT2D, QConvTBase, QDenseT
|
|
7
|
+
from hgq.quantizer.internal import FixedPointQuantizerBase
|
|
8
|
+
from keras import ops
|
|
9
|
+
|
|
10
|
+
from ....trace import FixedVariableArray
|
|
11
|
+
from ....trace.fixed_variable import FixedVariable
|
|
12
|
+
from ....trace.ops import _quantize
|
|
13
|
+
from ._base import ReplayOperationBase, mirror_quantizer, to_np_arr
|
|
14
|
+
from .conv import symbolic_extract_patches
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def keras_act_to_numpy(act: Callable) -> Callable:
|
|
18
|
+
match act:
|
|
19
|
+
case keras.activations.relu:
|
|
20
|
+
return lambda x: np.maximum(0, x)
|
|
21
|
+
case keras.activations.tanh:
|
|
22
|
+
return np.tanh
|
|
23
|
+
case keras.activations.softmax:
|
|
24
|
+
raise ValueError('Non-local activation must not be used')
|
|
25
|
+
case keras.activations.linear:
|
|
26
|
+
return lambda x: x
|
|
27
|
+
case keras.activations.sigmoid:
|
|
28
|
+
return lambda x: 1 / (1 + np.exp(-x))
|
|
29
|
+
case keras.activations.swish:
|
|
30
|
+
return lambda x: x / (1 + np.exp(-x))
|
|
31
|
+
case keras.activations.gelu:
|
|
32
|
+
return lambda x: 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
|
|
33
|
+
case keras.activations.elu:
|
|
34
|
+
return lambda x: np.where(x > 0, x, np.exp(x) - 1)
|
|
35
|
+
case keras.activations.selu:
|
|
36
|
+
alpha = 1.6732632423543772
|
|
37
|
+
scale = 1.0507009873554805
|
|
38
|
+
return lambda x: scale * np.where(x > 0, x, alpha * (np.exp(x) - 1))
|
|
39
|
+
case keras.activations.softplus:
|
|
40
|
+
return lambda x: np.log1p(np.exp(x))
|
|
41
|
+
case keras.activations.softsign:
|
|
42
|
+
return lambda x: x / (1 + np.abs(x))
|
|
43
|
+
case keras.activations.exponential:
|
|
44
|
+
return lambda x: np.exp(x)
|
|
45
|
+
case keras.activations.hard_silu:
|
|
46
|
+
return lambda x: x * np.minimum(1, np.maximum(0, (x + 1) / 2))
|
|
47
|
+
case _:
|
|
48
|
+
return lambda x: ops.convert_to_numpy(act(ops.convert_to_tensor(x)))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def gather_weights_and_activation(model: keras.Sequential):
|
|
52
|
+
ws: list[np.ndarray] = []
|
|
53
|
+
bs: list[np.ndarray | None] = []
|
|
54
|
+
acts: list[Callable[[np.ndarray], np.ndarray]] = []
|
|
55
|
+
for layer in model.layers:
|
|
56
|
+
layer: keras.layers.EinsumDense
|
|
57
|
+
w, *b = layer.get_weights()
|
|
58
|
+
act = keras_act_to_numpy(layer.activation)
|
|
59
|
+
if len(b) != 0:
|
|
60
|
+
assert len(b) == 1
|
|
61
|
+
b = b[0]
|
|
62
|
+
else:
|
|
63
|
+
b = None
|
|
64
|
+
if w.ndim == 3:
|
|
65
|
+
w = w[..., None]
|
|
66
|
+
if b is not None:
|
|
67
|
+
b = b[..., None]
|
|
68
|
+
ws.append(w)
|
|
69
|
+
bs.append(b)
|
|
70
|
+
acts.append(act)
|
|
71
|
+
return ws, bs, acts
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ReplayDenseTable(ReplayOperationBase):
|
|
75
|
+
handles = (QDenseT,)
|
|
76
|
+
|
|
77
|
+
__input_quantizer_handled__ = True
|
|
78
|
+
|
|
79
|
+
def call(self, inputs: FixedVariableArray) -> FixedVariableArray:
|
|
80
|
+
op: QDenseT = self.op # type: ignore
|
|
81
|
+
|
|
82
|
+
out = np.broadcast_to(inputs[..., None], inputs.shape + (op.n_out,)) # type: ignore
|
|
83
|
+
out = mirror_quantizer(op.iq, out)
|
|
84
|
+
|
|
85
|
+
l, h, s = out.lhs
|
|
86
|
+
|
|
87
|
+
table_sizes: np.ndarray = np.round((h - l) / s).astype(np.uint32) + 1
|
|
88
|
+
|
|
89
|
+
model = op.module
|
|
90
|
+
|
|
91
|
+
ws, bs, acts = gather_weights_and_activation(model)
|
|
92
|
+
|
|
93
|
+
out_shape: tuple[int, ...] = inputs.shape + (op.n_out,)
|
|
94
|
+
tables: list[np.ndarray] = [None] * prod(out_shape) # type: ignore
|
|
95
|
+
n, loc = np.unique(table_sizes, return_inverse=True)
|
|
96
|
+
|
|
97
|
+
for i in range(n.size):
|
|
98
|
+
mask: np.ndarray = loc == i
|
|
99
|
+
_l, _h = l[mask], h[mask]
|
|
100
|
+
inp = np.linspace(_l, _h, n[i])
|
|
101
|
+
|
|
102
|
+
_out = inp[..., None]
|
|
103
|
+
|
|
104
|
+
idxs = np.where(mask.ravel())[0]
|
|
105
|
+
mask = mask.reshape(-1, *mask.shape[-2:])
|
|
106
|
+
|
|
107
|
+
for w, b, act in zip(ws, bs, acts):
|
|
108
|
+
w = np.concatenate([w[_mask] for _mask in mask], axis=0)
|
|
109
|
+
if b is not None:
|
|
110
|
+
b = np.concatenate([b[_mask] for _mask in mask], axis=0)
|
|
111
|
+
else:
|
|
112
|
+
b = 0
|
|
113
|
+
_out = act(np.einsum('...ni,nij->...nj', _out, w, optimize='optimal') + b)
|
|
114
|
+
_out = _out[..., 0]
|
|
115
|
+
|
|
116
|
+
for j, idx in enumerate(idxs):
|
|
117
|
+
tables[idx] = _out[..., j]
|
|
118
|
+
|
|
119
|
+
if op.enable_bn:
|
|
120
|
+
bn = op.bn_module
|
|
121
|
+
beta: np.ndarray = ops.convert_to_numpy(bn.beta) if bn.center else 1 # type: ignore
|
|
122
|
+
gamma: np.ndarray = ops.convert_to_numpy(bn.gamma) if bn.scale else 1 # type: ignore
|
|
123
|
+
m_mean: np.ndarray = ops.convert_to_numpy(bn.moving_mean) # type: ignore
|
|
124
|
+
m_var: np.ndarray = ops.convert_to_numpy(bn.moving_variance) # type: ignore
|
|
125
|
+
epsilon = bn.epsilon
|
|
126
|
+
scaler = gamma / np.sqrt(m_var + epsilon)
|
|
127
|
+
offset = beta - m_mean * scaler
|
|
128
|
+
|
|
129
|
+
for i in range(len(tables)):
|
|
130
|
+
tables[i][:] = (tables[i] * scaler[i % op.n_out] + offset[i % op.n_out]) / sqrt(op.n_in)
|
|
131
|
+
|
|
132
|
+
assert all(v is not None for v in tables), tables
|
|
133
|
+
|
|
134
|
+
toq = op.toq
|
|
135
|
+
toq_internal: FixedPointQuantizerBase = toq.quantizer
|
|
136
|
+
kk, ki, kf = toq_internal.kif
|
|
137
|
+
|
|
138
|
+
_shape = (1,) + out.shape
|
|
139
|
+
kk = toq_internal.bw_mapper.bw_to_x(kk, _shape)
|
|
140
|
+
ki = toq_internal.bw_mapper.bw_to_x(ki, _shape)
|
|
141
|
+
kf = toq_internal.bw_mapper.bw_to_x(kf, _shape)
|
|
142
|
+
|
|
143
|
+
k, i, f = map(lambda x: to_np_arr(x).astype(np.int32).ravel(), (kk, ki, kf))
|
|
144
|
+
|
|
145
|
+
round_mode, overflow_mode = toq_internal.round_mode, toq_internal.overflow_mode
|
|
146
|
+
round_mode = round_mode[2:] if round_mode.startswith('S_') else round_mode
|
|
147
|
+
for arr, _k, _i, _f in zip(tables, k, i, f):
|
|
148
|
+
arr[:] = _quantize(arr, _k, _i, _f, overflow_mode, round_mode)
|
|
149
|
+
|
|
150
|
+
ret_vars: list[FixedVariable] = [None] * len(tables) # type: ignore
|
|
151
|
+
_vars = out.ravel()._vars
|
|
152
|
+
for i in range(len(tables)):
|
|
153
|
+
ret_vars[i] = _vars[i].lookup(tables[i])
|
|
154
|
+
out = FixedVariableArray(np.array(ret_vars).reshape(out_shape), solver_options=out.solver_options)
|
|
155
|
+
out = np.sum(out, axis=-2) # type: ignore
|
|
156
|
+
return out
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class ReplayConvTable(ReplayDenseTable):
|
|
160
|
+
handles = (QConvT2D, QConvT1D, QConvTBase)
|
|
161
|
+
|
|
162
|
+
def call(self, inputs: FixedVariableArray):
|
|
163
|
+
op: QConvTBase = self.op
|
|
164
|
+
|
|
165
|
+
if op.rank == 1:
|
|
166
|
+
inputs = inputs[:, None]
|
|
167
|
+
|
|
168
|
+
inputs = symbolic_extract_patches(inputs, **op.im2col_params)
|
|
169
|
+
|
|
170
|
+
if op.rank == 1:
|
|
171
|
+
inputs = inputs[:, 0]
|
|
172
|
+
|
|
173
|
+
return super().call(inputs)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
__all__ = ['ReplayDenseTable', 'ReplayConvTable']
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import keras
|
|
6
|
+
import numpy as np
|
|
7
|
+
from keras import KerasTensor, Operation
|
|
8
|
+
|
|
9
|
+
from ...cmvm.api import solver_options_t
|
|
10
|
+
from ...trace import FixedVariableArray, FixedVariableArrayInput, HWConfig, comb_trace
|
|
11
|
+
from ...trace.fixed_variable import FixedVariable
|
|
12
|
+
from .layers import _registry
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class OpObj:
|
|
17
|
+
operation: Operation
|
|
18
|
+
args: list
|
|
19
|
+
kwargs: dict
|
|
20
|
+
produces: tuple[KerasTensor, ...]
|
|
21
|
+
requires: tuple[KerasTensor, ...]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def parse_model(model: keras.Model):
|
|
25
|
+
if isinstance(model, keras.Sequential):
|
|
26
|
+
model = model._functional
|
|
27
|
+
operators: dict[int, list[OpObj]] = {}
|
|
28
|
+
for depth, nodes in model._nodes_by_depth.items():
|
|
29
|
+
_oprs = []
|
|
30
|
+
for node in nodes:
|
|
31
|
+
assert isinstance(node.operation, keras.Operation)
|
|
32
|
+
opr = OpObj(
|
|
33
|
+
operation=node.operation,
|
|
34
|
+
args=node.arguments.args,
|
|
35
|
+
kwargs=node.arguments.kwargs,
|
|
36
|
+
produces=node.outputs,
|
|
37
|
+
requires=node.arguments.keras_tensors,
|
|
38
|
+
)
|
|
39
|
+
_oprs.append(opr)
|
|
40
|
+
operators[depth] = _oprs
|
|
41
|
+
return [operators[i] for i in range(max(operators.keys()), -1, -1)]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def replace_tensors(tensor_map: dict[KerasTensor, FixedVariableArray], obj: Any) -> Any:
|
|
45
|
+
if isinstance(obj, KerasTensor):
|
|
46
|
+
return tensor_map[obj]
|
|
47
|
+
if isinstance(obj, list):
|
|
48
|
+
return [replace_tensors(tensor_map, o) for o in obj]
|
|
49
|
+
if isinstance(obj, tuple):
|
|
50
|
+
return tuple(replace_tensors(tensor_map, o) for o in obj)
|
|
51
|
+
if isinstance(obj, dict):
|
|
52
|
+
return {k: replace_tensors(tensor_map, v) for k, v in obj.items()}
|
|
53
|
+
return obj
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _flatten_arr(args: Any) -> FixedVariableArray:
|
|
57
|
+
if isinstance(args, FixedVariableArray):
|
|
58
|
+
return np.ravel(args) # type: ignore
|
|
59
|
+
if isinstance(args, FixedVariable):
|
|
60
|
+
return FixedVariableArray(np.array([args]))
|
|
61
|
+
if not isinstance(args, Sequence):
|
|
62
|
+
return None # type: ignore
|
|
63
|
+
args = [_flatten_arr(a) for a in args]
|
|
64
|
+
args = [a for a in args if a is not None]
|
|
65
|
+
return np.concatenate(args) # type: ignore
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _apply_nn(
|
|
69
|
+
model: keras.Model,
|
|
70
|
+
inputs: FixedVariableArray | Sequence[FixedVariableArray],
|
|
71
|
+
verbose: bool = False,
|
|
72
|
+
dump: bool = False,
|
|
73
|
+
n_nested: int = 0,
|
|
74
|
+
) -> tuple[FixedVariableArray, ...] | dict[str, FixedVariableArray]:
|
|
75
|
+
"""
|
|
76
|
+
Apply a keras model to a fixed variable array or a sequence of fixed variable arrays.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
model : keras.Model
|
|
81
|
+
The keras model to apply.
|
|
82
|
+
inputs : FixedVariableArray or Sequence[FixedVariableArray]
|
|
83
|
+
The input fixed variable array or sequence of fixed variable arrays.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
tuple of FixedVariableArray
|
|
88
|
+
A tuple containing the output(s) of the model as FixedVariableArray.
|
|
89
|
+
"""
|
|
90
|
+
if isinstance(inputs, FixedVariableArray):
|
|
91
|
+
inputs = (inputs,)
|
|
92
|
+
|
|
93
|
+
assert len(model.inputs) == len(inputs), f'Model has {len(model.inputs)} inputs, got {len(inputs)}'
|
|
94
|
+
tensor_map = {keras_tensor: da_tensor for keras_tensor, da_tensor in zip(model.inputs, inputs)}
|
|
95
|
+
|
|
96
|
+
_inputs = _flatten_arr(inputs)
|
|
97
|
+
|
|
98
|
+
if verbose and n_nested:
|
|
99
|
+
print(' -> enter:')
|
|
100
|
+
|
|
101
|
+
for ops in parse_model(model):
|
|
102
|
+
for op in ops:
|
|
103
|
+
assert all(t in tensor_map for t in op.requires)
|
|
104
|
+
args = replace_tensors(tensor_map, op.args)
|
|
105
|
+
kwargs: dict[str, Any] = replace_tensors(tensor_map, op.kwargs)
|
|
106
|
+
if op.operation.__class__ is keras.layers.InputLayer:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
if verbose:
|
|
110
|
+
indent = ' ' * n_nested
|
|
111
|
+
print(f'{indent}{op.operation.name} ({op.operation.__class__.__name__})', end='')
|
|
112
|
+
|
|
113
|
+
if isinstance(op.operation, keras.Model):
|
|
114
|
+
sub_model = op.operation._functional if isinstance(op.operation, keras.Sequential) else op.operation
|
|
115
|
+
outputs: tuple[FixedVariableArray, ...] = _apply_nn(
|
|
116
|
+
sub_model,
|
|
117
|
+
args,
|
|
118
|
+
verbose=verbose,
|
|
119
|
+
dump=False,
|
|
120
|
+
n_nested=n_nested + 1,
|
|
121
|
+
) # type: ignore
|
|
122
|
+
else:
|
|
123
|
+
mirror_op = _registry[op.operation.__class__](op.operation)
|
|
124
|
+
outputs = mirror_op(*args, **kwargs)
|
|
125
|
+
if verbose:
|
|
126
|
+
comb = comb_trace(_inputs, _flatten_arr(outputs))
|
|
127
|
+
print(f' cumcost: {comb.cost}, latency: {comb.latency[1]}')
|
|
128
|
+
|
|
129
|
+
for keras_tensor, da_tensor in zip(op.produces, outputs):
|
|
130
|
+
tensor_map[keras_tensor] = da_tensor
|
|
131
|
+
|
|
132
|
+
if verbose and n_nested:
|
|
133
|
+
indent = ' ' * (n_nested - 1)
|
|
134
|
+
print(f'{indent}<- exit', end='')
|
|
135
|
+
|
|
136
|
+
if not dump:
|
|
137
|
+
return tuple(tensor_map[keras_tensor] for keras_tensor in model.outputs)
|
|
138
|
+
else:
|
|
139
|
+
return {k.name: v for k, v in tensor_map.items()}
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def trace_model( # type: ignore
|
|
143
|
+
model: keras.Model,
|
|
144
|
+
hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
|
|
145
|
+
solver_options: solver_options_t | None = None,
|
|
146
|
+
verbose: bool = False,
|
|
147
|
+
inputs: tuple[FixedVariableArray, ...] | None = None,
|
|
148
|
+
inputs_kif: tuple[int, int, int] | None = None,
|
|
149
|
+
dump=False,
|
|
150
|
+
):
|
|
151
|
+
if inputs is None:
|
|
152
|
+
inputs = tuple(
|
|
153
|
+
FixedVariableArrayInput(inp.shape[1:], hwconf=hwconf, solver_options=solver_options) for inp in model.inputs
|
|
154
|
+
)
|
|
155
|
+
if inputs_kif is not None:
|
|
156
|
+
inputs = tuple(inp.quantize(*inputs_kif) for inp in inputs)
|
|
157
|
+
outputs = _apply_nn(model, inputs, verbose=verbose, dump=dump)
|
|
158
|
+
if not dump:
|
|
159
|
+
return _flatten_arr(inputs), _flatten_arr(outputs)
|
|
160
|
+
else:
|
|
161
|
+
return {k: _flatten_arr(v) for k, v in outputs.items()} # type: ignore
|
da4ml/trace/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .fixed_variable import FixedVariable, HWConfig
|
|
2
|
+
from .fixed_variable_array import FixedVariableArray, FixedVariableArrayInput
|
|
3
|
+
from .pipeline import to_pipeline
|
|
4
|
+
from .tracer import comb_trace
|
|
5
|
+
|
|
6
|
+
__all__ = ['to_pipeline', 'comb_trace', 'FixedVariableArray', 'FixedVariable', 'HWConfig', 'FixedVariableArrayInput']
|