da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +204 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +246 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.1.post1.dist-info/METADATA +85 -0
  92. da4ml-0.5.1.post1.dist-info/RECORD +96 -0
  93. da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
  94. da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
da4ml/cmvm/api.py ADDED
@@ -0,0 +1,264 @@
1
+ from math import ceil, log2
2
+ from typing import TypedDict
3
+
4
+ import numpy as np
5
+ from numba import jit, prange
6
+
7
+ from .core import _solve, create_state, to_solution
8
+ from .types import Pipeline, QInterval
9
+ from .util import kernel_decompose
10
+
11
+
12
+ @jit(cache=True)
13
+ def minimal_latency(
14
+ kernel: np.ndarray,
15
+ qintervals: list[QInterval],
16
+ latencies: list[float],
17
+ carry_size: int = -1,
18
+ adder_size: int = -1,
19
+ ):
20
+ """Fast latency calculation for a given kernel, QInterval, and input latencies.
21
+ When carry_size=-1, and the input latency is constant `l`:
22
+ this will be the same as `l + max(ceiling(log2(max(#CSD bits for each column, 1))))`.
23
+
24
+ Parameters
25
+ ----------
26
+ kernel : np.ndarray
27
+ The input kernel matrix.
28
+ qintervals : list[QInterval]
29
+ List of QIntervals for each input.
30
+ latencies : list[float]
31
+ List of latencies for each input
32
+ carry_size : int, optional
33
+ The size of the carry unit for latency computation, by default -1 (fixed latency for each addition operation)
34
+ adder_size : int, optional
35
+ The size of the adder unit for latency computation, by default -1 (fixed cost for each addition operation)
36
+
37
+ Returns
38
+ -------
39
+ float
40
+ The minimal latency for the given kernel, QInterval, and input latencies.
41
+ """
42
+
43
+ state = create_state(kernel, qintervals, latencies, no_stat_init=True)
44
+ sol = to_solution(state, adder_size=adder_size, carry_size=carry_size)
45
+ latencies = [sol.ops[i].latency if i >= 0 else 0.0 for i in sol.out_idxs]
46
+ return max(latencies)
47
+
48
+
49
+ @jit(cache=True)
50
+ def jit_solve(
51
+ kernel: np.ndarray,
52
+ method0: str = 'wmc',
53
+ method1: str = 'auto',
54
+ hard_dc: int = -1,
55
+ decompose_dc: int = -2,
56
+ qintervals: list[QInterval] | None = None,
57
+ latencies: list[float] | None = None,
58
+ adder_size: int = -1,
59
+ carry_size: int = -1,
60
+ ) -> Pipeline:
61
+ """Optimized implementation of a CMVM computation with cascaded two matrices.
62
+
63
+ Parameters
64
+ ----------
65
+ kernel : np.ndarray
66
+ The input kernel matrix to be implemented.
67
+ method0 : str, optional
68
+ Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
69
+ method1 : str, optional
70
+ Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
71
+ hard_dc : int, optional
72
+ Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
73
+ decompose_dc : int, optional
74
+ Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
75
+ qintervals : list[QInterval] | None, optional
76
+ List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
77
+ inp_latencies : list[float] | None, optional
78
+ List of input latencies, by default None (0. for all inputs)
79
+ adder_size : int, optional
80
+ Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
81
+ carry_size : int, optional
82
+ Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
83
+
84
+ Returns
85
+ -------
86
+ CascadedSolution
87
+ A solution containing the optimized implementation of the CMVM computation with cascaded stages.
88
+ """
89
+
90
+ if hard_dc < 0:
91
+ hard_dc = int(1e9)
92
+
93
+ if method1 == 'auto':
94
+ if hard_dc >= 6 or method0.endswith('dc'):
95
+ method1 = method0
96
+ else:
97
+ method1 = method0 + '-dc'
98
+ if hard_dc == 0 and not method0.endswith('dc'):
99
+ method0 = method0 + '-dc'
100
+
101
+ if qintervals is None:
102
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
103
+ else:
104
+ _qintervals = list(qintervals)
105
+ if latencies is None:
106
+ _inp_latencies = [0.0] * kernel.shape[0]
107
+ else:
108
+ _inp_latencies = [float(lat) for lat in latencies]
109
+ assert len(_qintervals) == kernel.shape[0]
110
+ assert len(_inp_latencies) == kernel.shape[0]
111
+
112
+ min_lat = minimal_latency(kernel, _qintervals, _inp_latencies, carry_size=carry_size, adder_size=adder_size)
113
+ latency_allowed = hard_dc + min_lat
114
+ if decompose_dc == -2:
115
+ decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
116
+ else:
117
+ decompose_dc = min(hard_dc, decompose_dc, ceil(log2(kernel.shape[0])))
118
+
119
+ while True:
120
+ if decompose_dc < 0 and hard_dc >= 0:
121
+ if method0 != 'dummy':
122
+ method0, method1 = 'wmc-dc', 'wmc-dc'
123
+ else:
124
+ method0, method1 = 'dummy', 'dummy'
125
+ mat0, mat1 = kernel_decompose(kernel, dc=decompose_dc)
126
+ sol0 = _solve(
127
+ mat0, method=method0, qintervals=_qintervals, latencies=_inp_latencies, adder_size=adder_size, carry_size=carry_size
128
+ )
129
+ latencies0 = [sol0.ops[i].latency if i >= 0 else 0.0 for i in sol0.out_idxs]
130
+ qintervals0 = [sol0.ops[i].qint if i >= 0 else QInterval(0.0, 0.0, np.inf) for i in sol0.out_idxs]
131
+ if max(latencies0) > latency_allowed:
132
+ if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
133
+ decompose_dc -= 1
134
+ continue
135
+ sol1 = _solve(
136
+ mat1, method=method1, qintervals=qintervals0, latencies=latencies0, adder_size=adder_size, carry_size=carry_size
137
+ )
138
+ latencies1 = [sol1.ops[i].latency if i >= 0 else 0.0 for i in sol1.out_idxs]
139
+ if max(latencies1) > latency_allowed:
140
+ # Prevent infinite loop, shouldn't happen though
141
+ if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
142
+ decompose_dc -= 1
143
+ continue
144
+ break
145
+ if max(latencies1) > latency_allowed:
146
+ # When latency depends on the bw, may happen
147
+ print(f'Latency constraint not satisfied: {int(latency_allowed)} < {int(max(latencies1))}')
148
+ return Pipeline((sol0, sol1))
149
+
150
+
151
+ @jit(cache=True, parallel=True)
152
+ def solve(
153
+ kernel: np.ndarray,
154
+ method0: str = 'wmc',
155
+ method1: str = 'auto',
156
+ hard_dc: int = -1,
157
+ decompose_dc: int = -2,
158
+ qintervals: list[QInterval] | None = None,
159
+ latencies: list[float] | None = None,
160
+ adder_size: int = -1,
161
+ carry_size: int = -1,
162
+ search_all_decompose_dc: bool = True,
163
+ ) -> Pipeline:
164
+ """Solve the CMVM problem with cascaded two matrices.
165
+
166
+ Parameters
167
+ ----------
168
+ kernel : np.ndarray
169
+ The input kernel matrix to be implemented.
170
+ method0 : str, optional
171
+ Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
172
+ method1 : str, optional
173
+ Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
174
+ hard_dc : int, optional
175
+ Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
176
+ decompose_dc : int, optional
177
+ Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
178
+ qintervals : list[QInterval] | None, optional
179
+ List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
180
+ inp_latencies : list[float] | None, optional
181
+ List of input latencies, by default None (0. for all inputs)
182
+ adder_size : int, optional
183
+ Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
184
+ carry_size : int, optional
185
+ Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
186
+ search_all_decompose_dc : bool, optional
187
+ If True, search for all possible decomposition depth constraints. If False, use the provided decompose_dc value.
188
+ Default is True.
189
+
190
+ Returns
191
+ -------
192
+ CascadedSolution
193
+ A solution containing the optimized implementation of the CMVM computation with cascaded stages.
194
+ """
195
+
196
+ if qintervals is None:
197
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
198
+ else:
199
+ _qintervals = list(qintervals)
200
+ if latencies is None:
201
+ _latencies = [0.0] * kernel.shape[0]
202
+ else:
203
+ _latencies = [float(lat) for lat in latencies]
204
+
205
+ if not search_all_decompose_dc:
206
+ return jit_solve(
207
+ kernel,
208
+ method0=method0,
209
+ method1=method1,
210
+ hard_dc=hard_dc,
211
+ decompose_dc=decompose_dc,
212
+ qintervals=_qintervals,
213
+ latencies=_latencies,
214
+ adder_size=adder_size,
215
+ carry_size=carry_size,
216
+ )
217
+
218
+ if hard_dc < 0:
219
+ hard_dc = int(1e9)
220
+
221
+ max_decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
222
+ try_decompose_dcs = list(range(-1, max_decompose_dc + 1))
223
+
224
+ costs = np.empty(len(try_decompose_dcs), dtype=np.float64)
225
+
226
+ for i in prange(len(try_decompose_dcs)):
227
+ decompose_dc = try_decompose_dcs[i]
228
+ _csol = jit_solve(
229
+ kernel,
230
+ method0=method0,
231
+ method1=method1,
232
+ hard_dc=hard_dc,
233
+ decompose_dc=decompose_dc,
234
+ qintervals=_qintervals,
235
+ latencies=_latencies,
236
+ adder_size=adder_size,
237
+ carry_size=carry_size,
238
+ )
239
+ _cost = sum([sum([op.cost for op in sol.ops]) for sol in _csol.solutions])
240
+ costs[i] = _cost
241
+
242
+ decompose_dc = try_decompose_dcs[np.argmin(costs)]
243
+ csol = jit_solve(
244
+ kernel,
245
+ method0=method0,
246
+ method1=method1,
247
+ hard_dc=hard_dc,
248
+ decompose_dc=decompose_dc,
249
+ qintervals=_qintervals,
250
+ latencies=_latencies,
251
+ adder_size=adder_size,
252
+ carry_size=carry_size,
253
+ )
254
+ return csol
255
+
256
+
257
+ class solver_options_t(TypedDict, total=False):
258
+ method0: str
259
+ method1: str
260
+ hard_dc: int
261
+ decompose_dc: int
262
+ adder_size: int
263
+ carry_size: int
264
+ search_all_decompose_dc: bool
@@ -0,0 +1,221 @@
1
+ import heapq
2
+ from math import log2
3
+
4
+ import numpy as np
5
+ from numba import jit
6
+
7
+ from ..types import CombLogic, DAState, Op, QInterval
8
+ from .indexers import (
9
+ idx_mc,
10
+ idx_mc_dc,
11
+ idx_wmc,
12
+ idx_wmc_dc,
13
+ )
14
+ from .state_opr import cost_add, create_state, qint_add, update_state
15
+
16
+
17
+ @jit(cache=True)
18
+ def cmvm(
19
+ kernel: np.ndarray,
20
+ method: str = 'wmc',
21
+ qintervals: list[QInterval] | None = None,
22
+ inp_latencies: list[float] | None = None,
23
+ adder_size: int = -1,
24
+ carry_size: int = -1,
25
+ ) -> DAState:
26
+ """Optimizes the kernel using the CMVM algorithm.
27
+
28
+ Parameters
29
+ ----------
30
+ kernel : np.ndarray
31
+ The kernel to optimize.
32
+ method : str, optional
33
+ Which indexing method to use, by default 'wmc' (weighted most common)
34
+ Must be one of [`mc`, `mc-dc`, `mc-pdc`, `wmc`, `wmc-dc`, `wmc-pdc`, `dummy`]
35
+ qintervals : list[QInterval] | None, optional
36
+ List of QIntervals for each input, by default None
37
+ If None, defaults to [-128., 127., 1.] for each input.
38
+ inp_latencies : list[float] | None, optional
39
+ List of latencies for each input, by default None
40
+ If None, defaults to 0. for each input.
41
+ adder_size : int, optional
42
+ The atomic size of the adder for cost computation, by default -1
43
+ if -1, each adder can be arbitrary large, and the cost will be the number of adders
44
+ carry_size : int, optional
45
+ The size of the carry unit for latency computation, by default -1
46
+ if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
47
+
48
+ Returns
49
+ -------
50
+ DAState
51
+ The optimized kernel as a DAState object.
52
+ """
53
+
54
+ if qintervals is None:
55
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
56
+ else:
57
+ _qintervals = [QInterval(*qi) for qi in qintervals]
58
+ if inp_latencies is None:
59
+ _inp_latencies = [0.0] * kernel.shape[0]
60
+ else:
61
+ _inp_latencies = [float(lat) for lat in inp_latencies]
62
+ assert len(_qintervals) == kernel.shape[0]
63
+ assert len(_inp_latencies) == kernel.shape[0]
64
+
65
+ state = create_state(kernel, _qintervals, _inp_latencies)
66
+ while True:
67
+ if len(state.freq_stat) == 0:
68
+ break
69
+ match method:
70
+ case 'mc':
71
+ pair_idx = idx_mc(state)
72
+ case 'mc-dc':
73
+ pair_idx = idx_mc_dc(state, absolute=True)
74
+ case 'mc-pdc':
75
+ pair_idx = idx_mc_dc(state, absolute=False)
76
+ case 'wmc':
77
+ pair_idx = idx_wmc(state)
78
+ case 'wmc-dc':
79
+ pair_idx = idx_wmc_dc(state, absolute=True)
80
+ case 'wmc-pdc':
81
+ pair_idx = idx_wmc_dc(state, absolute=False)
82
+ case 'dummy':
83
+ break
84
+ case _:
85
+ raise ValueError(f'Unknown method: {method}')
86
+ if pair_idx < 0:
87
+ break
88
+ pair_chosen = list(state.freq_stat.keys())[pair_idx]
89
+ state = update_state(state, pair_chosen, adder_size=adder_size, carry_size=carry_size)
90
+ return state
91
+
92
+
93
+ @jit(cache=True)
94
+ def to_solution(
95
+ state: DAState,
96
+ adder_size: int,
97
+ carry_size: int,
98
+ ):
99
+ """Converts the DAState to a Solution object with balanced tree reduction for the non-extracted bits in the kernel.
100
+
101
+ Parameters
102
+ ----------
103
+ state : DAState
104
+ The DAState to convert.
105
+ adder_size : int, optional
106
+ The atomic size of the adder for cost computation, by default -1
107
+ if -1, each adder can be arbitrary large, and the cost will be the number of adders
108
+ carry_size : int, optional
109
+ The size of the carry unit for latency computation, by default -1
110
+ if -1, each carry unit can be arbitrary large, and the cost will be the depth of the adder tree
111
+
112
+ Returns
113
+ -------
114
+ Solution
115
+ The Solution object with the optimized kernel.
116
+ """
117
+
118
+ ops = state.ops.copy()
119
+ n_out = state.kernel.shape[1]
120
+ expr = np.empty((len(state.expr), *state.expr[0].shape), dtype=np.int8)
121
+ for i, v in enumerate(state.expr):
122
+ expr[i] = v
123
+ in_shifts, out_shifts = state.shifts
124
+
125
+ out_qints = []
126
+ out_lats = []
127
+ out_idx = []
128
+ in_shift = in_shifts.copy()
129
+ out_shift = out_shifts.copy()
130
+ out_neg = []
131
+
132
+ _global_id = len(ops)
133
+ for i_out in range(n_out):
134
+ idx, shifts = np.where(expr[:, i_out] != 0)
135
+ sub = np.empty(len(idx), dtype=np.int64)
136
+ for i, (i_in, shift) in enumerate(zip(idx, shifts)):
137
+ sub[i] = expr[i_in, i_out, shift] == -1
138
+
139
+ qints: list[QInterval] = [state.ops[i].qint for i in idx]
140
+ lats: list[float] = [state.ops[i].latency for i in idx]
141
+
142
+ # No reduction required, dump the realized value directly
143
+ if len(sub) == 1:
144
+ out_shift[i_out] = out_shift[i_out] + shifts[0]
145
+ out_qints.append(qints[0])
146
+ out_lats.append(lats[0])
147
+ out_idx.append(idx[0])
148
+ out_neg.append(sub[0])
149
+ continue
150
+ # Output is zero
151
+ if len(sub) == 0:
152
+ out_idx.append(-1) # -1 means output constant zero
153
+ out_qints.append(QInterval(0.0, 0.0, np.inf))
154
+ out_lats.append(0.0)
155
+ out_neg.append(False)
156
+ continue
157
+
158
+ # Sort by latency -> location of rightmost bit -> lower bound
159
+ left_align: list[int] = []
160
+ for i, qint in enumerate(qints):
161
+ n_int = int(log2(max(abs(qint.max + qint.step), abs(qint.min))))
162
+ left_align.append(n_int + shifts[i])
163
+ heap = list(zip(lats, sub, left_align, qints, idx, shifts))
164
+ heapq.heapify(heap)
165
+
166
+ while len(heap) > 1:
167
+ lat0, sub0, _, qint0, id0, shift0 = heapq.heappop(heap)
168
+ lat1, sub1, _, qint1, id1, shift1 = heapq.heappop(heap)
169
+
170
+ if sub0:
171
+ shift = shift0 - shift1
172
+ qint = qint_add(qint1, qint0, shift, sub1, sub0)
173
+ dlat, dcost = cost_add(qint1, qint0, shift=shift, sub=1 ^ sub1, adder_size=adder_size, carry_size=carry_size)
174
+ lat = max(lat0, lat1) + dlat
175
+ op = Op(id1, id0, 1 ^ sub1, shift, qint, lat, dcost)
176
+ shift = shift1
177
+ else:
178
+ shift = shift1 - shift0
179
+ qint = qint_add(qint0, qint1, shift, sub0, sub1)
180
+ dlat, dcost = cost_add(qint0, qint1, shift=shift, sub=sub1, adder_size=adder_size, carry_size=carry_size)
181
+ lat = max(lat0, lat1) + dlat
182
+ op = Op(id0, id1, sub1, shift, qint, lat, dcost)
183
+ shift = shift0
184
+
185
+ left_align = int(log2(max(abs(qint.max + qint.step), abs(qint.min)))) + shift
186
+ heapq.heappush(heap, (lat, sub0 & sub1, left_align, qint, _global_id, shift))
187
+ ops.append(op)
188
+ _global_id += 1
189
+
190
+ lat, sub, _, qint, id0, shift0 = heap[0]
191
+ out_idx.append(_global_id - 1)
192
+ out_qints.append(qint)
193
+ out_lats.append(lat)
194
+ out_neg.append(sub)
195
+ out_shift[i_out] = out_shift[i_out] + shift0
196
+
197
+ return CombLogic(
198
+ shape=state.kernel.shape, # type: ignore
199
+ inp_shifts=list(in_shift),
200
+ out_idxs=out_idx,
201
+ out_shifts=list(out_shift),
202
+ out_negs=out_neg,
203
+ ops=ops,
204
+ carry_size=carry_size,
205
+ adder_size=adder_size,
206
+ )
207
+
208
+
209
+ @jit
210
+ def _solve(
211
+ kernel: np.ndarray,
212
+ method: str,
213
+ qintervals: list[QInterval],
214
+ latencies: list[float],
215
+ adder_size: int,
216
+ carry_size: int,
217
+ ):
218
+ state = cmvm(
219
+ kernel, method=method, qintervals=qintervals, inp_latencies=latencies, adder_size=adder_size, carry_size=carry_size
220
+ )
221
+ return to_solution(state, adder_size=adder_size, carry_size=carry_size)
@@ -0,0 +1,83 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+ from numba import jit
5
+
6
+ from ..types import DAState, QInterval
7
+
8
+
9
+ @jit
10
+ def idx_mc(state: DAState):
11
+ """Choose the pair with highest frequency."""
12
+ freqs = list(state.freq_stat.values())
13
+ max_freq = max(freqs)
14
+ pair_idx = freqs.index(max_freq)
15
+ return pair_idx
16
+
17
+
18
+ @jit
19
+ def idx_mc_dc(state: DAState, absolute: bool = False):
20
+ """Choose the pair with highest frequency with latency penalty.
21
+ If absolute is True, return -1 if any latency overhead may present."""
22
+ freqs = list(state.freq_stat.values())
23
+ factor = max(freqs) + 1
24
+ ops = state.ops
25
+ lat_penalty = [abs(ops[pair.id1].latency - ops[pair.id0].latency) * factor for pair in state.freq_stat.keys()]
26
+ score = [freq - lat_penalty[i] for i, freq in enumerate(freqs)]
27
+ max_score = max(score)
28
+ if absolute and max_score < 0:
29
+ return -1
30
+ pair_idx = score.index(max_score)
31
+ return pair_idx
32
+
33
+
34
+ @jit
35
+ def overlap_and_accum(qint0: QInterval, qint1: QInterval):
36
+ """Calculate the overlap and total number of bits for two QIntervals, when represented in fixed-point format."""
37
+
38
+ min0, max0, step0 = qint0
39
+ min1, max1, step1 = qint1
40
+ max0, max1 = max0 + step0, max1 + step1
41
+
42
+ f = -log2(max(step0, step1))
43
+ i_high = ceil(log2(max(abs(min0), abs(min1), abs(max0), abs(max1))))
44
+ i_low = ceil(log2(min(max(abs(min0), abs(max0)), max(abs(min1), abs(max1)))))
45
+ k = int(qint0.min < 0 or qint1.min < 0)
46
+ n_accum = k + i_high + f
47
+ n_overlap = k + i_low + f
48
+ return n_overlap, n_accum
49
+
50
+
51
+ @jit
52
+ def idx_wmc(state: DAState):
53
+ """Choose the pair with the highest weighted most common subexpression (WMC) score."""
54
+ freqs = list(state.freq_stat.values())
55
+ keys = list(state.freq_stat.keys())
56
+ score = np.empty(len(freqs), dtype=np.float32)
57
+ for i, (k, v) in enumerate(zip(keys, freqs)):
58
+ id0, id1 = k.id0, k.id1
59
+ qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
60
+ n_overlap, _ = overlap_and_accum(qint0, qint1)
61
+ score[i] = v * n_overlap
62
+ max_score = np.max(score)
63
+ if max_score < 0:
64
+ return -1
65
+ return np.argmax(score)
66
+
67
+
68
+ @jit
69
+ def idx_wmc_dc(state: DAState, absolute: bool = False):
70
+ """Choose the pair with the highest weighted most common subexpression (WMC) score with latency and cost penalty.
71
+ When absolute is True, return -1 if any latency overhead may present."""
72
+ freqs = list(state.freq_stat.values())
73
+ keys = list(state.freq_stat.keys())
74
+ score = np.empty(len(freqs), dtype=np.float32)
75
+ for i, (k, v) in enumerate(zip(keys, freqs)):
76
+ id0, id1 = k.id0, k.id1
77
+ qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
78
+ lat0, lat1 = state.ops[id0].latency, state.ops[id1].latency
79
+ n_overlap, _ = overlap_and_accum(qint0, qint1)
80
+ score[i] = v * n_overlap - 256 * abs(lat0 - lat1)
81
+ if absolute and np.max(score) < 0:
82
+ return -1
83
+ return np.argmax(score)