da4ml 0.2.1__py3-none-any.whl → 0.3.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/_version.py +2 -2
- da4ml/cmvm/types.py +95 -15
- da4ml/codegen/__init__.py +5 -4
- da4ml/codegen/cpp/__init__.py +2 -1
- da4ml/codegen/cpp/cpp_codegen.py +56 -23
- da4ml/codegen/cpp/hls_model.py +252 -0
- da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/cpp/source/binder_util.hh +56 -0
- da4ml/codegen/cpp/source/build_binder.mk +24 -0
- da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
- da4ml/codegen/verilog/__init__.py +2 -3
- da4ml/codegen/verilog/comb.py +65 -24
- da4ml/codegen/verilog/io_wrapper.py +36 -141
- da4ml/codegen/verilog/source/binder_util.hh +72 -0
- da4ml/codegen/verilog/source/mux.v +58 -0
- da4ml/codegen/verilog/source/negative.v +28 -0
- da4ml/codegen/verilog/source/shift_adder.v +4 -1
- da4ml/codegen/verilog/source/template.xdc +3 -0
- da4ml/codegen/verilog/verilog_model.py +36 -12
- da4ml/converter/__init__.py +0 -0
- da4ml/converter/hgq2/parser.py +105 -0
- da4ml/converter/hgq2/replica.py +383 -0
- da4ml/trace/__init__.py +2 -2
- da4ml/trace/fixed_variable.py +175 -16
- da4ml/trace/fixed_variable_array.py +109 -4
- da4ml/trace/ops/__init__.py +22 -6
- da4ml/trace/ops/conv_utils.py +147 -15
- da4ml/trace/ops/einsum_utils.py +9 -6
- da4ml/trace/ops/reduce_utils.py +103 -0
- da4ml/trace/pipeline.py +36 -34
- da4ml/trace/tracer.py +37 -7
- da4ml-0.3.0.post1.dist-info/METADATA +107 -0
- da4ml-0.3.0.post1.dist-info/RECORD +64 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
- da4ml-0.2.1.dist-info/METADATA +0 -65
- da4ml-0.2.1.dist-info/RECORD +0 -39
- /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
- {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/WHEEL +0 -0
- {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/top_level.txt +0 -0
da4ml/_version.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.
|
|
21
|
-
__version_tuple__ = version_tuple = (0,
|
|
20
|
+
__version__ = version = '0.3.0.post1'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 3, 0, 'post1')
|
da4ml/cmvm/types.py
CHANGED
|
@@ -291,6 +291,9 @@ class Solution(NamedTuple):
|
|
|
291
291
|
The output data after applying the operations defined in the solution.
|
|
292
292
|
|
|
293
293
|
"""
|
|
294
|
+
|
|
295
|
+
from ..trace.fixed_variable import FixedVariable
|
|
296
|
+
|
|
294
297
|
buf = np.empty(len(self.ops), dtype=object)
|
|
295
298
|
inp = np.asarray(inp)
|
|
296
299
|
|
|
@@ -320,39 +323,61 @@ class Solution(NamedTuple):
|
|
|
320
323
|
buf[i] = buf[op.id0] + bias
|
|
321
324
|
case 5:
|
|
322
325
|
buf[i] = op.data * op.qint.step # const definition
|
|
326
|
+
case 6 | -6: # MSB Mux
|
|
327
|
+
id_c = op.data & 0xFFFFFFFF
|
|
328
|
+
k, v0, v1 = buf[id_c], buf[op.id0], buf[op.id1]
|
|
329
|
+
shift = (op.data >> 32) & 0xFFFFFFFF
|
|
330
|
+
shift = shift if shift < 0x80000000 else shift - 0x100000000
|
|
331
|
+
if op.opcode == -6:
|
|
332
|
+
v1 = -v1
|
|
333
|
+
|
|
334
|
+
if isinstance(k, FixedVariable):
|
|
335
|
+
buf[i] = k.msb_mux(v0, v1 * 2**shift)
|
|
336
|
+
else:
|
|
337
|
+
qint_k = self.ops[id_c].qint
|
|
338
|
+
if qint_k.min < 0:
|
|
339
|
+
buf[i] = v0 if k < 0 else v1 * 2.0**shift
|
|
340
|
+
else:
|
|
341
|
+
_k, _i, _f = _minimal_kif(qint_k)
|
|
342
|
+
buf[i] = v0 if k >= 2.0 ** (_i - 1) else v1 * 2.0**shift
|
|
323
343
|
case _:
|
|
324
344
|
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
325
345
|
|
|
326
|
-
sf = 2.0 ** np.array(self.out_shifts)
|
|
346
|
+
sf = 2.0 ** np.array(self.out_shifts, dtype=np.float64)
|
|
327
347
|
sign = np.where(self.out_negs, -1, 1)
|
|
328
|
-
out_idx = np.array(self.out_idxs)
|
|
348
|
+
out_idx = np.array(self.out_idxs, dtype=np.int32)
|
|
329
349
|
mask = np.where(out_idx < 0, 0, 1)
|
|
330
350
|
if debug:
|
|
351
|
+
operands = []
|
|
331
352
|
for i, v in enumerate(buf):
|
|
332
353
|
op = self.ops[i]
|
|
333
354
|
match op.opcode:
|
|
334
355
|
case -1:
|
|
335
356
|
op_str = 'inp'
|
|
336
|
-
case 0:
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
op_str = f'relu(buf[{op.id0}])'
|
|
342
|
-
case -
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
op_str = f'quantize(buf[{op.id0}])'
|
|
346
|
-
case -3:
|
|
347
|
-
op_str = f'quantize(-buf[{op.id0}])'
|
|
357
|
+
case 0 | 1:
|
|
358
|
+
_sign = '-' if op.opcode == 1 else '+'
|
|
359
|
+
op_str = f'buf[{op.id0}] {_sign} buf[{op.id1}]<<{op.data}'
|
|
360
|
+
case 2 | -2:
|
|
361
|
+
_sign = '' if op.opcode == 2 else '-'
|
|
362
|
+
op_str = f'relu({_sign}buf[{op.id0}])'
|
|
363
|
+
case 3 | -3:
|
|
364
|
+
_sign = '' if op.opcode == 3 else '-'
|
|
365
|
+
op_str = f'quantize({_sign}buf[{op.id0}])'
|
|
348
366
|
case 4:
|
|
349
367
|
op_str = f'buf[{op.id0}] + {op.data * op.qint.step}'
|
|
350
368
|
case 5:
|
|
351
369
|
op_str = f'const {op.data * op.qint.step}'
|
|
370
|
+
case 6 | -6:
|
|
371
|
+
_sign = '-' if op.opcode == -6 else ''
|
|
372
|
+
op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
|
|
352
373
|
case _:
|
|
353
374
|
raise ValueError(f'Unknown opcode {op.opcode} in {op}')
|
|
354
375
|
|
|
355
|
-
|
|
376
|
+
result = f'|-> buf[{i}] = {v}'
|
|
377
|
+
operands.append((op_str, result))
|
|
378
|
+
max_len = max(len(op[0]) for op in operands)
|
|
379
|
+
for op_str, result in operands:
|
|
380
|
+
print(f'{op_str:<{max_len}} {result}')
|
|
356
381
|
|
|
357
382
|
if dump:
|
|
358
383
|
return buf
|
|
@@ -443,6 +468,61 @@ class Solution(NamedTuple):
|
|
|
443
468
|
data = json.load(f)
|
|
444
469
|
return cls.deserialize(data)
|
|
445
470
|
|
|
471
|
+
@property
|
|
472
|
+
def ref_count(self) -> np.ndarray:
|
|
473
|
+
"""The number of references to the output elements in the solution."""
|
|
474
|
+
ref_count = np.zeros(len(self.ops), dtype=np.uint64)
|
|
475
|
+
for op in self.ops:
|
|
476
|
+
if op.opcode == -1:
|
|
477
|
+
continue
|
|
478
|
+
id0, id1 = op.id0, op.id1
|
|
479
|
+
if id0 != -1:
|
|
480
|
+
ref_count[id0] += 1
|
|
481
|
+
if id1 != -1:
|
|
482
|
+
ref_count[id1] += 1
|
|
483
|
+
if op.opcode in (6, -6):
|
|
484
|
+
# msb_mux operation
|
|
485
|
+
ref_count[op.data & 0xFFFFFFFF] += 1
|
|
486
|
+
for i in self.out_idxs:
|
|
487
|
+
if i < 0:
|
|
488
|
+
continue
|
|
489
|
+
ref_count[i] += 1
|
|
490
|
+
return ref_count
|
|
491
|
+
|
|
492
|
+
def to_binary(self):
|
|
493
|
+
n_in, n_out = self.shape
|
|
494
|
+
header_size_i32 = 2 + n_in + n_out * 3 + 1
|
|
495
|
+
|
|
496
|
+
header = np.concatenate(
|
|
497
|
+
[
|
|
498
|
+
[n_in, n_out, len(self.ops)],
|
|
499
|
+
self.inp_shift,
|
|
500
|
+
self.out_idxs,
|
|
501
|
+
self.out_shifts,
|
|
502
|
+
self.out_negs,
|
|
503
|
+
],
|
|
504
|
+
axis=0,
|
|
505
|
+
dtype=np.int32,
|
|
506
|
+
)
|
|
507
|
+
assert len(header) == header_size_i32, f'Header size mismatch: {len(header)} != {header_size_i32}'
|
|
508
|
+
code = np.empty((len(self.ops), 8), dtype=np.int32)
|
|
509
|
+
for i, op in enumerate(self.ops):
|
|
510
|
+
buf = code[i]
|
|
511
|
+
buf[0] = op.opcode
|
|
512
|
+
buf[1] = op.id0
|
|
513
|
+
buf[2] = op.id1
|
|
514
|
+
buf[5:] = _minimal_kif(op.qint)
|
|
515
|
+
buf_i64 = buf[3:5].view(np.int64)
|
|
516
|
+
buf_i64[0] = op.data
|
|
517
|
+
data = np.concatenate([header, code.flatten()])
|
|
518
|
+
return data
|
|
519
|
+
|
|
520
|
+
def save_binary(self, path: str | Path):
|
|
521
|
+
"""Dump the solution to a binary file."""
|
|
522
|
+
data = self.to_binary()
|
|
523
|
+
with open(path, 'wb') as f:
|
|
524
|
+
data.tofile(f)
|
|
525
|
+
|
|
446
526
|
|
|
447
527
|
class CascadedSolution(NamedTuple):
|
|
448
528
|
"""A solution that implements cascaded matrix-vector multiplications through multiple CMVM stages.
|
da4ml/codegen/__init__.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
from .cpp import cpp_logic_and_bridge_gen
|
|
2
|
-
from .verilog import
|
|
1
|
+
from .cpp import HLSModel, cpp_logic_and_bridge_gen
|
|
2
|
+
from .verilog import VerilogModel, binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_logic_gen
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
5
|
'cpp_logic_and_bridge_gen',
|
|
6
6
|
'comb_logic_gen',
|
|
7
7
|
'generate_io_wrapper',
|
|
8
|
-
'comb_binder_gen',
|
|
9
8
|
'pipeline_logic_gen',
|
|
10
|
-
'
|
|
9
|
+
'binder_gen',
|
|
10
|
+
'HLSModel',
|
|
11
|
+
'VerilogModel',
|
|
11
12
|
]
|
da4ml/codegen/cpp/__init__.py
CHANGED
da4ml/codegen/cpp/cpp_codegen.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
2
|
|
|
3
|
-
from ...cmvm.types import
|
|
3
|
+
from ...cmvm.types import QInterval, Solution, _minimal_kif
|
|
4
4
|
from ...trace.fixed_variable import _const_f
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def kif_to_vitis_type(k: bool | int = 1, i: int = 0, f: int = 0):
|
|
8
8
|
if k == i == f == 0:
|
|
9
9
|
f = 1
|
|
10
|
-
return f'ap_{"" if k else "u"}fixed<{k+i+f},{k+i}>'
|
|
10
|
+
return f'ap_{"" if k else "u"}fixed<{k + i + f},{k + i}>'
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def kif_to_hlslib_type(k: bool | int = 1, i: int = 0, f: int = 0):
|
|
14
14
|
if k == i == f == 0:
|
|
15
15
|
f = 1
|
|
16
|
-
return f'ac_fixed<{int(k)},{k+i+f},{k+i}>'
|
|
16
|
+
return f'ac_fixed<{int(k)},{k + i + f},{k + i}>'
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def get_typestr_fn(flavor: str):
|
|
@@ -27,13 +27,18 @@ def get_typestr_fn(flavor: str):
|
|
|
27
27
|
return typestr_fn
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def ssa_gen(
|
|
31
|
-
|
|
30
|
+
def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
|
|
31
|
+
ops = sol.ops
|
|
32
|
+
all_kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
32
33
|
all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
|
|
33
34
|
|
|
34
35
|
lines = []
|
|
35
|
-
|
|
36
|
+
ref_count = sol.ref_count
|
|
36
37
|
for i, op in enumerate(ops):
|
|
38
|
+
if ref_count[i] == 0:
|
|
39
|
+
# Skip unused ops
|
|
40
|
+
continue
|
|
41
|
+
|
|
37
42
|
_type = all_types[i]
|
|
38
43
|
|
|
39
44
|
ref0 = f'v{op.id0}'
|
|
@@ -42,12 +47,10 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
|
|
|
42
47
|
case -1:
|
|
43
48
|
# Input marker
|
|
44
49
|
val = f'inp[{ops[op.id0].id0}]'
|
|
45
|
-
|
|
46
50
|
case 0 | 1:
|
|
47
51
|
# Common a+/-b<<shift op
|
|
48
52
|
ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
|
|
49
53
|
val = f'{ref0} {"-" if op.opcode == 1 else "+"} {ref1}'
|
|
50
|
-
|
|
51
54
|
case 2 | -2:
|
|
52
55
|
if op.opcode == 2: # relu(inp)
|
|
53
56
|
if ops[op.id0].qint.min < 0:
|
|
@@ -59,11 +62,9 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
|
|
|
59
62
|
val = f'{ref0} > 0 ? {_type}(0) : {_type}(-{ref0})'
|
|
60
63
|
else:
|
|
61
64
|
val = f'-{ref0}'
|
|
62
|
-
|
|
63
65
|
case 3 | -3:
|
|
64
66
|
# Explicit quantization op, done implicitly via assignment
|
|
65
67
|
val = ref0 if op.opcode == 3 else f'-{ref0}'
|
|
66
|
-
|
|
67
68
|
case 4:
|
|
68
69
|
# Constant addition
|
|
69
70
|
_number = op.data * op.qint.step
|
|
@@ -71,10 +72,20 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
|
|
|
71
72
|
f = _const_f(mag)
|
|
72
73
|
const_type_str = typestr_fn(*_minimal_kif(QInterval(mag, mag, 2.0**-f)))
|
|
73
74
|
val = f'{ref0} {sign} {const_type_str}({mag})'
|
|
74
|
-
|
|
75
75
|
case 5:
|
|
76
|
+
# Define constant
|
|
76
77
|
_number = op.data * op.qint.step
|
|
77
78
|
val = f'{_number}'
|
|
79
|
+
case 6 | -6:
|
|
80
|
+
# MSB Mux
|
|
81
|
+
id_c = op.data & 0xFFFFFFFF
|
|
82
|
+
bw_k = sum(all_kifs[id_c])
|
|
83
|
+
shift = (op.data >> 32) & 0xFFFFFFFF
|
|
84
|
+
shift = shift if shift < 0x80000000 else shift - 0x100000000
|
|
85
|
+
ref_k = f'v{id_c}[{bw_k - 1}]'
|
|
86
|
+
sign = '-' if op.opcode == -6 else ''
|
|
87
|
+
ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
|
|
88
|
+
val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
|
|
78
89
|
|
|
79
90
|
case _:
|
|
80
91
|
raise ValueError(f'Unsupported opcode: {op.opcode}')
|
|
@@ -103,6 +114,15 @@ def output_gen(sol: Solution, typestr_fn: Callable[[bool | int, int, int], str])
|
|
|
103
114
|
return lines
|
|
104
115
|
|
|
105
116
|
|
|
117
|
+
def get_io_types(sol: Solution, flavor: str):
|
|
118
|
+
typestr_fn = get_typestr_fn(flavor)
|
|
119
|
+
in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
|
|
120
|
+
inp_type = typestr_fn(*in_kif)
|
|
121
|
+
out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
|
|
122
|
+
out_type = typestr_fn(*out_kif)
|
|
123
|
+
return inp_type, out_type
|
|
124
|
+
|
|
125
|
+
|
|
106
126
|
def cpp_logic_and_bridge_gen(
|
|
107
127
|
sol: Solution,
|
|
108
128
|
fn_name: str,
|
|
@@ -113,36 +133,49 @@ def cpp_logic_and_bridge_gen(
|
|
|
113
133
|
print_latency: bool = False,
|
|
114
134
|
):
|
|
115
135
|
typestr_fn = get_typestr_fn(flavor)
|
|
116
|
-
|
|
117
|
-
inp_type = typestr_fn(*in_kif)
|
|
118
|
-
out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
|
|
119
|
-
out_type = typestr_fn(*out_kif)
|
|
136
|
+
inp_t, out_t = get_io_types(sol, flavor)
|
|
120
137
|
|
|
121
138
|
n_in, n_out = sol.shape
|
|
122
139
|
template_def = 'template <typename inp_t, typename out_t>'
|
|
123
140
|
fn_signature = f'void {fn_name}(inp_t inp[{n_in}], out_t out[{n_out}])'
|
|
124
141
|
pragmas = pragmas or []
|
|
125
142
|
|
|
126
|
-
ssa_lines = ssa_gen(sol
|
|
143
|
+
ssa_lines = ssa_gen(sol, print_latency=print_latency, typestr_fn=typestr_fn)
|
|
127
144
|
output_lines = output_gen(sol, typestr_fn=typestr_fn)
|
|
128
145
|
|
|
129
146
|
indent = ' ' * n_indent
|
|
130
147
|
base_indent = indent * n_base_indent
|
|
131
148
|
body_indent = '\n' + base_indent + indent
|
|
132
149
|
code = f"""{base_indent}{template_def}
|
|
133
|
-
{base_indent}{fn_signature} {{ // {
|
|
134
|
-
{
|
|
150
|
+
{base_indent}{fn_signature} {{ // {inp_t} -> {out_t}
|
|
151
|
+
{base_indent + indent}{body_indent.join(pragmas)}
|
|
135
152
|
{body_indent}{body_indent.join(ssa_lines)}
|
|
136
153
|
{body_indent}{body_indent.join(output_lines)}
|
|
137
154
|
{base_indent}}}
|
|
138
155
|
"""
|
|
139
|
-
bridge = f"""#include "
|
|
140
|
-
#include "
|
|
156
|
+
bridge = f"""#include "binder_util.hh"
|
|
157
|
+
#include "{fn_name}.hh"
|
|
158
|
+
|
|
159
|
+
struct {fn_name}_config {{
|
|
160
|
+
static const size_t N_inp = {n_in};
|
|
161
|
+
static const size_t N_out = {n_out};
|
|
162
|
+
typedef {inp_t} inp_t;
|
|
163
|
+
typedef {out_t} out_t;
|
|
164
|
+
constexpr static auto f = {fn_name}<inp_t, out_t>;
|
|
165
|
+
}};
|
|
141
166
|
|
|
142
167
|
extern "C" {{
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
168
|
+
|
|
169
|
+
bool openmp_enabled() {{
|
|
170
|
+
return _openmp;
|
|
171
|
+
}}
|
|
172
|
+
|
|
173
|
+
void inference_f64(double *inp, double *out, size_t size) {{
|
|
174
|
+
batch_inference<{fn_name}_config, double>(inp, out, size);
|
|
175
|
+
}}
|
|
176
|
+
|
|
177
|
+
void inference_f32(float *inp, float *out, size_t size) {{
|
|
178
|
+
batch_inference<{fn_name}_config, float>(inp, out, size);
|
|
146
179
|
}}
|
|
147
180
|
}}"""
|
|
148
181
|
return code, bridge
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import shutil
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TypeVar
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
|
|
15
|
+
from da4ml.cmvm.types import Solution
|
|
16
|
+
from da4ml.codegen.cpp.cpp_codegen import cpp_logic_and_bridge_gen, get_io_types
|
|
17
|
+
|
|
18
|
+
from ... import codegen
|
|
19
|
+
from ...cmvm.types import _minimal_kif
|
|
20
|
+
|
|
21
|
+
T = TypeVar('T', bound=np.floating)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HLSModel:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
solution: Solution,
|
|
28
|
+
prj_name: str,
|
|
29
|
+
path: str | Path,
|
|
30
|
+
flavor: str = 'vitis',
|
|
31
|
+
print_latency: bool = True,
|
|
32
|
+
part_name: str = 'xcvu13p-flga2577-2-e',
|
|
33
|
+
pragma: Sequence[str] | None = None,
|
|
34
|
+
clock_period: int = 5,
|
|
35
|
+
clock_uncertainty: float = 0.1,
|
|
36
|
+
io_delay_minmax: tuple[float, float] = (0.2, 0.4),
|
|
37
|
+
):
|
|
38
|
+
self._solution = solution
|
|
39
|
+
self._prj_name = prj_name
|
|
40
|
+
self._path = Path(path)
|
|
41
|
+
self._flavor = flavor.lower()
|
|
42
|
+
assert self._flavor in ('vitis', 'hlslib'), f'Unsupported HLS flavor: {self._flavor}'
|
|
43
|
+
self._print_latency = print_latency
|
|
44
|
+
self._part_name = part_name
|
|
45
|
+
self._clock_period = clock_period
|
|
46
|
+
self._clock_uncertainty = clock_uncertainty
|
|
47
|
+
self._io_delay_minmax = io_delay_minmax
|
|
48
|
+
self.__src_root = Path(codegen.__file__).parent
|
|
49
|
+
self._lib = None
|
|
50
|
+
self._uuid = None
|
|
51
|
+
|
|
52
|
+
if pragma is None:
|
|
53
|
+
if self._flavor == 'vitis':
|
|
54
|
+
self._pragma = (
|
|
55
|
+
'#pragma HLS ARRAY_PARTITION variable=inp complete',
|
|
56
|
+
'#pragma HLS ARRAY_PARTITION variable=out complete',
|
|
57
|
+
'#pragma HLS PIPELINE II=1',
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
self._pragma = ()
|
|
61
|
+
else:
|
|
62
|
+
self._pragma = tuple(pragma)
|
|
63
|
+
|
|
64
|
+
def write(self):
|
|
65
|
+
if not self._path.exists():
|
|
66
|
+
self._path.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
template_def, bridge = cpp_logic_and_bridge_gen(
|
|
68
|
+
self._solution,
|
|
69
|
+
self._prj_name,
|
|
70
|
+
self._flavor,
|
|
71
|
+
['#pragma HLS INLINE'],
|
|
72
|
+
4,
|
|
73
|
+
0,
|
|
74
|
+
self._print_latency,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
headers = ['#pragma once', '#include "bitshift.hh"']
|
|
78
|
+
|
|
79
|
+
inp_type, out_type = get_io_types(self._solution, self._flavor)
|
|
80
|
+
n_in, n_out = len(self._solution.inp_qint), len(self._solution.out_qint)
|
|
81
|
+
template_signature = (
|
|
82
|
+
f'template <typename inp_t, typename out_t>\nvoid {self._prj_name}(inp_t inp[{n_in}], out_t out[{n_out}]);'
|
|
83
|
+
)
|
|
84
|
+
fn_signature = f'void {self._prj_name}_fn({inp_type} inp[{n_in}], {out_type} out[{n_out}])'
|
|
85
|
+
|
|
86
|
+
with open(self._path / f'{self._prj_name}.hh', 'w') as f:
|
|
87
|
+
f.write('\n'.join(headers) + '\n\n')
|
|
88
|
+
f.write(f'{template_signature}\n\n{fn_signature};\n')
|
|
89
|
+
|
|
90
|
+
pragma_str = '\n'.join(self._pragma)
|
|
91
|
+
cpp_def = f"""
|
|
92
|
+
#include "{self._prj_name}.hh"
|
|
93
|
+
|
|
94
|
+
{template_def}
|
|
95
|
+
|
|
96
|
+
{fn_signature} {{
|
|
97
|
+
{pragma_str}
|
|
98
|
+
{self._prj_name}<{inp_type}, {out_type}>(inp, out);
|
|
99
|
+
}}
|
|
100
|
+
"""
|
|
101
|
+
with open(self._path / f'{self._prj_name}.cc', 'w') as f:
|
|
102
|
+
f.write(cpp_def)
|
|
103
|
+
|
|
104
|
+
with open(self._path / f'{self._prj_name}_bridge.cc', 'w') as f:
|
|
105
|
+
f.write(bridge)
|
|
106
|
+
|
|
107
|
+
shutil.copy(self.__src_root / 'cpp/source/binder_util.hh', self._path)
|
|
108
|
+
shutil.copy(self.__src_root / f'cpp/source/{self._flavor}_bitshift.hh', self._path / 'bitshift.hh')
|
|
109
|
+
shutil.copy(self.__src_root / 'cpp/source/build_binder.mk', self._path)
|
|
110
|
+
if self._flavor == 'vitis':
|
|
111
|
+
shutil.copytree(self.__src_root / 'cpp/source/ap_types', self._path / 'ap_types', dirs_exist_ok=True)
|
|
112
|
+
else:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
self._solution.save(self._path / 'project.json')
|
|
116
|
+
|
|
117
|
+
def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
|
|
118
|
+
"""Same as compile, but will not write to the library
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
verbose : bool, optional
|
|
123
|
+
Verbose output, by default False
|
|
124
|
+
openmp : bool, optional
|
|
125
|
+
Enable openmp, by default True
|
|
126
|
+
o3 : bool | None, optional
|
|
127
|
+
Turn on -O3 flag, by default False
|
|
128
|
+
clean : bool, optional
|
|
129
|
+
Remove obsolete shared object files, by default True
|
|
130
|
+
|
|
131
|
+
Raises
|
|
132
|
+
------
|
|
133
|
+
RuntimeError
|
|
134
|
+
If compilation fails
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
self._uuid = str(uuid4())
|
|
138
|
+
args = ['make', '-f', 'build_binder.mk']
|
|
139
|
+
env = os.environ.copy()
|
|
140
|
+
env['PRJ_NAME'] = self._prj_name
|
|
141
|
+
env['STAMP'] = self._uuid
|
|
142
|
+
env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
|
|
143
|
+
if o3:
|
|
144
|
+
args.append('fast')
|
|
145
|
+
|
|
146
|
+
if clean:
|
|
147
|
+
m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
|
|
148
|
+
for p in self._path.iterdir():
|
|
149
|
+
if not p.is_dir() and m.match(p.name):
|
|
150
|
+
p.unlink()
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
|
|
154
|
+
except subprocess.CalledProcessError as e:
|
|
155
|
+
print(e.stderr.decode(), file=sys.stderr)
|
|
156
|
+
print(e.stdout.decode(), file=sys.stdout)
|
|
157
|
+
raise RuntimeError('Compilation failed!!') from e
|
|
158
|
+
if r.returncode != 0:
|
|
159
|
+
print(r.stderr.decode(), file=sys.stderr)
|
|
160
|
+
print(r.stdout.decode(), file=sys.stderr)
|
|
161
|
+
raise RuntimeError('Compilation failed!!')
|
|
162
|
+
|
|
163
|
+
self._load_lib(self._uuid)
|
|
164
|
+
|
|
165
|
+
def _load_lib(self, uuid: str | None = None):
|
|
166
|
+
uuid = uuid if uuid is not None else self._uuid
|
|
167
|
+
self._uuid = uuid
|
|
168
|
+
lib_path = self._path / f'lib{self._prj_name}_{uuid}.so'
|
|
169
|
+
if not lib_path.exists():
|
|
170
|
+
raise RuntimeError(f'Library {lib_path} does not exist')
|
|
171
|
+
self._lib = ctypes.CDLL(str(lib_path))
|
|
172
|
+
|
|
173
|
+
def compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
|
|
174
|
+
"""Compile the model to a shared object file
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
verbose : bool, optional
|
|
179
|
+
Verbose output, by default False
|
|
180
|
+
openmp : bool, optional
|
|
181
|
+
Enable openmp, by default True
|
|
182
|
+
o3 : bool | None, optional
|
|
183
|
+
Turn on -O3 flag, by default False
|
|
184
|
+
clean : bool, optional
|
|
185
|
+
Remove obsolete shared object files, by default True
|
|
186
|
+
|
|
187
|
+
Raises
|
|
188
|
+
------
|
|
189
|
+
RuntimeError
|
|
190
|
+
If compilation fails
|
|
191
|
+
"""
|
|
192
|
+
self.write()
|
|
193
|
+
self._compile(verbose, openmp, o3, clean)
|
|
194
|
+
|
|
195
|
+
def predict(self, data: NDArray[T]) -> NDArray[T]:
|
|
196
|
+
"""Run the model on the input data.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
data : NDArray[np.floating]
|
|
201
|
+
Input data to the model. The shape is ignored, and the number of samples is
|
|
202
|
+
determined by the size of the data.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
NDArray[np.floating]
|
|
207
|
+
Output of the model in shape (n_samples, output_size).
|
|
208
|
+
"""
|
|
209
|
+
assert self._lib is not None, 'Library not loaded, call .compile() first.'
|
|
210
|
+
inp_size, out_size = self._solution.shape
|
|
211
|
+
|
|
212
|
+
dtype = data.dtype
|
|
213
|
+
if dtype not in (np.float32, np.float64):
|
|
214
|
+
raise TypeError(f'Unsupported input data type: {dtype}. Expected float32 or float64.')
|
|
215
|
+
c_dtype = ctypes.c_float if dtype == np.float32 else ctypes.c_double
|
|
216
|
+
|
|
217
|
+
assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
|
|
218
|
+
n_sample = data.size // inp_size
|
|
219
|
+
|
|
220
|
+
inp_data = np.ascontiguousarray(data)
|
|
221
|
+
out_data = np.empty(n_sample * out_size, dtype=dtype)
|
|
222
|
+
|
|
223
|
+
inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(c_dtype))
|
|
224
|
+
out_buf = out_data.ctypes.data_as(ctypes.POINTER(c_dtype))
|
|
225
|
+
if dtype == np.float32:
|
|
226
|
+
self._lib.inference_f32(inp_buf, out_buf, n_sample)
|
|
227
|
+
else:
|
|
228
|
+
self._lib.inference_f64(inp_buf, out_buf, n_sample)
|
|
229
|
+
|
|
230
|
+
return out_data.reshape(n_sample, out_size) # type: ignore
|
|
231
|
+
|
|
232
|
+
def __repr__(self):
|
|
233
|
+
inp_size, out_size = self._solution.shape
|
|
234
|
+
inp_size, out_size = self._solution.shape
|
|
235
|
+
cost = round(self._solution.cost)
|
|
236
|
+
inp_kifs = tuple(zip(*map(_minimal_kif, self._solution.inp_qint)))
|
|
237
|
+
out_kifs = tuple(zip(*map(_minimal_kif, self._solution.out_qint)))
|
|
238
|
+
in_bits, out_bits = np.sum(inp_kifs), np.sum(out_kifs)
|
|
239
|
+
|
|
240
|
+
spec = f"""Top Function: {self._prj_name}\n====================
|
|
241
|
+
{inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
|
|
242
|
+
combinational @ delay={self._solution.latency}
|
|
243
|
+
Estimated cost: {cost} LUTs"""
|
|
244
|
+
|
|
245
|
+
is_compiled = self._lib is not None
|
|
246
|
+
if is_compiled:
|
|
247
|
+
assert self._uuid is not None
|
|
248
|
+
openmp = 'with OpenMP' if self._lib.openmp_enabled() else '' # type: ignore
|
|
249
|
+
spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
|
|
250
|
+
else:
|
|
251
|
+
spec += '\nEmulator is **not compiled**'
|
|
252
|
+
return spec
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright 2024-2024 Chang Sun
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#ifndef __AP_BINARY_H__
|
|
18
|
+
#define __AP_BINARY_H__
|
|
19
|
+
|
|
20
|
+
#include <ap_fixed.h>
|
|
21
|
+
#include <cassert>
|
|
22
|
+
|
|
23
|
+
struct ap_binary {
|
|
24
|
+
|
|
25
|
+
bool is_one;
|
|
26
|
+
|
|
27
|
+
INLINE ap_binary() {}
|
|
28
|
+
|
|
29
|
+
INLINE ap_binary(const bool value) : is_one(value) {}
|
|
30
|
+
INLINE ap_binary(const ap_binary &value) : is_one(value.is_one) {}
|
|
31
|
+
|
|
32
|
+
INLINE operator int() const { return is_one ? 1 : -1; }
|
|
33
|
+
INLINE operator float() const { return is_one ? 1.0 : -1.0; }
|
|
34
|
+
|
|
35
|
+
template <typename T> INLINE ap_binary(T value) : is_one(value >= 0) {}
|
|
36
|
+
|
|
37
|
+
template <typename T>
|
|
38
|
+
INLINE auto operator=(T value) -> decltype(std::enable_if_t<std::is_same<T, ap_binary>::value, int>()) {
|
|
39
|
+
is_one = value.is_one;
|
|
40
|
+
return 0;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
template <typename T>
|
|
44
|
+
INLINE auto operator=(T value) -> decltype(std::enable_if_t<!std::is_same<T, ap_binary>::value, int>()) {
|
|
45
|
+
is_one = value >= 0;
|
|
46
|
+
return 0;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
INLINE ap_fixed<2, 1> value() const { return is_one ? 1 : -1; }
|
|
50
|
+
|
|
51
|
+
template <typename T> INLINE bool operator==(T value) const { return value() == value; }
|
|
52
|
+
|
|
53
|
+
template <typename T> INLINE bool operator!=(T value) const { return value() != value; }
|
|
54
|
+
|
|
55
|
+
template <typename T> INLINE bool operator<(T value) const { return value() < value; }
|
|
56
|
+
|
|
57
|
+
template <typename T> INLINE bool operator<=(T value) const { return value() <= value; }
|
|
58
|
+
|
|
59
|
+
template <typename T> INLINE bool operator>(T value) const { return value() > value; }
|
|
60
|
+
|
|
61
|
+
template <typename T> INLINE bool operator>=(T value) const { return value() >= value; }
|
|
62
|
+
|
|
63
|
+
template <typename T> INLINE ap_binary operator+(T value) const { return ap_binary(is_one || value.is_one); }
|
|
64
|
+
|
|
65
|
+
template <typename T> INLINE ap_binary operator*(T value) const { return ap_binary(is_one && value.is_one); }
|
|
66
|
+
|
|
67
|
+
template <typename T> INLINE ap_binary operator-(T value) const { return ap_binary(is_one && !value.is_one); }
|
|
68
|
+
|
|
69
|
+
template <typename T> INLINE T operator+(T value) { return value + value(); }
|
|
70
|
+
|
|
71
|
+
template <typename T> INLINE T operator*(T value) { return value * value(); }
|
|
72
|
+
|
|
73
|
+
template <typename T> INLINE T operator-(T value) { return value - value(); }
|
|
74
|
+
};
|
|
75
|
+
|
|
76
|
+
typedef ap_fixed<2, 1, AP_RND_CONV, AP_SAT_SYM> ap_ternary;
|
|
77
|
+
|
|
78
|
+
#endif
|