da4ml 0.3.0.post1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/_version.py +16 -3
- da4ml/cmvm/types.py +12 -2
- da4ml/codegen/cpp/cpp_codegen.py +4 -1
- da4ml/codegen/verilog/comb.py +19 -11
- da4ml/codegen/verilog/source/binder_util.hh +8 -6
- da4ml/codegen/verilog/source/build_prj.tcl +6 -8
- da4ml/codegen/verilog/source/ioutil.hh +2 -1
- da4ml/codegen/verilog/source/multiplier.v +37 -0
- da4ml/codegen/verilog/verilog_model.py +4 -5
- da4ml/converter/__init__.py +3 -0
- da4ml/converter/hgq2/__init__.py +3 -0
- da4ml/converter/hgq2/parser.py +60 -10
- da4ml/converter/hgq2/replica.py +125 -35
- da4ml/trace/fixed_variable.py +133 -20
- da4ml/trace/fixed_variable_array.py +55 -7
- da4ml/trace/ops/__init__.py +4 -4
- da4ml/trace/ops/einsum_utils.py +5 -2
- da4ml/trace/ops/reduce_utils.py +4 -2
- da4ml/trace/pipeline.py +6 -4
- da4ml/trace/tracer.py +27 -13
- da4ml-0.3.2.dist-info/METADATA +66 -0
- {da4ml-0.3.0.post1.dist-info → da4ml-0.3.2.dist-info}/RECORD +25 -23
- da4ml-0.3.0.post1.dist-info/METADATA +0 -107
- {da4ml-0.3.0.post1.dist-info → da4ml-0.3.2.dist-info}/WHEEL +0 -0
- {da4ml-0.3.0.post1.dist-info → da4ml-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.3.0.post1.dist-info → da4ml-0.3.2.dist-info}/top_level.txt +0 -0
da4ml/_version.py
CHANGED
|
@@ -1,7 +1,14 @@
|
|
|
1
1
|
# file generated by setuptools-scm
|
|
2
2
|
# don't change, don't track in version control
|
|
3
3
|
|
|
4
|
-
__all__ = [
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
5
12
|
|
|
6
13
|
TYPE_CHECKING = False
|
|
7
14
|
if TYPE_CHECKING:
|
|
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
|
|
|
9
16
|
from typing import Union
|
|
10
17
|
|
|
11
18
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
12
20
|
else:
|
|
13
21
|
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
14
23
|
|
|
15
24
|
version: str
|
|
16
25
|
__version__: str
|
|
17
26
|
__version_tuple__: VERSION_TUPLE
|
|
18
27
|
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
19
30
|
|
|
20
|
-
__version__ = version = '0.3.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 2)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
da4ml/cmvm/types.py
CHANGED
|
@@ -321,7 +321,7 @@ class Solution(NamedTuple):
|
|
|
321
321
|
case 4: # const addition
|
|
322
322
|
bias = op.data * op.qint.step
|
|
323
323
|
buf[i] = buf[op.id0] + bias
|
|
324
|
-
case 5:
|
|
324
|
+
case 5: # const definition
|
|
325
325
|
buf[i] = op.data * op.qint.step # const definition
|
|
326
326
|
case 6 | -6: # MSB Mux
|
|
327
327
|
id_c = op.data & 0xFFFFFFFF
|
|
@@ -340,6 +340,9 @@ class Solution(NamedTuple):
|
|
|
340
340
|
else:
|
|
341
341
|
_k, _i, _f = _minimal_kif(qint_k)
|
|
342
342
|
buf[i] = v0 if k >= 2.0 ** (_i - 1) else v1 * 2.0**shift
|
|
343
|
+
case 7:
|
|
344
|
+
v0, v1 = buf[op.id0], buf[op.id1]
|
|
345
|
+
buf[i] = v0 * v1
|
|
343
346
|
case _:
|
|
344
347
|
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
345
348
|
|
|
@@ -370,6 +373,8 @@ class Solution(NamedTuple):
|
|
|
370
373
|
case 6 | -6:
|
|
371
374
|
_sign = '-' if op.opcode == -6 else ''
|
|
372
375
|
op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
|
|
376
|
+
case 7:
|
|
377
|
+
op_str = f'buf[{op.id0}] * buf[{op.id1}]'
|
|
373
378
|
case _:
|
|
374
379
|
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
375
380
|
|
|
@@ -436,7 +441,12 @@ class Solution(NamedTuple):
|
|
|
436
441
|
@property
|
|
437
442
|
def inp_qint(self):
|
|
438
443
|
"""Quantization intervals of the input elements."""
|
|
439
|
-
|
|
444
|
+
qints = [QInterval(0.0, 0.0, 1.0) for _ in range(self.shape[0])]
|
|
445
|
+
for op in self.ops:
|
|
446
|
+
if op.opcode != -1:
|
|
447
|
+
continue
|
|
448
|
+
qints[op.id0] = op.qint
|
|
449
|
+
return qints
|
|
440
450
|
|
|
441
451
|
def save(self, path: str | Path):
|
|
442
452
|
"""Save the solution to a file."""
|
da4ml/codegen/cpp/cpp_codegen.py
CHANGED
|
@@ -86,7 +86,10 @@ def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int
|
|
|
86
86
|
sign = '-' if op.opcode == -6 else ''
|
|
87
87
|
ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
|
|
88
88
|
val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
|
|
89
|
-
|
|
89
|
+
case 7:
|
|
90
|
+
# Multiplication
|
|
91
|
+
ref1 = f'v{op.id1}'
|
|
92
|
+
val = f'{ref0} * {ref1}'
|
|
90
93
|
case _:
|
|
91
94
|
raise ValueError(f'Unsupported opcode: {op.opcode}')
|
|
92
95
|
|
da4ml/codegen/verilog/comb.py
CHANGED
|
@@ -9,7 +9,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
9
9
|
ops = sol.ops
|
|
10
10
|
kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
11
11
|
widths = list(map(sum, kifs))
|
|
12
|
-
inp_kifs = [_minimal_kif(
|
|
12
|
+
inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
|
|
13
13
|
inp_widths = list(map(sum, inp_kifs))
|
|
14
14
|
_inp_widths = np.cumsum([0] + inp_widths)
|
|
15
15
|
inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
|
|
@@ -31,6 +31,17 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
31
31
|
case -1: # Input marker
|
|
32
32
|
i0, i1 = inp_idxs[op.id0]
|
|
33
33
|
line = f'{_def} assign {v} = inp[{i0}:{i1}];'
|
|
34
|
+
|
|
35
|
+
case 0 | 1: # Common a+/-b<<shift oprs
|
|
36
|
+
p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
|
|
37
|
+
|
|
38
|
+
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
39
|
+
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
40
|
+
shift = op.data + f0 - f1
|
|
41
|
+
v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
|
|
42
|
+
|
|
43
|
+
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
|
|
44
|
+
|
|
34
45
|
case 2 | -2: # ReLU
|
|
35
46
|
lsb_bias = kifs[op.id0][2] - kifs[i][2]
|
|
36
47
|
i0, i1 = bw + lsb_bias - 1, lsb_bias
|
|
@@ -93,16 +104,6 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
93
104
|
num = 2**bw + num
|
|
94
105
|
line = f"{_def} assign {v} = '{bin(num)[1:]};"
|
|
95
106
|
|
|
96
|
-
case 0 | 1: # Common a+/-b<<shift oprs
|
|
97
|
-
p0, p1 = kifs[op.id0], kifs[op.id1] # precision -> keep_neg, integers (no sign), fractional
|
|
98
|
-
|
|
99
|
-
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
100
|
-
s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
|
|
101
|
-
shift = op.data + f0 - f1
|
|
102
|
-
v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
|
|
103
|
-
|
|
104
|
-
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
|
|
105
|
-
|
|
106
107
|
case 6 | -6: # MSB Muxing
|
|
107
108
|
k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
|
|
108
109
|
p0, p1 = kifs[a], kifs[b]
|
|
@@ -115,6 +116,13 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
115
116
|
vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
|
|
116
117
|
|
|
117
118
|
line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
|
|
119
|
+
case 7: # Multiplication
|
|
120
|
+
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
121
|
+
s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
|
|
122
|
+
v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
|
|
123
|
+
|
|
124
|
+
line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
|
|
125
|
+
|
|
118
126
|
case _:
|
|
119
127
|
raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
|
|
120
128
|
|
|
@@ -10,7 +10,7 @@ constexpr bool _openmp = false;
|
|
|
10
10
|
|
|
11
11
|
template <typename CONFIG_T>
|
|
12
12
|
std::enable_if_t<CONFIG_T::II != 0> _inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {
|
|
13
|
-
|
|
13
|
+
auto dut = std::make_unique<typename CONFIG_T::dut_t>();
|
|
14
14
|
|
|
15
15
|
size_t clk_req = n_samples * CONFIG_T::II + CONFIG_T::latency + 1;
|
|
16
16
|
|
|
@@ -18,14 +18,18 @@ std::enable_if_t<CONFIG_T::II != 0> _inference(int32_t *c_inp, int32_t *c_out, s
|
|
|
18
18
|
size_t t_out = t_inp - CONFIG_T::latency - 1;
|
|
19
19
|
|
|
20
20
|
if (t_inp < n_samples * CONFIG_T::II && t_inp % CONFIG_T::II == 0) {
|
|
21
|
-
write_input<CONFIG_T::N_inp, CONFIG_T::max_inp_bw>(
|
|
21
|
+
write_input<CONFIG_T::N_inp, CONFIG_T::max_inp_bw>(
|
|
22
|
+
dut->inp, &c_inp[t_inp / CONFIG_T::II * CONFIG_T::N_inp]
|
|
23
|
+
);
|
|
22
24
|
}
|
|
23
25
|
|
|
24
26
|
dut->clk = 0;
|
|
25
27
|
dut->eval();
|
|
26
28
|
|
|
27
29
|
if (t_inp > CONFIG_T::latency && t_out % CONFIG_T::II == 0) {
|
|
28
|
-
read_output<CONFIG_T::N_out, CONFIG_T::max_out_bw>(
|
|
30
|
+
read_output<CONFIG_T::N_out, CONFIG_T::max_out_bw>(
|
|
31
|
+
dut->out, &c_out[t_out / CONFIG_T::II * CONFIG_T::N_out]
|
|
32
|
+
);
|
|
29
33
|
}
|
|
30
34
|
|
|
31
35
|
dut->clk = 1;
|
|
@@ -33,12 +37,11 @@ std::enable_if_t<CONFIG_T::II != 0> _inference(int32_t *c_inp, int32_t *c_out, s
|
|
|
33
37
|
}
|
|
34
38
|
|
|
35
39
|
dut->final();
|
|
36
|
-
delete dut;
|
|
37
40
|
}
|
|
38
41
|
|
|
39
42
|
template <typename CONFIG_T>
|
|
40
43
|
std::enable_if_t<CONFIG_T::II == 0> _inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {
|
|
41
|
-
|
|
44
|
+
auto dut = std::make_unique<typename CONFIG_T::dut_t>();
|
|
42
45
|
|
|
43
46
|
for (size_t i = 0; i < n_samples; ++i) {
|
|
44
47
|
write_input<CONFIG_T::N_inp, CONFIG_T::max_inp_bw>(dut->inp, &c_inp[i * CONFIG_T::N_inp]);
|
|
@@ -47,7 +50,6 @@ std::enable_if_t<CONFIG_T::II == 0> _inference(int32_t *c_inp, int32_t *c_out, s
|
|
|
47
50
|
}
|
|
48
51
|
|
|
49
52
|
dut->final();
|
|
50
|
-
delete dut;
|
|
51
53
|
}
|
|
52
54
|
|
|
53
55
|
template <typename CONFIG_T> void batch_inference(int32_t *c_inp, int32_t *c_out, size_t n_samples) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
set project_name "${PROJECT_NAME}"
|
|
2
2
|
set device "${DEVICE}"
|
|
3
3
|
|
|
4
|
-
set top_module "${project_name}
|
|
4
|
+
set top_module "${project_name}"
|
|
5
5
|
set output_dir "./output_${project_name}"
|
|
6
6
|
|
|
7
7
|
create_project $project_name "${output_dir}/$project_name" -force -part $device
|
|
@@ -9,9 +9,10 @@ create_project $project_name "${output_dir}/$project_name" -force -part $device
|
|
|
9
9
|
set_property TARGET_LANGUAGE Verilog [current_project]
|
|
10
10
|
set_property DEFAULT_LIB work [current_project]
|
|
11
11
|
|
|
12
|
-
read_verilog "${project_name}_wrapper.v"
|
|
13
12
|
read_verilog "${project_name}.v"
|
|
14
13
|
read_verilog "shift_adder.v"
|
|
14
|
+
read_verilog "negative.v"
|
|
15
|
+
read_verilog "mux.v"
|
|
15
16
|
foreach file [glob -nocomplain "${project_name}_stage*.v"] {
|
|
16
17
|
read_verilog $file
|
|
17
18
|
}
|
|
@@ -25,8 +26,7 @@ file mkdir "${output_dir}/reports"
|
|
|
25
26
|
|
|
26
27
|
# synth
|
|
27
28
|
synth_design -top $top_module -mode out_of_context -retiming \
|
|
28
|
-
-flatten_hierarchy
|
|
29
|
-
-directive AlternateRoutability
|
|
29
|
+
-flatten_hierarchy full -resource_sharing auto
|
|
30
30
|
|
|
31
31
|
write_checkpoint -force "${output_dir}/${project_name}_post_synth.dcp"
|
|
32
32
|
|
|
@@ -34,15 +34,13 @@ report_timing_summary -file "${output_dir}/reports/${project_name}_post_synth_ti
|
|
|
34
34
|
report_power -file "${output_dir}/reports/${project_name}_post_synth_power.rpt"
|
|
35
35
|
report_utilization -file "${output_dir}/reports/${project_name}_post_synth_util.rpt"
|
|
36
36
|
|
|
37
|
-
#
|
|
38
|
-
|
|
39
|
-
opt_design -directive ExploreSequentialArea
|
|
37
|
+
# opt_design -directive ExploreSequentialArea
|
|
40
38
|
opt_design -directive ExploreWithRemap
|
|
41
39
|
|
|
42
40
|
report_design_analysis -congestion -file "${output_dir}/reports/${project_name}_post_opt_congestion.rpt"
|
|
43
41
|
|
|
44
42
|
# place
|
|
45
|
-
place_design -directive
|
|
43
|
+
place_design -directive SSI_HighUtilSLRs -fanout_opt
|
|
46
44
|
report_design_analysis -congestion -file "${output_dir}/reports/${project_name}_post_place_congestion_initial.rpt"
|
|
47
45
|
|
|
48
46
|
phys_opt_design -directive AggressiveExplore
|
|
@@ -68,7 +68,8 @@ template <size_t bw, size_t N_out> std::vector<int32_t> bitunpack(const std::vec
|
|
|
68
68
|
}
|
|
69
69
|
|
|
70
70
|
template <size_t bits_in, typename inp_buf_t>
|
|
71
|
-
std::enable_if_t<std::is_integral_v<inp_buf_t>, void>
|
|
71
|
+
std::enable_if_t<std::is_integral_v<inp_buf_t>, void>
|
|
72
|
+
_write_input(inp_buf_t &inp_buf, const std::vector<int32_t> &input) {
|
|
72
73
|
assert(input.size() == (bits_in + 31) / 32);
|
|
73
74
|
inp_buf = input[0] & 0xFFFFFFFF;
|
|
74
75
|
if (bits_in > 32) {
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module multiplier #(
|
|
5
|
+
parameter BW_INPUT0 = 32,
|
|
6
|
+
parameter BW_INPUT1 = 32,
|
|
7
|
+
parameter SIGNED0 = 0,
|
|
8
|
+
parameter SIGNED1 = 0,
|
|
9
|
+
parameter BW_OUT = 32
|
|
10
|
+
) (
|
|
11
|
+
input [BW_INPUT0-1:0] in0,
|
|
12
|
+
input [BW_INPUT1-1:0] in1,
|
|
13
|
+
output [BW_OUT-1:0] out
|
|
14
|
+
);
|
|
15
|
+
|
|
16
|
+
localparam BW_BUF = BW_INPUT0 + BW_INPUT1;
|
|
17
|
+
|
|
18
|
+
// verilator lint_off UNUSEDSIGNAL
|
|
19
|
+
wire [BW_BUF - 1:0] buffer;
|
|
20
|
+
// verilator lint_on UNUSEDSIGNAL
|
|
21
|
+
|
|
22
|
+
generate
|
|
23
|
+
if (SIGNED0 == 1 && SIGNED1 == 1) begin : signed_signed
|
|
24
|
+
assign buffer[BW_BUF-1:0] = $signed(in0) * $signed(in1);
|
|
25
|
+
end else if (SIGNED0 == 1 && SIGNED1 == 0) begin : signed_unsigned
|
|
26
|
+
assign buffer[BW_BUF-1:0] = $signed(in0) * $signed({{1'b0,in1}});
|
|
27
|
+
// assign buffer[BW_BUF-1] = in0[BW_INPUT0-1];
|
|
28
|
+
end else if (SIGNED0 == 0 && SIGNED1 == 1) begin : unsigned_signed
|
|
29
|
+
assign buffer[BW_BUF-1:0] = $signed({{1'b0,in0}}) * $signed(in1);
|
|
30
|
+
// assign buffer[BW_BUF-1] = in1[BW_INPUT1-1];
|
|
31
|
+
end else begin : unsigned_unsigned
|
|
32
|
+
assign buffer[BW_BUF-1:0] = in0 * in1;
|
|
33
|
+
end
|
|
34
|
+
endgenerate
|
|
35
|
+
|
|
36
|
+
assign out[BW_OUT-1:0] = buffer[BW_OUT-1:0];
|
|
37
|
+
endmodule
|
|
@@ -28,10 +28,10 @@ class VerilogModel:
|
|
|
28
28
|
solution: Solution | CascadedSolution,
|
|
29
29
|
prj_name: str,
|
|
30
30
|
path: str | Path,
|
|
31
|
-
latency_cutoff:
|
|
31
|
+
latency_cutoff: float = -1,
|
|
32
32
|
print_latency: bool = True,
|
|
33
33
|
part_name: str = 'xcvu13p-flga2577-2-e',
|
|
34
|
-
clock_period:
|
|
34
|
+
clock_period: float = 5,
|
|
35
35
|
clock_uncertainty: float = 0.1,
|
|
36
36
|
io_delay_minmax: tuple[float, float] = (0.2, 0.4),
|
|
37
37
|
register_layers: int = 1,
|
|
@@ -114,9 +114,8 @@ class VerilogModel:
|
|
|
114
114
|
f.write(binder)
|
|
115
115
|
|
|
116
116
|
# Common resource copy
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
shutil.copy(self.__src_root / 'verilog/source/negative.v', self._path)
|
|
117
|
+
for fname in self.__src_root.glob('verilog/source/*.v'):
|
|
118
|
+
shutil.copy(fname, self._path)
|
|
120
119
|
shutil.copy(self.__src_root / 'verilog/source/build_binder.mk', self._path)
|
|
121
120
|
shutil.copy(self.__src_root / 'verilog/source/ioutil.hh', self._path)
|
|
122
121
|
shutil.copy(self.__src_root / 'verilog/source/binder_util.hh', self._path)
|
da4ml/converter/__init__.py
CHANGED
da4ml/converter/hgq2/parser.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any, Literal, overload
|
|
4
4
|
|
|
5
5
|
import keras
|
|
6
|
+
import numpy as np
|
|
6
7
|
from keras import KerasTensor, Operation
|
|
7
8
|
|
|
8
|
-
from ...trace import FixedVariableArray, HWConfig
|
|
9
|
+
from ...trace import FixedVariableArray, HWConfig, comb_trace
|
|
9
10
|
from ...trace.fixed_variable_array import FixedVariableArrayInput
|
|
10
11
|
from .replica import _registry
|
|
11
12
|
|
|
@@ -20,6 +21,8 @@ class OpObj:
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
def parse_model(model: keras.Model):
|
|
24
|
+
if isinstance(model, keras.Sequential):
|
|
25
|
+
model = model._functional
|
|
23
26
|
operators: dict[int, list[OpObj]] = {}
|
|
24
27
|
for depth, nodes in model._nodes_by_depth.items():
|
|
25
28
|
_oprs = []
|
|
@@ -49,9 +52,22 @@ def replace_tensors(tensor_map: dict[KerasTensor, FixedVariableArray], obj: Any)
|
|
|
49
52
|
return obj
|
|
50
53
|
|
|
51
54
|
|
|
55
|
+
def _flatten_arr(args: Any) -> FixedVariableArray:
|
|
56
|
+
if isinstance(args, FixedVariableArray):
|
|
57
|
+
return np.ravel(args) # type: ignore
|
|
58
|
+
if not isinstance(args, Sequence):
|
|
59
|
+
return None # type: ignore
|
|
60
|
+
args = [_flatten_arr(a) for a in args]
|
|
61
|
+
args = [a for a in args if a is not None]
|
|
62
|
+
return np.concatenate(args) # type: ignore
|
|
63
|
+
|
|
64
|
+
|
|
52
65
|
def _apply_nn(
|
|
53
|
-
model: keras.Model,
|
|
54
|
-
|
|
66
|
+
model: keras.Model,
|
|
67
|
+
inputs: FixedVariableArray | Sequence[FixedVariableArray],
|
|
68
|
+
verbose: bool = False,
|
|
69
|
+
dump: bool = False,
|
|
70
|
+
) -> tuple[FixedVariableArray, ...] | dict[str, FixedVariableArray]:
|
|
55
71
|
"""
|
|
56
72
|
Apply a keras model to a fixed variable array or a sequence of fixed variable arrays.
|
|
57
73
|
|
|
@@ -73,6 +89,8 @@ def _apply_nn(
|
|
|
73
89
|
assert len(model.inputs) == len(inputs), f'Model has {len(model.inputs)} inputs, got {len(inputs)}'
|
|
74
90
|
tensor_map = {keras_tensor: da_tensor for keras_tensor, da_tensor in zip(model.inputs, inputs)}
|
|
75
91
|
|
|
92
|
+
_inputs = _flatten_arr(inputs)
|
|
93
|
+
|
|
76
94
|
for ops in parse_model(model):
|
|
77
95
|
for op in ops:
|
|
78
96
|
assert all(t in tensor_map for t in op.requires)
|
|
@@ -82,24 +100,56 @@ def _apply_nn(
|
|
|
82
100
|
continue
|
|
83
101
|
mirror_op = _registry[op.operation.__class__](op.operation)
|
|
84
102
|
if verbose:
|
|
85
|
-
print(f'Processing operation {op.operation.name} ({op.operation.__class__.__name__})')
|
|
103
|
+
print(f'Processing operation {op.operation.name} ({op.operation.__class__.__name__})', end='')
|
|
86
104
|
outputs = mirror_op(*args, **kwargs)
|
|
87
105
|
for keras_tensor, da_tensor in zip(op.produces, outputs):
|
|
88
106
|
tensor_map[keras_tensor] = da_tensor
|
|
107
|
+
if verbose:
|
|
108
|
+
cost = comb_trace(_inputs, _flatten_arr(outputs)).cost
|
|
109
|
+
print(f' cumcost: {cost}')
|
|
110
|
+
|
|
111
|
+
if not dump:
|
|
112
|
+
return tuple(tensor_map[keras_tensor] for keras_tensor in model.outputs)
|
|
113
|
+
else:
|
|
114
|
+
return {k.name: v for k, v in tensor_map.items()}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@overload
|
|
118
|
+
def trace_model( # type: ignore
|
|
119
|
+
model: keras.Model,
|
|
120
|
+
hwconf: HWConfig = HWConfig(1, -1, -1),
|
|
121
|
+
solver_options: dict[str, Any] | None = None,
|
|
122
|
+
verbose: bool = False,
|
|
123
|
+
inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
|
|
124
|
+
dump: Literal[False] = False,
|
|
125
|
+
) -> tuple[FixedVariableArray, FixedVariableArray]: ...
|
|
126
|
+
|
|
89
127
|
|
|
90
|
-
|
|
128
|
+
@overload
|
|
129
|
+
def trace_model( # type: ignore
|
|
130
|
+
model: keras.Model,
|
|
131
|
+
hwconf: HWConfig = HWConfig(1, -1, -1),
|
|
132
|
+
solver_options: dict[str, Any] | None = None,
|
|
133
|
+
verbose: bool = False,
|
|
134
|
+
inputs: tuple[FixedVariableArray, ...] | FixedVariableArray | None = None,
|
|
135
|
+
dump: Literal[True] = False, # type: ignore
|
|
136
|
+
) -> dict[str, FixedVariableArray]: ...
|
|
91
137
|
|
|
92
138
|
|
|
93
|
-
def trace_model(
|
|
139
|
+
def trace_model( # type: ignore
|
|
94
140
|
model: keras.Model,
|
|
95
141
|
hwconf: HWConfig = HWConfig(1, -1, -1),
|
|
96
142
|
solver_options: dict[str, Any] | None = None,
|
|
97
143
|
verbose: bool = False,
|
|
98
144
|
inputs: tuple[FixedVariableArray, ...] | None = None,
|
|
99
|
-
|
|
145
|
+
dump=False,
|
|
146
|
+
):
|
|
100
147
|
if inputs is None:
|
|
101
148
|
inputs = tuple(
|
|
102
149
|
FixedVariableArrayInput(inp.shape[1:], hwconf=hwconf, solver_options=solver_options) for inp in model.inputs
|
|
103
150
|
)
|
|
104
|
-
outputs = _apply_nn(model, inputs, verbose=verbose)
|
|
105
|
-
|
|
151
|
+
outputs = _apply_nn(model, inputs, verbose=verbose, dump=dump)
|
|
152
|
+
if not dump:
|
|
153
|
+
return _flatten_arr(inputs), _flatten_arr(outputs)
|
|
154
|
+
else:
|
|
155
|
+
return {k: _flatten_arr(v) for k, v in outputs.items()} # type: ignore
|