da4ml 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (59) hide show
  1. da4ml/_version.py +2 -2
  2. da4ml/cmvm/api.py +2 -6
  3. da4ml/cmvm/core/__init__.py +0 -1
  4. da4ml/cmvm/types.py +99 -19
  5. da4ml/codegen/__init__.py +5 -4
  6. da4ml/codegen/cpp/__init__.py +2 -1
  7. da4ml/codegen/cpp/cpp_codegen.py +58 -25
  8. da4ml/codegen/cpp/hls_model.py +252 -0
  9. da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
  10. da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
  11. da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
  12. da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
  13. da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
  14. da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
  15. da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
  16. da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
  17. da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
  18. da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
  19. da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
  20. da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
  21. da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
  22. da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
  23. da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
  24. da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
  25. da4ml/codegen/cpp/source/binder_util.hh +56 -0
  26. da4ml/codegen/cpp/source/build_binder.mk +24 -0
  27. da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
  28. da4ml/codegen/verilog/__init__.py +2 -3
  29. da4ml/codegen/verilog/comb.py +65 -24
  30. da4ml/codegen/verilog/io_wrapper.py +36 -141
  31. da4ml/codegen/verilog/pipeline.py +21 -3
  32. da4ml/codegen/verilog/source/binder_util.hh +72 -0
  33. da4ml/codegen/verilog/source/build_prj.tcl +0 -1
  34. da4ml/codegen/verilog/source/mux.v +58 -0
  35. da4ml/codegen/verilog/source/negative.v +28 -0
  36. da4ml/codegen/verilog/source/shift_adder.v +4 -1
  37. da4ml/codegen/verilog/source/template.xdc +3 -0
  38. da4ml/codegen/verilog/verilog_model.py +42 -15
  39. da4ml/converter/__init__.py +0 -0
  40. da4ml/converter/hgq2/parser.py +105 -0
  41. da4ml/converter/hgq2/replica.py +383 -0
  42. da4ml/trace/__init__.py +2 -2
  43. da4ml/trace/fixed_variable.py +177 -18
  44. da4ml/trace/fixed_variable_array.py +124 -9
  45. da4ml/trace/ops/__init__.py +22 -6
  46. da4ml/trace/ops/conv_utils.py +146 -14
  47. da4ml/trace/ops/einsum_utils.py +9 -6
  48. da4ml/trace/ops/reduce_utils.py +103 -0
  49. da4ml/trace/pipeline.py +36 -34
  50. da4ml/trace/tracer.py +37 -5
  51. da4ml-0.3.0.dist-info/METADATA +107 -0
  52. da4ml-0.3.0.dist-info/RECORD +64 -0
  53. da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
  54. da4ml-0.2.0.dist-info/METADATA +0 -65
  55. da4ml-0.2.0.dist-info/RECORD +0 -39
  56. /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
  57. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/WHEEL +0 -0
  58. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/licenses/LICENSE +0 -0
  59. {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/top_level.txt +0 -0
@@ -43,9 +43,9 @@ class FixedVariable:
43
43
  ) -> None:
44
44
  assert low <= high, f'low {low} must be less than high {high}'
45
45
 
46
- if low == high:
46
+ if low == high and opr != 'new':
47
47
  opr = 'const'
48
- _factor = 1.0
48
+ _factor = _factor
49
49
  _from = ()
50
50
 
51
51
  low, high, step = Decimal(low), Decimal(high), Decimal(step)
@@ -72,15 +72,21 @@ class FixedVariable:
72
72
  self.latency = _latency
73
73
  self.cost = _cost
74
74
 
75
+ # Update latency for constant variables to match the current variable for piplining
76
+
77
+ for v in self._from:
78
+ if v.opr == 'const':
79
+ v.latency = self.latency
80
+
75
81
  def get_cost_and_latency(self):
76
82
  if self.opr == 'const':
77
83
  return 0.0, 0.0
78
- if self.opr in ('vadd', 'cadd'):
84
+ if self.opr in ('vadd', 'cadd', 'min', 'max'):
79
85
  adder_size = self.hwconf.adder_size
80
86
  carry_size = self.hwconf.carry_size
81
87
  latency_cutoff = self.hwconf.latency_cutoff
82
88
 
83
- if self.opr == 'vadd':
89
+ if self.opr in ('min', 'max', 'vadd'):
84
90
  assert len(self._from) == 2
85
91
  v0, v1 = self._from
86
92
  int0, int1 = v0.qint, v1.qint
@@ -89,8 +95,6 @@ class FixedVariable:
89
95
  else:
90
96
  assert len(self._from) == 1
91
97
  assert self._data is not None, 'cadd must have data'
92
- # int0 = self._from[0].qint
93
- # int1 = QInterval(float(self._data), float(self._data), float(self.step))
94
98
  _f = _const_f(self._data)
95
99
  _cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
96
100
  base_latency = self._from[0].latency
@@ -138,6 +142,12 @@ class FixedVariable:
138
142
  k = self.low < 0
139
143
  return k, i, f
140
144
 
145
+ @classmethod
146
+ def from_const(cls, const: float | Decimal, hwconf: HWConfig, latency: float, _factor: float | Decimal):
147
+ f = _const_f(const)
148
+ step = Decimal(2) ** -f
149
+ return cls(const, const, step, hwconf=hwconf, opr='const', _factor=_factor, latency=latency)
150
+
141
151
  def __repr__(self) -> str:
142
152
  if self._factor == 1:
143
153
  return f'FixedVariable({self.low}, {self.high}, {self.step})'
@@ -185,7 +195,9 @@ class FixedVariable:
185
195
  hwconf=self.hwconf,
186
196
  )
187
197
 
188
- def _const_add(self, other: float | Decimal):
198
+ def _const_add(self, other: float | Decimal | None):
199
+ if other is None:
200
+ return self
189
201
  if not isinstance(other, (int, float, Decimal)):
190
202
  other = float(other) # direct numpy to decimal raises error
191
203
  other = Decimal(other)
@@ -222,7 +234,7 @@ class FixedVariable:
222
234
  other: 'float|Decimal',
223
235
  ):
224
236
  if other == 0:
225
- return FixedVariable(0, 0, 1, hwconf=self.hwconf)
237
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
226
238
 
227
239
  assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
228
240
 
@@ -266,8 +278,8 @@ class FixedVariable:
266
278
  step = Decimal(2) ** -f
267
279
  i = ceil(log2(val + step)) if not i else i
268
280
  eps = step / 2 if round_mode == 'RND' else 0
269
- val = floor(val / step + eps) % Decimal(2) ** i * step
270
- return FixedVariable(val, val, step, hwconf=self.hwconf)
281
+ val = (floor(val / step + eps) * step) % (Decimal(2) ** i)
282
+ return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
271
283
 
272
284
  step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
273
285
  if step > self.step and round_mode == 'RND':
@@ -281,6 +293,10 @@ class FixedVariable:
281
293
  low = Decimal(0)
282
294
  high = _high
283
295
  _factor = self._factor
296
+
297
+ if self.low == low and self.high == high and self.step == step:
298
+ return self
299
+
284
300
  return FixedVariable(
285
301
  low,
286
302
  high,
@@ -301,7 +317,7 @@ class FixedVariable:
301
317
  round_mode: str = 'TRN',
302
318
  ):
303
319
  overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
304
- assert overflow_mode in ('WRAP', 'SAT')
320
+ assert overflow_mode in ('WRAP', 'SAT', 'SAT_SM')
305
321
  assert round_mode in ('TRN', 'RND')
306
322
 
307
323
  _k, _i, _f = self.kif
@@ -312,32 +328,42 @@ class FixedVariable:
312
328
  if f < _f and round_mode == 'RND':
313
329
  return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
314
330
 
331
+ if overflow_mode in ('SAT', 'SAT_SM'):
332
+ step = Decimal(2) ** -f
333
+ _high = Decimal(2) ** i
334
+ high = _high - step
335
+ low = -_high * k if overflow_mode == 'SAT' else -high * k
336
+ return self.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
337
+
315
338
  if self.low == self.high:
316
339
  val = self.low
317
340
  step = Decimal(2) ** -f
318
341
  _high = Decimal(2) ** i
319
342
  high, low = _high - step, -_high * k
320
343
  val = (floor(val / step) * step - low) % (2 * _high) + low
321
- return FixedVariable(val, val, step, hwconf=self.hwconf)
344
+ return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
322
345
 
323
346
  # TODO: corner cases exists (e.g., overflow to negative, or negative overflow to high value)
324
347
  # bit-exactness will be lost in these cases, but they should never happen (quantizers are used in a weird way)
325
348
  # Keeping this for now; change if absolutely necessary
326
349
  f = min(f, _f)
327
- k = min(k, _k)
350
+ k = min(k, _k) if i >= _i else k
328
351
  i = min(i, _i)
329
352
 
330
- step = max(Decimal(2) ** -f, self.step)
353
+ if i + k + f <= 0:
354
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
355
+
356
+ step = Decimal(2) ** -f
331
357
 
332
358
  low = -k * Decimal(2) ** i
359
+
333
360
  high = Decimal(2) ** i - step
334
361
  _low, _high = self.low, self.high
335
362
 
336
363
  if _low >= low and _high <= high:
337
364
  low, high = _low, _high
338
-
339
- if low > high:
340
- return FixedVariable(0, 0, 1, hwconf=self.hwconf)
365
+ low = floor(low / step) * step
366
+ high = ceil(high / step) * step
341
367
 
342
368
  return FixedVariable(
343
369
  low,
@@ -345,7 +371,7 @@ class FixedVariable:
345
371
  step,
346
372
  _from=(self,),
347
373
  _factor=abs(self._factor),
348
- opr='wrap' if overflow_mode == 'WRAP' else 'sat',
374
+ opr='wrap',
349
375
  latency=self.latency,
350
376
  hwconf=self.hwconf,
351
377
  )
@@ -356,3 +382,136 @@ class FixedVariable:
356
382
  _high = Decimal(2) ** i
357
383
  low, high = k * _high, _high - step
358
384
  return cls(low, high, step, **kwargs)
385
+
386
+ def msb_mux(self, a: 'FixedVariable', b: 'FixedVariable', qint: tuple[Decimal, Decimal, Decimal] | None = None):
387
+ assert isinstance(a, FixedVariable) and isinstance(b, FixedVariable), 'msb_mux requires two FixedVariables'
388
+ if self._factor < 0:
389
+ return (-self).msb_mux(b, a, qint)
390
+
391
+ if a._factor < 0:
392
+ qint = (-qint[1], -qint[0], qint[2]) if qint else None
393
+ return -(self.msb_mux(-a, -b, qint=qint))
394
+
395
+ _factor = a._factor
396
+
397
+ if qint is None:
398
+ qint = (min(a.low, b.low), max(a.high, b.high), min(a.step, b.step))
399
+
400
+ dlat, dcost = cost_add(a.qint, b.qint, 0, False, self.hwconf.adder_size, self.hwconf.carry_size)
401
+ return FixedVariable(
402
+ *qint,
403
+ _from=(self, a, b),
404
+ _factor=_factor,
405
+ opr='msb_mux',
406
+ latency=max(a.latency, b.latency, self.latency) + dlat,
407
+ hwconf=self.hwconf,
408
+ cost=dcost,
409
+ )
410
+
411
+ def max_of(self, other):
412
+ if other == 0:
413
+ return self.relu()
414
+ if other == -float('inf'):
415
+ return self
416
+ if other == float('inf'):
417
+ raise ValueError('Cannot apply max_of with inf')
418
+ if not isinstance(other, FixedVariable):
419
+ other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=abs(self._factor))
420
+
421
+ if self.low >= other.high:
422
+ return self
423
+ if self.high <= other.low:
424
+ return other
425
+
426
+ qint = (max(self.low, other.low), max(self.high, other.high), min(self.step, other.step))
427
+ return (self - other).msb_mux(other, self, qint=qint)
428
+
429
+ def min_of(self, other):
430
+ if other == 0:
431
+ return (-self).relu()
432
+ if other == float('inf'):
433
+ return self
434
+ if other == -float('inf'):
435
+ raise ValueError('Cannot apply min_of with -inf')
436
+ if not isinstance(other, FixedVariable):
437
+ other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=(self._factor))
438
+
439
+ if self.high <= other.low:
440
+ return self
441
+ if self.low >= other.high:
442
+ return other
443
+
444
+ qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
445
+ return (self - other).msb_mux(self, other, qint=qint)
446
+
447
+
448
+ class FixedVariableInput(FixedVariable):
449
+ def __init__(
450
+ self,
451
+ latency: float | None = None,
452
+ hwconf=HWConfig(-1, -1, -1),
453
+ ) -> None:
454
+ self.low = Decimal(1e10)
455
+ self.high = Decimal(-1e10)
456
+ self.step = Decimal(1e10)
457
+ self._factor = Decimal(1)
458
+ self._from: tuple[FixedVariable, ...] = ()
459
+ self.opr = 'new'
460
+ self._data = None
461
+ self.id = uuid4()
462
+ self.hwconf = hwconf
463
+
464
+ self.latency = latency if latency is not None else 0.0
465
+ self.cost = 0.0
466
+
467
+ def __add__(self, other):
468
+ raise ValueError('Cannot operate on unquantized input variable')
469
+
470
+ def __sub__(self, other):
471
+ raise ValueError('Cannot operate on unquantized input variable')
472
+
473
+ def __neg__(self):
474
+ raise ValueError('Cannot negate unquantized input variable')
475
+
476
+ def relu(self, *args, **kwargs):
477
+ raise ValueError('Cannot apply relu on unquantized input variable')
478
+
479
+ def max_of(self, other):
480
+ raise ValueError('Cannot apply max_of on unquantized input variable')
481
+
482
+ def min_of(self, other):
483
+ raise ValueError('Cannot apply min_of on unquantized input variable')
484
+
485
+ def quantize(
486
+ self,
487
+ k: int | bool,
488
+ i: int,
489
+ f: int,
490
+ overflow_mode: str = 'WRAP',
491
+ round_mode: str = 'TRN',
492
+ ):
493
+ assert overflow_mode == 'WRAP'
494
+
495
+ if k + i + f <= 0:
496
+ return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
497
+
498
+ if round_mode == 'RND':
499
+ return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
500
+
501
+ step = Decimal(2) ** -f
502
+ _high = Decimal(2) ** i
503
+ low, high = -_high * k, _high - step
504
+ self.high = max(self.high, high)
505
+ self.low = min(self.low, low)
506
+ self.step = min(self.step, step)
507
+
508
+ return FixedVariable(
509
+ low,
510
+ high,
511
+ step,
512
+ _from=(self,),
513
+ _factor=self._factor,
514
+ opr='wrap',
515
+ latency=self.latency,
516
+ hwconf=self.hwconf,
517
+ )
@@ -1,20 +1,110 @@
1
- from typing import Any
1
+ from inspect import signature
2
+ from typing import Any, TypeVar
2
3
 
3
4
  import numpy as np
5
+ from numba.typed import List as NumbaList
4
6
  from numpy.typing import NDArray
5
7
 
6
8
  from ..cmvm import solve
7
- from .fixed_variable import FixedVariable, HWConfig, QInterval
9
+ from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, QInterval
10
+ from .ops import einsum, reduce
11
+
12
+ T = TypeVar('T')
13
+
14
+
15
+ def to_raw_arr(obj: T) -> T:
16
+ if isinstance(obj, tuple):
17
+ return tuple(to_raw_arr(x) for x in obj) # type: ignore
18
+ elif isinstance(obj, list):
19
+ return [to_raw_arr(x) for x in obj] # type: ignore
20
+ elif isinstance(obj, dict):
21
+ return {k: to_raw_arr(v) for k, v in obj.items()} # type: ignore
22
+ if isinstance(obj, FixedVariableArray):
23
+ return obj._vars # type: ignore
24
+ return obj
25
+
26
+
27
+ def _max_of(a, b):
28
+ if isinstance(a, FixedVariable):
29
+ return a.max_of(b)
30
+ elif isinstance(b, FixedVariable):
31
+ return b.max_of(a)
32
+ else:
33
+ return max(a, b)
34
+
35
+
36
+ def _min_of(a, b):
37
+ if isinstance(a, FixedVariable):
38
+ return a.min_of(b)
39
+ elif isinstance(b, FixedVariable):
40
+ return b.min_of(a)
41
+ else:
42
+ return min(a, b)
8
43
 
9
44
 
10
45
  class FixedVariableArray:
46
+ __array_priority__ = 100
47
+
48
+ def __array_function__(self, func, types, args, kwargs):
49
+ if func is np.matmul:
50
+ if len(args) == 1 and isinstance(args[0], np.ndarray):
51
+ return self.__matmul__(args[0])
52
+ elif len(args) == 2 and isinstance(args[0], np.ndarray) and isinstance(args[1], np.ndarray):
53
+ return self.__rmatmul__(args[1])
54
+
55
+ if func in (np.mean, np.sum, np.amax, np.amin, np.max, np.min):
56
+ match func:
57
+ case np.mean:
58
+ _x = reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
59
+ return _x * (_x.size / self._vars.size)
60
+ case np.sum:
61
+ return reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
62
+ case np.max | np.amax:
63
+ return reduce(_max_of, self, *args[1:], **kwargs)
64
+ case np.min | np.amin:
65
+ return reduce(_min_of, self, *args[1:], **kwargs)
66
+ case _:
67
+ raise NotImplementedError(f'Unsupported function: {func}')
68
+
69
+ if func is np.clip:
70
+ assert len(args) == 3, 'Clip function requires exactly three arguments'
71
+ x, low, high = args
72
+ _x, low, high = np.broadcast_arrays(x, low, high)
73
+ x = FixedVariableArray(_x, self.solver_options)
74
+ x = np.amax(np.stack((x, low), axis=-1), axis=-1) # type: ignore
75
+ return np.amin(np.stack((x, high), axis=-1), axis=-1)
76
+
77
+ if func is np.einsum:
78
+ # assert len(args) == 2
79
+ sig = signature(np.einsum)
80
+ bind = sig.bind(*args, **kwargs)
81
+ eq = args[0]
82
+ operands = bind.arguments['operands']
83
+ if isinstance(operands[0], str):
84
+ operands = operands[1:]
85
+ assert len(operands) == 2, 'Einsum on FixedVariableArray requires exactly two operands'
86
+ assert bind.arguments.get('out', None) is None, 'Output argument is not supported'
87
+ return einsum(eq, *operands)
88
+
89
+ args, kwargs = to_raw_arr(args), to_raw_arr(kwargs)
90
+ return FixedVariableArray(
91
+ func(*args, **kwargs),
92
+ self.solver_options,
93
+ )
94
+
11
95
  def __init__(
12
96
  self,
13
97
  vars: NDArray,
14
98
  solver_options: dict[str, Any] | None = None,
15
99
  ):
16
100
  self._vars = np.array(vars)
17
- self.solver_options = solver_options
101
+ _solver_options = signature(solve).parameters
102
+ _solver_options = {k: v.default for k, v in _solver_options.items() if v.default is not v.empty}
103
+ if solver_options is not None:
104
+ _solver_options.update(solver_options)
105
+ _solver_options.pop('qintervals', None)
106
+ _solver_options.pop('latencies', None)
107
+ self.solver_options = _solver_options
18
108
 
19
109
  @classmethod
20
110
  def from_lhs(
@@ -75,8 +165,10 @@ class FixedVariableArray:
75
165
  r = []
76
166
  for i in range(mat0.shape[0]):
77
167
  vec = mat0[i]
78
- qintervals = tuple([QInterval(float(v.low), float(v.high), float(v.step)) for v in vec._vars])
79
- latencies = tuple([float(v.latency) for v in vec._vars])
168
+ _qintervals = [QInterval(float(v.low), float(v.high), float(v.step)) for v in vec._vars]
169
+ _latencies = [float(v.latency) for v in vec._vars]
170
+ qintervals = NumbaList(_qintervals) # type: ignore
171
+ latencies = NumbaList(_latencies) # type: ignore
80
172
  hwconf = self._vars.ravel()[0].hwconf
81
173
  kwargs.update(adder_size=hwconf.adder_size, carry_size=hwconf.carry_size)
82
174
  _mat = np.ascontiguousarray(mat1.astype(np.float32))
@@ -96,8 +188,8 @@ class FixedVariableArray:
96
188
  axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
97
189
  return r.transpose(axes)
98
190
 
99
- def __getitem__(self, *item):
100
- vars = self._vars[*item]
191
+ def __getitem__(self, item):
192
+ vars = self._vars[item]
101
193
  if isinstance(vars, np.ndarray):
102
194
  return FixedVariableArray(vars, self.solver_options)
103
195
  else:
@@ -111,9 +203,13 @@ class FixedVariableArray:
111
203
  return self._vars.shape
112
204
 
113
205
  def __add__(self, other):
206
+ if isinstance(other, FixedVariableArray):
207
+ return FixedVariableArray(self._vars + other._vars, self.solver_options)
114
208
  return FixedVariableArray(self._vars + other, self.solver_options)
115
209
 
116
210
  def __sub__(self, other):
211
+ if isinstance(other, FixedVariableArray):
212
+ return FixedVariableArray(self._vars - other._vars, self.solver_options)
117
213
  return FixedVariableArray(self._vars - other, self.solver_options)
118
214
 
119
215
  def __mul__(self, other):
@@ -139,7 +235,7 @@ class FixedVariableArray:
139
235
  i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
140
236
  f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
141
237
  ret = []
142
- for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()):
238
+ for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()): # type: ignore
143
239
  ret.append(v.relu(i=i, f=f, round_mode=round_mode))
144
240
  return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
145
241
 
@@ -156,7 +252,7 @@ class FixedVariableArray:
156
252
  i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
157
253
  f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
158
254
  ret = []
159
- for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()):
255
+ for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()): # type: ignore
160
256
  ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
161
257
  return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
162
258
 
@@ -175,3 +271,22 @@ class FixedVariableArray:
175
271
  @property
176
272
  def dtype(self):
177
273
  return self._vars.dtype
274
+
275
+ @property
276
+ def size(self):
277
+ return self._vars.size
278
+
279
+ @property
280
+ def kif(self):
281
+ shape = self._vars.shape
282
+ kif = np.array([v.kif for v in self._vars.ravel()]).reshape(*shape, 3)
283
+ return np.moveaxis(kif, -1, 0)
284
+
285
+
286
+ class FixedVariableArrayInput(FixedVariableArray):
287
+ def __init__(self, shape: tuple[int, ...] | int, hwconf: HWConfig, solver_options: dict[str, Any] | None = None, latency=0.0):
288
+ _vars = np.empty(shape, dtype=object)
289
+ _vars_f = _vars.ravel()
290
+ for i in range(_vars.size):
291
+ _vars_f[i] = FixedVariableInput(latency, hwconf)
292
+ super().__init__(_vars, solver_options)
@@ -1,16 +1,22 @@
1
- from typing import TypeVar
1
+ from typing import TYPE_CHECKING, TypeVar
2
2
 
3
3
  import numpy as np
4
4
  from numpy.typing import NDArray
5
5
 
6
- from ..fixed_variable_array import FixedVariable, FixedVariableArray
7
- from .conv_utils import conv
6
+ from ..fixed_variable_array import FixedVariable
7
+ from .conv_utils import conv, pool
8
8
  from .einsum_utils import einsum
9
+ from .reduce_utils import reduce
9
10
 
10
- T = TypeVar('T', FixedVariableArray, NDArray[np.floating], list[FixedVariable])
11
+ if TYPE_CHECKING:
12
+ from ..fixed_variable_array import FixedVariableArray
13
+
14
+ T = TypeVar('T', 'FixedVariableArray', NDArray[np.floating], list[FixedVariable])
11
15
 
12
16
 
13
17
  def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN') -> T:
18
+ from ..fixed_variable_array import FixedVariableArray
19
+
14
20
  if isinstance(x, FixedVariableArray):
15
21
  return x.relu(i=i, f=f, round_mode=round_mode)
16
22
  elif isinstance(x, list):
@@ -35,12 +41,20 @@ def quantize(
35
41
  overflow_mode: str = 'WRAP',
36
42
  round_mode: str = 'TRN',
37
43
  ) -> T:
38
- assert overflow_mode.upper() == 'WRAP', 'Only WRAP overflow mode is supported'
44
+ from ..fixed_variable_array import FixedVariableArray
45
+
39
46
  if isinstance(x, FixedVariableArray):
40
47
  return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
41
48
  else:
49
+ x = x.copy()
50
+ if overflow_mode in ('SAT', 'SAT_SM'):
51
+ step = 2.0**-f
52
+ _high = 2.0**i
53
+ high = _high - step
54
+ low = -_high * k if overflow_mode == 'SAT' else -high * k
55
+ x = np.clip(x, low, high) # type: ignore
42
56
  if round_mode.upper() == 'RND':
43
- x += 2.0 ** (-f - 1)
57
+ x += 2.0 ** (-f - 1) # type: ignore
44
58
  b = k + i + f
45
59
  bias = 2.0 ** (b - 1) * k
46
60
  eps = 2.0**-f
@@ -52,4 +66,6 @@ __all__ = [
52
66
  'einsum',
53
67
  'relu',
54
68
  'quantize',
69
+ 'pool',
70
+ 'reduce',
55
71
  ]