da4ml 0.5.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-312-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +194 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +240 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.0.dist-info/METADATA +85 -0
  92. da4ml-0.5.0.dist-info/RECORD +96 -0
  93. da4ml-0.5.0.dist-info/WHEEL +6 -0
  94. da4ml-0.5.0.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.0.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
@@ -0,0 +1,600 @@
1
+ from collections.abc import Callable
2
+ from inspect import signature
3
+ from typing import TypeVar
4
+
5
+ import numpy as np
6
+ from numba.typed import List as NumbaList
7
+ from numpy.typing import NDArray
8
+
9
+ from ..cmvm.api import solve, solver_options_t
10
+ from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, LookupTable, QInterval
11
+ from .ops import _quantize, einsum, reduce
12
+
13
+ T = TypeVar('T')
14
+
15
+
16
+ def to_raw_arr(obj: T) -> T:
17
+ if isinstance(obj, tuple):
18
+ return tuple(to_raw_arr(x) for x in obj) # type: ignore
19
+ elif isinstance(obj, list):
20
+ return [to_raw_arr(x) for x in obj] # type: ignore
21
+ elif isinstance(obj, dict):
22
+ return {k: to_raw_arr(v) for k, v in obj.items()} # type: ignore
23
+ if isinstance(obj, FixedVariableArray):
24
+ return obj._vars # type: ignore
25
+ return obj
26
+
27
+
28
+ def _max_of(a, b):
29
+ if isinstance(a, FixedVariable):
30
+ return a.max_of(b)
31
+ elif isinstance(b, FixedVariable):
32
+ return b.max_of(a)
33
+ else:
34
+ return max(a, b)
35
+
36
+
37
+ def _min_of(a, b):
38
+ if isinstance(a, FixedVariable):
39
+ return a.min_of(b)
40
+ elif isinstance(b, FixedVariable):
41
+ return b.min_of(a)
42
+ else:
43
+ return min(a, b)
44
+
45
+
46
+ def mmm(mat0: np.ndarray, mat1: np.ndarray):
47
+ shape = mat0.shape[:-1] + mat1.shape[1:]
48
+ mat0, mat1 = mat0.reshape((-1, mat0.shape[-1])), mat1.reshape((mat1.shape[0], -1))
49
+ _shape = (mat0.shape[0], mat1.shape[1])
50
+ _vars = np.empty(_shape, dtype=object)
51
+ for i in range(mat0.shape[0]):
52
+ for j in range(mat1.shape[1]):
53
+ vec0 = mat0[i]
54
+ vec1 = mat1[:, j]
55
+ _vars[i, j] = reduce(lambda x, y: x + y, vec0 * vec1)
56
+ return _vars.reshape(shape)
57
+
58
+
59
+ def cmvm(cm: np.ndarray, v: 'FixedVariableArray', solver_options: solver_options_t) -> np.ndarray:
60
+ mask = offload_mask(cm, v)
61
+ if np.any(mask):
62
+ offload_cm = cm * mask.astype(cm.dtype)
63
+ cm = cm * (~mask).astype(cm.dtype)
64
+ else:
65
+ offload_cm = None
66
+ _qintervals = [QInterval(float(_v.low), float(_v.high), float(_v.step)) for _v in v._vars]
67
+ _latencies = [float(_v.latency) for _v in v._vars]
68
+ qintervals = NumbaList(_qintervals) # type: ignore
69
+ latencies = NumbaList(_latencies) # type: ignore
70
+ hwconf = v._vars.ravel()[0].hwconf
71
+ solver_options.setdefault('adder_size', hwconf.adder_size)
72
+ solver_options.setdefault('carry_size', hwconf.carry_size)
73
+ _mat = np.ascontiguousarray(cm.astype(np.float32))
74
+ sol = solve(_mat, qintervals=qintervals, latencies=latencies, **solver_options)
75
+ _r: np.ndarray = sol(v._vars)
76
+ if offload_cm is not None:
77
+ _r = _r + mmm(v._vars, offload_cm)
78
+ return _r
79
+
80
+
81
+ def offload_mask(cm: NDArray, v: 'FixedVariableArray') -> NDArray[np.bool_]:
82
+ assert v.ndim == 1
83
+ assert cm.ndim == 2
84
+ assert cm.shape[0] == v.shape[0]
85
+ bits = np.sum(v.kif, axis=0)[:, None]
86
+ return (bits == 0) & (cm != 0)
87
+
88
+
89
+ _unary_functions = (
90
+ np.sin,
91
+ np.cos,
92
+ np.tan,
93
+ np.exp,
94
+ np.log,
95
+ np.invert,
96
+ np.sqrt,
97
+ np.tanh,
98
+ np.sinh,
99
+ np.cosh,
100
+ np.arccos,
101
+ np.arcsin,
102
+ np.arctan,
103
+ np.arcsinh,
104
+ np.arccosh,
105
+ np.arctanh,
106
+ np.exp2,
107
+ np.expm1,
108
+ np.log2,
109
+ np.log10,
110
+ np.log1p,
111
+ np.cbrt,
112
+ np.reciprocal,
113
+ )
114
+
115
+
116
+ class FixedVariableArray:
117
+ """Symbolic array of FixedVariable for tracing operations. Supports numpy ufuncs and array functions."""
118
+
119
+ __array_priority__ = 100
120
+
121
+ def __array_function__(self, func, types, args, kwargs):
122
+ if func in (np.mean, np.sum, np.amax, np.amin, np.prod, np.max, np.min):
123
+ match func:
124
+ case np.mean:
125
+ _x = reduce(lambda x, y: x + y, *args, **kwargs)
126
+ return _x * (_x.size / self._vars.size)
127
+ case np.sum:
128
+ return reduce(lambda x, y: x + y, *args, **kwargs)
129
+ case np.max | np.amax:
130
+ return reduce(_max_of, *args, **kwargs)
131
+ case np.min | np.amin:
132
+ return reduce(_min_of, *args, **kwargs)
133
+ case np.prod:
134
+ return reduce(lambda x, y: x * y, *args, **kwargs)
135
+ case _:
136
+ raise NotImplementedError(f'Unsupported function: {func}')
137
+
138
+ if func is np.clip:
139
+ assert len(args) == 3, 'Clip function requires exactly three arguments'
140
+ x, low, high = args
141
+ _x, low, high = np.broadcast_arrays(x, low, high)
142
+ x = FixedVariableArray(_x, self.solver_options)
143
+ x = np.amax(np.stack((x, low), axis=-1), axis=-1) # type: ignore
144
+ return np.amin(np.stack((x, high), axis=-1), axis=-1)
145
+
146
+ if func is np.einsum:
147
+ # assert len(args) == 2
148
+ sig = signature(np.einsum)
149
+ bind = sig.bind(*args, **kwargs)
150
+ eq = args[0]
151
+ operands = bind.arguments['operands']
152
+ if isinstance(operands[0], str):
153
+ operands = operands[1:]
154
+ assert len(operands) == 2, 'Einsum on FixedVariableArray requires exactly two operands'
155
+ assert bind.arguments.get('out', None) is None, 'Output argument is not supported'
156
+ return einsum(eq, *operands)
157
+
158
+ if func is np.dot:
159
+ assert len(args) in (2, 3), 'Dot function requires exactly two or three arguments'
160
+
161
+ assert len(args) == 2
162
+ a, b = args
163
+ if not isinstance(a, FixedVariableArray):
164
+ a = np.array(a)
165
+ if not isinstance(b, FixedVariableArray):
166
+ b = np.array(b)
167
+ if a.shape[-1] == b.shape[0]:
168
+ return a @ b
169
+
170
+ assert a.size == 1 or b.size == 1, f'Error in dot product: {a.shape} @ {b.shape}'
171
+ return a * b
172
+
173
+ args, kwargs = to_raw_arr(args), to_raw_arr(kwargs)
174
+ return FixedVariableArray(
175
+ func(*args, **kwargs),
176
+ self.solver_options,
177
+ )
178
+
179
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
180
+ assert method == '__call__', f'Only __call__ method is supported for ufuncs, got {method}'
181
+
182
+ match ufunc:
183
+ case np.add | np.subtract | np.multiply | np.true_divide | np.negative:
184
+ inputs = [to_raw_arr(x) for x in inputs]
185
+ return FixedVariableArray(ufunc(*inputs, **kwargs), self.solver_options)
186
+
187
+ case np.negative:
188
+ assert len(inputs) == 1
189
+ return FixedVariableArray(ufunc(to_raw_arr(inputs[0]), **kwargs), self.solver_options)
190
+
191
+ case np.maximum | np.minimum:
192
+ op = _max_of if ufunc is np.maximum else _min_of
193
+ a, b = np.broadcast_arrays(inputs[0], inputs[1])
194
+ shape = a.shape
195
+ a, b = a.ravel(), b.ravel()
196
+ r = np.empty(a.size, dtype=object)
197
+ for i in range(a.size):
198
+ r[i] = op(a[i], b[i])
199
+ return FixedVariableArray(r.reshape(shape), self.solver_options)
200
+
201
+ case np.matmul:
202
+ assert len(inputs) == 2
203
+ assert isinstance(inputs[0], FixedVariableArray) or isinstance(inputs[1], FixedVariableArray)
204
+ if isinstance(inputs[0], FixedVariableArray):
205
+ return inputs[0].matmul(inputs[1])
206
+ else:
207
+ return inputs[1].rmatmul(inputs[0])
208
+
209
+ case np.power:
210
+ assert len(inputs) == 2
211
+ base, exp = inputs
212
+ return base**exp
213
+
214
+ case np.abs | np.absolute:
215
+ assert len(inputs) == 1
216
+ assert inputs[0] is self
217
+ arr = self._vars.ravel()
218
+ r = np.array([v.__abs__() for v in arr])
219
+ return FixedVariableArray(r.reshape(self.shape), self.solver_options)
220
+
221
+ case np.square:
222
+ assert len(inputs) == 1
223
+ assert inputs[0] is self
224
+ return self**2
225
+
226
+ if ufunc in _unary_functions:
227
+ assert len(inputs) == 1
228
+ assert inputs[0] is self
229
+ return self.apply(ufunc)
230
+
231
+ raise NotImplementedError(f'Unsupported ufunc: {ufunc}')
232
+
233
+ def __init__(
234
+ self,
235
+ vars: NDArray,
236
+ solver_options: solver_options_t | None = None,
237
+ ):
238
+ _vars = np.array(vars)
239
+ _vars_f = _vars.ravel()
240
+ hwconf = next(iter(v for v in _vars_f if isinstance(v, FixedVariable))).hwconf
241
+ for i, v in enumerate(_vars_f):
242
+ if not isinstance(v, FixedVariable):
243
+ _vars_f[i] = FixedVariable(float(v), float(v), 1.0, hwconf=hwconf)
244
+ self._vars = _vars
245
+ _solver_options = signature(solve).parameters
246
+ _solver_options = {k: v.default for k, v in _solver_options.items() if v.default is not v.empty}
247
+ if solver_options is not None:
248
+ _solver_options.update(solver_options)
249
+ _solver_options.pop('qintervals', None)
250
+ _solver_options.pop('latencies', None)
251
+ self.solver_options: solver_options_t = _solver_options # type: ignore
252
+
253
+ @classmethod
254
+ def from_lhs(
255
+ cls,
256
+ low: NDArray[np.floating],
257
+ high: NDArray[np.floating],
258
+ step: NDArray[np.floating],
259
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
260
+ latency: np.ndarray | float = 0.0,
261
+ solver_options: solver_options_t | None = None,
262
+ ):
263
+ low, high, step = np.array(low), np.array(high), np.array(step)
264
+ shape = low.shape
265
+ assert shape == high.shape == step.shape
266
+
267
+ low, high, step = low.ravel(), high.ravel(), step.ravel()
268
+ latency = np.full_like(low, latency) if isinstance(latency, (int, float)) else latency.ravel()
269
+
270
+ vars = []
271
+ for l, h, s, lat in zip(low, high, step, latency):
272
+ var = FixedVariable(
273
+ low=float(l),
274
+ high=float(h),
275
+ step=float(s),
276
+ hwconf=hwconf,
277
+ latency=float(
278
+ lat,
279
+ ),
280
+ )
281
+ vars.append(var)
282
+ vars = np.array(vars).reshape(shape)
283
+ return cls(vars, solver_options)
284
+
285
+ __array_priority__ = 100
286
+
287
+ @classmethod
288
+ def from_kif(
289
+ cls,
290
+ k: NDArray[np.bool_ | np.integer],
291
+ i: NDArray[np.integer],
292
+ f: NDArray[np.integer],
293
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
294
+ latency: NDArray[np.floating] | float = 0.0,
295
+ solver_options: solver_options_t | None = None,
296
+ ):
297
+ mask = k + i + f <= 0
298
+ k = np.where(mask, 0, k)
299
+ i = np.where(mask, 0, i)
300
+ f = np.where(mask, 0, f)
301
+ step = 2.0**-f
302
+ _high = 2.0**i
303
+ high, low = _high - step, -_high * k
304
+ return cls.from_lhs(low, high, step, hwconf, latency, solver_options)
305
+
306
+ def matmul(self, other) -> 'FixedVariableArray':
307
+ if self.collapsed:
308
+ self_mat = np.array([v.low for v in self._vars.ravel()], dtype=np.float64).reshape(self._vars.shape)
309
+ if isinstance(other, FixedVariableArray):
310
+ if not other.collapsed:
311
+ return self_mat @ other # type: ignore
312
+ other_mat = np.array([v.low for v in other._vars.ravel()], dtype=np.float64).reshape(other._vars.shape)
313
+ else:
314
+ other_mat = np.array(other, dtype=np.float64)
315
+
316
+ r = self_mat @ other_mat
317
+ return FixedVariableArray.from_lhs(
318
+ low=r,
319
+ high=r,
320
+ step=np.ones_like(r),
321
+ hwconf=self._vars.ravel()[0].hwconf,
322
+ solver_options=self.solver_options,
323
+ )
324
+
325
+ if isinstance(other, FixedVariableArray):
326
+ other = other._vars
327
+ if not isinstance(other, np.ndarray):
328
+ other = np.array(other)
329
+ if any(isinstance(x, FixedVariable) for x in other.ravel()):
330
+ mat0, mat1 = self._vars, other
331
+ _vars = mmm(mat0, mat1)
332
+ return FixedVariableArray(_vars, self.solver_options)
333
+
334
+ solver_options = (self.solver_options or {}).copy()
335
+ shape0, shape1 = self.shape, other.shape
336
+ assert shape0[-1] == shape1[0], f'Matrix shapes do not match: {shape0} @ {shape1}'
337
+ contract_len = shape1[0]
338
+ out_shape = shape0[:-1] + shape1[1:]
339
+ mat0, mat1 = self.reshape((-1, contract_len)), other.reshape((contract_len, -1))
340
+ r = []
341
+ for i in range(mat0.shape[0]):
342
+ vec = mat0[i]
343
+ _r = cmvm(mat1, vec, solver_options)
344
+ r.append(_r)
345
+ r = np.array(r).reshape(out_shape)
346
+ return FixedVariableArray(r, self.solver_options)
347
+
348
+ def __matmul__(self, other):
349
+ return self.matmul(other)
350
+
351
+ def rmatmul(self, other):
352
+ mat1 = np.moveaxis(other, -1, 0)
353
+ mat0 = np.moveaxis(self, 0, -1) # type: ignore
354
+ ndim0, ndim1 = mat0.ndim, mat1.ndim
355
+ r = mat0 @ mat1
356
+
357
+ _axes = tuple(range(0, ndim0 + ndim1 - 2))
358
+ axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
359
+ return r.transpose(axes)
360
+
361
+ def __rmatmul__(self, other):
362
+ return self.rmatmul(other)
363
+
364
+ def __getitem__(self, item):
365
+ vars = self._vars[item]
366
+ if isinstance(vars, np.ndarray):
367
+ return FixedVariableArray(vars, self.solver_options)
368
+ else:
369
+ return vars
370
+
371
+ def __len__(self):
372
+ return len(self._vars)
373
+
374
+ @property
375
+ def shape(self):
376
+ return self._vars.shape
377
+
378
+ def __add__(self, other):
379
+ if isinstance(other, FixedVariableArray):
380
+ return FixedVariableArray(self._vars + other._vars, self.solver_options)
381
+ return FixedVariableArray(self._vars + other, self.solver_options)
382
+
383
+ def __sub__(self, other):
384
+ if isinstance(other, FixedVariableArray):
385
+ return FixedVariableArray(self._vars - other._vars, self.solver_options)
386
+ return FixedVariableArray(self._vars - other, self.solver_options)
387
+
388
+ def __mul__(self, other):
389
+ if isinstance(other, FixedVariableArray):
390
+ return FixedVariableArray(self._vars * other._vars, self.solver_options)
391
+ return FixedVariableArray(self._vars * other, self.solver_options)
392
+
393
+ def __truediv__(self, other):
394
+ return FixedVariableArray(self._vars * (1 / other), self.solver_options)
395
+
396
+ def __radd__(self, other):
397
+ return self + other
398
+
399
+ def __neg__(self):
400
+ return FixedVariableArray(-self._vars, self.solver_options)
401
+
402
+ def __repr__(self):
403
+ shape = self._vars.shape
404
+ hwconf_str = str(self._vars.ravel()[0].hwconf)[8:]
405
+ max_lat = max(v.latency for v in self._vars.ravel())
406
+ return f'FixedVariableArray(shape={shape}, hwconf={hwconf_str}, latency={max_lat})'
407
+
408
+ def __pow__(self, power: int | float):
409
+ _power = int(power)
410
+ if _power == power and _power >= 0:
411
+ return FixedVariableArray(self._vars**_power, self.solver_options)
412
+ else:
413
+ return self.apply(lambda x: x**power)
414
+
415
+ def relu(
416
+ self,
417
+ i: NDArray[np.integer] | None = None,
418
+ f: NDArray[np.integer] | None = None,
419
+ round_mode: str = 'TRN',
420
+ ):
421
+ shape = self._vars.shape
422
+ i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
423
+ f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
424
+ ret = []
425
+ for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()): # type: ignore
426
+ ret.append(v.relu(i=i, f=f, round_mode=round_mode))
427
+ return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
428
+
429
+ def quantize(
430
+ self,
431
+ k: NDArray[np.integer] | np.integer | int | None = None,
432
+ i: NDArray[np.integer] | np.integer | int | None = None,
433
+ f: NDArray[np.integer] | np.integer | int | None = None,
434
+ overflow_mode: str = 'WRAP',
435
+ round_mode: str = 'TRN',
436
+ ):
437
+ shape = self._vars.shape
438
+ if any(x is None for x in (k, i, f)):
439
+ kif = self.kif
440
+ k = np.broadcast_to(k, shape) if k is not None else kif[0] # type: ignore
441
+ i = np.broadcast_to(i, shape) if i is not None else kif[1] # type: ignore
442
+ f = np.broadcast_to(f, shape) if f is not None else kif[2] # type: ignore
443
+ ret = []
444
+ for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()): # type: ignore
445
+ ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
446
+ return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
447
+
448
+ def flatten(self):
449
+ return FixedVariableArray(self._vars.flatten(), self.solver_options)
450
+
451
+ def reshape(self, *shape):
452
+ return FixedVariableArray(self._vars.reshape(*shape), self.solver_options)
453
+
454
+ def transpose(self, axes=None):
455
+ return FixedVariableArray(self._vars.transpose(axes), self.solver_options)
456
+
457
+ def ravel(self):
458
+ return FixedVariableArray(self._vars.ravel(), self.solver_options)
459
+
460
+ @property
461
+ def dtype(self):
462
+ return self._vars.dtype
463
+
464
+ @property
465
+ def size(self):
466
+ return self._vars.size
467
+
468
+ @property
469
+ def ndim(self):
470
+ return self._vars.ndim
471
+
472
+ @property
473
+ def kif(self):
474
+ """[k, i, f] array"""
475
+ shape = self._vars.shape
476
+ kif = np.array([v.kif for v in self._vars.ravel()]).reshape(*shape, 3)
477
+ return np.moveaxis(kif, -1, 0)
478
+
479
+ @property
480
+ def lhs(self):
481
+ """[low, high, step] array"""
482
+ shape = self._vars.shape
483
+ lhs = np.array([(v.low, v.high, v.step) for v in self._vars.ravel()], dtype=np.float32).reshape(*shape, 3)
484
+ return np.moveaxis(lhs, -1, 0)
485
+
486
+ @property
487
+ def latency(self):
488
+ """Maximum latency among all elements."""
489
+ return np.array([v.latency for v in self._vars.ravel()]).reshape(self._vars.shape)
490
+
491
+ @property
492
+ def collapsed(self):
493
+ return all(v.low == v.high for v in self._vars.ravel())
494
+
495
+ def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
496
+ """Apply a unary operator to all elements, returning a RetardedFixedVariableArray."""
497
+ return RetardedFixedVariableArray(
498
+ self._vars,
499
+ self.solver_options,
500
+ operator=fn,
501
+ )
502
+
503
+ @property
504
+ def T(self):
505
+ return self.transpose()
506
+
507
+
508
+ class FixedVariableArrayInput(FixedVariableArray):
509
+ """Similar to FixedVariableArray, but initializes all elements as FixedVariableInput - the precisions are unspecified when initialized, and the highest precision requested (i.e., quantized to) will be recorded for generation of the logic."""
510
+
511
+ def __init__(
512
+ self,
513
+ shape: tuple[int, ...] | int,
514
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
515
+ solver_options: solver_options_t | None = None,
516
+ latency=0.0,
517
+ ):
518
+ _vars = np.empty(shape, dtype=object)
519
+ _vars_f = _vars.ravel()
520
+ for i in range(_vars.size):
521
+ _vars_f[i] = FixedVariableInput(latency, hwconf)
522
+ super().__init__(_vars, solver_options)
523
+
524
+
525
+ def make_table(fn: Callable[[NDArray], NDArray], qint: QInterval) -> LookupTable:
526
+ low, high, step = qint
527
+ n = round(abs(high - low) / step) + 1
528
+ return LookupTable(fn(np.linspace(low, high, n)))
529
+
530
+
531
+ class RetardedFixedVariableArray(FixedVariableArray):
532
+ """Ephemeral FixedVariableArray generated from operations of unspecified output precision.
533
+ This object translates to normal FixedVariableArray upon quantization.
534
+ Does not inherit the maximum precision like FixedVariableArrayInput.
535
+
536
+ This object can be used in two ways:
537
+ 1. Quantization with specified precision, which converts to FixedVariableArray.
538
+ 2. Apply an further unary operation, which returns another RetardedFixedVariableArray. (e.g., composite functions)
539
+ """
540
+
541
+ def __init__(self, vars: NDArray, solver_options: solver_options_t | None, operator: Callable[[NDArray], NDArray]):
542
+ self._operator = operator
543
+ super().__init__(vars, solver_options)
544
+
545
+ def __array_function__(self, ufunc, method, *inputs, **kwargs):
546
+ raise RuntimeError('RetardedFixedVariableArray only supports quantization or further unary operations.')
547
+
548
+ def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
549
+ return RetardedFixedVariableArray(
550
+ self._vars,
551
+ self.solver_options,
552
+ operator=lambda x: fn(self._operator(x)),
553
+ )
554
+
555
+ def quantize(
556
+ self,
557
+ k: NDArray[np.integer] | np.integer | int | None = None,
558
+ i: NDArray[np.integer] | np.integer | int | None = None,
559
+ f: NDArray[np.integer] | np.integer | int | None = None,
560
+ overflow_mode: str = 'WRAP',
561
+ round_mode: str = 'TRN',
562
+ ):
563
+ if any(x is None for x in (k, i, f)):
564
+ assert all(x is not None for x in (k, i, f)), 'Either all or none of k, i, f must be specified'
565
+ _k = _i = _f = [None] * self.size
566
+ else:
567
+ _k = np.broadcast_to(k, self.shape).ravel() # type: ignore
568
+ _i = np.broadcast_to(i, self.shape).ravel() # type: ignore
569
+ _f = np.broadcast_to(f, self.shape).ravel() # type: ignore
570
+
571
+ op = lambda x: _quantize(self._operator(x), k, i, f, overflow_mode, round_mode) # type: ignore
572
+
573
+ local_tables: dict[tuple[QInterval, tuple[int, int, int]] | QInterval, LookupTable] = {}
574
+ variables = []
575
+ for v, _kk, _ii, _ff in zip(self._vars.ravel(), _k, _i, _f):
576
+ v: FixedVariable
577
+ qint = v.qint if v._factor >= 0 else QInterval(v.qint.max, v.qint.min, v.qint.step)
578
+ if (_kk is None) or (_ii is None) or (_ff is None):
579
+ op = self._operator
580
+ _key = qint
581
+ else:
582
+ op = lambda x: _quantize(self._operator(x), _kk, _ii, _ff, overflow_mode, round_mode) # type: ignore
583
+ _key = (qint, (int(_kk), int(_ii), int(_ff)))
584
+
585
+ if _key in local_tables:
586
+ table = local_tables[_key]
587
+ else:
588
+ table = make_table(op, qint)
589
+ local_tables[_key] = table
590
+ variables.append(v.lookup(table))
591
+
592
+ variables = np.array(variables).reshape(self._vars.shape)
593
+ return FixedVariableArray(variables, self.solver_options)
594
+
595
+ def __repr__(self):
596
+ return 'Retarded' + super().__repr__()
597
+
598
+ @property
599
+ def kif(self):
600
+ raise RuntimeError('RetardedFixedVariableArray does not have defined kif until quantized.')
@@ -0,0 +1,13 @@
1
+ from .einsum_utils import einsum
2
+ from .quantization import _quantize, quantize, relu
3
+ from .reduce_utils import reduce
4
+
5
+ __all__ = [
6
+ 'einsum',
7
+ 'relu',
8
+ 'quantization',
9
+ 'reduce',
10
+ '_quantize',
11
+ 'relu',
12
+ 'quantize',
13
+ ]