da4ml 0.5.1.post1__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. da4ml/__init__.py +4 -0
  2. da4ml/_binary/__init__.py +15 -0
  3. da4ml/_binary/dais_bin.cpython-311-x86_64-linux-gnu.so +0 -0
  4. da4ml/_binary/dais_bin.pyi +5 -0
  5. da4ml/_cli/__init__.py +30 -0
  6. da4ml/_cli/convert.py +204 -0
  7. da4ml/_cli/report.py +295 -0
  8. da4ml/_version.py +32 -0
  9. da4ml/cmvm/__init__.py +4 -0
  10. da4ml/cmvm/api.py +264 -0
  11. da4ml/cmvm/core/__init__.py +221 -0
  12. da4ml/cmvm/core/indexers.py +83 -0
  13. da4ml/cmvm/core/state_opr.py +284 -0
  14. da4ml/cmvm/types.py +739 -0
  15. da4ml/cmvm/util/__init__.py +7 -0
  16. da4ml/cmvm/util/bit_decompose.py +86 -0
  17. da4ml/cmvm/util/mat_decompose.py +121 -0
  18. da4ml/codegen/__init__.py +9 -0
  19. da4ml/codegen/hls/__init__.py +4 -0
  20. da4ml/codegen/hls/hls_codegen.py +196 -0
  21. da4ml/codegen/hls/hls_model.py +255 -0
  22. da4ml/codegen/hls/source/ap_types/ap_binary.h +78 -0
  23. da4ml/codegen/hls/source/ap_types/ap_common.h +376 -0
  24. da4ml/codegen/hls/source/ap_types/ap_decl.h +212 -0
  25. da4ml/codegen/hls/source/ap_types/ap_fixed.h +360 -0
  26. da4ml/codegen/hls/source/ap_types/ap_fixed_base.h +2354 -0
  27. da4ml/codegen/hls/source/ap_types/ap_fixed_ref.h +718 -0
  28. da4ml/codegen/hls/source/ap_types/ap_fixed_special.h +230 -0
  29. da4ml/codegen/hls/source/ap_types/ap_int.h +330 -0
  30. da4ml/codegen/hls/source/ap_types/ap_int_base.h +1885 -0
  31. da4ml/codegen/hls/source/ap_types/ap_int_ref.h +1346 -0
  32. da4ml/codegen/hls/source/ap_types/ap_int_special.h +223 -0
  33. da4ml/codegen/hls/source/ap_types/ap_shift_reg.h +138 -0
  34. da4ml/codegen/hls/source/ap_types/etc/ap_private.h +7199 -0
  35. da4ml/codegen/hls/source/ap_types/hls_math.h +27 -0
  36. da4ml/codegen/hls/source/ap_types/hls_stream.h +263 -0
  37. da4ml/codegen/hls/source/ap_types/utils/x_hls_utils.h +80 -0
  38. da4ml/codegen/hls/source/binder_util.hh +71 -0
  39. da4ml/codegen/hls/source/build_binder.mk +22 -0
  40. da4ml/codegen/hls/source/vitis_bitshift.hh +32 -0
  41. da4ml/codegen/rtl/__init__.py +15 -0
  42. da4ml/codegen/rtl/common_source/binder_util.hh +99 -0
  43. da4ml/codegen/rtl/common_source/build_binder.mk +34 -0
  44. da4ml/codegen/rtl/common_source/build_quartus_prj.tcl +104 -0
  45. da4ml/codegen/rtl/common_source/build_vivado_prj.tcl +111 -0
  46. da4ml/codegen/rtl/common_source/ioutil.hh +124 -0
  47. da4ml/codegen/rtl/common_source/template.sdc +27 -0
  48. da4ml/codegen/rtl/common_source/template.xdc +30 -0
  49. da4ml/codegen/rtl/rtl_model.py +486 -0
  50. da4ml/codegen/rtl/verilog/__init__.py +10 -0
  51. da4ml/codegen/rtl/verilog/comb.py +239 -0
  52. da4ml/codegen/rtl/verilog/io_wrapper.py +113 -0
  53. da4ml/codegen/rtl/verilog/pipeline.py +67 -0
  54. da4ml/codegen/rtl/verilog/source/lookup_table.v +27 -0
  55. da4ml/codegen/rtl/verilog/source/multiplier.v +37 -0
  56. da4ml/codegen/rtl/verilog/source/mux.v +58 -0
  57. da4ml/codegen/rtl/verilog/source/negative.v +31 -0
  58. da4ml/codegen/rtl/verilog/source/shift_adder.v +59 -0
  59. da4ml/codegen/rtl/vhdl/__init__.py +9 -0
  60. da4ml/codegen/rtl/vhdl/comb.py +206 -0
  61. da4ml/codegen/rtl/vhdl/io_wrapper.py +120 -0
  62. da4ml/codegen/rtl/vhdl/pipeline.py +71 -0
  63. da4ml/codegen/rtl/vhdl/source/lookup_table.vhd +52 -0
  64. da4ml/codegen/rtl/vhdl/source/multiplier.vhd +40 -0
  65. da4ml/codegen/rtl/vhdl/source/mux.vhd +102 -0
  66. da4ml/codegen/rtl/vhdl/source/negative.vhd +35 -0
  67. da4ml/codegen/rtl/vhdl/source/shift_adder.vhd +101 -0
  68. da4ml/converter/__init__.py +63 -0
  69. da4ml/converter/hgq2/__init__.py +3 -0
  70. da4ml/converter/hgq2/layers/__init__.py +11 -0
  71. da4ml/converter/hgq2/layers/_base.py +132 -0
  72. da4ml/converter/hgq2/layers/activation.py +81 -0
  73. da4ml/converter/hgq2/layers/attn.py +148 -0
  74. da4ml/converter/hgq2/layers/batchnorm.py +15 -0
  75. da4ml/converter/hgq2/layers/conv.py +149 -0
  76. da4ml/converter/hgq2/layers/dense.py +39 -0
  77. da4ml/converter/hgq2/layers/ops.py +246 -0
  78. da4ml/converter/hgq2/layers/pool.py +107 -0
  79. da4ml/converter/hgq2/layers/table.py +176 -0
  80. da4ml/converter/hgq2/parser.py +161 -0
  81. da4ml/trace/__init__.py +6 -0
  82. da4ml/trace/fixed_variable.py +965 -0
  83. da4ml/trace/fixed_variable_array.py +600 -0
  84. da4ml/trace/ops/__init__.py +13 -0
  85. da4ml/trace/ops/einsum_utils.py +305 -0
  86. da4ml/trace/ops/quantization.py +74 -0
  87. da4ml/trace/ops/reduce_utils.py +105 -0
  88. da4ml/trace/pipeline.py +181 -0
  89. da4ml/trace/tracer.py +186 -0
  90. da4ml/typing/__init__.py +3 -0
  91. da4ml-0.5.1.post1.dist-info/METADATA +85 -0
  92. da4ml-0.5.1.post1.dist-info/RECORD +96 -0
  93. da4ml-0.5.1.post1.dist-info/WHEEL +6 -0
  94. da4ml-0.5.1.post1.dist-info/entry_points.txt +3 -0
  95. da4ml-0.5.1.post1.dist-info/sboms/auditwheel.cdx.json +1 -0
  96. da4ml.libs/libgomp-e985bcbb.so.1.0.0 +0 -0
da4ml/cmvm/types.py ADDED
@@ -0,0 +1,739 @@
1
+ import json
2
+ from collections.abc import Sequence
3
+ from decimal import Decimal
4
+ from functools import reduce, singledispatch
5
+ from math import ceil, floor, log2
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, NamedTuple, TypeVar
8
+
9
+ import numpy as np
10
+ from numba import jit
11
+ from numpy import float32, int8
12
+ from numpy.typing import NDArray
13
+
14
+ from .._binary import dais_interp_run
15
+
16
+ if TYPE_CHECKING:
17
+ from ..trace.fixed_variable import FixedVariable, LookupTable
18
+
19
+
20
+ class QInterval(NamedTuple):
21
+ """A class representing a quantized interval: [min, max] with a step size."""
22
+
23
+ min: float
24
+ max: float
25
+ step: float
26
+
27
+ @classmethod
28
+ def from_kif(cls, k: int | bool, i: int, f: int):
29
+ _high = 2.0**i
30
+ step = 2.0**-f
31
+ low, high = -k * step, _high - step
32
+ return cls(low, high, step)
33
+
34
+ @classmethod
35
+ def from_precision(cls, prec: 'Precision'):
36
+ return cls.from_kif(*prec)
37
+
38
+ @property
39
+ def precision(self):
40
+ return Precision.from_qint(self)
41
+
42
+ def __repr__(self):
43
+ return f'[{self.min}, {self.max}, {self.step}]'
44
+
45
+
46
+ class Precision(NamedTuple):
47
+ """A class representing the precision of a quantized interval."""
48
+
49
+ keep_negative: bool
50
+ integers: int
51
+ fractional: int
52
+
53
+ def __str__(self):
54
+ k, i, f = self.keep_negative, self.integers, self.fractional
55
+ return f'fixed({k=}, {i=}, {f=})'
56
+
57
+ def __repr__(self):
58
+ return str(self)
59
+
60
+ @classmethod
61
+ def from_qint(cls, qint: QInterval, symmetric: bool = False):
62
+ return _minimal_kif(qint, symmetric=symmetric)
63
+
64
+ @property
65
+ def qint(self):
66
+ return QInterval.from_kif(*self)
67
+
68
+
69
+ class Op(NamedTuple):
70
+ """One single operation on the data buffer.
71
+
72
+ Parameters
73
+ ----------
74
+ id0: int
75
+ index of the first operand
76
+ id1: int
77
+ index of the second operand, or special opcode if negative
78
+ opcode: int
79
+ 0: addition, 1: subtraction, 2: relu, 3: quantize, 4: const addition
80
+ data: int
81
+ Data to be used in the operation
82
+ qint: QInterval
83
+ Quantization interval of the resultant buffer
84
+ latency: float
85
+ Latency of the data generated by this operation (t_available)
86
+ cost: float
87
+ Cost of the operation
88
+ """
89
+
90
+ id0: int
91
+ id1: int
92
+ opcode: int
93
+ data: int
94
+ qint: QInterval
95
+ latency: float
96
+ cost: float
97
+
98
+
99
+ class Pair(NamedTuple):
100
+ """An operation representing data[id0] +/- data[id1] * 2**shift."""
101
+
102
+ id0: int
103
+ id1: int
104
+ sub: bool
105
+ shift: int
106
+
107
+
108
+ class DAState(NamedTuple):
109
+ """Internal state of the DA algorithm."""
110
+
111
+ shifts: tuple[NDArray[int8], NDArray[int8]]
112
+ expr: list[NDArray[int8]]
113
+ ops: list[Op]
114
+ freq_stat: dict[Pair, int]
115
+ kernel: NDArray[float32]
116
+
117
+
118
+ def _minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision:
119
+ """Calculate the minimal KIF for a given QInterval.
120
+
121
+ Parameters
122
+ ----------
123
+ qi : QInterval
124
+ The QInterval to calculate the KIF for.
125
+ symmetric : bool
126
+ Only relevant if qi may be negative. If True, -2**i will be regarded as forbidden.
127
+ May be useful in special cases only.
128
+ Default is False.
129
+
130
+ Returns
131
+ -------
132
+ Precision
133
+ A named tuple with the KIF values.
134
+ """
135
+
136
+ if qi.min == qi.max == 0:
137
+ return Precision(keep_negative=False, integers=0, fractional=0)
138
+ keep_negative = qi.min < 0
139
+ fractional = int(-log2(qi.step))
140
+ int_min, int_max = round(qi.min / qi.step), round(qi.max / qi.step)
141
+ if symmetric:
142
+ bits = int(ceil(log2(max(abs(int_min), int_max) + 1)))
143
+ else:
144
+ bits = int(ceil(log2(max(abs(int_min), int_max + 1))))
145
+ integers = bits - fractional
146
+ return Precision(keep_negative=keep_negative, integers=integers, fractional=fractional)
147
+
148
+
149
+ if TYPE_CHECKING:
150
+
151
+ def minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision: ...
152
+ else:
153
+ minimal_kif = jit(_minimal_kif)
154
+
155
+
156
+ T = TypeVar('T', 'FixedVariable', float, int, np.float32, np.float64, Decimal)
157
+
158
+
159
+ @singledispatch
160
+ def _relu(v: 'T', i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN') -> 'T':
161
+ from ..trace.fixed_variable import FixedVariable
162
+
163
+ assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic relu'
164
+ if inv:
165
+ v = -v
166
+ return v.relu(i, f, round_mode=round_mode)
167
+
168
+
169
+ @_relu.register(float)
170
+ @_relu.register(int)
171
+ @_relu.register(np.float32)
172
+ @_relu.register(np.float64)
173
+ def _(v, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
174
+ if inv:
175
+ v = -v
176
+ v = max(0, v)
177
+ if f is not None:
178
+ if round_mode.upper() == 'RND':
179
+ v += 2.0 ** (-f - 1)
180
+ sf = 2.0**f
181
+ v = floor(v * sf) / sf
182
+ if i is not None:
183
+ v = v % 2.0**i
184
+ return v
185
+
186
+
187
+ @_relu.register
188
+ def _(v: Decimal, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
189
+ if inv:
190
+ v = -v
191
+ v = max(Decimal(0), v)
192
+ if f is not None:
193
+ if round_mode.upper() == 'RND':
194
+ v += Decimal(2) ** (-f - 1)
195
+ sf = Decimal(2) ** f
196
+ v = floor(v * sf) / sf
197
+ if i is not None:
198
+ v = v % Decimal(2) ** i
199
+ return v
200
+
201
+
202
+ @singledispatch
203
+ def _quantize(v: 'T', k: int | bool, i: int, f: int, round_mode: str = 'TRN') -> 'T':
204
+ from ..trace.fixed_variable import FixedVariable
205
+
206
+ assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic quantization'
207
+ return v.quantize(k, i, f, round_mode=round_mode)
208
+
209
+
210
+ @_quantize.register(float)
211
+ @_quantize.register(int)
212
+ @_quantize.register(np.float32)
213
+ @_quantize.register(np.float64)
214
+ def _(v, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
215
+ if round_mode.upper() == 'RND':
216
+ v += 2.0 ** (-f - 1)
217
+ b = k + i + f
218
+ bias = 2.0 ** (b - 1) * k
219
+ eps = 2.0**-f
220
+ return eps * ((np.floor(v / eps) + bias) % 2**b - bias)
221
+
222
+
223
+ @_quantize.register
224
+ def _(v: Decimal, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
225
+ if round_mode.upper() == 'RND':
226
+ v += Decimal(2) ** (-f - 1)
227
+ b = k + i + f
228
+ bias = Decimal(2) ** (b - 1) * k
229
+ eps = Decimal(2) ** -f
230
+ return eps * ((floor(v / eps) + bias) % Decimal(2) ** b - bias)
231
+
232
+
233
+ class JSONEncoder(json.JSONEncoder):
234
+ def default(self, o):
235
+ if hasattr(o, 'to_dict'):
236
+ return o.to_dict()
237
+ super().default(o)
238
+
239
+
240
+ class CombLogic(NamedTuple):
241
+ """A combinational logic that describes a series of operations on input data to produce output data.
242
+
243
+ Attributes
244
+ ----------
245
+ shape: tuple[int, int]
246
+ #input, #output
247
+ inp_shifts: list[int]
248
+ The shifts that should be applied to the input data.
249
+ out_idxs: list[int]
250
+ The indices of the output data in the buffer.
251
+ out_shifts: list[int]
252
+ The shifts that should be applied to the output data.
253
+ out_negs: list[bool]
254
+ The signs of the output data.
255
+ ops: list[Op]
256
+ Core list of operations for generating each buffer element.
257
+ carry_size: int
258
+ Size of the carrier for the adder, used for cost and latency estimation.
259
+ adder_size: int
260
+ Elementary size of the adder, used for cost and latency estimation.
261
+ lookup_tables: tuple[LookupTable, ...] | None
262
+ Lookup table arrays for lookup operations, if any.
263
+
264
+
265
+ The core part of the comb logic is the operations in the ops list.
266
+ For the exact operations executed with Op, refer to the Op class.
267
+ After all operations are executed, the output data is read from data[op.out_idx] and multiplied by 2**out_shift.
268
+
269
+ """
270
+
271
+ shape: tuple[int, int]
272
+ inp_shifts: list[int]
273
+ out_idxs: list[int]
274
+ out_shifts: list[int]
275
+ out_negs: list[bool]
276
+ ops: list[Op]
277
+ carry_size: int
278
+ adder_size: int
279
+ lookup_tables: 'tuple[LookupTable, ...] | None' = None
280
+
281
+ def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False, dump=False):
282
+ """Executes the solution on the input data.
283
+
284
+ Parameters
285
+ ----------
286
+ inp : list | np.ndarray | tuple
287
+ Input data to be processed. The input data should be a list or numpy array of objects.
288
+ quantize : bool
289
+ If True, the input data will be quantized to the output quantization intervals.
290
+ Only floating point data types are supported when quantize is True.
291
+ Default is False.
292
+ debug : bool
293
+ If True, the function will print debug information about the operations being performed.
294
+ Default is False.
295
+ dump : bool
296
+ If True, the return the whole buffer, without applying the output shifts and signs.
297
+ Default is False.
298
+
299
+ Returns
300
+ -------
301
+ np.ndarray
302
+ The output data after applying the operations defined in the solution.
303
+
304
+ """
305
+
306
+ from ..trace.fixed_variable import FixedVariable
307
+
308
+ buf = np.empty(len(self.ops), dtype=object)
309
+ inp = np.asarray(inp)
310
+
311
+ inp_qint = [op.qint for op in self.ops if op.opcode == -1]
312
+ if quantize: # TRN and WRAP
313
+ k, i, f = map(np.array, zip(*map(minimal_kif, inp_qint)))
314
+ inp = [_quantize(*x, round_mode='TRN') for x in zip(inp, k, i, f)]
315
+
316
+ inp = inp * (2.0 ** np.array(self.inp_shifts))
317
+ for i, op in enumerate(self.ops):
318
+ match op.opcode:
319
+ case -1: # copy form external buffer
320
+ buf[i] = inp[op.id0]
321
+ case 0 | 1: # addition
322
+ v0, v1 = buf[op.id0], 2.0**op.data * buf[op.id1]
323
+ buf[i] = v0 + v1 if op.opcode == 0 else v0 - v1
324
+ case 2 | -2: # relu(+/-x)
325
+ v = buf[op.id0]
326
+ _, _i, _f = _minimal_kif(op.qint)
327
+ buf[i] = _relu(v, _i, _f, inv=op.opcode == -2, round_mode='TRN')
328
+ case 3 | -3: # quantize(+/-x)
329
+ v = buf[op.id0] if op.opcode == 3 else -buf[op.id0]
330
+ _k, _i, _f = _minimal_kif(op.qint)
331
+ buf[i] = _quantize(v, _k, _i, _f, round_mode='TRN')
332
+ case 4: # const addition
333
+ bias = op.data * op.qint.step
334
+ buf[i] = buf[op.id0] + bias
335
+ case 5: # const definition
336
+ buf[i] = op.data * op.qint.step # const definition
337
+ case 6 | -6: # MSB Mux
338
+ id_c = op.data & 0xFFFFFFFF
339
+ k, v0, v1 = buf[id_c], buf[op.id0], buf[op.id1]
340
+ shift = (op.data >> 32) & 0xFFFFFFFF
341
+ shift = shift if shift < 0x80000000 else shift - 0x100000000
342
+ if op.opcode == -6:
343
+ v1 = -v1
344
+
345
+ if isinstance(k, FixedVariable):
346
+ buf[i] = k.msb_mux(v0, v1 * 2**shift, op.qint)
347
+ else:
348
+ qint_k = self.ops[id_c].qint
349
+ if qint_k.min < 0:
350
+ buf[i] = v0 if k < 0 else v1 * 2.0**shift
351
+ else:
352
+ _k, _i, _f = _minimal_kif(qint_k)
353
+ buf[i] = v0 if k >= 2.0 ** (_i - 1) else v1 * 2.0**shift
354
+ case 7:
355
+ v0, v1 = buf[op.id0], buf[op.id1]
356
+ buf[i] = v0 * v1
357
+ case 8:
358
+ v0 = buf[op.id0]
359
+ tables = self.lookup_tables
360
+ assert tables is not None, 'No lookup table provided for lookup operation'
361
+ table = tables[op.data]
362
+ buf[i] = table.lookup(v0, self.ops[op.id0].qint)
363
+ case _:
364
+ raise ValueError(f'Unknown opcode {op.opcode} in {op}')
365
+
366
+ sf = 2.0 ** np.array(self.out_shifts, dtype=np.float64)
367
+ sign = np.where(self.out_negs, -1, 1)
368
+ out_idx = np.array(self.out_idxs, dtype=np.int32)
369
+ mask = np.where(out_idx < 0, 0, 1)
370
+ if debug:
371
+ operands = []
372
+ for i, v in enumerate(buf):
373
+ op = self.ops[i]
374
+ match op.opcode:
375
+ case -1:
376
+ op_str = 'inp'
377
+ case 0 | 1:
378
+ _sign = '-' if op.opcode == 1 else '+'
379
+ op_str = f'buf[{op.id0}] {_sign} buf[{op.id1}]<<{op.data}'
380
+ case 2 | -2:
381
+ _sign = '' if op.opcode == 2 else '-'
382
+ op_str = f'relu({_sign}buf[{op.id0}])'
383
+ case 3 | -3:
384
+ _sign = '' if op.opcode == 3 else '-'
385
+ op_str = f'quantize({_sign}buf[{op.id0}])'
386
+ case 4:
387
+ op_str = f'buf[{op.id0}] + {op.data * op.qint.step}'
388
+ case 5:
389
+ op_str = f'const {op.data * op.qint.step}'
390
+ case 6 | -6:
391
+ _sign = '-' if op.opcode == -6 else ''
392
+ op_str = f'msb(buf[{op.data}]) ? buf[{op.id0}] : {_sign}buf[{op.id1}]'
393
+ case 7:
394
+ op_str = f'buf[{op.id0}] * buf[{op.id1}]'
395
+ case 8:
396
+ op_str = f'tables[{int(op.data)}].lookup(buf[{op.id0}])'
397
+ case _:
398
+ raise ValueError(f'Unknown opcode {op.opcode} in {op}')
399
+
400
+ result = f'|-> buf[{i}] = {v}'
401
+ operands.append((op_str, result))
402
+ max_len = max(len(op[0]) for op in operands)
403
+ for op_str, result in operands:
404
+ print(f'{op_str:<{max_len}} {result}')
405
+
406
+ if dump:
407
+ return buf
408
+ return buf[out_idx] * sf * sign * mask
409
+
410
+ @property
411
+ def kernel(self):
412
+ """the kernel represented by the solution, when applicable."""
413
+ kernel = np.empty(self.shape, dtype=np.float32)
414
+ for i, one_hot in enumerate(np.identity(self.shape[0])):
415
+ kernel[i] = self(one_hot)
416
+ return kernel
417
+
418
+ @property
419
+ def cost(self):
420
+ """Total cost of the solution."""
421
+ return float(sum(op.cost for op in self.ops))
422
+
423
+ @property
424
+ def latency(self):
425
+ """Minimum and maximum latency of the solution."""
426
+ latency = [self.ops[i].latency for i in self.out_idxs]
427
+ if len(latency) == 0:
428
+ return 0.0, 0.0
429
+ return min(latency), max(latency)
430
+
431
+ def __repr__(self):
432
+ n_in, n_out = self.shape
433
+ cost = self.cost
434
+ lat_min, lat_max = self.latency
435
+ return f'Solution([{n_in} -> {n_out}], cost={cost}, latency={lat_min}-{lat_max})'
436
+
437
+ @property
438
+ def out_latency(self):
439
+ """Latencies of all output elements of the solution."""
440
+ return [self.ops[i].latency if i >= 0 else 0.0 for i in self.out_idxs]
441
+
442
+ @property
443
+ def out_qint(self):
444
+ """Quantization intervals of the output elements."""
445
+ buf = []
446
+ for i, idx in enumerate(self.out_idxs):
447
+ _min, _max, _step = self.ops[idx].qint
448
+ sf = 2.0 ** self.out_shifts[i]
449
+ _min, _max, _step = _min * sf, _max * sf, _step * sf
450
+ if self.out_negs[i]:
451
+ _min, _max = -_max, -_min
452
+ buf.append(QInterval(_min, _max, _step))
453
+ return buf
454
+
455
+ @property
456
+ def out_kifs(self):
457
+ """KIFs of all output elements of the solution."""
458
+ return np.array([_minimal_kif(qi) for qi in self.out_qint]).T
459
+
460
+ @property
461
+ def inp_latency(self):
462
+ """Latencies of all input elements of the solution."""
463
+ return [op.latency for op in self.ops if op.opcode == -1]
464
+
465
+ @property
466
+ def inp_qint(self):
467
+ """Quantization intervals of the input elements."""
468
+ qints = [QInterval(0.0, 0.0, 1.0) for _ in range(self.shape[0])]
469
+ for op in self.ops:
470
+ if op.opcode != -1:
471
+ continue
472
+ qints[op.id0] = op.qint
473
+ return qints
474
+
475
+ @property
476
+ def inp_kifs(self):
477
+ """KIFs of all input elements of the solution."""
478
+ return np.array([_minimal_kif(qi) for qi in self.inp_qint]).T
479
+
480
+ def save(self, path: str | Path):
481
+ """Save the solution to a file."""
482
+ with open(path, 'w') as f:
483
+ json.dump(self, f, cls=JSONEncoder)
484
+
485
+ @classmethod
486
+ def deserialize(cls, data: list):
487
+ """Load the solution from a file."""
488
+ ops = []
489
+ for _op in data[5]:
490
+ op = Op(*_op[:4], QInterval(*_op[4]), *_op[5:]) # type: ignore
491
+ ops.append(op)
492
+ assert len(data) in (8, 9), f'{len(data)}'
493
+ lookup_tables = data[8] if len(data) > 8 else None
494
+ if lookup_tables is not None:
495
+ from ..trace.fixed_variable import LookupTable
496
+
497
+ lookup_tables = tuple(LookupTable.from_dict(tab) for tab in lookup_tables)
498
+ return cls(
499
+ shape=tuple(data[0]),
500
+ inp_shifts=data[1],
501
+ out_idxs=data[2],
502
+ out_shifts=data[3],
503
+ out_negs=data[4],
504
+ ops=ops,
505
+ carry_size=data[6],
506
+ adder_size=data[7],
507
+ lookup_tables=lookup_tables,
508
+ )
509
+
510
+ @classmethod
511
+ def load(cls, path: str | Path):
512
+ """Load the solution from a file."""
513
+ with open(path) as f:
514
+ data = json.load(f)
515
+ return cls.deserialize(data)
516
+
517
+ @property
518
+ def ref_count(self) -> np.ndarray:
519
+ """The number of references to the output elements in the solution."""
520
+ ref_count = np.zeros(len(self.ops), dtype=np.uint64)
521
+ for op in self.ops:
522
+ if op.opcode == -1:
523
+ continue
524
+ id0, id1 = op.id0, op.id1
525
+ if id0 != -1:
526
+ ref_count[id0] += 1
527
+ if id1 != -1:
528
+ ref_count[id1] += 1
529
+ if op.opcode in (6, -6):
530
+ # msb_mux operation
531
+ ref_count[op.data & 0xFFFFFFFF] += 1
532
+ for i in self.out_idxs:
533
+ if i < 0:
534
+ continue
535
+ ref_count[i] += 1
536
+ return ref_count
537
+
538
+ def to_binary(self, version: int = 0) -> NDArray[np.int32]:
539
+ n_in, n_out = self.shape
540
+ header_size_i32 = 6 + n_in + n_out * 3
541
+ n_tables = len(self.lookup_tables) if self.lookup_tables is not None else 0
542
+
543
+ header = np.concatenate(
544
+ [
545
+ [0, version, n_in, n_out, len(self.ops), n_tables],
546
+ self.inp_shifts,
547
+ self.out_idxs,
548
+ self.out_shifts,
549
+ self.out_negs,
550
+ ],
551
+ axis=0,
552
+ dtype=np.int32,
553
+ )
554
+ assert len(header) == header_size_i32, f'Header size mismatch: {len(header)} != {header_size_i32}'
555
+ code = np.empty((len(self.ops), 8), dtype=np.int32)
556
+ for i, op in enumerate(self.ops):
557
+ buf = code[i]
558
+ buf[0] = op.opcode
559
+ buf[1] = op.id0
560
+ buf[2] = op.id1
561
+ buf[5:] = _minimal_kif(op.qint)
562
+ buf_i64 = buf[3:5].view(np.int64)
563
+ if op.opcode != 8:
564
+ buf_i64[0] = op.data
565
+ else:
566
+ assert self.lookup_tables is not None
567
+ pad_left = self.lookup_tables[op.data]._get_pads(self.ops[op.id0].qint)[0]
568
+ buf_i64[0] = (pad_left << 32) | op.data
569
+ data = np.concatenate([header, code.flatten()])
570
+
571
+ if self.lookup_tables is None:
572
+ return data
573
+
574
+ tables = [table.table for table in self.lookup_tables]
575
+ table_sizes = [len(tab) for tab in tables]
576
+ table_data = np.concatenate([table_sizes] + tables, axis=0, dtype=np.int32)
577
+ data = np.concatenate([data, table_data])
578
+ return data
579
+
580
+ def save_binary(self, path: str | Path, version: int = 0):
581
+ """Dump the solution to a binary file."""
582
+ data = self.to_binary(version=version)
583
+ with open(path, 'wb') as f:
584
+ data.tofile(f)
585
+
586
+ def predict(
587
+ self,
588
+ data: NDArray | Sequence[NDArray],
589
+ n_threads: int = -1,
590
+ ) -> NDArray[np.float64]:
591
+ """Predict the output of the solution for a batch of input data with cpp backed DAIS interpreter.
592
+ Cannot be used if the binary interpreter is not installed.
593
+
594
+ Parameters
595
+ ----------
596
+ data : NDArray|Sequence[NDArray]
597
+ Input data to the model. The shape is ignored, and the number of samples is
598
+ determined by the size of the data.
599
+ n_threads: int
600
+ Number of threads to use for prediction.
601
+ Negative or zero values will use maximum available threads. Default is -1.
602
+ If OpenMP is not supported, this parameter is ignored.
603
+
604
+ Returns
605
+ -------
606
+ NDArray[np.float64]
607
+ Output of the model in shape (n_samples, output_size).
608
+ """
609
+
610
+ if isinstance(data, Sequence):
611
+ data = np.concatenate([a.reshape(a.shape[0], -1) for a in data], axis=-1)
612
+
613
+ if n_threads == 0:
614
+ n_threads = -1
615
+
616
+ bin_logic = self.to_binary()
617
+ return dais_interp_run(bin_logic, data, n_threads)
618
+
619
+
620
+ class Pipeline(NamedTuple):
621
+ """A pipeline with II=1,with each stage represented by a CombLogic
622
+ Attributes
623
+ ----------
624
+ solutions: tuple[Solution, ...]
625
+ A tuple containing the individual Solution objects for each stage of the cascade.
626
+
627
+ Properties
628
+ ----------
629
+ kernel: NDArray[float32]
630
+ Only useful when the pipeline describes a linear operation.
631
+ The overall kernel matrix which the cascaded solution implements: vec @ kernel = solution(vec).
632
+ This is calculated as the matrix product of all individual solution kernels.
633
+ cost: float
634
+ The total cost of the cascaded solution, computed as the sum of the costs of all stages.
635
+ latency: tuple[float, float]
636
+ The minimum and maximum latency of the pipeline, determined by the last stage.
637
+ inp_qint: list[QInterval]
638
+ Input quantization intervals
639
+ inp_lat: list[float]
640
+ Input latencies
641
+ in_shift: list[int]
642
+ Input shifts
643
+ out_qint: list[QInterval]
644
+ Output quantization intervals
645
+ out_lat: list[float]
646
+ Output latencies
647
+ out_shift: list[int]
648
+ Output shifts
649
+ out_neg: list[bool]
650
+ Output signs
651
+ shape: tuple[int, int]
652
+ The shape of the corresponding kernel matrix.
653
+ """
654
+
655
+ solutions: tuple[CombLogic, ...]
656
+
657
+ def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False):
658
+ out = np.asarray(inp)
659
+ for sol in self.solutions:
660
+ out = sol(out, quantize=quantize, debug=debug)
661
+ return out
662
+
663
+ @property
664
+ def kernel(self):
665
+ return reduce(lambda x, y: x @ y, [sol.kernel for sol in self.solutions])
666
+
667
+ @property
668
+ def cost(self):
669
+ return sum(sol.cost for sol in self.solutions)
670
+
671
+ @property
672
+ def latency(self):
673
+ return self.solutions[-1].latency
674
+
675
+ @property
676
+ def inp_qint(self):
677
+ return self.solutions[0].inp_qint
678
+
679
+ @property
680
+ def inp_latency(self):
681
+ return self.solutions[0].inp_latency
682
+
683
+ @property
684
+ def out_qint(self):
685
+ return self.solutions[-1].out_qint
686
+
687
+ @property
688
+ def out_latencies(self):
689
+ return self.solutions[-1].out_latency
690
+
691
+ @property
692
+ def shape(self):
693
+ return self.solutions[0].shape[0], self.solutions[-1].shape[1]
694
+
695
+ @property
696
+ def inp_shifts(self):
697
+ return self.solutions[0].inp_shifts
698
+
699
+ @property
700
+ def out_shift(self):
701
+ return self.solutions[-1].out_shifts
702
+
703
+ @property
704
+ def out_neg(self):
705
+ return self.solutions[-1].out_negs
706
+
707
+ def __repr__(self) -> str:
708
+ n_ins = [sol.shape[0] for sol in self.solutions] + [self.shape[1]]
709
+ shape_str = ' -> '.join(map(str, n_ins))
710
+ _cost = self.cost
711
+ lat_min, lat_max = self.latency
712
+ return f'CascatedSolution([{shape_str}], cost={_cost}, latency={lat_min}-{lat_max})'
713
+
714
+ def save(self, path: str | Path):
715
+ """Save the solution to a file."""
716
+ with open(path, 'w') as f:
717
+ json.dump(self, f, cls=JSONEncoder)
718
+
719
+ @classmethod
720
+ def deserialize(cls, data: dict):
721
+ """Load the solution from a file."""
722
+ return cls(solutions=tuple(CombLogic.deserialize(sol) for sol in data[0]))
723
+
724
+ @classmethod
725
+ def load(cls, path: str):
726
+ """Load the solution from a file."""
727
+ with open(path) as f:
728
+ data = json.load(f)
729
+ return cls.deserialize(data)
730
+
731
+ @property
732
+ def reg_bits(self):
733
+ """The number of bits used for the register in the solution."""
734
+ bits = sum(map(sum, (_minimal_kif(qint) for qint in self.inp_qint)))
735
+ for _sol in self.solutions:
736
+ kifs = [_minimal_kif(qint) for qint in _sol.out_qint]
737
+ _bits = sum(map(sum, kifs))
738
+ bits += _bits
739
+ return bits