da4ml 0.4.0__py3-none-any.whl → 0.5.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +2 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +2 -2
- da4ml/cmvm/api.py +15 -4
- da4ml/cmvm/core/__init__.py +2 -2
- da4ml/cmvm/types.py +32 -18
- da4ml/cmvm/util/bit_decompose.py +2 -2
- da4ml/codegen/hls/hls_codegen.py +10 -5
- da4ml/codegen/hls/hls_model.py +7 -4
- da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
- da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
- da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
- da4ml/codegen/rtl/common_source/template.sdc +27 -0
- da4ml/codegen/rtl/common_source/template.xdc +11 -13
- da4ml/codegen/rtl/rtl_model.py +105 -53
- da4ml/codegen/rtl/verilog/__init__.py +2 -1
- da4ml/codegen/rtl/verilog/comb.py +47 -7
- da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
- da4ml/codegen/rtl/verilog/pipeline.py +12 -12
- da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
- da4ml/codegen/rtl/vhdl/comb.py +27 -21
- da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
- da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
- da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
- da4ml/converter/__init__.py +57 -1
- da4ml/converter/hgq2/parser.py +4 -25
- da4ml/converter/hgq2/replica.py +210 -25
- da4ml/trace/fixed_variable.py +239 -29
- da4ml/trace/fixed_variable_array.py +276 -48
- da4ml/trace/ops/__init__.py +31 -15
- da4ml/trace/ops/reduce_utils.py +3 -3
- da4ml/trace/pipeline.py +40 -18
- da4ml/trace/tracer.py +33 -8
- da4ml/typing/__init__.py +3 -0
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
- da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
da4ml/codegen/rtl/rtl_model.py
CHANGED
|
@@ -1,30 +1,45 @@
|
|
|
1
1
|
import ctypes
|
|
2
|
+
import json
|
|
2
3
|
import os
|
|
3
4
|
import re
|
|
4
5
|
import shutil
|
|
5
6
|
import subprocess
|
|
6
7
|
import sys
|
|
8
|
+
from collections.abc import Sequence
|
|
7
9
|
from pathlib import Path
|
|
8
10
|
from uuid import uuid4
|
|
9
11
|
|
|
10
12
|
import numpy as np
|
|
11
13
|
from numpy.typing import NDArray
|
|
12
14
|
|
|
13
|
-
from ...cmvm.types import
|
|
15
|
+
from ...cmvm.types import CombLogic, Pipeline, _minimal_kif
|
|
14
16
|
from ...trace.pipeline import to_pipeline
|
|
15
17
|
from .. import rtl
|
|
16
18
|
|
|
17
19
|
|
|
18
|
-
def get_io_kifs(sol:
|
|
20
|
+
def get_io_kifs(sol: CombLogic | Pipeline):
|
|
19
21
|
inp_kifs = tuple(zip(*map(_minimal_kif, sol.inp_qint)))
|
|
20
22
|
out_kifs = tuple(zip(*map(_minimal_kif, sol.out_qint)))
|
|
21
23
|
return np.array(inp_kifs, np.int8), np.array(out_kifs, np.int8)
|
|
22
24
|
|
|
23
25
|
|
|
26
|
+
class at_path:
|
|
27
|
+
def __init__(self, path: str | Path):
|
|
28
|
+
self._path = Path(path)
|
|
29
|
+
self._orig_cwd = None
|
|
30
|
+
|
|
31
|
+
def __enter__(self):
|
|
32
|
+
self._orig_cwd = Path.cwd()
|
|
33
|
+
os.chdir(self._path)
|
|
34
|
+
|
|
35
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
36
|
+
os.chdir(self._orig_cwd) # type: ignore
|
|
37
|
+
|
|
38
|
+
|
|
24
39
|
class RTLModel:
|
|
25
40
|
def __init__(
|
|
26
41
|
self,
|
|
27
|
-
solution:
|
|
42
|
+
solution: CombLogic | Pipeline,
|
|
28
43
|
prj_name: str,
|
|
29
44
|
path: str | Path,
|
|
30
45
|
flavor: str = 'verilog',
|
|
@@ -51,9 +66,9 @@ class RTLModel:
|
|
|
51
66
|
|
|
52
67
|
assert self._flavor in ('vhdl', 'verilog'), f'Unsupported flavor {flavor}, only vhdl and verilog are supported.'
|
|
53
68
|
|
|
54
|
-
self._pipe = solution if isinstance(solution,
|
|
69
|
+
self._pipe = solution if isinstance(solution, Pipeline) else None
|
|
55
70
|
if latency_cutoff > 0 and self._pipe is None:
|
|
56
|
-
assert isinstance(solution,
|
|
71
|
+
assert isinstance(solution, CombLogic)
|
|
57
72
|
self._pipe = to_pipeline(solution, latency_cutoff, verbose=False)
|
|
58
73
|
|
|
59
74
|
if self._pipe is not None:
|
|
@@ -72,32 +87,46 @@ class RTLModel:
|
|
|
72
87
|
else: # verilog
|
|
73
88
|
from .verilog import binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_logic_gen
|
|
74
89
|
|
|
75
|
-
|
|
90
|
+
from .verilog.comb import table_mem_gen
|
|
91
|
+
|
|
92
|
+
(self._path / 'src/static').mkdir(parents=True, exist_ok=True)
|
|
93
|
+
(self._path / 'sim').mkdir(exist_ok=True)
|
|
94
|
+
(self._path / 'model').mkdir(exist_ok=True)
|
|
95
|
+
(self._path / 'src/memfiles').mkdir(exist_ok=True)
|
|
76
96
|
if self._pipe is not None: # Pipeline
|
|
77
97
|
# Main logic
|
|
78
98
|
codes = pipeline_logic_gen(self._pipe, self._prj_name, self._print_latency, register_layers=self._register_layers)
|
|
99
|
+
|
|
100
|
+
# Table memory files
|
|
101
|
+
memfiles: dict[str, str] = {}
|
|
102
|
+
for comb in self._pipe.solutions:
|
|
103
|
+
memfiles.update(table_mem_gen(comb))
|
|
104
|
+
|
|
79
105
|
for k, v in codes.items():
|
|
80
|
-
with open(self._path / f'{k}.{suffix}', 'w') as f:
|
|
106
|
+
with open(self._path / f'src/{k}.{suffix}', 'w') as f:
|
|
81
107
|
f.write(v)
|
|
82
108
|
|
|
83
|
-
# Build
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
109
|
+
# Build scripts
|
|
110
|
+
for path in (self.__src_root).glob('common_source/build_*_prj.tcl'):
|
|
111
|
+
with open(path) as f:
|
|
112
|
+
tcl = f.read()
|
|
113
|
+
tcl = tcl.replace('$::env(DEVICE)', self._part_name)
|
|
114
|
+
tcl = tcl.replace('$::env(PROJECT_NAME)', self._prj_name)
|
|
115
|
+
tcl = tcl.replace('$::env(SOURCE_TYPE)', flavor)
|
|
116
|
+
with open(self._path / path.name, 'w') as f:
|
|
117
|
+
f.write(tcl)
|
|
118
|
+
|
|
119
|
+
# Timing constraint
|
|
120
|
+
for fmt in ('xdc', 'sdc'):
|
|
121
|
+
with open(self.__src_root / f'common_source/template.{fmt}') as f:
|
|
122
|
+
constraint = f.read()
|
|
123
|
+
constraint = constraint.replace('$::env(CLOCK_PERIOD)', str(self._clock_period))
|
|
124
|
+
constraint = constraint.replace('$::env(UNCERTAINITY_SETUP)', str(self._clock_uncertainty))
|
|
125
|
+
constraint = constraint.replace('$::env(UNCERTAINITY_HOLD)', str(self._clock_uncertainty))
|
|
126
|
+
constraint = constraint.replace('$::env(DELAY_MAX)', str(self._io_delay_minmax[1]))
|
|
127
|
+
constraint = constraint.replace('$::env(DELAY_MIN)', str(self._io_delay_minmax[0]))
|
|
128
|
+
with open(self._path / f'src/{self._prj_name}.{fmt}', 'w') as f:
|
|
129
|
+
f.write(constraint)
|
|
101
130
|
|
|
102
131
|
# C++ binder w/ HDL wrapper for uniform bw
|
|
103
132
|
binder = binder_gen(self._pipe, f'{self._prj_name}_wrapper', 1, self._register_layers)
|
|
@@ -105,34 +134,46 @@ class RTLModel:
|
|
|
105
134
|
# Verilog IO wrapper (non-uniform bw to uniform one, clk passthrough)
|
|
106
135
|
io_wrapper = generate_io_wrapper(self._pipe, self._prj_name, True)
|
|
107
136
|
|
|
108
|
-
self._pipe.save(self._path / 'pipeline.json')
|
|
137
|
+
self._pipe.save(self._path / 'model/pipeline.json')
|
|
109
138
|
else: # Comb
|
|
110
|
-
assert isinstance(self._solution,
|
|
139
|
+
assert isinstance(self._solution, CombLogic)
|
|
140
|
+
|
|
141
|
+
# Table memory files
|
|
142
|
+
memfiles = table_mem_gen(self._solution)
|
|
111
143
|
|
|
112
144
|
# Main logic
|
|
113
145
|
code = comb_logic_gen(self._solution, self._prj_name, self._print_latency, '`timescale 1ns/1ps')
|
|
114
|
-
with open(self._path / f'{self._prj_name}.{suffix}', 'w') as f:
|
|
146
|
+
with open(self._path / f'src/{self._prj_name}.{suffix}', 'w') as f:
|
|
115
147
|
f.write(code)
|
|
116
148
|
|
|
117
149
|
# Verilog IO wrapper (non-uniform bw to uniform one, no clk)
|
|
118
150
|
io_wrapper = generate_io_wrapper(self._solution, self._prj_name, False)
|
|
119
151
|
binder = binder_gen(self._solution, f'{self._prj_name}_wrapper')
|
|
120
152
|
|
|
121
|
-
|
|
153
|
+
# Write table memory files
|
|
154
|
+
for name, mem in memfiles.items():
|
|
155
|
+
with open(self._path / 'src/memfiles' / name, 'w') as f:
|
|
156
|
+
f.write(mem)
|
|
157
|
+
|
|
158
|
+
with open(self._path / f'src/{self._prj_name}_wrapper.{suffix}', 'w') as f:
|
|
122
159
|
f.write(io_wrapper)
|
|
123
|
-
with open(self._path / f'{self._prj_name}_wrapper_binder.cc', 'w') as f:
|
|
160
|
+
with open(self._path / f'sim/{self._prj_name}_wrapper_binder.cc', 'w') as f:
|
|
124
161
|
f.write(binder)
|
|
125
162
|
|
|
126
163
|
# Common resource copy
|
|
127
|
-
for
|
|
128
|
-
shutil.copy(
|
|
129
|
-
|
|
130
|
-
shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path)
|
|
131
|
-
shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path)
|
|
132
|
-
shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path)
|
|
133
|
-
self._solution.save(self._path / 'model.json')
|
|
134
|
-
with open(self._path / '
|
|
135
|
-
|
|
164
|
+
for path in self.__src_root.glob(f'{flavor}/source/*.{suffix}'):
|
|
165
|
+
shutil.copy(path, self._path / 'src/static')
|
|
166
|
+
|
|
167
|
+
shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path / 'sim')
|
|
168
|
+
shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path / 'sim')
|
|
169
|
+
shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path / 'sim')
|
|
170
|
+
self._solution.save(self._path / 'model/comb.json')
|
|
171
|
+
with open(self._path / 'metadata.json', 'w') as f:
|
|
172
|
+
misc = {'cost': self._solution.cost}
|
|
173
|
+
if self._pipe is not None:
|
|
174
|
+
misc['latency'] = len(self._pipe[0])
|
|
175
|
+
misc['reg_bits'] = self._pipe.reg_bits
|
|
176
|
+
f.write(json.dumps(misc))
|
|
136
177
|
|
|
137
178
|
def _compile(self, verbose=False, openmp=True, nproc=None, o3: bool = False, clean=True):
|
|
138
179
|
"""Same as compile, but will not write to the library
|
|
@@ -149,7 +190,7 @@ class RTLModel:
|
|
|
149
190
|
o3 : bool | None, optional
|
|
150
191
|
Turn on -O3 flag, by default False
|
|
151
192
|
clean : bool, optional
|
|
152
|
-
Remove obsolete shared object files
|
|
193
|
+
Remove obsolete shared object files and `obj_dir`, by default True
|
|
153
194
|
|
|
154
195
|
Raises
|
|
155
196
|
------
|
|
@@ -169,18 +210,21 @@ class RTLModel:
|
|
|
169
210
|
if o3:
|
|
170
211
|
args.append('fast')
|
|
171
212
|
|
|
172
|
-
if clean
|
|
213
|
+
if clean:
|
|
173
214
|
m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
|
|
174
|
-
for p in self._path.iterdir():
|
|
215
|
+
for p in (self._path / 'sim').iterdir():
|
|
175
216
|
if not p.is_dir() and m.match(p.name):
|
|
176
217
|
p.unlink()
|
|
177
|
-
if clean:
|
|
178
218
|
subprocess.run(
|
|
179
|
-
['make', '-f', 'build_binder.mk', 'clean'],
|
|
219
|
+
['make', '-f', 'build_binder.mk', 'clean'],
|
|
220
|
+
env=env,
|
|
221
|
+
cwd=self._path / 'sim',
|
|
222
|
+
check=True,
|
|
223
|
+
capture_output=not verbose,
|
|
180
224
|
)
|
|
181
225
|
|
|
182
226
|
try:
|
|
183
|
-
r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
|
|
227
|
+
r = subprocess.run(args, env=env, check=True, cwd=self._path / 'sim', capture_output=not verbose)
|
|
184
228
|
except subprocess.CalledProcessError as e:
|
|
185
229
|
print(e.stderr.decode(), file=sys.stderr)
|
|
186
230
|
print(e.stdout.decode(), file=sys.stdout)
|
|
@@ -190,18 +234,21 @@ class RTLModel:
|
|
|
190
234
|
print(r.stdout.decode(), file=sys.stderr)
|
|
191
235
|
raise RuntimeError('Compilation failed!!')
|
|
192
236
|
|
|
237
|
+
if clean:
|
|
238
|
+
subprocess.run(['rm', '-rf', 'obj_dir'], cwd=self._path / 'sim', check=True, capture_output=not verbose)
|
|
239
|
+
|
|
193
240
|
self._load_lib(self._uuid)
|
|
194
241
|
|
|
195
242
|
def _load_lib(self, uuid: str | None = None):
|
|
196
243
|
uuid = uuid if uuid is not None else self._uuid
|
|
197
244
|
if uuid is None:
|
|
198
245
|
# load .so if there is only one, otherwise raise an error
|
|
199
|
-
libs = list(self._path.glob(f'lib{self._prj_name}_wrapper_*.so'))
|
|
246
|
+
libs = list(self._path.glob(f'sim/lib{self._prj_name}_wrapper_*.so'))
|
|
200
247
|
if len(libs) == 0:
|
|
201
248
|
raise RuntimeError(f'Cannot load library, found {len(libs)} libraries in {self._path}')
|
|
202
249
|
uuid = libs[0].name.split('_')[-1].split('.', 1)[0]
|
|
203
250
|
self._uuid = uuid
|
|
204
|
-
lib_path = self._path / f'lib{self._prj_name}_wrapper_{uuid}.so'
|
|
251
|
+
lib_path = self._path / f'sim/lib{self._prj_name}_wrapper_{uuid}.so'
|
|
205
252
|
if not lib_path.exists():
|
|
206
253
|
raise RuntimeError(f'Library {lib_path} does not exist')
|
|
207
254
|
self._lib = ctypes.CDLL(str(lib_path))
|
|
@@ -221,7 +268,7 @@ class RTLModel:
|
|
|
221
268
|
o3 : bool | None, optional
|
|
222
269
|
Turn on -O3 flag, by default False
|
|
223
270
|
clean : bool, optional
|
|
224
|
-
Remove obsolete shared object files
|
|
271
|
+
Remove obsolete shared object files and `obj_dir`, by default True
|
|
225
272
|
|
|
226
273
|
Raises
|
|
227
274
|
------
|
|
@@ -231,12 +278,12 @@ class RTLModel:
|
|
|
231
278
|
self.write()
|
|
232
279
|
self._compile(verbose=verbose, openmp=openmp, nproc=nproc, o3=o3, clean=clean)
|
|
233
280
|
|
|
234
|
-
def predict(self, data: NDArray[np.floating]) -> NDArray[np.float32]:
|
|
281
|
+
def predict(self, data: NDArray[np.floating] | Sequence[NDArray[np.floating]]) -> NDArray[np.float32]:
|
|
235
282
|
"""Run the model on the input data.
|
|
236
283
|
|
|
237
284
|
Parameters
|
|
238
285
|
----------
|
|
239
|
-
data : NDArray[np.floating]
|
|
286
|
+
data : NDArray[np.floating]|Sequence[NDArray[np.floating]]
|
|
240
287
|
Input data to the model. The shape is ignored, and the number of samples is
|
|
241
288
|
determined by the size of the data.
|
|
242
289
|
|
|
@@ -246,6 +293,9 @@ class RTLModel:
|
|
|
246
293
|
Output of the model in shape (n_samples, output_size).
|
|
247
294
|
"""
|
|
248
295
|
|
|
296
|
+
if isinstance(data, Sequence):
|
|
297
|
+
data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
|
|
298
|
+
|
|
249
299
|
assert self._lib is not None, 'Library not loaded, call .compile() first.'
|
|
250
300
|
inp_size, out_size = self._solution.shape
|
|
251
301
|
|
|
@@ -266,7 +316,9 @@ class RTLModel:
|
|
|
266
316
|
|
|
267
317
|
inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
|
|
268
318
|
out_buf = out_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
|
|
269
|
-
|
|
319
|
+
|
|
320
|
+
with at_path(self._path / 'src/memfiles'):
|
|
321
|
+
self._lib.inference(inp_buf, out_buf, n_sample)
|
|
270
322
|
|
|
271
323
|
# Unscale the output int32 to recover fp values
|
|
272
324
|
k, i, f = np.max(k_out), np.max(i_out), np.max(f_out)
|
|
@@ -307,7 +359,7 @@ Estimated cost: {cost} LUTs"""
|
|
|
307
359
|
class VerilogModel(RTLModel):
|
|
308
360
|
def __init__(
|
|
309
361
|
self,
|
|
310
|
-
solution:
|
|
362
|
+
solution: CombLogic | Pipeline,
|
|
311
363
|
prj_name: str,
|
|
312
364
|
path: str | Path,
|
|
313
365
|
latency_cutoff: float = -1,
|
|
@@ -336,7 +388,7 @@ class VerilogModel(RTLModel):
|
|
|
336
388
|
class VHDLModel(RTLModel):
|
|
337
389
|
def __init__(
|
|
338
390
|
self,
|
|
339
|
-
solution:
|
|
391
|
+
solution: CombLogic | Pipeline,
|
|
340
392
|
prj_name: str,
|
|
341
393
|
path: str | Path,
|
|
342
394
|
latency_cutoff: float = -1,
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from .comb import comb_logic_gen
|
|
1
|
+
from .comb import comb_logic_gen, table_mem_gen
|
|
2
2
|
from .io_wrapper import binder_gen, generate_io_wrapper
|
|
3
3
|
from .pipeline import pipeline_logic_gen
|
|
4
4
|
|
|
5
5
|
__all__ = [
|
|
6
6
|
'comb_logic_gen',
|
|
7
|
+
'table_mem_gen',
|
|
7
8
|
'generate_io_wrapper',
|
|
8
9
|
'pipeline_logic_gen',
|
|
9
10
|
'binder_gen',
|
|
@@ -2,7 +2,7 @@ from math import ceil, log2
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
|
-
from ....cmvm.types import Op, QInterval,
|
|
5
|
+
from ....cmvm.types import CombLogic, Op, QInterval, _minimal_kif
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def make_neg(
|
|
@@ -23,16 +23,35 @@ def make_neg(
|
|
|
23
23
|
return bw0, v0_name
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def
|
|
26
|
+
def gen_mem_file(sol: CombLogic, op: Op) -> str:
|
|
27
|
+
assert op.opcode == 8
|
|
28
|
+
assert sol.lookup_tables is not None
|
|
29
|
+
table = sol.lookup_tables[op.data]
|
|
30
|
+
width = sum(table.spec.out_kif)
|
|
31
|
+
ndigits = ceil(width / 4)
|
|
32
|
+
data = table.padded_table(sol.ops[op.id0].qint)
|
|
33
|
+
mem_lines = [f'{hex(value)[2:].upper().zfill(ndigits)}' for value in data & ((1 << width) - 1)]
|
|
34
|
+
return '\n'.join(mem_lines)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_table_name(sol: CombLogic, op: Op) -> str:
|
|
38
|
+
assert sol.lookup_tables is not None
|
|
39
|
+
assert op.opcode == 8
|
|
40
|
+
qint_in = sol.ops[op.id0].qint
|
|
41
|
+
uuid = sol.lookup_tables[op.data].get_uuid(qint_in)
|
|
42
|
+
return f'lut_{uuid}.mem'
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def ssa_gen(sol: CombLogic, neg_defined: set[int], print_latency: bool = False) -> list[str]:
|
|
27
46
|
ops = sol.ops
|
|
28
47
|
kifs = list(map(_minimal_kif, (op.qint for op in ops)))
|
|
29
|
-
widths = list(map(sum, kifs))
|
|
48
|
+
widths: list[int] = list(map(sum, kifs))
|
|
30
49
|
inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
|
|
31
50
|
inp_widths = list(map(sum, inp_kifs))
|
|
32
51
|
_inp_widths = np.cumsum([0] + inp_widths)
|
|
33
52
|
inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
|
|
34
53
|
|
|
35
|
-
lines = []
|
|
54
|
+
lines: list[str] = []
|
|
36
55
|
ref_count = sol.ref_count
|
|
37
56
|
|
|
38
57
|
for i, op in enumerate(ops):
|
|
@@ -84,6 +103,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
84
103
|
if op.opcode == -3 and op.id0 not in neg_defined:
|
|
85
104
|
neg_defined.add(op.id0)
|
|
86
105
|
bw0, v0_name = make_neg(lines, op, ops, bw0, v0_name)
|
|
106
|
+
|
|
87
107
|
line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
|
|
88
108
|
|
|
89
109
|
case 4: # constant addition
|
|
@@ -93,9 +113,11 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
93
113
|
bw0 = widths[op.id0]
|
|
94
114
|
s0 = int(kifs[op.id0][0])
|
|
95
115
|
v0 = f'v{op.id0}[{bw0 - 1}:0]'
|
|
96
|
-
v1 = f"'{bin(mag)[1:]}"
|
|
116
|
+
v1 = f"{bw1}'{bin(mag)[1:]}"
|
|
97
117
|
shift = kifs[op.id0][2] - kifs[i][2]
|
|
118
|
+
|
|
98
119
|
line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
|
|
120
|
+
|
|
99
121
|
case 5: # constant
|
|
100
122
|
num = op.data
|
|
101
123
|
if num < 0:
|
|
@@ -112,8 +134,13 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
112
134
|
_shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
|
|
113
135
|
shift = f0 - f1 + _shift
|
|
114
136
|
vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
|
|
137
|
+
if bw0 == 0:
|
|
138
|
+
v0, bw0 = "1'b0", 1
|
|
139
|
+
if bw1 == 0:
|
|
140
|
+
v1, bw1 = "1'b0", 1
|
|
115
141
|
|
|
116
142
|
line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
|
|
143
|
+
|
|
117
144
|
case 7: # Multiplication
|
|
118
145
|
bw0, bw1 = widths[op.id0], widths[op.id1] # width
|
|
119
146
|
s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
|
|
@@ -121,6 +148,12 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
121
148
|
|
|
122
149
|
line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
|
|
123
150
|
|
|
151
|
+
case 8: # Lookup Table
|
|
152
|
+
name = get_table_name(sol, op)
|
|
153
|
+
bw0 = widths[op.id0]
|
|
154
|
+
|
|
155
|
+
line = f'{_def} lookup_table #({bw0}, {bw}, "{name}") op_{i} (v{op.id0}, {v});'
|
|
156
|
+
|
|
124
157
|
case _:
|
|
125
158
|
raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
|
|
126
159
|
|
|
@@ -130,7 +163,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
|
|
|
130
163
|
return lines
|
|
131
164
|
|
|
132
165
|
|
|
133
|
-
def output_gen(sol:
|
|
166
|
+
def output_gen(sol: CombLogic, neg_defined: set[int]):
|
|
134
167
|
lines = []
|
|
135
168
|
widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
136
169
|
_widths = np.cumsum([0] + widths)
|
|
@@ -157,7 +190,7 @@ def output_gen(sol: Solution, neg_defined: set[int]):
|
|
|
157
190
|
return lines
|
|
158
191
|
|
|
159
192
|
|
|
160
|
-
def comb_logic_gen(sol:
|
|
193
|
+
def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
|
|
161
194
|
inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
|
|
162
195
|
out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
|
|
163
196
|
|
|
@@ -191,3 +224,10 @@ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, tim
|
|
|
191
224
|
if timescale is not None:
|
|
192
225
|
code = f'{timescale}\n\n{code}'
|
|
193
226
|
return code
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def table_mem_gen(sol: CombLogic) -> dict[str, str]:
|
|
230
|
+
if not sol.lookup_tables:
|
|
231
|
+
return {}
|
|
232
|
+
mem_files = {get_table_name(sol, op): gen_mem_file(sol, op) for op in sol.ops if op.opcode == 8}
|
|
233
|
+
return mem_files
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from itertools import accumulate
|
|
2
2
|
|
|
3
|
-
from ....cmvm.types import
|
|
3
|
+
from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
@@ -59,7 +59,7 @@ def hetero_io_map(qints: list[QInterval], merge: bool = False):
|
|
|
59
59
|
return regular, hetero, pads, (width_regular, width_packed)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def generate_io_wrapper(sol:
|
|
62
|
+
def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
|
|
63
63
|
reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
|
|
64
64
|
reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
|
|
65
65
|
|
|
@@ -113,12 +113,12 @@ endmodule
|
|
|
113
113
|
"""
|
|
114
114
|
|
|
115
115
|
|
|
116
|
-
def binder_gen(csol:
|
|
116
|
+
def binder_gen(csol: Pipeline | CombLogic, module_name: str, II: int = 1, latency_multiplier: int = 1):
|
|
117
117
|
k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
|
|
118
118
|
k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
|
|
119
119
|
max_inp_bw = max(k_in) + max(i_in) + max(f_in)
|
|
120
120
|
max_out_bw = max(k_out) + max(i_out) + max(f_out)
|
|
121
|
-
if isinstance(csol,
|
|
121
|
+
if isinstance(csol, CombLogic):
|
|
122
122
|
II = latency = 0
|
|
123
123
|
else:
|
|
124
124
|
latency = len(csol.solutions) * latency_multiplier
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from ....cmvm.types import
|
|
1
|
+
from ....cmvm.types import Pipeline, _minimal_kif
|
|
2
2
|
from .comb import comb_logic_gen
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def pipeline_logic_gen(
|
|
6
|
-
csol:
|
|
6
|
+
csol: Pipeline,
|
|
7
7
|
name: str,
|
|
8
8
|
print_latency=False,
|
|
9
9
|
timescale: str | None = '`timescale 1 ns / 1 ps',
|
|
@@ -13,36 +13,36 @@ def pipeline_logic_gen(
|
|
|
13
13
|
inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
|
|
14
14
|
out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
|
|
15
15
|
|
|
16
|
-
registers = [f'reg [{width-1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
|
|
16
|
+
registers = [f'reg [{width - 1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
|
|
17
17
|
for i in range(0, register_layers - 1):
|
|
18
|
-
registers += [f'reg [{width-1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
|
|
19
|
-
wires = [f'wire [{width-1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
|
|
18
|
+
registers += [f'reg [{width - 1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
|
|
19
|
+
wires = [f'wire [{width - 1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
|
|
20
20
|
|
|
21
21
|
comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
|
|
22
22
|
|
|
23
23
|
if register_layers == 1:
|
|
24
24
|
serial_logic = ['stage0_inp <= model_inp;']
|
|
25
|
-
serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
|
|
25
|
+
serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
|
|
26
26
|
else:
|
|
27
27
|
serial_logic = ['stage0_inp_copy0 <= model_inp;']
|
|
28
28
|
for j in range(1, register_layers - 1):
|
|
29
|
-
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
|
|
29
|
+
serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
|
|
30
30
|
serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
|
|
31
31
|
for i in range(1, N):
|
|
32
|
-
serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
|
|
32
|
+
serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
|
|
33
33
|
for j in range(1, register_layers - 1):
|
|
34
|
-
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
|
|
34
|
+
serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
|
|
35
35
|
serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
|
|
36
36
|
|
|
37
|
-
serial_logic += [f'model_out <= stage{N-1}_out;']
|
|
37
|
+
serial_logic += [f'model_out <= stage{N - 1}_out;']
|
|
38
38
|
|
|
39
39
|
sep0 = '\n '
|
|
40
40
|
sep1 = '\n '
|
|
41
41
|
|
|
42
42
|
module = f"""module {name} (
|
|
43
43
|
input clk,
|
|
44
|
-
input [{inp_bits[0]-1}:0] model_inp,
|
|
45
|
-
output reg [{out_bits[-1]-1}:0] model_out
|
|
44
|
+
input [{inp_bits[0] - 1}:0] model_inp,
|
|
45
|
+
output reg [{out_bits[-1] - 1}:0] model_out
|
|
46
46
|
);
|
|
47
47
|
|
|
48
48
|
{sep0.join(registers)}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
`timescale 1ns / 1ps
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
module lookup_table #(
|
|
5
|
+
parameter BW_IN = 8,
|
|
6
|
+
parameter BW_OUT = 8,
|
|
7
|
+
parameter MEM_FILE = "whatever.mem"
|
|
8
|
+
) (
|
|
9
|
+
input [BW_IN-1:0] in,
|
|
10
|
+
output [BW_OUT-1:0] out
|
|
11
|
+
);
|
|
12
|
+
|
|
13
|
+
(*rom_style = "distributed" *)
|
|
14
|
+
reg [BW_OUT-1:0] lut_rom [0:(1<<BW_IN)-1];
|
|
15
|
+
reg [BW_OUT-1:0] readout;
|
|
16
|
+
|
|
17
|
+
initial begin
|
|
18
|
+
$readmemh(MEM_FILE, lut_rom);
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
assign out[BW_OUT-1:0] = readout[BW_OUT-1:0];
|
|
22
|
+
|
|
23
|
+
always @(*) begin
|
|
24
|
+
readout = lut_rom[in];
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
endmodule
|