da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +204 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +246 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.1.post1.dist-info/METADATA +85 -0
  92. da4ml-0.5.1.post1.dist-info/RECORD +96 -0
  93. da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
  94. da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,305 @@
1
+ from math import prod
2
+ from typing import TYPE_CHECKING, TypedDict, overload
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ if TYPE_CHECKING:
8
+ from ..fixed_variable_array import FixedVariableArray
9
+
10
+
11
+ class EinsumRecipe(TypedDict):
12
+ direct_sum_axis: tuple[tuple[int, ...], tuple[int, ...]]
13
+ in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]]
14
+ L0: int
15
+ L1: int
16
+ I: int
17
+ C: int
18
+ out_interpert_shape: tuple[int, ...]
19
+ out_transpose_idxs: tuple[int, ...]
20
+
21
+
22
+ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]):
23
+ """Validate, resolve broadcasting, and compute output shape for einsum string
24
+
25
+ Parameters
26
+ ----------
27
+ fn : str
28
+ einsum string, e.g. 'ij,jk->ik'
29
+ shape0 : tuple[int,...]
30
+ shape of input0
31
+ shape1 : tuple[int,...]
32
+ shape of input1
33
+
34
+ Returns
35
+ -------
36
+ tuple[str, tuple[int,...]]
37
+ einsum string w/o broadcasting, and output shape
38
+
39
+ Raises
40
+ ------
41
+ ValueError
42
+ If the einsum string is invalid, or if it is incompatible with the input shapes
43
+ """
44
+ inp, out = map(str.strip, fn.split('->'))
45
+ in0, in1 = map(str.strip, inp.split(','))
46
+ alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
47
+ s_alphabets = set(alphabets)
48
+
49
+ # Invalid characters
50
+ if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))):
51
+ raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only")
52
+
53
+ in0 = in0.replace('...', '0')
54
+ in1 = in1.replace('...', '0')
55
+ out = out.replace('...', '0')
56
+ ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
57
+ sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out)
58
+ free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out))
59
+
60
+ # Repeated indices
61
+ if len(sax_in0) != len(ax_in0):
62
+ for a in in0:
63
+ if in0.count(a) == 1:
64
+ continue
65
+ a = a if a != '0' else '...'
66
+ raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times")
67
+ if len(sax_in1) != len(ax_in1):
68
+ for a in in1:
69
+ if in1.count(a) == 1:
70
+ continue
71
+ a = a if a != '0' else '...'
72
+ raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times")
73
+ if len(sax_out) != len(ax_out):
74
+ for a in out:
75
+ if out.count(a) == 1:
76
+ continue
77
+ a = a if a != '0' else '...'
78
+ raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times")
79
+
80
+ # Invalid broadcasting
81
+ if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out:
82
+ if '0' not in sax_out:
83
+ raise ValueError(f'einsum string {fn} is invalid: output does not allow broadcasting, but inputs do')
84
+ if '0' not in sax_in0 and '0' not in sax_in1:
85
+ raise ValueError(f'einsum string {fn} is invalid: output allows broadcasting, but inputs do not')
86
+
87
+ # Output index out of nowhere
88
+ if remaining := sax_out - sax_in0 - sax_in1:
89
+ raise ValueError(f'einsum string {fn} is invalid: output subscripts {remaining} not found in inputs')
90
+
91
+ _common_in = sax_in0 & sax_in1
92
+
93
+ if '0' in sax_in0 and '0' in sax_in1:
94
+ # Simultaneous axes expansion in both inputs
95
+ n_boardcast0 = len(shape0) - len(sax_in0) + 1
96
+ n_boardcast1 = len(shape1) - len(sax_in1) + 1
97
+ assert n_boardcast0 == n_boardcast1, f'... expands to {n_boardcast0} and {n_boardcast1}-axis in input0 and input1.'
98
+ # Replace expansion indices with free indices
99
+ in0 = in0.replace('0', free_indices[:n_boardcast0])
100
+ in1 = in1.replace('0', free_indices[:n_boardcast1])
101
+ out = out.replace('0', free_indices[:n_boardcast0])
102
+ ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
103
+ _common_in = set(ax_in0) & set(ax_in1)
104
+
105
+ else:
106
+ # Axes expansion in input0 or input1 only
107
+ if '0' in sax_in0:
108
+ if len(sax_in0) - 1 > len(shape0):
109
+ raise ValueError(f'Input0 requires at least {len(sax_in0) - 1} dimensions, but only {len(shape0)} given')
110
+ # Replace auto expansion indices with free indices
111
+ n_broadcast = len(shape0) - len(sax_in0) + 1
112
+ in0 = in0.replace('0', free_indices[:n_broadcast])
113
+ out = out.replace('0', free_indices[:n_broadcast])
114
+ ax_in0 = list(in0)
115
+ ax_out = list(out)
116
+ else:
117
+ if len(sax_in0) != len(shape0):
118
+ raise ValueError(f'Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given')
119
+
120
+ if '0' in sax_in1:
121
+ if len(sax_in1) - 1 > len(shape1):
122
+ raise ValueError(f'Input1 requires at least {len(sax_in1) - 1} dimensions, but only {len(shape1)} given')
123
+ # Replace expansion indices with free indices
124
+ n_broadcast = len(shape1) - len(sax_in1) + 1
125
+ in1 = in1.replace('0', free_indices[:n_broadcast])
126
+ out = out.replace('0', free_indices[:n_broadcast])
127
+ ax_in1 = list(in1)
128
+ ax_out = list(out)
129
+ else:
130
+ if len(sax_in1) != len(shape1):
131
+ raise ValueError(f'Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given')
132
+
133
+ # Input dimension mismatch
134
+ for a in _common_in:
135
+ ax_0 = ax_in0.index(a)
136
+ ax_1 = ax_in1.index(a)
137
+ if shape0[ax_0] != shape1[ax_1]:
138
+ raise ValueError(f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}")
139
+
140
+ out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out)
141
+ return f'{in0},{in1}->{out}', out_shape
142
+
143
+
144
+ def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe:
145
+ """Parse einsum operation on two input arrays, return a recipe for execution
146
+
147
+ Parameters
148
+ ----------
149
+ fn : str
150
+ einsum string, e.g. 'ij,jk->ik'
151
+ input : np.ndarray
152
+ input0, the first input array
153
+ input1 : np.ndarray
154
+ input1, the second input array
155
+
156
+ Returns
157
+ -------
158
+ EinsumRecipe
159
+ einsum recipe; executed by _exec_einsum
160
+ """
161
+
162
+ fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1)
163
+
164
+ _in, _out = fn.split('->')
165
+ _in0, _in1 = _in.split(',')
166
+
167
+ in0, in1, out = list(_in0), list(_in1), list(_out)
168
+ s_in0, s_in1, s_out = set(in0), set(in1), set(out)
169
+ _common = s_in0 & s_in1
170
+ _contract = _common - s_out
171
+ _inplace = _common & s_out
172
+ contract = sorted(_contract, key=lambda x: in1.index(x))
173
+ inplace = sorted(_inplace, key=lambda x: in1.index(x))
174
+ invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x))
175
+ invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x))
176
+ direct_sum0 = s_in0 - s_out - _common
177
+ direct_sum1 = s_in1 - s_out - _common
178
+ direct_sum_axis = (
179
+ tuple(sorted(in0.index(x) for x in direct_sum0)),
180
+ tuple(sorted(in1.index(x) for x in direct_sum1)),
181
+ )
182
+
183
+ contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract))
184
+ inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace))
185
+ invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1))
186
+
187
+ inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0])
188
+ inplace_size = prod(inplace_shape)
189
+ contract_size = prod(input_shape0[i] for i in contract_idxs[0])
190
+ invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0])
191
+ invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1])
192
+ invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1)
193
+
194
+ transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0]
195
+ transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1]
196
+
197
+ out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1
198
+ _out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1)))
199
+ out_transpose_idx = tuple(int(i) for i in _out_transpose_idx)
200
+
201
+ return EinsumRecipe(
202
+ direct_sum_axis=direct_sum_axis,
203
+ in_transpose_idxs=(transpose_idx0, transpose_idx1),
204
+ out_interpert_shape=out_shape_pretranspose,
205
+ out_transpose_idxs=out_transpose_idx,
206
+ L0=invariant_size0,
207
+ L1=invariant_size1,
208
+ I=inplace_size,
209
+ C=contract_size,
210
+ )
211
+
212
+
213
+ def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray:
214
+ """Execute einsum operation on two input arrays
215
+
216
+ Parameters
217
+ ----------
218
+ recipe : EinsumRecipe
219
+ einsum recipe
220
+ input0 : np.ndarray
221
+ input0, the first input array
222
+ input1 : np.ndarray
223
+ input1, the second input array
224
+
225
+ Returns
226
+ -------
227
+ np.ndarray
228
+ output array
229
+ """
230
+ sum_axis0, sum_axis1 = recipe['direct_sum_axis']
231
+ if sum_axis0:
232
+ input0 = np.sum(input0, axis=sum_axis0)
233
+ if sum_axis1:
234
+ input1 = np.sum(input1, axis=sum_axis1)
235
+ input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel()
236
+ input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel()
237
+ out_dtype = object if input0.dtype == object or input1.dtype == object else np.float64
238
+ output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=out_dtype)
239
+
240
+ L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C']
241
+
242
+ for l0 in range(L0):
243
+ for i in range(I):
244
+ A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C))
245
+ B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C]
246
+ output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = A @ B
247
+
248
+ return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs'])
249
+
250
+
251
+ def _einsum(fn: str, input0, input1) -> np.ndarray:
252
+ """Execute einsum operation on two input arrays.
253
+
254
+ WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators
255
+
256
+ Parameters
257
+ ----------
258
+ fn : str
259
+ einsum string, e.g. 'ij,jk->ik'
260
+ input : np.ndarray
261
+ input0, the first input array
262
+ input1 : np.ndarray
263
+ input1, the second input array
264
+
265
+ Returns
266
+ -------
267
+ np.ndarray
268
+ output array
269
+ """
270
+ recipe = parse_einsum(fn, input0.shape, input1.shape)
271
+ return _exec_einsum(recipe, input0, input1)
272
+
273
+
274
+ @overload
275
+ def einsum(fn: str, input0: 'FixedVariableArray', input1: 'FixedVariableArray') -> 'FixedVariableArray': ...
276
+
277
+
278
+ @overload
279
+ def einsum(fn: str, input0: 'FixedVariableArray', input1: NDArray[np.integer | np.floating]) -> 'FixedVariableArray': ...
280
+
281
+
282
+ @overload
283
+ def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: 'FixedVariableArray') -> 'FixedVariableArray': ...
284
+
285
+
286
+ @overload
287
+ def einsum(
288
+ fn: str, input0: NDArray[np.integer | np.floating], input1: NDArray[np.integer | np.floating]
289
+ ) -> NDArray[np.integer | np.floating]: ...
290
+
291
+
292
+ def einsum(fn: str, input0, input1):
293
+ from ..fixed_variable_array import FixedVariableArray
294
+
295
+ fg0 = isinstance(input0, FixedVariableArray)
296
+ fg1 = isinstance(input1, FixedVariableArray)
297
+
298
+ r = _einsum(fn, input0, input1)
299
+
300
+ if fg0:
301
+ return FixedVariableArray(r, input0.solver_options)
302
+ elif fg1:
303
+ return FixedVariableArray(r, input1.solver_options)
304
+ else:
305
+ return r
@@ -0,0 +1,74 @@
1
+ from typing import TYPE_CHECKING, TypeVar
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+ from quantizers.fixed_point.fixed_point_ops_np import get_fixed_quantizer_np
6
+
7
+ from ..fixed_variable_array import FixedVariable
8
+
9
+ if TYPE_CHECKING:
10
+ from ..fixed_variable_array import FixedVariableArray
11
+
12
+ T = TypeVar('T', 'FixedVariableArray', NDArray[np.floating], list[FixedVariable])
13
+
14
+
15
+ def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN') -> T:
16
+ from ..fixed_variable_array import FixedVariableArray
17
+
18
+ if isinstance(x, FixedVariableArray):
19
+ return x.relu(i=i, f=f, round_mode=round_mode)
20
+ elif isinstance(x, list):
21
+ return [xx.relu(i=ii, f=ff, round_mode=round_mode) for xx, ii, ff in zip(x, i, f)] # type: ignore
22
+ else:
23
+ round_mode = round_mode.upper()
24
+ assert round_mode in ('TRN', 'RND')
25
+ x = np.maximum(x, 0)
26
+ if f is not None:
27
+ if round_mode == 'RND':
28
+ x += 2.0 ** (-f - 1)
29
+ sf = 2.0**f
30
+ x = np.floor(x * sf) / sf
31
+ if i is not None:
32
+ x = x % 2.0**i
33
+ return x
34
+
35
+
36
+ def _quantize(
37
+ x: NDArray[np.floating],
38
+ k: NDArray[np.integer] | np.integer | int,
39
+ i: NDArray[np.integer] | np.integer | int,
40
+ f: NDArray[np.integer] | np.integer | int,
41
+ overflow_mode: str = 'WRAP',
42
+ round_mode: str = 'TRN',
43
+ ) -> NDArray[np.floating]:
44
+ q = get_fixed_quantizer_np(round_mode=round_mode, overflow_mode=overflow_mode)
45
+ return np.where(k + i + f <= 0, 0, q(x, k=k, i=i, f=f)) # type: ignore
46
+
47
+
48
+ def quantize(
49
+ x: T,
50
+ k: NDArray[np.integer] | np.integer | int,
51
+ i: NDArray[np.integer] | np.integer | int,
52
+ f: NDArray[np.integer] | np.integer | int,
53
+ overflow_mode: str = 'WRAP',
54
+ round_mode: str = 'TRN',
55
+ ) -> T:
56
+ from ..fixed_variable_array import FixedVariableArray
57
+
58
+ if isinstance(x, (FixedVariableArray, FixedVariable)):
59
+ return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
60
+ elif isinstance(x, list):
61
+ ret: list[FixedVariable] = []
62
+ for i in range(len(x)):
63
+ ret.append(
64
+ x[i].quantize(
65
+ k=int(k[i] if isinstance(k, (list, np.ndarray)) else k),
66
+ i=int(i[i] if isinstance(i, (list, np.ndarray)) else i),
67
+ f=int(f[i] if isinstance(f, (list, np.ndarray)) else f),
68
+ overflow_mode=overflow_mode,
69
+ round_mode=round_mode,
70
+ )
71
+ )
72
+ return ret # type: ignore
73
+ else:
74
+ return _quantize(x, k, i, f, overflow_mode, round_mode)
@@ -0,0 +1,105 @@
1
+ import heapq
2
+ import typing
3
+ from collections.abc import Callable, Sequence
4
+ from math import prod
5
+ from typing import TypeVar
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ if typing.TYPE_CHECKING:
11
+ from ..fixed_variable import FixedVariable
12
+ from ..fixed_variable_array import FixedVariableArray
13
+
14
+
15
+ T = typing.TypeVar('T', 'FixedVariable', float, np.floating)
16
+ TA = TypeVar('TA', 'FixedVariableArray', NDArray[np.integer | np.floating])
17
+
18
+
19
+ class Packet:
20
+ def __init__(self, v):
21
+ self.value = v
22
+
23
+ def __gt__(self, other: 'Packet') -> bool: # type: ignore
24
+ from ..fixed_variable_array import FixedVariable
25
+
26
+ a, b = self.value, other.value
27
+
28
+ if isinstance(a, FixedVariable):
29
+ if isinstance(b, FixedVariable):
30
+ if b.latency > a.latency:
31
+ return False
32
+ if b.latency < a.latency:
33
+ return True
34
+ if b._factor > 0 and a._factor < 0:
35
+ return False
36
+ if b._factor < 0 and a._factor > 0:
37
+ return True
38
+ return sum(a.kif[:2]) > sum(b.kif[:2])
39
+ return True
40
+
41
+ return False
42
+
43
+ def __lt__(self, other: 'Packet') -> bool: # type: ignore
44
+ return not self.__gt__(other)
45
+
46
+
47
+ def _reduce(operator: Callable[[T, T], T], arr: Sequence[T]) -> T:
48
+ from ..fixed_variable_array import FixedVariable
49
+
50
+ if isinstance(arr, np.ndarray):
51
+ arr = list(arr.ravel())
52
+ assert len(arr) > 0, 'Array must not be empty'
53
+ if len(arr) == 1:
54
+ return arr[0]
55
+ dtype = arr[0].__class__
56
+ if not issubclass(dtype, FixedVariable):
57
+ r = operator(arr[0], arr[1])
58
+ for i in range(2, len(arr)):
59
+ r = operator(r, arr[i])
60
+ return r
61
+
62
+ heap = [Packet(v) for v in arr] # type: ignore
63
+ heapq.heapify(heap)
64
+ while len(heap) > 1:
65
+ v1 = heapq.heappop(heap).value
66
+ v2 = heapq.heappop(heap).value
67
+ v = operator(v1, v2)
68
+ heapq.heappush(heap, Packet(v)) # type: ignore
69
+ return heap[0].value
70
+
71
+
72
+ def reduce(operator: Callable[[T, T], T], x: TA, axis: int | Sequence[int] | None = None, keepdims: bool = False) -> TA:
73
+ """
74
+ Reduce the array by summing over the specified axis.
75
+ """
76
+ from ..fixed_variable_array import FixedVariableArray
77
+
78
+ if isinstance(x, FixedVariableArray):
79
+ solver_config = x.solver_options
80
+ arr = x._vars
81
+ else:
82
+ solver_config = None
83
+ arr = x
84
+ all_axis = tuple(range(arr.ndim))
85
+ axis = axis if axis is not None else all_axis
86
+ axis = (axis,) if isinstance(axis, int) else tuple(axis)
87
+ axis = tuple(a if a >= 0 else a + arr.ndim for a in axis)
88
+
89
+ xpose_axis = sorted(all_axis, key=lambda a: (a in axis) * 1000 + a)
90
+ if keepdims:
91
+ target_shape = tuple(d if ax not in axis else 1 for ax, d in enumerate(arr.shape))
92
+ else:
93
+ target_shape = tuple(d for ax, d in enumerate(arr.shape) if ax not in axis)
94
+
95
+ dim_contract = prod(arr.shape[a] for a in axis)
96
+ arr = np.transpose(arr, xpose_axis) # type: ignore
97
+ _arr = arr.reshape(-1, dim_contract)
98
+ _arr = np.array([_reduce(operator, _arr[i]) for i in range(_arr.shape[0])])
99
+ r = _arr.reshape(target_shape) # type: ignore
100
+
101
+ if isinstance(x, FixedVariableArray):
102
+ r = FixedVariableArray(r, solver_config)
103
+ if r.shape == ():
104
+ return r._vars.item() # type: ignore
105
+ return r if r.shape != () or keepdims else r.item() # type: ignore
@@ -0,0 +1,181 @@
1
+ from math import ceil, floor
2
+
3
+ from ..cmvm.types import CombLogic, Op, Pipeline
4
+ from .fixed_variable import FixedVariable, HWConfig
5
+ from .tracer import comb_trace
6
+
7
+
8
+ def retime_pipeline(csol: Pipeline, verbose=True):
9
+ n_stages = len(csol[0])
10
+ cutoff_high = ceil(max(max(sol.out_latency) / (i + 1) for i, sol in enumerate(csol[0])))
11
+ cutoff_low = 0
12
+ adder_size, carry_size = csol[0][0].adder_size, csol[0][0].carry_size
13
+ best = csol
14
+ while cutoff_high - cutoff_low > 1:
15
+ cutoff = (cutoff_high + cutoff_low) // 2
16
+ _hwconf = HWConfig(adder_size, carry_size, cutoff)
17
+ inp = [FixedVariable(*qint, hwconf=_hwconf) for qint in csol.inp_qint]
18
+ try:
19
+ out = list(csol(inp))
20
+ except AssertionError:
21
+ cutoff_low = cutoff
22
+ continue
23
+ _sol = to_pipeline(comb_trace(inp, out), cutoff, retiming=False)
24
+ if len(_sol[0]) > n_stages:
25
+ cutoff_low = cutoff
26
+ else:
27
+ cutoff_high = cutoff
28
+ best = _sol
29
+ if verbose:
30
+ print(f'actual cutoff: {cutoff_high}')
31
+ return best
32
+
33
+
34
+ def _get_new_idx(
35
+ idx: int,
36
+ locator: list[dict[int, int]],
37
+ opd: dict[int, list[Op]],
38
+ out_idxd: dict[int, list[int]],
39
+ ops: list[Op],
40
+ stage: int,
41
+ latency_cutoff: float,
42
+ ):
43
+ if idx < 0:
44
+ return idx
45
+ p0_stages = locator[idx].keys()
46
+ if stage not in p0_stages:
47
+ # Need to copy parent to later states
48
+ p0_stage = max(p0_stages)
49
+ p0_idx = locator[idx][p0_stage]
50
+ for j in range(p0_stage, stage):
51
+ op0 = ops[idx]
52
+ latency = float(latency_cutoff * (j + 1))
53
+ out_idxd.setdefault(j, []).append(locator[idx][j])
54
+ _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op0.qint, latency, 0.0)
55
+ opd.setdefault(j + 1, []).append(_copy_op)
56
+ p0_idx = len(opd[j + 1]) - 1
57
+ locator[idx][j + 1] = p0_idx
58
+ else:
59
+ p0_idx = locator[idx][stage]
60
+ return p0_idx
61
+
62
+
63
+ def to_pipeline(comb: CombLogic, latency_cutoff: float, retiming=True, verbose=True) -> Pipeline:
64
+ """Split the record into multiple stages based on the latency of the operations.
65
+ Only useful for HDL generation.
66
+
67
+ Parameters
68
+ ----------
69
+ sol : CombLogic
70
+ The combinational logic to be pipelined into multiple stages.
71
+ latency_cutoff : float
72
+ The latency cutoff for splitting the operations.
73
+ retiming : bool
74
+ Whether to retime the solution after splitting. Default is True.
75
+ If False, new stages are created when the propagation latency exceeds the cutoff.
76
+ If True, after the first round of splitting, the solution is retimed balance the delay within each stage.
77
+ verbose : bool
78
+ Whether to print the actual cutoff used for splitting. Only used if rebalance is True.
79
+ Default is True.
80
+
81
+ Returns
82
+ -------
83
+ CascadedSolution
84
+ The cascaded solution with multiple stages.
85
+ """
86
+ assert len(comb.ops) > 0, 'No operations in the record'
87
+ for i, op in enumerate(comb.ops):
88
+ if op.id1 != -1:
89
+ break
90
+
91
+ def get_stage(op: Op):
92
+ return floor(op.latency / (latency_cutoff + 1e-9)) if latency_cutoff > 0 else 0
93
+
94
+ opd: dict[int, list[Op]] = {}
95
+ out_idxd: dict[int, list[int]] = {}
96
+
97
+ locator: list[dict[int, int]] = []
98
+
99
+ ops = comb.ops.copy()
100
+ lat = max(ops[i].latency for i in comb.out_idxs)
101
+ for i in comb.out_idxs:
102
+ op_out = ops[i]
103
+ ops.append(Op(i, -1001, -1001, 0, op_out.qint, lat, 0.0))
104
+
105
+ for i, op in enumerate(ops):
106
+ stage = get_stage(op)
107
+ if op.opcode == -1:
108
+ # Copy from external buffer
109
+ opd.setdefault(stage, []).append(op)
110
+ locator.append({stage: len(opd[stage]) - 1})
111
+ continue
112
+
113
+ p0_idx = _get_new_idx(op.id0, locator, opd, out_idxd, ops, stage, latency_cutoff)
114
+ p1_idx = _get_new_idx(op.id1, locator, opd, out_idxd, ops, stage, latency_cutoff)
115
+ if op.opcode in (6, -6):
116
+ k = op.data & 0xFFFFFFFF
117
+ _shift = (op.data >> 32) & 0xFFFFFFFF
118
+ k = _get_new_idx(k, locator, opd, out_idxd, ops, stage, latency_cutoff)
119
+ data = _shift << 32 | k
120
+ else:
121
+ data = op.data
122
+
123
+ if p1_idx == -1001:
124
+ # Output to external buffer
125
+ out_idxd.setdefault(stage, []).append(p0_idx)
126
+ else:
127
+ _Op = Op(p0_idx, p1_idx, op.opcode, data, op.qint, op.latency, op.cost)
128
+ opd.setdefault(stage, []).append(_Op)
129
+ locator.append({stage: len(opd[stage]) - 1})
130
+ sols = []
131
+ max_stage = max(opd.keys())
132
+ n_in = comb.shape[0]
133
+ for i, stage in enumerate(opd.keys()):
134
+ _ops = opd[stage]
135
+ _out_idx = out_idxd[stage]
136
+ n_out = len(_out_idx)
137
+
138
+ if i == max_stage:
139
+ out_shifts = comb.out_shifts
140
+ out_negs = comb.out_negs
141
+ else:
142
+ out_shifts = [0] * len(_out_idx)
143
+ out_negs = [False] * len(_out_idx)
144
+
145
+ if comb.lookup_tables is not None:
146
+ _ops, lookup_tables = remap_table_idxs(comb, _ops)
147
+ else:
148
+ lookup_tables = None
149
+ _sol = CombLogic(
150
+ shape=(n_in, n_out),
151
+ inp_shifts=[0] * n_in,
152
+ out_idxs=_out_idx,
153
+ out_shifts=out_shifts,
154
+ out_negs=out_negs,
155
+ ops=_ops,
156
+ carry_size=comb.carry_size,
157
+ adder_size=comb.adder_size,
158
+ lookup_tables=lookup_tables,
159
+ )
160
+ sols.append(_sol)
161
+
162
+ n_in = n_out
163
+ csol = Pipeline(tuple(sols))
164
+
165
+ if retiming:
166
+ csol = retime_pipeline(csol, verbose=verbose)
167
+ return csol
168
+
169
+
170
+ def remap_table_idxs(comb: CombLogic, _ops):
171
+ assert comb.lookup_tables is not None
172
+ table_idxs = sorted(list({op.data for op in _ops if op.opcode == 8}))
173
+ remap = {j: i for i, j in enumerate(table_idxs)}
174
+ _ops_remap = []
175
+ for op in _ops:
176
+ if op.opcode == 8:
177
+ op = Op(op.id0, op.id1, op.opcode, remap[op.data], op.qint, op.latency, op.cost)
178
+ _ops_remap.append(op)
179
+ _ops = _ops_remap
180
+ lookup_tables = tuple(comb.lookup_tables[i] for i in table_idxs)
181
+ return _ops, lookup_tables