da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of da4ml might be problematic. Click here for more details.
- da4ml/__init__.py +16 -16
- da4ml/_version.py +2 -2
- da4ml/cmvm/__init__.py +3 -34
- da4ml/cmvm/api.py +239 -73
- da4ml/cmvm/core/__init__.py +222 -0
- da4ml/cmvm/core/indexers.py +83 -0
- da4ml/cmvm/core/state_opr.py +284 -0
- da4ml/cmvm/types.py +569 -0
- da4ml/cmvm/util/__init__.py +7 -0
- da4ml/cmvm/util/bit_decompose.py +86 -0
- da4ml/cmvm/util/mat_decompose.py +121 -0
- da4ml/codegen/__init__.py +11 -0
- da4ml/codegen/cpp/__init__.py +3 -0
- da4ml/codegen/cpp/cpp_codegen.py +148 -0
- da4ml/codegen/cpp/source/vitis.h +30 -0
- da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
- da4ml/codegen/verilog/__init__.py +13 -0
- da4ml/codegen/verilog/comb.py +146 -0
- da4ml/codegen/verilog/io_wrapper.py +255 -0
- da4ml/codegen/verilog/pipeline.py +49 -0
- da4ml/codegen/verilog/source/build_binder.mk +27 -0
- da4ml/codegen/verilog/source/build_prj.tcl +75 -0
- da4ml/codegen/verilog/source/ioutils.hh +117 -0
- da4ml/codegen/verilog/source/shift_adder.v +56 -0
- da4ml/codegen/verilog/source/template.xdc +29 -0
- da4ml/codegen/verilog/verilog_model.py +265 -0
- da4ml/trace/__init__.py +6 -0
- da4ml/trace/fixed_variable.py +358 -0
- da4ml/trace/fixed_variable_array.py +177 -0
- da4ml/trace/ops/__init__.py +55 -0
- da4ml/trace/ops/conv_utils.py +104 -0
- da4ml/trace/ops/einsum_utils.py +299 -0
- da4ml/trace/pipeline.py +155 -0
- da4ml/trace/tracer.py +120 -0
- da4ml-0.2.0.dist-info/METADATA +65 -0
- da4ml-0.2.0.dist-info/RECORD +39 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
- da4ml/cmvm/balanced_reduction.py +0 -46
- da4ml/cmvm/cmvm.py +0 -328
- da4ml/cmvm/codegen.py +0 -159
- da4ml/cmvm/csd.py +0 -73
- da4ml/cmvm/fixed_variable.py +0 -205
- da4ml/cmvm/graph_compile.py +0 -85
- da4ml/cmvm/nb_fixed_precision.py +0 -98
- da4ml/cmvm/scoring.py +0 -55
- da4ml/cmvm/utils.py +0 -5
- da4ml-0.1.2.dist-info/METADATA +0 -122
- da4ml-0.1.2.dist-info/RECORD +0 -18
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
from ..cmvm import solve
|
|
7
|
+
from .fixed_variable import FixedVariable, HWConfig, QInterval
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FixedVariableArray:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
vars: NDArray,
|
|
14
|
+
solver_options: dict[str, Any] | None = None,
|
|
15
|
+
):
|
|
16
|
+
self._vars = np.array(vars)
|
|
17
|
+
self.solver_options = solver_options
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_lhs(
|
|
21
|
+
cls,
|
|
22
|
+
low: NDArray[np.floating],
|
|
23
|
+
high: NDArray[np.floating],
|
|
24
|
+
step: NDArray[np.floating],
|
|
25
|
+
hwconf: HWConfig,
|
|
26
|
+
latency: np.ndarray | float = 0.0,
|
|
27
|
+
solver_options: dict[str, Any] | None = None,
|
|
28
|
+
):
|
|
29
|
+
shape = low.shape
|
|
30
|
+
assert shape == high.shape == step.shape
|
|
31
|
+
|
|
32
|
+
low, high, step = low.ravel(), high.ravel(), step.ravel()
|
|
33
|
+
latency = np.full_like(low, latency) if isinstance(latency, (int, float)) else latency.ravel()
|
|
34
|
+
|
|
35
|
+
vars = []
|
|
36
|
+
for i, (l, h, s, lat) in enumerate(zip(low, high, step, latency)):
|
|
37
|
+
var = FixedVariable(
|
|
38
|
+
low=float(l),
|
|
39
|
+
high=float(h),
|
|
40
|
+
step=float(s),
|
|
41
|
+
hwconf=hwconf,
|
|
42
|
+
latency=float(
|
|
43
|
+
lat,
|
|
44
|
+
),
|
|
45
|
+
)
|
|
46
|
+
vars.append(var)
|
|
47
|
+
vars = np.array(vars).reshape(shape)
|
|
48
|
+
return cls(vars, solver_options)
|
|
49
|
+
|
|
50
|
+
__array_priority__ = 100
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_kif(
|
|
54
|
+
cls,
|
|
55
|
+
k: NDArray[np.bool_ | np.integer],
|
|
56
|
+
i: NDArray[np.integer],
|
|
57
|
+
f: NDArray[np.integer],
|
|
58
|
+
hwconf: HWConfig,
|
|
59
|
+
latency: NDArray[np.floating] | float = 0.0,
|
|
60
|
+
solver_options: dict[str, Any] | None = None,
|
|
61
|
+
):
|
|
62
|
+
step = 2.0**-f
|
|
63
|
+
_high = 2.0**i
|
|
64
|
+
high, low = _high - step, -_high * k
|
|
65
|
+
return cls.from_lhs(low, high, step, hwconf, latency, solver_options)
|
|
66
|
+
|
|
67
|
+
def __matmul__(self, other):
|
|
68
|
+
assert isinstance(other, np.ndarray)
|
|
69
|
+
kwargs = (self.solver_options or {}).copy()
|
|
70
|
+
shape0, shape1 = self.shape, other.shape
|
|
71
|
+
assert shape0[-1] == shape1[0], f'Matrix shapes do not match: {shape0} @ {shape1}'
|
|
72
|
+
c = shape1[0]
|
|
73
|
+
out_shape = shape0[:-1] + shape1[1:]
|
|
74
|
+
mat0, mat1 = self.reshape((-1, c)), other.reshape((c, -1))
|
|
75
|
+
r = []
|
|
76
|
+
for i in range(mat0.shape[0]):
|
|
77
|
+
vec = mat0[i]
|
|
78
|
+
qintervals = tuple([QInterval(float(v.low), float(v.high), float(v.step)) for v in vec._vars])
|
|
79
|
+
latencies = tuple([float(v.latency) for v in vec._vars])
|
|
80
|
+
hwconf = self._vars.ravel()[0].hwconf
|
|
81
|
+
kwargs.update(adder_size=hwconf.adder_size, carry_size=hwconf.carry_size)
|
|
82
|
+
_mat = np.ascontiguousarray(mat1.astype(np.float32))
|
|
83
|
+
sol = solve(_mat, qintervals=qintervals, latencies=latencies, **kwargs)
|
|
84
|
+
_r = sol(vec._vars)
|
|
85
|
+
r.append(_r)
|
|
86
|
+
r = np.array(r).reshape(out_shape)
|
|
87
|
+
return FixedVariableArray(r, self.solver_options)
|
|
88
|
+
|
|
89
|
+
def __rmatmul__(self, other):
|
|
90
|
+
mat1 = np.moveaxis(other, -1, 0)
|
|
91
|
+
mat0 = np.moveaxis(self._vars, 0, -1)
|
|
92
|
+
ndim0, ndim1 = mat0.ndim, mat1.ndim
|
|
93
|
+
r = FixedVariableArray(mat0, self.solver_options) @ mat1
|
|
94
|
+
|
|
95
|
+
_axes = tuple(range(0, ndim0 + ndim1 - 2))
|
|
96
|
+
axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
|
|
97
|
+
return r.transpose(axes)
|
|
98
|
+
|
|
99
|
+
def __getitem__(self, *item):
|
|
100
|
+
vars = self._vars[*item]
|
|
101
|
+
if isinstance(vars, np.ndarray):
|
|
102
|
+
return FixedVariableArray(vars, self.solver_options)
|
|
103
|
+
else:
|
|
104
|
+
return vars
|
|
105
|
+
|
|
106
|
+
def __len__(self):
|
|
107
|
+
return len(self._vars)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def shape(self):
|
|
111
|
+
return self._vars.shape
|
|
112
|
+
|
|
113
|
+
def __add__(self, other):
|
|
114
|
+
return FixedVariableArray(self._vars + other, self.solver_options)
|
|
115
|
+
|
|
116
|
+
def __sub__(self, other):
|
|
117
|
+
return FixedVariableArray(self._vars - other, self.solver_options)
|
|
118
|
+
|
|
119
|
+
def __mul__(self, other):
|
|
120
|
+
return FixedVariableArray(self._vars * other, self.solver_options)
|
|
121
|
+
|
|
122
|
+
def __truediv__(self, other):
|
|
123
|
+
return FixedVariableArray(self._vars * (1 / other), self.solver_options)
|
|
124
|
+
|
|
125
|
+
def __radd__(self, other):
|
|
126
|
+
return self + other
|
|
127
|
+
|
|
128
|
+
def __neg__(self):
|
|
129
|
+
return FixedVariableArray(-self._vars, self.solver_options)
|
|
130
|
+
|
|
131
|
+
def __repr__(self):
|
|
132
|
+
shape = self._vars.shape
|
|
133
|
+
hwconf_str = str(self._vars.ravel()[0].hwconf)[8:]
|
|
134
|
+
max_lat = max(v.latency for v in self._vars.ravel())
|
|
135
|
+
return f'FixedVariableArray(shape={shape}, hwconf={hwconf_str}, latency={max_lat})'
|
|
136
|
+
|
|
137
|
+
def relu(self, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN'):
|
|
138
|
+
shape = self._vars.shape
|
|
139
|
+
i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
|
|
140
|
+
f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
|
|
141
|
+
ret = []
|
|
142
|
+
for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()):
|
|
143
|
+
ret.append(v.relu(i=i, f=f, round_mode=round_mode))
|
|
144
|
+
return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
|
|
145
|
+
|
|
146
|
+
def quantize(
|
|
147
|
+
self,
|
|
148
|
+
k: NDArray[np.integer] | None = None,
|
|
149
|
+
i: NDArray[np.integer] | None = None,
|
|
150
|
+
f: NDArray[np.integer] | None = None,
|
|
151
|
+
overflow_mode: str = 'WRAP',
|
|
152
|
+
round_mode: str = 'TRN',
|
|
153
|
+
):
|
|
154
|
+
shape = self._vars.shape
|
|
155
|
+
k = np.broadcast_to(k, shape) if k is not None else np.full(shape, None)
|
|
156
|
+
i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
|
|
157
|
+
f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
|
|
158
|
+
ret = []
|
|
159
|
+
for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()):
|
|
160
|
+
ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
|
|
161
|
+
return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
|
|
162
|
+
|
|
163
|
+
def flatten(self):
|
|
164
|
+
return FixedVariableArray(self._vars.flatten(), self.solver_options)
|
|
165
|
+
|
|
166
|
+
def reshape(self, shape):
|
|
167
|
+
return FixedVariableArray(self._vars.reshape(shape), self.solver_options)
|
|
168
|
+
|
|
169
|
+
def transpose(self, axes=None):
|
|
170
|
+
return FixedVariableArray(self._vars.transpose(axes), self.solver_options)
|
|
171
|
+
|
|
172
|
+
def ravel(self):
|
|
173
|
+
return FixedVariableArray(self._vars.ravel(), self.solver_options)
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def dtype(self):
|
|
177
|
+
return self._vars.dtype
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import TypeVar
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
from ..fixed_variable_array import FixedVariable, FixedVariableArray
|
|
7
|
+
from .conv_utils import conv
|
|
8
|
+
from .einsum_utils import einsum
|
|
9
|
+
|
|
10
|
+
T = TypeVar('T', FixedVariableArray, NDArray[np.floating], list[FixedVariable])
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN') -> T:
|
|
14
|
+
if isinstance(x, FixedVariableArray):
|
|
15
|
+
return x.relu(i=i, f=f, round_mode=round_mode)
|
|
16
|
+
elif isinstance(x, list):
|
|
17
|
+
return [xx.relu(i=ii, f=ff, round_mode=round_mode) for xx, ii, ff in zip(x, i, f)] # type: ignore
|
|
18
|
+
else:
|
|
19
|
+
x = np.maximum(x, 0)
|
|
20
|
+
if f is not None:
|
|
21
|
+
if round_mode.upper() == 'RND':
|
|
22
|
+
x += 2.0 ** (-f - 1)
|
|
23
|
+
sf = 2.0**f
|
|
24
|
+
x = np.floor(x * sf) / sf
|
|
25
|
+
if i is not None:
|
|
26
|
+
x = x % 2.0**i
|
|
27
|
+
return x
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def quantize(
|
|
31
|
+
x: T,
|
|
32
|
+
k: NDArray[np.integer],
|
|
33
|
+
i: NDArray[np.integer],
|
|
34
|
+
f: NDArray[np.integer],
|
|
35
|
+
overflow_mode: str = 'WRAP',
|
|
36
|
+
round_mode: str = 'TRN',
|
|
37
|
+
) -> T:
|
|
38
|
+
assert overflow_mode.upper() == 'WRAP', 'Only WRAP overflow mode is supported'
|
|
39
|
+
if isinstance(x, FixedVariableArray):
|
|
40
|
+
return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
|
|
41
|
+
else:
|
|
42
|
+
if round_mode.upper() == 'RND':
|
|
43
|
+
x += 2.0 ** (-f - 1)
|
|
44
|
+
b = k + i + f
|
|
45
|
+
bias = 2.0 ** (b - 1) * k
|
|
46
|
+
eps = 2.0**-f
|
|
47
|
+
return eps * ((np.floor(x / eps) + bias) % 2.0**b - bias) # type: ignore
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
__all__ = [
|
|
51
|
+
'conv',
|
|
52
|
+
'einsum',
|
|
53
|
+
'relu',
|
|
54
|
+
'quantize',
|
|
55
|
+
]
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import TypeVar
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def r_im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int):
|
|
11
|
+
w = kernel_size[0]
|
|
12
|
+
if len(kernel_size) == 3: # 1D
|
|
13
|
+
for i in range(arr.shape[axis] - w + 1):
|
|
14
|
+
patch = np.take(arr, range(i, i + w), axis=axis)
|
|
15
|
+
buffer[i] = patch.flatten()
|
|
16
|
+
else: # 2D+
|
|
17
|
+
for i in range(arr.shape[axis] - w + 1):
|
|
18
|
+
patch = arr[i : i + w]
|
|
19
|
+
r_im2col(kernel_size[1:], patch, buffer[i], axis + 1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _im2col(kernel_size: Sequence[int], arr: np.ndarray):
|
|
23
|
+
if len(kernel_size) < 3:
|
|
24
|
+
return arr
|
|
25
|
+
shape = [inp_d - ker_d + 1 for inp_d, ker_d in zip(arr.shape, kernel_size[:-2])]
|
|
26
|
+
shape.append(np.prod(kernel_size[:-1])) # type: ignore
|
|
27
|
+
buf = np.empty(shape, dtype=arr.dtype)
|
|
28
|
+
r_im2col(kernel_size, arr, buf, 0)
|
|
29
|
+
return buf
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def stride_arr(stride: int | tuple[int, ...], arr: np.ndarray):
|
|
33
|
+
ndim = arr.ndim
|
|
34
|
+
if isinstance(stride, int):
|
|
35
|
+
stride = (stride,) * (ndim - 1)
|
|
36
|
+
assert len(stride) == ndim - 1, f'Invalid stride {stride} for array with {ndim} dimensions'
|
|
37
|
+
|
|
38
|
+
_idx = tuple(slice(None, None, st) for st in stride)
|
|
39
|
+
return arr[*_idx]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
T = TypeVar('T', FixedVariableArray, NDArray[np.integer | np.floating])
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def conv(
|
|
46
|
+
x: T,
|
|
47
|
+
kernel: NDArray[np.integer | np.floating],
|
|
48
|
+
bias: NDArray[np.integer | np.floating] | None = None,
|
|
49
|
+
strides: int | tuple[int, ...] = 1,
|
|
50
|
+
padding: tuple[tuple[int, int], ...] | str = 'VALID',
|
|
51
|
+
format: str = 'channels_last',
|
|
52
|
+
):
|
|
53
|
+
if isinstance(x, FixedVariableArray):
|
|
54
|
+
solver_options = x.solver_options
|
|
55
|
+
data = x._vars
|
|
56
|
+
is_symbolic = True
|
|
57
|
+
else:
|
|
58
|
+
solver_options = None
|
|
59
|
+
data = x
|
|
60
|
+
is_symbolic = False
|
|
61
|
+
|
|
62
|
+
ndim = data.ndim
|
|
63
|
+
ch_in, ch_out = kernel.shape[-2:]
|
|
64
|
+
_ch_in = data.shape[-1]
|
|
65
|
+
assert ch_in == _ch_in, f'Invalid input shape {data.shape} for kernel {kernel.shape}'
|
|
66
|
+
assert kernel.ndim == ndim + 1
|
|
67
|
+
|
|
68
|
+
assert format in ('channels_last', 'channels_first'), f'Invalid format {format}'
|
|
69
|
+
|
|
70
|
+
if isinstance(strides, int):
|
|
71
|
+
strides = (strides,) * (ndim - 1)
|
|
72
|
+
assert len(strides) == ndim - 1, f'Invalid stride {strides} for array with {ndim} dimensions'
|
|
73
|
+
|
|
74
|
+
if isinstance(padding, str):
|
|
75
|
+
padding = padding.upper()
|
|
76
|
+
if padding == 'VALID':
|
|
77
|
+
padding = ((0, 0),) * (ndim - 1)
|
|
78
|
+
elif padding == 'SAME':
|
|
79
|
+
_padding = []
|
|
80
|
+
for i in range(ndim - 1):
|
|
81
|
+
pad0 = kernel.shape[i] // 2
|
|
82
|
+
pad1 = kernel.shape[i] - pad0 - 1
|
|
83
|
+
_padding.append((pad1, pad0))
|
|
84
|
+
padding = tuple(_padding)
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f'Invalid padding {padding}')
|
|
87
|
+
assert len(padding) == ndim - 1, f'Invalid padding {padding} for array with {ndim} dimensions'
|
|
88
|
+
assert all(len(p) == 2 for p in padding), f'Invalid padding {padding} for array with {ndim} dimensions'
|
|
89
|
+
|
|
90
|
+
data = np.pad(data, padding + ((0, 0),), mode='constant', constant_values=0.0)
|
|
91
|
+
data = _im2col(kernel.shape, data)
|
|
92
|
+
if is_symbolic:
|
|
93
|
+
_data = FixedVariableArray(data, solver_options) @ kernel.reshape(-1, ch_out)
|
|
94
|
+
data = _data._vars
|
|
95
|
+
else:
|
|
96
|
+
data = data @ kernel.reshape(-1, ch_out)
|
|
97
|
+
data = stride_arr(strides, data)
|
|
98
|
+
if bias is not None:
|
|
99
|
+
data = data + bias
|
|
100
|
+
if format == 'channels_first':
|
|
101
|
+
data = np.moveaxis(data, -1, 1)
|
|
102
|
+
if solver_options is not None:
|
|
103
|
+
return FixedVariableArray(data, solver_options)
|
|
104
|
+
return data
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
from math import prod
|
|
2
|
+
from typing import TypedDict, overload
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from ..fixed_variable_array import FixedVariableArray
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EinsumRecipe(TypedDict):
|
|
11
|
+
direct_sum_axis: tuple[tuple[int, ...], tuple[int, ...]]
|
|
12
|
+
in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]]
|
|
13
|
+
L0: int
|
|
14
|
+
L1: int
|
|
15
|
+
I: int
|
|
16
|
+
C: int
|
|
17
|
+
out_interpert_shape: tuple[int, ...]
|
|
18
|
+
out_transpose_idxs: tuple[int, ...]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]):
|
|
22
|
+
"""Validate, resolve broadcasting, and compute output shape for einsum string
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
fn : str
|
|
27
|
+
einsum string, e.g. 'ij,jk->ik'
|
|
28
|
+
shape0 : tuple[int,...]
|
|
29
|
+
shape of input0
|
|
30
|
+
shape1 : tuple[int,...]
|
|
31
|
+
shape of input1
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
tuple[str, tuple[int,...]]
|
|
36
|
+
einsum string w/o broadcasting, and output shape
|
|
37
|
+
|
|
38
|
+
Raises
|
|
39
|
+
------
|
|
40
|
+
ValueError
|
|
41
|
+
If the einsum string is invalid, or if it is incompatible with the input shapes
|
|
42
|
+
"""
|
|
43
|
+
inp, out = map(str.strip, fn.split('->'))
|
|
44
|
+
in0, in1 = map(str.strip, inp.split(','))
|
|
45
|
+
alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|
46
|
+
s_alphabets = set(alphabets)
|
|
47
|
+
|
|
48
|
+
# Invalid characters
|
|
49
|
+
if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))):
|
|
50
|
+
raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only")
|
|
51
|
+
|
|
52
|
+
in0 = in0.replace('...', '0')
|
|
53
|
+
in1 = in1.replace('...', '0')
|
|
54
|
+
out = out.replace('...', '0')
|
|
55
|
+
ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
|
|
56
|
+
sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out)
|
|
57
|
+
free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out))
|
|
58
|
+
|
|
59
|
+
# Repeated indices
|
|
60
|
+
if len(sax_in0) != len(ax_in0):
|
|
61
|
+
for a in in0:
|
|
62
|
+
if in0.count(a) == 1:
|
|
63
|
+
continue
|
|
64
|
+
a = a if a != '0' else '...'
|
|
65
|
+
raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times")
|
|
66
|
+
if len(sax_in1) != len(ax_in1):
|
|
67
|
+
for a in in1:
|
|
68
|
+
if in1.count(a) == 1:
|
|
69
|
+
continue
|
|
70
|
+
a = a if a != '0' else '...'
|
|
71
|
+
raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times")
|
|
72
|
+
if len(sax_out) != len(ax_out):
|
|
73
|
+
for a in out:
|
|
74
|
+
if out.count(a) == 1:
|
|
75
|
+
continue
|
|
76
|
+
a = a if a != '0' else '...'
|
|
77
|
+
raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times")
|
|
78
|
+
|
|
79
|
+
# Invalid broadcasting
|
|
80
|
+
if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out:
|
|
81
|
+
if '0' not in sax_out:
|
|
82
|
+
raise ValueError(f'einsum string {fn} is invalid: output does not allow broadcasting, but inputs do')
|
|
83
|
+
if '0' not in sax_in0 and '0' not in sax_in1:
|
|
84
|
+
raise ValueError(f'einsum string {fn} is invalid: output allows broadcasting, but inputs do not')
|
|
85
|
+
|
|
86
|
+
# Output index out of nowhere
|
|
87
|
+
if remaining := sax_out - sax_in0 - sax_in1:
|
|
88
|
+
raise ValueError(f'einsum string {fn} is invalid: output subscripts {remaining} not found in inputs')
|
|
89
|
+
|
|
90
|
+
_common_in = sax_in0 & sax_in1
|
|
91
|
+
|
|
92
|
+
if '0' in sax_in0 and '0' in sax_in1:
|
|
93
|
+
# Simultaneous axes expansion in both inputs
|
|
94
|
+
n_boardcast0 = len(shape0) - len(sax_in0) + 1
|
|
95
|
+
n_boardcast1 = len(shape1) - len(sax_in1) + 1
|
|
96
|
+
assert n_boardcast0 == n_boardcast1, f'... expands to {n_boardcast0} and {n_boardcast1}-axis in input0 and input1.'
|
|
97
|
+
# Replace expansion indices with free indices
|
|
98
|
+
in0 = in0.replace('0', free_indices[:n_boardcast0])
|
|
99
|
+
in1 = in1.replace('0', free_indices[:n_boardcast1])
|
|
100
|
+
out = out.replace('0', free_indices[:n_boardcast0])
|
|
101
|
+
ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
|
|
102
|
+
_common_in = set(ax_in0) & set(ax_in1)
|
|
103
|
+
|
|
104
|
+
else:
|
|
105
|
+
# Axes expansion in input0 or input1 only
|
|
106
|
+
if '0' in sax_in0:
|
|
107
|
+
if len(sax_in0) - 1 > len(shape0):
|
|
108
|
+
raise ValueError(f'Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given')
|
|
109
|
+
# Replace auto expansion indices with free indices
|
|
110
|
+
n_broadcast = len(shape0) - len(sax_in0) + 1
|
|
111
|
+
in0 = in0.replace('0', free_indices[:n_broadcast])
|
|
112
|
+
out = out.replace('0', free_indices[:n_broadcast])
|
|
113
|
+
ax_in0 = list(in0)
|
|
114
|
+
ax_out = list(out)
|
|
115
|
+
else:
|
|
116
|
+
if len(sax_in0) != len(shape0):
|
|
117
|
+
raise ValueError(f'Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given')
|
|
118
|
+
|
|
119
|
+
if '0' in sax_in1:
|
|
120
|
+
if len(sax_in1) - 1 > len(shape1):
|
|
121
|
+
raise ValueError(f'Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given')
|
|
122
|
+
# Replace expansion indices with free indices
|
|
123
|
+
n_broadcast = len(shape1) - len(sax_in1) + 1
|
|
124
|
+
in1 = in1.replace('0', free_indices[:n_broadcast])
|
|
125
|
+
out = out.replace('0', free_indices[:n_broadcast])
|
|
126
|
+
ax_in1 = list(in1)
|
|
127
|
+
ax_out = list(out)
|
|
128
|
+
else:
|
|
129
|
+
if len(sax_in1) != len(shape1):
|
|
130
|
+
raise ValueError(f'Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given')
|
|
131
|
+
|
|
132
|
+
# Input dimension mismatch
|
|
133
|
+
for a in _common_in:
|
|
134
|
+
ax_0 = ax_in0.index(a)
|
|
135
|
+
ax_1 = ax_in1.index(a)
|
|
136
|
+
if shape0[ax_0] != shape1[ax_1]:
|
|
137
|
+
raise ValueError(f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}")
|
|
138
|
+
|
|
139
|
+
out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out)
|
|
140
|
+
return f'{in0},{in1}->{out}', out_shape
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe:
|
|
144
|
+
"""Parse einsum operation on two input arrays, return a recipe for execution
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
fn : str
|
|
149
|
+
einsum string, e.g. 'ij,jk->ik'
|
|
150
|
+
input : np.ndarray
|
|
151
|
+
input0, the first input array
|
|
152
|
+
input1 : np.ndarray
|
|
153
|
+
input1, the second input array
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
EinsumRecipe
|
|
158
|
+
einsum recipe; executed by _exec_einsum
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1)
|
|
162
|
+
|
|
163
|
+
_in, _out = fn.split('->')
|
|
164
|
+
_in0, _in1 = _in.split(',')
|
|
165
|
+
|
|
166
|
+
in0, in1, out = list(_in0), list(_in1), list(_out)
|
|
167
|
+
s_in0, s_in1, s_out = set(in0), set(in1), set(out)
|
|
168
|
+
_common = s_in0 & s_in1
|
|
169
|
+
_contract = _common - s_out
|
|
170
|
+
_inplace = _common & s_out
|
|
171
|
+
contract = sorted(_contract, key=lambda x: in1.index(x))
|
|
172
|
+
inplace = sorted(_inplace, key=lambda x: in1.index(x))
|
|
173
|
+
invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x))
|
|
174
|
+
invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x))
|
|
175
|
+
direct_sum0 = s_in0 - s_out - _common
|
|
176
|
+
direct_sum1 = s_in1 - s_out - _common
|
|
177
|
+
direct_sum_axis = (
|
|
178
|
+
tuple(sorted(in0.index(x) for x in direct_sum0)),
|
|
179
|
+
tuple(sorted(in1.index(x) for x in direct_sum1)),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract))
|
|
183
|
+
inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace))
|
|
184
|
+
invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1))
|
|
185
|
+
|
|
186
|
+
inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0])
|
|
187
|
+
inplace_size = prod(inplace_shape)
|
|
188
|
+
contract_size = prod(input_shape0[i] for i in contract_idxs[0])
|
|
189
|
+
invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0])
|
|
190
|
+
invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1])
|
|
191
|
+
invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1)
|
|
192
|
+
|
|
193
|
+
transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0]
|
|
194
|
+
transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1]
|
|
195
|
+
|
|
196
|
+
out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1
|
|
197
|
+
_out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1)))
|
|
198
|
+
out_transpose_idx = tuple(int(i) for i in _out_transpose_idx)
|
|
199
|
+
|
|
200
|
+
return EinsumRecipe(
|
|
201
|
+
direct_sum_axis=direct_sum_axis,
|
|
202
|
+
in_transpose_idxs=(transpose_idx0, transpose_idx1),
|
|
203
|
+
out_interpert_shape=out_shape_pretranspose,
|
|
204
|
+
out_transpose_idxs=out_transpose_idx,
|
|
205
|
+
L0=invariant_size0,
|
|
206
|
+
L1=invariant_size1,
|
|
207
|
+
I=inplace_size,
|
|
208
|
+
C=contract_size,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray:
|
|
213
|
+
"""Execute einsum operation on two input arrays
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
recipe : EinsumRecipe
|
|
218
|
+
einsum recipe
|
|
219
|
+
input0 : np.ndarray
|
|
220
|
+
input0, the first input array
|
|
221
|
+
input1 : np.ndarray
|
|
222
|
+
input1, the second input array
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
np.ndarray
|
|
227
|
+
output array
|
|
228
|
+
"""
|
|
229
|
+
sum_axis0, sum_axis1 = recipe['direct_sum_axis']
|
|
230
|
+
if sum_axis0:
|
|
231
|
+
input0 = np.sum(input0, axis=sum_axis0)
|
|
232
|
+
if sum_axis1:
|
|
233
|
+
input1 = np.sum(input1, axis=sum_axis1)
|
|
234
|
+
input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel()
|
|
235
|
+
input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel()
|
|
236
|
+
out_dtype = object if input0.dtype == object or input1.dtype == object else np.float64
|
|
237
|
+
output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=out_dtype)
|
|
238
|
+
|
|
239
|
+
L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C']
|
|
240
|
+
|
|
241
|
+
for l0 in range(L0):
|
|
242
|
+
for i in range(I):
|
|
243
|
+
A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C))
|
|
244
|
+
B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C]
|
|
245
|
+
output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = A @ B
|
|
246
|
+
|
|
247
|
+
return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs'])
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _einsum(fn: str, input0, input1) -> np.ndarray:
|
|
251
|
+
"""Execute einsum operation on two input arrays.
|
|
252
|
+
|
|
253
|
+
WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
fn : str
|
|
258
|
+
einsum string, e.g. 'ij,jk->ik'
|
|
259
|
+
input : np.ndarray
|
|
260
|
+
input0, the first input array
|
|
261
|
+
input1 : np.ndarray
|
|
262
|
+
input1, the second input array
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
np.ndarray
|
|
267
|
+
output array
|
|
268
|
+
"""
|
|
269
|
+
recipe = parse_einsum(fn, input0.shape, input1.shape)
|
|
270
|
+
return _exec_einsum(recipe, input0, input1)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@overload
|
|
274
|
+
def einsum(fn: str, input0: FixedVariableArray, input1: NDArray[np.integer | np.floating]) -> FixedVariableArray: ...
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@overload
|
|
278
|
+
def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: FixedVariableArray) -> FixedVariableArray: ...
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@overload
|
|
282
|
+
def einsum(
|
|
283
|
+
fn: str, input0: NDArray[np.integer | np.floating], input1: NDArray[np.integer | np.floating]
|
|
284
|
+
) -> NDArray[np.integer | np.floating]: ...
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def einsum(fn: str, input0, input1):
|
|
288
|
+
fg0 = isinstance(input0, FixedVariableArray)
|
|
289
|
+
fg1 = isinstance(input1, FixedVariableArray)
|
|
290
|
+
if fg0 and fg1:
|
|
291
|
+
raise ValueError('Einsum does not support two FixedVariableArray inputs')
|
|
292
|
+
|
|
293
|
+
r = _einsum(fn, input0, input1)
|
|
294
|
+
if fg0:
|
|
295
|
+
return FixedVariableArray(r, input0.solver_options)
|
|
296
|
+
elif fg1:
|
|
297
|
+
return FixedVariableArray(r, input1.solver_options)
|
|
298
|
+
else:
|
|
299
|
+
return r
|