da4ml 0.4.1__py3-none-any.whl → 0.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (40) hide show
  1. da4ml/__init__.py +2 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +2 -2
  4. da4ml/cmvm/api.py +15 -4
  5. da4ml/cmvm/core/__init__.py +2 -2
  6. da4ml/cmvm/types.py +32 -18
  7. da4ml/cmvm/util/bit_decompose.py +2 -2
  8. da4ml/codegen/hls/hls_codegen.py +10 -5
  9. da4ml/codegen/hls/hls_model.py +7 -4
  10. da4ml/codegen/rtl/common_source/build_binder.mk +6 -5
  11. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  12. da4ml/codegen/rtl/common_source/{build_prj.tcl → build_vivado_prj.tcl} +39 -18
  13. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  14. da4ml/codegen/rtl/common_source/template.xdc +11 -13
  15. da4ml/codegen/rtl/rtl_model.py +105 -54
  16. da4ml/codegen/rtl/verilog/__init__.py +2 -1
  17. da4ml/codegen/rtl/verilog/comb.py +47 -7
  18. da4ml/codegen/rtl/verilog/io_wrapper.py +4 -4
  19. da4ml/codegen/rtl/verilog/pipeline.py +12 -12
  20. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  21. da4ml/codegen/rtl/vhdl/comb.py +27 -21
  22. da4ml/codegen/rtl/vhdl/io_wrapper.py +11 -11
  23. da4ml/codegen/rtl/vhdl/pipeline.py +12 -12
  24. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  25. da4ml/converter/__init__.py +57 -1
  26. da4ml/converter/hgq2/parser.py +4 -25
  27. da4ml/converter/hgq2/replica.py +208 -22
  28. da4ml/trace/fixed_variable.py +239 -29
  29. da4ml/trace/fixed_variable_array.py +276 -48
  30. da4ml/trace/ops/__init__.py +31 -15
  31. da4ml/trace/ops/reduce_utils.py +3 -3
  32. da4ml/trace/pipeline.py +40 -18
  33. da4ml/trace/tracer.py +33 -8
  34. da4ml/typing/__init__.py +3 -0
  35. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/METADATA +2 -1
  36. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/RECORD +39 -35
  37. da4ml/codegen/rtl/vhdl/source/template.xdc +0 -32
  38. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/WHEEL +0 -0
  39. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/licenses/LICENSE +0 -0
  40. {da4ml-0.4.1.dist-info → da4ml-0.5.0b0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,15 @@
1
+ from collections.abc import Callable
2
+ from decimal import Decimal
1
3
  from inspect import signature
2
- from typing import Any, TypeVar
4
+ from typing import TypeVar
3
5
 
4
6
  import numpy as np
5
7
  from numba.typed import List as NumbaList
6
8
  from numpy.typing import NDArray
7
9
 
8
- from ..cmvm import solve
9
- from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, QInterval
10
- from .ops import einsum, reduce
10
+ from ..cmvm.api import solve, solver_options_t
11
+ from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, LookupTable, QInterval
12
+ from .ops import _quantize, einsum, reduce
11
13
 
12
14
  T = TypeVar('T')
13
15
 
@@ -42,7 +44,79 @@ def _min_of(a, b):
42
44
  return min(a, b)
43
45
 
44
46
 
47
+ def mmm(mat0: np.ndarray, mat1: np.ndarray):
48
+ shape = mat0.shape[:-1] + mat1.shape[1:]
49
+ mat0, mat1 = mat0.reshape((-1, mat0.shape[-1])), mat1.reshape((mat1.shape[0], -1))
50
+ _shape = (mat0.shape[0], mat1.shape[1])
51
+ _vars = np.empty(_shape, dtype=object)
52
+ for i in range(mat0.shape[0]):
53
+ for j in range(mat1.shape[1]):
54
+ vec0 = mat0[i]
55
+ vec1 = mat1[:, j]
56
+ _vars[i, j] = reduce(lambda x, y: x + y, vec0 * vec1)
57
+ return _vars.reshape(shape)
58
+
59
+
60
+ def cmvm(cm: np.ndarray, v: 'FixedVariableArray', solver_options: solver_options_t) -> np.ndarray:
61
+ mask = offload_mask(cm, v)
62
+ if np.any(mask):
63
+ offload_cm = cm * mask.astype(cm.dtype)
64
+ cm = cm * (~mask).astype(cm.dtype)
65
+ else:
66
+ offload_cm = None
67
+ _qintervals = [QInterval(float(_v.low), float(_v.high), float(_v.step)) for _v in v._vars]
68
+ _latencies = [float(_v.latency) for _v in v._vars]
69
+ qintervals = NumbaList(_qintervals) # type: ignore
70
+ latencies = NumbaList(_latencies) # type: ignore
71
+ hwconf = v._vars.ravel()[0].hwconf
72
+ solver_options.setdefault('adder_size', hwconf.adder_size)
73
+ solver_options.setdefault('carry_size', hwconf.carry_size)
74
+ _mat = np.ascontiguousarray(cm.astype(np.float32))
75
+ sol = solve(_mat, qintervals=qintervals, latencies=latencies, **solver_options)
76
+ _r: np.ndarray = sol(v._vars)
77
+ if offload_cm is not None:
78
+ _r = _r + mmm(v._vars, offload_cm)
79
+ return _r
80
+
81
+
82
+ def offload_mask(cm: NDArray, v: 'FixedVariableArray') -> NDArray[np.bool_]:
83
+ assert v.ndim == 1
84
+ assert cm.ndim == 2
85
+ assert cm.shape[0] == v.shape[0]
86
+ bits = np.sum(v.kif, axis=0)[:, None]
87
+ return (bits == 0) & (cm != 0)
88
+
89
+
90
+ _unary_functions = (
91
+ np.sin,
92
+ np.cos,
93
+ np.tan,
94
+ np.exp,
95
+ np.log,
96
+ np.invert,
97
+ np.sqrt,
98
+ np.tanh,
99
+ np.sinh,
100
+ np.cosh,
101
+ np.arccos,
102
+ np.arcsin,
103
+ np.arctan,
104
+ np.arcsinh,
105
+ np.arccosh,
106
+ np.arctanh,
107
+ np.exp2,
108
+ np.expm1,
109
+ np.log2,
110
+ np.log10,
111
+ np.log1p,
112
+ np.cbrt,
113
+ np.reciprocal,
114
+ )
115
+
116
+
45
117
  class FixedVariableArray:
118
+ """Symbolic array of FixedVariable for tracing operations. Supports numpy ufuncs and array functions."""
119
+
46
120
  __array_priority__ = 100
47
121
 
48
122
  def __array_function__(self, func, types, args, kwargs):
@@ -52,17 +126,19 @@ class FixedVariableArray:
52
126
  elif len(args) == 2 and isinstance(args[0], np.ndarray) and isinstance(args[1], np.ndarray):
53
127
  return self.__rmatmul__(args[1])
54
128
 
55
- if func in (np.mean, np.sum, np.amax, np.amin, np.max, np.min):
129
+ if func in (np.mean, np.sum, np.amax, np.amin, np.prod, np.max, np.min):
56
130
  match func:
57
131
  case np.mean:
58
- _x = reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
132
+ _x = reduce(lambda x, y: x + y, *args, **kwargs)
59
133
  return _x * (_x.size / self._vars.size)
60
134
  case np.sum:
61
- return reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
135
+ return reduce(lambda x, y: x + y, *args, **kwargs)
62
136
  case np.max | np.amax:
63
- return reduce(_max_of, self, *args[1:], **kwargs)
137
+ return reduce(_max_of, *args, **kwargs)
64
138
  case np.min | np.amin:
65
- return reduce(_min_of, self, *args[1:], **kwargs)
139
+ return reduce(_min_of, *args, **kwargs)
140
+ case np.prod:
141
+ return reduce(lambda x, y: x * y, *args, **kwargs)
66
142
  case _:
67
143
  raise NotImplementedError(f'Unsupported function: {func}')
68
144
 
@@ -86,7 +162,7 @@ class FixedVariableArray:
86
162
  assert bind.arguments.get('out', None) is None, 'Output argument is not supported'
87
163
  return einsum(eq, *operands)
88
164
 
89
- if func in (np.dot, np.matmul):
165
+ if func is np.dot:
90
166
  assert len(args) in (2, 3), 'Dot function requires exactly two or three arguments'
91
167
 
92
168
  assert len(args) == 2
@@ -107,19 +183,85 @@ class FixedVariableArray:
107
183
  self.solver_options,
108
184
  )
109
185
 
186
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
187
+ assert method == '__call__', f'Only __call__ method is supported for ufuncs, got {method}'
188
+
189
+ match ufunc:
190
+ case np.add | np.subtract | np.multiply | np.true_divide | np.negative:
191
+ inputs = [to_raw_arr(x) for x in inputs]
192
+ return FixedVariableArray(ufunc(*inputs, **kwargs), self.solver_options)
193
+ case np.negative:
194
+ assert len(inputs) == 1
195
+ return FixedVariableArray(ufunc(to_raw_arr(inputs[0]), **kwargs), self.solver_options)
196
+ case np.maximum | np.minimum:
197
+ op = _max_of if ufunc is np.maximum else _min_of
198
+ a, b = np.broadcast_arrays(inputs[0], inputs[1])
199
+ shape = a.shape
200
+ a, b = a.ravel(), b.ravel()
201
+ r = np.empty(a.size, dtype=object)
202
+ for i in range(a.size):
203
+ r[i] = op(a[i], b[i])
204
+ return FixedVariableArray(r.reshape(shape), self.solver_options)
205
+ case np.matmul:
206
+ assert len(inputs) == 2
207
+ assert isinstance(inputs[0], FixedVariableArray) or isinstance(inputs[1], FixedVariableArray)
208
+ if isinstance(inputs[0], FixedVariableArray):
209
+ return inputs[0].matmul(inputs[1])
210
+ else:
211
+ return inputs[1].rmatmul(inputs[0])
212
+ case np.power:
213
+ assert len(inputs) == 2
214
+ base, exp = inputs
215
+ return base**exp
216
+
217
+ case np.abs | np.absolute:
218
+ assert len(inputs) == 1
219
+ assert inputs[0] is self
220
+ mask: np.ndarray = (self.kif[0] == 0).ravel()
221
+ arr = self._vars.ravel()
222
+
223
+ r = np.empty(arr.size, dtype=object)
224
+ for i in range(arr.size):
225
+ if mask[i]:
226
+ r[i] = arr[i]
227
+ continue
228
+ v = arr[i]
229
+ v = v.msb_mux(-v, v)
230
+ v.low = Decimal(0)
231
+ r[i] = v
232
+ return FixedVariableArray(r.reshape(self.shape), self.solver_options)
233
+
234
+ case np.square:
235
+ assert len(inputs) == 1
236
+ assert inputs[0] is self
237
+ return self**2
238
+
239
+ if ufunc in _unary_functions:
240
+ assert len(inputs) == 1
241
+ assert inputs[0] is self
242
+ return self.apply(ufunc)
243
+
244
+ raise NotImplementedError(f'Unsupported ufunc: {ufunc}')
245
+
110
246
  def __init__(
111
247
  self,
112
248
  vars: NDArray,
113
- solver_options: dict[str, Any] | None = None,
249
+ solver_options: solver_options_t | None = None,
114
250
  ):
115
- self._vars = np.array(vars)
251
+ _vars = np.array(vars)
252
+ _vars_f = _vars.ravel()
253
+ hwconf = next(iter(v for v in _vars_f if isinstance(v, FixedVariable))).hwconf
254
+ for i, v in enumerate(_vars_f):
255
+ if not isinstance(v, FixedVariable):
256
+ _vars_f[i] = FixedVariable(float(v), float(v), 1.0, hwconf=hwconf)
257
+ self._vars = _vars
116
258
  _solver_options = signature(solve).parameters
117
259
  _solver_options = {k: v.default for k, v in _solver_options.items() if v.default is not v.empty}
118
260
  if solver_options is not None:
119
261
  _solver_options.update(solver_options)
120
262
  _solver_options.pop('qintervals', None)
121
263
  _solver_options.pop('latencies', None)
122
- self.solver_options = _solver_options
264
+ self.solver_options: solver_options_t = _solver_options # type: ignore
123
265
 
124
266
  @classmethod
125
267
  def from_lhs(
@@ -129,7 +271,7 @@ class FixedVariableArray:
129
271
  step: NDArray[np.floating],
130
272
  hwconf: HWConfig,
131
273
  latency: np.ndarray | float = 0.0,
132
- solver_options: dict[str, Any] | None = None,
274
+ solver_options: solver_options_t | None = None,
133
275
  ):
134
276
  shape = low.shape
135
277
  assert shape == high.shape == step.shape
@@ -162,7 +304,7 @@ class FixedVariableArray:
162
304
  f: NDArray[np.integer],
163
305
  hwconf: HWConfig,
164
306
  latency: NDArray[np.floating] | float = 0.0,
165
- solver_options: dict[str, Any] | None = None,
307
+ solver_options: solver_options_t | None = None,
166
308
  ):
167
309
  mask = k + i + f <= 0
168
310
  k = np.where(mask, 0, k)
@@ -173,47 +315,34 @@ class FixedVariableArray:
173
315
  high, low = _high - step, -_high * k
174
316
  return cls.from_lhs(low, high, step, hwconf, latency, solver_options)
175
317
 
176
- def __matmul__(self, other):
318
+ def matmul(self, other):
177
319
  if isinstance(other, FixedVariableArray):
178
320
  other = other._vars
179
321
  if not isinstance(other, np.ndarray):
180
322
  other = np.array(other)
181
323
  if any(isinstance(x, FixedVariable) for x in other.ravel()):
182
324
  mat0, mat1 = self._vars, other
183
- shape = mat0.shape[:-1] + mat1.shape[1:]
184
- mat0, mat1 = mat0.reshape((-1, mat0.shape[-1])), mat1.reshape((mat1.shape[0], -1))
185
- _shape = (mat0.shape[0], mat1.shape[1])
186
- _vars = np.empty(_shape, dtype=object)
187
- for i in range(mat0.shape[0]):
188
- for j in range(mat1.shape[1]):
189
- vec0 = mat0[i]
190
- vec1 = mat1[:, j]
191
- _vars[i, j] = reduce(lambda x, y: x + y, vec0 * vec1)
192
- return FixedVariableArray(_vars.reshape(shape), self.solver_options)
193
-
194
- kwargs = (self.solver_options or {}).copy()
325
+ _vars = mmm(mat0, mat1)
326
+ return FixedVariableArray(_vars, self.solver_options)
327
+
328
+ solver_options = (self.solver_options or {}).copy()
195
329
  shape0, shape1 = self.shape, other.shape
196
330
  assert shape0[-1] == shape1[0], f'Matrix shapes do not match: {shape0} @ {shape1}'
197
- c = shape1[0]
331
+ contract_len = shape1[0]
198
332
  out_shape = shape0[:-1] + shape1[1:]
199
- mat0, mat1 = self.reshape((-1, c)), other.reshape((c, -1))
333
+ mat0, mat1 = self.reshape((-1, contract_len)), other.reshape((contract_len, -1))
200
334
  r = []
201
335
  for i in range(mat0.shape[0]):
202
336
  vec = mat0[i]
203
- _qintervals = [QInterval(float(v.low), float(v.high), float(v.step)) for v in vec._vars]
204
- _latencies = [float(v.latency) for v in vec._vars]
205
- qintervals = NumbaList(_qintervals) # type: ignore
206
- latencies = NumbaList(_latencies) # type: ignore
207
- hwconf = self._vars.ravel()[0].hwconf
208
- kwargs.update(adder_size=hwconf.adder_size, carry_size=hwconf.carry_size)
209
- _mat = np.ascontiguousarray(mat1.astype(np.float32))
210
- sol = solve(_mat, qintervals=qintervals, latencies=latencies, **kwargs)
211
- _r = sol(vec._vars)
337
+ _r = cmvm(mat1, vec, solver_options)
212
338
  r.append(_r)
213
339
  r = np.array(r).reshape(out_shape)
214
340
  return FixedVariableArray(r, self.solver_options)
215
341
 
216
- def __rmatmul__(self, other):
342
+ def __matmul__(self, other):
343
+ return self.matmul(other)
344
+
345
+ def rmatmul(self, other):
217
346
  mat1 = np.moveaxis(other, -1, 0)
218
347
  mat0 = np.moveaxis(self, 0, -1) # type: ignore
219
348
  ndim0, ndim1 = mat0.ndim, mat1.ndim
@@ -223,6 +352,9 @@ class FixedVariableArray:
223
352
  axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
224
353
  return r.transpose(axes)
225
354
 
355
+ def __rmatmul__(self, other):
356
+ return self.rmatmul(other)
357
+
226
358
  def __getitem__(self, item):
227
359
  vars = self._vars[item]
228
360
  if isinstance(vars, np.ndarray):
@@ -269,10 +401,17 @@ class FixedVariableArray:
269
401
 
270
402
  def __pow__(self, power: int | float):
271
403
  _power = int(power)
272
- assert _power == power, 'Power must be an integer'
273
- return FixedVariableArray(self._vars**_power, self.solver_options)
404
+ if _power == power and _power >= 0:
405
+ return FixedVariableArray(self._vars**_power, self.solver_options)
406
+ else:
407
+ return self.apply(lambda x: x**power)
274
408
 
275
- def relu(self, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN'):
409
+ def relu(
410
+ self,
411
+ i: NDArray[np.integer] | None = None,
412
+ f: NDArray[np.integer] | None = None,
413
+ round_mode: str = 'TRN',
414
+ ):
276
415
  shape = self._vars.shape
277
416
  i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
278
417
  f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
@@ -290,9 +429,11 @@ class FixedVariableArray:
290
429
  round_mode: str = 'TRN',
291
430
  ):
292
431
  shape = self._vars.shape
293
- k = np.broadcast_to(k, shape) if k is not None else np.full(shape, None)
294
- i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
295
- f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
432
+ if any(x is None for x in (k, i, f)):
433
+ kif = self.kif
434
+ k = np.broadcast_to(k, shape) if k is not None else kif[0] # type: ignore
435
+ i = np.broadcast_to(i, shape) if i is not None else kif[1] # type: ignore
436
+ f = np.broadcast_to(f, shape) if f is not None else kif[2] # type: ignore
296
437
  ret = []
297
438
  for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()): # type: ignore
298
439
  ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
@@ -324,17 +465,28 @@ class FixedVariableArray:
324
465
 
325
466
  @property
326
467
  def kif(self):
468
+ """[k, i, f] array"""
327
469
  shape = self._vars.shape
328
470
  kif = np.array([v.kif for v in self._vars.ravel()]).reshape(*shape, 3)
329
471
  return np.moveaxis(kif, -1, 0)
330
472
 
473
+ def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
474
+ """Apply a unary operator to all elements, returning a RetardedFixedVariableArray."""
475
+ return RetardedFixedVariableArray(
476
+ self._vars,
477
+ self.solver_options,
478
+ operator=fn,
479
+ )
480
+
331
481
 
332
482
  class FixedVariableArrayInput(FixedVariableArray):
483
+ """Similar to FixedVariableArray, but initializes all elements as FixedVariableInput - the precisions are unspecified when initialized, and the highest precision requested (i.e., quantized to) will be recorded for generation of the logic."""
484
+
333
485
  def __init__(
334
486
  self,
335
487
  shape: tuple[int, ...] | int,
336
- hwconf: HWConfig = HWConfig(1, -1, -1),
337
- solver_options: dict[str, Any] | None = None,
488
+ hwconf: HWConfig | tuple[int, int, int] = HWConfig(1, -1, -1),
489
+ solver_options: solver_options_t | None = None,
338
490
  latency=0.0,
339
491
  ):
340
492
  _vars = np.empty(shape, dtype=object)
@@ -342,3 +494,79 @@ class FixedVariableArrayInput(FixedVariableArray):
342
494
  for i in range(_vars.size):
343
495
  _vars_f[i] = FixedVariableInput(latency, hwconf)
344
496
  super().__init__(_vars, solver_options)
497
+
498
+
499
+ def make_table(fn: Callable[[NDArray], NDArray], qint: QInterval) -> LookupTable:
500
+ low, high, step = qint
501
+ n = round((high - low) / step) + 1
502
+ return LookupTable(fn(np.linspace(low, high, n)))
503
+
504
+
505
+ class RetardedFixedVariableArray(FixedVariableArray):
506
+ """Ephemeral FixedVariableArray generated from operations of unspecified output precision.
507
+ This object translates to normal FixedVariableArray upon quantization.
508
+ Does not inherit the maximum precision like FixedVariableArrayInput.
509
+
510
+ This object can be used in two ways:
511
+ 1. Quantization with specified precision, which converts to FixedVariableArray.
512
+ 2. Apply an further unary operation, which returns another RetardedFixedVariableArray. (e.g., composite functions)
513
+ """
514
+
515
+ def __init__(self, vars: NDArray, solver_options: solver_options_t | None, operator: Callable[[NDArray], NDArray]):
516
+ self._operator = operator
517
+ super().__init__(vars, solver_options)
518
+
519
+ def __array_function__(self, ufunc, method, *inputs, **kwargs):
520
+ raise RuntimeError('RetardedFixedVariableArray only supports quantization or further unary operations.')
521
+
522
+ def apply(self, fn: Callable[[NDArray], NDArray]) -> 'RetardedFixedVariableArray':
523
+ return RetardedFixedVariableArray(
524
+ self._vars,
525
+ self.solver_options,
526
+ operator=lambda x: fn(self._operator(x)),
527
+ )
528
+
529
+ def quantize(
530
+ self,
531
+ k: NDArray[np.integer] | np.integer | int | None = None,
532
+ i: NDArray[np.integer] | np.integer | int | None = None,
533
+ f: NDArray[np.integer] | np.integer | int | None = None,
534
+ overflow_mode: str = 'WRAP',
535
+ round_mode: str = 'TRN',
536
+ ):
537
+ if any(x is None for x in (k, i, f)):
538
+ assert all(x is not None for x in (k, i, f)), 'Either all or none of k, i, f must be specified'
539
+ _k = _i = _f = [None] * self.size
540
+ else:
541
+ _k = np.broadcast_to(k, self.shape).ravel() # type: ignore
542
+ _i = np.broadcast_to(i, self.shape).ravel() # type: ignore
543
+ _f = np.broadcast_to(f, self.shape).ravel() # type: ignore
544
+
545
+ op = lambda x: _quantize(self._operator(x), k, i, f, overflow_mode, round_mode) # type: ignore
546
+
547
+ local_tables: dict[tuple[QInterval, tuple[int, int, int] | None], LookupTable] = {}
548
+ variables = []
549
+ for v, _kk, _ii, _ff in zip(self._vars.ravel(), _k, _i, _f):
550
+ if (_kk is None) or (_ii is None) or (_ff is None):
551
+ op = self._operator
552
+ _key = v.qint
553
+ else:
554
+ op = lambda x: _quantize(self._operator(x), _kk, _ii, _ff, overflow_mode, round_mode) # type: ignore
555
+ _key = (v.qint, (int(_kk), int(_ii), int(_ff)))
556
+
557
+ if _key in local_tables:
558
+ table = local_tables[_key]
559
+ else:
560
+ table = make_table(op, v.qint)
561
+ local_tables[_key] = table
562
+ variables.append(v.lookup(table))
563
+
564
+ variables = np.array(variables).reshape(self._vars.shape)
565
+ return FixedVariableArray(variables, self.solver_options)
566
+
567
+ def __repr__(self):
568
+ return 'Retarded' + super().__repr__()
569
+
570
+ @property
571
+ def kif(self):
572
+ raise RuntimeError('RetardedFixedVariableArray does not have defined kif until quantized.')
@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, TypeVar
2
2
 
3
3
  import numpy as np
4
4
  from numpy.typing import NDArray
5
+ from quantizers.fixed_point.fixed_point_ops_np import get_fixed_quantizer_np
5
6
 
6
7
  from ..fixed_variable_array import FixedVariable
7
8
  from .conv_utils import conv, pool
@@ -22,9 +23,11 @@ def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | No
22
23
  elif isinstance(x, list):
23
24
  return [xx.relu(i=ii, f=ff, round_mode=round_mode) for xx, ii, ff in zip(x, i, f)] # type: ignore
24
25
  else:
26
+ round_mode = round_mode.upper()
27
+ assert round_mode in ('TRN', 'RND')
25
28
  x = np.maximum(x, 0)
26
29
  if f is not None:
27
- if round_mode.upper() == 'RND':
30
+ if round_mode == 'RND':
28
31
  x += 2.0 ** (-f - 1)
29
32
  sf = 2.0**f
30
33
  x = np.floor(x * sf) / sf
@@ -33,6 +36,18 @@ def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | No
33
36
  return x
34
37
 
35
38
 
39
+ def _quantize(
40
+ x: NDArray[np.floating],
41
+ k: NDArray[np.integer] | np.integer | int,
42
+ i: NDArray[np.integer] | np.integer | int,
43
+ f: NDArray[np.integer] | np.integer | int,
44
+ overflow_mode: str = 'WRAP',
45
+ round_mode: str = 'TRN',
46
+ ) -> NDArray[np.floating]:
47
+ q = get_fixed_quantizer_np(round_mode=round_mode, overflow_mode=overflow_mode)
48
+ return q(x, k=k, i=i, f=f) # type: ignore
49
+
50
+
36
51
  def quantize(
37
52
  x: T,
38
53
  k: NDArray[np.integer] | np.integer | int,
@@ -43,22 +58,23 @@ def quantize(
43
58
  ) -> T:
44
59
  from ..fixed_variable_array import FixedVariableArray
45
60
 
46
- if isinstance(x, FixedVariableArray):
61
+ if isinstance(x, (FixedVariableArray, FixedVariable)):
47
62
  return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
63
+ elif isinstance(x, list):
64
+ ret: list[FixedVariable] = []
65
+ for i in range(len(x)):
66
+ ret.append(
67
+ x[i].quantize(
68
+ k=int(k[i] if isinstance(k, (list, np.ndarray)) else k),
69
+ i=int(i[i] if isinstance(i, (list, np.ndarray)) else i),
70
+ f=int(f[i] if isinstance(f, (list, np.ndarray)) else f),
71
+ overflow_mode=overflow_mode,
72
+ round_mode=round_mode,
73
+ )
74
+ )
75
+ return ret # type: ignore
48
76
  else:
49
- x = x.copy()
50
- if overflow_mode in ('SAT', 'SAT_SYM'):
51
- step = 2.0**-f
52
- _high = 2.0**i
53
- high = _high - step
54
- low = -_high * k if overflow_mode == 'SAT' else -high * k
55
- x = np.clip(x, low, high) # type: ignore
56
- if round_mode.upper() == 'RND':
57
- x += 2.0 ** (-f - 1) # type: ignore
58
- b = k + i + f
59
- bias = 2.0 ** (b - 1) * k
60
- eps = 2.0**-f
61
- return eps * ((np.floor(x / eps) + bias) % 2.0**b - bias) # type: ignore
77
+ return _quantize(x, k, i, f, overflow_mode, round_mode)
62
78
 
63
79
 
64
80
  __all__ = [
@@ -100,6 +100,6 @@ def reduce(operator: Callable[[T, T], T], x: TA, axis: int | Sequence[int] | Non
100
100
 
101
101
  if isinstance(x, FixedVariableArray):
102
102
  r = FixedVariableArray(r, solver_config)
103
- if r.size == 1 and not keepdims:
104
- return r.ravel()[0] # type: ignore
105
- return r if r.size > 1 or keepdims else r.ravel()[0] # type: ignore
103
+ if r.shape == ():
104
+ return r._vars.item() # type: ignore
105
+ return r if r.shape != () or keepdims else r.item() # type: ignore
da4ml/trace/pipeline.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from math import ceil, floor
2
2
 
3
- from ..cmvm.types import CascadedSolution, Op, Solution
3
+ from ..cmvm.types import CombLogic, Op, Pipeline
4
4
  from .fixed_variable import FixedVariable, HWConfig
5
5
  from .tracer import comb_trace
6
6
 
7
7
 
8
- def retime_pipeline(csol: CascadedSolution, verbose=True):
8
+ def retime_pipeline(csol: Pipeline, verbose=True):
9
9
  n_stages = len(csol[0])
10
10
  cutoff_high = ceil(max(max(sol.out_latency) / (i + 1) for i, sol in enumerate(csol[0])))
11
11
  cutoff_low = 0
@@ -60,14 +60,14 @@ def _get_new_idx(
60
60
  return p0_idx
61
61
 
62
62
 
63
- def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=True) -> CascadedSolution:
63
+ def to_pipeline(comb: CombLogic, latency_cutoff: float, retiming=True, verbose=True) -> Pipeline:
64
64
  """Split the record into multiple stages based on the latency of the operations.
65
65
  Only useful for HDL generation.
66
66
 
67
67
  Parameters
68
68
  ----------
69
- sol : Solution
70
- The solution to be split into multiple stages.
69
+ sol : CombLogic
70
+ The combinational logic to be pipelined into multiple stages.
71
71
  latency_cutoff : float
72
72
  The latency cutoff for splitting the operations.
73
73
  retiming : bool
@@ -83,8 +83,8 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
83
83
  CascadedSolution
84
84
  The cascaded solution with multiple stages.
85
85
  """
86
- assert len(sol.ops) > 0, 'No operations in the record'
87
- for i, op in enumerate(sol.ops):
86
+ assert len(comb.ops) > 0, 'No operations in the record'
87
+ for i, op in enumerate(comb.ops):
88
88
  if op.id1 != -1:
89
89
  break
90
90
 
@@ -96,9 +96,9 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
96
96
 
97
97
  locator: list[dict[int, int]] = []
98
98
 
99
- ops = sol.ops.copy()
100
- lat = max(ops[i].latency for i in sol.out_idxs)
101
- for i in sol.out_idxs:
99
+ ops = comb.ops.copy()
100
+ lat = max(ops[i].latency for i in comb.out_idxs)
101
+ for i in comb.out_idxs:
102
102
  op_out = ops[i]
103
103
  ops.append(Op(i, -1001, -1001, 0, op_out.qint, lat, 0.0))
104
104
 
@@ -113,7 +113,10 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
113
113
  p0_idx = _get_new_idx(op.id0, locator, opd, out_idxd, ops, stage, latency_cutoff)
114
114
  p1_idx = _get_new_idx(op.id1, locator, opd, out_idxd, ops, stage, latency_cutoff)
115
115
  if op.opcode in (6, -6):
116
- data = _get_new_idx(op.data, locator, opd, out_idxd, ops, stage, latency_cutoff)
116
+ k = op.data & 0xFFFFFFFF
117
+ _shift = (op.data >> 32) & 0xFFFFFFFF
118
+ k = _get_new_idx(k, locator, opd, out_idxd, ops, stage, latency_cutoff)
119
+ data = _shift << 32 | k
117
120
  else:
118
121
  data = op.data
119
122
 
@@ -126,34 +129,53 @@ def to_pipeline(sol: Solution, latency_cutoff: float, retiming=True, verbose=Tru
126
129
  locator.append({stage: len(opd[stage]) - 1})
127
130
  sols = []
128
131
  max_stage = max(opd.keys())
129
- n_in = sol.shape[0]
132
+ n_in = comb.shape[0]
130
133
  for i, stage in enumerate(opd.keys()):
131
134
  _ops = opd[stage]
132
135
  _out_idx = out_idxd[stage]
133
136
  n_out = len(_out_idx)
134
137
 
135
138
  if i == max_stage:
136
- out_shifts = sol.out_shifts
137
- out_negs = sol.out_negs
139
+ out_shifts = comb.out_shifts
140
+ out_negs = comb.out_negs
138
141
  else:
139
142
  out_shifts = [0] * len(_out_idx)
140
143
  out_negs = [False] * len(_out_idx)
141
144
 
142
- _sol = Solution(
145
+ if comb.lookup_tables is not None:
146
+ _ops, lookup_tables = remap_table_idxs(comb, _ops)
147
+ else:
148
+ lookup_tables = None
149
+ _sol = CombLogic(
143
150
  shape=(n_in, n_out),
144
151
  inp_shift=[0] * n_in,
145
152
  out_idxs=_out_idx,
146
153
  out_shifts=out_shifts,
147
154
  out_negs=out_negs,
148
155
  ops=_ops,
149
- carry_size=sol.carry_size,
150
- adder_size=sol.adder_size,
156
+ carry_size=comb.carry_size,
157
+ adder_size=comb.adder_size,
158
+ lookup_tables=lookup_tables,
151
159
  )
152
160
  sols.append(_sol)
153
161
 
154
162
  n_in = n_out
155
- csol = CascadedSolution(tuple(sols))
163
+ csol = Pipeline(tuple(sols))
156
164
 
157
165
  if retiming:
158
166
  csol = retime_pipeline(csol, verbose=verbose)
159
167
  return csol
168
+
169
+
170
+ def remap_table_idxs(comb: CombLogic, _ops):
171
+ assert comb.lookup_tables is not None
172
+ table_idxs = sorted(list({op.data for op in _ops if op.opcode == 8}))
173
+ remap = {j: i for i, j in enumerate(table_idxs)}
174
+ _ops_remap = []
175
+ for op in _ops:
176
+ if op.opcode == 8:
177
+ op = Op(op.id0, op.id1, op.opcode, remap[op.data], op.qint, op.latency, op.cost)
178
+ _ops_remap.append(op)
179
+ _ops = _ops_remap
180
+ lookup_tables = tuple(comb.lookup_tables[i] for i in table_idxs)
181
+ return _ops, lookup_tables