da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +239 -73
  5. da4ml/cmvm/core/__init__.py +222 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +49 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +265 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +177 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +120 -0
  35. da4ml-0.2.0.dist-info/METADATA +65 -0
  36. da4ml-0.2.0.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+ from ..cmvm import solve
7
+ from .fixed_variable import FixedVariable, HWConfig, QInterval
8
+
9
+
10
+ class FixedVariableArray:
11
+ def __init__(
12
+ self,
13
+ vars: NDArray,
14
+ solver_options: dict[str, Any] | None = None,
15
+ ):
16
+ self._vars = np.array(vars)
17
+ self.solver_options = solver_options
18
+
19
+ @classmethod
20
+ def from_lhs(
21
+ cls,
22
+ low: NDArray[np.floating],
23
+ high: NDArray[np.floating],
24
+ step: NDArray[np.floating],
25
+ hwconf: HWConfig,
26
+ latency: np.ndarray | float = 0.0,
27
+ solver_options: dict[str, Any] | None = None,
28
+ ):
29
+ shape = low.shape
30
+ assert shape == high.shape == step.shape
31
+
32
+ low, high, step = low.ravel(), high.ravel(), step.ravel()
33
+ latency = np.full_like(low, latency) if isinstance(latency, (int, float)) else latency.ravel()
34
+
35
+ vars = []
36
+ for i, (l, h, s, lat) in enumerate(zip(low, high, step, latency)):
37
+ var = FixedVariable(
38
+ low=float(l),
39
+ high=float(h),
40
+ step=float(s),
41
+ hwconf=hwconf,
42
+ latency=float(
43
+ lat,
44
+ ),
45
+ )
46
+ vars.append(var)
47
+ vars = np.array(vars).reshape(shape)
48
+ return cls(vars, solver_options)
49
+
50
+ __array_priority__ = 100
51
+
52
+ @classmethod
53
+ def from_kif(
54
+ cls,
55
+ k: NDArray[np.bool_ | np.integer],
56
+ i: NDArray[np.integer],
57
+ f: NDArray[np.integer],
58
+ hwconf: HWConfig,
59
+ latency: NDArray[np.floating] | float = 0.0,
60
+ solver_options: dict[str, Any] | None = None,
61
+ ):
62
+ step = 2.0**-f
63
+ _high = 2.0**i
64
+ high, low = _high - step, -_high * k
65
+ return cls.from_lhs(low, high, step, hwconf, latency, solver_options)
66
+
67
+ def __matmul__(self, other):
68
+ assert isinstance(other, np.ndarray)
69
+ kwargs = (self.solver_options or {}).copy()
70
+ shape0, shape1 = self.shape, other.shape
71
+ assert shape0[-1] == shape1[0], f'Matrix shapes do not match: {shape0} @ {shape1}'
72
+ c = shape1[0]
73
+ out_shape = shape0[:-1] + shape1[1:]
74
+ mat0, mat1 = self.reshape((-1, c)), other.reshape((c, -1))
75
+ r = []
76
+ for i in range(mat0.shape[0]):
77
+ vec = mat0[i]
78
+ qintervals = tuple([QInterval(float(v.low), float(v.high), float(v.step)) for v in vec._vars])
79
+ latencies = tuple([float(v.latency) for v in vec._vars])
80
+ hwconf = self._vars.ravel()[0].hwconf
81
+ kwargs.update(adder_size=hwconf.adder_size, carry_size=hwconf.carry_size)
82
+ _mat = np.ascontiguousarray(mat1.astype(np.float32))
83
+ sol = solve(_mat, qintervals=qintervals, latencies=latencies, **kwargs)
84
+ _r = sol(vec._vars)
85
+ r.append(_r)
86
+ r = np.array(r).reshape(out_shape)
87
+ return FixedVariableArray(r, self.solver_options)
88
+
89
+ def __rmatmul__(self, other):
90
+ mat1 = np.moveaxis(other, -1, 0)
91
+ mat0 = np.moveaxis(self._vars, 0, -1)
92
+ ndim0, ndim1 = mat0.ndim, mat1.ndim
93
+ r = FixedVariableArray(mat0, self.solver_options) @ mat1
94
+
95
+ _axes = tuple(range(0, ndim0 + ndim1 - 2))
96
+ axes = _axes[ndim0 - 1 :] + _axes[: ndim0 - 1]
97
+ return r.transpose(axes)
98
+
99
+ def __getitem__(self, *item):
100
+ vars = self._vars[*item]
101
+ if isinstance(vars, np.ndarray):
102
+ return FixedVariableArray(vars, self.solver_options)
103
+ else:
104
+ return vars
105
+
106
+ def __len__(self):
107
+ return len(self._vars)
108
+
109
+ @property
110
+ def shape(self):
111
+ return self._vars.shape
112
+
113
+ def __add__(self, other):
114
+ return FixedVariableArray(self._vars + other, self.solver_options)
115
+
116
+ def __sub__(self, other):
117
+ return FixedVariableArray(self._vars - other, self.solver_options)
118
+
119
+ def __mul__(self, other):
120
+ return FixedVariableArray(self._vars * other, self.solver_options)
121
+
122
+ def __truediv__(self, other):
123
+ return FixedVariableArray(self._vars * (1 / other), self.solver_options)
124
+
125
+ def __radd__(self, other):
126
+ return self + other
127
+
128
+ def __neg__(self):
129
+ return FixedVariableArray(-self._vars, self.solver_options)
130
+
131
+ def __repr__(self):
132
+ shape = self._vars.shape
133
+ hwconf_str = str(self._vars.ravel()[0].hwconf)[8:]
134
+ max_lat = max(v.latency for v in self._vars.ravel())
135
+ return f'FixedVariableArray(shape={shape}, hwconf={hwconf_str}, latency={max_lat})'
136
+
137
+ def relu(self, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN'):
138
+ shape = self._vars.shape
139
+ i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
140
+ f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
141
+ ret = []
142
+ for v, i, f in zip(self._vars.ravel(), i.ravel(), f.ravel()):
143
+ ret.append(v.relu(i=i, f=f, round_mode=round_mode))
144
+ return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
145
+
146
+ def quantize(
147
+ self,
148
+ k: NDArray[np.integer] | None = None,
149
+ i: NDArray[np.integer] | None = None,
150
+ f: NDArray[np.integer] | None = None,
151
+ overflow_mode: str = 'WRAP',
152
+ round_mode: str = 'TRN',
153
+ ):
154
+ shape = self._vars.shape
155
+ k = np.broadcast_to(k, shape) if k is not None else np.full(shape, None)
156
+ i = np.broadcast_to(i, shape) if i is not None else np.full(shape, None)
157
+ f = np.broadcast_to(f, shape) if f is not None else np.full(shape, None)
158
+ ret = []
159
+ for v, k, i, f in zip(self._vars.ravel(), k.ravel(), i.ravel(), f.ravel()):
160
+ ret.append(v.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode))
161
+ return FixedVariableArray(np.array(ret).reshape(shape), self.solver_options)
162
+
163
+ def flatten(self):
164
+ return FixedVariableArray(self._vars.flatten(), self.solver_options)
165
+
166
+ def reshape(self, shape):
167
+ return FixedVariableArray(self._vars.reshape(shape), self.solver_options)
168
+
169
+ def transpose(self, axes=None):
170
+ return FixedVariableArray(self._vars.transpose(axes), self.solver_options)
171
+
172
+ def ravel(self):
173
+ return FixedVariableArray(self._vars.ravel(), self.solver_options)
174
+
175
+ @property
176
+ def dtype(self):
177
+ return self._vars.dtype
@@ -0,0 +1,55 @@
1
+ from typing import TypeVar
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+ from ..fixed_variable_array import FixedVariable, FixedVariableArray
7
+ from .conv_utils import conv
8
+ from .einsum_utils import einsum
9
+
10
+ T = TypeVar('T', FixedVariableArray, NDArray[np.floating], list[FixedVariable])
11
+
12
+
13
+ def relu(x: T, i: NDArray[np.integer] | None = None, f: NDArray[np.integer] | None = None, round_mode: str = 'TRN') -> T:
14
+ if isinstance(x, FixedVariableArray):
15
+ return x.relu(i=i, f=f, round_mode=round_mode)
16
+ elif isinstance(x, list):
17
+ return [xx.relu(i=ii, f=ff, round_mode=round_mode) for xx, ii, ff in zip(x, i, f)] # type: ignore
18
+ else:
19
+ x = np.maximum(x, 0)
20
+ if f is not None:
21
+ if round_mode.upper() == 'RND':
22
+ x += 2.0 ** (-f - 1)
23
+ sf = 2.0**f
24
+ x = np.floor(x * sf) / sf
25
+ if i is not None:
26
+ x = x % 2.0**i
27
+ return x
28
+
29
+
30
+ def quantize(
31
+ x: T,
32
+ k: NDArray[np.integer],
33
+ i: NDArray[np.integer],
34
+ f: NDArray[np.integer],
35
+ overflow_mode: str = 'WRAP',
36
+ round_mode: str = 'TRN',
37
+ ) -> T:
38
+ assert overflow_mode.upper() == 'WRAP', 'Only WRAP overflow mode is supported'
39
+ if isinstance(x, FixedVariableArray):
40
+ return x.quantize(k=k, i=i, f=f, overflow_mode=overflow_mode, round_mode=round_mode)
41
+ else:
42
+ if round_mode.upper() == 'RND':
43
+ x += 2.0 ** (-f - 1)
44
+ b = k + i + f
45
+ bias = 2.0 ** (b - 1) * k
46
+ eps = 2.0**-f
47
+ return eps * ((np.floor(x / eps) + bias) % 2.0**b - bias) # type: ignore
48
+
49
+
50
+ __all__ = [
51
+ 'conv',
52
+ 'einsum',
53
+ 'relu',
54
+ 'quantize',
55
+ ]
@@ -0,0 +1,104 @@
1
+ from collections.abc import Sequence
2
+ from typing import TypeVar
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ from ..fixed_variable_array import FixedVariableArray
8
+
9
+
10
+ def r_im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int):
11
+ w = kernel_size[0]
12
+ if len(kernel_size) == 3: # 1D
13
+ for i in range(arr.shape[axis] - w + 1):
14
+ patch = np.take(arr, range(i, i + w), axis=axis)
15
+ buffer[i] = patch.flatten()
16
+ else: # 2D+
17
+ for i in range(arr.shape[axis] - w + 1):
18
+ patch = arr[i : i + w]
19
+ r_im2col(kernel_size[1:], patch, buffer[i], axis + 1)
20
+
21
+
22
+ def _im2col(kernel_size: Sequence[int], arr: np.ndarray):
23
+ if len(kernel_size) < 3:
24
+ return arr
25
+ shape = [inp_d - ker_d + 1 for inp_d, ker_d in zip(arr.shape, kernel_size[:-2])]
26
+ shape.append(np.prod(kernel_size[:-1])) # type: ignore
27
+ buf = np.empty(shape, dtype=arr.dtype)
28
+ r_im2col(kernel_size, arr, buf, 0)
29
+ return buf
30
+
31
+
32
+ def stride_arr(stride: int | tuple[int, ...], arr: np.ndarray):
33
+ ndim = arr.ndim
34
+ if isinstance(stride, int):
35
+ stride = (stride,) * (ndim - 1)
36
+ assert len(stride) == ndim - 1, f'Invalid stride {stride} for array with {ndim} dimensions'
37
+
38
+ _idx = tuple(slice(None, None, st) for st in stride)
39
+ return arr[*_idx]
40
+
41
+
42
+ T = TypeVar('T', FixedVariableArray, NDArray[np.integer | np.floating])
43
+
44
+
45
+ def conv(
46
+ x: T,
47
+ kernel: NDArray[np.integer | np.floating],
48
+ bias: NDArray[np.integer | np.floating] | None = None,
49
+ strides: int | tuple[int, ...] = 1,
50
+ padding: tuple[tuple[int, int], ...] | str = 'VALID',
51
+ format: str = 'channels_last',
52
+ ):
53
+ if isinstance(x, FixedVariableArray):
54
+ solver_options = x.solver_options
55
+ data = x._vars
56
+ is_symbolic = True
57
+ else:
58
+ solver_options = None
59
+ data = x
60
+ is_symbolic = False
61
+
62
+ ndim = data.ndim
63
+ ch_in, ch_out = kernel.shape[-2:]
64
+ _ch_in = data.shape[-1]
65
+ assert ch_in == _ch_in, f'Invalid input shape {data.shape} for kernel {kernel.shape}'
66
+ assert kernel.ndim == ndim + 1
67
+
68
+ assert format in ('channels_last', 'channels_first'), f'Invalid format {format}'
69
+
70
+ if isinstance(strides, int):
71
+ strides = (strides,) * (ndim - 1)
72
+ assert len(strides) == ndim - 1, f'Invalid stride {strides} for array with {ndim} dimensions'
73
+
74
+ if isinstance(padding, str):
75
+ padding = padding.upper()
76
+ if padding == 'VALID':
77
+ padding = ((0, 0),) * (ndim - 1)
78
+ elif padding == 'SAME':
79
+ _padding = []
80
+ for i in range(ndim - 1):
81
+ pad0 = kernel.shape[i] // 2
82
+ pad1 = kernel.shape[i] - pad0 - 1
83
+ _padding.append((pad1, pad0))
84
+ padding = tuple(_padding)
85
+ else:
86
+ raise ValueError(f'Invalid padding {padding}')
87
+ assert len(padding) == ndim - 1, f'Invalid padding {padding} for array with {ndim} dimensions'
88
+ assert all(len(p) == 2 for p in padding), f'Invalid padding {padding} for array with {ndim} dimensions'
89
+
90
+ data = np.pad(data, padding + ((0, 0),), mode='constant', constant_values=0.0)
91
+ data = _im2col(kernel.shape, data)
92
+ if is_symbolic:
93
+ _data = FixedVariableArray(data, solver_options) @ kernel.reshape(-1, ch_out)
94
+ data = _data._vars
95
+ else:
96
+ data = data @ kernel.reshape(-1, ch_out)
97
+ data = stride_arr(strides, data)
98
+ if bias is not None:
99
+ data = data + bias
100
+ if format == 'channels_first':
101
+ data = np.moveaxis(data, -1, 1)
102
+ if solver_options is not None:
103
+ return FixedVariableArray(data, solver_options)
104
+ return data
@@ -0,0 +1,299 @@
1
+ from math import prod
2
+ from typing import TypedDict, overload
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ from ..fixed_variable_array import FixedVariableArray
8
+
9
+
10
+ class EinsumRecipe(TypedDict):
11
+ direct_sum_axis: tuple[tuple[int, ...], tuple[int, ...]]
12
+ in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]]
13
+ L0: int
14
+ L1: int
15
+ I: int
16
+ C: int
17
+ out_interpert_shape: tuple[int, ...]
18
+ out_transpose_idxs: tuple[int, ...]
19
+
20
+
21
+ def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]):
22
+ """Validate, resolve broadcasting, and compute output shape for einsum string
23
+
24
+ Parameters
25
+ ----------
26
+ fn : str
27
+ einsum string, e.g. 'ij,jk->ik'
28
+ shape0 : tuple[int,...]
29
+ shape of input0
30
+ shape1 : tuple[int,...]
31
+ shape of input1
32
+
33
+ Returns
34
+ -------
35
+ tuple[str, tuple[int,...]]
36
+ einsum string w/o broadcasting, and output shape
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ If the einsum string is invalid, or if it is incompatible with the input shapes
42
+ """
43
+ inp, out = map(str.strip, fn.split('->'))
44
+ in0, in1 = map(str.strip, inp.split(','))
45
+ alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
46
+ s_alphabets = set(alphabets)
47
+
48
+ # Invalid characters
49
+ if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))):
50
+ raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only")
51
+
52
+ in0 = in0.replace('...', '0')
53
+ in1 = in1.replace('...', '0')
54
+ out = out.replace('...', '0')
55
+ ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
56
+ sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out)
57
+ free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out))
58
+
59
+ # Repeated indices
60
+ if len(sax_in0) != len(ax_in0):
61
+ for a in in0:
62
+ if in0.count(a) == 1:
63
+ continue
64
+ a = a if a != '0' else '...'
65
+ raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times")
66
+ if len(sax_in1) != len(ax_in1):
67
+ for a in in1:
68
+ if in1.count(a) == 1:
69
+ continue
70
+ a = a if a != '0' else '...'
71
+ raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times")
72
+ if len(sax_out) != len(ax_out):
73
+ for a in out:
74
+ if out.count(a) == 1:
75
+ continue
76
+ a = a if a != '0' else '...'
77
+ raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times")
78
+
79
+ # Invalid broadcasting
80
+ if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out:
81
+ if '0' not in sax_out:
82
+ raise ValueError(f'einsum string {fn} is invalid: output does not allow broadcasting, but inputs do')
83
+ if '0' not in sax_in0 and '0' not in sax_in1:
84
+ raise ValueError(f'einsum string {fn} is invalid: output allows broadcasting, but inputs do not')
85
+
86
+ # Output index out of nowhere
87
+ if remaining := sax_out - sax_in0 - sax_in1:
88
+ raise ValueError(f'einsum string {fn} is invalid: output subscripts {remaining} not found in inputs')
89
+
90
+ _common_in = sax_in0 & sax_in1
91
+
92
+ if '0' in sax_in0 and '0' in sax_in1:
93
+ # Simultaneous axes expansion in both inputs
94
+ n_boardcast0 = len(shape0) - len(sax_in0) + 1
95
+ n_boardcast1 = len(shape1) - len(sax_in1) + 1
96
+ assert n_boardcast0 == n_boardcast1, f'... expands to {n_boardcast0} and {n_boardcast1}-axis in input0 and input1.'
97
+ # Replace expansion indices with free indices
98
+ in0 = in0.replace('0', free_indices[:n_boardcast0])
99
+ in1 = in1.replace('0', free_indices[:n_boardcast1])
100
+ out = out.replace('0', free_indices[:n_boardcast0])
101
+ ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out)
102
+ _common_in = set(ax_in0) & set(ax_in1)
103
+
104
+ else:
105
+ # Axes expansion in input0 or input1 only
106
+ if '0' in sax_in0:
107
+ if len(sax_in0) - 1 > len(shape0):
108
+ raise ValueError(f'Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given')
109
+ # Replace auto expansion indices with free indices
110
+ n_broadcast = len(shape0) - len(sax_in0) + 1
111
+ in0 = in0.replace('0', free_indices[:n_broadcast])
112
+ out = out.replace('0', free_indices[:n_broadcast])
113
+ ax_in0 = list(in0)
114
+ ax_out = list(out)
115
+ else:
116
+ if len(sax_in0) != len(shape0):
117
+ raise ValueError(f'Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given')
118
+
119
+ if '0' in sax_in1:
120
+ if len(sax_in1) - 1 > len(shape1):
121
+ raise ValueError(f'Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given')
122
+ # Replace expansion indices with free indices
123
+ n_broadcast = len(shape1) - len(sax_in1) + 1
124
+ in1 = in1.replace('0', free_indices[:n_broadcast])
125
+ out = out.replace('0', free_indices[:n_broadcast])
126
+ ax_in1 = list(in1)
127
+ ax_out = list(out)
128
+ else:
129
+ if len(sax_in1) != len(shape1):
130
+ raise ValueError(f'Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given')
131
+
132
+ # Input dimension mismatch
133
+ for a in _common_in:
134
+ ax_0 = ax_in0.index(a)
135
+ ax_1 = ax_in1.index(a)
136
+ if shape0[ax_0] != shape1[ax_1]:
137
+ raise ValueError(f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}")
138
+
139
+ out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out)
140
+ return f'{in0},{in1}->{out}', out_shape
141
+
142
+
143
+ def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe:
144
+ """Parse einsum operation on two input arrays, return a recipe for execution
145
+
146
+ Parameters
147
+ ----------
148
+ fn : str
149
+ einsum string, e.g. 'ij,jk->ik'
150
+ input : np.ndarray
151
+ input0, the first input array
152
+ input1 : np.ndarray
153
+ input1, the second input array
154
+
155
+ Returns
156
+ -------
157
+ EinsumRecipe
158
+ einsum recipe; executed by _exec_einsum
159
+ """
160
+
161
+ fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1)
162
+
163
+ _in, _out = fn.split('->')
164
+ _in0, _in1 = _in.split(',')
165
+
166
+ in0, in1, out = list(_in0), list(_in1), list(_out)
167
+ s_in0, s_in1, s_out = set(in0), set(in1), set(out)
168
+ _common = s_in0 & s_in1
169
+ _contract = _common - s_out
170
+ _inplace = _common & s_out
171
+ contract = sorted(_contract, key=lambda x: in1.index(x))
172
+ inplace = sorted(_inplace, key=lambda x: in1.index(x))
173
+ invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x))
174
+ invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x))
175
+ direct_sum0 = s_in0 - s_out - _common
176
+ direct_sum1 = s_in1 - s_out - _common
177
+ direct_sum_axis = (
178
+ tuple(sorted(in0.index(x) for x in direct_sum0)),
179
+ tuple(sorted(in1.index(x) for x in direct_sum1)),
180
+ )
181
+
182
+ contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract))
183
+ inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace))
184
+ invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1))
185
+
186
+ inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0])
187
+ inplace_size = prod(inplace_shape)
188
+ contract_size = prod(input_shape0[i] for i in contract_idxs[0])
189
+ invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0])
190
+ invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1])
191
+ invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1)
192
+
193
+ transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0]
194
+ transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1]
195
+
196
+ out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1
197
+ _out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1)))
198
+ out_transpose_idx = tuple(int(i) for i in _out_transpose_idx)
199
+
200
+ return EinsumRecipe(
201
+ direct_sum_axis=direct_sum_axis,
202
+ in_transpose_idxs=(transpose_idx0, transpose_idx1),
203
+ out_interpert_shape=out_shape_pretranspose,
204
+ out_transpose_idxs=out_transpose_idx,
205
+ L0=invariant_size0,
206
+ L1=invariant_size1,
207
+ I=inplace_size,
208
+ C=contract_size,
209
+ )
210
+
211
+
212
+ def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray:
213
+ """Execute einsum operation on two input arrays
214
+
215
+ Parameters
216
+ ----------
217
+ recipe : EinsumRecipe
218
+ einsum recipe
219
+ input0 : np.ndarray
220
+ input0, the first input array
221
+ input1 : np.ndarray
222
+ input1, the second input array
223
+
224
+ Returns
225
+ -------
226
+ np.ndarray
227
+ output array
228
+ """
229
+ sum_axis0, sum_axis1 = recipe['direct_sum_axis']
230
+ if sum_axis0:
231
+ input0 = np.sum(input0, axis=sum_axis0)
232
+ if sum_axis1:
233
+ input1 = np.sum(input1, axis=sum_axis1)
234
+ input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel()
235
+ input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel()
236
+ out_dtype = object if input0.dtype == object or input1.dtype == object else np.float64
237
+ output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=out_dtype)
238
+
239
+ L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C']
240
+
241
+ for l0 in range(L0):
242
+ for i in range(I):
243
+ A = input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C))
244
+ B = input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C]
245
+ output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = A @ B
246
+
247
+ return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs'])
248
+
249
+
250
+ def _einsum(fn: str, input0, input1) -> np.ndarray:
251
+ """Execute einsum operation on two input arrays.
252
+
253
+ WARNING: Order of multiplication is reversed -- watchout if you are using non-commutative operators
254
+
255
+ Parameters
256
+ ----------
257
+ fn : str
258
+ einsum string, e.g. 'ij,jk->ik'
259
+ input : np.ndarray
260
+ input0, the first input array
261
+ input1 : np.ndarray
262
+ input1, the second input array
263
+
264
+ Returns
265
+ -------
266
+ np.ndarray
267
+ output array
268
+ """
269
+ recipe = parse_einsum(fn, input0.shape, input1.shape)
270
+ return _exec_einsum(recipe, input0, input1)
271
+
272
+
273
+ @overload
274
+ def einsum(fn: str, input0: FixedVariableArray, input1: NDArray[np.integer | np.floating]) -> FixedVariableArray: ...
275
+
276
+
277
+ @overload
278
+ def einsum(fn: str, input0: NDArray[np.integer | np.floating], input1: FixedVariableArray) -> FixedVariableArray: ...
279
+
280
+
281
+ @overload
282
+ def einsum(
283
+ fn: str, input0: NDArray[np.integer | np.floating], input1: NDArray[np.integer | np.floating]
284
+ ) -> NDArray[np.integer | np.floating]: ...
285
+
286
+
287
+ def einsum(fn: str, input0, input1):
288
+ fg0 = isinstance(input0, FixedVariableArray)
289
+ fg1 = isinstance(input1, FixedVariableArray)
290
+ if fg0 and fg1:
291
+ raise ValueError('Einsum does not support two FixedVariableArray inputs')
292
+
293
+ r = _einsum(fn, input0, input1)
294
+ if fg0:
295
+ return FixedVariableArray(r, input0.solver_options)
296
+ elif fg1:
297
+ return FixedVariableArray(r, input1.solver_options)
298
+ else:
299
+ return r