da4ml 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (86) hide show
  1. da4ml-0.3.0/.clang-format +43 -0
  2. da4ml-0.3.0/DAIS-spec.md +79 -0
  3. da4ml-0.3.0/PKG-INFO +107 -0
  4. da4ml-0.3.0/README.md +84 -0
  5. da4ml-0.3.0/interperter/DAISInterpreter.cc +314 -0
  6. da4ml-0.3.0/interperter/DAISInterpreter.hh +90 -0
  7. {da4ml-0.2.0 → da4ml-0.3.0}/pyproject.toml +1 -1
  8. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/_version.py +2 -2
  9. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/api.py +2 -6
  10. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/core/__init__.py +0 -1
  11. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/types.py +99 -19
  12. da4ml-0.3.0/src/da4ml/codegen/__init__.py +12 -0
  13. da4ml-0.3.0/src/da4ml/codegen/cpp/__init__.py +4 -0
  14. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/cpp/cpp_codegen.py +58 -25
  15. da4ml-0.3.0/src/da4ml/codegen/cpp/hls_model.py +252 -0
  16. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
  17. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
  18. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
  19. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
  20. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
  21. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
  22. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
  23. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
  24. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
  25. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
  26. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
  27. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
  28. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
  29. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
  30. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
  31. da4ml-0.3.0/src/da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
  32. da4ml-0.3.0/src/da4ml/codegen/cpp/source/binder_util.hh +56 -0
  33. da4ml-0.3.0/src/da4ml/codegen/cpp/source/build_binder.mk +24 -0
  34. da4ml-0.2.0/src/da4ml/codegen/cpp/source/vitis.h → da4ml-0.3.0/src/da4ml/codegen/cpp/source/vitis_bitshift.hh +1 -1
  35. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/__init__.py +2 -3
  36. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/comb.py +65 -24
  37. da4ml-0.3.0/src/da4ml/codegen/verilog/io_wrapper.py +150 -0
  38. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/pipeline.py +21 -3
  39. da4ml-0.3.0/src/da4ml/codegen/verilog/source/binder_util.hh +72 -0
  40. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/source/build_prj.tcl +0 -1
  41. da4ml-0.3.0/src/da4ml/codegen/verilog/source/mux.v +58 -0
  42. da4ml-0.3.0/src/da4ml/codegen/verilog/source/negative.v +28 -0
  43. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/source/shift_adder.v +4 -1
  44. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/source/template.xdc +3 -0
  45. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/verilog_model.py +42 -15
  46. da4ml-0.3.0/src/da4ml/converter/__init__.py +0 -0
  47. da4ml-0.3.0/src/da4ml/converter/hgq2/parser.py +105 -0
  48. da4ml-0.3.0/src/da4ml/converter/hgq2/replica.py +383 -0
  49. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/trace/__init__.py +2 -2
  50. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/trace/fixed_variable.py +177 -18
  51. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/trace/fixed_variable_array.py +124 -9
  52. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/trace/ops/__init__.py +22 -6
  53. da4ml-0.3.0/src/da4ml/trace/ops/conv_utils.py +236 -0
  54. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/trace/ops/einsum_utils.py +9 -6
  55. da4ml-0.3.0/src/da4ml/trace/ops/reduce_utils.py +103 -0
  56. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/trace/pipeline.py +36 -34
  57. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/trace/tracer.py +37 -5
  58. da4ml-0.3.0/src/da4ml.egg-info/PKG-INFO +107 -0
  59. da4ml-0.3.0/src/da4ml.egg-info/SOURCES.txt +74 -0
  60. da4ml-0.2.0/.clang-format +0 -191
  61. da4ml-0.2.0/PKG-INFO +0 -65
  62. da4ml-0.2.0/README.md +0 -42
  63. da4ml-0.2.0/src/da4ml/codegen/__init__.py +0 -11
  64. da4ml-0.2.0/src/da4ml/codegen/cpp/__init__.py +0 -3
  65. da4ml-0.2.0/src/da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
  66. da4ml-0.2.0/src/da4ml/codegen/verilog/io_wrapper.py +0 -255
  67. da4ml-0.2.0/src/da4ml/trace/ops/conv_utils.py +0 -104
  68. da4ml-0.2.0/src/da4ml.egg-info/PKG-INFO +0 -65
  69. da4ml-0.2.0/src/da4ml.egg-info/SOURCES.txt +0 -46
  70. {da4ml-0.2.0 → da4ml-0.3.0}/.github/workflows/python-publish.yml +0 -0
  71. {da4ml-0.2.0 → da4ml-0.3.0}/.gitignore +0 -0
  72. {da4ml-0.2.0 → da4ml-0.3.0}/.pre-commit-config.yaml +0 -0
  73. {da4ml-0.2.0 → da4ml-0.3.0}/LICENSE +0 -0
  74. {da4ml-0.2.0 → da4ml-0.3.0}/setup.cfg +0 -0
  75. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/__init__.py +0 -0
  76. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/__init__.py +0 -0
  77. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/core/indexers.py +0 -0
  78. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/core/state_opr.py +0 -0
  79. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/util/__init__.py +0 -0
  80. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/util/bit_decompose.py +0 -0
  81. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/cmvm/util/mat_decompose.py +0 -0
  82. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml/codegen/verilog/source/build_binder.mk +0 -0
  83. /da4ml-0.2.0/src/da4ml/codegen/verilog/source/ioutils.hh → /da4ml-0.3.0/src/da4ml/codegen/verilog/source/ioutil.hh +0 -0
  84. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml.egg-info/dependency_links.txt +0 -0
  85. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml.egg-info/requires.txt +0 -0
  86. {da4ml-0.2.0 → da4ml-0.3.0}/src/da4ml.egg-info/top_level.txt +0 -0
@@ -0,0 +1,43 @@
1
+ BasedOnStyle: LLVM
2
+ AlignAfterOpenBracket: BlockIndent
3
+ AllowAllParametersOfDeclarationOnNextLine: false
4
+ AllowShortBlocksOnASingleLine: true
5
+ AllowShortCaseLabelsOnASingleLine: true
6
+ AllowShortFunctionsOnASingleLine: All
7
+ AlwaysBreakAfterDefinitionReturnType: None
8
+ AlwaysBreakAfterReturnType: None
9
+ BreakBeforeBraces: Custom
10
+ BraceWrapping:
11
+ AfterClass: true
12
+ AfterControlStatement: true
13
+ AfterEnum: true
14
+ AfterFunction: true
15
+ AfterNamespace: true
16
+ AfterStruct: true
17
+ AfterUnion: true
18
+ AfterExternBlock: true
19
+ BeforeCatch: true
20
+ BeforeElse: true
21
+ IndentBraces: true
22
+ SplitEmptyFunction: false
23
+ SplitEmptyRecord: false
24
+ SplitEmptyNamespace: false
25
+ ContinuationIndentWidth: 4
26
+ ConstructorInitializerIndentWidth: 4
27
+ Cpp11BracedListStyle: true
28
+ KeepEmptyLinesAtTheStartOfBlocks: false
29
+ Language: Cpp
30
+ NamespaceIndentation: All
31
+ PenaltyExcessCharacter: 100
32
+ PenaltyReturnTypeOnItsOwnLine: 1
33
+ PointerAlignment: Right
34
+ SortIncludes: false
35
+ SpaceBeforeParens: Never
36
+ SpacesInContainerLiterals: false
37
+ Standard: Cpp11
38
+ BinPackArguments: false
39
+ BinPackParameters: false
40
+ ColumnLimit: 120
41
+ IndentWidth: 4
42
+ TabWidth: 4
43
+ UseTab: Never
@@ -0,0 +1,79 @@
1
+ Distributed Arithmetic Instruction Set (DAIS)
2
+ =================================================
3
+
4
+ DAIS is a minimal domain-specific language (DSL) for the da4ml framework. It is designed to be minimal to satisfy the requirements for reqpresenting a neural network in high-granularity. Specifically, each DAIS program contains one block of logic that are fully-parallelizable (i.e., fully combinational), and resource multiplexing shall be performed on a higher level.
5
+
6
+ As the name suggests, DAIS is a Instruction Set level language for easy interperting and compilation into HDL or HLS languages. One program in DAIS consists of the following components:
7
+
8
+ ## Program Structure
9
+
10
+ - `shape`: tuple<int, int>
11
+ - The number of inputs and outputs of the program.
12
+ - `inp_shift`: vector<int>
13
+ - The shifts required to interpret the input data. (i.e., number of integers in the fixed-point representation)
14
+ - `out_idxs`: vector<int>
15
+ - The indices of the output data shall be read from the buffer
16
+ - `out_shifts`: vector<int>
17
+ - The shifts required to interpret the output data.
18
+ - `out_negs`: vector<bool>
19
+ - The signs of the output data.
20
+ - `ops`: vector<Op>
21
+ - The core list of operations for populating the full buffer.
22
+
23
+ Each operation is reqpresented as a `Op` object, consists of the following components:
24
+ - `opcode`: int
25
+ - The operation code, see [OpCode](#opcode).
26
+ - `id0`, `id1`: int
27
+ - The first and second operand indices in the buffer. Unused operands must be set to `-1`.
28
+ - `data`: int64
29
+ - Extra integer data for the operation, functionality depends on the opcode.
30
+ - `dtype`: tuple<float, float, float> OR tuple<int/bool, int, int>
31
+ - Annotates the datatype of the output buffer as a quantization interval.
32
+ - (min, max, step) or (signed, integer_bits (excl sign bit), fractional_bits). If using (min, max, step), format, it is assumed that the minimal fixed-point representation that contains the full range of the quantization interval is used. (e.g., (-3., 3., 1.) is the same as (-4., 3., 1.): both are (1, 2, 0) in fixed point representation). Step **must** be of a power of two.
33
+ - **Must** cause no overflow if the operation itself does not imply quantization.
34
+
35
+ The program is executed in the following:
36
+ 1. Instantiate a empty buffer of size `len(ops)`.
37
+ 2. Go through the list of operations in `ops`. Fill the i-th index of the buffer with the result of the i-th operation: buf[i] = ops[i](buf, inp)
38
+ 3. Instantiate the output buffer of size `shape[1]`.
39
+ 4. Fill output buffer:
40
+ - `output_buf[i] = buf[out_idxs[i]] * 2^out_shifts[i] * (-1 if out_negs[i] else 1)`
41
+
42
+ ### OpCode
43
+ The operation codes are defined as follows:
44
+ - `-1`: Copy from input buffer (**implies quantization**)
45
+ - `buf[i] = input[id0]`
46
+ - `0/1`: Addition/Subtraction
47
+ - `buf[i] = buf[id0] +/- buf[id1] * 2^data`
48
+ - `2/-2`: ReLU (**implies quantization**)
49
+ - `buf[i] = quantize(relu(+/- buf[id0]))`
50
+ - `3/-3`: Quantization (**implies quantization**)
51
+ - `buf[i] = quantize(+/- buf[id0])`
52
+ - `4`: Add a constant
53
+ - `buf[i] = buf[id0] + data * qint.step`
54
+ - `5`: Define a constant
55
+ - `buf[i] = data * qint.step`
56
+ - `6/-6`: Mux by MSB
57
+ - `buf[i] = MSB(buf[int32(data_lower_i32)]) ? buf[id0] : +/- buf[id1] * 2^int32(data_higher_i32)`
58
+ - `*`: Multiplication (**NOT IMPLEMENTED**)
59
+ - `buf[i] = buf[id0] * buf[id1]`
60
+
61
+ In all cases, unused id0 or id1 **must** be set to `-1`; id0, id1 (and data for opcode=+/-6) **must** be smaller than the index of the operation itself to ensure causality. All quantization are direct bit-drop in binary format (i.e., WRAP for overflow and TRUNC for rounding).
62
+
63
+ ### Binary Representation
64
+ The binary representation of the program is as follows, in order:
65
+ - `shape`: int32[2]
66
+ - `len(ops)`: int32
67
+ - `inp_shift`: int32[shape[0]]
68
+ - `out_idxs`: int32[shape[1]]
69
+ - `out_shifts`: int32[shape[1]]
70
+ - `out_negs`: int32[shape[1]]
71
+ - `ops`: Op[len(ops)]
72
+ - `opcode`: int32
73
+ - `id0`: int32
74
+ - `id1`: int32
75
+ - `data_higher`: int32
76
+ - `data_lower`: int32
77
+ - `dtype`: int32[3] (only (signed, integer_bits, fractional_bits) format for binary representation)
78
+
79
+ In execution, the internal buffer **must** have larger bitwidth than the maximum bitwidth appears in any of the operations. When an operation implies quantization, the program **must** apply the quantization explicitly. When an operation does not imply quantization, the program **may** apply quantization and verify no value change is incurred as a result.
da4ml-0.3.0/PKG-INFO ADDED
@@ -0,0 +1,107 @@
1
+ Metadata-Version: 2.4
2
+ Name: da4ml
3
+ Version: 0.3.0
4
+ Summary: Digital Arithmetic for Machine Learning
5
+ Author-email: Chang Sun <chsun@cern.ch>
6
+ License: GNU Lesser General Public License v3 (LGPLv3)
7
+ Project-URL: repository, https://github.com/calad0i/da4ml
8
+ Keywords: CMVM,distributed arithmetic,hls4ml,MCM,subexpression elimination
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Requires-Python: >=3.10
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: llvmlite>=0.44
21
+ Requires-Dist: numba>=0.61
22
+ Dynamic: license-file
23
+
24
+ # da4ml: Distributed Arithmetic for Machine Learning
25
+
26
+ This project performs Constant Matrix-Vector Multiplication (CMVM) with Distributed Arithmetic (DA) for Machine Learning (ML) on a Field Programmable Gate Arrays (FPGAs).
27
+
28
+ CMVM optimization is done through greedy CSE of two-term subexpressions, with possible Delay Constraints (DC). The optimization is done in jitted Python (Numba), and a list of optimized operations is generated as traced Python code.
29
+
30
+ The project generates Verilog or Vitis HLS code for the optimized CMVM operations. This project can be used in conjunction with [`hls4ml`](https://github.com/fastmachinelearning/hls4ml/) for optimizing the neural networks deployed on FPGAs. For a subset of neural networks, the full design can be generated standalone in Verilog or Vitis HLS.
31
+
32
+
33
+ ## Installation
34
+
35
+ The project is available on PyPI and can be installed with pip:
36
+
37
+ ```bash
38
+ pip install da4ml
39
+ ```
40
+
41
+ Notice that `numba>=6.0.0` is required for the project to work. The project does not work with `python<3.10`. If the project fails to compile, try upgrading `numba` and `llvmlite` to the latest versions.
42
+
43
+ ## `hls4ml`
44
+
45
+ The major use of this project is through the `distributed_arithmetic` strategy in the `hls4ml`:
46
+
47
+ ```python
48
+ model_hls = hls4ml.converters.convert_from_keras_model(
49
+ model,
50
+ hls_config={
51
+ 'Model': {
52
+ ...
53
+ 'Strategy': 'distributed_arithmetic',
54
+ },
55
+ ...
56
+ },
57
+ ...
58
+ )
59
+ ```
60
+
61
+ Currently, `Dense/Conv1D/Conv2D` layers are supported for both `io_parallel` and `io_stream` dataflows. However, notice that distributed arithmetic implies `reuse_factor=1`, as the whole kernel is implemented in combinational logic.
62
+
63
+ ## Standalone usage
64
+
65
+ ### `HGQ2`
66
+
67
+ For some models trained with `HGQ2`, the `da4ml` can be used to generate the whole model in Verilog or Vitis HLS:
68
+
69
+ ```python
70
+ from da4ml.codegen import HLSModel, VerilogModel
71
+ from da4ml.converter.hgq2.parser import trace_model
72
+ from da4ml.trace import comb_trace
73
+
74
+ inp, out = trace_model(hgq2_model)
75
+ comb_logic = comb_trace(inp[0], out[0]) # Currently, only models with 1 input and 1 output are supported
76
+
77
+ # Pipelined Verilog model generation
78
+ # `latency_cutoff` is used to control auto piplining behavior. To disable pipelining, set it to -1.
79
+ verilog_model = VerilogModel(sol, prj_name='barbar', path='/tmp/barbar', latency_cutoff=5)
80
+ verilog_model.compile() # write and verilator binding
81
+ verilog_model.predict(inputs)
82
+
83
+ vitis_hls_model = HLSModel(sol, prj_name='foo', path='/tmp/foo', flavor='vitis') # Only vitis is supported for now
84
+ vitis_hls_model.compile() # write and hls binding
85
+ vitis_hls_model.predict(inputs)
86
+ ```
87
+
88
+ ### Functional Definition
89
+ For generic operations, one can define a combinational logic with the functional API:
90
+
91
+ ```python
92
+ from da4ml.trace import FixedVariableArray, HWConfig, comb_trace
93
+ from da4ml.trace.ops import einsum, relu, quantize, conv, pool
94
+
95
+ # k, i, f are numpy arrays of integers: keep_negative (0/1), integer bits (excl. sign), fractional bits
96
+ inp = FixedVariableArray.from_kif(k, i, f, HWConfig(1, -1, -1), solver_options={'hard_dc':2})
97
+ out = inp @ kernel
98
+ out = relu(out)
99
+ out = einsum(equation, out, weights)
100
+ ...
101
+
102
+ comb = comb_trace(inp, out)
103
+ ```
104
+
105
+ `+`, `-`, `@` are supported as well as `einsum`, `relu`, `quantize` (WRAP, with TRN or RND), `conv`, `pool` (average only). For multiplications, only power-of-two multipliers are supported, otherwise use `einsum` or `@` operators.
106
+
107
+ The `comb_trace` returns a `Solution` objects that contains a list of low-level operations that are used to implement the combinational logic, which in turn can be used to generate Verilog or Vitis HLS code.
da4ml-0.3.0/README.md ADDED
@@ -0,0 +1,84 @@
1
+ # da4ml: Distributed Arithmetic for Machine Learning
2
+
3
+ This project performs Constant Matrix-Vector Multiplication (CMVM) with Distributed Arithmetic (DA) for Machine Learning (ML) on a Field Programmable Gate Arrays (FPGAs).
4
+
5
+ CMVM optimization is done through greedy CSE of two-term subexpressions, with possible Delay Constraints (DC). The optimization is done in jitted Python (Numba), and a list of optimized operations is generated as traced Python code.
6
+
7
+ The project generates Verilog or Vitis HLS code for the optimized CMVM operations. This project can be used in conjunction with [`hls4ml`](https://github.com/fastmachinelearning/hls4ml/) for optimizing the neural networks deployed on FPGAs. For a subset of neural networks, the full design can be generated standalone in Verilog or Vitis HLS.
8
+
9
+
10
+ ## Installation
11
+
12
+ The project is available on PyPI and can be installed with pip:
13
+
14
+ ```bash
15
+ pip install da4ml
16
+ ```
17
+
18
+ Notice that `numba>=6.0.0` is required for the project to work. The project does not work with `python<3.10`. If the project fails to compile, try upgrading `numba` and `llvmlite` to the latest versions.
19
+
20
+ ## `hls4ml`
21
+
22
+ The major use of this project is through the `distributed_arithmetic` strategy in the `hls4ml`:
23
+
24
+ ```python
25
+ model_hls = hls4ml.converters.convert_from_keras_model(
26
+ model,
27
+ hls_config={
28
+ 'Model': {
29
+ ...
30
+ 'Strategy': 'distributed_arithmetic',
31
+ },
32
+ ...
33
+ },
34
+ ...
35
+ )
36
+ ```
37
+
38
+ Currently, `Dense/Conv1D/Conv2D` layers are supported for both `io_parallel` and `io_stream` dataflows. However, notice that distributed arithmetic implies `reuse_factor=1`, as the whole kernel is implemented in combinational logic.
39
+
40
+ ## Standalone usage
41
+
42
+ ### `HGQ2`
43
+
44
+ For some models trained with `HGQ2`, the `da4ml` can be used to generate the whole model in Verilog or Vitis HLS:
45
+
46
+ ```python
47
+ from da4ml.codegen import HLSModel, VerilogModel
48
+ from da4ml.converter.hgq2.parser import trace_model
49
+ from da4ml.trace import comb_trace
50
+
51
+ inp, out = trace_model(hgq2_model)
52
+ comb_logic = comb_trace(inp[0], out[0]) # Currently, only models with 1 input and 1 output are supported
53
+
54
+ # Pipelined Verilog model generation
55
+ # `latency_cutoff` is used to control auto piplining behavior. To disable pipelining, set it to -1.
56
+ verilog_model = VerilogModel(sol, prj_name='barbar', path='/tmp/barbar', latency_cutoff=5)
57
+ verilog_model.compile() # write and verilator binding
58
+ verilog_model.predict(inputs)
59
+
60
+ vitis_hls_model = HLSModel(sol, prj_name='foo', path='/tmp/foo', flavor='vitis') # Only vitis is supported for now
61
+ vitis_hls_model.compile() # write and hls binding
62
+ vitis_hls_model.predict(inputs)
63
+ ```
64
+
65
+ ### Functional Definition
66
+ For generic operations, one can define a combinational logic with the functional API:
67
+
68
+ ```python
69
+ from da4ml.trace import FixedVariableArray, HWConfig, comb_trace
70
+ from da4ml.trace.ops import einsum, relu, quantize, conv, pool
71
+
72
+ # k, i, f are numpy arrays of integers: keep_negative (0/1), integer bits (excl. sign), fractional bits
73
+ inp = FixedVariableArray.from_kif(k, i, f, HWConfig(1, -1, -1), solver_options={'hard_dc':2})
74
+ out = inp @ kernel
75
+ out = relu(out)
76
+ out = einsum(equation, out, weights)
77
+ ...
78
+
79
+ comb = comb_trace(inp, out)
80
+ ```
81
+
82
+ `+`, `-`, `@` are supported as well as `einsum`, `relu`, `quantize` (WRAP, with TRN or RND), `conv`, `pool` (average only). For multiplications, only power-of-two multipliers are supported, otherwise use `einsum` or `@` operators.
83
+
84
+ The `comb_trace` returns a `Solution` objects that contains a list of low-level operations that are used to implement the combinational logic, which in turn can be used to generate Verilog or Vitis HLS code.
@@ -0,0 +1,314 @@
1
+ #include "DAISInterpreter.hh"
2
+ #include <cstring>
3
+ #include <stdexcept>
4
+ #include <fstream>
5
+ #include <iostream>
6
+ #include <cmath>
7
+
8
+ namespace dais
9
+ {
10
+
11
+ void DAISInterpreter::load_from_binary(const std::vector<int32_t> &binary_data)
12
+ {
13
+ n_in = binary_data[0];
14
+ n_out = binary_data[1];
15
+ n_ops = binary_data[2];
16
+ size_t expect_length = 3 + n_in + 3 * n_out + 8 * n_ops;
17
+ const static size_t d_size = sizeof(int32_t);
18
+
19
+ if(binary_data.size() != expect_length)
20
+ throw std::runtime_error(
21
+ "Binary data size mismatch: expected " + std::to_string(expect_length * d_size) + "bytes , got " +
22
+ std::to_string(binary_data.size() * d_size) + " bytes"
23
+ );
24
+
25
+ ops.resize(n_ops);
26
+ inp_shift.resize(n_in);
27
+ out_idxs.resize(n_out);
28
+ out_shifts.resize(n_out);
29
+ out_negs.resize(n_out);
30
+
31
+ std::memcpy(inp_shift.data(), &binary_data[3], n_in * d_size);
32
+ std::memcpy(out_idxs.data(), &binary_data[3 + n_in], n_out * d_size);
33
+ std::memcpy(out_shifts.data(), &binary_data[3 + n_in + n_out], n_out * d_size);
34
+ std::memcpy(out_negs.data(), &binary_data[3 + n_in + 2 * n_out], n_out);
35
+ std::memcpy(ops.data(), &binary_data[3 + n_in + 3 * n_out], n_ops * 8 * d_size);
36
+
37
+ for(const auto &op : ops)
38
+ {
39
+ int32_t width = op.dtype.width();
40
+ if(op.opcode == -1)
41
+ max_inp_width = std::max(max_inp_width, width);
42
+ max_ops_width = std::max(max_ops_width, width);
43
+ }
44
+ for(const int32_t &idx : out_idxs)
45
+ {
46
+ if(idx >= 0)
47
+ max_out_width = std::max(max_out_width, ops[idx].dtype.width());
48
+ }
49
+ validate();
50
+ }
51
+
52
+ void DAISInterpreter::load_from_file(const std::string &path)
53
+ {
54
+ std::ifstream file(path, std::ios::binary);
55
+ if(!file)
56
+ throw std::runtime_error("Failed to open file: " + path);
57
+
58
+ std::vector<int32_t> binary_data;
59
+ file.seekg(0, std::ios::end);
60
+ size_t file_size = file.tellg();
61
+ file.seekg(0, std::ios::beg);
62
+ if(file_size % sizeof(int32_t) != 0)
63
+ throw std::runtime_error("File size is not a multiple of int32_t size");
64
+ if(file_size < 3 * sizeof(int32_t))
65
+ throw std::runtime_error("File size is too small to contain valid DAIS model file");
66
+ size_t num_elements = file_size / sizeof(int32_t);
67
+ binary_data.resize(num_elements);
68
+ file.read(reinterpret_cast<char *>(binary_data.data()), file_size);
69
+ load_from_binary(binary_data);
70
+ }
71
+
72
+ int64_t DAISInterpreter::shift_add(
73
+ int64_t v1,
74
+ int64_t v2,
75
+ int32_t shift,
76
+ bool is_minus,
77
+ const DType &dtype0,
78
+ const DType &dtype1,
79
+ const DType &dtype_out
80
+ ) const
81
+ {
82
+ int32_t actual_shift = shift + dtype0.fractionals - dtype1.fractionals;
83
+ int64_t _v2 = is_minus ? -v2 : v2;
84
+ if(actual_shift > 0)
85
+ return v1 + (_v2 << actual_shift);
86
+ else
87
+ return (v1 << -actual_shift) + _v2;
88
+ }
89
+
90
+ int64_t DAISInterpreter::quantize(int64_t value, const DType &dtype_from, const DType &dtype_to) const
91
+ {
92
+ int32_t shift = dtype_from.fractionals - dtype_to.fractionals;
93
+ value = value >> shift;
94
+ int32_t int_max = dtype_to.int_max();
95
+ int32_t int_min = dtype_to.int_min();
96
+ const int64_t _mod = 1LL << dtype_to.width();
97
+ // std::cout << "value = " << value << " (min=" << int_min << ", max=" << int_max << ", mod=" << _mod <<
98
+ // std::endl;
99
+ value = ((value - int_min + (std::abs(value) / _mod + 1) * _mod) % _mod) + int_min;
100
+ return value;
101
+ }
102
+
103
+ int64_t DAISInterpreter::relu(int64_t value, const DType &dtype_from, const DType &dtype_to) const
104
+ {
105
+ if(value < 0)
106
+ return 0;
107
+ return quantize(value, dtype_from, dtype_to);
108
+ }
109
+
110
+ int64_t
111
+ DAISInterpreter::const_add(int64_t value, DType dtype_from, DType dtype_to, int32_t data_high, int32_t data_low)
112
+ const
113
+ {
114
+ const int32_t _shift = dtype_to.fractionals - dtype_from.fractionals;
115
+ int64_t data = (static_cast<int64_t>(data_high) << 32) | static_cast<uint32_t>(data_low);
116
+ // std::cout << "v=" << value << " c=" << data << " shift=" << _shift << std::endl;
117
+ return (value << _shift) + data;
118
+ }
119
+
120
+ bool DAISInterpreter::get_msb(int64_t value, const DType &dtype) const
121
+ {
122
+ if(dtype.is_signed)
123
+ return value < 0;
124
+ return value >= (1LL << (dtype.width() - 2));
125
+ }
126
+
127
+ int64_t DAISInterpreter::msb_mux(
128
+ int64_t v0,
129
+ int64_t v1,
130
+ int64_t v_cond,
131
+ int32_t _shift,
132
+ const DType &dtype0,
133
+ const DType &dtype1,
134
+ const DType &dtype_cond,
135
+ const DType &dtype_out
136
+ ) const
137
+ {
138
+ bool cond = get_msb(v_cond, dtype_cond);
139
+ int32_t shift = dtype0.fractionals - dtype1.fractionals + _shift;
140
+ int64_t shifted_v0 = shift > 0 ? v0 : (v0 << -shift);
141
+ int64_t shifted_v1 = shift > 0 ? (v1 << shift) : v1;
142
+ if(cond)
143
+ return shifted_v0;
144
+ else
145
+ return shifted_v1;
146
+ }
147
+
148
+ std::vector<int64_t> DAISInterpreter::exec_ops(const std::vector<double> &inputs)
149
+ {
150
+ if(inputs.size() != n_in)
151
+ throw std::runtime_error(
152
+ "Input size mismatch: expected " + std::to_string(n_in) + ", got " + std::to_string(inputs.size())
153
+ );
154
+
155
+ std::vector<int64_t> buffer(n_ops);
156
+ std::vector<int64_t> output_buffer(n_out);
157
+
158
+ for(size_t i = 0; i < n_ops; ++i)
159
+ {
160
+ const Op &op = ops[i];
161
+ switch(op.opcode)
162
+ {
163
+ case -1: {
164
+ int64_t input_value = static_cast<int64_t>(
165
+ std::floor(inputs[op.id0] * std::pow(2.0, inp_shift[op.id0] + ops[i].dtype.fractionals))
166
+ );
167
+ buffer[i] = quantize(input_value, op.dtype, op.dtype);
168
+ break;
169
+ }
170
+ case 0:
171
+ case 1:
172
+ buffer[i] = shift_add(
173
+ buffer[op.id0],
174
+ buffer[op.id1],
175
+ op.data_low,
176
+ op.opcode == 1,
177
+ ops[op.id0].dtype,
178
+ ops[op.id1].dtype,
179
+ ops[i].dtype
180
+ );
181
+ break;
182
+ case 2:
183
+ case -2:
184
+ buffer[i] =
185
+ relu(op.opcode == -2 ? -buffer[op.id0] : buffer[op.id0], ops[op.id0].dtype, ops[i].dtype);
186
+ break;
187
+ case 3:
188
+ case -3:
189
+ buffer[i] = quantize(
190
+ op.opcode == -3 ? -buffer[op.id0] : buffer[op.id0], ops[op.id0].dtype, ops[i].dtype
191
+ );
192
+ break;
193
+ case 4:
194
+ buffer[i] =
195
+ const_add(buffer[op.id0], ops[op.id0].dtype, ops[i].dtype, op.data_high, op.data_low);
196
+ break;
197
+ case 5:
198
+ buffer[i] = static_cast<int64_t>(op.data_high) << 32 | static_cast<uint32_t>(op.data_low);
199
+ break;
200
+ case 6:
201
+ case -6:
202
+ buffer[i] = msb_mux(
203
+ buffer[op.id0],
204
+ op.opcode == -6 ? -buffer[op.id1] : buffer[op.id1],
205
+ buffer[op.data_low],
206
+ op.data_high,
207
+ ops[op.id0].dtype,
208
+ ops[op.id1].dtype,
209
+ ops[op.data_low].dtype,
210
+ ops[i].dtype
211
+ );
212
+ break;
213
+ default:
214
+ throw std::runtime_error(
215
+ "Unknown opcode: " + std::to_string(op.opcode) + " at index " + std::to_string(i)
216
+ );
217
+ }
218
+ }
219
+ for(size_t i = 0; i < n_out; ++i)
220
+ output_buffer[i] = out_idxs[i] >= 0 ? buffer[out_idxs[i]] : 0;
221
+ return output_buffer;
222
+ }
223
+
224
+ std::vector<double> DAISInterpreter::inference(const std::vector<double> &inputs)
225
+ {
226
+ std::vector<int64_t> int_outputs = exec_ops(inputs);
227
+ std::vector<double> outputs(n_out);
228
+ for(size_t i = 0; i < n_out; ++i)
229
+ {
230
+ const int64_t tmp = out_negs[i] ? -int_outputs[i] : int_outputs[i];
231
+ outputs[i] =
232
+ static_cast<double>(tmp) * std::pow(2.0, out_shifts[i] - ops[out_idxs[i]].dtype.fractionals);
233
+ }
234
+
235
+ return outputs;
236
+ }
237
+
238
+ void DAISInterpreter::print_program_info() const
239
+ {
240
+ size_t bits_in = 0, bits_out = 0;
241
+ for(int32_t i = 0; i < n_ops; ++i)
242
+ {
243
+ const Op op = ops[i];
244
+ if(op.opcode == -1)
245
+ bits_in += op.dtype.width();
246
+ }
247
+ for(int32_t i = 0; i < n_out; ++i)
248
+ {
249
+ if(out_idxs[i] >= 0)
250
+ bits_out += ops[out_idxs[i]].dtype.width();
251
+ }
252
+ std::cout << "DAIS Sequence:\n";
253
+ std::cout << n_in << " (" << bits_in << " bits) -> " << n_out << " (" << bits_out << " bits)\n";
254
+ std::cout << "# operations: " << n_ops << "\n";
255
+ std::cout << "Maximum intermediate width: " << max_ops_width << " bits\n";
256
+ }
257
+
258
+ void DAISInterpreter::validate() const
259
+ {
260
+ for(int32_t i = 0; i < n_ops; ++i) // Causality check
261
+ {
262
+ const Op &op = ops[i];
263
+ if(op.id0 >= i && op.opcode != -1)
264
+ throw std::runtime_error(
265
+ "Operation " + std::to_string(i) + " has id0=" + std::to_string(op.id0) + "violating causality"
266
+ );
267
+ if(op.id1 >= i)
268
+ throw std::runtime_error(
269
+ "Operation " + std::to_string(i) + " has id1=" + std::to_string(op.id1) + " violating causality"
270
+ );
271
+ if(abs(op.opcode) == 6 && op.data_low >= i)
272
+ throw std::runtime_error(
273
+ "Operation " + std::to_string(i) + " has cond_idx=" + std::to_string(op.data_low) +
274
+ " violating causality"
275
+ );
276
+ }
277
+
278
+ if(max_inp_width > 32 || max_out_width > 32)
279
+ {
280
+ std::cerr << "Warning: max_inp_width=" << max_inp_width << " or max_out_width=" << max_out_width
281
+ << " exceeds 32 bits, which may cause issues with the Verilator binder.\n";
282
+ }
283
+ if(max_ops_width > 64)
284
+ {
285
+ std::cerr << "Warning: max_ops_width=" << max_ops_width
286
+ << " exceeds 64 bits. This may comppromise bit-exactness of the interpreter.\n"
287
+ << "This high wdith is very unusual for a properly quantized network, so you may want to "
288
+ "check your "
289
+ "model.\n";
290
+ }
291
+ }
292
+ } // namespace dais
293
+
294
+ extern "C"
295
+ {
296
+ void run_interp(int32_t *data, double *inputs, double *outputs, size_t n_copy)
297
+ {
298
+ int32_t n_in = data[0];
299
+ int32_t n_out = data[1];
300
+ int32_t n_ops = data[2];
301
+ size_t size = 3 + n_in + 3 * n_out + 8 * n_ops;
302
+ std::vector<int32_t> binary_data(data, data + size);
303
+ dais::DAISInterpreter interp;
304
+ interp.load_from_binary(binary_data);
305
+ interp.print_program_info();
306
+
307
+ for(size_t i = 0; i < n_copy; ++i)
308
+ {
309
+ std::vector<double> inp_vec(inputs + i * n_in, inputs + (i + 1) * n_in);
310
+ auto ret = interp.inference(inp_vec);
311
+ memcpy(&outputs[i * n_out], ret.data(), n_out * sizeof(double));
312
+ }
313
+ }
314
+ }