da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +239 -73
  5. da4ml/cmvm/core/__init__.py +222 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +49 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +265 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +177 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +120 -0
  35. da4ml-0.2.0.dist-info/METADATA +65 -0
  36. da4ml-0.2.0.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
da4ml/cmvm/cmvm.py DELETED
@@ -1,328 +0,0 @@
1
- import heapq
2
- from collections.abc import Sequence
3
- from math import ceil
4
-
5
- import numpy as np
6
- from numba import njit, prange
7
- from numpy.typing import NDArray
8
-
9
- from .csd import to_csd
10
- from .nb_fixed_precision import NBFixedPrecision
11
- from .scoring import scorer
12
- from .utils import DAState, OpCode, Score
13
-
14
-
15
- @njit
16
- def extract_pairs(
17
- csd: list[NDArray[np.int8]], precisions: list[NBFixedPrecision], updated: list[int] | None = None, dc: int | None = None
18
- ):
19
- d_in = len(csd)
20
- if d_in == 0:
21
- raise ValueError('csd must have at least one element')
22
- d_out, n_bit = csd[0].shape
23
- _stat = np.zeros((d_in, d_in, n_bit, 2), dtype=np.int64)
24
- process_locs = np.zeros((d_in, d_out), dtype=np.bool_)
25
-
26
- if updated is not None:
27
- for i in range(len(updated)):
28
- pos = updated[i]
29
- for n in range(d_out):
30
- process_locs[pos, n] = True
31
- else:
32
- for pos in range(d_in):
33
- for n in range(d_out):
34
- if np.any(csd[pos][n]) and precisions[pos].b != 0:
35
- process_locs[pos, n] = True
36
-
37
- if dc is not None:
38
- depths = np.zeros(d_in, dtype=np.int64)
39
- for pos in range(d_in):
40
- depths[pos] = precisions[pos]._depth
41
- depth_min = np.min(depths)
42
- mask = depths <= depth_min + dc
43
- if np.count_nonzero(mask) >= 2:
44
- for n in range(d_out):
45
- process_locs[:, n] &= mask
46
-
47
- args: list[tuple[int, int, int]] = []
48
- for pos0 in range(d_in):
49
- for pos1 in range(d_in):
50
- for n in range(d_out):
51
- if process_locs[pos0, n] or process_locs[pos1, n]:
52
- args.append((pos0, pos1, n))
53
-
54
- for idx in prange(len(args)):
55
- pos0, pos1, n = args[idx]
56
- for shift0 in range(n_bit):
57
- if csd[pos0][n, shift0] == 0:
58
- continue
59
- lower = shift0 if pos0 < pos1 else shift0 + 1
60
- for shift1 in range(lower, n_bit):
61
- if csd[pos1][n, shift1] == 0:
62
- continue
63
- dsign = int(csd[pos0][n, shift0] != csd[pos1][n, shift1])
64
- _stat[pos0, pos1, shift1 - shift0, dsign] += 1
65
-
66
- stat: list[tuple[float, int, int, int, int, int]] = []
67
- for pos0 in range(d_in):
68
- for pos1 in range(d_in):
69
- for dshift in range(n_bit):
70
- for dsign in range(2):
71
- if _stat[pos0, pos1, dshift, dsign] > 1:
72
- count = int(_stat[pos0, pos1, dshift, dsign])
73
- score = scorer(precisions[pos0], precisions[pos1], dshift, dsign)
74
- n_count = (count - 1) if count > 1 else 0
75
- data = (-score * n_count, count, pos0, pos1, dshift, dsign)
76
- stat.append(data)
77
- if updated is None:
78
- heapq.heapify(stat)
79
- return stat
80
-
81
-
82
- @njit
83
- def init_var(
84
- k: bool,
85
- b: int,
86
- i: int,
87
- symmetric: bool = False,
88
- _depth: int = 0,
89
- ) -> NBFixedPrecision:
90
- shift = b - i
91
- int_max = 2**b - 1
92
- int_min = -int_max - 1 + symmetric if k else 0
93
- return NBFixedPrecision(
94
- int_min,
95
- int_max,
96
- shift,
97
- symmetric,
98
- _depth,
99
- )
100
-
101
-
102
- @njit
103
- def init_vars(
104
- ks: tuple[bool, ...],
105
- bs: tuple[int, ...],
106
- is_: tuple[int, ...],
107
- symmetrics: tuple[bool, ...],
108
- depths: tuple[int, ...],
109
- ):
110
- n = len(ks)
111
- vars_ = []
112
- for i in range(n):
113
- vars_.append(init_var(ks[i], bs[i], is_[i], symmetrics[i], depths[i]))
114
- return vars_
115
-
116
-
117
- @njit
118
- def init_state(
119
- kernel: np.ndarray,
120
- signs: tuple[bool, ...],
121
- bits: tuple[int, ...],
122
- int_bits: tuple[int, ...],
123
- symmetrics: tuple[bool, ...],
124
- depths: tuple[int, ...],
125
- ):
126
- assert kernel.ndim == 2
127
- assert len(signs) == len(bits) == len(int_bits) == len(symmetrics) == len(depths) == kernel.shape[0]
128
- csd, shifts = to_csd(kernel)
129
- shift = shifts[0]
130
-
131
- d_in = len(csd)
132
- vars_ = init_vars(signs, bits, int_bits, symmetrics, depths)
133
- op_codes = []
134
- for i in range(d_in):
135
- opr_code = OpCode(i, -10, shift, 0, 0, 0)
136
- op_codes.append(opr_code)
137
- vars_[i] = vars_[i] << shift
138
- pairs = extract_pairs(csd, vars_)
139
-
140
- potential = 0.0
141
- for i in range(len(pairs)):
142
- potential -= pairs[i][0]
143
- score = Score(potential, 0.0, 0.0, 0.0)
144
-
145
- state = DAState(csd, vars_, op_codes, pairs, score, kernel)
146
- return state
147
-
148
-
149
- @njit
150
- def update_state(state: DAState, pair: tuple[float, int, int, int, int, int], dc=None):
151
- neg_cum_score, count, pos0, pos1, dshift, dsign = pair
152
- variables = state.variables.copy()
153
- op_codes = state.op_codes.copy()
154
- csd = state.csd
155
- pairs = state.pairs
156
-
157
- realized = state.score.realized - neg_cum_score
158
-
159
- _d_in, d_out, n_bit = len(csd), *csd[0].shape
160
- new_csd_col = np.zeros((d_out, n_bit), dtype=np.int8)
161
- dsign = -1 if dsign else 1
162
- for n in range(d_out):
163
- for shift0 in range(0, n_bit - dshift):
164
- _dsign = csd[pos0][n, shift0] * csd[pos1][n, shift0 + dshift]
165
- if _dsign == dsign:
166
- new_csd_col[n, shift0] = csd[pos0][n, shift0]
167
- csd[pos0][n, shift0] = 0
168
- csd[pos1][n, shift0 + dshift] = 0
169
-
170
- csd.append(new_csd_col)
171
- v0, v1 = variables[pos0], variables[pos1] << dshift
172
- v = v0 + v1 if dsign == 0 else v0 - v1
173
- variables.append(v)
174
- op_code = OpCode(pos0, pos1, 0, dshift, 1, dsign)
175
- op_codes.append(op_code)
176
- updated = [pos0, pos1, len(variables) - 1]
177
-
178
- d_pairs = extract_pairs(csd, variables, updated, dc)
179
- for i in range(len(pairs) - 1, -1, -1):
180
- _pair = pairs[i]
181
- if pos0 == _pair[2] or pos1 == _pair[2] or pos0 == _pair[3] or pos1 == _pair[3]:
182
- pairs.pop(i)
183
-
184
- for i in range(len(d_pairs)):
185
- heapq.heappush(pairs, d_pairs[i])
186
-
187
- cur_potential = 0.0
188
- for i in range(len(pairs)):
189
- cur_potential -= pairs[i][0]
190
-
191
- lost = state.score.potential - cur_potential
192
- value = realized - lost
193
- score = Score(state.score.potential, realized, lost, value)
194
- return DAState(csd, variables, op_codes, pairs, score, state.kernel)
195
-
196
-
197
- @njit(cache=True)
198
- def get_top_n_pairs(state: DAState, n: int):
199
- return state.pairs[:n]
200
- _pairs = state.pairs.copy()
201
- return [heapq.heappop(_pairs) for _ in range(min(n, len(_pairs)))]
202
-
203
-
204
- @njit
205
- def cmvm_cse(state: DAState, progress=None, beams: int = 1, dc=None):
206
- assert len(state.pairs) > 0, f'{len(state.pairs)}'
207
- top_pairs = get_top_n_pairs(state, beams)
208
- states_0 = [update_state(state, top_pairs[i], dc) for i in range(len(top_pairs))]
209
-
210
- next_score = np.full((beams, beams), -1, dtype=np.float64)
211
- use_n = np.empty(beams, dtype=np.int64)
212
- while True:
213
- use_n[:] = 0
214
- next_score[:] = -1
215
- for i in range(len(states_0)):
216
- if len(states_0[i].pairs) > 0 and states_0[i].pairs[0][0] < 0:
217
- break
218
- else:
219
- break
220
-
221
- for i in range(len(states_0)):
222
- # pass
223
- _state = states_0[i]
224
- next_score[i, 0] = _state.score.realized
225
- next_score[i, 1:] = -1
226
- top_pairs = get_top_n_pairs(_state, beams)
227
- for j in range(len(top_pairs)):
228
- next_score[i, j] = -top_pairs[j][0] + _state.score.realized
229
-
230
- order = np.argsort(next_score.ravel())[::-1]
231
- for i in range(beams):
232
- i_st, i_pair = np.divmod(order[i], next_score.shape[0])
233
- if next_score[i_st, i_pair] > 0:
234
- use_n[i_st] += 1
235
-
236
- states_1 = []
237
- for i in range(beams):
238
- if use_n[i] == 0:
239
- continue
240
-
241
- _state = states_0[i]
242
- if len(_state.pairs) == 0 or _state.pairs[0][0] > 0:
243
- states_1.append(_state)
244
- continue
245
-
246
- top_pairs = get_top_n_pairs(_state, use_n[i])
247
- for j in range(use_n[i]):
248
- states_1.append(update_state(_state, top_pairs[j], dc))
249
-
250
- states_0 = states_1
251
-
252
- _max = states_0[0].score.realized
253
- _idx = 0
254
- for i in range(len(states_0)):
255
- _state = states_0[i]
256
- if _state.score.realized > _max:
257
- _max = _state.score.realized
258
- _idx = i
259
- return states_0[_idx]
260
-
261
-
262
- @njit
263
- def compile_kernel_mono(
264
- kernel: np.ndarray,
265
- signs: tuple[bool, ...],
266
- bits: tuple[int, ...],
267
- int_bits: tuple[int, ...],
268
- symmetrics: tuple[bool, ...],
269
- depths: tuple[int, ...],
270
- n_beams: int = 1,
271
- dc: int | None = None,
272
- ):
273
- state = init_state(kernel, signs, bits, int_bits, symmetrics, depths)
274
- _state = cmvm_cse(state, beams=n_beams, dc=dc)
275
- return _state
276
-
277
-
278
- def compile_kernel(
279
- kernel: np.ndarray,
280
- signs: Sequence[bool],
281
- bits: Sequence[int],
282
- int_bits: Sequence[int],
283
- symmetrics: Sequence[bool],
284
- depths: Sequence[int],
285
- n_beams: int = 1,
286
- dc: int | None = None,
287
- n_inp_max: int = -1,
288
- n_out_max: int = -1,
289
- ) -> list[list[DAState]]:
290
- d_inp, d_out = kernel.shape
291
- n_inp_part = 1
292
- if n_inp_max > 0 and n_inp_max < d_inp:
293
- n_inp_part = ceil(d_inp / n_inp_max)
294
- n_out_part = 1
295
- if n_out_max > 0 and n_out_max < d_out:
296
- n_out_part = ceil(d_out / n_out_max)
297
-
298
- inp_chunk_size = ceil(d_inp / n_inp_part)
299
- out_chunk_size = ceil(d_out / n_out_part)
300
-
301
- inp_part_locs = np.arange(0, n_inp_part + 1) * inp_chunk_size
302
- out_part_locs = np.arange(0, n_out_part + 1) * out_chunk_size
303
-
304
- states = [[None for _ in range(n_out_part)] for _ in range(n_inp_part)]
305
- for idx in range(n_inp_part * n_out_part):
306
- j, i = np.divmod(idx, n_inp_part)
307
- inp_start, inp_end = inp_part_locs[i], min(inp_part_locs[i + 1], d_inp) # type: ignore
308
- out_start, out_end = out_part_locs[j], min(out_part_locs[j + 1], d_out) # type: ignore
309
- _kernel = kernel[inp_start:inp_end, out_start:out_end]
310
- _signs = signs[inp_start:inp_end]
311
- _bits = bits[inp_start:inp_end]
312
- _int_bits = int_bits[inp_start:inp_end]
313
- _symmetrics = symmetrics[inp_start:inp_end]
314
- _depths = depths[inp_start:inp_end]
315
-
316
- # unify input type to prevent recompilation
317
- _kernel = np.ascontiguousarray(_kernel)
318
- _signs = tuple(bool(v) for v in _signs)
319
- _bits = tuple(int(v) for v in _bits)
320
- _int_bits = tuple(int(v) for v in _int_bits)
321
- _symmetrics = tuple(bool(v) for v in _symmetrics)
322
- _depths = tuple(int(v) for v in _depths)
323
- try:
324
- states[i][j] = compile_kernel_mono(_kernel, _signs, _bits, _int_bits, _symmetrics, _depths, n_beams, dc)
325
- except AssertionError:
326
- states[i][j] = init_state(_kernel, _signs, _bits, _int_bits, _symmetrics, _depths)
327
-
328
- return states # type: ignore
da4ml/cmvm/codegen.py DELETED
@@ -1,159 +0,0 @@
1
- import types
2
- from math import log2
3
-
4
- from .fixed_variable import FixedVariable, Namer
5
-
6
-
7
- class PyCodegenBackend:
8
- _comment = '#'
9
-
10
- def __init__(self, namer=Namer(), fn_name: str = 'placeholder', **kwargs):
11
- self._namer = namer
12
- self._attrs = {'fn_name': fn_name, **kwargs}
13
-
14
- def reference_code(self, v: FixedVariable):
15
- """How the variable should be referenced in the code"""
16
- if v.int_min == v.int_max:
17
- return f'{v.min}'
18
-
19
- neg = v._factor < 0
20
- shift = log2(abs(v._factor))
21
- assert shift % 1 == 0
22
- shift = int(shift)
23
- s_sign = '-' if neg else ''
24
- s_shift = f' * {2.**shift}' if shift != 0 else ''
25
- return f'{s_sign}{v.name}{s_shift}'
26
-
27
- def def_code(self, v: FixedVariable):
28
- """How the variable should be defined in the code"""
29
- if v.int_min == v.int_max:
30
- raise ValueError('Constant variable should not be defined')
31
- assert v._from is not None, 'Variable not derived from other variables cannot be defined in runtime'
32
- v1_str = self.reference_code(v._from[0])
33
- v2_str = self.reference_code(v._from[1])
34
- if v2_str[0] == '-':
35
- return f'{v.name} = {v1_str} - {v2_str[1:]}'
36
- return f'{v.name} = {v1_str} + {v2_str}'
37
-
38
- def _resolve_variable(self, v: FixedVariable, _recorded: dict[str, FixedVariable]):
39
- if v.name in _recorded:
40
- return
41
-
42
- if v.int_min == v.int_max:
43
- _recorded[v.name] = v
44
- return
45
-
46
- if v._from is None:
47
- raise ValueError('Variable not derived from other variables cannot be defined in runtime')
48
-
49
- self._resolve_variable(v._from[0], _recorded)
50
- self._resolve_variable(v._from[1], _recorded)
51
- _recorded[v.name] = v
52
-
53
- def resolve_all_variables(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
54
- _recorded = {v.name: v for v in inputs}
55
- for v in outputs:
56
- self._resolve_variable(v, _recorded)
57
- return _recorded
58
-
59
- def gen_lines(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
60
- variables = self.resolve_all_variables(inputs, outputs)
61
- keys = list(variables.keys())
62
- keys = sorted(keys, key=lambda x: variables[x]._depth)
63
- codes = []
64
- cur_depth = -1
65
- s_inputs = set(inputs)
66
- for key in keys:
67
- v = variables[key]
68
- if v.int_min == v.int_max or v in s_inputs:
69
- continue
70
- if cur_depth != v._depth:
71
- cur_depth = v._depth
72
- codes.append(f'{self._comment} ========================== Latency: {cur_depth} ==========================')
73
- codes.append(self.def_code(v))
74
- for i, out in enumerate(outputs):
75
- codes.append(f'out[{i}] = {self.reference_code(out)}')
76
- return codes
77
-
78
- def gen_fn(self, inputs: list[FixedVariable], outputs: list[FixedVariable], **kwargs):
79
- fn_name = kwargs.get('fn_name', self._attrs['fn_name'])
80
- code = self.gen_lines(inputs, outputs)
81
- code_str = '\n '.join(code)
82
- fn_str = f"""def {fn_name}(inp: list[float]):
83
- out = [0.]*{len(outputs)}
84
- {code_str}
85
- return out
86
- """
87
- fn_obj = compile(fn_str, '<string>', 'exec')
88
- fn = types.FunctionType(fn_obj.co_consts[1], globals())
89
- return fn, fn_str
90
-
91
- def __call__(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
92
- return self.gen_fn(inputs, outputs)
93
-
94
-
95
- class VitisCodegenBackend(PyCodegenBackend):
96
- _comment = '//'
97
-
98
- def __init__(self, namer=Namer(), fn_name: str = 'placeholder', **kwargs):
99
- self._namer = namer
100
- self._attrs = {'fn_name': fn_name, **kwargs}
101
-
102
- def reference_code(self, v: FixedVariable):
103
- """How the variable should be referenced in the code"""
104
- if v.int_min == v.int_max:
105
- k, b, i = v.k, v.b, v.i
106
- u = '' if k else 'u'
107
- type_str = f'ap_{u}fixed<{max(b+k,1)}, {i+k}>'
108
- return f'{type_str}({v.min})'
109
-
110
- neg = v._factor < 0
111
- shift = log2(abs(v._factor))
112
- assert shift % 1 == 0
113
- shift = int(shift)
114
- s_sign = '-' if neg else ''
115
- if shift == 0:
116
- return f'{s_sign}{v.name}'
117
- return f'{s_sign}bit_shift<{shift}>({v.name})'
118
-
119
- def gen_lines(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
120
- codes = super().gen_lines(inputs, outputs)
121
- n = len(outputs)
122
- for i, out in enumerate(outputs):
123
- codes[-n + i] = f'out[{i}] = {self.reference_code(out)};'
124
- return codes
125
-
126
- def def_code(self, v: FixedVariable):
127
- """How the variable should be defined in the code"""
128
- if v.int_min == v.int_max:
129
- raise ValueError('Constant variable should not be defined')
130
- assert v._from is not None, 'Variable not derived from other variables cannot be defined in runtime'
131
- v1_str = self.reference_code(v._from[0])
132
- v2_str = self.reference_code(v._from[1])
133
- vv = v * (1 / v._factor)
134
- k, b, i = vv.k, vv.b, vv.i
135
- b, i = b + k, i + k # b and i did not include sign bit
136
- u = '' if k else 'u'
137
- type_str = f'ap_{u}fixed<{b}, {i}>'
138
- if v2_str[0] == '-':
139
- return f'{type_str} {v.name} = {v1_str} - {v2_str[1:]};'
140
- return f'{type_str} {v.name} = {v1_str} + {v2_str};'
141
-
142
- def gen_fn(self, inputs: list[FixedVariable], outputs: list[FixedVariable], **kwargs):
143
- attrs = {**self._attrs, **kwargs}
144
- fn_name = attrs['fn_name']
145
- code = self.gen_lines(inputs, outputs)
146
- code_str = '\n '.join(code)
147
-
148
- fn_str = f"""template <typename inp_t, typename out_t>
149
- void {fn_name}(inp_t inp[{len(inputs)}], out_t out[{len(outputs)}]) {{
150
- {code_str}
151
- }}
152
- """
153
- self._comment = '#'
154
- fn, _ = PyCodegenBackend().gen_fn(inputs, outputs, fn_name=fn_name)
155
- self._comment = '//'
156
- return fn, fn_str
157
-
158
- def __call__(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
159
- return self.gen_fn(inputs, outputs)
da4ml/cmvm/csd.py DELETED
@@ -1,73 +0,0 @@
1
- import numpy as np
2
- from numba import njit
3
- from numpy.typing import NDArray
4
-
5
-
6
- @njit
7
- def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
8
- x = x
9
- N = np.max(np.ceil(np.log2(np.abs(x + 1e-19) * 1.5)))
10
- N = int(max(N, 1))
11
- buf = np.zeros((*np.shape(x), N), dtype=np.int8)
12
-
13
- for n in range(N - 1, -1, -1):
14
- _2pn = 2**n
15
- thres = _2pn / 1.5
16
- bit = (x > thres).astype(np.int8)
17
- bit -= (x < -thres).astype(np.int8)
18
- x -= _2pn * bit
19
- buf[..., n] = bit
20
- return buf
21
-
22
-
23
- @njit(error_model='numpy')
24
- def to_csd(x: NDArray) -> tuple[list[NDArray[np.int8]], NDArray[np.int8]]:
25
- low, high = -32, 32
26
- if np.all(x == 0):
27
- high = low = 0
28
- while high - low > 1:
29
- mid = (high + low) // 2
30
- xs = x * (2.0**mid)
31
- if np.all(xs == np.floor(xs)):
32
- high = mid
33
- else:
34
- low = mid
35
- _x = x * (2.0**high)
36
- csd = _volatile_int_arr_to_csd(_x)
37
- shifts = np.arange(csd.shape[-1], dtype=np.int8) - high
38
- return list(csd), shifts
39
-
40
-
41
- @njit
42
- def _volatile_int_arr_to_binary(x: NDArray) -> NDArray[np.int8]:
43
- x = x
44
- N = np.max(np.ceil(np.log2(np.abs(x) + 1)))
45
- N = int(max(N, 1))
46
- buf = np.zeros((*np.shape(x), N), dtype=np.int8)
47
-
48
- for n in range(N - 1, -1, -1):
49
- _2pn = 2**n
50
- thres = _2pn
51
- bit = (x >= thres).astype(np.int8)
52
- bit -= (x <= -thres).astype(np.int8)
53
- x -= _2pn * bit
54
- buf[..., n] = bit
55
- return buf
56
-
57
-
58
- @njit(error_model='numpy')
59
- def to_binary(x: NDArray) -> tuple[list[NDArray[np.int8]], NDArray[np.int8]]:
60
- low, high = -32, 32
61
- if np.all(x == 0):
62
- high = low = 0
63
- while high - low > 1:
64
- mid = (high + low) // 2
65
- xs = x * (2.0**mid)
66
- if np.all(xs == np.floor(xs)):
67
- high = mid
68
- else:
69
- low = mid
70
- _x = x * (2.0**high)
71
- csd = _volatile_int_arr_to_binary(_x)
72
- shifts = np.arange(csd.shape[-1], dtype=np.int8) - high
73
- return list(csd), shifts