da4ml 0.2.1__py3-none-any.whl → 0.3.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (55) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/cmvm/types.py +95 -15
  3. da4ml/codegen/__init__.py +5 -4
  4. da4ml/codegen/cpp/__init__.py +2 -1
  5. da4ml/codegen/cpp/cpp_codegen.py +56 -23
  6. da4ml/codegen/cpp/hls_model.py +252 -0
  7. da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
  8. da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
  9. da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
  10. da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
  11. da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
  12. da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
  13. da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
  14. da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
  15. da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
  16. da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
  17. da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
  18. da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
  19. da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
  20. da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
  21. da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
  22. da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
  23. da4ml/codegen/cpp/source/binder_util.hh +56 -0
  24. da4ml/codegen/cpp/source/build_binder.mk +24 -0
  25. da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
  26. da4ml/codegen/verilog/__init__.py +2 -3
  27. da4ml/codegen/verilog/comb.py +65 -24
  28. da4ml/codegen/verilog/io_wrapper.py +36 -141
  29. da4ml/codegen/verilog/source/binder_util.hh +72 -0
  30. da4ml/codegen/verilog/source/mux.v +58 -0
  31. da4ml/codegen/verilog/source/negative.v +28 -0
  32. da4ml/codegen/verilog/source/shift_adder.v +4 -1
  33. da4ml/codegen/verilog/source/template.xdc +3 -0
  34. da4ml/codegen/verilog/verilog_model.py +36 -12
  35. da4ml/converter/__init__.py +0 -0
  36. da4ml/converter/hgq2/parser.py +105 -0
  37. da4ml/converter/hgq2/replica.py +383 -0
  38. da4ml/trace/__init__.py +2 -2
  39. da4ml/trace/fixed_variable.py +175 -16
  40. da4ml/trace/fixed_variable_array.py +109 -4
  41. da4ml/trace/ops/__init__.py +22 -6
  42. da4ml/trace/ops/conv_utils.py +147 -15
  43. da4ml/trace/ops/einsum_utils.py +9 -6
  44. da4ml/trace/ops/reduce_utils.py +103 -0
  45. da4ml/trace/pipeline.py +36 -34
  46. da4ml/trace/tracer.py +37 -7
  47. da4ml-0.3.0.post1.dist-info/METADATA +107 -0
  48. da4ml-0.3.0.post1.dist-info/RECORD +64 -0
  49. da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
  50. da4ml-0.2.1.dist-info/METADATA +0 -65
  51. da4ml-0.2.1.dist-info/RECORD +0 -39
  52. /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
  53. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/WHEEL +0 -0
  54. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/licenses/LICENSE +0 -0
  55. {da4ml-0.2.1.dist-info → da4ml-0.3.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,27 @@
1
+ #ifndef X_HLS_MATH_H
2
+ #define X_HLS_MATH_H
3
+
4
+ #include <cmath>
5
+ #include "ap_fixed.h"
6
+
7
+ namespace hls {
8
+
9
+ template<class T>
10
+ static T exp(const T x) {
11
+ return (T) std::exp(x.to_double());
12
+ }
13
+
14
+ template <typename T> T sin(T x) { return (T) std::sin(x.to_double()); };
15
+
16
+ template <typename T> T cos(T x) { return (T) std::cos(x.to_double()); };
17
+
18
+ template <typename T> T asin(T x) { return (T) std::asin(x.to_double()); };
19
+
20
+ template <typename T> T acos(T x) { return (T) std::acos(x.to_double()); };
21
+
22
+ template <typename T> T atan(T x) { return (T) std::atan(x.to_double()); };
23
+
24
+ template <typename T> T atan2(T x, T y) { return (T) hls::atan2(x.to_double(), y.to_double()); };
25
+
26
+ }
27
+ #endif
@@ -0,0 +1,263 @@
1
+ /*
2
+ #- (c) Copyright 2011-2018 Xilinx, Inc. All rights reserved.
3
+ #-
4
+ #- This file contains confidential and proprietary information
5
+ #- of Xilinx, Inc. and is protected under U.S. and
6
+ #- international copyright and other intellectual property
7
+ #- laws.
8
+ #-
9
+ #- DISCLAIMER
10
+ #- This disclaimer is not a license and does not grant any
11
+ #- rights to the materials distributed herewith. Except as
12
+ #- otherwise provided in a valid license issued to you by
13
+ #- Xilinx, and to the maximum extent permitted by applicable
14
+ #- law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
15
+ #- WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
16
+ #- AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
17
+ #- BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
18
+ #- INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
19
+ #- (2) Xilinx shall not be liable (whether in contract or tort,
20
+ #- including negligence, or under any other theory of
21
+ #- liability) for any loss or damage of any kind or nature
22
+ #- related to, arising under or in connection with these
23
+ #- materials, including for any direct, or any indirect,
24
+ #- special, incidental, or consequential loss or damage
25
+ #- (including loss of data, profits, goodwill, or any type of
26
+ #- loss or damage suffered as a result of any action brought
27
+ #- by a third party) even if such damage or loss was
28
+ #- reasonably foreseeable or Xilinx had been advised of the
29
+ #- possibility of the same.
30
+ #-
31
+ #- CRITICAL APPLICATIONS
32
+ #- Xilinx products are not designed or intended to be fail-
33
+ #- safe, or for use in any application requiring fail-safe
34
+ #- performance, such as life-support or safety devices or
35
+ #- systems, Class III medical devices, nuclear facilities,
36
+ #- applications related to the deployment of airbags, or any
37
+ #- other applications that could lead to death, personal
38
+ #- injury, or severe property or environmental damage
39
+ #- (individually and collectively, "Critical
40
+ #- Applications"). Customer assumes the sole risk and
41
+ #- liability of any use of Xilinx products in Critical
42
+ #- Applications, subject only to applicable laws and
43
+ #- regulations governing limitations on product liability.
44
+ #-
45
+ #- THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
46
+ #- PART OF THIS FILE AT ALL TIMES.
47
+ #- ************************************************************************
48
+
49
+
50
+ Licensed under the Apache License, Version 2.0 (the "License");
51
+ you may not use this file except in compliance with the License.
52
+ You may obtain a copy of the License at
53
+
54
+ http://www.apache.org/licenses/LICENSE-2.0
55
+
56
+ Unless required by applicable law or agreed to in writing, software
57
+ distributed under the License is distributed on an "AS IS" BASIS,
58
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
59
+ See the License for the specific language governing permissions and
60
+ limitations under the License.
61
+ */
62
+
63
+ #ifndef X_HLS_STREAM_SIM_H
64
+ #define X_HLS_STREAM_SIM_H
65
+
66
+ /*
67
+ * This file contains a C++ model of hls::stream.
68
+ * It defines C simulation model.
69
+ */
70
+ #ifndef __cplusplus
71
+
72
+ #error C++ is required to include this header file
73
+
74
+ #else
75
+
76
+ //////////////////////////////////////////////
77
+ // C level simulation models for hls::stream
78
+ //////////////////////////////////////////////
79
+ #include <queue>
80
+ #include <iostream>
81
+ #include <typeinfo>
82
+ #include <string>
83
+ #include <sstream>
84
+
85
+ #ifdef HLS_STREAM_THREAD_SAFE
86
+ #include <mutex>
87
+ #include <condition_variable>
88
+ #endif
89
+
90
+ #ifndef _MSC_VER
91
+ #include <cxxabi.h>
92
+ #include <stdlib.h>
93
+ #endif
94
+
95
+ namespace hls {
96
+
97
+ template<typename __STREAM_T__>
98
+ class stream
99
+ {
100
+ protected:
101
+ std::string _name;
102
+ std::deque<__STREAM_T__> _data; // container for the elements
103
+ #ifdef HLS_STREAM_THREAD_SAFE
104
+ std::mutex _mutex;
105
+ std::condition_variable _condition_var;
106
+ #endif
107
+
108
+ public:
109
+ /// Constructors
110
+ // Keep consistent with the synthesis model's constructors
111
+ stream() {
112
+ static unsigned _counter = 1;
113
+ std::stringstream ss;
114
+ #ifndef _MSC_VER
115
+ char* _demangle_name = abi::__cxa_demangle(typeid(*this).name(), 0, 0, 0);
116
+ if (_demangle_name) {
117
+ _name = _demangle_name;
118
+ free(_demangle_name);
119
+ }
120
+ else {
121
+ _name = "hls_stream";
122
+ }
123
+ #else
124
+ _name = typeid(*this).name();
125
+ #endif
126
+
127
+ ss << _counter++;
128
+ _name += "." + ss.str();
129
+ }
130
+
131
+ stream(const std::string name) {
132
+ // default constructor,
133
+ // capacity set to predefined maximum
134
+ _name = name;
135
+ }
136
+
137
+ /// Make copy constructor and assignment operator private
138
+ private:
139
+ stream(const stream< __STREAM_T__ >& chn):
140
+ _name(chn._name), _data(chn._data) {
141
+ }
142
+
143
+ stream& operator = (const stream< __STREAM_T__ >& chn) {
144
+ _name = chn._name;
145
+ _data = chn._data;
146
+ return *this;
147
+ }
148
+
149
+ public:
150
+ /// Overload >> and << operators to implement read() and write()
151
+ void operator >> (__STREAM_T__& rdata) {
152
+ read(rdata);
153
+ }
154
+
155
+ void operator << (const __STREAM_T__& wdata) {
156
+ write(wdata);
157
+ }
158
+
159
+
160
+ public:
161
+ /// Destructor
162
+ /// Check status of the queue
163
+ virtual ~stream() {
164
+ if (!_data.empty())
165
+ {
166
+ std::cout << "WARNING: Hls::stream '"
167
+ << _name
168
+ << "' contains leftover data,"
169
+ << " which may result in RTL simulation hanging."
170
+ << std::endl;
171
+ }
172
+ }
173
+
174
+ /// Status of the queue
175
+ bool empty() {
176
+ #ifdef HLS_STREAM_THREAD_SAFE
177
+ std::lock_guard<std::mutex> lg(_mutex);
178
+ #endif
179
+ return _data.empty();
180
+ }
181
+
182
+ bool full() const { return false; }
183
+
184
+ /// Blocking read
185
+ void read(__STREAM_T__& head) {
186
+ head = read();
187
+ }
188
+
189
+ #ifdef HLS_STREAM_THREAD_SAFE
190
+ __STREAM_T__ read() {
191
+ std::unique_lock<std::mutex> ul(_mutex);
192
+ while (_data.empty()) {
193
+ _condition_var.wait(ul);
194
+ }
195
+
196
+ __STREAM_T__ elem;
197
+ elem = _data.front();
198
+ _data.pop_front();
199
+ return elem;
200
+ }
201
+ #else
202
+ __STREAM_T__ read() {
203
+ __STREAM_T__ elem;
204
+ if (_data.empty()) {
205
+ std::cout << "WARNING: Hls::stream '"
206
+ << _name
207
+ << "' is read while empty,"
208
+ << " which may result in RTL simulation hanging."
209
+ << std::endl;
210
+ elem = __STREAM_T__();
211
+ } else {
212
+ elem = _data.front();
213
+ _data.pop_front();
214
+ }
215
+ return elem;
216
+ }
217
+ #endif
218
+
219
+ /// Blocking write
220
+ void write(const __STREAM_T__& tail) {
221
+ #ifdef HLS_STREAM_THREAD_SAFE
222
+ std::unique_lock<std::mutex> ul(_mutex);
223
+ #endif
224
+ _data.push_back(tail);
225
+ #ifdef HLS_STREAM_THREAD_SAFE
226
+ _condition_var.notify_one();
227
+ #endif
228
+ }
229
+
230
+ /// Nonblocking read
231
+ bool read_nb(__STREAM_T__& head) {
232
+ #ifdef HLS_STREAM_THREAD_SAFE
233
+ std::lock_guard<std::mutex> lg(_mutex);
234
+ #endif
235
+ bool is_empty = _data.empty();
236
+ if (is_empty) {
237
+ head = __STREAM_T__();
238
+ } else {
239
+ __STREAM_T__ elem(_data.front());
240
+ _data.pop_front();
241
+ head = elem;
242
+ }
243
+ return !is_empty;
244
+ }
245
+
246
+ /// Nonblocking write
247
+ bool write_nb(const __STREAM_T__& tail) {
248
+ bool is_full = full();
249
+ write(tail);
250
+ return !is_full;
251
+ }
252
+
253
+ /// Fifo size
254
+ size_t size() {
255
+ return _data.size();
256
+ }
257
+ };
258
+
259
+ } // namespace hls
260
+
261
+ #endif // __cplusplus
262
+ #endif // X_HLS_STREAM_SIM_H
263
+
@@ -0,0 +1,80 @@
1
+ #ifndef X_HLS_UTILS_H
2
+ #define X_HLS_UTILS_H
3
+ #include "ap_fixed.h"
4
+ #include <limits>
5
+
6
+ namespace hls {
7
+
8
+ template<typename T>
9
+ class numeric_limits {
10
+ public:
11
+ static T max() { return std::numeric_limits<T>::max(); }
12
+ static T min() { return std::numeric_limits<T>::min(); }
13
+ static T epsilon() { return std::numeric_limits<T>::epsilon(); }
14
+ };
15
+
16
+ template <int W, int I, ap_q_mode Q, ap_o_mode O>
17
+ class numeric_limits<ap_fixed<W,I,Q,O> > {
18
+ public:
19
+ static ap_fixed<W,I,Q,O> max() {
20
+ ap_int<W> m = ::hls::numeric_limits<ap_int<W> >::max();
21
+ ap_fixed<W,I,Q,O> x;
22
+ x(W-1,0) = m(W-1,0);
23
+ return x;
24
+ }
25
+ static ap_fixed<W,I,Q,O> min() {
26
+ ap_int<W> m = ::hls::numeric_limits<ap_int<W> >::min();
27
+ ap_fixed<W,I,Q,O> x;
28
+ x(W-1,0) = m(W-1,0);
29
+ return x;
30
+ }
31
+ static ap_fixed<W,I,Q,O> epsilon() {
32
+ ap_fixed<W,I,Q,O> x = 0;
33
+ x[0] = 1;
34
+ return x;
35
+ }
36
+ };
37
+
38
+ template <int W, int I, ap_q_mode Q, ap_o_mode O>
39
+ class numeric_limits<ap_ufixed<W,I,Q,O> > {
40
+ public:
41
+ static ap_ufixed<W,I,Q,O> max() {
42
+ ap_uint<W> m = ::hls::numeric_limits<ap_uint<W> >::max();
43
+ ap_ufixed<W,I,Q,O> x;
44
+ x(W-1,0) = m(W-1,0);
45
+ return x;
46
+ }
47
+ static ap_ufixed<W,I,Q,O> min() { return 0; }
48
+ static ap_ufixed<W,I,Q,O> epsilon() {
49
+ ap_ufixed<W,I,Q,O> x = 0;
50
+ x[0] = 1;
51
+ return x;
52
+ }
53
+ };
54
+
55
+ template <int W>
56
+ class numeric_limits<ap_int<W> > {
57
+ public:
58
+ static ap_int<W> max() { ap_int<W> m = min(); return ~m; }
59
+ static ap_int<W> min() { ap_int<W> m = 0; m[W-1] = 1; return m; }
60
+ static ap_int<W> epsilon() {
61
+ ap_int<W> x = 0;
62
+ x[0] = 1;
63
+ return x;
64
+ }
65
+ };
66
+
67
+ template <int W>
68
+ class numeric_limits<ap_uint<W> > {
69
+ public:
70
+ static ap_uint<W> max() { ap_uint<W> zero = 0; return ~zero; }
71
+ static ap_uint<W> min() { return 0; }
72
+ static ap_uint<W> epsilon() {
73
+ ap_uint<W> x = 0;
74
+ x[0] = 1;
75
+ return x;
76
+ }
77
+ };
78
+ }
79
+
80
+ #endif
@@ -0,0 +1,56 @@
1
+ #pragma once
2
+ #include <cstddef>
3
+
4
+ #ifdef _OPENMP
5
+ #include <algorithm>
6
+ #include <omp.h>
7
+ constexpr bool _openmp = true;
8
+ #else
9
+ constexpr bool _openmp = false;
10
+ #endif
11
+
12
+ template <typename CONFIG_T, typename T> void _inference(T *c_inp, T *c_out, size_t n_samples)
13
+ {
14
+ typename CONFIG_T::inp_t in_fixed_buf[CONFIG_T::N_inp];
15
+ typename CONFIG_T::out_t out_fixed_buf[CONFIG_T::N_out];
16
+
17
+ for(size_t i = 0; i < n_samples; ++i)
18
+ {
19
+ size_t offset_in = i * CONFIG_T::N_inp;
20
+ size_t offset_out = i * CONFIG_T::N_out;
21
+ for(size_t j = 0; j < CONFIG_T::N_inp; ++j)
22
+ {
23
+ in_fixed_buf[j] = c_inp[offset_in + j];
24
+ }
25
+
26
+ CONFIG_T::f(in_fixed_buf, out_fixed_buf);
27
+
28
+ for(size_t j = 0; j < CONFIG_T::N_out; ++j)
29
+ {
30
+ c_out[offset_out + j] = out_fixed_buf[j];
31
+ }
32
+ }
33
+ }
34
+
35
+ template <typename CONFIG_T, typename T> void batch_inference(T *c_inp, T *c_out, size_t n_samples)
36
+ {
37
+ #ifdef _OPENMP
38
+ size_t n_max_threads = omp_get_max_threads();
39
+ size_t n_samples_per_thread = std::max<size_t>(n_samples / n_max_threads, 32);
40
+ size_t n_thread = n_samples / n_samples_per_thread;
41
+ n_thread += (n_samples % n_samples_per_thread) ? 1 : 0;
42
+
43
+ #pragma omp parallel for num_threads(n_thread) schedule(static)
44
+ for(size_t i = 0; i < n_thread; ++i)
45
+ {
46
+ size_t start = i * n_samples_per_thread;
47
+ size_t end = std::min<size_t>(start + n_samples_per_thread, n_samples);
48
+ size_t n_samples_this_thread = end - start;
49
+ size_t offset_in = start * CONFIG_T::N_inp;
50
+ size_t offset_out = start * CONFIG_T::N_out;
51
+ _inference<CONFIG_T, T>(&c_inp[offset_in], &c_out[offset_out], n_samples_this_thread);
52
+ }
53
+ #else
54
+ _inference<CONFIG_T, T>(c_inp, c_out, n_samples);
55
+ #endif
56
+ }
@@ -0,0 +1,24 @@
1
+ default: slow
2
+ CXX = g++
3
+ CC = gcc
4
+ INCLUDES = -I ap_types -I .
5
+ CXXFLAGS = -fPIC
6
+ CFLAGS = -std=c++17 -fPIC
7
+ LIBNAME = lib$(PRJ_NAME)_$(STAMP).so
8
+
9
+ fast: CXXFLAGS += -O3
10
+ fast: $(LIBNAME)
11
+
12
+ slow: CXXFLAGS += -O
13
+ slow: $(LIBNAME)
14
+
15
+ $(PRJ_NAME)_$(STAMP).o: $(PRJ_NAME).cc
16
+ $(CC) -c $(PRJ_NAME).cc -o $(PRJ_NAME)_$(STAMP).o $(INCLUDES) $(CXXFLAGS) $(EXTRA_CXXFLAGS)
17
+
18
+ $(LIBNAME): $(PRJ_NAME)_$(STAMP).o $(PRJ_NAME)_bridge.cc
19
+ $(CXX) $(INCLUDES) $(CXXFLAGS) -shared -o $@ $(PRJ_NAME)_$(STAMP).o $(PRJ_NAME)_bridge.cc $(EXTRA_CXXFLAGS)
20
+
21
+ clean:
22
+ rm -f $(LIBNAME) $(PRJ_NAME)_$(STAMP).o
23
+
24
+ .PHONY: clean
@@ -1,5 +1,5 @@
1
1
  #pragma once
2
- #include "ap_fixed.h"
2
+ #include "ap_types/ap_fixed.h"
3
3
 
4
4
  template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_fixed<b, i + s> bit_shift(ap_fixed<b, i, Q, O, N> x) {
5
5
  #pragma HLS INLINE
@@ -1,13 +1,12 @@
1
1
  from .comb import comb_logic_gen
2
- from .io_wrapper import comb_binder_gen, generate_io_wrapper, pipeline_binder_gen
2
+ from .io_wrapper import binder_gen, generate_io_wrapper
3
3
  from .pipeline import pipeline_logic_gen
4
4
  from .verilog_model import VerilogModel
5
5
 
6
6
  __all__ = [
7
7
  'comb_logic_gen',
8
8
  'generate_io_wrapper',
9
- 'comb_binder_gen',
10
9
  'pipeline_logic_gen',
11
- 'pipeline_binder_gen',
10
+ 'binder_gen',
12
11
  'VerilogModel',
13
12
  ]
@@ -2,10 +2,11 @@ from math import ceil, log2
2
2
 
3
3
  import numpy as np
4
4
 
5
- from da4ml.cmvm.types import Op, QInterval, Solution, _minimal_kif
5
+ from ...cmvm.types import QInterval, Solution, _minimal_kif
6
6
 
7
7
 
8
- def ssa_gen(ops: list[Op], print_latency: bool = False):
8
+ def ssa_gen(sol: Solution, neg_defined: set[int], print_latency: bool = False):
9
+ ops = sol.ops
9
10
  kifs = list(map(_minimal_kif, (op.qint for op in ops)))
10
11
  widths = list(map(sum, kifs))
11
12
  inp_kifs = [_minimal_kif(op.qint) for op in ops if op.opcode == -1]
@@ -14,11 +15,17 @@ def ssa_gen(ops: list[Op], print_latency: bool = False):
14
15
  inp_idxs = np.stack([_inp_widths[1:] - 1, _inp_widths[:-1]], axis=1)
15
16
 
16
17
  lines = []
18
+ ref_count = sol.ref_count
17
19
 
18
20
  for i, op in enumerate(ops):
21
+ if ref_count[i] == 0:
22
+ continue
23
+
19
24
  bw = widths[i]
20
- v = f'v{i}[{bw-1}:0]'
21
- _def = f'wire [{bw-1}:0] v{i};'
25
+ v = f'v{i}[{bw - 1}:0]'
26
+ _def = f'wire [{bw - 1}:0] v{i};'
27
+ if bw == 0:
28
+ continue
22
29
 
23
30
  match op.opcode:
24
31
  case -1: # Input marker
@@ -34,12 +41,16 @@ def ssa_gen(ops: list[Op], print_latency: bool = False):
34
41
  if op.opcode == -2:
35
42
  _min, _max, step = ops[op.id0].qint
36
43
  bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
37
- lines.append(
38
- f'wire [{bw_neg-1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg-1}:0] = -{v0_name}[{bw0-1}:0];'
39
- )
44
+ if op.id0 not in neg_defined:
45
+ neg_defined.add(op.id0)
46
+ was_signed = int(kifs[op.id0][0])
47
+ lines.append(
48
+ f'wire [{bw_neg - 1}:0] v{op.id0}_neg; negative #({bw0}, {bw_neg}, {was_signed}) op_neg_{op.id0} ({v0_name}, v{op.id0}_neg);'
49
+ )
50
+ bw0 = bw_neg
40
51
  v0_name = f'v{op.id0}_neg'
41
52
  if ops[op.id0].qint.min < 0:
42
- line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}] & {{{bw}{{~{v0_name}[{bw0-1}]}}}};'
53
+ line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}] & {{{bw}{{~{v0_name}[{bw0 - 1}]}}}};'
43
54
  else:
44
55
  line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
45
56
  case 3 | -3: # Explicit quantization
@@ -50,23 +61,31 @@ def ssa_gen(ops: list[Op], print_latency: bool = False):
50
61
 
51
62
  if op.opcode == -3:
52
63
  _min, _max, step = ops[op.id0].qint
64
+ lines.append('/* verilator lint_off WIDTHTRUNC */')
53
65
  bw_neg = max(sum(_minimal_kif(QInterval(-_max, -_min, step))), bw0)
54
- lines.append(
55
- f'wire [{bw_neg-1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg-1}:0] = -{v0_name}[{bw0-1}:0];'
56
- )
66
+ if op.id0 not in neg_defined:
67
+ neg_defined.add(op.id0)
68
+ # lines.append('/* verilator lint_off WIDTHTRUNC */')
69
+ # lines.append(
70
+ # f'wire [{bw_neg - 1}:0] v{op.id0}_neg; assign v{op.id0}_neg[{bw_neg - 1}:0] = -{v0_name}[{bw0 - 1}:0];'
71
+ # )
72
+ # lines.append('/* verilator lint_on WIDTHTRUNC */')
73
+ was_signed = int(kifs[op.id0][0])
74
+ lines.append(
75
+ f'wire [{bw_neg - 1}:0] v{op.id0}_neg; negative #({bw0}, {bw_neg}, {was_signed}) op_neg_{op.id0} ({v0_name}, v{op.id0}_neg);'
76
+ )
57
77
  v0_name = f'v{op.id0}_neg'
58
78
 
59
79
  line = f'{_def} assign {v} = {v0_name}[{i0}:{i1}];'
60
80
  case 4: # constant addition
61
81
  num = op.data
62
82
  sign, mag = int(num < 0), abs(num)
63
- line = f"{_def} assign {v} = '{bin(mag)[1:]};"
64
83
  bw1 = ceil(log2(mag + 1))
65
84
  bw0 = widths[op.id0]
66
85
  s0 = int(kifs[op.id0][0])
67
- v0 = f'v{op.id0}[{bw0-1}:0]'
86
+ v0 = f'v{op.id0}[{bw0 - 1}:0]'
68
87
  v1 = f"'{bin(mag)[1:]}"
69
- shift = int(log2(op.qint.step / ops[op.id0].qint.step))
88
+ shift = kifs[op.id0][2] - kifs[i][2]
70
89
  line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, 0, {bw}, {shift}, {sign}) op_{i} ({v0}, {v1}, {v});'
71
90
  case 5: # constant
72
91
  num = op.data
@@ -80,9 +99,22 @@ def ssa_gen(ops: list[Op], print_latency: bool = False):
80
99
  bw0, bw1 = widths[op.id0], widths[op.id1] # width
81
100
  s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
82
101
  shift = op.data + f0 - f1
83
- v0, v1 = f'v{op.id0}[{bw0-1}:0]', f'v{op.id1}[{bw1-1}:0]'
102
+ v0, v1 = f'v{op.id0}[{bw0 - 1}:0]', f'v{op.id1}[{bw1 - 1}:0]'
84
103
 
85
104
  line = f'{_def} shift_adder #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {op.opcode}) op_{i} ({v0}, {v1}, {v});'
105
+
106
+ case 6 | -6: # MSB Muxing
107
+ k, a, b = op.data & 0xFFFFFFFF, op.id0, op.id1
108
+ p0, p1 = kifs[a], kifs[b]
109
+ inv = '1' if op.opcode == -6 else '0'
110
+ bwk, bw0, bw1 = widths[k], widths[a], widths[b]
111
+ s0, f0, s1, f1 = int(p0[0]), p0[2], int(p1[0]), p1[2]
112
+ _shift = (op.data >> 32) & 0xFFFFFFFF
113
+ _shift = _shift if _shift < 0x80000000 else _shift - 0x100000000
114
+ shift = f0 - f1 + _shift
115
+ vk, v0, v1 = f'v{k}[{bwk - 1}]', f'v{a}[{bw0 - 1}:0]', f'v{b}[{bw1 - 1}:0]'
116
+
117
+ line = f'{_def} mux #({bw0}, {bw1}, {s0}, {s1}, {bw}, {shift}, {inv}) op_{i} ({vk}, {v0}, {v1}, {v});'
86
118
  case _:
87
119
  raise ValueError(f'Unknown opcode {op.opcode} for operation {i} ({op})')
88
120
 
@@ -92,7 +124,7 @@ def ssa_gen(ops: list[Op], print_latency: bool = False):
92
124
  return lines
93
125
 
94
126
 
95
- def output_gen(sol: Solution):
127
+ def output_gen(sol: Solution, neg_defined: set[int]):
96
128
  lines = []
97
129
  widths = list(map(sum, map(_minimal_kif, sol.out_qint)))
98
130
  _widths = np.cumsum([0] + widths)
@@ -101,13 +133,21 @@ def output_gen(sol: Solution):
101
133
  if idx < 0:
102
134
  continue
103
135
  i0, i1 = out_idxs[i]
136
+ if i0 == i1 - 1:
137
+ continue
104
138
  bw = widths[i]
105
- bw0 = sum(_minimal_kif(sol.ops[idx].qint))
106
139
  if sol.out_negs[i]:
107
- lines.append(f'wire [{bw-1}:0] out_neg{i}; assign out_neg{i} = -v{idx}[{bw0-1}:0];')
108
- lines.append(f'assign out[{i0}:{i1}] = out_neg{i}[{bw-1}:0];')
140
+ if idx not in neg_defined:
141
+ neg_defined.add(idx)
142
+ bw0 = sum(_minimal_kif(sol.ops[idx].qint))
143
+ was_signed = int(sol.ops[idx].qint[0] < 0)
144
+ lines.append(
145
+ f'wire [{bw - 1}:0] v{idx}_neg; negative #({bw0}, {bw}, {was_signed}) op_neg_{idx} (v{idx}, v{idx}_neg);'
146
+ )
147
+ lines.append(f'assign out[{i0}:{i1}] = v{idx}_neg[{bw - 1}:0];')
148
+
109
149
  else:
110
- lines.append(f'assign out[{i0}:{i1}] = v{idx}[{bw-1}:0];')
150
+ lines.append(f'assign out[{i0}:{i1}] = v{idx}[{bw - 1}:0];')
111
151
  return lines
112
152
 
113
153
 
@@ -117,13 +157,14 @@ def comb_logic_gen(sol: Solution, fn_name: str, print_latency: bool = False, tim
117
157
 
118
158
  fn_signature = [
119
159
  f'module {fn_name} (',
120
- f' input [{inp_bits-1}:0] inp,',
121
- f' output [{out_bits-1}:0] out',
160
+ f' input [{inp_bits - 1}:0] inp,',
161
+ f' output [{out_bits - 1}:0] out',
122
162
  ');',
123
163
  ]
124
164
 
125
- ssa_lines = ssa_gen(sol.ops, print_latency=print_latency)
126
- output_lines = output_gen(sol)
165
+ neg_defined = set()
166
+ ssa_lines = ssa_gen(sol, neg_defined=neg_defined, print_latency=print_latency)
167
+ output_lines = output_gen(sol, neg_defined)
127
168
 
128
169
  indent = ' '
129
170
  base_indent = '\n'