da4ml 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/_version.py +2 -2
- da4ml/cmvm/api.py +2 -6
- da4ml/cmvm/core/__init__.py +0 -1
- da4ml/cmvm/types.py +99 -19
- da4ml/codegen/__init__.py +5 -4
- da4ml/codegen/cpp/__init__.py +2 -1
- da4ml/codegen/cpp/cpp_codegen.py +58 -25
- da4ml/codegen/cpp/hls_model.py +252 -0
- da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/cpp/source/binder_util.hh +56 -0
- da4ml/codegen/cpp/source/build_binder.mk +24 -0
- da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
- da4ml/codegen/verilog/__init__.py +2 -3
- da4ml/codegen/verilog/comb.py +65 -24
- da4ml/codegen/verilog/io_wrapper.py +36 -141
- da4ml/codegen/verilog/pipeline.py +21 -3
- da4ml/codegen/verilog/source/binder_util.hh +72 -0
- da4ml/codegen/verilog/source/build_prj.tcl +0 -1
- da4ml/codegen/verilog/source/mux.v +58 -0
- da4ml/codegen/verilog/source/negative.v +28 -0
- da4ml/codegen/verilog/source/shift_adder.v +4 -1
- da4ml/codegen/verilog/source/template.xdc +3 -0
- da4ml/codegen/verilog/verilog_model.py +42 -15
- da4ml/converter/__init__.py +0 -0
- da4ml/converter/hgq2/parser.py +105 -0
- da4ml/converter/hgq2/replica.py +383 -0
- da4ml/trace/__init__.py +2 -2
- da4ml/trace/fixed_variable.py +177 -18
- da4ml/trace/fixed_variable_array.py +124 -9
- da4ml/trace/ops/__init__.py +22 -6
- da4ml/trace/ops/conv_utils.py +146 -14
- da4ml/trace/ops/einsum_utils.py +9 -6
- da4ml/trace/ops/reduce_utils.py +103 -0
- da4ml/trace/pipeline.py +36 -34
- da4ml/trace/tracer.py +37 -5
- da4ml-0.3.0.dist-info/METADATA +107 -0
- da4ml-0.3.0.dist-info/RECORD +64 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
- da4ml-0.2.0.dist-info/METADATA +0 -65
- da4ml-0.2.0.dist-info/RECORD +0 -39
- /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
- {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/WHEEL +0 -0
- {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/top_level.txt +0 -0
da4ml/_version.py
CHANGED
da4ml/cmvm/api.py
CHANGED
|
@@ -140,10 +140,6 @@ def jit_solve(
|
|
|
140
140
|
if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
|
|
141
141
|
decompose_dc -= 1
|
|
142
142
|
continue
|
|
143
|
-
if sum([op.cost for op in sol1.ops]) * 4 > sum([op.cost for op in sol0.ops]) and decompose_dc > 0:
|
|
144
|
-
# If the second stage is too expensive, the decomposition usually doesn't worth it
|
|
145
|
-
decompose_dc -= 1
|
|
146
|
-
continue
|
|
147
143
|
break
|
|
148
144
|
if max(latencies1) > latency_allowed:
|
|
149
145
|
# When latency depends on the bw, may happen
|
|
@@ -158,8 +154,8 @@ def solve(
|
|
|
158
154
|
method1: str = 'auto',
|
|
159
155
|
hard_dc: int = -1,
|
|
160
156
|
decompose_dc: int = -2,
|
|
161
|
-
qintervals:
|
|
162
|
-
latencies:
|
|
157
|
+
qintervals: list[QInterval] | None = None,
|
|
158
|
+
latencies: list[float] | None = None,
|
|
163
159
|
adder_size: int = -1,
|
|
164
160
|
carry_size: int = -1,
|
|
165
161
|
search_all_decompose_dc: bool = True,
|
da4ml/cmvm/core/__init__.py
CHANGED
da4ml/cmvm/types.py
CHANGED
|
@@ -159,6 +159,8 @@ def _relu(v: 'T', i: int | None = None, f: int | None = None, inv: bool = False,
|
|
|
159
159
|
from ..trace.fixed_variable import FixedVariable
|
|
160
160
|
|
|
161
161
|
assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic relu'
|
|
162
|
+
if inv:
|
|
163
|
+
v = -v
|
|
162
164
|
return v.relu(i, f, round_mode=round_mode)
|
|
163
165
|
|
|
164
166
|
|
|
@@ -289,15 +291,16 @@ class Solution(NamedTuple):
|
|
|
289
291
|
The output data after applying the operations defined in the solution.
|
|
290
292
|
|
|
291
293
|
"""
|
|
294
|
+
|
|
295
|
+
from ..trace.fixed_variable import FixedVariable
|
|
296
|
+
|
|
292
297
|
buf = np.empty(len(self.ops), dtype=object)
|
|
293
298
|
inp = np.asarray(inp)
|
|
294
299
|
|
|
295
300
|
inp_qint = [op.qint for op in self.ops if op.opcode == -1]
|
|
296
301
|
if quantize: # TRN and WRAP
|
|
297
302
|
k, i, f = map(np.array, zip(*map(minimal_kif, inp_qint)))
|
|
298
|
-
|
|
299
|
-
_low, _high = -(2.0 ** (i + f)) * k, 2.0 ** (i + f) - 1
|
|
300
|
-
inp = eps * ((np.floor(inp / eps) - _low) % 2.0 ** (k + i + f) + _low)
|
|
303
|
+
inp = [_quantize(*x, round_mode='TRN') for x in zip(inp, k, i, f)]
|
|
301
304
|
|
|
302
305
|
inp = inp * (2.0 ** np.array(self.inp_shift))
|
|
303
306
|
for i, op in enumerate(self.ops):
|
|
@@ -320,39 +323,61 @@ class Solution(NamedTuple):
|
|
|
320
323
|
buf[i] = buf[op.id0] + bias
|
|
321
324
|
case 5:
|
|
322
325
|
buf[i] = op.data * op.qint.step # const definition
|
|
326
|
+
case 6 | -6: # MSB Mux
|
|
327
|
+
id_c = op.data & 0xFFFFFFFF
|
|
328
|
+
k, v0, v1 = buf[id_c], buf[op.id0], buf[op.id1]
|
|
329
|
+
shift = (op.data >> 32) & 0xFFFFFFFF
|
|
330
|
+
shift = shift if shift < 0x80000000 else shift - 0x100000000
|
|
331
|
+
if op.opcode == -6:
|
|
332
|
+
v1 = -v1
|
|
333
|
+
|
|
334
|
+
if isinstance(k, FixedVariable):
|
|
335
|
+
buf[i] = k.msb_mux(v0, v1 * 2**shift)
|
|
336
|
+
else:
|
|
337
|
+
qint_k = self.ops[id_c].qint
|
|
338
|
+
if qint_k.min < 0:
|
|
339
|
+
buf[i] = v0 if k < 0 else v1 * 2.0**shift
|
|
340
|
+
else:
|
|
341
|
+
_k, _i, _f = _minimal_kif(qint_k)
|
|
342
|
+
buf[i] = v0 if k >= 2.0 ** (_i - 1) else v1 * 2.0**shift
|
|
323
343
|
case _:
|
|
324
344
|
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
325
345
|
|
|
326
|
-
sf = 2.0 ** np.array(self.out_shifts)
|
|
346
|
+
sf = 2.0 ** np.array(self.out_shifts, dtype=np.float64)
|
|
327
347
|
sign = np.where(self.out_negs, -1, 1)
|
|
328
|
-
out_idx = np.array(self.out_idxs)
|
|
348
|
+
out_idx = np.array(self.out_idxs, dtype=np.int32)
|
|
329
349
|
mask = np.where(out_idx < 0, 0, 1)
|
|
330
350
|
if debug:
|
|
351
|
+
operands = []
|
|
331
352
|
for i, v in enumerate(buf):
|
|
332
353
|
op = self.ops[i]
|
|
333
354
|
match op.opcode:
|
|
334
355
|
case -1:
|
|
335
356
|
op_str = 'inp'
|
|
336
|
-
case 0:
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
op_str = f'relu(buf[{op.id0}])'
|
|
342
|
-
case -
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
op_str = f'quantize(buf[{op.id0}])'
|
|
346
|
-
case -3:
|
|
347
|
-
op_str = f'quantize(-buf[{op.id0}])'
|
|
357
|
+
case 0 | 1:
|
|
358
|
+
_sign = '-' if op.opcode == 1 else '+'
|
|
359
|
+
op_str = f'buf[{op.id0}] {_sign} buf[{op.id1}]<<{op.data}'
|
|
360
|
+
case 2 | -2:
|
|
361
|
+
_sign = '' if op.opcode == 2 else '-'
|
|
362
|
+
op_str = f'relu({_sign}buf[{op.id0}])'
|
|
363
|
+
case 3 | -3:
|
|
364
|
+
_sign = '' if op.opcode == 3 else '-'
|
|
365
|
+
op_str = f'quantize({_sign}buf[{op.id0}])'
|
|
348
366
|
case 4:
|
|
349
367
|
op_str = f'buf[{op.id0}] + {op.data * op.qint.step}'
|
|
350
368
|
case 5:
|
|
351
369
|
op_str = f'const {op.data * op.qint.step}'
|
|
370
|
+
case 6 | -6:
|
|
371
|
+
_sign = '-' if op.opcode == -6 else ''
|
|
372
|
+
op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
|
|
352
373
|
case _:
|
|
353
374
|
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
354
375
|
|
|
355
|
-
|
|
376
|
+
result = f'|-> buf[{i}] = {v}'
|
|
377
|
+
operands.append((op_str, result))
|
|
378
|
+
max_len = max(len(op[0]) for op in operands)
|
|
379
|
+
for op_str, result in operands:
|
|
380
|
+
print(f'{op_str:<{max_len}} {result}')
|
|
356
381
|
|
|
357
382
|
if dump:
|
|
358
383
|
return buf
|
|
@@ -443,6 +468,61 @@ class Solution(NamedTuple):
|
|
|
443
468
|
data = json.load(f)
|
|
444
469
|
return cls.deserialize(data)
|
|
445
470
|
|
|
471
|
+
@property
|
|
472
|
+
def ref_count(self) -> np.ndarray:
|
|
473
|
+
"""The number of references to the output elements in the solution."""
|
|
474
|
+
ref_count = np.zeros(len(self.ops), dtype=np.uint64)
|
|
475
|
+
for op in self.ops:
|
|
476
|
+
if op.opcode == -1:
|
|
477
|
+
continue
|
|
478
|
+
id0, id1 = op.id0, op.id1
|
|
479
|
+
if id0 != -1:
|
|
480
|
+
ref_count[id0] += 1
|
|
481
|
+
if id1 != -1:
|
|
482
|
+
ref_count[id1] += 1
|
|
483
|
+
if op.opcode in (6, -6):
|
|
484
|
+
# msb_mux operation
|
|
485
|
+
ref_count[op.data & 0xFFFFFFFF] += 1
|
|
486
|
+
for i in self.out_idxs:
|
|
487
|
+
if i < 0:
|
|
488
|
+
continue
|
|
489
|
+
ref_count[i] += 1
|
|
490
|
+
return ref_count
|
|
491
|
+
|
|
492
|
+
def to_binary(self):
|
|
493
|
+
n_in, n_out = self.shape
|
|
494
|
+
header_size_i32 = 2 + n_in + n_out * 3 + 1
|
|
495
|
+
|
|
496
|
+
header = np.concatenate(
|
|
497
|
+
[
|
|
498
|
+
[n_in, n_out, len(self.ops)],
|
|
499
|
+
self.inp_shift,
|
|
500
|
+
self.out_idxs,
|
|
501
|
+
self.out_shifts,
|
|
502
|
+
self.out_negs,
|
|
503
|
+
],
|
|
504
|
+
axis=0,
|
|
505
|
+
dtype=np.int32,
|
|
506
|
+
)
|
|
507
|
+
assert len(header) == header_size_i32, f'Header size mismatch: {len(header)} != {header_size_i32}'
|
|
508
|
+
code = np.empty((len(self.ops), 8), dtype=np.int32)
|
|
509
|
+
for i, op in enumerate(self.ops):
|
|
510
|
+
buf = code[i]
|
|
511
|
+
buf[0] = op.opcode
|
|
512
|
+
buf[1] = op.id0
|
|
513
|
+
buf[2] = op.id1
|
|
514
|
+
buf[5:] = _minimal_kif(op.qint)
|
|
515
|
+
buf_i64 = buf[3:5].view(np.int64)
|
|
516
|
+
buf_i64[0] = op.data
|
|
517
|
+
data = np.concatenate([header, code.flatten()])
|
|
518
|
+
return data
|
|
519
|
+
|
|
520
|
+
def save_binary(self, path: str | Path):
|
|
521
|
+
"""Dump the solution to a binary file."""
|
|
522
|
+
data = self.to_binary()
|
|
523
|
+
with open(path, 'wb') as f:
|
|
524
|
+
data.tofile(f)
|
|
525
|
+
|
|
446
526
|
|
|
447
527
|
class CascadedSolution(NamedTuple):
|
|
448
528
|
"""A solution that implements cascaded matrix-vector multiplications through multiple CMVM stages.
|
|
@@ -561,7 +641,7 @@ class CascadedSolution(NamedTuple):
|
|
|
561
641
|
@property
|
|
562
642
|
def reg_bits(self):
|
|
563
643
|
"""The number of bits used for the register in the solution."""
|
|
564
|
-
bits =
|
|
644
|
+
bits = sum(map(sum, (_minimal_kif(qint) for qint in self.inp_qint)))
|
|
565
645
|
for _sol in self.solutions:
|
|
566
646
|
kifs = [_minimal_kif(qint) for qint in _sol.out_qint]
|
|
567
647
|
_bits = sum(map(sum, kifs))
|
da4ml/codegen/__init__.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
from .cpp import cpp_logic_and_bridge_gen
|
|
2
|
-
from .verilog import
|
|
1
|
+
from .cpp import HLSModel, cpp_logic_and_bridge_gen
|
|
2
|
+
from .verilog import VerilogModel, binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_logic_gen
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
5
|
'cpp_logic_and_bridge_gen',
|
|
6
6
|
'comb_logic_gen',
|
|
7
7
|
'generate_io_wrapper',
|
|
8
|
-
'comb_binder_gen',
|
|
9
8
|
'pipeline_logic_gen',
|
|
10
|
-
'
|
|
9
|
+
'binder_gen',
|
|
10
|
+
'HLSModel',
|
|
11
|
+
'VerilogModel',
|
|
11
12
|
]
|
da4ml/codegen/cpp/__init__.py
CHANGED
da4ml/codegen/cpp/cpp_codegen.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
2
|
|
|
3
|
-
from ...cmvm.types import
|
|
3
|
+
from ...cmvm.types import QInterval, Solution, _minimal_kif
|
|
4
4
|
from ...trace.fixed_variable import _const_f
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
def kif_to_vitis_type(k: bool | int, i: int, f: int):
|
|
7
|
+
def kif_to_vitis_type(k: bool | int = 1, i: int = 0, f: int = 0):
|
|
8
8
|
if k == i == f == 0:
|
|
9
9
|
f = 1
|
|
10
|
-
return f'ap_{"" if k else "u"}fixed<{k+i+f},{k+i}>'
|
|
10
|
+
return f'ap_{"" if k else "u"}fixed<{k + i + f},{k + i}>'
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def kif_to_hlslib_type(k: bool | int, i: int, f: int):
|
|
13
|
+
def kif_to_hlslib_type(k: bool | int = 1, i: int = 0, f: int = 0):
|
|
14
14
|
if k == i == f == 0:
|
|
15
15
|
f = 1
|
|
16
|
-
return f'ac_fixed<{int(k)},{k+i+f},{k+i}>'
|
|
16
|
+
return f'ac_fixed<{int(k)},{k + i + f},{k + i}>'
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def get_typestr_fn(flavor: str):
|
|
@@ -27,13 +27,18 @@ def get_typestr_fn(flavor: str):
|
|
|
27
27
|
return typestr_fn
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def ssa_gen(
|
|
31
|
-
|
|
30
|
+
def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
|
|
31
|
+
ops = sol.ops
|
|
32
|
+
all_kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
32
33
|
all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
|
|
33
34
|
|
|
34
35
|
lines = []
|
|
35
|
-
|
|
36
|
+
ref_count = sol.ref_count
|
|
36
37
|
for i, op in enumerate(ops):
|
|
38
|
+
if ref_count[i] == 0:
|
|
39
|
+
# Skip unused ops
|
|
40
|
+
continue
|
|
41
|
+
|
|
37
42
|
_type = all_types[i]
|
|
38
43
|
|
|
39
44
|
ref0 = f'v{op.id0}'
|
|
@@ -42,12 +47,10 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
|
|
|
42
47
|
case -1:
|
|
43
48
|
# Input marker
|
|
44
49
|
val = f'inp[{ops[op.id0].id0}]'
|
|
45
|
-
|
|
46
50
|
case 0 | 1:
|
|
47
51
|
# Common a+/-b<<shift op
|
|
48
52
|
ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
|
|
49
53
|
val = f'{ref0} {"-" if op.opcode == 1 else "+"} {ref1}'
|
|
50
|
-
|
|
51
54
|
case 2 | -2:
|
|
52
55
|
if op.opcode == 2: # relu(inp)
|
|
53
56
|
if ops[op.id0].qint.min < 0:
|
|
@@ -59,11 +62,9 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
|
|
|
59
62
|
val = f'{ref0} > 0 ? {_type}(0) : {_type}(-{ref0})'
|
|
60
63
|
else:
|
|
61
64
|
val = f'-{ref0}'
|
|
62
|
-
|
|
63
65
|
case 3 | -3:
|
|
64
66
|
# Explicit quantization op, done implicitly via assignment
|
|
65
67
|
val = ref0 if op.opcode == 3 else f'-{ref0}'
|
|
66
|
-
|
|
67
68
|
case 4:
|
|
68
69
|
# Constant addition
|
|
69
70
|
_number = op.data * op.qint.step
|
|
@@ -71,10 +72,20 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
|
|
|
71
72
|
f = _const_f(mag)
|
|
72
73
|
const_type_str = typestr_fn(*_minimal_kif(QInterval(mag, mag, 2.0**-f)))
|
|
73
74
|
val = f'{ref0} {sign} {const_type_str}({mag})'
|
|
74
|
-
|
|
75
75
|
case 5:
|
|
76
|
+
# Define constant
|
|
76
77
|
_number = op.data * op.qint.step
|
|
77
78
|
val = f'{_number}'
|
|
79
|
+
case 6 | -6:
|
|
80
|
+
# MSB Mux
|
|
81
|
+
id_c = op.data & 0xFFFFFFFF
|
|
82
|
+
bw_k = sum(all_kifs[id_c])
|
|
83
|
+
shift = (op.data >> 32) & 0xFFFFFFFF
|
|
84
|
+
shift = shift if shift < 0x80000000 else shift - 0x100000000
|
|
85
|
+
ref_k = f'v{id_c}[{bw_k - 1}]'
|
|
86
|
+
sign = '-' if op.opcode == -6 else ''
|
|
87
|
+
ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
|
|
88
|
+
val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
|
|
78
89
|
|
|
79
90
|
case _:
|
|
80
91
|
raise ValueError(f'Unsupported opcode: {op.opcode}')
|
|
@@ -103,6 +114,15 @@ def output_gen(sol: Solution, typestr_fn: Callable[[bool | int, int, int], str])
|
|
|
103
114
|
return lines
|
|
104
115
|
|
|
105
116
|
|
|
117
|
+
def get_io_types(sol: Solution, flavor: str):
|
|
118
|
+
typestr_fn = get_typestr_fn(flavor)
|
|
119
|
+
in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
|
|
120
|
+
inp_type = typestr_fn(*in_kif)
|
|
121
|
+
out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
|
|
122
|
+
out_type = typestr_fn(*out_kif)
|
|
123
|
+
return inp_type, out_type
|
|
124
|
+
|
|
125
|
+
|
|
106
126
|
def cpp_logic_and_bridge_gen(
|
|
107
127
|
sol: Solution,
|
|
108
128
|
fn_name: str,
|
|
@@ -113,36 +133,49 @@ def cpp_logic_and_bridge_gen(
|
|
|
113
133
|
print_latency: bool = False,
|
|
114
134
|
):
|
|
115
135
|
typestr_fn = get_typestr_fn(flavor)
|
|
116
|
-
|
|
117
|
-
inp_type = typestr_fn(*in_kif)
|
|
118
|
-
out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
|
|
119
|
-
out_type = typestr_fn(*out_kif)
|
|
136
|
+
inp_t, out_t = get_io_types(sol, flavor)
|
|
120
137
|
|
|
121
138
|
n_in, n_out = sol.shape
|
|
122
139
|
template_def = 'template <typename inp_t, typename out_t>'
|
|
123
140
|
fn_signature = f'void {fn_name}(inp_t inp[{n_in}], out_t out[{n_out}])'
|
|
124
141
|
pragmas = pragmas or []
|
|
125
142
|
|
|
126
|
-
ssa_lines = ssa_gen(sol
|
|
143
|
+
ssa_lines = ssa_gen(sol, print_latency=print_latency, typestr_fn=typestr_fn)
|
|
127
144
|
output_lines = output_gen(sol, typestr_fn=typestr_fn)
|
|
128
145
|
|
|
129
146
|
indent = ' ' * n_indent
|
|
130
147
|
base_indent = indent * n_base_indent
|
|
131
148
|
body_indent = '\n' + base_indent + indent
|
|
132
149
|
code = f"""{base_indent}{template_def}
|
|
133
|
-
{base_indent}{fn_signature} {{ // {
|
|
134
|
-
{
|
|
150
|
+
{base_indent}{fn_signature} {{ // {inp_t} -> {out_t}
|
|
151
|
+
{base_indent + indent}{body_indent.join(pragmas)}
|
|
135
152
|
{body_indent}{body_indent.join(ssa_lines)}
|
|
136
153
|
{body_indent}{body_indent.join(output_lines)}
|
|
137
154
|
{base_indent}}}
|
|
138
155
|
"""
|
|
139
|
-
bridge = f"""#include "
|
|
140
|
-
#include "
|
|
156
|
+
bridge = f"""#include "binder_util.hh"
|
|
157
|
+
#include "{fn_name}.hh"
|
|
158
|
+
|
|
159
|
+
struct {fn_name}_config {{
|
|
160
|
+
static const size_t N_inp = {n_in};
|
|
161
|
+
static const size_t N_out = {n_out};
|
|
162
|
+
typedef {inp_t} inp_t;
|
|
163
|
+
typedef {out_t} out_t;
|
|
164
|
+
constexpr static auto f = {fn_name}<inp_t, out_t>;
|
|
165
|
+
}};
|
|
141
166
|
|
|
142
167
|
extern "C" {{
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
168
|
+
|
|
169
|
+
bool openmp_enabled() {{
|
|
170
|
+
return _openmp;
|
|
171
|
+
}}
|
|
172
|
+
|
|
173
|
+
void inference_f64(double *inp, double *out, size_t size) {{
|
|
174
|
+
batch_inference<{fn_name}_config, double>(inp, out, size);
|
|
175
|
+
}}
|
|
176
|
+
|
|
177
|
+
void inference_f32(float *inp, float *out, size_t size) {{
|
|
178
|
+
batch_inference<{fn_name}_config, float>(inp, out, size);
|
|
146
179
|
}}
|
|
147
180
|
}}"""
|
|
148
181
|
return code, bridge
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import shutil
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TypeVar
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
|
|
15
|
+
from da4ml.cmvm.types import Solution
|
|
16
|
+
from da4ml.codegen.cpp.cpp_codegen import cpp_logic_and_bridge_gen, get_io_types
|
|
17
|
+
|
|
18
|
+
from ... import codegen
|
|
19
|
+
from ...cmvm.types import _minimal_kif
|
|
20
|
+
|
|
21
|
+
T = TypeVar('T', bound=np.floating)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HLSModel:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
solution: Solution,
|
|
28
|
+
prj_name: str,
|
|
29
|
+
path: str | Path,
|
|
30
|
+
flavor: str = 'vitis',
|
|
31
|
+
print_latency: bool = True,
|
|
32
|
+
part_name: str = 'xcvu13p-flga2577-2-e',
|
|
33
|
+
pragma: Sequence[str] | None = None,
|
|
34
|
+
clock_period: int = 5,
|
|
35
|
+
clock_uncertainty: float = 0.1,
|
|
36
|
+
io_delay_minmax: tuple[float, float] = (0.2, 0.4),
|
|
37
|
+
):
|
|
38
|
+
self._solution = solution
|
|
39
|
+
self._prj_name = prj_name
|
|
40
|
+
self._path = Path(path)
|
|
41
|
+
self._flavor = flavor.lower()
|
|
42
|
+
assert self._flavor in ('vitis', 'hlslib'), f'Unsupported HLS flavor: {self._flavor}'
|
|
43
|
+
self._print_latency = print_latency
|
|
44
|
+
self._part_name = part_name
|
|
45
|
+
self._clock_period = clock_period
|
|
46
|
+
self._clock_uncertainty = clock_uncertainty
|
|
47
|
+
self._io_delay_minmax = io_delay_minmax
|
|
48
|
+
self.__src_root = Path(codegen.__file__).parent
|
|
49
|
+
self._lib = None
|
|
50
|
+
self._uuid = None
|
|
51
|
+
|
|
52
|
+
if pragma is None:
|
|
53
|
+
if self._flavor == 'vitis':
|
|
54
|
+
self._pragma = (
|
|
55
|
+
'#pragma HLS ARRAY_PARTITION variable=inp complete',
|
|
56
|
+
'#pragma HLS ARRAY_PARTITION variable=out complete',
|
|
57
|
+
'#pragma HLS PIPELINE II=1',
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
self._pragma = ()
|
|
61
|
+
else:
|
|
62
|
+
self._pragma = tuple(pragma)
|
|
63
|
+
|
|
64
|
+
def write(self):
|
|
65
|
+
if not self._path.exists():
|
|
66
|
+
self._path.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
template_def, bridge = cpp_logic_and_bridge_gen(
|
|
68
|
+
self._solution,
|
|
69
|
+
self._prj_name,
|
|
70
|
+
self._flavor,
|
|
71
|
+
['#pragma HLS INLINE'],
|
|
72
|
+
4,
|
|
73
|
+
0,
|
|
74
|
+
self._print_latency,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
headers = ['#pragma once', '#include "bitshift.hh"']
|
|
78
|
+
|
|
79
|
+
inp_type, out_type = get_io_types(self._solution, self._flavor)
|
|
80
|
+
n_in, n_out = len(self._solution.inp_qint), len(self._solution.out_qint)
|
|
81
|
+
template_signature = (
|
|
82
|
+
f'template <typename inp_t, typename out_t>\nvoid {self._prj_name}(inp_t inp[{n_in}], out_t out[{n_out}]);'
|
|
83
|
+
)
|
|
84
|
+
fn_signature = f'void {self._prj_name}_fn({inp_type} inp[{n_in}], {out_type} out[{n_out}])'
|
|
85
|
+
|
|
86
|
+
with open(self._path / f'{self._prj_name}.hh', 'w') as f:
|
|
87
|
+
f.write('\n'.join(headers) + '\n\n')
|
|
88
|
+
f.write(f'{template_signature}\n\n{fn_signature};\n')
|
|
89
|
+
|
|
90
|
+
pragma_str = '\n'.join(self._pragma)
|
|
91
|
+
cpp_def = f"""
|
|
92
|
+
#include "{self._prj_name}.hh"
|
|
93
|
+
|
|
94
|
+
{template_def}
|
|
95
|
+
|
|
96
|
+
{fn_signature} {{
|
|
97
|
+
{pragma_str}
|
|
98
|
+
{self._prj_name}<{inp_type}, {out_type}>(inp, out);
|
|
99
|
+
}}
|
|
100
|
+
"""
|
|
101
|
+
with open(self._path / f'{self._prj_name}.cc', 'w') as f:
|
|
102
|
+
f.write(cpp_def)
|
|
103
|
+
|
|
104
|
+
with open(self._path / f'{self._prj_name}_bridge.cc', 'w') as f:
|
|
105
|
+
f.write(bridge)
|
|
106
|
+
|
|
107
|
+
shutil.copy(self.__src_root / 'cpp/source/binder_util.hh', self._path)
|
|
108
|
+
shutil.copy(self.__src_root / f'cpp/source/{self._flavor}_bitshift.hh', self._path / 'bitshift.hh')
|
|
109
|
+
shutil.copy(self.__src_root / 'cpp/source/build_binder.mk', self._path)
|
|
110
|
+
if self._flavor == 'vitis':
|
|
111
|
+
shutil.copytree(self.__src_root / 'cpp/source/ap_types', self._path / 'ap_types', dirs_exist_ok=True)
|
|
112
|
+
else:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
self._solution.save(self._path / 'project.json')
|
|
116
|
+
|
|
117
|
+
def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
|
|
118
|
+
"""Same as compile, but will not write to the library
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
verbose : bool, optional
|
|
123
|
+
Verbose output, by default False
|
|
124
|
+
openmp : bool, optional
|
|
125
|
+
Enable openmp, by default True
|
|
126
|
+
o3 : bool | None, optional
|
|
127
|
+
Turn on -O3 flag, by default False
|
|
128
|
+
clean : bool, optional
|
|
129
|
+
Remove obsolete shared object files, by default True
|
|
130
|
+
|
|
131
|
+
Raises
|
|
132
|
+
------
|
|
133
|
+
RuntimeError
|
|
134
|
+
If compilation fails
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
self._uuid = str(uuid4())
|
|
138
|
+
args = ['make', '-f', 'build_binder.mk']
|
|
139
|
+
env = os.environ.copy()
|
|
140
|
+
env['PRJ_NAME'] = self._prj_name
|
|
141
|
+
env['STAMP'] = self._uuid
|
|
142
|
+
env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
|
|
143
|
+
if o3:
|
|
144
|
+
args.append('fast')
|
|
145
|
+
|
|
146
|
+
if clean:
|
|
147
|
+
m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
|
|
148
|
+
for p in self._path.iterdir():
|
|
149
|
+
if not p.is_dir() and m.match(p.name):
|
|
150
|
+
p.unlink()
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
|
|
154
|
+
except subprocess.CalledProcessError as e:
|
|
155
|
+
print(e.stderr.decode(), file=sys.stderr)
|
|
156
|
+
print(e.stdout.decode(), file=sys.stdout)
|
|
157
|
+
raise RuntimeError('Compilation failed!!') from e
|
|
158
|
+
if r.returncode != 0:
|
|
159
|
+
print(r.stderr.decode(), file=sys.stderr)
|
|
160
|
+
print(r.stdout.decode(), file=sys.stderr)
|
|
161
|
+
raise RuntimeError('Compilation failed!!')
|
|
162
|
+
|
|
163
|
+
self._load_lib(self._uuid)
|
|
164
|
+
|
|
165
|
+
def _load_lib(self, uuid: str | None = None):
|
|
166
|
+
uuid = uuid if uuid is not None else self._uuid
|
|
167
|
+
self._uuid = uuid
|
|
168
|
+
lib_path = self._path / f'lib{self._prj_name}_{uuid}.so'
|
|
169
|
+
if not lib_path.exists():
|
|
170
|
+
raise RuntimeError(f'Library {lib_path} does not exist')
|
|
171
|
+
self._lib = ctypes.CDLL(str(lib_path))
|
|
172
|
+
|
|
173
|
+
def compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
|
|
174
|
+
"""Compile the model to a shared object file
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
verbose : bool, optional
|
|
179
|
+
Verbose output, by default False
|
|
180
|
+
openmp : bool, optional
|
|
181
|
+
Enable openmp, by default True
|
|
182
|
+
o3 : bool | None, optional
|
|
183
|
+
Turn on -O3 flag, by default False
|
|
184
|
+
clean : bool, optional
|
|
185
|
+
Remove obsolete shared object files, by default True
|
|
186
|
+
|
|
187
|
+
Raises
|
|
188
|
+
------
|
|
189
|
+
RuntimeError
|
|
190
|
+
If compilation fails
|
|
191
|
+
"""
|
|
192
|
+
self.write()
|
|
193
|
+
self._compile(verbose, openmp, o3, clean)
|
|
194
|
+
|
|
195
|
+
def predict(self, data: NDArray[T]) -> NDArray[T]:
|
|
196
|
+
"""Run the model on the input data.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
data : NDArray[np.floating]
|
|
201
|
+
Input data to the model. The shape is ignored, and the number of samples is
|
|
202
|
+
determined by the size of the data.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
NDArray[np.floating]
|
|
207
|
+
Output of the model in shape (n_samples, output_size).
|
|
208
|
+
"""
|
|
209
|
+
assert self._lib is not None, 'Library not loaded, call .compile() first.'
|
|
210
|
+
inp_size, out_size = self._solution.shape
|
|
211
|
+
|
|
212
|
+
dtype = data.dtype
|
|
213
|
+
if dtype not in (np.float32, np.float64):
|
|
214
|
+
raise TypeError(f'Unsupported input data type: {dtype}. Expected float32 or float64.')
|
|
215
|
+
c_dtype = ctypes.c_float if dtype == np.float32 else ctypes.c_double
|
|
216
|
+
|
|
217
|
+
assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
|
|
218
|
+
n_sample = data.size // inp_size
|
|
219
|
+
|
|
220
|
+
inp_data = np.ascontiguousarray(data)
|
|
221
|
+
out_data = np.empty(n_sample * out_size, dtype=dtype)
|
|
222
|
+
|
|
223
|
+
inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(c_dtype))
|
|
224
|
+
out_buf = out_data.ctypes.data_as(ctypes.POINTER(c_dtype))
|
|
225
|
+
if dtype == np.float32:
|
|
226
|
+
self._lib.inference_f32(inp_buf, out_buf, n_sample)
|
|
227
|
+
else:
|
|
228
|
+
self._lib.inference_f64(inp_buf, out_buf, n_sample)
|
|
229
|
+
|
|
230
|
+
return out_data.reshape(n_sample, out_size) # type: ignore
|
|
231
|
+
|
|
232
|
+
def __repr__(self):
|
|
233
|
+
inp_size, out_size = self._solution.shape
|
|
234
|
+
inp_size, out_size = self._solution.shape
|
|
235
|
+
cost = round(self._solution.cost)
|
|
236
|
+
inp_kifs = tuple(zip(*map(_minimal_kif, self._solution.inp_qint)))
|
|
237
|
+
out_kifs = tuple(zip(*map(_minimal_kif, self._solution.out_qint)))
|
|
238
|
+
in_bits, out_bits = np.sum(inp_kifs), np.sum(out_kifs)
|
|
239
|
+
|
|
240
|
+
spec = f"""Top Function: {self._prj_name}\n====================
|
|
241
|
+
{inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
|
|
242
|
+
combinational @ delay={self._solution.latency}
|
|
243
|
+
Estimated cost: {cost} LUTs"""
|
|
244
|
+
|
|
245
|
+
is_compiled = self._lib is not None
|
|
246
|
+
if is_compiled:
|
|
247
|
+
assert self._uuid is not None
|
|
248
|
+
openmp = 'with OpenMP' if self._lib.openmp_enabled() else '' # type: ignore
|
|
249
|
+
spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
|
|
250
|
+
else:
|
|
251
|
+
spec += '\nEmulator is **not compiled**'
|
|
252
|
+
return spec
|