da4ml 0.4.1__py3-none-any.whl → 0.5.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +2 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +2 -2
- da4ml/cmvm/api.py +15 -4
- da4ml/cmvm/core/__init__.py +2 -2
- da4ml/cmvm/types.py +32 -18
- da4ml/cmvm/util/bit_decompose.py +2 -2
- da4ml/codegen/hls/hls_codegen.py +10 -5
- da4ml/codegen/hls/hls_model.py +7 -4
- da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +11 -13
- da4ml/codegen/rtl/rtl_model.py +105 -54
- da4ml/codegen/rtl/verilog/__init__.py +2 -1
- da4ml/codegen/rtl/verilog/comb.py +47 -7
- da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
- da4ml/codegen/rtl/verilog/pipeline.py +12 -12
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/vhdl/comb.py +27 -21
- da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
- da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/converter/__init__.py +57 -1
- da4ml/converter/hgq2/parser.py +4 -25
- da4ml/converter/hgq2/replica.py +208 -22
- da4ml/trace/fixed_variable.py +239 -29
- da4ml/trace/fixed_variable_array.py +276 -48
- da4ml/trace/ops/__init__.py +31 -15
- da4ml/trace/ops/reduce_utils.py +3 -3
- da4ml/trace/pipeline.py +40 -18
- da4ml/trace/tracer.py +33 -8
- da4ml/typing/__init__.py +3 -0
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
- da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
da4ml/codegen/rtl/rtl_model.py
CHANGED
|
@@ -1,30 +1,45 @@
|
|
|
1
1
|
import ctypes
|
|
2
|
+
import json
|
|
2
3
|
import os
|
|
3
4
|
import re
|
|
4
5
|
import shutil
|
|
5
6
|
import subprocess
|
|
6
7
|
import sys
|
|
8
|
+
from collections.abc import Sequence
|
|
7
9
|
from pathlib import Path
|
|
8
10
|
from uuid import uuid4
|
|
9
11
|
|
|
10
12
|
import numpy as np
|
|
11
13
|
from numpy.typing import NDArray
|
|
12
14
|
|
|
13
|
-
from ...cmvm.types import
|
|
15
|
+
from ...cmvm.types import CombLogic, Pipeline, _minimal_kif
|
|
14
16
|
from ...trace.pipeline import to_pipeline
|
|
15
17
|
from .. import rtl
|
|
16
18
|
|
|
17
19
|
|
|
18
|
-
def get_io_kifs(sol:
|
|
20
|
+
def get_io_kifs(sol: CombLogic | Pipeline):
|
|
19
21
|
inp_kifs = tuple(zip(*map(_minimal_kif, sol.inp_qint)))
|
|
20
22
|
out_kifs = tuple(zip(*map(_minimal_kif, sol.out_qint)))
|
|
21
23
|
return np.array(inp_kifs, np.int8), np.array(out_kifs, np.int8)
|
|
22
24
|
|
|
23
25
|
|
|
26
|
+
class at_path:
|
|
27
|
+
def __init__(self, path: str | Path):
|
|
28
|
+
self._path = Path(path)
|
|
29
|
+
self._orig_cwd = None
|
|
30
|
+
|
|
31
|
+
def __enter__(self):
|
|
32
|
+
self._orig_cwd = Path.cwd()
|
|
33
|
+
os.chdir(self._path)
|
|
34
|
+
|
|
35
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
36
|
+
os.chdir(self._orig_cwd) # type: ignore
|
|
37
|
+
|
|
38
|
+
|
|
24
39
|
class RTLModel:
|
|
25
40
|
def __init__(
|
|
26
41
|
self,
|
|
27
|
-
solution:
|
|
42
|
+
solution: CombLogic | Pipeline,
|
|
28
43
|
prj_name: str,
|
|
29
44
|
path: str | Path,
|
|
30
45
|
flavor: str = 'verilog',
|
|
@@ -51,9 +66,9 @@ class RTLModel:
|
|
|
51
66
|
|
|
52
67
|
assert self._flavor in ('vhdl', 'verilog'), f'Unsupported flavor {flavor}, only vhdl and verilog are supported.'
|
|
53
68
|
|
|
54
|
-
self._pipe = solution if isinstance(solution,
|
|
69
|
+
self._pipe = solution if isinstance(solution, Pipeline) else None
|
|
55
70
|
if latency_cutoff > 0 and self._pipe is None:
|
|
56
|
-
assert isinstance(solution,
|
|
71
|
+
assert isinstance(solution, CombLogic)
|
|
57
72
|
self._pipe = to_pipeline(solution, latency_cutoff, verbose=False)
|
|
58
73
|
|
|
59
74
|
if self._pipe is not None:
|
|
@@ -72,33 +87,46 @@ class RTLModel:
|
|
|
72
87
|
else: # verilog
|
|
73
88
|
from .verilog import binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_logic_gen
|
|
74
89
|
|
|
75
|
-
|
|
90
|
+
from .verilog.comb import table_mem_gen
|
|
91
|
+
|
|
92
|
+
(self._path / 'src/static').mkdir(parents=True, exist_ok=True)
|
|
93
|
+
(self._path / 'sim').mkdir(exist_ok=True)
|
|
94
|
+
(self._path / 'model').mkdir(exist_ok=True)
|
|
95
|
+
(self._path / 'src/memfiles').mkdir(exist_ok=True)
|
|
76
96
|
if self._pipe is not None: # Pipeline
|
|
77
97
|
# Main logic
|
|
78
98
|
codes = pipeline_logic_gen(self._pipe, self._prj_name, self._print_latency, register_layers=self._register_layers)
|
|
99
|
+
|
|
100
|
+
# Table memory files
|
|
101
|
+
memfiles: dict[str, str] = {}
|
|
102
|
+
for comb in self._pipe.solutions:
|
|
103
|
+
memfiles.update(table_mem_gen(comb))
|
|
104
|
+
|
|
79
105
|
for k, v in codes.items():
|
|
80
|
-
with open(self._path / f'{k}.{suffix}', 'w') as f:
|
|
106
|
+
with open(self._path / f'src/{k}.{suffix}', 'w') as f:
|
|
81
107
|
f.write(v)
|
|
82
108
|
|
|
83
|
-
# Build
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
109
|
+
# Build scripts
|
|
110
|
+
for path in (self.__src_root).glob('common_source/build_*_prj.tcl'):
|
|
111
|
+
with open(path) as f:
|
|
112
|
+
tcl = f.read()
|
|
113
|
+
tcl = tcl.replace('$::env(DEVICE)', self._part_name)
|
|
114
|
+
tcl = tcl.replace('$::env(PROJECT_NAME)', self._prj_name)
|
|
115
|
+
tcl = tcl.replace('$::env(SOURCE_TYPE)', flavor)
|
|
116
|
+
with open(self._path / path.name, 'w') as f:
|
|
117
|
+
f.write(tcl)
|
|
118
|
+
|
|
119
|
+
# Timing constraint
|
|
120
|
+
for fmt in ('xdc', 'sdc'):
|
|
121
|
+
with open(self.__src_root / f'common_source/template.{fmt}') as f:
|
|
122
|
+
constraint = f.read()
|
|
123
|
+
constraint = constraint.replace('$::env(CLOCK_PERIOD)', str(self._clock_period))
|
|
124
|
+
constraint = constraint.replace('$::env(UNCERTAINITY_SETUP)', str(self._clock_uncertainty))
|
|
125
|
+
constraint = constraint.replace('$::env(UNCERTAINITY_HOLD)', str(self._clock_uncertainty))
|
|
126
|
+
constraint = constraint.replace('$::env(DELAY_MAX)', str(self._io_delay_minmax[1]))
|
|
127
|
+
constraint = constraint.replace('$::env(DELAY_MIN)', str(self._io_delay_minmax[0]))
|
|
128
|
+
with open(self._path / f'src/{self._prj_name}.{fmt}', 'w') as f:
|
|
129
|
+
f.write(constraint)
|
|
102
130
|
|
|
103
131
|
# C++ binder w/ HDL wrapper for uniform bw
|
|
104
132
|
binder = binder_gen(self._pipe, f'{self._prj_name}_wrapper', 1, self._register_layers)
|
|
@@ -106,34 +134,46 @@ class RTLModel:
|
|
|
106
134
|
# Verilog IO wrapper (non-uniform bw to uniform one, clk passthrough)
|
|
107
135
|
io_wrapper = generate_io_wrapper(self._pipe, self._prj_name, True)
|
|
108
136
|
|
|
109
|
-
self._pipe.save(self._path / 'pipeline.json')
|
|
137
|
+
self._pipe.save(self._path / 'model/pipeline.json')
|
|
110
138
|
else: # Comb
|
|
111
|
-
assert isinstance(self._solution,
|
|
139
|
+
assert isinstance(self._solution, CombLogic)
|
|
140
|
+
|
|
141
|
+
# Table memory files
|
|
142
|
+
memfiles = table_mem_gen(self._solution)
|
|
112
143
|
|
|
113
144
|
# Main logic
|
|
114
145
|
code = comb_logic_gen(self._solution, self._prj_name, self._print_latency, '`timescale 1ns/1ps')
|
|
115
|
-
with open(self._path / f'{self._prj_name}.{suffix}', 'w') as f:
|
|
146
|
+
with open(self._path / f'src/{self._prj_name}.{suffix}', 'w') as f:
|
|
116
147
|
f.write(code)
|
|
117
148
|
|
|
118
149
|
# Verilog IO wrapper (non-uniform bw to uniform one, no clk)
|
|
119
150
|
io_wrapper = generate_io_wrapper(self._solution, self._prj_name, False)
|
|
120
151
|
binder = binder_gen(self._solution, f'{self._prj_name}_wrapper')
|
|
121
152
|
|
|
122
|
-
|
|
153
|
+
# Write table memory files
|
|
154
|
+
for name, mem in memfiles.items():
|
|
155
|
+
with open(self._path / 'src/memfiles' / name, 'w') as f:
|
|
156
|
+
f.write(mem)
|
|
157
|
+
|
|
158
|
+
with open(self._path / f'src/{self._prj_name}_wrapper.{suffix}', 'w') as f:
|
|
123
159
|
f.write(io_wrapper)
|
|
124
|
-
with open(self._path / f'{self._prj_name}_wrapper_binder.cc', 'w') as f:
|
|
160
|
+
with open(self._path / f'sim/{self._prj_name}_wrapper_binder.cc', 'w') as f:
|
|
125
161
|
f.write(binder)
|
|
126
162
|
|
|
127
163
|
# Common resource copy
|
|
128
|
-
for
|
|
129
|
-
shutil.copy(
|
|
130
|
-
|
|
131
|
-
shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path)
|
|
132
|
-
shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path)
|
|
133
|
-
shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path)
|
|
134
|
-
self._solution.save(self._path / 'model.json')
|
|
135
|
-
with open(self._path / '
|
|
136
|
-
|
|
164
|
+
for path in self.__src_root.glob(f'{flavor}/source/*.{suffix}'):
|
|
165
|
+
shutil.copy(path, self._path / 'src/static')
|
|
166
|
+
|
|
167
|
+
shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path / 'sim')
|
|
168
|
+
shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path / 'sim')
|
|
169
|
+
shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path / 'sim')
|
|
170
|
+
self._solution.save(self._path / 'model/comb.json')
|
|
171
|
+
with open(self._path / 'metadata.json', 'w') as f:
|
|
172
|
+
misc = {'cost': self._solution.cost}
|
|
173
|
+
if self._pipe is not None:
|
|
174
|
+
misc['latency'] = len(self._pipe[0])
|
|
175
|
+
misc['reg_bits'] = self._pipe.reg_bits
|
|
176
|
+
f.write(json.dumps(misc))
|
|
137
177
|
|
|
138
178
|
def _compile(self, verbose=False, openmp=True, nproc=None, o3: bool = False, clean=True):
|
|
139
179
|
"""Same as compile, but will not write to the library
|
|
@@ -150,7 +190,7 @@ class RTLModel:
|
|
|
150
190
|
o3 : bool | None, optional
|
|
151
191
|
Turn on -O3 flag, by default False
|
|
152
192
|
clean : bool, optional
|
|
153
|
-
Remove obsolete shared object files
|
|
193
|
+
Remove obsolete shared object files and `obj_dir`, by default True
|
|
154
194
|
|
|
155
195
|
Raises
|
|
156
196
|
------
|
|
@@ -170,18 +210,21 @@ class RTLModel:
|
|
|
170
210
|
if o3:
|
|
171
211
|
args.append('fast')
|
|
172
212
|
|
|
173
|
-
if clean
|
|
213
|
+
if clean:
|
|
174
214
|
m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
|
|
175
|
-
for p in self._path.iterdir():
|
|
215
|
+
for p in (self._path / 'sim').iterdir():
|
|
176
216
|
if not p.is_dir() and m.match(p.name):
|
|
177
217
|
p.unlink()
|
|
178
|
-
if clean:
|
|
179
218
|
subprocess.run(
|
|
180
|
-
['make', '-f', 'build_binder.mk', 'clean'],
|
|
219
|
+
['make', '-f', 'build_binder.mk', 'clean'],
|
|
220
|
+
env=env,
|
|
221
|
+
cwd=self._path / 'sim',
|
|
222
|
+
check=True,
|
|
223
|
+
capture_output=not verbose,
|
|
181
224
|
)
|
|
182
225
|
|
|
183
226
|
try:
|
|
184
|
-
r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
|
|
227
|
+
r = subprocess.run(args, env=env, check=True, cwd=self._path / 'sim', capture_output=not verbose)
|
|
185
228
|
except subprocess.CalledProcessError as e:
|
|
186
229
|
print(e.stderr.decode(), file=sys.stderr)
|
|
187
230
|
print(e.stdout.decode(), file=sys.stdout)
|
|
@@ -191,18 +234,21 @@ class RTLModel:
|
|
|
191
234
|
print(r.stdout.decode(), file=sys.stderr)
|
|
192
235
|
raise RuntimeError('Compilation failed!!')
|
|
193
236
|
|
|
237
|
+
if clean:
|
|
238
|
+
subprocess.run(['rm', '-rf', 'obj_dir'], cwd=self._path / 'sim', check=True, capture_output=not verbose)
|
|
239
|
+
|
|
194
240
|
self._load_lib(self._uuid)
|
|
195
241
|
|
|
196
242
|
def _load_lib(self, uuid: str | None = None):
|
|
197
243
|
uuid = uuid if uuid is not None else self._uuid
|
|
198
244
|
if uuid is None:
|
|
199
245
|
# load .so if there is only one, otherwise raise an error
|
|
200
|
-
libs = list(self._path.glob(f'lib{self._prj_name}_wrapper_*.so'))
|
|
246
|
+
libs = list(self._path.glob(f'sim/lib{self._prj_name}_wrapper_*.so'))
|
|
201
247
|
if len(libs) == 0:
|
|
202
248
|
raise RuntimeError(f'Cannot load library, found {len(libs)} libraries in {self._path}')
|
|
203
249
|
uuid = libs[0].name.split('_')[-1].split('.', 1)[0]
|
|
204
250
|
self._uuid = uuid
|
|
205
|
-
lib_path = self._path / f'lib{self._prj_name}_wrapper_{uuid}.so'
|
|
251
|
+
lib_path = self._path / f'sim/lib{self._prj_name}_wrapper_{uuid}.so'
|
|
206
252
|
if not lib_path.exists():
|
|
207
253
|
raise RuntimeError(f'Library {lib_path} does not exist')
|
|
208
254
|
self._lib = ctypes.CDLL(str(lib_path))
|
|
@@ -222,7 +268,7 @@ class RTLModel:
|
|
|
222
268
|
o3 : bool | None, optional
|
|
223
269
|
Turn on -O3 flag, by default False
|
|
224
270
|
clean : bool, optional
|
|
225
|
-
Remove obsolete shared object files
|
|
271
|
+
Remove obsolete shared object files and `obj_dir`, by default True
|
|
226
272
|
|
|
227
273
|
Raises
|
|
228
274
|
------
|
|
@@ -232,12 +278,12 @@ class RTLModel:
|
|
|
232
278
|
self.write()
|
|
233
279
|
self._compile(verbose=verbose, openmp=openmp, nproc=nproc, o3=o3, clean=clean)
|
|
234
280
|
|
|
235
|
-
def predict(self, data: NDArray[np.floating]) -> NDArray[np.float32]:
|
|
281
|
+
def predict(self, data: NDArray[np.floating] | Sequence[NDArray[np.floating]]) -> NDArray[np.float32]:
|
|
236
282
|
"""Run the model on the input data.
|
|
237
283
|
|
|
238
284
|
Parameters
|
|
239
285
|
----------
|
|
240
|
-
data : NDArray[np.floating]
|
|
286
|
+
data : NDArray[np.floating]|Sequence[NDArray[np.floating]]
|
|
241
287
|
Input data to the model. The shape is ignored, and the number of samples is
|
|
242
288
|
determined by the size of the data.
|
|
243
289
|
|
|
@@ -247,6 +293,9 @@ class RTLModel:
|
|
|
247
293
|
Output of the model in shape (n_samples, output_size).
|
|
248
294
|
"""
|
|
249
295
|
|
|
296
|
+
if isinstance(data, Sequence):
|
|
297
|
+
data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
|
|
298
|
+
|
|
250
299
|
assert self._lib is not None, 'Library not loaded, call .compile() first.'
|
|
251
300
|
inp_size, out_size = self._solution.shape
|
|
252
301
|
|
|
@@ -267,7 +316,9 @@ class RTLModel:
|
|
|
267
316
|
|
|
268
317
|
inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
|
|
269
318
|
out_buf = out_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
|
|
270
|
-
|
|
319
|
+
|
|
320
|
+
with at_path(self._path / 'src/memfiles'):
|
|
321
|
+
self._lib.inference(inp_buf, out_buf, n_sample)
|
|
271
322
|
|
|
272
323
|
# Unscale the output int32 to recover fp values
|
|
273
324
|
k, i, f = np.max(k_out), np.max(i_out), np.max(f_out)
|
|
@@ -308,7 +359,7 @@ Estimated cost: {cost} LUTs"""
|
|
|
308
359
|
class VerilogModel(RTLModel):
|
|
309
360
|
def __init__(
|
|
310
361
|
self,
|
|
311
|
-
solution:
|
|
362
|
+
solution: CombLogic | Pipeline,
|
|
312
363
|
prj_name: str,
|
|
313
364
|
path: str | Path,
|
|
314
365
|
latency_cutoff: float = -1,
|
|
@@ -337,7 +388,7 @@ class VerilogModel(RTLModel):
|
|
|
337
388
|
class VHDLModel(RTLModel):
|
|
338
389
|
def __init__(
|
|
339
390
|
self,
|
|
340
|
-
solution:
|
|
391
|
+
solution: CombLogic | Pipeline,
|
|
341
392
|
prj_name: str,
|
|
342
393
|
path: str | Path,
|
|
343
394
|
latency_cutoff: float = -1,
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from .comb import comb_logic_gen
|
|
1
|
+
from .comb import comb_logic_gen, table_mem_gen
|
|
2
2
|
from .io_wrapper import binder_gen, generate_io_wrapper
|
|
3
3
|
from .pipeline import pipeline_logic_gen
|
|
4
4
|
|
|
5
5
|
__all__ = [
|
|
6
6
|
'comb_logic_gen',
|
|
7
|
+
'table_mem_gen',
|
|
7
8
|
'generate_io_wrapper',
|
|
8
9
|
'pipeline_logic_gen',
|
|
9
10
|
'binder_gen',
|
|
@@ -2,7 +2,7 @@ from math import ceil, log2
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
|
-
from ....cmvm.types import Op, QInterval,
|
|
5
|
+
from ....cmvm.types import CombLogic, Op, QInterval, _minimal_kif
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def make_neg(
|
|
@@ -23,16 +23,35 @@ def make_neg(
|
|
|
23
23
|
return bw0, v0_name
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def
|
|
26
|
+
def gen_mem_file(sol: CombLogic, op: Op) -> str:
|
|
27
|
+
assert op.opcode == 8
|
|
28
|
+
assert sol.lookup_tables is not None
|
|
29
|
+
table = sol.lookup_tables[op.data]
|
|
30
|
+
width = sum(table.spec.out_kif)
|
|
31
|
+
ndigits = ceil(width / 4)
|
|
32
|
+
data = table.padded_table(sol.ops[op.id0].qint)
|
|
33
|
+
mem_lines = [f'{hex(value)[2:].upper().zfill(ndigits)}' for value in data & ((1 << width) - 1)]
|
|
34
|
+
return '\n'.join(mem_lines)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_table_name(sol: CombLogic, op: Op) -> str:
|
|
38
|
+
assert sol.lookup_tables is not None
|
|
39
|
+
assert op.opcode == 8
|
|
40
|
+
qint_in = sol.ops[op.id0].qint
|
|
41
|
+
uuid = sol.lookup_tables[op.data].get_uuid(qint_in)
|
|
42
|
+
return f'lut_{uuid}.mem'
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def ssa_gen(sol: CombLogic, neg_defined: set[int], print_latency: bool = False) -> list[str]:
|
|
27
46
|
ops = sol.ops
|
|
28
47
|
kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
29
|
-
widths = list(map(sum, kifs))
|
|
48
|
+
widths: list[int] = list(map(sum, kifs))
|
|
30
49
|
inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
|
|
31
50
|
inp_widths = list(map(sum, inp_kifs))
|
|
32
51
|
_inp_widths = np.cumsum([0] + inp_widths)
|
|
33
52
|
inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
|
|
34
53
|
|
|
35
|
-
lines = []
|
|
54
|
+
lines: list[str] = []
|
|
36
55
|
ref_count = sol.ref_count
|
|
37
56
|
|
|
38
57
|
for i, op in enumerate(ops):
|
|
@@ -84,6 +103,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
84
103
|
if op.opcode == -3 and op.id0 not in neg_defined:
|
|
85
104
|
neg_defined.add(op.id0)
|
|
86
105
|
bw0, v0_name = make_neg(lines, op, ops, bw0, v0_name)
|
|
106
|
+
|
|
87
107
|
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
|
|
88
108
|
|
|
89
109
|
case 4: # constant addition
|
|
@@ -93,9 +113,11 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
93
113
|
bw0 = widths[op.id0]
|
|
94
114
|
s0 = int(kifs[op.id0][0])
|
|
95
115
|
v0 = f'v{op.id0}[{bw0 - 1}:0]'
|
|
96
|
-
v1 = f"'{bin(mag)[1:]}"
|
|
116
|
+
v1 = f"{bw1}'{bin(mag)[1:]}"
|
|
97
117
|
shift = kifs[op.id0][2] - kifs[i][2]
|
|
118
|
+
|
|
98
119
|
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
|
|
120
|
+
|
|
99
121
|
case 5: # constant
|
|
100
122
|
num = op.data
|
|
101
123
|
if num < 0:
|
|
@@ -112,8 +134,13 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
112
134
|
_shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
|
|
113
135
|
shift = f0 - f1 + _shift
|
|
114
136
|
vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
|
|
137
|
+
if bw0 == 0:
|
|
138
|
+
v0, bw0 = "1'b0", 1
|
|
139
|
+
if bw1 == 0:
|
|
140
|
+
v1, bw1 = "1'b0", 1
|
|
115
141
|
|
|
116
142
|
line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
|
|
143
|
+
|
|
117
144
|
case 7: # Multiplication
|
|
118
145
|
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
119
146
|
s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
|
|
@@ -121,6 +148,12 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
121
148
|
|
|
122
149
|
line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
|
|
123
150
|
|
|
151
|
+
case 8: # Lookup Table
|
|
152
|
+
name = get_table_name(sol, op)
|
|
153
|
+
bw0 = widths[op.id0]
|
|
154
|
+
|
|
155
|
+
line = f'{_def} lookup_table #({bw0}, {bw}, "{name}") op_{i} (v{op.id0}, {v});'
|
|
156
|
+
|
|
124
157
|
case _:
|
|
125
158
|
raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
|
|
126
159
|
|
|
@@ -130,7 +163,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
130
163
|
return lines
|
|
131
164
|
|
|
132
165
|
|
|
133
|
-
def output_gen(sol:
|
|
166
|
+
def output_gen(sol: CombLogic, neg_defined: set[int]):
|
|
134
167
|
lines = []
|
|
135
168
|
widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
136
169
|
_widths = np.cumsum([0] + widths)
|
|
@@ -157,7 +190,7 @@ def output_gen(sol: Solution, neg_defined: set[int]):
|
|
|
157
190
|
return lines
|
|
158
191
|
|
|
159
192
|
|
|
160
|
-
def comb_logic_gen(sol:
|
|
193
|
+
def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
|
|
161
194
|
inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
|
|
162
195
|
out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
163
196
|
|
|
@@ -191,3 +224,10 @@ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, tim
|
|
|
191
224
|
if timescale is not None:
|
|
192
225
|
code = f'{timescale}\n\n{code}'
|
|
193
226
|
return code
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def table_mem_gen(sol: CombLogic) -> dict[str, str]:
|
|
230
|
+
if not sol.lookup_tables:
|
|
231
|
+
return {}
|
|
232
|
+
mem_files = {get_table_name(sol, op): gen_mem_file(sol, op) for op in sol.ops if op.opcode == 8}
|
|
233
|
+
return mem_files
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from itertools import accumulate
|
|
2
2
|
|
|
3
|
-
from ....cmvm.types import
|
|
3
|
+
from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
@@ -59,7 +59,7 @@ def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
|
59
59
|
return regular, hetero, pads, (width_regular, width_packed)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def generate_io_wrapper(sol:
|
|
62
|
+
def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
|
|
63
63
|
reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
|
|
64
64
|
reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
|
|
65
65
|
|
|
@@ -113,12 +113,12 @@ endmodule
|
|
|
113
113
|
"""
|
|
114
114
|
|
|
115
115
|
|
|
116
|
-
def binder_gen(csol:
|
|
116
|
+
def binder_gen(csol: Pipeline | CombLogic, module_name: str, II: int = 1, latency_multiplier: int = 1):
|
|
117
117
|
k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
|
|
118
118
|
k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
|
|
119
119
|
max_inp_bw = max(k_in) + max(i_in) + max(f_in)
|
|
120
120
|
max_out_bw = max(k_out) + max(i_out) + max(f_out)
|
|
121
|
-
if isinstance(csol,
|
|
121
|
+
if isinstance(csol, CombLogic):
|
|
122
122
|
II = latency = 0
|
|
123
123
|
else:
|
|
124
124
|
latency = len(csol.solutions) * latency_multiplier
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from ....cmvm.types import
|
|
1
|
+
from ....cmvm.types import Pipeline, _minimal_kif
|
|
2
2
|
from .comb import comb_logic_gen
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def pipeline_logic_gen(
|
|
6
|
-
csol:
|
|
6
|
+
csol: Pipeline,
|
|
7
7
|
name: str,
|
|
8
8
|
print_latency=False,
|
|
9
9
|
timescale: str | None = '`timescale 1 ns / 1 ps',
|
|
@@ -13,36 +13,36 @@ def pipeline_logic_gen(
|
|
|
13
13
|
inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
|
|
14
14
|
out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
|
|
15
15
|
|
|
16
|
-
registers = [f'reg [{width-1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
|
|
16
|
+
registers = [f'reg [{width - 1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
|
|
17
17
|
for i in range(0, register_layers - 1):
|
|
18
|
-
registers += [f'reg [{width-1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
|
|
19
|
-
wires = [f'wire [{width-1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
|
|
18
|
+
registers += [f'reg [{width - 1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
|
|
19
|
+
wires = [f'wire [{width - 1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
|
|
20
20
|
|
|
21
21
|
comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
|
|
22
22
|
|
|
23
23
|
if register_layers == 1:
|
|
24
24
|
serial_logic = ['stage0_inp <= model_inp;']
|
|
25
|
-
serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
|
|
25
|
+
serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
|
|
26
26
|
else:
|
|
27
27
|
serial_logic = ['stage0_inp_copy0 <= model_inp;']
|
|
28
28
|
for j in range(1, register_layers - 1):
|
|
29
|
-
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
|
|
29
|
+
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
|
|
30
30
|
serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
|
|
31
31
|
for i in range(1, N):
|
|
32
|
-
serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
|
|
32
|
+
serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
|
|
33
33
|
for j in range(1, register_layers - 1):
|
|
34
|
-
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
|
|
34
|
+
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
|
|
35
35
|
serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
|
|
36
36
|
|
|
37
|
-
serial_logic += [f'model_out <= stage{N-1}_out;']
|
|
37
|
+
serial_logic += [f'model_out <= stage{N - 1}_out;']
|
|
38
38
|
|
|
39
39
|
sep0 = '\n '
|
|
40
40
|
sep1 = '\n '
|
|
41
41
|
|
|
42
42
|
module = f"""module {name} (
|
|
43
43
|
input clk,
|
|
44
|
-
input [{inp_bits[0]-1}:0] model_inp,
|
|
45
|
-
output reg [{out_bits[-1]-1}:0] model_out
|
|
44
|
+
input [{inp_bits[0] - 1}:0] model_inp,
|
|
45
|
+
output reg [{out_bits[-1] - 1}:0] model_out
|
|
46
46
|
);
|
|
47
47
|
|
|
48
48
|
{sep0.join(registers)}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module lookup_table #(
|
|
5
|
+
parameter BW_IN = 8,
|
|
6
|
+
parameter BW_OUT = 8,
|
|
7
|
+
parameter MEM_FILE = "whatever.mem"
|
|
8
|
+
) (
|
|
9
|
+
input [BW_IN-1:0] in,
|
|
10
|
+
output [BW_OUT-1:0] out
|
|
11
|
+
);
|
|
12
|
+
|
|
13
|
+
(*rom_style = "distributed" *)
|
|
14
|
+
reg [BW_OUT-1:0] lut_rom [0:(1<<BW_IN)-1];
|
|
15
|
+
reg [BW_OUT-1:0] readout;
|
|
16
|
+
|
|
17
|
+
initial begin
|
|
18
|
+
$readmemh(MEM_FILE, lut_rom);
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
assign out[BW_OUT-1:0] = readout[BW_OUT-1:0];
|
|
22
|
+
|
|
23
|
+
always @(*) begin
|
|
24
|
+
readout = lut_rom[in];
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
endmodule
|