da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +204 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +246 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.1.post1.dist-info/METADATA +85 -0
  92. da4ml-0.5.1.post1.dist-info/RECORD +96 -0
  93. da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
  94. da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
da4ml/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from . import cmvm, codegen, converter, trace, typing
2
+ from ._version import *
3
+
4
+ __all__ = ['cmvm', 'codegen', 'converter', 'trace', 'typing']
@@ -0,0 +1,15 @@
1
+ import numpy as np
2
+ from numpy.typing import NDArray
3
+
4
+ from .dais_bin import run_interp
5
+
6
+
7
+ def dais_interp_run(bin_logic: NDArray[np.int32], data: NDArray, n_threads: int = 1):
8
+ inp_size = int(bin_logic[2])
9
+
10
+ assert data.size % inp_size == 0, f'Input size {data.size} is not divisible by {inp_size}'
11
+
12
+ inputs = np.ascontiguousarray(np.ravel(data), dtype=np.float64)
13
+ bin_logic = np.ascontiguousarray(np.ravel(bin_logic), dtype=np.int32)
14
+
15
+ return run_interp(bin_logic, inputs, n_threads)
@@ -0,0 +1,5 @@
1
+ import numpy
2
+ from numpy.typing import NDArray
3
+
4
+
5
+ def run_interp(bin_logic: NDArray[numpy.int32], data: NDArray[numpy.float64], n_threads: int = 1) -> NDArray[numpy.float64]: ...
da4ml/_cli/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ import argparse
2
+
3
+ from .. import _version
4
+ from .convert import _add_convert_args, convert_main
5
+ from .report import _add_report_args, report_main
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description='Welcome to the da4ml command line interface')
10
+ subparsers = parser.add_subparsers(dest='command')
11
+
12
+ convert_parser = subparsers.add_parser('convert', help='Convert a Keras model to RTL project')
13
+ report_parser = subparsers.add_parser('report', help='Generate report from an existing RTL projects')
14
+ _add_convert_args(convert_parser)
15
+ _add_report_args(report_parser)
16
+ parser.add_argument('--version', '-v', action='version', version=f'%(prog)s {_version.__version__}')
17
+ args = parser.parse_args()
18
+
19
+ match args.command:
20
+ case 'convert':
21
+ convert_main(args)
22
+ case 'report':
23
+ report_main(args)
24
+ case _:
25
+ parser.print_help()
26
+ exit(1)
27
+
28
+
29
+ if __name__ == '__main__':
30
+ main()
da4ml/_cli/convert.py ADDED
@@ -0,0 +1,204 @@
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+
7
+
8
+ def to_da4ml(
9
+ model_path: Path,
10
+ path: Path,
11
+ n_test_sample: int,
12
+ period: float,
13
+ unc: float,
14
+ flavor: str,
15
+ latency_cutoff: int,
16
+ part_name: str,
17
+ verbose: int = 1,
18
+ rtl_validation: bool = False,
19
+ hwconf: tuple[int, int, int] = (1, -1, -1),
20
+ hard_dc: int = 2,
21
+ openmp: bool = True,
22
+ n_threads: int = 4,
23
+ metadata=None,
24
+ inputs_kif: tuple[int, int, int] | None = None,
25
+ ):
26
+ from da4ml.cmvm.types import CombLogic
27
+ from da4ml.codegen import RTLModel
28
+ from da4ml.converter import trace_model
29
+ from da4ml.trace import HWConfig, comb_trace
30
+
31
+ if model_path.suffix in {'.h5', '.keras'}:
32
+ import hgq # noqa: F401
33
+ import keras
34
+
35
+ model: keras.Model = keras.models.load_model(model_path, compile=False) # type: ignore
36
+ if verbose > 1:
37
+ model.summary()
38
+ inp, out = trace_model(model, HWConfig(*hwconf), {'hard_dc': hard_dc}, verbose > 1, inputs_kif=inputs_kif)
39
+ comb = comb_trace(inp, out)
40
+
41
+ elif model_path.suffix == '.json':
42
+ comb = CombLogic.load(model_path)
43
+ model = None # type: ignore
44
+
45
+ else:
46
+ raise ValueError(f'Unsupported model file format: {model_path}')
47
+
48
+ rtl_model = RTLModel(
49
+ comb,
50
+ 'model',
51
+ path,
52
+ flavor=flavor,
53
+ latency_cutoff=latency_cutoff,
54
+ print_latency=True,
55
+ clock_uncertainty=unc / 100,
56
+ clock_period=period,
57
+ part_name=part_name,
58
+ )
59
+ rtl_model.write(metadata)
60
+ if verbose > 1:
61
+ print(rtl_model)
62
+ print('Model written')
63
+ if not n_test_sample:
64
+ return
65
+
66
+ if model is not None:
67
+ data_in = [np.random.rand(n_test_sample, *inp.shape[1:]).astype(np.float32) * 64 - 32 for inp in model.inputs]
68
+ if len(data_in) == 1:
69
+ data_in = data_in[0]
70
+ y_keras = model.predict(data_in, batch_size=16384, verbose=0) # type: ignore
71
+
72
+ if isinstance(y_keras, list):
73
+ y_keras = np.concatenate([y.reshape(n_test_sample, -1) for y in y_keras], axis=1)
74
+ else:
75
+ y_keras = y_keras.reshape(n_test_sample, -1)
76
+ y_comb = comb.predict(data_in, n_threads=n_threads)
77
+
78
+ total = y_comb.size
79
+ mask = y_comb != y_keras
80
+ ndiff = np.sum(mask)
81
+ if ndiff:
82
+ n_nonzero = np.sum(y_keras != 0)
83
+ abs_diff = np.abs(y_comb - y_keras)[mask]
84
+ rel_diff = abs_diff / (np.abs(y_keras[np.where(mask)]) + 1e-6)
85
+
86
+ max_diff, max_rel_diff = np.max(abs_diff), np.max(rel_diff)
87
+ mean_diff, mean_rel_diff = np.mean(abs_diff), np.mean(rel_diff)
88
+ print(
89
+ f'[WARNING] {ndiff}/{total} ({n_nonzero}) mismatches ({max_diff=}, {max_rel_diff=}, {mean_diff=}, {mean_rel_diff=})'
90
+ )
91
+ else:
92
+ max_diff = max_rel_diff = mean_diff = mean_rel_diff = 0.0
93
+ if verbose:
94
+ print(f'[INFO] DAIS simulation passed: [0/{total}] mismatches.')
95
+ with open(path / 'mismatches.json', 'w') as f:
96
+ json.dump(
97
+ {
98
+ 'n_total': int(total),
99
+ 'n_mismatch': int(ndiff),
100
+ 'max_diff': float(max_diff),
101
+ 'max_rel_diff': float(max_rel_diff),
102
+ 'mean_diff': float(mean_diff),
103
+ 'mean_rel_diff': float(mean_rel_diff),
104
+ },
105
+ f,
106
+ )
107
+ else:
108
+ if not rtl_validation:
109
+ return
110
+ data_in = np.random.rand(n_test_sample, comb.shape[0]).astype(np.float32) * 64 - 32
111
+ y_comb = comb.predict(data_in, n_threads=n_threads)
112
+ total = y_comb.size
113
+
114
+ if not rtl_validation:
115
+ return
116
+
117
+ if verbose > 1:
118
+ print('Verilating...')
119
+ for _ in range(3):
120
+ try:
121
+ rtl_model._compile(nproc=n_threads, openmp=openmp)
122
+ break
123
+ except RuntimeError:
124
+ pass
125
+ y_da4ml = rtl_model.predict(data_in)
126
+ if not np.all(y_comb == y_da4ml):
127
+ raise RuntimeError(f'[CRITICAL ERROR] RTL validation failed: {np.sum(y_comb != y_da4ml)}/{total} mismatches!')
128
+ if verbose:
129
+ print(f'[INFO] RTL validation passed: [0/{total}] mismatches.')
130
+
131
+
132
+ def convert_main(args):
133
+ args.outdir.mkdir(parents=True, exist_ok=True)
134
+ hw_conf = tuple(args.hw_config)
135
+ if args.metadata is not None:
136
+ with open(args.metadata) as f:
137
+ metadata = json.load(f)
138
+ else:
139
+ metadata = None
140
+
141
+ to_da4ml(
142
+ args.model,
143
+ args.outdir,
144
+ args.n_test_sample,
145
+ args.clock_period,
146
+ args.clock_uncertainty,
147
+ latency_cutoff=args.latency_cutoff,
148
+ part_name=args.part_name,
149
+ flavor=args.flavor,
150
+ verbose=args.verbose,
151
+ rtl_validation=args.validate_rtl,
152
+ hwconf=hw_conf,
153
+ hard_dc=args.delay_constraint,
154
+ openmp=not args.no_openmp,
155
+ n_threads=args.n_threads,
156
+ metadata=metadata,
157
+ inputs_kif=args.inputs_kif,
158
+ )
159
+
160
+
161
+ def _add_convert_args(parser: argparse.ArgumentParser):
162
+ parser.add_argument('model', type=Path, help='Path to the Keras model file (.h5 or .keras)')
163
+ parser.add_argument('outdir', type=Path, help='Output directory')
164
+ parser.add_argument('--n-test-sample', '-n', type=int, default=131072, help='Number of test samples for validation')
165
+ parser.add_argument('--clock-period', '-c', type=float, default=5.0, help='Clock period in ns')
166
+ parser.add_argument('--clock-uncertainty', '-unc', type=float, default=10.0, help='Clock uncertainty in percent')
167
+ parser.add_argument('--flavor', type=str, default='verilog', help='Flavor for DA4ML (verilog/vhdl)')
168
+ parser.add_argument('--latency-cutoff', '-lc', type=float, default=5, help='Latency cutoff for pipelining')
169
+ parser.add_argument('--part-name', '-p', type=str, default='xcvu13p-flga2577-2-e', help='FPGA part name')
170
+ parser.add_argument('--verbose', '-v', default=1, type=int, help='Set verbosity level (0: silent, 1: info, 2: debug)')
171
+ parser.add_argument('--validate-rtl', '-vr', action='store_true', help='Validate RTL by Verilator (and GHDL)')
172
+ parser.add_argument('--n-threads', '-j', type=int, default=4, help='Number of threads for compilation and DAIS simulation')
173
+ parser.add_argument('--metadata', '-meta', type=str, default=None, help='Path to metadata JSON file to be included')
174
+ parser.add_argument(
175
+ '--hw-config',
176
+ '-hc',
177
+ type=int,
178
+ nargs=3,
179
+ metavar=('ACCUM_SIZE', 'ADDER_SIZE', 'CUTOFF'),
180
+ default=[1, -1, -1],
181
+ help='Size of accumulator and adder, and cutoff threshold during tracing. No need to modify unless you know what you are doing.',
182
+ )
183
+ parser.add_argument('--delay-constraint', '-dc', type=int, default=2, help='Delay constraint for each CMVM block')
184
+ parser.add_argument(
185
+ '--no-openmp',
186
+ '--no-omp',
187
+ action='store_true',
188
+ help='Disable OpenMP in RTL simulation; no effect if --validate-rtl is not set',
189
+ )
190
+ parser.add_argument(
191
+ '--inputs-kif',
192
+ '-ikif',
193
+ type=int,
194
+ nargs=3,
195
+ default=None,
196
+ help='Input precision in KIF format (keep_neg, int bits, frac bits), if known.',
197
+ )
198
+
199
+
200
+ if __name__ == '__main__':
201
+ parser = argparse.ArgumentParser(description='Convert Keras model to da4ml RTL model with random input test vectors')
202
+ _add_convert_args(parser)
203
+ args = parser.parse_args()
204
+ convert_main(args)
da4ml/_cli/report.py ADDED
@@ -0,0 +1,295 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ from math import ceil, log10
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ def parse_timing_summary(timing_summary: str):
11
+ loc0 = timing_summary.find('Design Timing Summary')
12
+ lines = timing_summary[loc0:].split('\n')[3:10]
13
+ lines = [line for line in lines if line.strip() != '']
14
+
15
+ assert set(lines[1]) == {' ', '-'}
16
+ keys = [k.strip() for k in lines[0].split(' ') if k]
17
+ vals = [int(v) if '.' not in v else float(v) for v in lines[2].split(' ') if v]
18
+ assert len(keys) == len(vals)
19
+ d = dict(zip(keys, vals))
20
+ return d
21
+
22
+
23
+ track = [
24
+ 'DSPs',
25
+ 'LUT as Logic',
26
+ 'LUT as Memory',
27
+ 'CLB Registers',
28
+ 'CARRY8',
29
+ 'Register as Latch',
30
+ 'Register as Flip Flop',
31
+ 'RAMB18',
32
+ 'URAM',
33
+ 'RAMB36/FIFO*',
34
+ ]
35
+
36
+ mms = []
37
+ for name in track:
38
+ m = re.compile(
39
+ rf'\|\s*{name}\s*\|\s*(?P<Used>\d+)\s*\|\s*(?P<Fixed>\d+)\s*\|\s*(?P<Prohibited>\d+)\s*\|\s*(?P<Available>\d+)\s*\|'
40
+ )
41
+ mms.append(m)
42
+
43
+
44
+ def parse_utilization(utilization: str):
45
+ """
46
+ Parse the utilization report and return a DataFrame with the results.
47
+ """
48
+
49
+ dd = {}
50
+ for name, m in zip(track, mms):
51
+ found = m.findall(utilization)
52
+ # assert found, f"{name} not found in utilization report"
53
+ used, fixed, prohibited, available = map(int, found[0])
54
+ dd[name] = used
55
+ dd[f'{name}_fixed'] = fixed
56
+ dd[f'{name}_prohibited'] = prohibited
57
+ dd[f'{name}_available'] = available
58
+
59
+ dd['FF'] = dd['Register as Flip Flop'] + dd['Register as Latch']
60
+ dd['LUT'] = dd['LUT as Logic'] + dd['LUT as Memory']
61
+ dd['LUT_available'] = max(dd['LUT as Logic_available'], dd['LUT as Memory_available'])
62
+ dd['FF_available'] = max(dd['Register as Flip Flop_available'], dd['Register as Latch_available'])
63
+ dd['DSP'] = dd['DSPs']
64
+
65
+ return dd
66
+
67
+
68
+ def _load_project(path: str | Path) -> dict[str, Any]:
69
+ path = Path(path)
70
+ build_tcl_path = path / 'build_vivado_prj.tcl'
71
+ assert build_tcl_path.exists(), f'build_vivado_prj.tcl not found in {path}'
72
+ top_name = build_tcl_path.read_text().split('"', 2)[1]
73
+
74
+ with open(path / f'src/{top_name}.xdc') as f:
75
+ target_clock_period = float(f.readline().strip().split()[2])
76
+ with open(path / 'metadata.json') as f:
77
+ metadata = json.load(f)
78
+
79
+ if metadata['flavor'] == 'vhdl':
80
+ with open(path / f'src/{top_name}.vhd') as f: # type: ignore
81
+ latency = f.read().count('register') // 2
82
+ else:
83
+ with open(path / f'src/{top_name}.v') as f: # type: ignore
84
+ latency = f.read().count('reg') - 1
85
+
86
+ d = {**metadata, 'clock_period': target_clock_period, 'latency': latency}
87
+
88
+ if (path / f'output_{top_name}/reports/{top_name}_post_route_util.rpt').exists():
89
+ with open(path / f'output_{top_name}/reports/{top_name}_post_route_util.rpt') as f:
90
+ util_rpt = f.read()
91
+ util = parse_utilization(util_rpt)
92
+
93
+ with open(path / f'output_{top_name}/reports/{top_name}_post_route_timing.rpt') as f:
94
+ timing_rpt = f.read()
95
+ timing = parse_timing_summary(timing_rpt)
96
+ d.update(timing)
97
+ d.update(util)
98
+
99
+ d['actual_period'] = d['clock_period'] - d['WNS(ns)']
100
+ d['Fmax(MHz)'] = 1000.0 / d['actual_period']
101
+ d['latency(ns)'] = d['latency'] * d['actual_period']
102
+
103
+ return d
104
+
105
+
106
+ def load_project(path: str | Path) -> dict[str, Any] | None:
107
+ try:
108
+ return _load_project(path)
109
+ except Exception as e:
110
+ print(e)
111
+ return None
112
+
113
+
114
+ def extra_info_from_fname(fname: str):
115
+ d = {}
116
+ for part in fname.split('-'):
117
+ if '=' not in part:
118
+ continue
119
+ k, v = part.split('=', 1)
120
+ try:
121
+ v = int(v)
122
+ d[k] = v
123
+ continue
124
+ except ValueError:
125
+ pass
126
+ try:
127
+ v = float(v)
128
+ d[k] = v
129
+ continue
130
+ except ValueError:
131
+ pass
132
+ d[k] = v
133
+ return d
134
+
135
+
136
+ def pretty_print(arr: list[list]):
137
+ n_cols = len(arr[0])
138
+ terminal_width = os.get_terminal_size().columns
139
+ default_width = [
140
+ max(min(6, len(str(arr[i][j]))) if isinstance(arr[i][j], float) else len(str(arr[i][j])) for i in range(len(arr)))
141
+ for j in range(n_cols)
142
+ ]
143
+ if sum(default_width) + 2 * n_cols + 1 <= terminal_width:
144
+ col_width = default_width
145
+ else:
146
+ th = max(8, (terminal_width - 2 * n_cols - 1) // n_cols)
147
+ col_width = [min(w, th) for w in default_width]
148
+
149
+ header = [
150
+ '| ' + ' | '.join(f'{str(arr[0][i]).ljust(col_width[i])[: col_width[i]]}' for i in range(n_cols)) + ' |',
151
+ '|-' + '-|-'.join('-' * col_width[i] for i in range(n_cols)) + '-|',
152
+ ]
153
+ content = []
154
+ for row in arr[1:]:
155
+ _row = []
156
+ for i, v in enumerate(row):
157
+ w = col_width[i]
158
+ if type(v) is float:
159
+ n_int = ceil(log10(abs(v) + 1)) if v != 0 else 1 + (v < 0)
160
+ v = round(v, 10 - n_int)
161
+ if type(v) is int:
162
+ fmt = f'{{:>{w}d}}'
163
+ _v = fmt.format(v)
164
+ else:
165
+ _v = str(v)
166
+ if len(_v) > w:
167
+ fmt = f'{{:.{max(w - n_int - 1, 0)}f}}'
168
+ _v = fmt.format(v).ljust(w)
169
+ else:
170
+ _v = _v.ljust(w)
171
+ else:
172
+ _v = str(v).ljust(w)[:w]
173
+ _row.append(_v)
174
+ content.append('| ' + ' | '.join(_row) + ' |')
175
+ print('\n'.join(header + content))
176
+
177
+
178
+ def stdout_print(arr: list[list], full: bool, columns: list[str] | None):
179
+ whitelist = [
180
+ 'epoch',
181
+ 'flavor',
182
+ 'actual_period',
183
+ 'clock_period',
184
+ 'ebops',
185
+ 'cost',
186
+ 'latency',
187
+ 'DSP',
188
+ 'LUT',
189
+ 'FF',
190
+ 'comb_metric',
191
+ 'Fmax(MHz)',
192
+ 'latency(ns)',
193
+ ]
194
+ if columns is None:
195
+ columns = whitelist
196
+
197
+ if not full:
198
+ idx_row = arr[0]
199
+ keep_cols = [idx_row.index(col) for col in columns if col in idx_row]
200
+ arr = [[row[i] for i in keep_cols] for row in arr]
201
+
202
+ if len(arr) == 2: # One sample
203
+ k_width = max(len(str(h)) for h in arr[0])
204
+ for k, v in zip(arr[0], arr[1]):
205
+ print(f'{str(k).ljust(k_width)} : {v}')
206
+ else:
207
+ pretty_print(arr)
208
+
209
+
210
+ def report_main(args):
211
+ _vals = [load_project(Path(p)) for p in args.paths]
212
+ vals = [v for v in _vals if v is not None]
213
+ for path, val in zip(args.paths, vals):
214
+ d = extra_info_from_fname(Path(path).name)
215
+ for k, v in d.items():
216
+ val.setdefault(k, v)
217
+
218
+ _key = [x.get(args.sort_by, float('inf')) for x in vals]
219
+ _order = sorted(range(len(vals)), key=lambda i: -_key[i])
220
+ vals = [vals[i] for i in _order]
221
+
222
+ _attrs: set[str] = set()
223
+ for v in vals:
224
+ _attrs.update(v.keys())
225
+ attrs = sorted(_attrs)
226
+ arr: list[list] = [attrs]
227
+ for v in vals:
228
+ arr.append([v.get(a, '') for a in attrs])
229
+
230
+ output = args.output
231
+ if output == 'stdout':
232
+ stdout_print(arr, args.full, args.columns)
233
+ return
234
+
235
+ with open(output, 'w') as f:
236
+ ext = Path(output).suffix
237
+ if ext == '.json':
238
+ json.dump(vals, f)
239
+ elif ext in ['.tsv', '.csv']:
240
+ sep = ',' if ext == '.csv' else '\t'
241
+ op = (lambda x: str(x) if ',' not in str(x) else f'"{str(x)}"') if ext == '.csv' else lambda x: str(x)
242
+ for row in arr:
243
+ f.write(sep.join(map(op, row)) + '\n') # type: ignore
244
+ elif ext == '.md':
245
+ f.write('| ' + ' | '.join(map(str, arr[0])) + ' |\n')
246
+ f.write('|' + '|'.join(['---'] * len(arr[0])) + '|\n')
247
+ for row in arr[1:]:
248
+ f.write('| ' + ' | '.join(map(str, row)) + ' |\n')
249
+ elif ext == '.html':
250
+ f.write('<table>\n')
251
+ f.write(' <tr>' + ''.join([f'<th>{a}</th>' for a in arr[0]]) + '</tr>\n')
252
+ for row in arr[1:]:
253
+ f.write(' <tr>' + ''.join([f'<td>{a}</td>' for a in row]) + '</tr>\n')
254
+ f.write('</table>\n')
255
+ else:
256
+ raise ValueError(f'Unsupported output format: {ext}')
257
+
258
+
259
+ def _add_report_args(parser: argparse.ArgumentParser):
260
+ parser.add_argument('paths', type=str, nargs='+', help='Paths to the directories containing HDL summaries')
261
+ parser.add_argument(
262
+ '--output',
263
+ '-o',
264
+ type=str,
265
+ default='stdout',
266
+ help='Output file name for the summary. Can be stdout, .json, .csv, .tsv, .md, .html',
267
+ )
268
+ parser.add_argument(
269
+ '--sort-by',
270
+ '-s',
271
+ type=str,
272
+ default='comb_metric',
273
+ help='Attribute to sort the summary by. Default is cost.',
274
+ )
275
+ parser.add_argument(
276
+ '--full',
277
+ '-f',
278
+ action='store_true',
279
+ help='Include full information for stdout output. For file output, all information will always be included.',
280
+ )
281
+ parser.add_argument(
282
+ '--columns',
283
+ '-c',
284
+ type=str,
285
+ nargs='+',
286
+ default=None,
287
+ help='Specify columns to include in the report. Only applicable for stdout output. Ignored if --full is set.',
288
+ )
289
+
290
+
291
+ if __name__ == '__main__':
292
+ parser = argparse.ArgumentParser(description='Load HDL summaries')
293
+ _add_report_args(parser)
294
+ args = parser.parse_args()
295
+ report_main(args)
da4ml/_version.py ADDED
@@ -0,0 +1,32 @@
1
+ __all__ = [
2
+ '__version__',
3
+ '__version_tuple__',
4
+ 'version',
5
+ 'version_tuple',
6
+ '__commit_id__',
7
+ 'commit_id',
8
+ ]
9
+
10
+ TYPE_CHECKING = False
11
+ if TYPE_CHECKING:
12
+ VERSION_TUPLE = tuple[int | str, ...]
13
+ COMMIT_ID = str | None
14
+ else:
15
+ VERSION_TUPLE = object
16
+ COMMIT_ID = object
17
+
18
+ version: str
19
+ __version__: str
20
+ __version_tuple__: VERSION_TUPLE
21
+ version_tuple: VERSION_TUPLE
22
+ commit_id: COMMIT_ID
23
+ __commit_id__: COMMIT_ID
24
+
25
+ __full_version = "0.5.1post1"
26
+ __version__ = version = __full_version.split('-')[0]
27
+ __version_tuple__ = version_tuple = tuple(
28
+ int(part) if part.isdigit() else part
29
+ for part in __version__.split('.')
30
+ )
31
+
32
+ __commit_id__ = commit_id = __full_version.rsplit('-', 1)[-1] if '-' in __full_version else None
da4ml/cmvm/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .api import minimal_latency, solve
2
+ from .types import CombLogic, Op, QInterval
3
+
4
+ __all__ = ['minimal_latency', 'solve', 'QInterval', 'Op', 'CombLogic']