da4ml 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (59) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/cmvm/api.py +2 -6
  3. da4ml/cmvm/core/__init__.py +0 -1
  4. da4ml/cmvm/types.py +99 -19
  5. da4ml/codegen/__init__.py +5 -4
  6. da4ml/codegen/cpp/__init__.py +2 -1
  7. da4ml/codegen/cpp/cpp_codegen.py +58 -25
  8. da4ml/codegen/cpp/hls_model.py +252 -0
  9. da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
  10. da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
  11. da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
  12. da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
  13. da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
  14. da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
  15. da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
  16. da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
  17. da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
  18. da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
  19. da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
  20. da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
  21. da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
  22. da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
  23. da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
  24. da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
  25. da4ml/codegen/cpp/source/binder_util.hh +56 -0
  26. da4ml/codegen/cpp/source/build_binder.mk +24 -0
  27. da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
  28. da4ml/codegen/verilog/__init__.py +2 -3
  29. da4ml/codegen/verilog/comb.py +65 -24
  30. da4ml/codegen/verilog/io_wrapper.py +36 -141
  31. da4ml/codegen/verilog/pipeline.py +21 -3
  32. da4ml/codegen/verilog/source/binder_util.hh +72 -0
  33. da4ml/codegen/verilog/source/build_prj.tcl +0 -1
  34. da4ml/codegen/verilog/source/mux.v +58 -0
  35. da4ml/codegen/verilog/source/negative.v +28 -0
  36. da4ml/codegen/verilog/source/shift_adder.v +4 -1
  37. da4ml/codegen/verilog/source/template.xdc +3 -0
  38. da4ml/codegen/verilog/verilog_model.py +42 -15
  39. da4ml/converter/__init__.py +0 -0
  40. da4ml/converter/hgq2/parser.py +105 -0
  41. da4ml/converter/hgq2/replica.py +383 -0
  42. da4ml/trace/__init__.py +2 -2
  43. da4ml/trace/fixed_variable.py +177 -18
  44. da4ml/trace/fixed_variable_array.py +124 -9
  45. da4ml/trace/ops/__init__.py +22 -6
  46. da4ml/trace/ops/conv_utils.py +146 -14
  47. da4ml/trace/ops/einsum_utils.py +9 -6
  48. da4ml/trace/ops/reduce_utils.py +103 -0
  49. da4ml/trace/pipeline.py +36 -34
  50. da4ml/trace/tracer.py +37 -5
  51. da4ml-0.3.0.dist-info/METADATA +107 -0
  52. da4ml-0.3.0.dist-info/RECORD +64 -0
  53. da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
  54. da4ml-0.2.0.dist-info/METADATA +0 -65
  55. da4ml-0.2.0.dist-info/RECORD +0 -39
  56. /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
  57. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/WHEEL +0 -0
  58. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/licenses/LICENSE +0 -0
  59. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/top_level.txt +0 -0
da4ml/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0'
21
- __version_tuple__ = version_tuple = (0, 2, 0)
20
+ __version__ = version = '0.3.0'
21
+ __version_tuple__ = version_tuple = (0, 3, 0)
da4ml/cmvm/api.py CHANGED
@@ -140,10 +140,6 @@ def jit_solve(
140
140
  if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
141
141
  decompose_dc -= 1
142
142
  continue
143
- if sum([op.cost for op in sol1.ops]) * 4 > sum([op.cost for op in sol0.ops]) and decompose_dc > 0:
144
- # If the second stage is too expensive, the decomposition usually doesn't worth it
145
- decompose_dc -= 1
146
- continue
147
143
  break
148
144
  if max(latencies1) > latency_allowed:
149
145
  # When latency depends on the bw, may happen
@@ -158,8 +154,8 @@ def solve(
158
154
  method1: str = 'auto',
159
155
  hard_dc: int = -1,
160
156
  decompose_dc: int = -2,
161
- qintervals: tuple[QInterval, ...] | None = None,
162
- latencies: tuple[float, ...] | None = None,
157
+ qintervals: list[QInterval] | None = None,
158
+ latencies: list[float] | None = None,
163
159
  adder_size: int = -1,
164
160
  carry_size: int = -1,
165
161
  search_all_decompose_dc: bool = True,
@@ -131,7 +131,6 @@ def to_solution(
131
131
 
132
132
  _global_id = len(ops)
133
133
  for i_out in range(n_out):
134
- heap = []
135
134
  idx, shifts = np.where(expr[:, i_out] != 0)
136
135
  sub = np.empty(len(idx), dtype=np.int64)
137
136
  for i, (i_in, shift) in enumerate(zip(idx, shifts)):
da4ml/cmvm/types.py CHANGED
@@ -159,6 +159,8 @@ def _relu(v: 'T', i: int | None = None, f: int | None = None, inv: bool = False,
159
159
  from ..trace.fixed_variable import FixedVariable
160
160
 
161
161
  assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic relu'
162
+ if inv:
163
+ v = -v
162
164
  return v.relu(i, f, round_mode=round_mode)
163
165
 
164
166
 
@@ -289,15 +291,16 @@ class Solution(NamedTuple):
289
291
  The output data after applying the operations defined in the solution.
290
292
 
291
293
  """
294
+
295
+ from ..trace.fixed_variable import FixedVariable
296
+
292
297
  buf = np.empty(len(self.ops), dtype=object)
293
298
  inp = np.asarray(inp)
294
299
 
295
300
  inp_qint = [op.qint for op in self.ops if op.opcode == -1]
296
301
  if quantize: # TRN and WRAP
297
302
  k, i, f = map(np.array, zip(*map(minimal_kif, inp_qint)))
298
- eps = 2.0**-f
299
- _low, _high = -(2.0 ** (i + f)) * k, 2.0 ** (i + f) - 1
300
- inp = eps * ((np.floor(inp / eps) - _low) % 2.0 ** (k + i + f) + _low)
303
+ inp = [_quantize(*x, round_mode='TRN') for x in zip(inp, k, i, f)]
301
304
 
302
305
  inp = inp * (2.0 ** np.array(self.inp_shift))
303
306
  for i, op in enumerate(self.ops):
@@ -320,39 +323,61 @@ class Solution(NamedTuple):
320
323
  buf[i] = buf[op.id0] + bias
321
324
  case 5:
322
325
  buf[i] = op.data * op.qint.step # const definition
326
+ case 6 | -6: # MSB Mux
327
+ id_c = op.data & 0xFFFFFFFF
328
+ k, v0, v1 = buf[id_c], buf[op.id0], buf[op.id1]
329
+ shift = (op.data >> 32) & 0xFFFFFFFF
330
+ shift = shift if shift < 0x80000000 else shift - 0x100000000
331
+ if op.opcode == -6:
332
+ v1 = -v1
333
+
334
+ if isinstance(k, FixedVariable):
335
+ buf[i] = k.msb_mux(v0, v1 * 2**shift)
336
+ else:
337
+ qint_k = self.ops[id_c].qint
338
+ if qint_k.min < 0:
339
+ buf[i] = v0 if k < 0 else v1 * 2.0**shift
340
+ else:
341
+ _k, _i, _f = _minimal_kif(qint_k)
342
+ buf[i] = v0 if k >= 2.0 ** (_i - 1) else v1 * 2.0**shift
323
343
  case _:
324
344
  raise ValueError(f'Unknown opcode {op.opcode} in {op}')
325
345
 
326
- sf = 2.0 ** np.array(self.out_shifts)
346
+ sf = 2.0 ** np.array(self.out_shifts, dtype=np.float64)
327
347
  sign = np.where(self.out_negs, -1, 1)
328
- out_idx = np.array(self.out_idxs)
348
+ out_idx = np.array(self.out_idxs, dtype=np.int32)
329
349
  mask = np.where(out_idx < 0, 0, 1)
330
350
  if debug:
351
+ operands = []
331
352
  for i, v in enumerate(buf):
332
353
  op = self.ops[i]
333
354
  match op.opcode:
334
355
  case -1:
335
356
  op_str = 'inp'
336
- case 0:
337
- op_str = f'buf[{op.id0}] + buf[{op.id1}]<<{op.data}'
338
- case 1:
339
- op_str = f'buf[{op.id0}] - buf[{op.id1}]<<{op.data}'
340
- case 2:
341
- op_str = f'relu(buf[{op.id0}])'
342
- case -2:
343
- op_str = f'relu(-buf[{op.id0}])'
344
- case 3:
345
- op_str = f'quantize(buf[{op.id0}])'
346
- case -3:
347
- op_str = f'quantize(-buf[{op.id0}])'
357
+ case 0 | 1:
358
+ _sign = '-' if op.opcode == 1 else '+'
359
+ op_str = f'buf[{op.id0}] {_sign} buf[{op.id1}]<<{op.data}'
360
+ case 2 | -2:
361
+ _sign = '' if op.opcode == 2 else '-'
362
+ op_str = f'relu({_sign}buf[{op.id0}])'
363
+ case 3 | -3:
364
+ _sign = '' if op.opcode == 3 else '-'
365
+ op_str = f'quantize({_sign}buf[{op.id0}])'
348
366
  case 4:
349
367
  op_str = f'buf[{op.id0}] + {op.data * op.qint.step}'
350
368
  case 5:
351
369
  op_str = f'const {op.data * op.qint.step}'
370
+ case 6 | -6:
371
+ _sign = '-' if op.opcode == -6 else ''
372
+ op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
352
373
  case _:
353
374
  raise ValueError(f'Unknown opcode {op.opcode} in {op}')
354
375
 
355
- print(f'{op_str:24} |-> buf[{i}] = {v}')
376
+ result = f'|-> buf[{i}] = {v}'
377
+ operands.append((op_str, result))
378
+ max_len = max(len(op[0]) for op in operands)
379
+ for op_str, result in operands:
380
+ print(f'{op_str:<{max_len}} {result}')
356
381
 
357
382
  if dump:
358
383
  return buf
@@ -443,6 +468,61 @@ class Solution(NamedTuple):
443
468
  data = json.load(f)
444
469
  return cls.deserialize(data)
445
470
 
471
+ @property
472
+ def ref_count(self) -> np.ndarray:
473
+ """The number of references to the output elements in the solution."""
474
+ ref_count = np.zeros(len(self.ops), dtype=np.uint64)
475
+ for op in self.ops:
476
+ if op.opcode == -1:
477
+ continue
478
+ id0, id1 = op.id0, op.id1
479
+ if id0 != -1:
480
+ ref_count[id0] += 1
481
+ if id1 != -1:
482
+ ref_count[id1] += 1
483
+ if op.opcode in (6, -6):
484
+ # msb_mux operation
485
+ ref_count[op.data & 0xFFFFFFFF] += 1
486
+ for i in self.out_idxs:
487
+ if i < 0:
488
+ continue
489
+ ref_count[i] += 1
490
+ return ref_count
491
+
492
+ def to_binary(self):
493
+ n_in, n_out = self.shape
494
+ header_size_i32 = 2 + n_in + n_out * 3 + 1
495
+
496
+ header = np.concatenate(
497
+ [
498
+ [n_in, n_out, len(self.ops)],
499
+ self.inp_shift,
500
+ self.out_idxs,
501
+ self.out_shifts,
502
+ self.out_negs,
503
+ ],
504
+ axis=0,
505
+ dtype=np.int32,
506
+ )
507
+ assert len(header) == header_size_i32, f'Header size mismatch: {len(header)} != {header_size_i32}'
508
+ code = np.empty((len(self.ops), 8), dtype=np.int32)
509
+ for i, op in enumerate(self.ops):
510
+ buf = code[i]
511
+ buf[0] = op.opcode
512
+ buf[1] = op.id0
513
+ buf[2] = op.id1
514
+ buf[5:] = _minimal_kif(op.qint)
515
+ buf_i64 = buf[3:5].view(np.int64)
516
+ buf_i64[0] = op.data
517
+ data = np.concatenate([header, code.flatten()])
518
+ return data
519
+
520
+ def save_binary(self, path: str | Path):
521
+ """Dump the solution to a binary file."""
522
+ data = self.to_binary()
523
+ with open(path, 'wb') as f:
524
+ data.tofile(f)
525
+
446
526
 
447
527
  class CascadedSolution(NamedTuple):
448
528
  """A solution that implements cascaded matrix-vector multiplications through multiple CMVM stages.
@@ -561,7 +641,7 @@ class CascadedSolution(NamedTuple):
561
641
  @property
562
642
  def reg_bits(self):
563
643
  """The number of bits used for the register in the solution."""
564
- bits = 0
644
+ bits = sum(map(sum, (_minimal_kif(qint) for qint in self.inp_qint)))
565
645
  for _sol in self.solutions:
566
646
  kifs = [_minimal_kif(qint) for qint in _sol.out_qint]
567
647
  _bits = sum(map(sum, kifs))
da4ml/codegen/__init__.py CHANGED
@@ -1,11 +1,12 @@
1
- from .cpp import cpp_logic_and_bridge_gen
2
- from .verilog import comb_binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_binder_gen, pipeline_logic_gen
1
+ from .cpp import HLSModel, cpp_logic_and_bridge_gen
2
+ from .verilog import VerilogModel, binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_logic_gen
3
3
 
4
4
  __all__ = [
5
5
  'cpp_logic_and_bridge_gen',
6
6
  'comb_logic_gen',
7
7
  'generate_io_wrapper',
8
- 'comb_binder_gen',
9
8
  'pipeline_logic_gen',
10
- 'pipeline_binder_gen',
9
+ 'binder_gen',
10
+ 'HLSModel',
11
+ 'VerilogModel',
11
12
  ]
@@ -1,3 +1,4 @@
1
1
  from .cpp_codegen import cpp_logic_and_bridge_gen
2
+ from .hls_model import HLSModel
2
3
 
3
- __all__ = ['cpp_logic_and_bridge_gen']
4
+ __all__ = ['cpp_logic_and_bridge_gen', 'HLSModel']
@@ -1,19 +1,19 @@
1
1
  from collections.abc import Callable
2
2
 
3
- from ...cmvm.types import Op, QInterval, Solution, _minimal_kif
3
+ from ...cmvm.types import QInterval, Solution, _minimal_kif
4
4
  from ...trace.fixed_variable import _const_f
5
5
 
6
6
 
7
- def kif_to_vitis_type(k: bool | int, i: int, f: int):
7
+ def kif_to_vitis_type(k: bool | int = 1, i: int = 0, f: int = 0):
8
8
  if k == i == f == 0:
9
9
  f = 1
10
- return f'ap_{"" if k else "u"}fixed<{k+i+f},{k+i}>'
10
+ return f'ap_{"" if k else "u"}fixed<{k + i + f},{k + i}>'
11
11
 
12
12
 
13
- def kif_to_hlslib_type(k: bool | int, i: int, f: int):
13
+ def kif_to_hlslib_type(k: bool | int = 1, i: int = 0, f: int = 0):
14
14
  if k == i == f == 0:
15
15
  f = 1
16
- return f'ac_fixed<{int(k)},{k+i+f},{k+i}>'
16
+ return f'ac_fixed<{int(k)},{k + i + f},{k + i}>'
17
17
 
18
18
 
19
19
  def get_typestr_fn(flavor: str):
@@ -27,13 +27,18 @@ def get_typestr_fn(flavor: str):
27
27
  return typestr_fn
28
28
 
29
29
 
30
- def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
31
- all_kifs = map(_minimal_kif, (op.qint for op in ops))
30
+ def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
31
+ ops = sol.ops
32
+ all_kifs = list(map(_minimal_kif, (op.qint for op in ops)))
32
33
  all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
33
34
 
34
35
  lines = []
35
-
36
+ ref_count = sol.ref_count
36
37
  for i, op in enumerate(ops):
38
+ if ref_count[i] == 0:
39
+ # Skip unused ops
40
+ continue
41
+
37
42
  _type = all_types[i]
38
43
 
39
44
  ref0 = f'v{op.id0}'
@@ -42,12 +47,10 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
42
47
  case -1:
43
48
  # Input marker
44
49
  val = f'inp[{ops[op.id0].id0}]'
45
-
46
50
  case 0 | 1:
47
51
  # Common a+/-b<<shift op
48
52
  ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
49
53
  val = f'{ref0} {"-" if op.opcode == 1 else "+"} {ref1}'
50
-
51
54
  case 2 | -2:
52
55
  if op.opcode == 2: # relu(inp)
53
56
  if ops[op.id0].qint.min < 0:
@@ -59,11 +62,9 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
59
62
  val = f'{ref0} > 0 ? {_type}(0) : {_type}(-{ref0})'
60
63
  else:
61
64
  val = f'-{ref0}'
62
-
63
65
  case 3 | -3:
64
66
  # Explicit quantization op, done implicitly via assignment
65
67
  val = ref0 if op.opcode == 3 else f'-{ref0}'
66
-
67
68
  case 4:
68
69
  # Constant addition
69
70
  _number = op.data * op.qint.step
@@ -71,10 +72,20 @@ def ssa_gen(ops: list[Op], print_latency: bool, typestr_fn: Callable[[bool | int
71
72
  f = _const_f(mag)
72
73
  const_type_str = typestr_fn(*_minimal_kif(QInterval(mag, mag, 2.0**-f)))
73
74
  val = f'{ref0} {sign} {const_type_str}({mag})'
74
-
75
75
  case 5:
76
+ # Define constant
76
77
  _number = op.data * op.qint.step
77
78
  val = f'{_number}'
79
+ case 6 | -6:
80
+ # MSB Mux
81
+ id_c = op.data & 0xFFFFFFFF
82
+ bw_k = sum(all_kifs[id_c])
83
+ shift = (op.data >> 32) & 0xFFFFFFFF
84
+ shift = shift if shift < 0x80000000 else shift - 0x100000000
85
+ ref_k = f'v{id_c}[{bw_k - 1}]'
86
+ sign = '-' if op.opcode == -6 else ''
87
+ ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
88
+ val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
78
89
 
79
90
  case _:
80
91
  raise ValueError(f'Unsupported opcode: {op.opcode}')
@@ -103,6 +114,15 @@ def output_gen(sol: Solution, typestr_fn: Callable[[bool | int, int, int], str])
103
114
  return lines
104
115
 
105
116
 
117
+ def get_io_types(sol: Solution, flavor: str):
118
+ typestr_fn = get_typestr_fn(flavor)
119
+ in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
120
+ inp_type = typestr_fn(*in_kif)
121
+ out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
122
+ out_type = typestr_fn(*out_kif)
123
+ return inp_type, out_type
124
+
125
+
106
126
  def cpp_logic_and_bridge_gen(
107
127
  sol: Solution,
108
128
  fn_name: str,
@@ -113,36 +133,49 @@ def cpp_logic_and_bridge_gen(
113
133
  print_latency: bool = False,
114
134
  ):
115
135
  typestr_fn = get_typestr_fn(flavor)
116
- in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
117
- inp_type = typestr_fn(*in_kif)
118
- out_kif = map(max, zip(*map(_minimal_kif, sol.out_qint)))
119
- out_type = typestr_fn(*out_kif)
136
+ inp_t, out_t = get_io_types(sol, flavor)
120
137
 
121
138
  n_in, n_out = sol.shape
122
139
  template_def = 'template <typename inp_t, typename out_t>'
123
140
  fn_signature = f'void {fn_name}(inp_t inp[{n_in}], out_t out[{n_out}])'
124
141
  pragmas = pragmas or []
125
142
 
126
- ssa_lines = ssa_gen(sol.ops, print_latency=print_latency, typestr_fn=typestr_fn)
143
+ ssa_lines = ssa_gen(sol, print_latency=print_latency, typestr_fn=typestr_fn)
127
144
  output_lines = output_gen(sol, typestr_fn=typestr_fn)
128
145
 
129
146
  indent = ' ' * n_indent
130
147
  base_indent = indent * n_base_indent
131
148
  body_indent = '\n' + base_indent + indent
132
149
  code = f"""{base_indent}{template_def}
133
- {base_indent}{fn_signature} {{ // {inp_type} -> {out_type}
134
- {body_indent}{body_indent.join(pragmas)}
150
+ {base_indent}{fn_signature} {{ // {inp_t} -> {out_t}
151
+ {base_indent + indent}{body_indent.join(pragmas)}
135
152
  {body_indent}{body_indent.join(ssa_lines)}
136
153
  {body_indent}{body_indent.join(output_lines)}
137
154
  {base_indent}}}
138
155
  """
139
- bridge = f"""#include "bridge.h"
140
- #include "fn.h"
156
+ bridge = f"""#include "binder_util.hh"
157
+ #include "{fn_name}.hh"
158
+
159
+ struct {fn_name}_config {{
160
+ static const size_t N_inp = {n_in};
161
+ static const size_t N_out = {n_out};
162
+ typedef {inp_t} inp_t;
163
+ typedef {out_t} out_t;
164
+ constexpr static auto f = {fn_name}<inp_t, out_t>;
165
+ }};
141
166
 
142
167
  extern "C" {{
143
- void bridge(double *inp, double *out, int size) {{
144
- auto fn = {fn_name}<{inp_type}, {out_type}>;
145
- vitis_bridge<{inp_type}, {out_type}, {n_in}, {n_out}>(fn, inp, out, size);
168
+
169
+ bool openmp_enabled() {{
170
+ return _openmp;
171
+ }}
172
+
173
+ void inference_f64(double *inp, double *out, size_t size) {{
174
+ batch_inference<{fn_name}_config, double>(inp, out, size);
175
+ }}
176
+
177
+ void inference_f32(float *inp, float *out, size_t size) {{
178
+ batch_inference<{fn_name}_config, float>(inp, out, size);
146
179
  }}
147
180
  }}"""
148
181
  return code, bridge
@@ -0,0 +1,252 @@
1
+ import ctypes
2
+ import os
3
+ import re
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ from collections.abc import Sequence
8
+ from pathlib import Path
9
+ from typing import TypeVar
10
+ from uuid import uuid4
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+ from da4ml.cmvm.types import Solution
16
+ from da4ml.codegen.cpp.cpp_codegen import cpp_logic_and_bridge_gen, get_io_types
17
+
18
+ from ... import codegen
19
+ from ...cmvm.types import _minimal_kif
20
+
21
+ T = TypeVar('T', bound=np.floating)
22
+
23
+
24
+ class HLSModel:
25
+ def __init__(
26
+ self,
27
+ solution: Solution,
28
+ prj_name: str,
29
+ path: str | Path,
30
+ flavor: str = 'vitis',
31
+ print_latency: bool = True,
32
+ part_name: str = 'xcvu13p-flga2577-2-e',
33
+ pragma: Sequence[str] | None = None,
34
+ clock_period: int = 5,
35
+ clock_uncertainty: float = 0.1,
36
+ io_delay_minmax: tuple[float, float] = (0.2, 0.4),
37
+ ):
38
+ self._solution = solution
39
+ self._prj_name = prj_name
40
+ self._path = Path(path)
41
+ self._flavor = flavor.lower()
42
+ assert self._flavor in ('vitis', 'hlslib'), f'Unsupported HLS flavor: {self._flavor}'
43
+ self._print_latency = print_latency
44
+ self._part_name = part_name
45
+ self._clock_period = clock_period
46
+ self._clock_uncertainty = clock_uncertainty
47
+ self._io_delay_minmax = io_delay_minmax
48
+ self.__src_root = Path(codegen.__file__).parent
49
+ self._lib = None
50
+ self._uuid = None
51
+
52
+ if pragma is None:
53
+ if self._flavor == 'vitis':
54
+ self._pragma = (
55
+ '#pragma HLS ARRAY_PARTITION variable=inp complete',
56
+ '#pragma HLS ARRAY_PARTITION variable=out complete',
57
+ '#pragma HLS PIPELINE II=1',
58
+ )
59
+ else:
60
+ self._pragma = ()
61
+ else:
62
+ self._pragma = tuple(pragma)
63
+
64
+ def write(self):
65
+ if not self._path.exists():
66
+ self._path.mkdir(parents=True, exist_ok=True)
67
+ template_def, bridge = cpp_logic_and_bridge_gen(
68
+ self._solution,
69
+ self._prj_name,
70
+ self._flavor,
71
+ ['#pragma HLS INLINE'],
72
+ 4,
73
+ 0,
74
+ self._print_latency,
75
+ )
76
+
77
+ headers = ['#pragma once', '#include "bitshift.hh"']
78
+
79
+ inp_type, out_type = get_io_types(self._solution, self._flavor)
80
+ n_in, n_out = len(self._solution.inp_qint), len(self._solution.out_qint)
81
+ template_signature = (
82
+ f'template <typename inp_t, typename out_t>\nvoid {self._prj_name}(inp_t inp[{n_in}], out_t out[{n_out}]);'
83
+ )
84
+ fn_signature = f'void {self._prj_name}_fn({inp_type} inp[{n_in}], {out_type} out[{n_out}])'
85
+
86
+ with open(self._path / f'{self._prj_name}.hh', 'w') as f:
87
+ f.write('\n'.join(headers) + '\n\n')
88
+ f.write(f'{template_signature}\n\n{fn_signature};\n')
89
+
90
+ pragma_str = '\n'.join(self._pragma)
91
+ cpp_def = f"""
92
+ #include "{self._prj_name}.hh"
93
+
94
+ {template_def}
95
+
96
+ {fn_signature} {{
97
+ {pragma_str}
98
+ {self._prj_name}<{inp_type}, {out_type}>(inp, out);
99
+ }}
100
+ """
101
+ with open(self._path / f'{self._prj_name}.cc', 'w') as f:
102
+ f.write(cpp_def)
103
+
104
+ with open(self._path / f'{self._prj_name}_bridge.cc', 'w') as f:
105
+ f.write(bridge)
106
+
107
+ shutil.copy(self.__src_root / 'cpp/source/binder_util.hh', self._path)
108
+ shutil.copy(self.__src_root / f'cpp/source/{self._flavor}_bitshift.hh', self._path / 'bitshift.hh')
109
+ shutil.copy(self.__src_root / 'cpp/source/build_binder.mk', self._path)
110
+ if self._flavor == 'vitis':
111
+ shutil.copytree(self.__src_root / 'cpp/source/ap_types', self._path / 'ap_types', dirs_exist_ok=True)
112
+ else:
113
+ pass
114
+
115
+ self._solution.save(self._path / 'project.json')
116
+
117
+ def _compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
118
+ """Same as compile, but will not write to the library
119
+
120
+ Parameters
121
+ ----------
122
+ verbose : bool, optional
123
+ Verbose output, by default False
124
+ openmp : bool, optional
125
+ Enable openmp, by default True
126
+ o3 : bool | None, optional
127
+ Turn on -O3 flag, by default False
128
+ clean : bool, optional
129
+ Remove obsolete shared object files, by default True
130
+
131
+ Raises
132
+ ------
133
+ RuntimeError
134
+ If compilation fails
135
+ """
136
+
137
+ self._uuid = str(uuid4())
138
+ args = ['make', '-f', 'build_binder.mk']
139
+ env = os.environ.copy()
140
+ env['PRJ_NAME'] = self._prj_name
141
+ env['STAMP'] = self._uuid
142
+ env['EXTRA_CXXFLAGS'] = '-fopenmp' if openmp else ''
143
+ if o3:
144
+ args.append('fast')
145
+
146
+ if clean:
147
+ m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
148
+ for p in self._path.iterdir():
149
+ if not p.is_dir() and m.match(p.name):
150
+ p.unlink()
151
+
152
+ try:
153
+ r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
154
+ except subprocess.CalledProcessError as e:
155
+ print(e.stderr.decode(), file=sys.stderr)
156
+ print(e.stdout.decode(), file=sys.stdout)
157
+ raise RuntimeError('Compilation failed!!') from e
158
+ if r.returncode != 0:
159
+ print(r.stderr.decode(), file=sys.stderr)
160
+ print(r.stdout.decode(), file=sys.stderr)
161
+ raise RuntimeError('Compilation failed!!')
162
+
163
+ self._load_lib(self._uuid)
164
+
165
+ def _load_lib(self, uuid: str | None = None):
166
+ uuid = uuid if uuid is not None else self._uuid
167
+ self._uuid = uuid
168
+ lib_path = self._path / f'lib{self._prj_name}_{uuid}.so'
169
+ if not lib_path.exists():
170
+ raise RuntimeError(f'Library {lib_path} does not exist')
171
+ self._lib = ctypes.CDLL(str(lib_path))
172
+
173
+ def compile(self, verbose=False, openmp=True, o3: bool = False, clean=True):
174
+ """Compile the model to a shared object file
175
+
176
+ Parameters
177
+ ----------
178
+ verbose : bool, optional
179
+ Verbose output, by default False
180
+ openmp : bool, optional
181
+ Enable openmp, by default True
182
+ o3 : bool | None, optional
183
+ Turn on -O3 flag, by default False
184
+ clean : bool, optional
185
+ Remove obsolete shared object files, by default True
186
+
187
+ Raises
188
+ ------
189
+ RuntimeError
190
+ If compilation fails
191
+ """
192
+ self.write()
193
+ self._compile(verbose, openmp, o3, clean)
194
+
195
+ def predict(self, data: NDArray[T]) -> NDArray[T]:
196
+ """Run the model on the input data.
197
+
198
+ Parameters
199
+ ----------
200
+ data : NDArray[np.floating]
201
+ Input data to the model. The shape is ignored, and the number of samples is
202
+ determined by the size of the data.
203
+
204
+ Returns
205
+ -------
206
+ NDArray[np.floating]
207
+ Output of the model in shape (n_samples, output_size).
208
+ """
209
+ assert self._lib is not None, 'Library not loaded, call .compile() first.'
210
+ inp_size, out_size = self._solution.shape
211
+
212
+ dtype = data.dtype
213
+ if dtype not in (np.float32, np.float64):
214
+ raise TypeError(f'Unsupported input data type: {dtype}. Expected float32 or float64.')
215
+ c_dtype = ctypes.c_float if dtype == np.float32 else ctypes.c_double
216
+
217
+ assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
218
+ n_sample = data.size // inp_size
219
+
220
+ inp_data = np.ascontiguousarray(data)
221
+ out_data = np.empty(n_sample * out_size, dtype=dtype)
222
+
223
+ inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(c_dtype))
224
+ out_buf = out_data.ctypes.data_as(ctypes.POINTER(c_dtype))
225
+ if dtype == np.float32:
226
+ self._lib.inference_f32(inp_buf, out_buf, n_sample)
227
+ else:
228
+ self._lib.inference_f64(inp_buf, out_buf, n_sample)
229
+
230
+ return out_data.reshape(n_sample, out_size) # type: ignore
231
+
232
+ def __repr__(self):
233
+ inp_size, out_size = self._solution.shape
234
+ inp_size, out_size = self._solution.shape
235
+ cost = round(self._solution.cost)
236
+ inp_kifs = tuple(zip(*map(_minimal_kif, self._solution.inp_qint)))
237
+ out_kifs = tuple(zip(*map(_minimal_kif, self._solution.out_qint)))
238
+ in_bits, out_bits = np.sum(inp_kifs), np.sum(out_kifs)
239
+
240
+ spec = f"""Top Function: {self._prj_name}\n====================
241
+ {inp_size} ({in_bits} bits) -> {out_size} ({out_bits} bits)
242
+ combinational @ delay={self._solution.latency}
243
+ Estimated cost: {cost} LUTs"""
244
+
245
+ is_compiled = self._lib is not None
246
+ if is_compiled:
247
+ assert self._uuid is not None
248
+ openmp = 'with OpenMP' if self._lib.openmp_enabled() else '' # type: ignore
249
+ spec += f'\nEmulator is compiled {openmp} ({self._uuid[-12:]})'
250
+ else:
251
+ spec += '\nEmulator is **not compiled**'
252
+ return spec