da4ml 0.4.1__py3-none-any.whl → 0.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (40) hide show
  1. da4ml/__init__.py +2 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +2 -2
  4. da4ml/cmvm/api.py +15 -4
  5. da4ml/cmvm/core/__init__.py +2 -2
  6. da4ml/cmvm/types.py +32 -18
  7. da4ml/cmvm/util/bit_decompose.py +2 -2
  8. da4ml/codegen/hls/hls_codegen.py +10 -5
  9. da4ml/codegen/hls/hls_model.py +7 -4
  10. da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
  11. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  12. da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
  13. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  14. da4ml/codegen/rtl/common_source/template.xdc +11 -13
  15. da4ml/codegen/rtl/rtl_model.py +105 -54
  16. da4ml/codegen/rtl/verilog/__init__.py +2 -1
  17. da4ml/codegen/rtl/verilog/comb.py +47 -7
  18. da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
  19. da4ml/codegen/rtl/verilog/pipeline.py +12 -12
  20. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  21. da4ml/codegen/rtl/vhdl/comb.py +27 -21
  22. da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
  23. da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
  24. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  25. da4ml/converter/__init__.py +57 -1
  26. da4ml/converter/hgq2/parser.py +4 -25
  27. da4ml/converter/hgq2/replica.py +208 -22
  28. da4ml/trace/fixed_variable.py +239 -29
  29. da4ml/trace/fixed_variable_array.py +276 -48
  30. da4ml/trace/ops/__init__.py +31 -15
  31. da4ml/trace/ops/reduce_utils.py +3 -3
  32. da4ml/trace/pipeline.py +40 -18
  33. da4ml/trace/tracer.py +33 -8
  34. da4ml/typing/__init__.py +3 -0
  35. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
  36. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
  37. da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
  38. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
  39. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
  40. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,45 @@
1
1
  import ctypes
2
+ import json
2
3
  import os
3
4
  import re
4
5
  import shutil
5
6
  import subprocess
6
7
  import sys
8
+ from collections.abc import Sequence
7
9
  from pathlib import Path
8
10
  from uuid import uuid4
9
11
 
10
12
  import numpy as np
11
13
  from numpy.typing import NDArray
12
14
 
13
- from ...cmvm.types import CascadedSolution, Solution, _minimal_kif
15
+ from ...cmvm.types import CombLogic, Pipeline, _minimal_kif
14
16
  from ...trace.pipeline import to_pipeline
15
17
  from .. import rtl
16
18
 
17
19
 
18
- def get_io_kifs(sol: Solution | CascadedSolution):
20
+ def get_io_kifs(sol: CombLogic | Pipeline):
19
21
  inp_kifs = tuple(zip(*map(_minimal_kif, sol.inp_qint)))
20
22
  out_kifs = tuple(zip(*map(_minimal_kif, sol.out_qint)))
21
23
  return np.array(inp_kifs, np.int8), np.array(out_kifs, np.int8)
22
24
 
23
25
 
26
+ class at_path:
27
+ def __init__(self, path: str | Path):
28
+ self._path = Path(path)
29
+ self._orig_cwd = None
30
+
31
+ def __enter__(self):
32
+ self._orig_cwd = Path.cwd()
33
+ os.chdir(self._path)
34
+
35
+ def __exit__(self, exc_type, exc_value, traceback):
36
+ os.chdir(self._orig_cwd) # type: ignore
37
+
38
+
24
39
  class RTLModel:
25
40
  def __init__(
26
41
  self,
27
- solution: Solution | CascadedSolution,
42
+ solution: CombLogic | Pipeline,
28
43
  prj_name: str,
29
44
  path: str | Path,
30
45
  flavor: str = 'verilog',
@@ -51,9 +66,9 @@ class RTLModel:
51
66
 
52
67
  assert self._flavor in ('vhdl', 'verilog'), f'Unsupported flavor {flavor}, only vhdl and verilog are supported.'
53
68
 
54
- self._pipe = solution if isinstance(solution, CascadedSolution) else None
69
+ self._pipe = solution if isinstance(solution, Pipeline) else None
55
70
  if latency_cutoff > 0 and self._pipe is None:
56
- assert isinstance(solution, Solution)
71
+ assert isinstance(solution, CombLogic)
57
72
  self._pipe = to_pipeline(solution, latency_cutoff, verbose=False)
58
73
 
59
74
  if self._pipe is not None:
@@ -72,33 +87,46 @@ class RTLModel:
72
87
  else: # verilog
73
88
  from .verilog import binder_gen, comb_logic_gen, generate_io_wrapper, pipeline_logic_gen
74
89
 
75
- self._path.mkdir(parents=True, exist_ok=True)
90
+ from .verilog.comb import table_mem_gen
91
+
92
+ (self._path / 'src/static').mkdir(parents=True, exist_ok=True)
93
+ (self._path / 'sim').mkdir(exist_ok=True)
94
+ (self._path / 'model').mkdir(exist_ok=True)
95
+ (self._path / 'src/memfiles').mkdir(exist_ok=True)
76
96
  if self._pipe is not None: # Pipeline
77
97
  # Main logic
78
98
  codes = pipeline_logic_gen(self._pipe, self._prj_name, self._print_latency, register_layers=self._register_layers)
99
+
100
+ # Table memory files
101
+ memfiles: dict[str, str] = {}
102
+ for comb in self._pipe.solutions:
103
+ memfiles.update(table_mem_gen(comb))
104
+
79
105
  for k, v in codes.items():
80
- with open(self._path / f'{k}.{suffix}', 'w') as f:
106
+ with open(self._path / f'src/{k}.{suffix}', 'w') as f:
81
107
  f.write(v)
82
108
 
83
- # Build script
84
- with open(self.__src_root / 'common_source/build_prj.tcl') as f:
85
- tcl = f.read()
86
- tcl = tcl.replace('${DEVICE}', self._part_name)
87
- tcl = tcl.replace('${PROJECT_NAME}', self._prj_name)
88
- tcl = tcl.replace('${SOURCE_TYPE}', flavor)
89
- with open(self._path / 'build_prj.tcl', 'w') as f:
90
- f.write(tcl)
91
-
92
- # XDC
93
- with open(self.__src_root / 'common_source/template.xdc') as f:
94
- xdc = f.read()
95
- xdc = xdc.replace('${CLOCK_PERIOD}', str(self._clock_period))
96
- xdc = xdc.replace('${UNCERTAINITY_SETUP}', str(self._clock_uncertainty))
97
- xdc = xdc.replace('${UNCERTAINITY_HOLD}', str(self._clock_uncertainty))
98
- xdc = xdc.replace('${DELAY_MAX}', str(self._io_delay_minmax[1]))
99
- xdc = xdc.replace('${DELAY_MIN}', str(self._io_delay_minmax[0]))
100
- with open(self._path / f'{self._prj_name}.xdc', 'w') as f:
101
- f.write(xdc)
109
+ # Build scripts
110
+ for path in (self.__src_root).glob('common_source/build_*_prj.tcl'):
111
+ with open(path) as f:
112
+ tcl = f.read()
113
+ tcl = tcl.replace('$::env(DEVICE)', self._part_name)
114
+ tcl = tcl.replace('$::env(PROJECT_NAME)', self._prj_name)
115
+ tcl = tcl.replace('$::env(SOURCE_TYPE)', flavor)
116
+ with open(self._path / path.name, 'w') as f:
117
+ f.write(tcl)
118
+
119
+ # Timing constraint
120
+ for fmt in ('xdc', 'sdc'):
121
+ with open(self.__src_root / f'common_source/template.{fmt}') as f:
122
+ constraint = f.read()
123
+ constraint = constraint.replace('$::env(CLOCK_PERIOD)', str(self._clock_period))
124
+ constraint = constraint.replace('$::env(UNCERTAINITY_SETUP)', str(self._clock_uncertainty))
125
+ constraint = constraint.replace('$::env(UNCERTAINITY_HOLD)', str(self._clock_uncertainty))
126
+ constraint = constraint.replace('$::env(DELAY_MAX)', str(self._io_delay_minmax[1]))
127
+ constraint = constraint.replace('$::env(DELAY_MIN)', str(self._io_delay_minmax[0]))
128
+ with open(self._path / f'src/{self._prj_name}.{fmt}', 'w') as f:
129
+ f.write(constraint)
102
130
 
103
131
  # C++ binder w/ HDL wrapper for uniform bw
104
132
  binder = binder_gen(self._pipe, f'{self._prj_name}_wrapper', 1, self._register_layers)
@@ -106,34 +134,46 @@ class RTLModel:
106
134
  # Verilog IO wrapper (non-uniform bw to uniform one, clk passthrough)
107
135
  io_wrapper = generate_io_wrapper(self._pipe, self._prj_name, True)
108
136
 
109
- self._pipe.save(self._path / 'pipeline.json')
137
+ self._pipe.save(self._path / 'model/pipeline.json')
110
138
  else: # Comb
111
- assert isinstance(self._solution, Solution)
139
+ assert isinstance(self._solution, CombLogic)
140
+
141
+ # Table memory files
142
+ memfiles = table_mem_gen(self._solution)
112
143
 
113
144
  # Main logic
114
145
  code = comb_logic_gen(self._solution, self._prj_name, self._print_latency, '`timescale 1ns/1ps')
115
- with open(self._path / f'{self._prj_name}.{suffix}', 'w') as f:
146
+ with open(self._path / f'src/{self._prj_name}.{suffix}', 'w') as f:
116
147
  f.write(code)
117
148
 
118
149
  # Verilog IO wrapper (non-uniform bw to uniform one, no clk)
119
150
  io_wrapper = generate_io_wrapper(self._solution, self._prj_name, False)
120
151
  binder = binder_gen(self._solution, f'{self._prj_name}_wrapper')
121
152
 
122
- with open(self._path / f'{self._prj_name}_wrapper.{suffix}', 'w') as f:
153
+ # Write table memory files
154
+ for name, mem in memfiles.items():
155
+ with open(self._path / 'src/memfiles' / name, 'w') as f:
156
+ f.write(mem)
157
+
158
+ with open(self._path / f'src/{self._prj_name}_wrapper.{suffix}', 'w') as f:
123
159
  f.write(io_wrapper)
124
- with open(self._path / f'{self._prj_name}_wrapper_binder.cc', 'w') as f:
160
+ with open(self._path / f'sim/{self._prj_name}_wrapper_binder.cc', 'w') as f:
125
161
  f.write(binder)
126
162
 
127
163
  # Common resource copy
128
- for fname in self.__src_root.glob(f'{flavor}/source/*.{suffix}'):
129
- shutil.copy(fname, self._path)
130
-
131
- shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path)
132
- shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path)
133
- shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path)
134
- self._solution.save(self._path / 'model.json')
135
- with open(self._path / 'misc.json', 'w') as f:
136
- f.write(f'{{"cost": {self._solution.cost}}}')
164
+ for path in self.__src_root.glob(f'{flavor}/source/*.{suffix}'):
165
+ shutil.copy(path, self._path / 'src/static')
166
+
167
+ shutil.copy(self.__src_root / 'common_source/build_binder.mk', self._path / 'sim')
168
+ shutil.copy(self.__src_root / 'common_source/ioutil.hh', self._path / 'sim')
169
+ shutil.copy(self.__src_root / 'common_source/binder_util.hh', self._path / 'sim')
170
+ self._solution.save(self._path / 'model/comb.json')
171
+ with open(self._path / 'metadata.json', 'w') as f:
172
+ misc = {'cost': self._solution.cost}
173
+ if self._pipe is not None:
174
+ misc['latency'] = len(self._pipe[0])
175
+ misc['reg_bits'] = self._pipe.reg_bits
176
+ f.write(json.dumps(misc))
137
177
 
138
178
  def _compile(self, verbose=False, openmp=True, nproc=None, o3: bool = False, clean=True):
139
179
  """Same as compile, but will not write to the library
@@ -150,7 +190,7 @@ class RTLModel:
150
190
  o3 : bool | None, optional
151
191
  Turn on -O3 flag, by default False
152
192
  clean : bool, optional
153
- Remove obsolete shared object files, by default True
193
+ Remove obsolete shared object files and `obj_dir`, by default True
154
194
 
155
195
  Raises
156
196
  ------
@@ -170,18 +210,21 @@ class RTLModel:
170
210
  if o3:
171
211
  args.append('fast')
172
212
 
173
- if clean is not False:
213
+ if clean:
174
214
  m = re.compile(r'^lib.*[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\.so$')
175
- for p in self._path.iterdir():
215
+ for p in (self._path / 'sim').iterdir():
176
216
  if not p.is_dir() and m.match(p.name):
177
217
  p.unlink()
178
- if clean:
179
218
  subprocess.run(
180
- ['make', '-f', 'build_binder.mk', 'clean'], env=env, cwd=self._path, check=True, capture_output=not verbose
219
+ ['make', '-f', 'build_binder.mk', 'clean'],
220
+ env=env,
221
+ cwd=self._path / 'sim',
222
+ check=True,
223
+ capture_output=not verbose,
181
224
  )
182
225
 
183
226
  try:
184
- r = subprocess.run(args, env=env, check=True, cwd=self._path, capture_output=not verbose)
227
+ r = subprocess.run(args, env=env, check=True, cwd=self._path / 'sim', capture_output=not verbose)
185
228
  except subprocess.CalledProcessError as e:
186
229
  print(e.stderr.decode(), file=sys.stderr)
187
230
  print(e.stdout.decode(), file=sys.stdout)
@@ -191,18 +234,21 @@ class RTLModel:
191
234
  print(r.stdout.decode(), file=sys.stderr)
192
235
  raise RuntimeError('Compilation failed!!')
193
236
 
237
+ if clean:
238
+ subprocess.run(['rm', '-rf', 'obj_dir'], cwd=self._path / 'sim', check=True, capture_output=not verbose)
239
+
194
240
  self._load_lib(self._uuid)
195
241
 
196
242
  def _load_lib(self, uuid: str | None = None):
197
243
  uuid = uuid if uuid is not None else self._uuid
198
244
  if uuid is None:
199
245
  # load .so if there is only one, otherwise raise an error
200
- libs = list(self._path.glob(f'lib{self._prj_name}_wrapper_*.so'))
246
+ libs = list(self._path.glob(f'sim/lib{self._prj_name}_wrapper_*.so'))
201
247
  if len(libs) == 0:
202
248
  raise RuntimeError(f'Cannot load library, found {len(libs)} libraries in {self._path}')
203
249
  uuid = libs[0].name.split('_')[-1].split('.', 1)[0]
204
250
  self._uuid = uuid
205
- lib_path = self._path / f'lib{self._prj_name}_wrapper_{uuid}.so'
251
+ lib_path = self._path / f'sim/lib{self._prj_name}_wrapper_{uuid}.so'
206
252
  if not lib_path.exists():
207
253
  raise RuntimeError(f'Library {lib_path} does not exist')
208
254
  self._lib = ctypes.CDLL(str(lib_path))
@@ -222,7 +268,7 @@ class RTLModel:
222
268
  o3 : bool | None, optional
223
269
  Turn on -O3 flag, by default False
224
270
  clean : bool, optional
225
- Remove obsolete shared object files, by default True
271
+ Remove obsolete shared object files and `obj_dir`, by default True
226
272
 
227
273
  Raises
228
274
  ------
@@ -232,12 +278,12 @@ class RTLModel:
232
278
  self.write()
233
279
  self._compile(verbose=verbose, openmp=openmp, nproc=nproc, o3=o3, clean=clean)
234
280
 
235
- def predict(self, data: NDArray[np.floating]) -> NDArray[np.float32]:
281
+ def predict(self, data: NDArray[np.floating] | Sequence[NDArray[np.floating]]) -> NDArray[np.float32]:
236
282
  """Run the model on the input data.
237
283
 
238
284
  Parameters
239
285
  ----------
240
- data : NDArray[np.floating]
286
+ data : NDArray[np.floating]|Sequence[NDArray[np.floating]]
241
287
  Input data to the model. The shape is ignored, and the number of samples is
242
288
  determined by the size of the data.
243
289
 
@@ -247,6 +293,9 @@ class RTLModel:
247
293
  Output of the model in shape (n_samples, output_size).
248
294
  """
249
295
 
296
+ if isinstance(data, Sequence):
297
+ data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
298
+
250
299
  assert self._lib is not None, 'Library not loaded, call .compile() first.'
251
300
  inp_size, out_size = self._solution.shape
252
301
 
@@ -267,7 +316,9 @@ class RTLModel:
267
316
 
268
317
  inp_buf = inp_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
269
318
  out_buf = out_data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
270
- self._lib.inference(inp_buf, out_buf, n_sample)
319
+
320
+ with at_path(self._path / 'src/memfiles'):
321
+ self._lib.inference(inp_buf, out_buf, n_sample)
271
322
 
272
323
  # Unscale the output int32 to recover fp values
273
324
  k, i, f = np.max(k_out), np.max(i_out), np.max(f_out)
@@ -308,7 +359,7 @@ Estimated cost: {cost} LUTs"""
308
359
  class VerilogModel(RTLModel):
309
360
  def __init__(
310
361
  self,
311
- solution: Solution | CascadedSolution,
362
+ solution: CombLogic | Pipeline,
312
363
  prj_name: str,
313
364
  path: str | Path,
314
365
  latency_cutoff: float = -1,
@@ -337,7 +388,7 @@ class VerilogModel(RTLModel):
337
388
  class VHDLModel(RTLModel):
338
389
  def __init__(
339
390
  self,
340
- solution: Solution | CascadedSolution,
391
+ solution: CombLogic | Pipeline,
341
392
  prj_name: str,
342
393
  path: str | Path,
343
394
  latency_cutoff: float = -1,
@@ -1,9 +1,10 @@
1
- from .comb import comb_logic_gen
1
+ from .comb import comb_logic_gen, table_mem_gen
2
2
  from .io_wrapper import binder_gen, generate_io_wrapper
3
3
  from .pipeline import pipeline_logic_gen
4
4
 
5
5
  __all__ = [
6
6
  'comb_logic_gen',
7
+ 'table_mem_gen',
7
8
  'generate_io_wrapper',
8
9
  'pipeline_logic_gen',
9
10
  'binder_gen',
@@ -2,7 +2,7 @@ from math import ceil, log2
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ....cmvm.types import Op, QInterval, Solution, _minimal_kif
5
+ from ....cmvm.types import CombLogic, Op, QInterval, _minimal_kif
6
6
 
7
7
 
8
8
  def make_neg(
@@ -23,16 +23,35 @@ def make_neg(
23
23
  return bw0, v0_name
24
24
 
25
25
 
26
- def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
26
+ def gen_mem_file(sol: CombLogic, op: Op) -> str:
27
+ assert op.opcode == 8
28
+ assert sol.lookup_tables is not None
29
+ table = sol.lookup_tables[op.data]
30
+ width = sum(table.spec.out_kif)
31
+ ndigits = ceil(width / 4)
32
+ data = table.padded_table(sol.ops[op.id0].qint)
33
+ mem_lines = [f'{hex(value)[2:].upper().zfill(ndigits)}' for value in data & ((1 << width) - 1)]
34
+ return '\n'.join(mem_lines)
35
+
36
+
37
+ def get_table_name(sol: CombLogic, op: Op) -> str:
38
+ assert sol.lookup_tables is not None
39
+ assert op.opcode == 8
40
+ qint_in = sol.ops[op.id0].qint
41
+ uuid = sol.lookup_tables[op.data].get_uuid(qint_in)
42
+ return f'lut_{uuid}.mem'
43
+
44
+
45
+ def ssa_gen(sol: CombLogic, neg_defined: set[int], print_latency: bool = False) -> list[str]:
27
46
  ops = sol.ops
28
47
  kifs = list(map(_minimal_kif, (op.qint for op in ops)))
29
- widths = list(map(sum, kifs))
48
+ widths: list[int] = list(map(sum, kifs))
30
49
  inp_kifs = [_minimal_kif(qint) for qint in sol.inp_qint]
31
50
  inp_widths = list(map(sum, inp_kifs))
32
51
  _inp_widths = np.cumsum([0] + inp_widths)
33
52
  inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
34
53
 
35
- lines = []
54
+ lines: list[str] = []
36
55
  ref_count = sol.ref_count
37
56
 
38
57
  for i, op in enumerate(ops):
@@ -84,6 +103,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
84
103
  if op.opcode == -3 and op.id0 not in neg_defined:
85
104
  neg_defined.add(op.id0)
86
105
  bw0, v0_name = make_neg(lines, op, ops, bw0, v0_name)
106
+
87
107
  line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
88
108
 
89
109
  case 4: # constant addition
@@ -93,9 +113,11 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
93
113
  bw0 = widths[op.id0]
94
114
  s0 = int(kifs[op.id0][0])
95
115
  v0 = f'v{op.id0}[{bw0 - 1}:0]'
96
- v1 = f"'{bin(mag)[1:]}"
116
+ v1 = f"{bw1}'{bin(mag)[1:]}"
97
117
  shift = kifs[op.id0][2] - kifs[i][2]
118
+
98
119
  line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
120
+
99
121
  case 5: # constant
100
122
  num = op.data
101
123
  if num < 0:
@@ -112,8 +134,13 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
112
134
  _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
113
135
  shift = f0 - f1 + _shift
114
136
  vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
137
+ if bw0 == 0:
138
+ v0, bw0 = "1'b0", 1
139
+ if bw1 == 0:
140
+ v1, bw1 = "1'b0", 1
115
141
 
116
142
  line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
143
+
117
144
  case 7: # Multiplication
118
145
  bw0, bw1 = widths[op.id0], widths[op.id1] # width
119
146
  s0, s1 = int(kifs[op.id0][0]), int(kifs[op.id1][0])
@@ -121,6 +148,12 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
121
148
 
122
149
  line = f'{_def} multiplier #({bw0}, {bw1}, {s0}, {s1}, {bw}) op_{i} ({v0}, {v1}, {v});'
123
150
 
151
+ case 8: # Lookup Table
152
+ name = get_table_name(sol, op)
153
+ bw0 = widths[op.id0]
154
+
155
+ line = f'{_def} lookup_table #({bw0}, {bw}, "{name}") op_{i} (v{op.id0}, {v});'
156
+
124
157
  case _:
125
158
  raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
126
159
 
@@ -130,7 +163,7 @@ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
130
163
  return lines
131
164
 
132
165
 
133
- def output_gen(sol: Solution, neg_defined: set[int]):
166
+ def output_gen(sol: CombLogic, neg_defined: set[int]):
134
167
  lines = []
135
168
  widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
136
169
  _widths = np.cumsum([0] + widths)
@@ -157,7 +190,7 @@ def output_gen(sol: Solution, neg_defined: set[int]):
157
190
  return lines
158
191
 
159
192
 
160
- def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, timescale: str | None = None):
193
+ def comb_logic_gen(sol: CombLogic, fn_name: str, print_latency: bool = False, timescale: str | None = None):
161
194
  inp_bits = sum(map(sum, map(_minimal_kif, sol.inp_qint)))
162
195
  out_bits = sum(map(sum, map(_minimal_kif, sol.out_qint)))
163
196
 
@@ -191,3 +224,10 @@ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, tim
191
224
  if timescale is not None:
192
225
  code = f'{timescale}\n\n{code}'
193
226
  return code
227
+
228
+
229
+ def table_mem_gen(sol: CombLogic) -> dict[str, str]:
230
+ if not sol.lookup_tables:
231
+ return {}
232
+ mem_files = {get_table_name(sol, op): gen_mem_file(sol, op) for op in sol.ops if op.opcode == 8}
233
+ return mem_files
@@ -1,6 +1,6 @@
1
1
  from itertools import accumulate
2
2
 
3
- from ....cmvm.types import CascadedSolution, QInterval, Solution, _minimal_kif
3
+ from ....cmvm.types import CombLogic, Pipeline, QInterval, _minimal_kif
4
4
 
5
5
 
6
6
  def hetero_io_map(qints: list[QInterval], merge: bool = False):
@@ -59,7 +59,7 @@ def hetero_io_map(qints: list[QInterval], merge: bool = False):
59
59
  return regular, hetero, pads, (width_regular, width_packed)
60
60
 
61
61
 
62
- def generate_io_wrapper(sol: Solution | CascadedSolution, module_name: str, pipelined: bool = False):
62
+ def generate_io_wrapper(sol: CombLogic | Pipeline, module_name: str, pipelined: bool = False):
63
63
  reg_in, het_in, _, shape_in = hetero_io_map(sol.inp_qint, merge=True)
64
64
  reg_out, het_out, pad_out, shape_out = hetero_io_map(sol.out_qint, merge=True)
65
65
 
@@ -113,12 +113,12 @@ endmodule
113
113
  """
114
114
 
115
115
 
116
- def binder_gen(csol: CascadedSolution | Solution, module_name: str, II: int = 1, latency_multiplier: int = 1):
116
+ def binder_gen(csol: Pipeline | CombLogic, module_name: str, II: int = 1, latency_multiplier: int = 1):
117
117
  k_in, i_in, f_in = zip(*map(_minimal_kif, csol.inp_qint))
118
118
  k_out, i_out, f_out = zip(*map(_minimal_kif, csol.out_qint))
119
119
  max_inp_bw = max(k_in) + max(i_in) + max(f_in)
120
120
  max_out_bw = max(k_out) + max(i_out) + max(f_out)
121
- if isinstance(csol, Solution):
121
+ if isinstance(csol, CombLogic):
122
122
  II = latency = 0
123
123
  else:
124
124
  latency = len(csol.solutions) * latency_multiplier
@@ -1,9 +1,9 @@
1
- from ....cmvm.types import CascadedSolution, _minimal_kif
1
+ from ....cmvm.types import Pipeline, _minimal_kif
2
2
  from .comb import comb_logic_gen
3
3
 
4
4
 
5
5
  def pipeline_logic_gen(
6
- csol: CascadedSolution,
6
+ csol: Pipeline,
7
7
  name: str,
8
8
  print_latency=False,
9
9
  timescale: str | None = '`timescale 1 ns / 1 ps',
@@ -13,36 +13,36 @@ def pipeline_logic_gen(
13
13
  inp_bits = [sum(map(sum, map(_minimal_kif, sol.inp_qint))) for sol in csol.solutions]
14
14
  out_bits = inp_bits[1:] + [sum(map(sum, map(_minimal_kif, csol.out_qint)))]
15
15
 
16
- registers = [f'reg [{width-1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
16
+ registers = [f'reg [{width - 1}:0] stage{i}_inp;' for i, width in enumerate(inp_bits)]
17
17
  for i in range(0, register_layers - 1):
18
- registers += [f'reg [{width-1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
- wires = [f'wire [{width-1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
18
+ registers += [f'reg [{width - 1}:0] stage{j}_inp_copy{i};' for j, width in enumerate(inp_bits)]
19
+ wires = [f'wire [{width - 1}:0] stage{i}_out;' for i, width in enumerate(out_bits)]
20
20
 
21
21
  comb_logic = [f'{name}_stage{i} stage{i} (.model_inp(stage{i}_inp), .model_out(stage{i}_out));' for i in range(N)]
22
22
 
23
23
  if register_layers == 1:
24
24
  serial_logic = ['stage0_inp <= model_inp;']
25
- serial_logic += [f'stage{i}_inp <= stage{i-1}_out;' for i in range(1, N)]
25
+ serial_logic += [f'stage{i}_inp <= stage{i - 1}_out;' for i in range(1, N)]
26
26
  else:
27
27
  serial_logic = ['stage0_inp_copy0 <= model_inp;']
28
28
  for j in range(1, register_layers - 1):
29
- serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j-1};')
29
+ serial_logic.append(f'stage0_inp_copy{j} <= stage0_inp_copy{j - 1};')
30
30
  serial_logic.append(f'stage0_inp <= stage0_inp_copy{register_layers - 2};')
31
31
  for i in range(1, N):
32
- serial_logic.append(f'stage{i}_inp_copy0 <= stage{i-1}_out;')
32
+ serial_logic.append(f'stage{i}_inp_copy0 <= stage{i - 1}_out;')
33
33
  for j in range(1, register_layers - 1):
34
- serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j-1};')
34
+ serial_logic.append(f'stage{i}_inp_copy{j} <= stage{i}_inp_copy{j - 1};')
35
35
  serial_logic.append(f'stage{i}_inp <= stage{i}_inp_copy{register_layers - 2};')
36
36
 
37
- serial_logic += [f'model_out <= stage{N-1}_out;']
37
+ serial_logic += [f'model_out <= stage{N - 1}_out;']
38
38
 
39
39
  sep0 = '\n '
40
40
  sep1 = '\n '
41
41
 
42
42
  module = f"""module {name} (
43
43
  input clk,
44
- input [{inp_bits[0]-1}:0] model_inp,
45
- output reg [{out_bits[-1]-1}:0] model_out
44
+ input [{inp_bits[0] - 1}:0] model_inp,
45
+ output reg [{out_bits[-1] - 1}:0] model_out
46
46
  );
47
47
 
48
48
  {sep0.join(registers)}
@@ -0,0 +1,27 @@
1
+ `timescale 1ns / 1ps
2
+
3
+
4
+ module lookup_table #(
5
+ parameter BW_IN = 8,
6
+ parameter BW_OUT = 8,
7
+ parameter MEM_FILE = "whatever.mem"
8
+ ) (
9
+ input [BW_IN-1:0] in,
10
+ output [BW_OUT-1:0] out
11
+ );
12
+
13
+ (*rom_style = "distributed" *)
14
+ reg [BW_OUT-1:0] lut_rom [0:(1<<BW_IN)-1];
15
+ reg [BW_OUT-1:0] readout;
16
+
17
+ initial begin
18
+ $readmemh(MEM_FILE, lut_rom);
19
+ end
20
+
21
+ assign out[BW_OUT-1:0] = readout[BW_OUT-1:0];
22
+
23
+ always @(*) begin
24
+ readout = lut_rom[in];
25
+ end
26
+
27
+ endmodule