da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +194 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +240 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.0.dist-info/METADATA +85 -0
  92. da4ml-0.5.0.dist-info/RECORD +96 -0
  93. da4ml-0.5.0.dist-info/WHEEL +6 -0
  94. da4ml-0.5.0.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,284 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+ from numba import jit
5
+
6
+ from ..types import DAState, Op, Pair, QInterval
7
+ from ..util import csd_decompose
8
+
9
+
10
+ @jit
11
+ def qint_add(qint0: QInterval, qint1: QInterval, shift: int, sub0=False, sub1=False) -> QInterval:
12
+ min0, max0, step0 = qint0
13
+ min1, max1, step1 = qint1
14
+ if sub0:
15
+ min0, max0 = -max0, -min0
16
+ if sub1:
17
+ min1, max1 = -max1, -min1
18
+
19
+ s = 2.0**shift
20
+ min1, max1, step1 = min1 * s, max1 * s, step1 * s
21
+
22
+ return QInterval(min0 + min1, max0 + max1, min(step0, step1))
23
+
24
+
25
+ @jit
26
+ def cost_add(
27
+ qint0: QInterval, qint1: QInterval, shift: int, sub: bool = False, adder_size: int = -1, carry_size: int = -1
28
+ ) -> tuple[float, float]:
29
+ """Calculate the latency and cost of an addition operation.
30
+
31
+ Parameters
32
+ ----------
33
+ qint1 : QInterval
34
+ The first QInterval.
35
+ qint2 : QInterval
36
+ The second QInterval.
37
+ sub : bool
38
+ If True, the operation is a subtraction (a - b) instead of an addition (a + b).
39
+ adder_size : int
40
+ The atomic size of the adder.
41
+ carry_size : int
42
+ The size of the look-ahead carry.
43
+
44
+ Returns
45
+ -------
46
+ tuple[float, float]
47
+ The latency and cost of the addition operation.
48
+ """
49
+ if adder_size < 0 and carry_size < 0:
50
+ return 1.0, 1.0
51
+ if adder_size < 0:
52
+ adder_size = 65535
53
+ if carry_size < 0:
54
+ carry_size = 65535
55
+
56
+ min0, max0, step0 = qint0
57
+ min1, max1, step1 = qint1
58
+ if sub:
59
+ min1, max1 = max1, min1
60
+ sf = 2.0**shift
61
+ min1, max1, step1 = min1 * sf, max1 * sf, step1 * sf
62
+ max0, max1 = max0 + step0, max1 + step1
63
+
64
+ f = -log2(max(step0, step1))
65
+ i = ceil(log2(max(abs(min0), abs(min1), abs(max0), abs(max1))))
66
+ k = int(qint0.min < 0 or qint1.min < 0)
67
+ n_accum = k + i + f
68
+ # Align to the number of carry and adder bits, when they are block-based (e.g., 4/8 bits look-ahead carry in Xilinx FPGAs)
69
+ # For Altera, the carry seems to be single bit adder chains, but need to check
70
+ return float(ceil(n_accum / carry_size)), float(ceil(n_accum / adder_size))
71
+
72
+
73
+ @jit
74
+ def create_state(
75
+ kernel: np.ndarray,
76
+ qintervals: list[QInterval],
77
+ inp_latencies: list[float],
78
+ no_stat_init: bool = False,
79
+ ):
80
+ assert len(qintervals) == kernel.shape[0]
81
+ assert len(inp_latencies) == kernel.shape[0]
82
+ assert kernel.ndim == 2
83
+
84
+ kernel = kernel.astype(np.float64)
85
+ n_in, n_out = kernel.shape
86
+ kernel = np.asarray(kernel)
87
+ csd, shift0, shift1 = csd_decompose(kernel)
88
+ for i, qint in enumerate(qintervals):
89
+ if qint.min == qint.max == 0:
90
+ csd[i] = 0
91
+ n_bits = csd.shape[-1]
92
+ expr = list(csd)
93
+ shifts = (shift0, shift1)
94
+
95
+ # Dirty numba typing trick
96
+ stat = {Pair(-1, -1, False, 0): 0}
97
+ del stat[Pair(-1, -1, False, 0)]
98
+
99
+ # Loop over outputs, in0, in1, shift0, shift1 to gather all two-term pairs
100
+ # Force i1>=i0
101
+ if not no_stat_init:
102
+ # Initialize the stat dictionary
103
+ # Skip if no_stat_init is True (skip optimization)
104
+ for i_out in range(n_out):
105
+ for i0 in range(n_in):
106
+ for j0 in range(n_bits):
107
+ bit0 = csd[i0, i_out, j0]
108
+ if not bit0:
109
+ continue
110
+ for i1 in range(i0, n_in):
111
+ for j1 in range(n_bits):
112
+ bit1 = csd[i1, i_out, j1]
113
+ if not bit1:
114
+ continue
115
+ # Avoid count the same bit
116
+ if i0 == i1 and j0 <= j1:
117
+ continue
118
+ pair = Pair(i0, i1, bit0 != bit1, j1 - j0)
119
+ stat[pair] = stat.get(pair, 0) + 1
120
+
121
+ for k in list(stat.keys()):
122
+ if stat[k] < 2.0:
123
+ del stat[k]
124
+
125
+ ops = [Op(i, -1, -1, 0, qintervals[i], inp_latencies[i], 0.0) for i in range(n_in)]
126
+
127
+ return DAState(
128
+ shifts=shifts,
129
+ expr=expr,
130
+ ops=ops,
131
+ freq_stat=stat,
132
+ kernel=kernel,
133
+ )
134
+
135
+
136
+ @jit
137
+ def update_stats(
138
+ state: DAState,
139
+ pair: Pair,
140
+ ):
141
+ """Updates the statistics of any 2-term pair in the state that may be affected by implementing op."""
142
+ id0, id1 = pair.id0, pair.id1
143
+
144
+ ks = list(state.freq_stat.keys())
145
+ for k in ks:
146
+ if k.id0 == id0 or k.id1 == id1 or k.id1 == id0 or k.id0 == id1:
147
+ del state.freq_stat[k]
148
+
149
+ n_constructed = len(state.expr)
150
+ modified = [n_constructed - 1]
151
+ modified.append(id0)
152
+ if id1 != id0:
153
+ modified.append(id1)
154
+
155
+ n_bits = state.expr[0].shape[-1]
156
+
157
+ # Loop over outputs, in0, in1, shift0, shift1 to gather all two-term pairs
158
+ for i_out in range(state.kernel.shape[1]):
159
+ for _in0 in modified:
160
+ for _in1 in range(n_constructed):
161
+ if _in1 in modified and _in0 > _in1:
162
+ # Avoid double counting of the two locations when _i0 != _i1
163
+ continue
164
+ # Order inputs, as _in0 can be either in0 or in1, range of _in is not restricted
165
+ id0, id1 = (_in0, _in1) if _in0 <= _in1 else (_in1, _in0)
166
+ for j0 in range(n_bits):
167
+ bit0 = state.expr[id0][i_out, j0]
168
+ if not bit0:
169
+ continue
170
+ for j1 in range(n_bits):
171
+ bit1 = state.expr[id1][i_out, j1]
172
+ if not bit1:
173
+ continue
174
+ if id0 == id1 and j0 <= j1:
175
+ continue
176
+ pair = Pair(id0, id1, bit0 != bit1, j1 - j0)
177
+ state.freq_stat[pair] = state.freq_stat.get(pair, 0) + 1
178
+
179
+ ks, vs = list(state.freq_stat.keys()), list(state.freq_stat.values())
180
+ for k, v in zip(ks, vs):
181
+ if v < 2.0:
182
+ del state.freq_stat[k]
183
+ return state
184
+
185
+
186
+ @jit
187
+ def gather_matching_idxs(state: DAState, pair: Pair):
188
+ """Generates all i_out, j0, j1 ST expr[i_out][in0, j0] and expr[i_out][in1, j1] corresponds to op provided."""
189
+ id0, id1 = pair.id0, pair.id1
190
+ shift = pair.shift
191
+ sub = pair.sub
192
+ n_out = state.kernel.shape[1]
193
+ n_bits = state.expr[0].shape[-1]
194
+
195
+ flip = False
196
+ if shift < 0:
197
+ id0, id1 = id1, id0
198
+ shift = -shift
199
+ flip = True
200
+
201
+ sign = 1 if not sub else -1
202
+
203
+ for j0 in range(n_bits - shift):
204
+ for i_out in range(n_out):
205
+ bit0 = state.expr[id0][i_out, j0]
206
+ j1 = j0 + shift
207
+ bit1 = state.expr[id1][i_out, j1]
208
+ if sign * bit1 * bit0 != 1:
209
+ continue
210
+
211
+ if flip:
212
+ yield i_out, j1, j0
213
+ else:
214
+ yield i_out, j0, j1
215
+
216
+
217
+ @jit
218
+ def pair_to_op(pair: Pair, state: DAState, adder_size: int = -1, carry_size: int = -1):
219
+ id0, id1 = pair.id0, pair.id1
220
+ dlat, cost = cost_add(
221
+ state.ops[pair.id0].qint,
222
+ state.ops[pair.id1].qint,
223
+ pair.shift,
224
+ pair.sub,
225
+ adder_size=adder_size,
226
+ carry_size=carry_size,
227
+ )
228
+ lat = max(state.ops[id0].latency, state.ops[id1].latency) + dlat
229
+ qint = qint_add(
230
+ state.ops[pair.id0].qint,
231
+ state.ops[pair.id1].qint,
232
+ shift=pair.shift,
233
+ sub1=pair.sub,
234
+ )
235
+ return Op(id0, id1, int(pair.sub), pair.shift, qint, lat, cost)
236
+
237
+
238
+ @jit
239
+ def update_expr(
240
+ state: DAState,
241
+ pair: Pair,
242
+ adder_size: int,
243
+ carry_size: int,
244
+ ):
245
+ "Updates the state by implementing the operation op, excepts common 2-term pair freq update."
246
+ id0, id1 = pair.id0, pair.id1
247
+ op = pair_to_op(pair, state, adder_size=adder_size, carry_size=carry_size)
248
+ n_out = state.kernel.shape[1]
249
+ n_bits = state.expr[0].shape[-1]
250
+
251
+ expr = state.expr.copy()
252
+ ops = state.ops.copy()
253
+
254
+ ops.append(op)
255
+
256
+ new_slice = np.zeros((n_out, n_bits), dtype=np.int8)
257
+
258
+ for i_out, j0, j1 in gather_matching_idxs(state, pair):
259
+ new_slice[i_out, j0] = expr[id0][i_out, j0]
260
+ expr[id0][i_out, j0] = 0
261
+ expr[id1][i_out, j1] = 0
262
+
263
+ expr.append(new_slice)
264
+
265
+ return DAState(
266
+ shifts=state.shifts,
267
+ expr=expr,
268
+ ops=ops,
269
+ freq_stat=state.freq_stat,
270
+ kernel=state.kernel,
271
+ )
272
+
273
+
274
+ @jit
275
+ def update_state(
276
+ state: DAState,
277
+ pair_chosen: Pair,
278
+ adder_size: int,
279
+ carry_size: int,
280
+ ):
281
+ """Update the state by removing all occurrences of pair_chosen from the state, register op code, and update the statistics."""
282
+ state = update_expr(state, pair_chosen, adder_size=adder_size, carry_size=carry_size)
283
+ state = update_stats(state, pair_chosen)
284
+ return state