da4ml 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

da4ml/_version.py CHANGED
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '0.3.1'
21
- __version_tuple__ = version_tuple = (0, 3, 1)
31
+ __version__ = version = '0.3.3'
32
+ __version_tuple__ = version_tuple = (0, 3, 3)
33
+
34
+ __commit_id__ = commit_id = None
da4ml/cmvm/types.py CHANGED
@@ -321,7 +321,7 @@ class Solution(NamedTuple):
321
321
  case 4: # const addition
322
322
  bias = op.data * op.qint.step
323
323
  buf[i] = buf[op.id0] + bias
324
- case 5:
324
+ case 5: # const definition
325
325
  buf[i] = op.data * op.qint.step # const definition
326
326
  case 6 | -6: # MSB Mux
327
327
  id_c = op.data & 0xFFFFFFFF
@@ -340,6 +340,9 @@ class Solution(NamedTuple):
340
340
  else:
341
341
  _k, _i, _f = _minimal_kif(qint_k)
342
342
  buf[i] = v0 if k >= 2.0 ** (_i - 1) else v1 * 2.0**shift
343
+ case 7:
344
+ v0, v1 = buf[op.id0], buf[op.id1]
345
+ buf[i] = v0 * v1
343
346
  case _:
344
347
  raise ValueError(f'Unknown opcode {op.opcode} in {op}')
345
348
 
@@ -370,6 +373,8 @@ class Solution(NamedTuple):
370
373
  case 6 | -6:
371
374
  _sign = '-' if op.opcode == -6 else ''
372
375
  op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
376
+ case 7:
377
+ op_str = f'buf[{op.id0}] * buf[{op.id1}]'
373
378
  case _:
374
379
  raise ValueError(f'Unknown opcode {op.opcode} in {op}')
375
380
 
@@ -436,7 +441,12 @@ class Solution(NamedTuple):
436
441
  @property
437
442
  def inp_qint(self):
438
443
  """Quantization intervals of the input elements."""
439
- return [op.qint for op in self.ops if op.opcode == -1]
444
+ qints = [QInterval(0.0, 0.0, 1.0) for _ in range(self.shape[0])]
445
+ for op in self.ops:
446
+ if op.opcode != -1:
447
+ continue
448
+ qints[op.id0] = op.qint
449
+ return qints
440
450
 
441
451
  def save(self, path: str | Path):
442
452
  """Save the solution to a file."""
@@ -46,7 +46,7 @@ def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int
46
46
  match op.opcode:
47
47
  case -1:
48
48
  # Input marker
49
- val = f'inp[{ops[op.id0].id0}]'
49
+ val = f'inp[{op.id0}]'
50
50
  case 0 | 1:
51
51
  # Common a+/-b<<shift op
52
52
  ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
@@ -86,7 +86,10 @@ def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int
86
86
  sign = '-' if op.opcode == -6 else ''
87
87
  ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
88
88
  val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
89
-
89
+ case 7:
90
+ # Multiplication
91
+ ref1 = f'v{op.id1}'
92
+ val = f'{ref0} * {ref1}'
90
93
  case _:
91
94
  raise ValueError(f'Unsupported opcode: {op.opcode}')
92
95
 
@@ -9,31 +9,26 @@ constexpr bool _openmp = true;
9
9
  constexpr bool _openmp = false;
10
10
  #endif
11
11
 
12
- template <typename CONFIG_T, typename T> void _inference(T *c_inp, T *c_out, size_t n_samples)
13
- {
12
+ template <typename CONFIG_T, typename T> void _inference(T *c_inp, T *c_out, size_t n_samples) {
14
13
  typename CONFIG_T::inp_t in_fixed_buf[CONFIG_T::N_inp];
15
14
  typename CONFIG_T::out_t out_fixed_buf[CONFIG_T::N_out];
16
15
 
17
- for(size_t i = 0; i < n_samples; ++i)
18
- {
19
- size_t offset_in = i * CONFIG_T::N_inp;
20
- size_t offset_out = i * CONFIG_T::N_out;
21
- for(size_t j = 0; j < CONFIG_T::N_inp; ++j)
22
- {
23
- in_fixed_buf[j] = c_inp[offset_in + j];
24
- }
16
+ for (size_t i = 0; i < n_samples; ++i) {
17
+ size_t offset_in = i * CONFIG_T::N_inp;
18
+ size_t offset_out = i * CONFIG_T::N_out;
19
+ for (size_t j = 0; j < CONFIG_T::N_inp; ++j) {
20
+ in_fixed_buf[j] = c_inp[offset_in + j];
21
+ }
25
22
 
26
- CONFIG_T::f(in_fixed_buf, out_fixed_buf);
23
+ CONFIG_T::f(in_fixed_buf, out_fixed_buf);
27
24
 
28
- for(size_t j = 0; j < CONFIG_T::N_out; ++j)
29
- {
30
- c_out[offset_out + j] = out_fixed_buf[j];
31
- }
25
+ for (size_t j = 0; j < CONFIG_T::N_out; ++j) {
26
+ c_out[offset_out + j] = out_fixed_buf[j];
32
27
  }
28
+ }
33
29
  }
34
30
 
35
- template <typename CONFIG_T, typename T> void batch_inference(T *c_inp, T *c_out, size_t n_samples)
36
- {
31
+ template <typename CONFIG_T, typename T> void batch_inference(T *c_inp, T *c_out, size_t n_samples) {
37
32
  #ifdef _OPENMP
38
33
  size_t n_max_threads = omp_get_max_threads();
39
34
  size_t n_samples_per_thread = std::max<size_t>(n_samples / n_max_threads, 32);
@@ -41,15 +36,14 @@ template <typename CONFIG_T, typename T> void batch_inference(T *c_inp, T *c_out
41
36
  n_thread += (n_samples % n_samples_per_thread) ? 1 : 0;
42
37
 
43
38
  #pragma omp parallel for num_threads(n_thread) schedule(static)
44
- for(size_t i = 0; i < n_thread; ++i)
45
- {
46
- size_t start = i * n_samples_per_thread;
47
- size_t end = std::min<size_t>(start + n_samples_per_thread, n_samples);
48
- size_t n_samples_this_thread = end - start;
49
- size_t offset_in = start * CONFIG_T::N_inp;
50
- size_t offset_out = start * CONFIG_T::N_out;
51
- _inference<CONFIG_T, T>(&c_inp[offset_in], &c_out[offset_out], n_samples_this_thread);
52
- }
39
+ for (size_t i = 0; i < n_thread; ++i) {
40
+ size_t start = i * n_samples_per_thread;
41
+ size_t end = std::min<size_t>(start + n_samples_per_thread, n_samples);
42
+ size_t n_samples_this_thread = end - start;
43
+ size_t offset_in = start * CONFIG_T::N_inp;
44
+ size_t offset_out = start * CONFIG_T::N_out;
45
+ _inference<CONFIG_T, T>(&c_inp[offset_in], &c_out[offset_out], n_samples_this_thread);
46
+ }
53
47
  #else
54
48
  _inference<CONFIG_T, T>(c_inp, c_out, n_samples);
55
49
  #endif
@@ -1,14 +1,16 @@
1
1
  #pragma once
2
- #include "ap_types/ap_fixed.h"
2
+ #include "ap_fixed.h"
3
3
 
4
- template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_fixed<b, i + s> bit_shift(ap_fixed<b, i, Q, O, N> x) {
4
+ template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N>
5
+ ap_fixed<b, i + s> bit_shift(ap_fixed<b, i, Q, O, N> x) {
5
6
  #pragma HLS INLINE
6
7
  ap_fixed<b, i + s> r;
7
8
  r.range() = x.range();
8
9
  return r;
9
10
  };
10
11
 
11
- template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_ufixed<b, i + s> bit_shift(ap_ufixed<b, i, Q, O, N> x) {
12
+ template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N>
13
+ ap_ufixed<b, i + s> bit_shift(ap_ufixed<b, i, Q, O, N> x) {
12
14
  #pragma HLS INLINE
13
15
  ap_ufixed<b, i + s> r;
14
16
  r.range() = x.range();
@@ -9,7 +9,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
9
9
  ops = sol.ops
10
10
  kifs = list(map(_minimal_kif, (op.qint for op in ops)))
11
11
  widths = list(map(sum, kifs))
12
- inp_kifs = [_minimal_kif(op.qint) for op in ops if op.opcode == -1]
12
+ inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
13
13
  inp_widths = list(map(sum, inp_kifs))
14
14
  _inp_widths = np.cumsum([0] + inp_widths)
15
15
  inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
@@ -31,6 +31,17 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
31
31
  case -1: # Input marker
32
32
  i0, i1 = inp_idxs[op.id0]
33
33
  line = f'{_def} assign {v} = inp[{i0}:{i1}];'
34
+
35
+ case 0 | 1: # Common a+/-b<<shift oprs
36
+ p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
37
+
38
+ bw0, bw1 = widths[op.id0], widths[op.id1] # width
39
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
40
+ shift = op.data + f0 - f1
41
+ v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
42
+
43
+ line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
44
+
34
45
  case 2 | -2: # ReLU
35
46
  lsb_bias = kifs[op.id0][2] - kifs[i][2]
36
47
  i0, i1 = bw + lsb_bias - 1, lsb_bias
@@ -93,16 +104,6 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
93
104
  num = 2**bw + num
94
105
  line = f"{_def} assign {v} = '{bin(num)[1:]};"
95
106
 
96
- case 0 | 1: # Common a+/-b<<shift oprs
97
- p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
98
-
99
- bw0, bw1 = widths[op.id0], widths[op.id1] # width
100
- s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
101
- shift = op.data + f0 - f1
102
- v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
103
-
104
- line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
105
-
106
107
  case 6 | -6: # MSB Muxing
107
108
  k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
108
109
  p0, p1 = kifs[a], kifs[b]
@@ -115,6 +116,13 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
115
116
  vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
116
117
 
117
118
  line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
119
+ case 7: # Multiplication
120
+ bw0, bw1 = widths[op.id0], widths[op.id1] # width
121
+ s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
122
+ v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
123
+
124
+ line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
125
+
118
126
  case _:
119
127
  raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
120
128
 
@@ -10,7 +10,7 @@ constexpr bool _openmp = false;
10
10
 
11
11
  template <typename CONFIG_T>
12
12
  std::enable_if_t<CONFIG_T::II != 0> _inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {
13
- typename CONFIG_T::dut_t *dut = new typename CONFIG_T::dut_t;
13
+ auto dut = std::make_unique<typename CONFIG_T::dut_t>();
14
14
 
15
15
  size_t clk_req = n_samples * CONFIG_T::II + CONFIG_T::latency + 1;
16
16
 
@@ -18,14 +18,18 @@ std::enable_if_t<CONFIG_T::II != 0> _inference(int32_t *c_inp, int32_t *c_out, s
18
18
  size_t t_out = t_inp - CONFIG_T::latency - 1;
19
19
 
20
20
  if (t_inp < n_samples * CONFIG_T::II && t_inp % CONFIG_T::II == 0) {
21
- write_input<CONFIG_T::N_inp, CONFIG_T::max_inp_bw>(dut->inp, &c_inp[t_inp / CONFIG_T::II * CONFIG_T::N_inp]);
21
+ write_input<CONFIG_T::N_inp, CONFIG_T::max_inp_bw>(
22
+ dut->inp, &c_inp[t_inp / CONFIG_T::II * CONFIG_T::N_inp]
23
+ );
22
24
  }
23
25
 
24
26
  dut->clk = 0;
25
27
  dut->eval();
26
28
 
27
29
  if (t_inp > CONFIG_T::latency && t_out % CONFIG_T::II == 0) {
28
- read_output<CONFIG_T::N_out, CONFIG_T::max_out_bw>(dut->out, &c_out[t_out / CONFIG_T::II * CONFIG_T::N_out]);
30
+ read_output<CONFIG_T::N_out, CONFIG_T::max_out_bw>(
31
+ dut->out, &c_out[t_out / CONFIG_T::II * CONFIG_T::N_out]
32
+ );
29
33
  }
30
34
 
31
35
  dut->clk = 1;
@@ -33,12 +37,11 @@ std::enable_if_t<CONFIG_T::II != 0> _inference(int32_t *c_inp, int32_t *c_out, s
33
37
  }
34
38
 
35
39
  dut->final();
36
- delete dut;
37
40
  }
38
41
 
39
42
  template <typename CONFIG_T>
40
43
  std::enable_if_t<CONFIG_T::II == 0> _inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {
41
- typename CONFIG_T::dut_t *dut = new typename CONFIG_T::dut_t;
44
+ auto dut = std::make_unique<typename CONFIG_T::dut_t>();
42
45
 
43
46
  for (size_t i = 0; i < n_samples; ++i) {
44
47
  write_input<CONFIG_T::N_inp, CONFIG_T::max_inp_bw>(dut->inp, &c_inp[i * CONFIG_T::N_inp]);
@@ -47,7 +50,6 @@ std::enable_if_t<CONFIG_T::II == 0> _inference(int32_t *c_inp, int32_t *c_out, s
47
50
  }
48
51
 
49
52
  dut->final();
50
- delete dut;
51
53
  }
52
54
 
53
55
  template <typename CONFIG_T> void batch_inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {
@@ -68,7 +68,8 @@ template <size_t bw, size_t N_out> std::vector<int32_t> bitunpack(const std::vec
68
68
  }
69
69
 
70
70
  template <size_t bits_in, typename inp_buf_t>
71
- std::enable_if_t<std::is_integral_v<inp_buf_t>, void> _write_input(inp_buf_t &inp_buf, const std::vector<int32_t> &input) {
71
+ std::enable_if_t<std::is_integral_v<inp_buf_t>, void>
72
+ _write_input(inp_buf_t &inp_buf, const std::vector<int32_t> &input) {
72
73
  assert(input.size() == (bits_in + 31) / 32);
73
74
  inp_buf = input[0] & 0xFFFFFFFF;
74
75
  if (bits_in > 32) {
@@ -0,0 +1,37 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module multiplier #(
5
+ parameter BW_INPUT0 = 32,
6
+ parameter BW_INPUT1 = 32,
7
+ parameter SIGNED0 = 0,
8
+ parameter SIGNED1 = 0,
9
+ parameter BW_OUT = 32
10
+ ) (
11
+ input [BW_INPUT0-1:0] in0,
12
+ input [BW_INPUT1-1:0] in1,
13
+ output [BW_OUT-1:0] out
14
+ );
15
+
16
+ localparam BW_BUF = BW_INPUT0 + BW_INPUT1;
17
+
18
+ // verilator lint_off UNUSEDSIGNAL
19
+ wire [BW_BUF - 1:0] buffer;
20
+ // verilator lint_on UNUSEDSIGNAL
21
+
22
+ generate
23
+ if (SIGNED0 == 1 && SIGNED1 == 1) begin : signed_signed
24
+ assign buffer[BW_BUF-1:0] = $signed(in0) * $signed(in1);
25
+ end else if (SIGNED0 == 1 && SIGNED1 == 0) begin : signed_unsigned
26
+ assign buffer[BW_BUF-1:0] = $signed(in0) * $signed({{1'b0,in1}});
27
+ // assign buffer[BW_BUF-1] = in0[BW_INPUT0-1];
28
+ end else if (SIGNED0 == 0 && SIGNED1 == 1) begin : unsigned_signed
29
+ assign buffer[BW_BUF-1:0] = $signed({{1'b0,in0}}) * $signed(in1);
30
+ // assign buffer[BW_BUF-1] = in1[BW_INPUT1-1];
31
+ end else begin : unsigned_unsigned
32
+ assign buffer[BW_BUF-1:0] = in0 * in1;
33
+ end
34
+ endgenerate
35
+
36
+ assign out[BW_OUT-1:0] = buffer[BW_OUT-1:0];
37
+ endmodule
@@ -114,9 +114,8 @@ class VerilogModel:
114
114
  f.write(binder)
115
115
 
116
116
  # Common resource copy
117
- shutil.copy(self.__src_root / 'verilog/source/shift_adder.v', self._path)
118
- shutil.copy(self.__src_root / 'verilog/source/mux.v', self._path)
119
- shutil.copy(self.__src_root / 'verilog/source/negative.v', self._path)
117
+ for fname in self.__src_root.glob('verilog/source/*.v'):
118
+ shutil.copy(fname, self._path)
120
119
  shutil.copy(self.__src_root / 'verilog/source/build_binder.mk', self._path)
121
120
  shutil.copy(self.__src_root / 'verilog/source/ioutil.hh', self._path)
122
121
  shutil.copy(self.__src_root / 'verilog/source/binder_util.hh', self._path)
@@ -0,0 +1,3 @@
1
+ from .hgq2 import trace_model
2
+
3
+ __all__ = ['trace_model']
@@ -0,0 +1,3 @@
1
+ from .parser import trace_model
2
+
3
+ __all__ = ['trace_model']
@@ -1,12 +1,13 @@
1
1
  from collections.abc import Sequence
2
2
  from dataclasses import dataclass
3
- from typing import Any
3
+ from typing import Any, Literal, overload
4
4
 
5
5
  import keras
6
+ import numpy as np
6
7
  from keras import KerasTensor, Operation
7
8
 
8
- from ...trace import FixedVariableArray, HWConfig
9
- from ...trace.fixed_variable_array import FixedVariableArrayInput
9
+ from ...trace import FixedVariableArray, FixedVariableArrayInput, HWConfig, comb_trace
10
+ from ...trace.fixed_variable import FixedVariable
10
11
  from .replica import _registry
11
12
 
12
13
 
@@ -20,6 +21,8 @@ class OpObj:
20
21
 
21
22
 
22
23
  def parse_model(model: keras.Model):
24
+ if isinstance(model, keras.Sequential):
25
+ model = model._functional
23
26
  operators: dict[int, list[OpObj]] = {}
24
27
  for depth, nodes in model._nodes_by_depth.items():
25
28
  _oprs = []
@@ -49,9 +52,24 @@ def replace_tensors(tensor_map: dict[KerasTensor, FixedVariableArray], obj: Any)
49
52
  return obj
50
53
 
51
54
 
55
+ def _flatten_arr(args: Any) -> FixedVariableArray:
56
+ if isinstance(args, FixedVariableArray):
57
+ return np.ravel(args) # type: ignore
58
+ if isinstance(args, FixedVariable):
59
+ return FixedVariableArray(np.array([args]))
60
+ if not isinstance(args, Sequence):
61
+ return None # type: ignore
62
+ args = [_flatten_arr(a) for a in args]
63
+ args = [a for a in args if a is not None]
64
+ return np.concatenate(args) # type: ignore
65
+
66
+
52
67
  def _apply_nn(
53
- model: keras.Model, inputs: FixedVariableArray | Sequence[FixedVariableArray], verbose: bool = False
54
- ) -> tuple[FixedVariableArray, ...]:
68
+ model: keras.Model,
69
+ inputs: FixedVariableArray | Sequence[FixedVariableArray],
70
+ verbose: bool = False,
71
+ dump: bool = False,
72
+ ) -> tuple[FixedVariableArray, ...] | dict[str, FixedVariableArray]:
55
73
  """
56
74
  Apply a keras model to a fixed variable array or a sequence of fixed variable arrays.
57
75
 
@@ -73,6 +91,8 @@ def _apply_nn(
73
91
  assert len(model.inputs) == len(inputs), f'Model has {len(model.inputs)} inputs, got {len(inputs)}'
74
92
  tensor_map = {keras_tensor: da_tensor for keras_tensor, da_tensor in zip(model.inputs, inputs)}
75
93
 
94
+ _inputs = _flatten_arr(inputs)
95
+
76
96
  for ops in parse_model(model):
77
97
  for op in ops:
78
98
  assert all(t in tensor_map for t in op.requires)
@@ -82,24 +102,56 @@ def _apply_nn(
82
102
  continue
83
103
  mirror_op = _registry[op.operation.__class__](op.operation)
84
104
  if verbose:
85
- print(f'Processing operation {op.operation.name} ({op.operation.__class__.__name__})')
105
+ print(f'Processing operation {op.operation.name} ({op.operation.__class__.__name__})', end='')
86
106
  outputs = mirror_op(*args, **kwargs)
87
107
  for keras_tensor, da_tensor in zip(op.produces, outputs):
88
108
  tensor_map[keras_tensor] = da_tensor
109
+ if verbose:
110
+ cost = comb_trace(_inputs, _flatten_arr(outputs)).cost
111
+ print(f' cumcost: {cost}')
112
+
113
+ if not dump:
114
+ return tuple(tensor_map[keras_tensor] for keras_tensor in model.outputs)
115
+ else:
116
+ return {k.name: v for k, v in tensor_map.items()}
117
+
118
+
119
+ @overload
120
+ def trace_model( # type: ignore
121
+ model: keras.Model,
122
+ hwconf: HWConfig = HWConfig(1, -1, -1),
123
+ solver_options: dict[str, Any] | None = None,
124
+ verbose: bool = False,
125
+ inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
126
+ dump: Literal[False] = False,
127
+ ) -> tuple[FixedVariableArray, FixedVariableArray]: ...
128
+
89
129
 
90
- return tuple(tensor_map[keras_tensor] for keras_tensor in model.outputs)
130
+ @overload
131
+ def trace_model( # type: ignore
132
+ model: keras.Model,
133
+ hwconf: HWConfig = HWConfig(1, -1, -1),
134
+ solver_options: dict[str, Any] | None = None,
135
+ verbose: bool = False,
136
+ inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
137
+ dump: Literal[True] = False, # type: ignore
138
+ ) -> dict[str, FixedVariableArray]: ...
91
139
 
92
140
 
93
- def trace_model(
141
+ def trace_model( # type: ignore
94
142
  model: keras.Model,
95
143
  hwconf: HWConfig = HWConfig(1, -1, -1),
96
144
  solver_options: dict[str, Any] | None = None,
97
145
  verbose: bool = False,
98
146
  inputs: tuple[FixedVariableArray, ...] | None = None,
99
- ) -> tuple[tuple[FixedVariableArray, ...], tuple[FixedVariableArray, ...]]:
147
+ dump=False,
148
+ ):
100
149
  if inputs is None:
101
150
  inputs = tuple(
102
151
  FixedVariableArrayInput(inp.shape[1:], hwconf=hwconf, solver_options=solver_options) for inp in model.inputs
103
152
  )
104
- outputs = _apply_nn(model, inputs, verbose=verbose)
105
- return inputs, outputs
153
+ outputs = _apply_nn(model, inputs, verbose=verbose, dump=dump)
154
+ if not dump:
155
+ return _flatten_arr(inputs), _flatten_arr(outputs)
156
+ else:
157
+ return {k: _flatten_arr(v) for k, v in outputs.items()} # type: ignore