da4ml 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +239 -73
  5. da4ml/cmvm/core/__init__.py +222 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +49 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +265 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +177 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +120 -0
  35. da4ml-0.2.0.dist-info/METADATA +65 -0
  36. da4ml-0.2.0.dist-info/RECORD +39 -0
  37. {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.1.dist-info/METADATA +0 -121
  48. da4ml-0.1.1.dist-info/RECORD +0 -18
  49. {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info/licenses}/LICENSE +0 -0
  50. {da4ml-0.1.1.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
da4ml/cmvm/cmvm.py DELETED
@@ -1,328 +0,0 @@
1
- import heapq
2
- from math import ceil
3
-
4
- import numpy as np
5
- from numba import njit, prange
6
- from numpy.typing import NDArray
7
-
8
- from .csd import to_csd
9
- from .nb_fixed_precision import NBFixedPrecision
10
- from .scoring import scorer
11
- from .utils import DAState, OpCode, Score
12
-
13
-
14
- @njit
15
- def extract_pairs(
16
- csd: list[NDArray[np.int8]], precisions: list[NBFixedPrecision], updated: list[int] | None = None, dc: int | None = None
17
- ):
18
- d_in = len(csd)
19
- if d_in == 0:
20
- raise ValueError('csd must have at least one element')
21
- d_out, n_bit = csd[0].shape
22
- _stat = np.zeros((d_in, d_in, n_bit, 2), dtype=np.int64)
23
- process_locs = np.zeros((d_in, d_out), dtype=np.bool_)
24
-
25
- if updated is not None:
26
- for i in range(len(updated)):
27
- pos = updated[i]
28
- for n in range(d_out):
29
- process_locs[pos, n] = True
30
- else:
31
- for pos in range(d_in):
32
- for n in range(d_out):
33
- if np.any(csd[pos][n]) and precisions[pos].b != 0:
34
- process_locs[pos, n] = True
35
-
36
- if dc is not None:
37
- depths = np.zeros(d_in, dtype=np.int64)
38
- for pos in range(d_in):
39
- depths[pos] = precisions[pos]._depth
40
- depth_min = np.min(depths)
41
- mask = depths <= depth_min + dc
42
- if np.count_nonzero(mask) >= 2:
43
- for n in range(d_out):
44
- process_locs[:, n] &= mask
45
-
46
- args: list[tuple[int, int, int]] = []
47
- for pos0 in range(d_in):
48
- for pos1 in range(d_in):
49
- for n in range(d_out):
50
- if process_locs[pos0, n] or process_locs[pos1, n]:
51
- args.append((pos0, pos1, n))
52
-
53
- for idx in prange(len(args)):
54
- pos0, pos1, n = args[idx]
55
- for shift0 in range(n_bit):
56
- if csd[pos0][n, shift0] == 0:
57
- continue
58
- lower = shift0 if pos0 < pos1 else shift0 + 1
59
- for shift1 in range(lower, n_bit):
60
- if csd[pos1][n, shift1] == 0:
61
- continue
62
- dsign = int(csd[pos0][n, shift0] != csd[pos1][n, shift1])
63
- _stat[pos0, pos1, shift1 - shift0, dsign] += 1
64
-
65
- stat: list[tuple[float, int, int, int, int, int]] = []
66
- for pos0 in range(d_in):
67
- for pos1 in range(d_in):
68
- for dshift in range(n_bit):
69
- for dsign in range(2):
70
- if _stat[pos0, pos1, dshift, dsign] > 1:
71
- count = int(_stat[pos0, pos1, dshift, dsign])
72
- score = scorer(precisions[pos0], precisions[pos1], dshift, dsign)
73
- n_count = (count - 1) if count > 1 else 0
74
- data = (-score * n_count, count, pos0, pos1, dshift, dsign)
75
- stat.append(data)
76
- if updated is None:
77
- heapq.heapify(stat)
78
- return stat
79
-
80
-
81
- @njit
82
- def init_var(
83
- k: bool,
84
- b: int,
85
- i: int,
86
- symmetric: bool = False,
87
- _depth: int = 0,
88
- ) -> NBFixedPrecision:
89
- shift = b - i
90
- int_max = 2**b - 1
91
- int_min = -int_max - 1 + symmetric if k else 0
92
- return NBFixedPrecision(
93
- int_min,
94
- int_max,
95
- shift,
96
- symmetric,
97
- _depth,
98
- )
99
-
100
-
101
- @njit
102
- def init_vars(
103
- ks: list[bool],
104
- bs: list[int],
105
- is_: list[int],
106
- symmetrics: list[bool],
107
- depths: list[int],
108
- ):
109
- n = len(ks)
110
- vars_ = []
111
- for i in range(n):
112
- vars_.append(init_var(ks[i], bs[i], is_[i], symmetrics[i], depths[i]))
113
- return vars_
114
-
115
-
116
- @njit
117
- def init_state(
118
- kernel: np.ndarray,
119
- signs: list[bool],
120
- bits: list[int],
121
- int_bits: list[int],
122
- symmetrics: list[bool],
123
- depths: list[int],
124
- ):
125
- assert kernel.ndim == 2
126
- assert len(signs) == len(bits) == len(int_bits) == len(symmetrics) == len(depths) == kernel.shape[0]
127
- csd, shifts = to_csd(kernel)
128
- shift = shifts[0]
129
-
130
- d_in = len(csd)
131
- vars_ = init_vars(signs, bits, int_bits, symmetrics, depths)
132
- op_codes = []
133
- for i in range(d_in):
134
- opr_code = OpCode(i, -10, shift, 0, 0, 0)
135
- op_codes.append(opr_code)
136
- vars_[i] = vars_[i] << shift
137
- pairs = extract_pairs(csd, vars_)
138
-
139
- potential = 0.0
140
- for i in range(len(pairs)):
141
- potential -= pairs[i][0]
142
- score = Score(potential, 0.0, 0.0, 0.0)
143
-
144
- state = DAState(csd, vars_, op_codes, pairs, score, kernel)
145
- return state
146
-
147
-
148
- @njit
149
- def update_state(state: DAState, pair: tuple[float, int, int, int, int, int], dc=None):
150
- neg_cum_score, count, pos0, pos1, dshift, dsign = pair
151
- variables = state.variables.copy()
152
- op_codes = state.op_codes.copy()
153
- csd = state.csd
154
- pairs = state.pairs
155
-
156
- realized = state.score.realized - neg_cum_score
157
-
158
- _d_in, d_out, n_bit = len(csd), *csd[0].shape
159
- new_csd_col = np.zeros((d_out, n_bit), dtype=np.int8)
160
- dsign = -1 if dsign else 1
161
- for n in range(d_out):
162
- for shift0 in range(0, n_bit - dshift):
163
- _dsign = csd[pos0][n, shift0] * csd[pos1][n, shift0 + dshift]
164
- if _dsign == dsign:
165
- new_csd_col[n, shift0] = csd[pos0][n, shift0]
166
- csd[pos0][n, shift0] = 0
167
- csd[pos1][n, shift0 + dshift] = 0
168
-
169
- csd.append(new_csd_col)
170
- v0, v1 = variables[pos0], variables[pos1] << dshift
171
- v = v0 + v1 if dsign == 0 else v0 - v1
172
- variables.append(v)
173
- op_code = OpCode(pos0, pos1, 0, dshift, 1, dsign)
174
- op_codes.append(op_code)
175
- updated = [pos0, pos1, len(variables) - 1]
176
-
177
- d_pairs = extract_pairs(csd, variables, updated, dc)
178
- for i in range(len(pairs) - 1, -1, -1):
179
- _pair = pairs[i]
180
- if pos0 == _pair[2] or pos1 == _pair[2] or pos0 == _pair[3] or pos1 == _pair[3]:
181
- pairs.pop(i)
182
-
183
- for i in range(len(d_pairs)):
184
- heapq.heappush(pairs, d_pairs[i])
185
-
186
- cur_potential = 0.0
187
- for i in range(len(pairs)):
188
- cur_potential -= pairs[i][0]
189
-
190
- lost = state.score.potential - cur_potential
191
- value = realized - lost
192
- score = Score(state.score.potential, realized, lost, value)
193
- return DAState(csd, variables, op_codes, pairs, score, state.kernel)
194
-
195
-
196
- @njit(cache=True)
197
- def get_top_n_pairs(state: DAState, n: int):
198
- return state.pairs[:n]
199
- _pairs = state.pairs.copy()
200
- return [heapq.heappop(_pairs) for _ in range(min(n, len(_pairs)))]
201
-
202
-
203
- @njit
204
- def cmvm_cse(state: DAState, progress=None, beams: int = 1, dc=None):
205
- assert len(state.pairs) > 0, f'{len(state.pairs)}'
206
- top_pairs = get_top_n_pairs(state, beams)
207
- states_0 = [update_state(state, top_pairs[i], dc) for i in range(len(top_pairs))]
208
-
209
- next_score = np.full((beams, beams), -1, dtype=np.float64)
210
- use_n = np.empty(beams, dtype=np.int64)
211
- while True:
212
- use_n[:] = 0
213
- next_score[:] = -1
214
- for i in range(len(states_0)):
215
- if len(states_0[i].pairs) > 0 and states_0[i].pairs[0][0] < 0:
216
- break
217
- else:
218
- break
219
-
220
- for i in range(len(states_0)):
221
- # pass
222
- _state = states_0[i]
223
- next_score[i, 0] = _state.score.realized
224
- next_score[i, 1:] = -1
225
- top_pairs = get_top_n_pairs(_state, beams)
226
- for j in range(len(top_pairs)):
227
- next_score[i, j] = -top_pairs[j][0] + _state.score.realized
228
-
229
- order = np.argsort(next_score.ravel())[::-1]
230
- for i in range(beams):
231
- i_st, i_pair = np.divmod(order[i], next_score.shape[0])
232
- if next_score[i_st, i_pair] > 0:
233
- use_n[i_st] += 1
234
-
235
- states_1 = []
236
- for i in range(beams):
237
- if use_n[i] == 0:
238
- continue
239
-
240
- _state = states_0[i]
241
- if len(_state.pairs) == 0 or _state.pairs[0][0] > 0:
242
- states_1.append(_state)
243
- continue
244
-
245
- top_pairs = get_top_n_pairs(_state, use_n[i])
246
- for j in range(use_n[i]):
247
- states_1.append(update_state(_state, top_pairs[j], dc))
248
-
249
- states_0 = states_1
250
-
251
- _max = states_0[0].score.realized
252
- _idx = 0
253
- for i in range(len(states_0)):
254
- _state = states_0[i]
255
- if _state.score.realized > _max:
256
- _max = _state.score.realized
257
- _idx = i
258
- return states_0[_idx]
259
-
260
-
261
- @njit
262
- def compile_kernel_mono(
263
- kernel: np.ndarray,
264
- signs: list[bool],
265
- bits: list[int],
266
- int_bits: list[int],
267
- symmetrics: list[bool],
268
- depths: list[int],
269
- n_beams: int = 1,
270
- dc: int | None = None,
271
- ):
272
- state = init_state(kernel, signs, bits, int_bits, symmetrics, depths)
273
- _state = cmvm_cse(state, beams=n_beams, dc=dc)
274
- return _state
275
-
276
-
277
- # @njit(cache=True)
278
- def compile_kernel(
279
- kernel: np.ndarray,
280
- signs: list[bool],
281
- bits: list[int],
282
- int_bits: list[int],
283
- symmetrics: list[bool],
284
- depths: list[int],
285
- n_beams: int = 1,
286
- dc: int | None = None,
287
- n_inp_max: int = -1,
288
- n_out_max: int = -1,
289
- ) -> list[list[DAState]]:
290
- d_inp, d_out = kernel.shape
291
- n_inp_part = 1
292
- if n_inp_max > 0 and n_inp_max < d_inp:
293
- n_inp_part = ceil(d_inp / n_inp_max)
294
- n_out_part = 1
295
- if n_out_max > 0 and n_out_max < d_out:
296
- n_out_part = ceil(d_out / n_out_max)
297
-
298
- inp_chunk_size = ceil(d_inp / n_inp_part)
299
- out_chunk_size = ceil(d_out / n_out_part)
300
-
301
- inp_part_locs = np.arange(0, n_inp_part + 1) * inp_chunk_size
302
- out_part_locs = np.arange(0, n_out_part + 1) * out_chunk_size
303
-
304
- states = [[None for _ in range(n_out_part)] for _ in range(n_inp_part)]
305
- for idx in range(n_inp_part * n_out_part):
306
- j, i = np.divmod(idx, n_inp_part)
307
- inp_start, inp_end = inp_part_locs[i], min(inp_part_locs[i + 1], d_inp) # type: ignore
308
- out_start, out_end = out_part_locs[j], min(out_part_locs[j + 1], d_out) # type: ignore
309
- _kernel = kernel[inp_start:inp_end, out_start:out_end]
310
- _signs = signs[inp_start:inp_end]
311
- _bits = bits[inp_start:inp_end]
312
- _int_bits = int_bits[inp_start:inp_end]
313
- _symmetrics = symmetrics[inp_start:inp_end]
314
- _depths = depths[inp_start:inp_end]
315
-
316
- # unify input type to prevent recompilation
317
- _kernel = np.ascontiguousarray(_kernel)
318
- _signs = [bool(v) for v in _signs]
319
- _bits = [int(v) for v in _bits]
320
- _int_bits = [int(v) for v in _int_bits]
321
- _symmetrics = [bool(v) for v in _symmetrics]
322
- _depths = [int(v) for v in _depths]
323
- try:
324
- states[i][j] = compile_kernel_mono(_kernel, _signs, _bits, _int_bits, _symmetrics, _depths, n_beams, dc)
325
- except AssertionError:
326
- states[i][j] = init_state(_kernel, _signs, _bits, _int_bits, _symmetrics, _depths)
327
-
328
- return states # type: ignore
da4ml/cmvm/codegen.py DELETED
@@ -1,159 +0,0 @@
1
- import types
2
- from math import log2
3
-
4
- from .fixed_variable import FixedVariable, Namer
5
-
6
-
7
- class PyCodegenBackend:
8
- _comment = '#'
9
-
10
- def __init__(self, namer=Namer(), fn_name: str = 'placeholder', **kwargs):
11
- self._namer = namer
12
- self._attrs = {'fn_name': fn_name, **kwargs}
13
-
14
- def reference_code(self, v: FixedVariable):
15
- """How the variable should be referenced in the code"""
16
- if v.int_min == v.int_max:
17
- return f'{v.min}'
18
-
19
- neg = v._factor < 0
20
- shift = log2(abs(v._factor))
21
- assert shift % 1 == 0
22
- shift = int(shift)
23
- s_sign = '-' if neg else ''
24
- s_shift = f' * {2.**shift}' if shift != 0 else ''
25
- return f'{s_sign}{v.name}{s_shift}'
26
-
27
- def def_code(self, v: FixedVariable):
28
- """How the variable should be defined in the code"""
29
- if v.int_min == v.int_max:
30
- raise ValueError('Constant variable should not be defined')
31
- assert v._from is not None, 'Variable not derived from other variables cannot be defined in runtime'
32
- v1_str = self.reference_code(v._from[0])
33
- v2_str = self.reference_code(v._from[1])
34
- if v2_str[0] == '-':
35
- return f'{v.name} = {v1_str} - {v2_str[1:]}'
36
- return f'{v.name} = {v1_str} + {v2_str}'
37
-
38
- def _resolve_variable(self, v: FixedVariable, _recorded: dict[str, FixedVariable]):
39
- if v.name in _recorded:
40
- return
41
-
42
- if v.int_min == v.int_max:
43
- _recorded[v.name] = v
44
- return
45
-
46
- if v._from is None:
47
- raise ValueError('Variable not derived from other variables cannot be defined in runtime')
48
-
49
- self._resolve_variable(v._from[0], _recorded)
50
- self._resolve_variable(v._from[1], _recorded)
51
- _recorded[v.name] = v
52
-
53
- def resolve_all_variables(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
54
- _recorded = {v.name: v for v in inputs}
55
- for v in outputs:
56
- self._resolve_variable(v, _recorded)
57
- return _recorded
58
-
59
- def gen_lines(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
60
- variables = self.resolve_all_variables(inputs, outputs)
61
- keys = list(variables.keys())
62
- keys = sorted(keys, key=lambda x: variables[x]._depth)
63
- codes = []
64
- cur_depth = -1
65
- s_inputs = set(inputs)
66
- for key in keys:
67
- v = variables[key]
68
- if v.int_min == v.int_max or v in s_inputs:
69
- continue
70
- if cur_depth != v._depth:
71
- cur_depth = v._depth
72
- codes.append(f'{self._comment} ========================== Latency: {cur_depth} ==========================')
73
- codes.append(self.def_code(v))
74
- for i, out in enumerate(outputs):
75
- codes.append(f'out[{i}] = {self.reference_code(out)}')
76
- return codes
77
-
78
- def gen_fn(self, inputs: list[FixedVariable], outputs: list[FixedVariable], **kwargs):
79
- fn_name = kwargs.get('fn_name', self._attrs['fn_name'])
80
- code = self.gen_lines(inputs, outputs)
81
- code_str = '\n '.join(code)
82
- fn_str = f"""def {fn_name}(inp: list[float]):
83
- out = [0.]*{len(outputs)}
84
- {code_str}
85
- return out
86
- """
87
- fn_obj = compile(fn_str, '<string>', 'exec')
88
- fn = types.FunctionType(fn_obj.co_consts[1], globals())
89
- return fn, fn_str
90
-
91
- def __call__(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
92
- return self.gen_fn(inputs, outputs)
93
-
94
-
95
- class VitisCodegenBackend(PyCodegenBackend):
96
- _comment = '//'
97
-
98
- def __init__(self, namer=Namer(), fn_name: str = 'placeholder', **kwargs):
99
- self._namer = namer
100
- self._attrs = {'fn_name': fn_name, **kwargs}
101
-
102
- def reference_code(self, v: FixedVariable):
103
- """How the variable should be referenced in the code"""
104
- if v.int_min == v.int_max:
105
- k, b, i = v.k, v.b, v.i
106
- u = '' if k else 'u'
107
- type_str = f'ap_{u}fixed<{max(b+k,1)}, {i+k}>'
108
- return f'{type_str}({v.min})'
109
-
110
- neg = v._factor < 0
111
- shift = log2(abs(v._factor))
112
- assert shift % 1 == 0
113
- shift = int(shift)
114
- s_sign = '-' if neg else ''
115
- if shift == 0:
116
- return f'{s_sign}{v.name}'
117
- return f'{s_sign}bit_shift<{shift}>({v.name})'
118
-
119
- def gen_lines(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
120
- codes = super().gen_lines(inputs, outputs)
121
- n = len(outputs)
122
- for i, out in enumerate(outputs):
123
- codes[-n + i] = f'out[{i}] = {self.reference_code(out)};'
124
- return codes
125
-
126
- def def_code(self, v: FixedVariable):
127
- """How the variable should be defined in the code"""
128
- if v.int_min == v.int_max:
129
- raise ValueError('Constant variable should not be defined')
130
- assert v._from is not None, 'Variable not derived from other variables cannot be defined in runtime'
131
- v1_str = self.reference_code(v._from[0])
132
- v2_str = self.reference_code(v._from[1])
133
- vv = v * (1 / v._factor)
134
- k, b, i = vv.k, vv.b, vv.i
135
- b, i = b + k, i + k # b and i did not include sign bit
136
- u = '' if k else 'u'
137
- type_str = f'ap_{u}fixed<{b}, {i}>'
138
- if v2_str[0] == '-':
139
- return f'{type_str} {v.name} = {v1_str} - {v2_str[1:]};'
140
- return f'{type_str} {v.name} = {v1_str} + {v2_str};'
141
-
142
- def gen_fn(self, inputs: list[FixedVariable], outputs: list[FixedVariable], **kwargs):
143
- attrs = {**self._attrs, **kwargs}
144
- fn_name = attrs['fn_name']
145
- code = self.gen_lines(inputs, outputs)
146
- code_str = '\n '.join(code)
147
-
148
- fn_str = f"""template <typename inp_t, typename out_t>
149
- void {fn_name}(inp_t inp[{len(inputs)}], out_t out[{len(outputs)}]) {{
150
- {code_str}
151
- }}
152
- """
153
- self._comment = '#'
154
- fn, _ = PyCodegenBackend().gen_fn(inputs, outputs, fn_name=fn_name)
155
- self._comment = '//'
156
- return fn, fn_str
157
-
158
- def __call__(self, inputs: list[FixedVariable], outputs: list[FixedVariable]):
159
- return self.gen_fn(inputs, outputs)
da4ml/cmvm/csd.py DELETED
@@ -1,73 +0,0 @@
1
- import numpy as np
2
- from numba import njit
3
- from numpy.typing import NDArray
4
-
5
-
6
- @njit
7
- def _volatile_int_arr_to_csd(x: NDArray) -> NDArray[np.int8]:
8
- x = x
9
- N = np.max(np.ceil(np.log2(np.abs(x + 1e-19) * 1.5)))
10
- N = int(max(N, 1))
11
- buf = np.zeros((*np.shape(x), N), dtype=np.int8)
12
-
13
- for n in range(N - 1, -1, -1):
14
- _2pn = 2**n
15
- thres = _2pn / 1.5
16
- bit = (x > thres).astype(np.int8)
17
- bit -= (x < -thres).astype(np.int8)
18
- x -= _2pn * bit
19
- buf[..., n] = bit
20
- return buf
21
-
22
-
23
- @njit(error_model='numpy')
24
- def to_csd(x: NDArray) -> tuple[list[NDArray[np.int8]], NDArray[np.int8]]:
25
- low, high = -32, 32
26
- if np.all(x == 0):
27
- high = low = 0
28
- while high - low > 1:
29
- mid = (high + low) // 2
30
- xs = x * (2.0**mid)
31
- if np.all(xs == np.floor(xs)):
32
- high = mid
33
- else:
34
- low = mid
35
- _x = x * (2.0**high)
36
- csd = _volatile_int_arr_to_csd(_x)
37
- shifts = np.arange(csd.shape[-1], dtype=np.int8) - high
38
- return list(csd), shifts
39
-
40
-
41
- @njit
42
- def _volatile_int_arr_to_binary(x: NDArray) -> NDArray[np.int8]:
43
- x = x
44
- N = np.max(np.ceil(np.log2(np.abs(x) + 1)))
45
- N = int(max(N, 1))
46
- buf = np.zeros((*np.shape(x), N), dtype=np.int8)
47
-
48
- for n in range(N - 1, -1, -1):
49
- _2pn = 2**n
50
- thres = _2pn
51
- bit = (x >= thres).astype(np.int8)
52
- bit -= (x <= -thres).astype(np.int8)
53
- x -= _2pn * bit
54
- buf[..., n] = bit
55
- return buf
56
-
57
-
58
- @njit(error_model='numpy')
59
- def to_binary(x: NDArray) -> tuple[list[NDArray[np.int8]], NDArray[np.int8]]:
60
- low, high = -32, 32
61
- if np.all(x == 0):
62
- high = low = 0
63
- while high - low > 1:
64
- mid = (high + low) // 2
65
- xs = x * (2.0**mid)
66
- if np.all(xs == np.floor(xs)):
67
- high = mid
68
- else:
69
- low = mid
70
- _x = x * (2.0**high)
71
- csd = _volatile_int_arr_to_binary(_x)
72
- shifts = np.arange(csd.shape[-1], dtype=np.int8) - high
73
- return list(csd), shifts