da4ml 0.4.1__tar.gz → 0.5.0b0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (112) hide show
  1. {da4ml-0.4.1/src/da4ml.egg-info → da4ml-0.5.0b0}/PKG-INFO +2 -1
  2. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/getting_started.md +5 -4
  3. {da4ml-0.4.1 → da4ml-0.5.0b0}/pyproject.toml +8 -2
  4. da4ml-0.5.0b0/src/da4ml/__init__.py +3 -0
  5. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/_version.py +3 -3
  6. da4ml-0.5.0b0/src/da4ml/cmvm/__init__.py +4 -0
  7. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/api.py +15 -4
  8. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/core/__init__.py +2 -2
  9. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/types.py +32 -18
  10. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/util/bit_decompose.py +2 -2
  11. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/hls_codegen.py +10 -5
  12. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/hls_model.py +7 -4
  13. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
  14. da4ml-0.5.0b0/src/da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  15. da4ml-0.4.1/src/da4ml/codegen/rtl/common_source/build_prj.tcl → da4ml-0.5.0b0/src/da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +39 -18
  16. da4ml-0.5.0b0/src/da4ml/codegen/rtl/common_source/template.sdc +27 -0
  17. da4ml-0.5.0b0/src/da4ml/codegen/rtl/common_source/template.xdc +30 -0
  18. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/rtl_model.py +105 -54
  19. {da4ml-0.4.1/src/da4ml/codegen/rtl/vhdl → da4ml-0.5.0b0/src/da4ml/codegen/rtl/verilog}/__init__.py +2 -1
  20. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/verilog/comb.py +47 -7
  21. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
  22. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/verilog/pipeline.py +12 -12
  23. da4ml-0.5.0b0/src/da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  24. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/vhdl/comb.py +27 -21
  25. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
  26. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
  27. da4ml-0.5.0b0/src/da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  28. da4ml-0.5.0b0/src/da4ml/converter/__init__.py +59 -0
  29. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/converter/hgq2/parser.py +4 -25
  30. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/converter/hgq2/replica.py +208 -22
  31. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/fixed_variable.py +239 -29
  32. da4ml-0.5.0b0/src/da4ml/trace/fixed_variable_array.py +572 -0
  33. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/ops/__init__.py +31 -15
  34. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/ops/reduce_utils.py +3 -3
  35. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/pipeline.py +40 -18
  36. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/tracer.py +33 -8
  37. da4ml-0.5.0b0/src/da4ml/typing/__init__.py +3 -0
  38. {da4ml-0.4.1 → da4ml-0.5.0b0/src/da4ml.egg-info}/PKG-INFO +2 -1
  39. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml.egg-info/SOURCES.txt +7 -3
  40. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml.egg-info/requires.txt +1 -0
  41. da4ml-0.4.1/src/da4ml/__init__.py +0 -17
  42. da4ml-0.4.1/src/da4ml/cmvm/__init__.py +0 -4
  43. da4ml-0.4.1/src/da4ml/codegen/rtl/common_source/template.xdc +0 -32
  44. da4ml-0.4.1/src/da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
  45. da4ml-0.4.1/src/da4ml/converter/__init__.py +0 -3
  46. da4ml-0.4.1/src/da4ml/trace/fixed_variable_array.py +0 -344
  47. {da4ml-0.4.1 → da4ml-0.5.0b0}/.clang-format +0 -0
  48. {da4ml-0.4.1 → da4ml-0.5.0b0}/.github/workflows/python-publish.yml +0 -0
  49. {da4ml-0.4.1 → da4ml-0.5.0b0}/.github/workflows/sphinx-build.yml +0 -0
  50. {da4ml-0.4.1 → da4ml-0.5.0b0}/.gitignore +0 -0
  51. {da4ml-0.4.1 → da4ml-0.5.0b0}/.pre-commit-config.yaml +0 -0
  52. {da4ml-0.4.1 → da4ml-0.5.0b0}/LICENSE +0 -0
  53. {da4ml-0.4.1 → da4ml-0.5.0b0}/README.md +0 -0
  54. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/Makefile +0 -0
  55. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/_static/example.svg +0 -0
  56. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/_static/icon.svg +0 -0
  57. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/_static/stage1.svg +0 -0
  58. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/_static/stage2.svg +0 -0
  59. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/_static/workflow.svg +0 -0
  60. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/cmvm.md +0 -0
  61. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/conf.py +0 -0
  62. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/dais.md +0 -0
  63. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/faq.md +0 -0
  64. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/index.rst +0 -0
  65. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/install.md +0 -0
  66. {da4ml-0.4.1 → da4ml-0.5.0b0}/docs/status.md +0 -0
  67. {da4ml-0.4.1 → da4ml-0.5.0b0}/interperter/DAISInterpreter.cc +0 -0
  68. {da4ml-0.4.1 → da4ml-0.5.0b0}/interperter/DAISInterpreter.hh +0 -0
  69. {da4ml-0.4.1 → da4ml-0.5.0b0}/setup.cfg +0 -0
  70. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/core/indexers.py +0 -0
  71. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/core/state_opr.py +0 -0
  72. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/util/__init__.py +0 -0
  73. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/cmvm/util/mat_decompose.py +0 -0
  74. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/__init__.py +0 -0
  75. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/__init__.py +0 -0
  76. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_binary.h +0 -0
  77. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_common.h +0 -0
  78. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_decl.h +0 -0
  79. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_fixed.h +0 -0
  80. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +0 -0
  81. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +0 -0
  82. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +0 -0
  83. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_int.h +0 -0
  84. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_int_base.h +0 -0
  85. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_int_ref.h +0 -0
  86. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_int_special.h +0 -0
  87. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +0 -0
  88. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/etc/ap_private.h +0 -0
  89. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/hls_math.h +0 -0
  90. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/hls_stream.h +0 -0
  91. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +0 -0
  92. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/binder_util.hh +0 -0
  93. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/build_binder.mk +0 -0
  94. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/hls/source/vitis_bitshift.hh +0 -0
  95. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/__init__.py +0 -0
  96. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/common_source/binder_util.hh +0 -0
  97. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/common_source/ioutil.hh +0 -0
  98. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/verilog/source/multiplier.v +0 -0
  99. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/verilog/source/mux.v +0 -0
  100. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/verilog/source/negative.v +0 -0
  101. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/verilog/source/shift_adder.v +0 -0
  102. {da4ml-0.4.1/src/da4ml/codegen/rtl/verilog → da4ml-0.5.0b0/src/da4ml/codegen/rtl/vhdl}/__init__.py +0 -0
  103. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/vhdl/source/multiplier.vhd +0 -0
  104. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/vhdl/source/mux.vhd +0 -0
  105. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/vhdl/source/negative.vhd +0 -0
  106. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +0 -0
  107. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/converter/hgq2/__init__.py +0 -0
  108. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/__init__.py +0 -0
  109. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/ops/conv_utils.py +0 -0
  110. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml/trace/ops/einsum_utils.py +0 -0
  111. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml.egg-info/dependency_links.txt +0 -0
  112. {da4ml-0.4.1 → da4ml-0.5.0b0}/src/da4ml.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: da4ml
3
- Version: 0.4.1
3
+ Version: 0.5.0b0
4
4
  Summary: Distributed Arithmetic for Machine Learning
5
5
  Author-email: Chang Sun <chsun@cern.ch>
6
6
  License: GNU Lesser General Public License v3 (LGPLv3)
@@ -19,6 +19,7 @@ Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: llvmlite>=0.44
21
21
  Requires-Dist: numba>=0.61
22
+ Requires-Dist: quantizers<2,>=1
22
23
  Provides-Extra: docs
23
24
  Requires-Dist: hgq2; extra == "docs"
24
25
  Requires-Dist: myst-parser; extra == "docs"
@@ -34,11 +34,12 @@ inp = FixedVariableArrayInput((4, 5))
34
34
  out = operation(inp)
35
35
 
36
36
  # Generate pipelined Verilog code form the traced operation
37
+ # flavor can be 'verilog' or 'vhdl'. VHDL code generated will be in 2008 standard.
37
38
  comb_logic = comb_trace(inp, out)
38
- verilog_model = VerilogModel(comb_logic, 'vmodel', '/tmp/verilog', latency_cutoff=5) # can also be HLSModel
39
- verilog_model.write()
40
- # verilog_model.compile() # compile the generated Verilog code with verilator
41
- # verilog_model.predict(data_inp) # run inference with the compiled model; bit-accurate
39
+ rtl_model = RTLModel(comb_logic, 'vmodel', '/tmp/rtl', flavor='verilog', latency_cutoff=5) # can also be HLSModel
40
+ rtl_model.write()
41
+ # rtl_model.compile() # compile the generated Verilog code with verilator (with GHDL, if using vhdl)
42
+ # rtl_model.predict(data_inp) # run inference with the compiled model; bit-accurate
42
43
  ```
43
44
 
44
45
  ## HGQ2/Keras3 integration:
@@ -29,8 +29,14 @@ classifiers = [
29
29
  "Programming Language :: Python :: 3.13",
30
30
  ]
31
31
  dynamic = [ "version" ]
32
- dependencies = [ "llvmlite>=0.44", "numba>=0.61" ]
33
- optional-dependencies.docs = [ "hgq2", "myst-parser", "pyparsing", "sphinx", "sphinx-rtd-theme" ]
32
+ dependencies = [ "llvmlite>=0.44", "numba>=0.61", "quantizers>=1,<2" ]
33
+ optional-dependencies.docs = [
34
+ "hgq2",
35
+ "myst-parser",
36
+ "pyparsing",
37
+ "sphinx",
38
+ "sphinx-rtd-theme",
39
+ ]
34
40
  urls.repository = "https://github.com/calad0i/da4ml"
35
41
 
36
42
  [tool.setuptools]
@@ -0,0 +1,3 @@
1
+ from . import cmvm, codegen, converter, trace, typing
2
+
3
+ __all__ = ['cmvm', 'codegen', 'converter', 'trace', 'typing']
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.4.1'
32
- __version_tuple__ = version_tuple = (0, 4, 1)
31
+ __version__ = version = '0.5.0b0'
32
+ __version_tuple__ = version_tuple = (0, 5, 0, 'b0')
33
33
 
34
- __commit_id__ = commit_id = 'g359d5d425'
34
+ __commit_id__ = commit_id = 'gcdf2e46cb'
@@ -0,0 +1,4 @@
1
+ from .api import minimal_latency, solve
2
+ from .types import CombLogic, Op, QInterval
3
+
4
+ __all__ = ['minimal_latency', 'solve', 'QInterval', 'Op', 'CombLogic']
@@ -1,10 +1,11 @@
1
1
  from math import ceil, log2
2
+ from typing import TypedDict
2
3
 
3
4
  import numpy as np
4
5
  from numba import jit, prange
5
6
 
6
7
  from .core import _solve, create_state, to_solution
7
- from .types import CascadedSolution, QInterval
8
+ from .types import Pipeline, QInterval
8
9
  from .util import kernel_decompose
9
10
 
10
11
 
@@ -56,7 +57,7 @@ def jit_solve(
56
57
  latencies: list[float] | None = None,
57
58
  adder_size: int = -1,
58
59
  carry_size: int = -1,
59
- ) -> CascadedSolution:
60
+ ) -> Pipeline:
60
61
  """Optimized implementation of a CMVM computation with cascaded two matrices.
61
62
 
62
63
  Parameters
@@ -144,7 +145,7 @@ def jit_solve(
144
145
  if max(latencies1) > latency_allowed:
145
146
  # When latency depends on the bw, may happen
146
147
  print(f'Latency constraint not satisfied: {int(latency_allowed)} < {int(max(latencies1))}')
147
- return CascadedSolution((sol0, sol1))
148
+ return Pipeline((sol0, sol1))
148
149
 
149
150
 
150
151
  @jit(cache=True, parallel=True)
@@ -159,7 +160,7 @@ def solve(
159
160
  adder_size: int = -1,
160
161
  carry_size: int = -1,
161
162
  search_all_decompose_dc: bool = True,
162
- ) -> CascadedSolution:
163
+ ) -> Pipeline:
163
164
  """Solve the CMVM problem with cascaded two matrices.
164
165
 
165
166
  Parameters
@@ -251,3 +252,13 @@ def solve(
251
252
  carry_size=carry_size,
252
253
  )
253
254
  return csol
255
+
256
+
257
+ class solver_options_t(TypedDict, total=False):
258
+ method0: str
259
+ method1: str
260
+ hard_dc: int
261
+ decompose_dc: int
262
+ adder_size: int
263
+ carry_size: int
264
+ search_all_decompose_dc: bool
@@ -4,7 +4,7 @@ from math import log2
4
4
  import numpy as np
5
5
  from numba import jit
6
6
 
7
- from ..types import DAState, Op, QInterval, Solution
7
+ from ..types import CombLogic, DAState, Op, QInterval
8
8
  from .indexers import (
9
9
  idx_mc,
10
10
  idx_mc_dc,
@@ -194,7 +194,7 @@ def to_solution(
194
194
  out_neg.append(sub)
195
195
  out_shift[i_out] = out_shift[i_out] + shift0
196
196
 
197
- return Solution(
197
+ return CombLogic(
198
198
  shape=state.kernel.shape, # type: ignore
199
199
  inp_shift=list(in_shift),
200
200
  out_idxs=out_idx,
@@ -11,7 +11,7 @@ from numpy import float32, int8
11
11
  from numpy.typing import NDArray
12
12
 
13
13
  if TYPE_CHECKING:
14
- from ..trace.tracer import FixedVariable
14
+ from ..trace.fixed_variable import FixedVariable, LookupTable
15
15
 
16
16
 
17
17
  class QInterval(NamedTuple):
@@ -228,9 +228,15 @@ def _(v: Decimal, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
228
228
  return eps * ((floor(v / eps) + bias) % Decimal(2) ** b - bias)
229
229
 
230
230
 
231
- class Solution(NamedTuple):
232
- """Represents a series of operations that can be applied to a vector of data.
233
- May represent a CMVM solution or a general neural network
231
+ class JSONEncoder(json.JSONEncoder):
232
+ def default(self, o):
233
+ if hasattr(o, 'to_dict'):
234
+ return o.to_dict()
235
+ super().default(o)
236
+
237
+
238
+ class CombLogic(NamedTuple):
239
+ """A combinational logic that describes a series of operations on input data to produce output data.
234
240
 
235
241
  Attributes
236
242
  ----------
@@ -247,12 +253,14 @@ class Solution(NamedTuple):
247
253
  ops: list[Op]
248
254
  Core list of operations for generating each buffer element.
249
255
  carry_size: int
250
- Size of the carrier for the adder.
256
+ Size of the carrier for the adder, used for cost and latency estimation.
251
257
  adder_size: int
252
- Elementary size of the adder.
258
+ Elementary size of the adder, used for cost and latency estimation.
259
+ lookup_tables: tuple[LookupTable, ...] | None
260
+ Lookup table arrays for lookup operations, if any.
253
261
 
254
262
 
255
- The core part of the solution is the operations in the ops list.
263
+ The core part of the comb logic is the operations in the ops list.
256
264
  For the exact operations executed with Op, refer to the Op class.
257
265
  After all operations are executed, the output data is read from data[op.out_idx] and multiplied by 2**out_shift.
258
266
 
@@ -266,6 +274,7 @@ class Solution(NamedTuple):
266
274
  ops: list[Op]
267
275
  carry_size: int
268
276
  adder_size: int
277
+ lookup_tables: 'tuple[LookupTable, ...] | None' = None
269
278
 
270
279
  def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False, dump=False):
271
280
  """Executes the solution on the input data.
@@ -343,6 +352,12 @@ class Solution(NamedTuple):
343
352
  case 7:
344
353
  v0, v1 = buf[op.id0], buf[op.id1]
345
354
  buf[i] = v0 * v1
355
+ case 8:
356
+ v0 = buf[op.id0]
357
+ tables = self.lookup_tables
358
+ assert tables is not None, 'No lookup table provided for lookup operation'
359
+ table = tables[op.data]
360
+ buf[i] = table.lookup(v0, self.ops[op.id0].qint)
346
361
  case _:
347
362
  raise ValueError(f'Unknown opcode {op.opcode} in {op}')
348
363
 
@@ -375,6 +390,8 @@ class Solution(NamedTuple):
375
390
  op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
376
391
  case 7:
377
392
  op_str = f'buf[{op.id0}] * buf[{op.id1}]'
393
+ case 8:
394
+ op_str = f'tables[{int(op.data)}].lookup(buf[{op.id0}])'
378
395
  case _:
379
396
  raise ValueError(f'Unknown opcode {op.opcode} in {op}')
380
397
 
@@ -451,7 +468,7 @@ class Solution(NamedTuple):
451
468
  def save(self, path: str | Path):
452
469
  """Save the solution to a file."""
453
470
  with open(path, 'w') as f:
454
- json.dump(self, f)
471
+ json.dump(self, f, cls=JSONEncoder)
455
472
 
456
473
  @classmethod
457
474
  def deserialize(cls, data: dict):
@@ -534,12 +551,8 @@ class Solution(NamedTuple):
534
551
  data.tofile(f)
535
552
 
536
553
 
537
- class CascadedSolution(NamedTuple):
538
- """A solution that implements cascaded matrix-vector multiplications through multiple CMVM stages.
539
-
540
- CascadedSolution represents a sequence of Solution objects where the output of each stage
541
- is fed as input to the next stage.
542
-
554
+ class Pipeline(NamedTuple):
555
+ """A pipeline with II=1,with each stage represented by a CombLogic
543
556
  Attributes
544
557
  ----------
545
558
  solutions: tuple[Solution, ...]
@@ -548,12 +561,13 @@ class CascadedSolution(NamedTuple):
548
561
  Properties
549
562
  ----------
550
563
  kernel: NDArray[float32]
564
+ Only useful when the pipeline describes a linear operation.
551
565
  The overall kernel matrix which the cascaded solution implements: vec @ kernel = solution(vec).
552
566
  This is calculated as the matrix product of all individual solution kernels.
553
567
  cost: float
554
568
  The total cost of the cascaded solution, computed as the sum of the costs of all stages.
555
569
  latency: tuple[float, float]
556
- The minimum and maximum latency of the cascaded solution.
570
+ The minimum and maximum latency of the pipeline, determined by the last stage.
557
571
  inp_qint: list[QInterval]
558
572
  Input quantization intervals
559
573
  inp_lat: list[float]
@@ -572,7 +586,7 @@ class CascadedSolution(NamedTuple):
572
586
  The shape of the corresponding kernel matrix.
573
587
  """
574
588
 
575
- solutions: tuple[Solution, ...]
589
+ solutions: tuple[CombLogic, ...]
576
590
 
577
591
  def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False):
578
592
  out = np.asarray(inp)
@@ -634,12 +648,12 @@ class CascadedSolution(NamedTuple):
634
648
  def save(self, path: str | Path):
635
649
  """Save the solution to a file."""
636
650
  with open(path, 'w') as f:
637
- json.dump(self, f)
651
+ json.dump(self, f, cls=JSONEncoder)
638
652
 
639
653
  @classmethod
640
654
  def deserialize(cls, data: dict):
641
655
  """Load the solution from a file."""
642
- return cls(solutions=tuple(Solution.deserialize(sol) for sol in data[0]))
656
+ return cls(solutions=tuple(CombLogic.deserialize(sol) for sol in data[0]))
643
657
 
644
658
  @classmethod
645
659
  def load(cls, path: str):
@@ -15,7 +15,7 @@ def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
15
15
  thres = _2pn / 1.5
16
16
  bit = (x > thres).astype(np.int8)
17
17
  bit -= (x < -thres).astype(np.int8)
18
- x -= _2pn * bit
18
+ x -= _2pn * bit.astype(x.dtype)
19
19
  buf[..., n] = bit
20
20
  return buf
21
21
 
@@ -50,7 +50,7 @@ def _center(arr: NDArray):
50
50
  arr = arr * (2.0**-shift1)
51
51
  shift0 = shift_centering(arr, 0) # d_in
52
52
  arr = arr * (2.0 ** -shift0[:, None])
53
- return arr, shift0, shift1
53
+ return arr, shift0.astype(np.int8), shift1.astype(np.int8)
54
54
 
55
55
 
56
56
  @jit
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Callable
2
2
 
3
- from ...cmvm.types import QInterval, Solution, _minimal_kif
3
+ from ...cmvm.types import CombLogic, QInterval, _minimal_kif
4
4
  from ...trace.fixed_variable import _const_f
5
5
 
6
6
 
@@ -34,7 +34,7 @@ def get_typestr_fn(flavor: str):
34
34
  return typestr_fn
35
35
 
36
36
 
37
- def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
37
+ def ssa_gen(sol: CombLogic, print_latency: bool, typestr_fn: Callable[[bool | int, int, int], str]):
38
38
  ops = sol.ops
39
39
  all_kifs = list(map(_minimal_kif, (op.qint for op in ops)))
40
40
  all_types = list(map(lambda x: typestr_fn(*x), all_kifs))
@@ -92,6 +92,11 @@ def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int
92
92
  ref_k = f'v{id_c}[{bw_k - 1}]'
93
93
  sign = '-' if op.opcode == -6 else ''
94
94
  ref1 = f'v{op.id1}' if shift == 0 else f'bit_shift<{shift}>(v{op.id1})'
95
+ bw0, bw1 = sum(all_kifs[op.id0]), sum(all_kifs[op.id1])
96
+ if bw0 == 0:
97
+ ref0 = '0'
98
+ if bw1 == 0:
99
+ ref1 = '0'
95
100
  val = f'{ref_k} ? {_type}({ref0}) : {_type}({sign}{ref1})'
96
101
  case 7:
97
102
  # Multiplication
@@ -108,7 +113,7 @@ def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int
108
113
  return lines
109
114
 
110
115
 
111
- def output_gen(sol: Solution, typestr_fn: Callable[[bool | int, int, int], str]):
116
+ def output_gen(sol: CombLogic, typestr_fn: Callable[[bool | int, int, int], str]):
112
117
  lines = []
113
118
  for i, idx in enumerate(sol.out_idxs):
114
119
  if idx < 0:
@@ -124,7 +129,7 @@ def output_gen(sol: Solution, typestr_fn: Callable[[bool | int, int, int], str])
124
129
  return lines
125
130
 
126
131
 
127
- def get_io_types(sol: Solution, flavor: str):
132
+ def get_io_types(sol: CombLogic, flavor: str):
128
133
  typestr_fn = get_typestr_fn(flavor)
129
134
  in_kif = map(max, zip(*map(_minimal_kif, sol.inp_qint)))
130
135
  inp_type = typestr_fn(*in_kif)
@@ -134,7 +139,7 @@ def get_io_types(sol: Solution, flavor: str):
134
139
 
135
140
 
136
141
  def hls_logic_and_bridge_gen(
137
- sol: Solution,
142
+ sol: CombLogic,
138
143
  fn_name: str,
139
144
  flavor: str,
140
145
  pragmas: list[str] | None = None,
@@ -12,7 +12,7 @@ from uuid import uuid4
12
12
  import numpy as np
13
13
  from numpy.typing import NDArray
14
14
 
15
- from da4ml.cmvm.types import Solution
15
+ from da4ml.cmvm.types import CombLogic
16
16
  from da4ml.codegen.hls.hls_codegen import get_io_types, hls_logic_and_bridge_gen
17
17
 
18
18
  from ... import codegen
@@ -24,7 +24,7 @@ T = TypeVar('T', bound=np.floating)
24
24
  class HLSModel:
25
25
  def __init__(
26
26
  self,
27
- solution: Solution,
27
+ solution: CombLogic,
28
28
  prj_name: str,
29
29
  path: str | Path,
30
30
  flavor: str = 'vitis',
@@ -192,12 +192,12 @@ class HLSModel:
192
192
  self.write()
193
193
  self._compile(verbose, openmp, o3, clean)
194
194
 
195
- def predict(self, data: NDArray[T]) -> NDArray[T]:
195
+ def predict(self, data: NDArray[T] | Sequence[NDArray[T]]) -> NDArray[T]:
196
196
  """Run the model on the input data.
197
197
 
198
198
  Parameters
199
199
  ----------
200
- data : NDArray[np.floating]
200
+ data: NDArray[np.floating] | Sequence[NDArray[np.floating]]
201
201
  Input data to the model. The shape is ignored, and the number of samples is
202
202
  determined by the size of the data.
203
203
 
@@ -209,6 +209,9 @@ class HLSModel:
209
209
  assert self._lib is not None, 'Library not loaded, call .compile() first.'
210
210
  inp_size, out_size = self._solution.shape
211
211
 
212
+ if isinstance(data, Sequence):
213
+ data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
214
+
212
215
  dtype = data.dtype
213
216
  if dtype not in (np.float32, np.float64):
214
217
  raise TypeError(f'Unsupported input data type: {dtype}. Expected float32 or float64.')
@@ -1,7 +1,7 @@
1
1
  default: slow
2
2
 
3
3
  VERILATOR_ROOT = $(shell verilator -V | grep -a VERILATOR_ROOT | tail -1 | awk '{{print $$3}}')
4
- INCLUDES = -I./obj_dir -I$(VERILATOR_ROOT)/include
4
+ INCLUDES = -I./obj_dir -I$(VERILATOR_ROOT)/include -I../src
5
5
  WARNINGS = -Wl,--no-undefined
6
6
  CFLAGS = -std=c++17 -fPIC
7
7
  LINKFLAGS = $(INCLUDES) $(WARNINGS)
@@ -9,14 +9,15 @@ LIBNAME = lib$(VM_PREFIX)_$(STAMP).so
9
9
  N_JOBS ?= $(shell nproc)
10
10
  VERILATOR_FLAGS ?=
11
11
 
12
- $(VM_PREFIX).v: $(wildcard $(VM_PREFIX).vhd)
12
+ ../src/$(VM_PREFIX).v: $(wildcard ../src/$(VM_PREFIX).vhd) $(wildcard ../src/$(VM_PREFIX)_stage*.vhd)
13
13
  # vhdl specific - convert to verilog first for verilating
14
14
  mkdir -p obj_dir
15
- ghdl -a --std=08 --workdir=obj_dir multiplier.vhd mux.vhd negative.vhd shift_adder.vhd $(wildcard $(VM_PREFIX:_wrapper=)_stage*.vhd) $(wildcard $(VM_PREFIX:_wrapper=).vhd) $(VM_PREFIX).vhd
15
+ cp ../src/memfiles/* ./
16
+ ghdl -a --std=08 --workdir=obj_dir ../src/static/multiplier.vhd ../src/static/mux.vhd ../src/static/negative.vhd ../src/static/shift_adder.vhd ../src/static/lookup_table.vhd $(wildcard ../src/$(VM_PREFIX:_wrapper=)_stage*.vhd) $(wildcard ../src/$(VM_PREFIX:_wrapper=).vhd) ../src/$(VM_PREFIX).vhd
16
17
  ghdl synth --std=08 --workdir=obj_dir --out=verilog $(VM_PREFIX) > $(VM_PREFIX).v
17
18
 
18
- ./obj_dir/libV$(VM_PREFIX).a ./obj_dir/libverilated.a ./obj_dir/V$(VM_PREFIX)__ALL.a: $(VM_PREFIX).v
19
- verilator --cc -j $(N_JOBS) -build $(VM_PREFIX).v --prefix V$(VM_PREFIX) $(VERILATOR_FLAGS) -CFLAGS "$(CFLAGS)"
19
+ ./obj_dir/libV$(VM_PREFIX).a ./obj_dir/libverilated.a ./obj_dir/V$(VM_PREFIX)__ALL.a: ../src/$(VM_PREFIX).v $(wildcard ../src/$(VM_PREFIX)_stage*.v)
20
+ verilator --cc -j $(N_JOBS) -build $(VM_PREFIX).v --prefix V$(VM_PREFIX) $(VERILATOR_FLAGS) -CFLAGS "$(CFLAGS)" -I../src -I../src/static
20
21
 
21
22
  $(LIBNAME): ./obj_dir/libV$(VM_PREFIX).a ./obj_dir/libverilated.a ./obj_dir/V$(VM_PREFIX)__ALL.a $(VM_PREFIX)_binder.cc
22
23
  $(CXX) $(CFLAGS) $(LINKFLAGS) $(CXXFLAGS2) -pthread -shared -o $(LIBNAME) $(VM_PREFIX)_binder.cc ./obj_dir/libV$(VM_PREFIX).a ./obj_dir/libverilated.a ./obj_dir/V$(VM_PREFIX)__ALL.a $(EXTRA_CXXFLAGS)
@@ -0,0 +1,104 @@
1
+ set project_name "$::env(PROJECT_NAME)"
2
+ set device "$::env(DEVICE)"
3
+ set source_type "$::env(SOURCE_TYPE)"
4
+
5
+ set top_module "${project_name}"
6
+ set output_dir "./output_${project_name}"
7
+
8
+ file mkdir $output_dir
9
+ file mkdir "${output_dir}/reports"
10
+
11
+ project_new "${project_name}" -overwrite -revision "${project_name}"
12
+
13
+ set_global_assignment -name FAMILY [lindex [split "${device}" "-"] 0]
14
+ set_global_assignment -name DEVICE "${device}"
15
+
16
+ if { "${source_type}" != "vhdl" && "${source_type}" != "verilog" } {
17
+ puts "Error: SOURCE_TYPE must be either 'vhdl' or 'verilog'."
18
+ exit 1
19
+ }
20
+
21
+ # Add source files based on type
22
+ if { "${source_type}" == "vhdl" } {
23
+ set_global_assignment -name VHDL_INPUT_VERSION VHDL_2008
24
+
25
+ foreach file [glob -nocomplain "src/static/*.vhd"] {
26
+ set_global_assignment -name VHDL_FILE "${file}"
27
+ }
28
+
29
+ set_global_assignment -name VHDL_FILE "src/${project_name}.vhd"
30
+ foreach file [glob -nocomplain "src/${project_name}_stage*.vhd"] {
31
+ set_global_assignment -name VHDL_FILE "${file}"
32
+ }
33
+ } else {
34
+ foreach file [glob -nocomplain "src/static/*.v"] {
35
+ set_global_assignment -name VERILOG_FILE "${file}"
36
+ }
37
+
38
+ set_global_assignment -name VERILOG_FILE "src/${project_name}.v"
39
+ foreach file [glob -nocomplain "src/${project_name}_stage*.v"] {
40
+ set_global_assignment -name VERILOG_FILE "${file}"
41
+ }
42
+ }
43
+
44
+ set mems [glob -nocomplain "src/memfiles/*.mem"]
45
+
46
+ # VHDL only uses relative path to working dir apparently...
47
+ if { "${source_type}" == "vhdl" } {
48
+ foreach f $mems {
49
+ file copy -force $f [file tail $f]
50
+ }
51
+ set mems [glob -nocomplain "*.mem"]
52
+ }
53
+
54
+ foreach f $mems {
55
+ set_global_assignment -name MIF_FILE "${f}"
56
+ }
57
+
58
+ # Add SDC constraint file if it exists
59
+ if { [file exists "src/${project_name}.sdc"] } {
60
+ set_global_assignment -name SDC_FILE "${project_name}.sdc"
61
+ }
62
+
63
+ # Set top-level entity
64
+ set_global_assignment -name TOP_LEVEL_ENTITY "${top_module}"
65
+
66
+ # OOC
67
+ load_package flow
68
+
69
+ proc make_all_pins_virtual {} {
70
+ execute_module -tool map
71
+
72
+ set name_ids [get_names -filter * -node_type pin]
73
+
74
+ foreach_in_collection name_id $name_ids {
75
+ set pin_name [get_name_info -info full_path $name_id]
76
+ post_message "Making VIRTUAL_PIN assignment to $pin_name"
77
+ set_instance_assignment -to $pin_name -name VIRTUAL_PIN ON
78
+ }
79
+ export_assignments
80
+ }
81
+
82
+ make_all_pins_virtual
83
+
84
+ # Config
85
+ set_global_assignment -name OPTIMIZATION_MODE "HIGH PERFORMANCE EFFORT"
86
+ set_global_assignment -name OPTIMIZATION_TECHNIQUE SPEED
87
+ set_global_assignment -name AUTO_RESOURCE_SHARING ON
88
+ set_global_assignment -name ALLOW_ANY_RAM_SIZE_FOR_RECOGNITION ON
89
+ set_global_assignment -name ALLOW_ANY_ROM_SIZE_FOR_RECOGNITION ON
90
+ set_global_assignment -name ALLOW_REGISTER_RETIMING ON
91
+
92
+ set_global_assignment -name TIMEQUEST_MULTICORNER_ANALYSIS ON
93
+ set_global_assignment -name TIMEQUEST_DO_CCPP_REMOVAL ON
94
+
95
+ set_global_assignment -name FITTER_EFFORT "STANDARD FIT"
96
+
97
+ set_global_assignment -name SYNTH_TIMING_DRIVEN_SYNTHESIS ON
98
+ set_global_assignment -name SYNTHESIS_EFFORT AUTO
99
+ set_global_assignment -name ADV_NETLIST_OPT_SYNTH_WYSIWYG_REMAP ON
100
+
101
+ # Run!!!
102
+ execute_flow -compile
103
+
104
+ project_close
@@ -1,6 +1,6 @@
1
- set project_name "${PROJECT_NAME}"
2
- set device "${DEVICE}"
3
- set source_type "${SOURCE_TYPE}"
1
+ set project_name "$::env(PROJECT_NAME)"
2
+ set device "$::env(DEVICE)"
3
+ set source_type "$::env(SOURCE_TYPE)"
4
4
 
5
5
  set top_module "${project_name}"
6
6
  set output_dir "./output_${project_name}"
@@ -17,28 +17,47 @@ if { $source_type != "vhdl" && $source_type != "verilog" } {
17
17
  if { $source_type == "vhdl" } {
18
18
  set_property TARGET_LANGUAGE VHDL [current_project]
19
19
 
20
- read_vhdl -vhdl2008 "${project_name}.vhd"
21
- read_vhdl -vhdl2008 "shift_adder.vhd"
22
- read_vhdl -vhdl2008 "negative.vhd"
23
- read_vhdl -vhdl2008 "mux.vhd"
24
- read_vhdl -vhdl2008 "multiplier.vhd"
25
- foreach file [glob -nocomplain "${project_name}_stage*.vhd"] {
20
+ foreach file [glob -nocomplain "src/static/*.vhd"] {
21
+ read_vhdl -vhdl2008 $file
22
+ }
23
+
24
+ read_vhdl -vhdl2008 "src/${project_name}.vhd"
25
+ foreach file [glob -nocomplain "src/${project_name}_stage*.vhd"] {
26
26
  read_vhdl -vhdl2008 $file
27
27
  }
28
28
  } else {
29
29
  set_property TARGET_LANGUAGE Verilog [current_project]
30
30
 
31
- read_verilog "${project_name}.v"
32
- read_verilog "shift_adder.v"
33
- read_verilog "negative.v"
34
- read_verilog "mux.v"
35
- read_verilog "multiplier.v"
36
- foreach file [glob -nocomplain "${project_name}_stage*.v"] {
31
+ foreach file [glob -nocomplain "src/static/*.v"] {
32
+ read_verilog $file
33
+ }
34
+
35
+ read_verilog "src/${project_name}.v"
36
+ foreach file [glob -nocomplain "src/${project_name}_stage*.v"] {
37
37
  read_verilog $file
38
38
  }
39
39
  }
40
40
 
41
- read_xdc "${project_name}.xdc" -mode out_of_context
41
+
42
+ set mems [glob -nocomplain "src/memfiles/*.mem"]
43
+
44
+ # VHDL only uses relative path to working dir apparently...
45
+ if { $source_type == "vhdl" } {
46
+ foreach f $mems {
47
+ file copy -force $f [file tail $f]
48
+ }
49
+ set mems [glob -nocomplain "*.mem"]
50
+ }
51
+
52
+ foreach f $mems {
53
+ add_files -fileset [current_fileset] $f
54
+ set_property used_in_synthesis true [get_files $f]
55
+ }
56
+
57
+ # Add XDC constraint if it exists
58
+ if { [file exists "src/${project_name}.xdc"] } {
59
+ read_xdc "src/${project_name}.xdc" -mode out_of_context
60
+ }
42
61
 
43
62
  set_property top $top_module [current_fileset]
44
63
 
@@ -46,8 +65,8 @@ file mkdir $output_dir
46
65
  file mkdir "${output_dir}/reports"
47
66
 
48
67
  # synth
49
- synth_design -top $top_module -mode out_of_context -retiming \
50
- -flatten_hierarchy full -resource_sharing auto
68
+ synth_design -top $top_module -mode out_of_context -global_retiming on \
69
+ -flatten_hierarchy full -resource_sharing auto -directive PerformanceOptimized
51
70
 
52
71
  write_checkpoint -force "${output_dir}/${project_name}_post_synth.dcp"
53
72
 
@@ -66,6 +85,7 @@ report_design_analysis -congestion -file "${output_dir}/reports/${project_name}_
66
85
 
67
86
  phys_opt_design -directive AggressiveExplore
68
87
  write_checkpoint -force "${output_dir}/${project_name}_post_place.dcp"
88
+ file delete -force "${output_dir}/${project_name}_post_synth.dcp"
69
89
 
70
90
  report_design_analysis -congestion -file "${output_dir}/reports/${project_name}_post_place_congestion_final.rpt"
71
91
 
@@ -75,6 +95,7 @@ report_utilization -hierarchical -file "${output_dir}/reports/${project_name}_po
75
95
  # route
76
96
  route_design -directive NoTimingRelaxation
77
97
  write_checkpoint -force "${output_dir}/${project_name}_post_route.dcp"
98
+ file delete -force "${output_dir}/${project_name}_post_place.dcp"
78
99
 
79
100
 
80
101
  report_timing_summary -file "${output_dir}/reports/${project_name}_post_route_timing.rpt"
@@ -0,0 +1,27 @@
1
+ set clock_period $::env(CLOCK_PERIOD)
2
+
3
+ # Clock uncertainty as percentage of clock period
4
+ set uncertainty_setup_r $::env(UNCERTAINITY_SETUP)
5
+ set uncertainty_hold_r $::env(UNCERTAINITY_HOLD)
6
+ set delay_max_r $::env(DELAY_MAX)
7
+ set delay_min_r $::env(DELAY_MIN)
8
+
9
+ # Calculate actual uncertainty values
10
+ set uncertainty_setup [expr {$clock_period * $uncertainty_setup_r}]
11
+ set uncertainty_hold [expr {$clock_period * $uncertainty_hold_r}]
12
+ set delay_max [expr {$clock_period * $delay_max_r}]
13
+ set delay_min [expr {$clock_period * $delay_min_r}]
14
+
15
+ # Create clock with variable period
16
+ create_clock -period $clock_period -name sys_clk [get_ports {clk}]
17
+
18
+ # Input/Output constraints
19
+ set_input_delay -clock sys_clk -max $delay_max [get_ports {model_inp[*]}]
20
+ set_input_delay -clock sys_clk -min $delay_min [get_ports {model_inp[*]}]
21
+
22
+ set_output_delay -clock sys_clk -max $delay_max [get_ports {model_out[*]}]
23
+ set_output_delay -clock sys_clk -min $delay_min [get_ports {model_out[*]}]
24
+
25
+ # Apply calculated uncertainty values
26
+ set_clock_uncertainty -setup -to [get_clocks sys_clk] $uncertainty_setup
27
+ set_clock_uncertainty -hold -to [get_clocks sys_clk] $uncertainty_hold