da4ml 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +239 -73
  5. da4ml/cmvm/core/__init__.py +222 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +49 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +265 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +177 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +120 -0
  35. da4ml-0.2.0.dist-info/METADATA +65 -0
  36. da4ml-0.2.0.dist-info/RECORD +39 -0
  37. {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.1.dist-info/METADATA +0 -121
  48. da4ml-0.1.1.dist-info/RECORD +0 -18
  49. {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info/licenses}/LICENSE +0 -0
  50. {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
da4ml/cmvm/types.py ADDED
@@ -0,0 +1,569 @@
1
+ import json
2
+ from decimal import Decimal
3
+ from functools import reduce, singledispatch
4
+ from math import ceil, floor, log2
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, NamedTuple, TypeVar
7
+
8
+ import numpy as np
9
+ from numba import jit
10
+ from numpy import float32, int8
11
+ from numpy.typing import NDArray
12
+
13
+ if TYPE_CHECKING:
14
+ from ..trace.tracer import FixedVariable
15
+
16
+
17
+ class QInterval(NamedTuple):
18
+ """A class representing a quantized interval: [min, max] with a step size."""
19
+
20
+ min: float
21
+ max: float
22
+ step: float
23
+
24
+ @classmethod
25
+ def from_kif(cls, k: int | bool, i: int, f: int):
26
+ _high = 2.0**i
27
+ step = 2.0**-f
28
+ low, high = -k * step, _high - step
29
+ return cls(low, high, step)
30
+
31
+ @classmethod
32
+ def from_precision(cls, prec: 'Precision'):
33
+ return cls.from_kif(*prec)
34
+
35
+ @property
36
+ def precision(self):
37
+ return Precision.from_qint(self)
38
+
39
+ def __repr__(self):
40
+ return f'[{self.min}, {self.max}, {self.step}]'
41
+
42
+
43
+ class Precision(NamedTuple):
44
+ """A class representing the precision of a quantized interval."""
45
+
46
+ keep_negative: bool
47
+ integers: int
48
+ fractional: int
49
+
50
+ def __str__(self):
51
+ k, i, f = self.keep_negative, self.integers, self.fractional
52
+ k, B, I = k, i + f + k, i + k
53
+ return f'fixed({k}, {B}, {I})'
54
+
55
+ def __repr__(self):
56
+ return str(self)
57
+
58
+ @classmethod
59
+ def from_qint(cls, qint: QInterval, symmetric: bool = False):
60
+ return _minimal_kif(qint, symmetric=symmetric)
61
+
62
+ @property
63
+ def qint(self):
64
+ return QInterval.from_kif(*self)
65
+
66
+
67
+ class Op(NamedTuple):
68
+ """One single operation on the data buffer.
69
+
70
+ Parameters
71
+ ----------
72
+ id0: int
73
+ index of the first operand
74
+ id1: int
75
+ index of the second operand, or special opcode if negative
76
+ opcode: int
77
+ 0: addition, 1: subtraction, 2: relu, 3: quantize, 4: const addition
78
+ data: int
79
+ Data to be used in the operation
80
+ qint: QInterval
81
+ Quantization interval of the resultant buffer
82
+ latency: float
83
+ Latency of the data generated by this operation (t_available)
84
+ cost: float
85
+ Cost of the operation
86
+ """
87
+
88
+ id0: int
89
+ id1: int
90
+ opcode: int
91
+ data: int
92
+ qint: QInterval
93
+ latency: float
94
+ cost: float
95
+
96
+
97
+ class Pair(NamedTuple):
98
+ """An operation representing data[id0] +/- data[id1] * 2**shift."""
99
+
100
+ id0: int
101
+ id1: int
102
+ sub: bool
103
+ shift: int
104
+
105
+
106
+ class DAState(NamedTuple):
107
+ """Internal state of the DA algorithm."""
108
+
109
+ shifts: tuple[NDArray[int8], NDArray[int8]]
110
+ expr: list[NDArray[int8]]
111
+ ops: list[Op]
112
+ freq_stat: dict[Pair, int]
113
+ kernel: NDArray[float32]
114
+
115
+
116
+ def _minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision:
117
+ """Calculate the minimal KIF for a given QInterval.
118
+
119
+ Parameters
120
+ ----------
121
+ qi : QInterval
122
+ The QInterval to calculate the KIF for.
123
+ symmetric : bool
124
+ Only relevant if qi may be negative. If True, -2**i will be regarded as forbidden.
125
+ May be useful in special cases only.
126
+ Default is False.
127
+
128
+ Returns
129
+ -------
130
+ Precision
131
+ A named tuple with the KIF values.
132
+ """
133
+
134
+ if qi.min == qi.max == 0:
135
+ return Precision(keep_negative=False, integers=0, fractional=0)
136
+ keep_negative = qi.min < 0
137
+ fractional = int(-log2(qi.step))
138
+ int_min, int_max = round(qi.min / qi.step), round(qi.max / qi.step)
139
+ if symmetric:
140
+ bits = int(ceil(log2(max(abs(int_min), int_max) + 1)))
141
+ else:
142
+ bits = int(ceil(log2(max(abs(int_min), int_max + 1))))
143
+ integers = bits - fractional
144
+ return Precision(keep_negative=keep_negative, integers=integers, fractional=fractional)
145
+
146
+
147
+ if TYPE_CHECKING:
148
+
149
+ def minimal_kif(qi: QInterval, symmetric: bool = False) -> Precision: ...
150
+ else:
151
+ minimal_kif = jit(_minimal_kif)
152
+
153
+
154
+ T = TypeVar('T', 'FixedVariable', float, int, np.float32, np.float64, Decimal)
155
+
156
+
157
+ @singledispatch
158
+ def _relu(v: 'T', i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN') -> 'T':
159
+ from ..trace.fixed_variable import FixedVariable
160
+
161
+ assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic relu'
162
+ return v.relu(i, f, round_mode=round_mode)
163
+
164
+
165
+ @_relu.register(float)
166
+ @_relu.register(int)
167
+ @_relu.register(np.float32)
168
+ @_relu.register(np.float64)
169
+ def _(v, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
170
+ if inv:
171
+ v = -v
172
+ v = max(0, v)
173
+ if f is not None:
174
+ if round_mode.upper() == 'RND':
175
+ v += 2.0 ** (-f - 1)
176
+ sf = 2.0**f
177
+ v = floor(v * sf) / sf
178
+ if i is not None:
179
+ v = v % 2.0**i
180
+ return v
181
+
182
+
183
+ @_relu.register
184
+ def _(v: Decimal, i: int | None = None, f: int | None = None, inv: bool = False, round_mode: str = 'TRN'):
185
+ if inv:
186
+ v = -v
187
+ v = max(Decimal(0), v)
188
+ if f is not None:
189
+ if round_mode.upper() == 'RND':
190
+ v += Decimal(2) ** (-f - 1)
191
+ sf = Decimal(2) ** f
192
+ v = floor(v * sf) / sf
193
+ if i is not None:
194
+ v = v % Decimal(2) ** i
195
+ return v
196
+
197
+
198
+ @singledispatch
199
+ def _quantize(v: 'T', k: int | bool, i: int, f: int, round_mode: str = 'TRN') -> 'T':
200
+ from ..trace.fixed_variable import FixedVariable
201
+
202
+ assert isinstance(v, FixedVariable), f'Unknown type {type(v)} for symbolic quantization'
203
+ return v.quantize(k, i, f, round_mode=round_mode)
204
+
205
+
206
+ @_quantize.register(float)
207
+ @_quantize.register(int)
208
+ @_quantize.register(np.float32)
209
+ @_quantize.register(np.float64)
210
+ def _(v, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
211
+ if round_mode.upper() == 'RND':
212
+ v += 2.0 ** (-f - 1)
213
+ b = k + i + f
214
+ bias = 2.0 ** (b - 1) * k
215
+ eps = 2.0**-f
216
+ return eps * ((np.floor(v / eps) + bias) % 2**b - bias)
217
+
218
+
219
+ @_quantize.register
220
+ def _(v: Decimal, k: int | bool, i: int, f: int, round_mode: str = 'TRN'):
221
+ if round_mode.upper() == 'RND':
222
+ v += Decimal(2) ** (-f - 1)
223
+ b = k + i + f
224
+ bias = Decimal(2) ** (b - 1) * k
225
+ eps = Decimal(2) ** -f
226
+ return eps * ((floor(v / eps) + bias) % Decimal(2) ** b - bias)
227
+
228
+
229
+ class Solution(NamedTuple):
230
+ """Represents a series of operations that can be applied to a vector of data.
231
+ May represent a CMVM solution or a general neural network
232
+
233
+ Attributes
234
+ ----------
235
+ shape: tuple[int, int]
236
+ #input, #output
237
+ inp_shift: list[int]
238
+ The shifts that should be applied to the input data.
239
+ out_idxs: list[int]
240
+ The indices of the output data in the buffer.
241
+ out_shifts: list[int]
242
+ The shifts that should be applied to the output data.
243
+ out_negs: list[bool]
244
+ The signs of the output data.
245
+ ops: list[Op]
246
+ Core list of operations for generating each buffer element.
247
+ carry_size: int
248
+ Size of the carrier for the adder.
249
+ adder_size: int
250
+ Elementary size of the adder.
251
+
252
+
253
+ The core part of the solution is the operations in the ops list.
254
+ For the exact operations executed with Op, refer to the Op class.
255
+ After all operations are executed, the output data is read from data[op.out_idx] and multiplied by 2**out_shift.
256
+
257
+ """
258
+
259
+ shape: tuple[int, int]
260
+ inp_shift: list[int]
261
+ out_idxs: list[int]
262
+ out_shifts: list[int]
263
+ out_negs: list[bool]
264
+ ops: list[Op]
265
+ carry_size: int
266
+ adder_size: int
267
+
268
+ def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False, dump=False):
269
+ """Executes the solution on the input data.
270
+
271
+ Parameters
272
+ ----------
273
+ inp : list | np.ndarray | tuple
274
+ Input data to be processed. The input data should be a list or numpy array of objects.
275
+ quantize : bool
276
+ If True, the input data will be quantized to the output quantization intervals.
277
+ Only floating point data types are supported when quantize is True.
278
+ Default is False.
279
+ debug : bool
280
+ If True, the function will print debug information about the operations being performed.
281
+ Default is False.
282
+ dump : bool
283
+ If True, the return the whole buffer, without applying the output shifts and signs.
284
+ Default is False.
285
+
286
+ Returns
287
+ -------
288
+ np.ndarray
289
+ The output data after applying the operations defined in the solution.
290
+
291
+ """
292
+ buf = np.empty(len(self.ops), dtype=object)
293
+ inp = np.asarray(inp)
294
+
295
+ inp_qint = [op.qint for op in self.ops if op.opcode == -1]
296
+ if quantize: # TRN and WRAP
297
+ k, i, f = map(np.array, zip(*map(minimal_kif, inp_qint)))
298
+ eps = 2.0**-f
299
+ _low, _high = -(2.0 ** (i + f)) * k, 2.0 ** (i + f) - 1
300
+ inp = eps * ((np.floor(inp / eps) - _low) % 2.0 ** (k + i + f) + _low)
301
+
302
+ inp = inp * (2.0 ** np.array(self.inp_shift))
303
+ for i, op in enumerate(self.ops):
304
+ match op.opcode:
305
+ case -1: # copy form external buffer
306
+ buf[i] = inp[op.id0]
307
+ case 0 | 1: # addition
308
+ v0, v1 = buf[op.id0], 2.0**op.data * buf[op.id1]
309
+ buf[i] = v0 + v1 if op.opcode == 0 else v0 - v1
310
+ case 2 | -2: # relu(+/-x)
311
+ v = buf[op.id0]
312
+ _, _i, _f = _minimal_kif(op.qint)
313
+ buf[i] = _relu(v, _i, _f, inv=op.opcode == -2, round_mode='TRN')
314
+ case 3 | -3: # quantize(+/-x)
315
+ v = buf[op.id0] if op.opcode == 3 else -buf[op.id0]
316
+ _k, _i, _f = _minimal_kif(op.qint)
317
+ buf[i] = _quantize(v, _k, _i, _f, round_mode='TRN')
318
+ case 4: # const addition
319
+ bias = op.data * op.qint.step
320
+ buf[i] = buf[op.id0] + bias
321
+ case 5:
322
+ buf[i] = op.data * op.qint.step # const definition
323
+ case _:
324
+ raise ValueError(f'Unknown opcode {op.opcode} in {op}')
325
+
326
+ sf = 2.0 ** np.array(self.out_shifts)
327
+ sign = np.where(self.out_negs, -1, 1)
328
+ out_idx = np.array(self.out_idxs)
329
+ mask = np.where(out_idx < 0, 0, 1)
330
+ if debug:
331
+ for i, v in enumerate(buf):
332
+ op = self.ops[i]
333
+ match op.opcode:
334
+ case -1:
335
+ op_str = 'inp'
336
+ case 0:
337
+ op_str = f'buf[{op.id0}] + buf[{op.id1}]<<{op.data}'
338
+ case 1:
339
+ op_str = f'buf[{op.id0}] - buf[{op.id1}]<<{op.data}'
340
+ case 2:
341
+ op_str = f'relu(buf[{op.id0}])'
342
+ case -2:
343
+ op_str = f'relu(-buf[{op.id0}])'
344
+ case 3:
345
+ op_str = f'quantize(buf[{op.id0}])'
346
+ case -3:
347
+ op_str = f'quantize(-buf[{op.id0}])'
348
+ case 4:
349
+ op_str = f'buf[{op.id0}] + {op.data * op.qint.step}'
350
+ case 5:
351
+ op_str = f'const {op.data * op.qint.step}'
352
+ case _:
353
+ raise ValueError(f'Unknown opcode {op.opcode} in {op}')
354
+
355
+ print(f'{op_str:24} |-> buf[{i}] = {v}')
356
+
357
+ if dump:
358
+ return buf
359
+ return buf[out_idx] * sf * sign * mask
360
+
361
+ @property
362
+ def kernel(self):
363
+ """the kernel represented by the solution, when applicable."""
364
+ kernel = np.empty(self.shape, dtype=np.float32)
365
+ for i, one_hot in enumerate(np.identity(self.shape[0])):
366
+ kernel[i] = self(one_hot)
367
+ return kernel
368
+
369
+ @property
370
+ def cost(self):
371
+ """Total cost of the solution."""
372
+ return float(sum(op.cost for op in self.ops))
373
+
374
+ @property
375
+ def latency(self):
376
+ """Minimum and maximum latency of the solution."""
377
+ latency = [self.ops[i].latency for i in self.out_idxs]
378
+ if len(latency) == 0:
379
+ return 0.0, 0.0
380
+ return min(latency), max(latency)
381
+
382
+ def __repr__(self):
383
+ n_in, n_out = self.shape
384
+ cost = self.cost
385
+ lat_min, lat_max = self.latency
386
+ return f'Solution([{n_in} -> {n_out}], cost={cost}, latency={lat_min}-{lat_max})'
387
+
388
+ @property
389
+ def out_latency(self):
390
+ """Latencies of all output elements of the solution."""
391
+ return [self.ops[i].latency if i >= 0 else 0.0 for i in self.out_idxs]
392
+
393
+ @property
394
+ def out_qint(self):
395
+ """Quantization intervals of the output elements."""
396
+ buf = []
397
+ for i, idx in enumerate(self.out_idxs):
398
+ _min, _max, _step = self.ops[idx].qint
399
+ sf = 2.0 ** self.out_shifts[i]
400
+ _min, _max, _step = _min * sf, _max * sf, _step * sf
401
+ if self.out_negs[i]:
402
+ _min, _max = -_max, -_min
403
+ buf.append(QInterval(_min, _max, _step))
404
+ return buf
405
+
406
+ @property
407
+ def inp_latency(self):
408
+ """Latencies of all input elements of the solution."""
409
+ return [op.latency for op in self.ops if op.opcode == -1]
410
+
411
+ @property
412
+ def inp_qint(self):
413
+ """Quantization intervals of the input elements."""
414
+ return [op.qint for op in self.ops if op.opcode == -1]
415
+
416
+ def save(self, path: str | Path):
417
+ """Save the solution to a file."""
418
+ with open(path, 'w') as f:
419
+ json.dump(self, f)
420
+
421
+ @classmethod
422
+ def deserialize(cls, data: dict):
423
+ """Load the solution from a file."""
424
+ ops = []
425
+ for _op in data[5]:
426
+ op = Op(*_op[:4], QInterval(*_op[4]), *_op[5:]) # type: ignore
427
+ ops.append(op)
428
+ return cls(
429
+ shape=tuple(data[0]),
430
+ inp_shift=data[1],
431
+ out_idxs=data[2],
432
+ out_shifts=data[3],
433
+ out_negs=data[4],
434
+ ops=ops,
435
+ carry_size=data[6],
436
+ adder_size=data[7],
437
+ )
438
+
439
+ @classmethod
440
+ def load(cls, path: str | Path):
441
+ """Load the solution from a file."""
442
+ with open(path) as f:
443
+ data = json.load(f)
444
+ return cls.deserialize(data)
445
+
446
+
447
+ class CascadedSolution(NamedTuple):
448
+ """A solution that implements cascaded matrix-vector multiplications through multiple CMVM stages.
449
+
450
+ CascadedSolution represents a sequence of Solution objects where the output of each stage
451
+ is fed as input to the next stage.
452
+
453
+ Attributes
454
+ ----------
455
+ solutions: tuple[Solution, ...]
456
+ A tuple containing the individual Solution objects for each stage of the cascade.
457
+
458
+ Properties
459
+ ----------
460
+ kernel: NDArray[float32]
461
+ The overall kernel matrix which the cascaded solution implements: vec @ kernel = solution(vec).
462
+ This is calculated as the matrix product of all individual solution kernels.
463
+ cost: float
464
+ The total cost of the cascaded solution, computed as the sum of the costs of all stages.
465
+ latency: tuple[float, float]
466
+ The minimum and maximum latency of the cascaded solution.
467
+ inp_qint: list[QInterval]
468
+ Input quantization intervals
469
+ inp_lat: list[float]
470
+ Input latencies
471
+ in_shift: list[int]
472
+ Input shifts
473
+ out_qint: list[QInterval]
474
+ Output quantization intervals
475
+ out_lat: list[float]
476
+ Output latencies
477
+ out_shift: list[int]
478
+ Output shifts
479
+ out_neg: list[bool]
480
+ Output signs
481
+ shape: tuple[int, int]
482
+ The shape of the corresponding kernel matrix.
483
+ """
484
+
485
+ solutions: tuple[Solution, ...]
486
+
487
+ def __call__(self, inp: list | np.ndarray | tuple, quantize=False, debug=False):
488
+ out = np.asarray(inp)
489
+ for sol in self.solutions:
490
+ out = sol(out, quantize=quantize, debug=debug)
491
+ return out
492
+
493
+ @property
494
+ def kernel(self):
495
+ return reduce(lambda x, y: x @ y, [sol.kernel for sol in self.solutions])
496
+
497
+ @property
498
+ def cost(self):
499
+ return sum(sol.cost for sol in self.solutions)
500
+
501
+ @property
502
+ def latency(self):
503
+ return self.solutions[-1].latency
504
+
505
+ @property
506
+ def inp_qint(self):
507
+ return self.solutions[0].inp_qint
508
+
509
+ @property
510
+ def inp_latency(self):
511
+ return self.solutions[0].inp_latency
512
+
513
+ @property
514
+ def out_qint(self):
515
+ return self.solutions[-1].out_qint
516
+
517
+ @property
518
+ def out_latencies(self):
519
+ return self.solutions[-1].out_latency
520
+
521
+ @property
522
+ def shape(self):
523
+ return self.solutions[0].shape[0], self.solutions[-1].shape[1]
524
+
525
+ @property
526
+ def inp_shift(self):
527
+ return self.solutions[0].inp_shift
528
+
529
+ @property
530
+ def out_shift(self):
531
+ return self.solutions[-1].out_shifts
532
+
533
+ @property
534
+ def out_neg(self):
535
+ return self.solutions[-1].out_negs
536
+
537
+ def __repr__(self) -> str:
538
+ n_ins = [sol.shape[0] for sol in self.solutions] + [self.shape[1]]
539
+ shape_str = ' -> '.join(map(str, n_ins))
540
+ _cost = self.cost
541
+ lat_min, lat_max = self.latency
542
+ return f'CascatedSolution([{shape_str}], cost={_cost}, latency={lat_min}-{lat_max})'
543
+
544
+ def save(self, path: str | Path):
545
+ """Save the solution to a file."""
546
+ with open(path, 'w') as f:
547
+ json.dump(self, f)
548
+
549
+ @classmethod
550
+ def deserialize(cls, data: dict):
551
+ """Load the solution from a file."""
552
+ return cls(solutions=tuple(Solution.deserialize(sol) for sol in data[0]))
553
+
554
+ @classmethod
555
+ def load(cls, path: str):
556
+ """Load the solution from a file."""
557
+ with open(path) as f:
558
+ data = json.load(f)
559
+ return cls.deserialize(data)
560
+
561
+ @property
562
+ def reg_bits(self):
563
+ """The number of bits used for the register in the solution."""
564
+ bits = 0
565
+ for _sol in self.solutions:
566
+ kifs = [_minimal_kif(qint) for qint in _sol.out_qint]
567
+ _bits = sum(map(sum, kifs))
568
+ bits += _bits
569
+ return bits
@@ -0,0 +1,7 @@
1
+ from .bit_decompose import csd_decompose
2
+ from .mat_decompose import kernel_decompose
3
+
4
+ __all__ = [
5
+ 'csd_decompose',
6
+ 'kernel_decompose',
7
+ ]
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+ from numba import jit
3
+ from numpy.typing import NDArray
4
+
5
+
6
+ @jit
7
+ def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
8
+ x = x
9
+ N = np.max(np.ceil(np.log2(np.abs(x) * 1.5 + 1e-19)))
10
+ N = int(max(N, 1))
11
+ buf = np.zeros((*np.shape(x), N), dtype=np.int8)
12
+
13
+ for n in range(N - 1, -1, -1):
14
+ _2pn = 2**n
15
+ thres = _2pn / 1.5
16
+ bit = (x > thres).astype(np.int8)
17
+ bit -= (x < -thres).astype(np.int8)
18
+ x -= _2pn * bit
19
+ buf[..., n] = bit
20
+ return buf
21
+
22
+
23
+ @jit(error_model='numpy')
24
+ def _shift_centering(arr: NDArray):
25
+ low, high = -64, 64
26
+ if np.all(arr == 0):
27
+ high = low = 0
28
+ while high - low > 1:
29
+ mid = (high + low) // 2
30
+ xs = arr * (2.0**mid)
31
+ if np.all(xs == np.floor(xs)):
32
+ high = mid
33
+ else:
34
+ low = mid
35
+ return -high
36
+
37
+
38
+ @jit(error_model='numpy')
39
+ def shift_centering(arr: NDArray, axis: int):
40
+ n = arr.shape[axis]
41
+ shifts = np.empty(n, dtype=np.int8)
42
+ for i in range(n):
43
+ shifts[i] = _shift_centering(arr.take(i, axis=axis))
44
+ return shifts
45
+
46
+
47
+ @jit
48
+ def _center(arr: NDArray):
49
+ shift1 = shift_centering(arr, 1) # d_out
50
+ arr = arr * (2.0**-shift1)
51
+ shift0 = shift_centering(arr, 0) # d_in
52
+ arr = arr * (2.0 ** -shift0[:, None])
53
+ return arr, shift0, shift1
54
+
55
+
56
+ @jit
57
+ def csd_decompose(arr: NDArray, center=True):
58
+ """
59
+ Convert an 2D array to CSD representation.
60
+
61
+ Parameters
62
+ ----------
63
+ arr : ndarray
64
+ Input array to be converted.
65
+ center : bool, optional
66
+ If True, the array is centered before conversion. Default is True.
67
+ If False, the function may accept non-2D arrays.
68
+
69
+ Returns
70
+ -------
71
+ csd : ndarray
72
+ CSD representation of the input array after centering, if center is True.
73
+ shift0 : ndarray
74
+ Shift values for the first axis.
75
+ shift1 : ndarray
76
+ Shift values for the second axis.
77
+ """
78
+
79
+ if center:
80
+ arr, shift0, shift1 = _center(arr)
81
+ else:
82
+ shift0 = np.zeros(arr.shape[0], dtype=np.int8)
83
+ shift1 = np.zeros(arr.shape[1], dtype=np.int8)
84
+ arr = arr.copy()
85
+ csd = _volatile_int_arr_to_csd(arr)
86
+ return csd, shift0, shift1