da4ml 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/_version.py +2 -2
- da4ml/cmvm/api.py +2 -6
- da4ml/cmvm/core/__init__.py +0 -1
- da4ml/cmvm/types.py +99 -19
- da4ml/codegen/__init__.py +5 -4
- da4ml/codegen/cpp/__init__.py +2 -1
- da4ml/codegen/cpp/cpp_codegen.py +58 -25
- da4ml/codegen/cpp/hls_model.py +252 -0
- da4ml/codegen/cpp/source/ap_types/ap_binary.h +78 -0
- da4ml/codegen/cpp/source/ap_types/ap_common.h +376 -0
- da4ml/codegen/cpp/source/ap_types/ap_decl.h +212 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed.h +360 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_base.h +2354 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_ref.h +718 -0
- da4ml/codegen/cpp/source/ap_types/ap_fixed_special.h +230 -0
- da4ml/codegen/cpp/source/ap_types/ap_int.h +330 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_base.h +1885 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_ref.h +1346 -0
- da4ml/codegen/cpp/source/ap_types/ap_int_special.h +223 -0
- da4ml/codegen/cpp/source/ap_types/ap_shift_reg.h +138 -0
- da4ml/codegen/cpp/source/ap_types/etc/ap_private.h +7199 -0
- da4ml/codegen/cpp/source/ap_types/hls_math.h +27 -0
- da4ml/codegen/cpp/source/ap_types/hls_stream.h +263 -0
- da4ml/codegen/cpp/source/ap_types/utils/x_hls_utils.h +80 -0
- da4ml/codegen/cpp/source/binder_util.hh +56 -0
- da4ml/codegen/cpp/source/build_binder.mk +24 -0
- da4ml/codegen/cpp/source/{vitis.h → vitis_bitshift.hh} +1 -1
- da4ml/codegen/verilog/__init__.py +2 -3
- da4ml/codegen/verilog/comb.py +65 -24
- da4ml/codegen/verilog/io_wrapper.py +36 -141
- da4ml/codegen/verilog/pipeline.py +21 -3
- da4ml/codegen/verilog/source/binder_util.hh +72 -0
- da4ml/codegen/verilog/source/build_prj.tcl +0 -1
- da4ml/codegen/verilog/source/mux.v +58 -0
- da4ml/codegen/verilog/source/negative.v +28 -0
- da4ml/codegen/verilog/source/shift_adder.v +4 -1
- da4ml/codegen/verilog/source/template.xdc +3 -0
- da4ml/codegen/verilog/verilog_model.py +42 -15
- da4ml/converter/__init__.py +0 -0
- da4ml/converter/hgq2/parser.py +105 -0
- da4ml/converter/hgq2/replica.py +383 -0
- da4ml/trace/__init__.py +2 -2
- da4ml/trace/fixed_variable.py +177 -18
- da4ml/trace/fixed_variable_array.py +124 -9
- da4ml/trace/ops/__init__.py +22 -6
- da4ml/trace/ops/conv_utils.py +146 -14
- da4ml/trace/ops/einsum_utils.py +9 -6
- da4ml/trace/ops/reduce_utils.py +103 -0
- da4ml/trace/pipeline.py +36 -34
- da4ml/trace/tracer.py +37 -5
- da4ml-0.3.0.dist-info/METADATA +107 -0
- da4ml-0.3.0.dist-info/RECORD +64 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +0 -17
- da4ml-0.2.0.dist-info/METADATA +0 -65
- da4ml-0.2.0.dist-info/RECORD +0 -39
- /da4ml/codegen/verilog/source/{ioutils.hh → ioutil.hh} +0 -0
- {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/WHEEL +0 -0
- {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.2.0.dist-info → da4ml-0.3.0.dist-info}/top_level.txt +0 -0
da4ml/trace/fixed_variable.py
CHANGED
|
@@ -43,9 +43,9 @@ class FixedVariable:
|
|
|
43
43
|
) -> None:
|
|
44
44
|
assert low <= high, f'low {low} must be less than high {high}'
|
|
45
45
|
|
|
46
|
-
if low == high:
|
|
46
|
+
if low == high and opr != 'new':
|
|
47
47
|
opr = 'const'
|
|
48
|
-
_factor =
|
|
48
|
+
_factor = _factor
|
|
49
49
|
_from = ()
|
|
50
50
|
|
|
51
51
|
low, high, step = Decimal(low), Decimal(high), Decimal(step)
|
|
@@ -72,15 +72,21 @@ class FixedVariable:
|
|
|
72
72
|
self.latency = _latency
|
|
73
73
|
self.cost = _cost
|
|
74
74
|
|
|
75
|
+
# Update latency for constant variables to match the current variable for piplining
|
|
76
|
+
|
|
77
|
+
for v in self._from:
|
|
78
|
+
if v.opr == 'const':
|
|
79
|
+
v.latency = self.latency
|
|
80
|
+
|
|
75
81
|
def get_cost_and_latency(self):
|
|
76
82
|
if self.opr == 'const':
|
|
77
83
|
return 0.0, 0.0
|
|
78
|
-
if self.opr in ('vadd', 'cadd'):
|
|
84
|
+
if self.opr in ('vadd', 'cadd', 'min', 'max'):
|
|
79
85
|
adder_size = self.hwconf.adder_size
|
|
80
86
|
carry_size = self.hwconf.carry_size
|
|
81
87
|
latency_cutoff = self.hwconf.latency_cutoff
|
|
82
88
|
|
|
83
|
-
if self.opr
|
|
89
|
+
if self.opr in ('min', 'max', 'vadd'):
|
|
84
90
|
assert len(self._from) == 2
|
|
85
91
|
v0, v1 = self._from
|
|
86
92
|
int0, int1 = v0.qint, v1.qint
|
|
@@ -89,8 +95,6 @@ class FixedVariable:
|
|
|
89
95
|
else:
|
|
90
96
|
assert len(self._from) == 1
|
|
91
97
|
assert self._data is not None, 'cadd must have data'
|
|
92
|
-
# int0 = self._from[0].qint
|
|
93
|
-
# int1 = QInterval(float(self._data), float(self._data), float(self.step))
|
|
94
98
|
_f = _const_f(self._data)
|
|
95
99
|
_cost = float(ceil(log2(abs(self._data) + Decimal(2) ** -_f))) + _f
|
|
96
100
|
base_latency = self._from[0].latency
|
|
@@ -138,6 +142,12 @@ class FixedVariable:
|
|
|
138
142
|
k = self.low < 0
|
|
139
143
|
return k, i, f
|
|
140
144
|
|
|
145
|
+
@classmethod
|
|
146
|
+
def from_const(cls, const: float | Decimal, hwconf: HWConfig, latency: float, _factor: float | Decimal):
|
|
147
|
+
f = _const_f(const)
|
|
148
|
+
step = Decimal(2) ** -f
|
|
149
|
+
return cls(const, const, step, hwconf=hwconf, opr='const', _factor=_factor, latency=latency)
|
|
150
|
+
|
|
141
151
|
def __repr__(self) -> str:
|
|
142
152
|
if self._factor == 1:
|
|
143
153
|
return f'FixedVariable({self.low}, {self.high}, {self.step})'
|
|
@@ -185,7 +195,9 @@ class FixedVariable:
|
|
|
185
195
|
hwconf=self.hwconf,
|
|
186
196
|
)
|
|
187
197
|
|
|
188
|
-
def _const_add(self, other: float | Decimal):
|
|
198
|
+
def _const_add(self, other: float | Decimal | None):
|
|
199
|
+
if other is None:
|
|
200
|
+
return self
|
|
189
201
|
if not isinstance(other, (int, float, Decimal)):
|
|
190
202
|
other = float(other) # direct numpy to decimal raises error
|
|
191
203
|
other = Decimal(other)
|
|
@@ -222,7 +234,7 @@ class FixedVariable:
|
|
|
222
234
|
other: 'float|Decimal',
|
|
223
235
|
):
|
|
224
236
|
if other == 0:
|
|
225
|
-
return FixedVariable(0, 0, 1, hwconf=self.hwconf)
|
|
237
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
226
238
|
|
|
227
239
|
assert log2(abs(other)) % 1 == 0, 'Only support pow2 multiplication'
|
|
228
240
|
|
|
@@ -266,8 +278,8 @@ class FixedVariable:
|
|
|
266
278
|
step = Decimal(2) ** -f
|
|
267
279
|
i = ceil(log2(val + step)) if not i else i
|
|
268
280
|
eps = step / 2 if round_mode == 'RND' else 0
|
|
269
|
-
val = floor(val / step + eps) % Decimal(2) ** i
|
|
270
|
-
return FixedVariable(val, val, step, hwconf=self.hwconf)
|
|
281
|
+
val = (floor(val / step + eps) * step) % (Decimal(2) ** i)
|
|
282
|
+
return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
|
|
271
283
|
|
|
272
284
|
step = max(Decimal(2) ** -f, self.step) if f is not None else self.step
|
|
273
285
|
if step > self.step and round_mode == 'RND':
|
|
@@ -281,6 +293,10 @@ class FixedVariable:
|
|
|
281
293
|
low = Decimal(0)
|
|
282
294
|
high = _high
|
|
283
295
|
_factor = self._factor
|
|
296
|
+
|
|
297
|
+
if self.low == low and self.high == high and self.step == step:
|
|
298
|
+
return self
|
|
299
|
+
|
|
284
300
|
return FixedVariable(
|
|
285
301
|
low,
|
|
286
302
|
high,
|
|
@@ -301,7 +317,7 @@ class FixedVariable:
|
|
|
301
317
|
round_mode: str = 'TRN',
|
|
302
318
|
):
|
|
303
319
|
overflow_mode, round_mode = overflow_mode.upper(), round_mode.upper()
|
|
304
|
-
assert overflow_mode in ('WRAP', 'SAT')
|
|
320
|
+
assert overflow_mode in ('WRAP', 'SAT', 'SAT_SM')
|
|
305
321
|
assert round_mode in ('TRN', 'RND')
|
|
306
322
|
|
|
307
323
|
_k, _i, _f = self.kif
|
|
@@ -312,32 +328,42 @@ class FixedVariable:
|
|
|
312
328
|
if f < _f and round_mode == 'RND':
|
|
313
329
|
return (self + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
314
330
|
|
|
331
|
+
if overflow_mode in ('SAT', 'SAT_SM'):
|
|
332
|
+
step = Decimal(2) ** -f
|
|
333
|
+
_high = Decimal(2) ** i
|
|
334
|
+
high = _high - step
|
|
335
|
+
low = -_high * k if overflow_mode == 'SAT' else -high * k
|
|
336
|
+
return self.max_of(low).min_of(high).quantize(k, i, f, 'WRAP', round_mode)
|
|
337
|
+
|
|
315
338
|
if self.low == self.high:
|
|
316
339
|
val = self.low
|
|
317
340
|
step = Decimal(2) ** -f
|
|
318
341
|
_high = Decimal(2) ** i
|
|
319
342
|
high, low = _high - step, -_high * k
|
|
320
343
|
val = (floor(val / step) * step - low) % (2 * _high) + low
|
|
321
|
-
return FixedVariable(val, val, step, hwconf=self.hwconf)
|
|
344
|
+
return FixedVariable(val, val, step, hwconf=self.hwconf, opr='const')
|
|
322
345
|
|
|
323
346
|
# TODO: corner cases exists (e.g., overflow to negative, or negative overflow to high value)
|
|
324
347
|
# bit-exactness will be lost in these cases, but they should never happen (quantizers are used in a weird way)
|
|
325
348
|
# Keeping this for now; change if absolutely necessary
|
|
326
349
|
f = min(f, _f)
|
|
327
|
-
k = min(k, _k)
|
|
350
|
+
k = min(k, _k) if i >= _i else k
|
|
328
351
|
i = min(i, _i)
|
|
329
352
|
|
|
330
|
-
|
|
353
|
+
if i + k + f <= 0:
|
|
354
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
355
|
+
|
|
356
|
+
step = Decimal(2) ** -f
|
|
331
357
|
|
|
332
358
|
low = -k * Decimal(2) ** i
|
|
359
|
+
|
|
333
360
|
high = Decimal(2) ** i - step
|
|
334
361
|
_low, _high = self.low, self.high
|
|
335
362
|
|
|
336
363
|
if _low >= low and _high <= high:
|
|
337
364
|
low, high = _low, _high
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
return FixedVariable(0, 0, 1, hwconf=self.hwconf)
|
|
365
|
+
low = floor(low / step) * step
|
|
366
|
+
high = ceil(high / step) * step
|
|
341
367
|
|
|
342
368
|
return FixedVariable(
|
|
343
369
|
low,
|
|
@@ -345,7 +371,7 @@ class FixedVariable:
|
|
|
345
371
|
step,
|
|
346
372
|
_from=(self,),
|
|
347
373
|
_factor=abs(self._factor),
|
|
348
|
-
opr='wrap'
|
|
374
|
+
opr='wrap',
|
|
349
375
|
latency=self.latency,
|
|
350
376
|
hwconf=self.hwconf,
|
|
351
377
|
)
|
|
@@ -356,3 +382,136 @@ class FixedVariable:
|
|
|
356
382
|
_high = Decimal(2) ** i
|
|
357
383
|
low, high = k * _high, _high - step
|
|
358
384
|
return cls(low, high, step, **kwargs)
|
|
385
|
+
|
|
386
|
+
def msb_mux(self, a: 'FixedVariable', b: 'FixedVariable', qint: tuple[Decimal, Decimal, Decimal] | None = None):
|
|
387
|
+
assert isinstance(a, FixedVariable) and isinstance(b, FixedVariable), 'msb_mux requires two FixedVariables'
|
|
388
|
+
if self._factor < 0:
|
|
389
|
+
return (-self).msb_mux(b, a, qint)
|
|
390
|
+
|
|
391
|
+
if a._factor < 0:
|
|
392
|
+
qint = (-qint[1], -qint[0], qint[2]) if qint else None
|
|
393
|
+
return -(self.msb_mux(-a, -b, qint=qint))
|
|
394
|
+
|
|
395
|
+
_factor = a._factor
|
|
396
|
+
|
|
397
|
+
if qint is None:
|
|
398
|
+
qint = (min(a.low, b.low), max(a.high, b.high), min(a.step, b.step))
|
|
399
|
+
|
|
400
|
+
dlat, dcost = cost_add(a.qint, b.qint, 0, False, self.hwconf.adder_size, self.hwconf.carry_size)
|
|
401
|
+
return FixedVariable(
|
|
402
|
+
*qint,
|
|
403
|
+
_from=(self, a, b),
|
|
404
|
+
_factor=_factor,
|
|
405
|
+
opr='msb_mux',
|
|
406
|
+
latency=max(a.latency, b.latency, self.latency) + dlat,
|
|
407
|
+
hwconf=self.hwconf,
|
|
408
|
+
cost=dcost,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
def max_of(self, other):
|
|
412
|
+
if other == 0:
|
|
413
|
+
return self.relu()
|
|
414
|
+
if other == -float('inf'):
|
|
415
|
+
return self
|
|
416
|
+
if other == float('inf'):
|
|
417
|
+
raise ValueError('Cannot apply max_of with inf')
|
|
418
|
+
if not isinstance(other, FixedVariable):
|
|
419
|
+
other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=abs(self._factor))
|
|
420
|
+
|
|
421
|
+
if self.low >= other.high:
|
|
422
|
+
return self
|
|
423
|
+
if self.high <= other.low:
|
|
424
|
+
return other
|
|
425
|
+
|
|
426
|
+
qint = (max(self.low, other.low), max(self.high, other.high), min(self.step, other.step))
|
|
427
|
+
return (self - other).msb_mux(other, self, qint=qint)
|
|
428
|
+
|
|
429
|
+
def min_of(self, other):
|
|
430
|
+
if other == 0:
|
|
431
|
+
return (-self).relu()
|
|
432
|
+
if other == float('inf'):
|
|
433
|
+
return self
|
|
434
|
+
if other == -float('inf'):
|
|
435
|
+
raise ValueError('Cannot apply min_of with -inf')
|
|
436
|
+
if not isinstance(other, FixedVariable):
|
|
437
|
+
other = FixedVariable.from_const(other, hwconf=self.hwconf, latency=self.latency, _factor=(self._factor))
|
|
438
|
+
|
|
439
|
+
if self.high <= other.low:
|
|
440
|
+
return self
|
|
441
|
+
if self.low >= other.high:
|
|
442
|
+
return other
|
|
443
|
+
|
|
444
|
+
qint = (min(self.low, other.low), min(self.high, other.high), min(self.step, other.step))
|
|
445
|
+
return (self - other).msb_mux(self, other, qint=qint)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class FixedVariableInput(FixedVariable):
|
|
449
|
+
def __init__(
|
|
450
|
+
self,
|
|
451
|
+
latency: float | None = None,
|
|
452
|
+
hwconf=HWConfig(-1, -1, -1),
|
|
453
|
+
) -> None:
|
|
454
|
+
self.low = Decimal(1e10)
|
|
455
|
+
self.high = Decimal(-1e10)
|
|
456
|
+
self.step = Decimal(1e10)
|
|
457
|
+
self._factor = Decimal(1)
|
|
458
|
+
self._from: tuple[FixedVariable, ...] = ()
|
|
459
|
+
self.opr = 'new'
|
|
460
|
+
self._data = None
|
|
461
|
+
self.id = uuid4()
|
|
462
|
+
self.hwconf = hwconf
|
|
463
|
+
|
|
464
|
+
self.latency = latency if latency is not None else 0.0
|
|
465
|
+
self.cost = 0.0
|
|
466
|
+
|
|
467
|
+
def __add__(self, other):
|
|
468
|
+
raise ValueError('Cannot operate on unquantized input variable')
|
|
469
|
+
|
|
470
|
+
def __sub__(self, other):
|
|
471
|
+
raise ValueError('Cannot operate on unquantized input variable')
|
|
472
|
+
|
|
473
|
+
def __neg__(self):
|
|
474
|
+
raise ValueError('Cannot negate unquantized input variable')
|
|
475
|
+
|
|
476
|
+
def relu(self, *args, **kwargs):
|
|
477
|
+
raise ValueError('Cannot apply relu on unquantized input variable')
|
|
478
|
+
|
|
479
|
+
def max_of(self, other):
|
|
480
|
+
raise ValueError('Cannot apply max_of on unquantized input variable')
|
|
481
|
+
|
|
482
|
+
def min_of(self, other):
|
|
483
|
+
raise ValueError('Cannot apply min_of on unquantized input variable')
|
|
484
|
+
|
|
485
|
+
def quantize(
|
|
486
|
+
self,
|
|
487
|
+
k: int | bool,
|
|
488
|
+
i: int,
|
|
489
|
+
f: int,
|
|
490
|
+
overflow_mode: str = 'WRAP',
|
|
491
|
+
round_mode: str = 'TRN',
|
|
492
|
+
):
|
|
493
|
+
assert overflow_mode == 'WRAP'
|
|
494
|
+
|
|
495
|
+
if k + i + f <= 0:
|
|
496
|
+
return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
|
|
497
|
+
|
|
498
|
+
if round_mode == 'RND':
|
|
499
|
+
return (self.quantize(k, i, f + 1) + 2.0 ** (-f - 1)).quantize(k, i, f, overflow_mode, 'TRN')
|
|
500
|
+
|
|
501
|
+
step = Decimal(2) ** -f
|
|
502
|
+
_high = Decimal(2) ** i
|
|
503
|
+
low, high = -_high * k, _high - step
|
|
504
|
+
self.high = max(self.high, high)
|
|
505
|
+
self.low = min(self.low, low)
|
|
506
|
+
self.step = min(self.step, step)
|
|
507
|
+
|
|
508
|
+
return FixedVariable(
|
|
509
|
+
low,
|
|
510
|
+
high,
|
|
511
|
+
step,
|
|
512
|
+
_from=(self,),
|
|
513
|
+
_factor=self._factor,
|
|
514
|
+
opr='wrap',
|
|
515
|
+
latency=self.latency,
|
|
516
|
+
hwconf=self.hwconf,
|
|
517
|
+
)
|
|
@@ -1,20 +1,110 @@
|
|
|
1
|
-
from
|
|
1
|
+
from inspect import signature
|
|
2
|
+
from typing import Any, TypeVar
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
5
|
+
from numba.typed import List as NumbaList
|
|
4
6
|
from numpy.typing import NDArray
|
|
5
7
|
|
|
6
8
|
from ..cmvm import solve
|
|
7
|
-
from .fixed_variable import FixedVariable, HWConfig, QInterval
|
|
9
|
+
from .fixed_variable import FixedVariable, FixedVariableInput, HWConfig, QInterval
|
|
10
|
+
from .ops import einsum, reduce
|
|
11
|
+
|
|
12
|
+
T = TypeVar('T')
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def to_raw_arr(obj: T) -> T:
|
|
16
|
+
if isinstance(obj, tuple):
|
|
17
|
+
return tuple(to_raw_arr(x) for x in obj) # type: ignore
|
|
18
|
+
elif isinstance(obj, list):
|
|
19
|
+
return [to_raw_arr(x) for x in obj] # type: ignore
|
|
20
|
+
elif isinstance(obj, dict):
|
|
21
|
+
return {k: to_raw_arr(v) for k, v in obj.items()} # type: ignore
|
|
22
|
+
if isinstance(obj, FixedVariableArray):
|
|
23
|
+
return obj._vars # type: ignore
|
|
24
|
+
return obj
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _max_of(a, b):
|
|
28
|
+
if isinstance(a, FixedVariable):
|
|
29
|
+
return a.max_of(b)
|
|
30
|
+
elif isinstance(b, FixedVariable):
|
|
31
|
+
return b.max_of(a)
|
|
32
|
+
else:
|
|
33
|
+
return max(a, b)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _min_of(a, b):
|
|
37
|
+
if isinstance(a, FixedVariable):
|
|
38
|
+
return a.min_of(b)
|
|
39
|
+
elif isinstance(b, FixedVariable):
|
|
40
|
+
return b.min_of(a)
|
|
41
|
+
else:
|
|
42
|
+
return min(a, b)
|
|
8
43
|
|
|
9
44
|
|
|
10
45
|
class FixedVariableArray:
|
|
46
|
+
__array_priority__ = 100
|
|
47
|
+
|
|
48
|
+
def __array_function__(self, func, types, args, kwargs):
|
|
49
|
+
if func is np.matmul:
|
|
50
|
+
if len(args) == 1 and isinstance(args[0], np.ndarray):
|
|
51
|
+
return self.__matmul__(args[0])
|
|
52
|
+
elif len(args) == 2 and isinstance(args[0], np.ndarray) and isinstance(args[1], np.ndarray):
|
|
53
|
+
return self.__rmatmul__(args[1])
|
|
54
|
+
|
|
55
|
+
if func in (np.mean, np.sum, np.amax, np.amin, np.max, np.min):
|
|
56
|
+
match func:
|
|
57
|
+
case np.mean:
|
|
58
|
+
_x = reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
|
|
59
|
+
return _x * (_x.size / self._vars.size)
|
|
60
|
+
case np.sum:
|
|
61
|
+
return reduce(lambda x, y: x + y, self, *args[1:], **kwargs)
|
|
62
|
+
case np.max | np.amax:
|
|
63
|
+
return reduce(_max_of, self, *args[1:], **kwargs)
|
|
64
|
+
case np.min | np.amin:
|
|
65
|
+
return reduce(_min_of, self, *args[1:], **kwargs)
|
|
66
|
+
case _:
|
|
67
|
+
raise NotImplementedError(f'Unsupported function: {func}')
|
|
68
|
+
|
|
69
|
+
if func is np.clip:
|
|
70
|
+
assert len(args) == 3, 'Clip function requires exactly three arguments'
|
|
71
|
+
x, low, high = args
|
|
72
|
+
_x, low, high = np.broadcast_arrays(x, low, high)
|
|
73
|
+
x = FixedVariableArray(_x, self.solver_options)
|
|
74
|
+
x = np.amax(np.stack((x, low), axis=-1), axis=-1) # type: ignore
|
|
75
|
+
return np.amin(np.stack((x, high), axis=-1), axis=-1)
|
|
76
|
+
|
|
77
|
+
if func is np.einsum:
|
|
78
|
+
# assert len(args) == 2
|
|
79
|
+
sig = signature(np.einsum)
|
|
80
|
+
bind = sig.bind(*args, **kwargs)
|
|
81
|
+
eq = args[0]
|
|
82
|
+
operands = bind.arguments['operands']
|
|
83
|
+
if isinstance(operands[0], str):
|
|
84
|
+
operands = operands[1:]
|
|
85
|
+
assert len(operands) == 2, 'Einsum on FixedVariableArray requires exactly two operands'
|
|
86
|
+
assert bind.arguments.get('out', None) is None, 'Output argument is not supported'
|
|
87
|
+
return einsum(eq, *operands)
|
|
88
|
+
|
|
89
|
+
args, kwargs = to_raw_arr(args), to_raw_arr(kwargs)
|
|
90
|
+
return FixedVariableArray(
|
|
91
|
+
func(*args, **kwargs),
|
|
92
|
+
self.solver_options,
|
|
93
|
+
)
|
|
94
|
+
|
|
11
95
|
def __init__(
|
|
12
96
|
self,
|
|
13
97
|
vars: NDArray,
|
|
14
98
|
solver_options: dict[str, Any] | None = None,
|
|
15
99
|
):
|
|
16
100
|
self._vars = np.array(vars)
|
|
17
|
-
|
|
101
|
+
_solver_options = signature(solve).parameters
|
|
102
|
+
_solver_options = {k: v.default for k, v in _solver_options.items() if v.default is not v.empty}
|
|
103
|
+
if solver_options is not None:
|
|
104
|
+
_solver_options.update(solver_options)
|
|
105
|
+
_solver_options.pop('qintervals', None)
|
|
106
|
+
_solver_options.pop('latencies', None)
|
|
107
|
+
self.solver_options = _solver_options
|
|
18
108
|
|
|
19
109
|
@classmethod
|
|
20
110
|
def from_lhs(
|
|
@@ -75,8 +165,10 @@ class FixedVariableArray:
|
|
|
75
165
|
r = []
|
|
76
166
|
for i in range(mat0.shape[0]):
|
|
77
167
|
vec = mat0[i]
|
|
78
|
-
|
|
79
|
-
|
|
168
|
+
_qintervals = [QInterval(float(v.low), float(v.high), float(v.step)) for v in vec._vars]
|
|
169
|
+
_latencies = [float(v.latency) for v in vec._vars]
|
|
170
|
+
qintervals = NumbaList(_qintervals) # type: ignore
|
|
171
|
+
latencies = NumbaList(_latencies) # type: ignore
|
|
80
172
|
hwconf = self._vars.ravel()[0].hwconf
|
|
81
173
|
kwargs.update(adder_size=hwconf.adder_size, carry_size=hwconf.carry_size)
|
|
82
174
|
_mat = np.ascontiguousarray(mat1.astype(np.float32))
|
|
@@ -96,8 +188,8 @@ class FixedVariableArray:
|
|
|
96
188
|
axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
|
|
97
189
|
return r.transpose(axes)
|
|
98
190
|
|
|
99
|
-
def __getitem__(self,
|
|
100
|
-
vars = self._vars[
|
|
191
|
+
def __getitem__(self, item):
|
|
192
|
+
vars = self._vars[item]
|
|
101
193
|
if isinstance(vars, np.ndarray):
|
|
102
194
|
return FixedVariableArray(vars, self.solver_options)
|
|
103
195
|
else:
|
|
@@ -111,9 +203,13 @@ class FixedVariableArray:
|
|
|
111
203
|
return self._vars.shape
|
|
112
204
|
|
|
113
205
|
def __add__(self, other):
|
|
206
|
+
if isinstance(other, FixedVariableArray):
|
|
207
|
+
return FixedVariableArray(self._vars + other._vars, self.solver_options)
|
|
114
208
|
return FixedVariableArray(self._vars + other, self.solver_options)
|
|
115
209
|
|
|
116
210
|
def __sub__(self, other):
|
|
211
|
+
if isinstance(other, FixedVariableArray):
|
|
212
|
+
return FixedVariableArray(self._vars - other._vars, self.solver_options)
|
|
117
213
|
return FixedVariableArray(self._vars - other, self.solver_options)
|
|
118
214
|
|
|
119
215
|
def __mul__(self, other):
|
|
@@ -139,7 +235,7 @@ class FixedVariableArray:
|
|
|
139
235
|
i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
|
|
140
236
|
f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
|
|
141
237
|
ret = []
|
|
142
|
-
for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()):
|
|
238
|
+
for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()): # type: ignore
|
|
143
239
|
ret.append(v.relu(i=i, f=f, round_mode=round_mode))
|
|
144
240
|
return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
|
|
145
241
|
|
|
@@ -156,7 +252,7 @@ class FixedVariableArray:
|
|
|
156
252
|
i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
|
|
157
253
|
f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
|
|
158
254
|
ret = []
|
|
159
|
-
for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()):
|
|
255
|
+
for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()): # type: ignore
|
|
160
256
|
ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
|
|
161
257
|
return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
|
|
162
258
|
|
|
@@ -175,3 +271,22 @@ class FixedVariableArray:
|
|
|
175
271
|
@property
|
|
176
272
|
def dtype(self):
|
|
177
273
|
return self._vars.dtype
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def size(self):
|
|
277
|
+
return self._vars.size
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def kif(self):
|
|
281
|
+
shape = self._vars.shape
|
|
282
|
+
kif = np.array([v.kif for v in self._vars.ravel()]).reshape(*shape, 3)
|
|
283
|
+
return np.moveaxis(kif, -1, 0)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class FixedVariableArrayInput(FixedVariableArray):
|
|
287
|
+
def __init__(self, shape: tuple[int, ...] | int, hwconf: HWConfig, solver_options: dict[str, Any] | None = None, latency=0.0):
|
|
288
|
+
_vars = np.empty(shape, dtype=object)
|
|
289
|
+
_vars_f = _vars.ravel()
|
|
290
|
+
for i in range(_vars.size):
|
|
291
|
+
_vars_f[i] = FixedVariableInput(latency, hwconf)
|
|
292
|
+
super().__init__(_vars, solver_options)
|
da4ml/trace/ops/__init__.py
CHANGED
|
@@ -1,16 +1,22 @@
|
|
|
1
|
-
from typing import TypeVar
|
|
1
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from numpy.typing import NDArray
|
|
5
5
|
|
|
6
|
-
from ..fixed_variable_array import FixedVariable
|
|
7
|
-
from .conv_utils import conv
|
|
6
|
+
from ..fixed_variable_array import FixedVariable
|
|
7
|
+
from .conv_utils import conv, pool
|
|
8
8
|
from .einsum_utils import einsum
|
|
9
|
+
from .reduce_utils import reduce
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
13
|
+
|
|
14
|
+
T = TypeVar('T', 'FixedVariableArray', NDArray[np.floating], list[FixedVariable])
|
|
11
15
|
|
|
12
16
|
|
|
13
17
|
def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN') -> T:
|
|
18
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
19
|
+
|
|
14
20
|
if isinstance(x, FixedVariableArray):
|
|
15
21
|
return x.relu(i=i, f=f, round_mode=round_mode)
|
|
16
22
|
elif isinstance(x, list):
|
|
@@ -35,12 +41,20 @@ def quantize(
|
|
|
35
41
|
overflow_mode: str = 'WRAP',
|
|
36
42
|
round_mode: str = 'TRN',
|
|
37
43
|
) -> T:
|
|
38
|
-
|
|
44
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
45
|
+
|
|
39
46
|
if isinstance(x, FixedVariableArray):
|
|
40
47
|
return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
|
|
41
48
|
else:
|
|
49
|
+
x = x.copy()
|
|
50
|
+
if overflow_mode in ('SAT', 'SAT_SM'):
|
|
51
|
+
step = 2.0**-f
|
|
52
|
+
_high = 2.0**i
|
|
53
|
+
high = _high - step
|
|
54
|
+
low = -_high * k if overflow_mode == 'SAT' else -high * k
|
|
55
|
+
x = np.clip(x, low, high) # type: ignore
|
|
42
56
|
if round_mode.upper() == 'RND':
|
|
43
|
-
x += 2.0 ** (-f - 1)
|
|
57
|
+
x += 2.0 ** (-f - 1) # type: ignore
|
|
44
58
|
b = k + i + f
|
|
45
59
|
bias = 2.0 ** (b - 1) * k
|
|
46
60
|
eps = 2.0**-f
|
|
@@ -52,4 +66,6 @@ __all__ = [
|
|
|
52
66
|
'einsum',
|
|
53
67
|
'relu',
|
|
54
68
|
'quantize',
|
|
69
|
+
'pool',
|
|
70
|
+
'reduce',
|
|
55
71
|
]
|