da4ml 0.4.0__py3-none-any.whl → 0.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (40) hide show
  1. da4ml/__init__.py +2 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +2 -2
  4. da4ml/cmvm/api.py +15 -4
  5. da4ml/cmvm/core/__init__.py +2 -2
  6. da4ml/cmvm/types.py +32 -18
  7. da4ml/cmvm/util/bit_decompose.py +2 -2
  8. da4ml/codegen/hls/hls_codegen.py +10 -5
  9. da4ml/codegen/hls/hls_model.py +7 -4
  10. da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
  11. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  12. da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
  13. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  14. da4ml/codegen/rtl/common_source/template.xdc +11 -13
  15. da4ml/codegen/rtl/rtl_model.py +105 -53
  16. da4ml/codegen/rtl/verilog/__init__.py +2 -1
  17. da4ml/codegen/rtl/verilog/comb.py +47 -7
  18. da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
  19. da4ml/codegen/rtl/verilog/pipeline.py +12 -12
  20. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  21. da4ml/codegen/rtl/vhdl/comb.py +27 -21
  22. da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
  23. da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
  24. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  25. da4ml/converter/__init__.py +57 -1
  26. da4ml/converter/hgq2/parser.py +4 -25
  27. da4ml/converter/hgq2/replica.py +210 -25
  28. da4ml/trace/fixed_variable.py +239 -29
  29. da4ml/trace/fixed_variable_array.py +276 -48
  30. da4ml/trace/ops/__init__.py +31 -15
  31. da4ml/trace/ops/reduce_utils.py +3 -3
  32. da4ml/trace/pipeline.py +40 -18
  33. da4ml/trace/tracer.py +33 -8
  34. da4ml/typing/__init__.py +3 -0
  35. {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
  36. {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
  37. da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
  38. {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
  39. {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
  40. {da4ml-0.4.0.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,45 @@
1
1
  import ctypes
2
+ import json
2
3
  import os
3
4
  import re
4
5
  import shutil
5
6
  import subprocess
6
7
  import sys
8
+ from collections.abc import Sequence
7
9
  from pathlib import Path
8
10
  from uuid import uuid4
9
11
 
10
12
  import numpy as np
11
13
  from numpy.typing import NDArray
12
14
 
13
- from ...cmvm.types import CascadedSolution, Solution, _minimal_kif
15
+ from ...cmvm.types import CombLogic, Pipeline, _minimal_kif
14
16
  from ...trace.pipeline import to_pipeline
15
17
  from .. import rtl
16
18
 
17
19
 
18
- def get_io_kifs(sol: Solution | CascadedSolution):
20
+ def get_io_kifs(sol: CombLogic | Pipeline):
19
21
  inp_kifs = tuple(zip(*map(_minimal_kif, sol.inp_qint)))
20
22
  out_kifs = tuple(zip(*map(_minimal_kif, sol.out_qint)))
21
23
  return np.array(inp_kifs, np.int8), np.array(out_kifs, np.int8)
22
24
 
23
25
 
26
+ class at_path:
27
+ def __init__(self, path: str | Path):
28
+ self._path = Path(path)
29
+ self._orig_cwd = None
30
+
31
+ def __enter__(self):
32
+ self._orig_cwd = Path.cwd()
33
+ os.chdir(self._path)
34
+
35
+ def __exit__(self, exc_type, exc_value, traceback):
36
+ os.chdir(self._orig_cwd) # type: ignore
37
+
38
+
24
39
  class RTLModel:
25
40
  def __init__(
26
41
  self,
27
- solution: Solution | CascadedSolution,
42
+ solution: CombLogic | Pipeline,
28
43
  prj_name: str,
29
44
  path: str | Path,
30
45
  flavor: str = 'verilog',
@@ -51,9 +66,9 @@ class RTLModel:
51
66
 
52
67
  assert self._flavor in ('vhdl', 'verilog'), f'Unsupported flavor {flavor}, only vhdl and verilog are supported.'
53
68
 
54
- self._pipe = solution if isinstance(solution, CascadedSolution) else None
69
+ self._pipe = solution if isinstance(solution, Pipeline) else None
55
70
  if latency_cutoff > 0 and self._pipe is None:
56
- assert isinstance(solution, Solution)
71
+ assert isinstance(solution, CombLogic)
57
72
  self._pipe = to_pipeline(solution, latency_cutoff, verbose=False)
58
73
 
59
74
  if self._pipe is not None:
@@ -72,32 +87,46 @@ class RTLModel:
72
87
  else: # verilog
73
88
  from .verilog import binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_logic_gen
74
89
 
75
- self._path.mkdir(parents=True, exist_ok=True)
90
+ from .verilog.comb import table_mem_gen
91
+
92
+ (self._path / 'src/static').mkdir(parents=True, exist_ok=True)
93
+ (self._path / 'sim').mkdir(exist_ok=True)
94
+ (self._path / 'model').mkdir(exist_ok=True)
95
+ (self._path / 'src/memfiles').mkdir(exist_ok=True)
76
96
  if self._pipe is not None: # Pipeline
77
97
  # Main logic
78
98
  codes = pipeline_logic_gen(self._pipe, self._prj_name, self._print_latency, register_layers=self._register_layers)
99
+
100
+ # Table memory files
101
+ memfiles: dict[str, str] = {}
102
+ for comb in self._pipe.solutions:
103
+ memfiles.update(table_mem_gen(comb))
104
+
79
105
  for k, v in codes.items():
80
- with open(self._path / f'{k}.{suffix}', 'w') as f:
106
+ with open(self._path / f'src/{k}.{suffix}', 'w') as f:
81
107
  f.write(v)
82
108
 
83
- # Build script
84
- with open(self.__src_root / 'common_source/build_prj.tcl') as f:
85
- tcl = f.read()
86
- tcl = tcl.replace('${DEVICE}', self._part_name)
87
- tcl = tcl.replace('${PROJECT_NAME}', self._prj_name)
88
- with open(self._path / 'build_prj.tcl', 'w') as f:
89
- f.write(tcl)
90
-
91
- # XDC
92
- with open(self.__src_root / 'common_source/template.xdc') as f:
93
- xdc = f.read()
94
- xdc = xdc.replace('${CLOCK_PERIOD}', str(self._clock_period))
95
- xdc = xdc.replace('${UNCERTAINITY_SETUP}', str(self._clock_uncertainty))
96
- xdc = xdc.replace('${UNCERTAINITY_HOLD}', str(self._clock_uncertainty))
97
- xdc = xdc.replace('${DELAY_MAX}', str(self._io_delay_minmax[0]))
98
- xdc = xdc.replace('${DELAY_MIN}', str(self._io_delay_minmax[1]))
99
- with open(self._path / f'{self._prj_name}.xdc', 'w') as f:
100
- f.write(xdc)
109
+ # Build scripts
110
+ for path in (self.__src_root).glob('common_source/build_*_prj.tcl'):
111
+ with open(path) as f:
112
+ tcl = f.read()
113
+ tcl = tcl.replace('$::env(DEVICE)', self._part_name)
114
+ tcl = tcl.replace('$::env(PROJECT_NAME)', self._prj_name)
115
+ tcl = tcl.replace('$::env(SOURCE_TYPE)', flavor)
116
+ with open(self._path / path.name, 'w') as f:
117
+ f.write(tcl)
118
+
119
+ # Timing constraint
120
+ for fmt in ('xdc', 'sdc'):
121
+ with open(self.__src_root / f'common_source/template.{fmt}') as f:
122
+ constraint = f.read()
123
+ constraint = constraint.replace('$::env(CLOCK_PERIOD)', str(self._clock_period))
124
+ constraint = constraint.replace('$::env(UNCERTAINITY_SETUP)', str(self._clock_uncertainty))
125
+ constraint = constraint.replace('$::env(UNCERTAINITY_HOLD)', str(self._clock_uncertainty))
126
+ constraint = constraint.replace('$::env(DELAY_MAX)', str(self._io_delay_minmax[1]))
127
+ constraint = constraint.replace('$::env(DELAY_MIN)', str(self._io_delay_minmax[0]))
128
+ with open(self._path / f'src/{self._prj_name}.{fmt}', 'w') as f:
129
+ f.write(constraint)
101
130
 
102
131
  # C++ binder w/ HDL wrapper for uniform bw
103
132
  binder = binder_gen(self._pipe, f'{self._prj_name}_wrapper', 1, self._register_layers)
@@ -105,34 +134,46 @@ class RTLModel:
105
134
  # Verilog IO wrapper (non-uniform bw to uniform one, clk passthrough)
106
135
  io_wrapper = generate_io_wrapper(self._pipe, self._prj_name, True)
107
136
 
108
- self._pipe.save(self._path / 'pipeline.json')
137
+ self._pipe.save(self._path / 'model/pipeline.json')
109
138
  else: # Comb
110
- assert isinstance(self._solution, Solution)
139
+ assert isinstance(self._solution, CombLogic)
140
+
141
+ # Table memory files
142
+ memfiles = table_mem_gen(self._solution)
111
143
 
112
144
  # Main logic
113
145
  code = comb_logic_gen(self._solution, self._prj_name, self._print_latency, '`timescale 1ns/1ps')
114
- with open(self._path / f'{self._prj_name}.{suffix}', 'w') as f:
146
+ with open(self._path / f'src/{self._prj_name}.{suffix}', 'w') as f:
115
147
  f.write(code)
116
148
 
117
149
  # Verilog IO wrapper (non-uniform bw to uniform one, no clk)
118
150
  io_wrapper = generate_io_wrapper(self._solution, self._prj_name, False)
119
151
  binder = binder_gen(self._solution, f'{self._prj_name}_wrapper')
120
152
 
121
- with open(self._path / f'{self._prj_name}_wrapper.{suffix}', 'w') as f:
153
+ # Write table memory files
154
+ for name, mem in memfiles.items():
155
+ with open(self._path / 'src/memfiles' / name, 'w') as f:
156
+ f.write(mem)
157
+
158
+ with open(self._path / f'src/{self._prj_name}_wrapper.{suffix}', 'w') as f:
122
159
  f.write(io_wrapper)
123
- with open(self._path / f'{self._prj_name}_wrapper_binder.cc', 'w') as f:
160
+ with open(self._path / f'sim/{self._prj_name}_wrapper_binder.cc', 'w') as f:
124
161
  f.write(binder)
125
162
 
126
163
  # Common resource copy
127
- for fname in self.__src_root.glob(f'{flavor}/source/*.{suffix}'):
128
- shutil.copy(fname, self._path)
129
-
130
- shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path)
131
- shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path)
132
- shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path)
133
- self._solution.save(self._path / 'model.json')
134
- with open(self._path / 'misc.json', 'w') as f:
135
- f.write(f'{{"cost": {self._solution.cost}}}')
164
+ for path in self.__src_root.glob(f'{flavor}/source/*.{suffix}'):
165
+ shutil.copy(path, self._path / 'src/static')
166
+
167
+ shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path / 'sim')
168
+ shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path / 'sim')
169
+ shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path / 'sim')
170
+ self._solution.save(self._path / 'model/comb.json')
171
+ with open(self._path / 'metadata.json', 'w') as f:
172
+ misc = {'cost': self._solution.cost}
173
+ if self._pipe is not None:
174
+ misc['latency'] = len(self._pipe[0])
175
+ misc['reg_bits'] = self._pipe.reg_bits
176
+ f.write(json.dumps(misc))
136
177
 
137
178
  def _compile(self, verbose=False, openmp=True, nproc=None, o3: bool = False, clean=True):
138
179
  """Same as compile, but will not write to the library
@@ -149,7 +190,7 @@ class RTLModel:
149
190
  o3 : bool | None, optional
150
191
  Turn on -O3 flag, by default False
151
192
  clean : bool, optional
152
- Remove obsolete shared object files, by default True
193
+ Remove obsolete shared object files and `obj_dir`, by default True
153
194
 
154
195
  Raises
155
196
  ------
@@ -169,18 +210,21 @@ class RTLModel:
169
210
  if o3:
170
211
  args.append('fast')
171
212
 
172
- if clean is not False:
213
+ if clean:
173
214
  m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
174
- for p in self._path.iterdir():
215
+ for p in (self._path / 'sim').iterdir():
175
216
  if not p.is_dir() and m.match(p.name):
176
217
  p.unlink()
177
- if clean:
178
218
  subprocess.run(
179
- ['make', '-f', 'build_binder.mk', 'clean'], env=env, cwd=self._path, check=True, capture_output=not verbose
219
+ ['make', '-f', 'build_binder.mk', 'clean'],
220
+ env=env,
221
+ cwd=self._path / 'sim',
222
+ check=True,
223
+ capture_output=not verbose,
180
224
  )
181
225
 
182
226
  try:
183
- r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
227
+ r = subprocess.run(args, env=env, check=True, cwd=self._path / 'sim', capture_output=not verbose)
184
228
  except subprocess.CalledProcessError as e:
185
229
  print(e.stderr.decode(), file=sys.stderr)
186
230
  print(e.stdout.decode(), file=sys.stdout)
@@ -190,18 +234,21 @@ class RTLModel:
190
234
  print(r.stdout.decode(), file=sys.stderr)
191
235
  raise RuntimeError('Compilation failed!!')
192
236
 
237
+ if clean:
238
+ subprocess.run(['rm', '-rf', 'obj_dir'], cwd=self._path / 'sim', check=True, capture_output=not verbose)
239
+
193
240
  self._load_lib(self._uuid)
194
241
 
195
242
  def _load_lib(self, uuid: str | None = None):
196
243
  uuid = uuid if uuid is not None else self._uuid
197
244
  if uuid is None:
198
245
  # load .so if there is only one, otherwise raise an error
199
- libs = list(self._path.glob(f'lib{self._prj_name}_wrapper_*.so'))
246
+ libs = list(self._path.glob(f'sim/lib{self._prj_name}_wrapper_*.so'))
200
247
  if len(libs) == 0:
201
248
  raise RuntimeError(f'Cannot load library, found {len(libs)} libraries in {self._path}')
202
249
  uuid = libs[0].name.split('_')[-1].split('.', 1)[0]
203
250
  self._uuid = uuid
204
- lib_path = self._path / f'lib{self._prj_name}_wrapper_{uuid}.so'
251
+ lib_path = self._path / f'sim/lib{self._prj_name}_wrapper_{uuid}.so'
205
252
  if not lib_path.exists():
206
253
  raise RuntimeError(f'Library {lib_path} does not exist')
207
254
  self._lib = ctypes.CDLL(str(lib_path))
@@ -221,7 +268,7 @@ class RTLModel:
221
268
  o3 : bool | None, optional
222
269
  Turn on -O3 flag, by default False
223
270
  clean : bool, optional
224
- Remove obsolete shared object files, by default True
271
+ Remove obsolete shared object files and `obj_dir`, by default True
225
272
 
226
273
  Raises
227
274
  ------
@@ -231,12 +278,12 @@ class RTLModel:
231
278
  self.write()
232
279
  self._compile(verbose=verbose, openmp=openmp, nproc=nproc, o3=o3, clean=clean)
233
280
 
234
- def predict(self, data: NDArray[np.floating]) -> NDArray[np.float32]:
281
+ def predict(self, data: NDArray[np.floating] | Sequence[NDArray[np.floating]]) -> NDArray[np.float32]:
235
282
  """Run the model on the input data.
236
283
 
237
284
  Parameters
238
285
  ----------
239
- data : NDArray[np.floating]
286
+ data : NDArray[np.floating]|Sequence[NDArray[np.floating]]
240
287
  Input data to the model. The shape is ignored, and the number of samples is
241
288
  determined by the size of the data.
242
289
 
@@ -246,6 +293,9 @@ class RTLModel:
246
293
  Output of the model in shape (n_samples, output_size).
247
294
  """
248
295
 
296
+ if isinstance(data, Sequence):
297
+ data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
298
+
249
299
  assert self._lib is not None, 'Library not loaded, call .compile() first.'
250
300
  inp_size, out_size = self._solution.shape
251
301
 
@@ -266,7 +316,9 @@ class RTLModel:
266
316
 
267
317
  inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
268
318
  out_buf = out_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
269
- self._lib.inference(inp_buf, out_buf, n_sample)
319
+
320
+ with at_path(self._path / 'src/memfiles'):
321
+ self._lib.inference(inp_buf, out_buf, n_sample)
270
322
 
271
323
  # Unscale the output int32 to recover fp values
272
324
  k, i, f = np.max(k_out), np.max(i_out), np.max(f_out)
@@ -307,7 +359,7 @@ Estimated cost: {cost} LUTs"""
307
359
  class VerilogModel(RTLModel):
308
360
  def __init__(
309
361
  self,
310
- solution: Solution | CascadedSolution,
362
+ solution: CombLogic | Pipeline,
311
363
  prj_name: str,
312
364
  path: str | Path,
313
365
  latency_cutoff: float = -1,
@@ -336,7 +388,7 @@ class VerilogModel(RTLModel):
336
388
  class VHDLModel(RTLModel):
337
389
  def __init__(
338
390
  self,
339
- solution: Solution | CascadedSolution,
391
+ solution: CombLogic | Pipeline,
340
392
  prj_name: str,
341
393
  path: str | Path,
342
394
  latency_cutoff: float = -1,
@@ -1,9 +1,10 @@
1
- from .comb import comb_logic_gen
1
+ from .comb import comb_logic_gen, table_mem_gen
2
2
  from .io_wrapper import binder_gen, generate_io_wrapper
3
3
  from .pipeline import pipeline_logic_gen
4
4
 
5
5
  __all__ = [
6
6
  'comb_logic_gen',
7
+ 'table_mem_gen',
7
8
  'generate_io_wrapper',
8
9
  'pipeline_logic_gen',
9
10
  'binder_gen',
@@ -2,7 +2,7 @@ from math import ceil, log2
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ....cmvm.types import Op, QInterval, Solution, _minimal_kif
5
+ from ....cmvm.types import CombLogic, Op, QInterval, _minimal_kif
6
6
 
7
7
 
8
8
  def make_neg(
@@ -23,16 +23,35 @@ def make_neg(
23
23
  return bw0, v0_name
24
24
 
25
25
 
26
- def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
26
+ def gen_mem_file(sol: CombLogic, op: Op) -> str:
27
+ assert op.opcode == 8
28
+ assert sol.lookup_tables is not None
29
+ table = sol.lookup_tables[op.data]
30
+ width = sum(table.spec.out_kif)
31
+ ndigits = ceil(width / 4)
32
+ data = table.padded_table(sol.ops[op.id0].qint)
33
+ mem_lines = [f'{hex(value)[2:].upper().zfill(ndigits)}' for value in data & ((1 << width) - 1)]
34
+ return '\n'.join(mem_lines)
35
+
36
+
37
+ def get_table_name(sol: CombLogic, op: Op) -> str:
38
+ assert sol.lookup_tables is not None
39
+ assert op.opcode == 8
40
+ qint_in = sol.ops[op.id0].qint
41
+ uuid = sol.lookup_tables[op.data].get_uuid(qint_in)
42
+ return f'lut_{uuid}.mem'
43
+
44
+
45
+ def ssa_gen(sol: CombLogic, neg_defined: set[int], print_latency: bool = False) -> list[str]:
27
46
  ops = sol.ops
28
47
  kifs = list(map(_minimal_kif, (op.qint for op in ops)))
29
- widths = list(map(sum, kifs))
48
+ widths: list[int] = list(map(sum, kifs))
30
49
  inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
31
50
  inp_widths = list(map(sum, inp_kifs))
32
51
  _inp_widths = np.cumsum([0] + inp_widths)
33
52
  inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
34
53
 
35
- lines = []
54
+ lines: list[str] = []
36
55
  ref_count = sol.ref_count
37
56
 
38
57
  for i, op in enumerate(ops):
@@ -84,6 +103,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
84
103
  if op.opcode == -3 and op.id0 not in neg_defined:
85
104
  neg_defined.add(op.id0)
86
105
  bw0, v0_name = make_neg(lines, op, ops, bw0, v0_name)
106
+
87
107
  line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
88
108
 
89
109
  case 4: # constant addition
@@ -93,9 +113,11 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
93
113
  bw0 = widths[op.id0]
94
114
  s0 = int(kifs[op.id0][0])
95
115
  v0 = f'v{op.id0}[{bw0 - 1}:0]'
96
- v1 = f"'{bin(mag)[1:]}"
116
+ v1 = f"{bw1}'{bin(mag)[1:]}"
97
117
  shift = kifs[op.id0][2] - kifs[i][2]
118
+
98
119
  line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
120
+
99
121
  case 5: # constant
100
122
  num = op.data
101
123
  if num < 0:
@@ -112,8 +134,13 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
112
134
  _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
113
135
  shift = f0 - f1 + _shift
114
136
  vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
137
+ if bw0 == 0:
138
+ v0, bw0 = "1'b0", 1
139
+ if bw1 == 0:
140
+ v1, bw1 = "1'b0", 1
115
141
 
116
142
  line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
143
+
117
144
  case 7: # Multiplication
118
145
  bw0, bw1 = widths[op.id0], widths[op.id1] # width
119
146
  s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
@@ -121,6 +148,12 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
121
148
 
122
149
  line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
123
150
 
151
+ case 8: # Lookup Table
152
+ name = get_table_name(sol, op)
153
+ bw0 = widths[op.id0]
154
+
155
+ line = f'{_def} lookup_table #({bw0}, {bw}, "{name}") op_{i} (v{op.id0}, {v});'
156
+
124
157
  case _:
125
158
  raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
126
159
 
@@ -130,7 +163,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
130
163
  return lines
131
164
 
132
165
 
133
- def output_gen(sol: Solution, neg_defined: set[int]):
166
+ def output_gen(sol: CombLogic, neg_defined: set[int]):
134
167
  lines = []
135
168
  widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
136
169
  _widths = np.cumsum([0] + widths)
@@ -157,7 +190,7 @@ def output_gen(sol: Solution, neg_defined: set[int]):
157
190
  return lines
158
191
 
159
192
 
160
- def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
193
+ def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
161
194
  inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
162
195
  out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
163
196
 
@@ -191,3 +224,10 @@ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, tim
191
224
  if timescale is not None:
192
225
  code = f'{timescale}\n\n{code}'
193
226
  return code
227
+
228
+
229
+ def table_mem_gen(sol: CombLogic) -> dict[str, str]:
230
+ if not sol.lookup_tables:
231
+ return {}
232
+ mem_files = {get_table_name(sol, op): gen_mem_file(sol, op) for op in sol.ops if op.opcode == 8}
233
+ return mem_files
@@ -1,6 +1,6 @@
1
1
  from itertools import accumulate
2
2
 
3
- from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
3
+ from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
4
4
 
5
5
 
6
6
  def hetero_io_map(qints: list[QInterval], merge: bool = False):
@@ -59,7 +59,7 @@ def hetero_io_map(qints: list[QInterval], merge: bool = False):
59
59
  return regular, hetero, pads, (width_regular, width_packed)
60
60
 
61
61
 
62
- def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipelined: bool = False):
62
+ def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
63
63
  reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
64
64
  reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
65
65
 
@@ -113,12 +113,12 @@ endmodule
113
113
  """
114
114
 
115
115
 
116
- def binder_gen(csol: CascadedSolution | Solution, module_name: str, II: int = 1, latency_multiplier: int = 1):
116
+ def binder_gen(csol: Pipeline | CombLogic, module_name: str, II: int = 1, latency_multiplier: int = 1):
117
117
  k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
118
118
  k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
119
119
  max_inp_bw = max(k_in) + max(i_in) + max(f_in)
120
120
  max_out_bw = max(k_out) + max(i_out) + max(f_out)
121
- if isinstance(csol, Solution):
121
+ if isinstance(csol, CombLogic):
122
122
  II = latency = 0
123
123
  else:
124
124
  latency = len(csol.solutions) * latency_multiplier
@@ -1,9 +1,9 @@
1
- from ....cmvm.types import CascadedSolution, _minimal_kif
1
+ from ....cmvm.types import Pipeline, _minimal_kif
2
2
  from .comb import comb_logic_gen
3
3
 
4
4
 
5
5
  def pipeline_logic_gen(
6
- csol: CascadedSolution,
6
+ csol: Pipeline,
7
7
  name: str,
8
8
  print_latency=False,
9
9
  timescale: str | None = '`timescale 1 ns / 1 ps',
@@ -13,36 +13,36 @@ def pipeline_logic_gen(
13
13
  inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
14
  out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
15
 
16
- registers = [f'reg [{width-1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
16
+ registers = [f'reg [{width - 1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
17
17
  for i in range(0, register_layers - 1):
18
- registers += [f'reg [{width-1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
- wires = [f'wire [{width-1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
18
+ registers += [f'reg [{width - 1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
+ wires = [f'wire [{width - 1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
20
20
 
21
21
  comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
22
22
 
23
23
  if register_layers == 1:
24
24
  serial_logic = ['stage0_inp <= model_inp;']
25
- serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
25
+ serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
26
26
  else:
27
27
  serial_logic = ['stage0_inp_copy0 <= model_inp;']
28
28
  for j in range(1, register_layers - 1):
29
- serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
29
+ serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
30
30
  serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
31
31
  for i in range(1, N):
32
- serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
32
+ serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
33
33
  for j in range(1, register_layers - 1):
34
- serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
34
+ serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
35
35
  serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
36
36
 
37
- serial_logic += [f'model_out <= stage{N-1}_out;']
37
+ serial_logic += [f'model_out <= stage{N - 1}_out;']
38
38
 
39
39
  sep0 = '\n '
40
40
  sep1 = '\n '
41
41
 
42
42
  module = f"""module {name} (
43
43
  input clk,
44
- input [{inp_bits[0]-1}:0] model_inp,
45
- output reg [{out_bits[-1]-1}:0] model_out
44
+ input [{inp_bits[0] - 1}:0] model_inp,
45
+ output reg [{out_bits[-1] - 1}:0] model_out
46
46
  );
47
47
 
48
48
  {sep0.join(registers)}
@@ -0,0 +1,27 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module lookup_table #(
5
+ parameter BW_IN = 8,
6
+ parameter BW_OUT = 8,
7
+ parameter MEM_FILE = "whatever.mem"
8
+ ) (
9
+ input [BW_IN-1:0] in,
10
+ output [BW_OUT-1:0] out
11
+ );
12
+
13
+ (*rom_style = "distributed" *)
14
+ reg [BW_OUT-1:0] lut_rom [0:(1<<BW_IN)-1];
15
+ reg [BW_OUT-1:0] readout;
16
+
17
+ initial begin
18
+ $readmemh(MEM_FILE, lut_rom);
19
+ end
20
+
21
+ assign out[BW_OUT-1:0] = readout[BW_OUT-1:0];
22
+
23
+ always @(*) begin
24
+ readout = lut_rom[in];
25
+ end
26
+
27
+ endmodule