da4ml 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +235 -73
  5. da4ml/cmvm/core/__init__.py +221 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +67 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +74 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +268 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +187 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +122 -0
  35. da4ml-0.2.1.dist-info/METADATA +65 -0
  36. da4ml-0.2.1.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,83 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+ from numba import jit
5
+
6
+ from ..types import DAState, QInterval
7
+
8
+
9
+ @jit
10
+ def idx_mc(state: DAState):
11
+ """Choose the pair with highest frequency."""
12
+ freqs = list(state.freq_stat.values())
13
+ max_freq = max(freqs)
14
+ pair_idx = freqs.index(max_freq)
15
+ return pair_idx
16
+
17
+
18
+ @jit
19
+ def idx_mc_dc(state: DAState, absolute: bool = False):
20
+ """Choose the pair with highest frequency with latency penalty.
21
+ If absolute is True, return -1 if any latency overhead may present."""
22
+ freqs = list(state.freq_stat.values())
23
+ factor = max(freqs) + 1
24
+ ops = state.ops
25
+ lat_penalty = [abs(ops[pair.id1].latency - ops[pair.id0].latency) * factor for pair in state.freq_stat.keys()]
26
+ score = [freq - lat_penalty[i] for i, freq in enumerate(freqs)]
27
+ max_score = max(score)
28
+ if absolute and max_score < 0:
29
+ return -1
30
+ pair_idx = score.index(max_score)
31
+ return pair_idx
32
+
33
+
34
+ @jit
35
+ def overlap_and_accum(qint0: QInterval, qint1: QInterval):
36
+ """Calculate the overlap and total number of bits for two QIntervals, when represented in fixed-point format."""
37
+
38
+ min0, max0, step0 = qint0
39
+ min1, max1, step1 = qint1
40
+ max0, max1 = max0 + step0, max1 + step1
41
+
42
+ f = -log2(max(step0, step1))
43
+ i_high = ceil(log2(max(abs(min0), abs(min1), abs(max0), abs(max1))))
44
+ i_low = ceil(log2(min(max(abs(min0), abs(max0)), max(abs(min1), abs(max1)))))
45
+ k = int(qint0.min < 0 or qint1.min < 0)
46
+ n_accum = k + i_high + f
47
+ n_overlap = k + i_low + f
48
+ return n_overlap, n_accum
49
+
50
+
51
+ @jit
52
+ def idx_wmc(state: DAState):
53
+ """Choose the pair with the highest weighted most common subexpression (WMC) score."""
54
+ freqs = list(state.freq_stat.values())
55
+ keys = list(state.freq_stat.keys())
56
+ score = np.empty(len(freqs), dtype=np.float32)
57
+ for i, (k, v) in enumerate(zip(keys, freqs)):
58
+ id0, id1 = k.id0, k.id1
59
+ qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
60
+ n_overlap, _ = overlap_and_accum(qint0, qint1)
61
+ score[i] = v * n_overlap
62
+ max_score = np.max(score)
63
+ if max_score < 0:
64
+ return -1
65
+ return np.argmax(score)
66
+
67
+
68
+ @jit
69
+ def idx_wmc_dc(state: DAState, absolute: bool = False):
70
+ """Choose the pair with the highest weighted most common subexpression (WMC) score with latency and cost penalty.
71
+ When absolute is True, return -1 if any latency overhead may present."""
72
+ freqs = list(state.freq_stat.values())
73
+ keys = list(state.freq_stat.keys())
74
+ score = np.empty(len(freqs), dtype=np.float32)
75
+ for i, (k, v) in enumerate(zip(keys, freqs)):
76
+ id0, id1 = k.id0, k.id1
77
+ qint0, qint1 = state.ops[id0].qint, state.ops[id1].qint
78
+ lat0, lat1 = state.ops[id0].latency, state.ops[id1].latency
79
+ n_overlap, _ = overlap_and_accum(qint0, qint1)
80
+ score[i] = v * n_overlap - 256 * abs(lat0 - lat1)
81
+ if absolute and np.max(score) < 0:
82
+ return -1
83
+ return np.argmax(score)
@@ -0,0 +1,284 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+ from numba import jit
5
+
6
+ from ..types import DAState, Op, Pair, QInterval
7
+ from ..util import csd_decompose
8
+
9
+
10
+ @jit
11
+ def qint_add(qint0: QInterval, qint1: QInterval, shift: int, sub0=False, sub1=False) -> QInterval:
12
+ min0, max0, step0 = qint0
13
+ min1, max1, step1 = qint1
14
+ if sub0:
15
+ min0, max0 = -max0, -min0
16
+ if sub1:
17
+ min1, max1 = -max1, -min1
18
+
19
+ s = 2.0**shift
20
+ min1, max1, step1 = min1 * s, max1 * s, step1 * s
21
+
22
+ return QInterval(min0 + min1, max0 + max1, min(step0, step1))
23
+
24
+
25
+ @jit
26
+ def cost_add(
27
+ qint0: QInterval, qint1: QInterval, shift: int, sub: bool = False, adder_size: int = -1, carry_size: int = -1
28
+ ) -> tuple[float, float]:
29
+ """Calculate the latency and cost of an addition operation.
30
+
31
+ Parameters
32
+ ----------
33
+ qint1 : QInterval
34
+ The first QInterval.
35
+ qint2 : QInterval
36
+ The second QInterval.
37
+ sub : bool
38
+ If True, the operation is a subtraction (a - b) instead of an addition (a + b).
39
+ adder_size : int
40
+ The atomic size of the adder.
41
+ carry_size : int
42
+ The size of the look-ahead carry.
43
+
44
+ Returns
45
+ -------
46
+ tuple[float, float]
47
+ The latency and cost of the addition operation.
48
+ """
49
+ if adder_size < 0 and carry_size < 0:
50
+ return 1.0, 1.0
51
+ if adder_size < 0:
52
+ adder_size = 65535
53
+ if carry_size < 0:
54
+ carry_size = 65535
55
+
56
+ min0, max0, step0 = qint0
57
+ min1, max1, step1 = qint1
58
+ if sub:
59
+ min1, max1 = max1, min1
60
+ sf = 2.0**shift
61
+ min1, max1, step1 = min1 * sf, max1 * sf, step1 * sf
62
+ max0, max1 = max0 + step0, max1 + step1
63
+
64
+ f = -log2(max(step0, step1))
65
+ i = ceil(log2(max(abs(min0), abs(min1), abs(max0), abs(max1))))
66
+ k = int(qint0.min < 0 or qint1.min < 0)
67
+ n_accum = k + i + f
68
+ # Align to the number of carry and adder bits, when they are block-based (e.g., 4/8 bits look-ahead carry in Xilinx FPGAs)
69
+ # For Altera, the carry seems to be single bit adder chains, but need to check
70
+ return float(ceil(n_accum / carry_size)), float(ceil(n_accum / adder_size))
71
+
72
+
73
+ @jit
74
+ def create_state(
75
+ kernel: np.ndarray,
76
+ qintervals: list[QInterval],
77
+ inp_latencies: list[float],
78
+ no_stat_init: bool = False,
79
+ ):
80
+ assert len(qintervals) == kernel.shape[0]
81
+ assert len(inp_latencies) == kernel.shape[0]
82
+ assert kernel.ndim == 2
83
+
84
+ kernel = kernel.astype(np.float64)
85
+ n_in, n_out = kernel.shape
86
+ kernel = np.asarray(kernel)
87
+ csd, shift0, shift1 = csd_decompose(kernel)
88
+ for i, qint in enumerate(qintervals):
89
+ if qint.min == qint.max == 0:
90
+ csd[i] = 0
91
+ n_bits = csd.shape[-1]
92
+ expr = list(csd)
93
+ shifts = (shift0, shift1)
94
+
95
+ # Dirty numba typing trick
96
+ stat = {Pair(-1, -1, False, 0): 0}
97
+ del stat[Pair(-1, -1, False, 0)]
98
+
99
+ # Loop over outputs, in0, in1, shift0, shift1 to gather all two-term pairs
100
+ # Force i1>=i0
101
+ if not no_stat_init:
102
+ # Initialize the stat dictionary
103
+ # Skip if no_stat_init is True (skip optimization)
104
+ for i_out in range(n_out):
105
+ for i0 in range(n_in):
106
+ for j0 in range(n_bits):
107
+ bit0 = csd[i0, i_out, j0]
108
+ if not bit0:
109
+ continue
110
+ for i1 in range(i0, n_in):
111
+ for j1 in range(n_bits):
112
+ bit1 = csd[i1, i_out, j1]
113
+ if not bit1:
114
+ continue
115
+ # Avoid count the same bit
116
+ if i0 == i1 and j0 <= j1:
117
+ continue
118
+ pair = Pair(i0, i1, bit0 != bit1, j1 - j0)
119
+ stat[pair] = stat.get(pair, 0) + 1
120
+
121
+ for k in list(stat.keys()):
122
+ if stat[k] < 2.0:
123
+ del stat[k]
124
+
125
+ ops = [Op(i, -1, -1, 0, qintervals[i], inp_latencies[i], 0.0) for i in range(n_in)]
126
+
127
+ return DAState(
128
+ shifts=shifts,
129
+ expr=expr,
130
+ ops=ops,
131
+ freq_stat=stat,
132
+ kernel=kernel,
133
+ )
134
+
135
+
136
+ @jit
137
+ def update_stats(
138
+ state: DAState,
139
+ pair: Pair,
140
+ ):
141
+ """Updates the statistics of any 2-term pair in the state that may be affected by implementing op."""
142
+ id0, id1 = pair.id0, pair.id1
143
+
144
+ ks = list(state.freq_stat.keys())
145
+ for k in ks:
146
+ if k.id0 == id0 or k.id1 == id1 or k.id1 == id0 or k.id0 == id1:
147
+ del state.freq_stat[k]
148
+
149
+ n_constructed = len(state.expr)
150
+ modified = [n_constructed - 1]
151
+ modified.append(id0)
152
+ if id1 != id0:
153
+ modified.append(id1)
154
+
155
+ n_bits = state.expr[0].shape[-1]
156
+
157
+ # Loop over outputs, in0, in1, shift0, shift1 to gather all two-term pairs
158
+ for i_out in range(state.kernel.shape[1]):
159
+ for _in0 in modified:
160
+ for _in1 in range(n_constructed):
161
+ if _in1 in modified and _in0 > _in1:
162
+ # Avoid double counting of the two locations when _i0 != _i1
163
+ continue
164
+ # Order inputs, as _in0 can be either in0 or in1, range of _in is not restricted
165
+ id0, id1 = (_in0, _in1) if _in0 <= _in1 else (_in1, _in0)
166
+ for j0 in range(n_bits):
167
+ bit0 = state.expr[id0][i_out, j0]
168
+ if not bit0:
169
+ continue
170
+ for j1 in range(n_bits):
171
+ bit1 = state.expr[id1][i_out, j1]
172
+ if not bit1:
173
+ continue
174
+ if id0 == id1 and j0 <= j1:
175
+ continue
176
+ pair = Pair(id0, id1, bit0 != bit1, j1 - j0)
177
+ state.freq_stat[pair] = state.freq_stat.get(pair, 0) + 1
178
+
179
+ ks, vs = list(state.freq_stat.keys()), list(state.freq_stat.values())
180
+ for k, v in zip(ks, vs):
181
+ if v < 2.0:
182
+ del state.freq_stat[k]
183
+ return state
184
+
185
+
186
+ @jit
187
+ def gather_matching_idxs(state: DAState, pair: Pair):
188
+ """Generates all i_out, j0, j1 ST expr[i_out][in0, j0] and expr[i_out][in1, j1] corresponds to op provided."""
189
+ id0, id1 = pair.id0, pair.id1
190
+ shift = pair.shift
191
+ sub = pair.sub
192
+ n_out = state.kernel.shape[1]
193
+ n_bits = state.expr[0].shape[-1]
194
+
195
+ flip = False
196
+ if shift < 0:
197
+ id0, id1 = id1, id0
198
+ shift = -shift
199
+ flip = True
200
+
201
+ sign = 1 if not sub else -1
202
+
203
+ for j0 in range(n_bits - shift):
204
+ for i_out in range(n_out):
205
+ bit0 = state.expr[id0][i_out, j0]
206
+ j1 = j0 + shift
207
+ bit1 = state.expr[id1][i_out, j1]
208
+ if sign * bit1 * bit0 != 1:
209
+ continue
210
+
211
+ if flip:
212
+ yield i_out, j1, j0
213
+ else:
214
+ yield i_out, j0, j1
215
+
216
+
217
+ @jit
218
+ def pair_to_op(pair: Pair, state: DAState, adder_size: int = -1, carry_size: int = -1):
219
+ id0, id1 = pair.id0, pair.id1
220
+ dlat, cost = cost_add(
221
+ state.ops[pair.id0].qint,
222
+ state.ops[pair.id1].qint,
223
+ pair.shift,
224
+ pair.sub,
225
+ adder_size=adder_size,
226
+ carry_size=carry_size,
227
+ )
228
+ lat = max(state.ops[id0].latency, state.ops[id1].latency) + dlat
229
+ qint = qint_add(
230
+ state.ops[pair.id0].qint,
231
+ state.ops[pair.id1].qint,
232
+ shift=pair.shift,
233
+ sub1=pair.sub,
234
+ )
235
+ return Op(id0, id1, int(pair.sub), pair.shift, qint, lat, cost)
236
+
237
+
238
+ @jit
239
+ def update_expr(
240
+ state: DAState,
241
+ pair: Pair,
242
+ adder_size: int,
243
+ carry_size: int,
244
+ ):
245
+ "Updates the state by implementing the operation op, excepts common 2-term pair freq update."
246
+ id0, id1 = pair.id0, pair.id1
247
+ op = pair_to_op(pair, state, adder_size=adder_size, carry_size=carry_size)
248
+ n_out = state.kernel.shape[1]
249
+ n_bits = state.expr[0].shape[-1]
250
+
251
+ expr = state.expr.copy()
252
+ ops = state.ops.copy()
253
+
254
+ ops.append(op)
255
+
256
+ new_slice = np.zeros((n_out, n_bits), dtype=np.int8)
257
+
258
+ for i_out, j0, j1 in gather_matching_idxs(state, pair):
259
+ new_slice[i_out, j0] = expr[id0][i_out, j0]
260
+ expr[id0][i_out, j0] = 0
261
+ expr[id1][i_out, j1] = 0
262
+
263
+ expr.append(new_slice)
264
+
265
+ return DAState(
266
+ shifts=state.shifts,
267
+ expr=expr,
268
+ ops=ops,
269
+ freq_stat=state.freq_stat,
270
+ kernel=state.kernel,
271
+ )
272
+
273
+
274
+ @jit
275
+ def update_state(
276
+ state: DAState,
277
+ pair_chosen: Pair,
278
+ adder_size: int,
279
+ carry_size: int,
280
+ ):
281
+ """Update the state by removing all occurrences of pair_chosen from the state, register op code, and update the statistics."""
282
+ state = update_expr(state, pair_chosen, adder_size=adder_size, carry_size=carry_size)
283
+ state = update_stats(state, pair_chosen)
284
+ return state