brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240623__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +147 -42
- brainstate/_module_test.py +95 -21
- brainstate/environ.py +0 -1
- brainstate/functional/__init__.py +2 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +0 -1
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/METADATA +2 -12
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/RECORD +28 -35
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/top_level.txt +0 -0
@@ -1,169 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
import keyword
|
18
|
-
import warnings
|
19
|
-
from typing import List, Optional, Set, Tuple, Union
|
20
|
-
|
21
|
-
_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
22
|
-
|
23
|
-
|
24
|
-
class EinopsError(Exception):
|
25
|
-
pass
|
26
|
-
|
27
|
-
|
28
|
-
class AnonymousAxis(object):
|
29
|
-
"""Important thing: all instances of this class are not equal to each other """
|
30
|
-
|
31
|
-
def __init__(self, value: str):
|
32
|
-
self.value = int(value)
|
33
|
-
if self.value <= 1:
|
34
|
-
if self.value == 1:
|
35
|
-
raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
|
36
|
-
else:
|
37
|
-
raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
|
38
|
-
|
39
|
-
def __repr__(self):
|
40
|
-
return "{}-axis".format(str(self.value))
|
41
|
-
|
42
|
-
|
43
|
-
class ParsedExpression:
|
44
|
-
"""
|
45
|
-
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
|
46
|
-
and keeps some information important for downstream
|
47
|
-
"""
|
48
|
-
|
49
|
-
def __init__(self, expression: str, *, allow_underscore: bool = False,
|
50
|
-
allow_duplicates: bool = False):
|
51
|
-
self.has_ellipsis: bool = False
|
52
|
-
self.has_ellipsis_parenthesized: Optional[bool] = None
|
53
|
-
self.identifiers: Set[str] = set()
|
54
|
-
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
55
|
-
self.has_non_unitary_anonymous_axes: bool = False
|
56
|
-
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
57
|
-
self.composition: List[Union[List[str], str]] = []
|
58
|
-
if '.' in expression:
|
59
|
-
if '...' not in expression:
|
60
|
-
raise EinopsError('Expression may contain dots only inside ellipsis (...)')
|
61
|
-
if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
|
62
|
-
raise EinopsError(
|
63
|
-
'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
|
64
|
-
expression = expression.replace('...', _ellipsis)
|
65
|
-
self.has_ellipsis = True
|
66
|
-
|
67
|
-
bracket_group: Optional[List[str]] = None
|
68
|
-
|
69
|
-
def add_axis_name(x):
|
70
|
-
if x in self.identifiers:
|
71
|
-
if not (allow_underscore and x == "_") and not allow_duplicates:
|
72
|
-
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
|
73
|
-
if x == _ellipsis:
|
74
|
-
self.identifiers.add(_ellipsis)
|
75
|
-
if bracket_group is None:
|
76
|
-
self.composition.append(_ellipsis)
|
77
|
-
self.has_ellipsis_parenthesized = False
|
78
|
-
else:
|
79
|
-
bracket_group.append(_ellipsis)
|
80
|
-
self.has_ellipsis_parenthesized = True
|
81
|
-
else:
|
82
|
-
is_number = str.isdecimal(x)
|
83
|
-
if is_number and int(x) == 1:
|
84
|
-
# handling the case of anonymous axis of length 1
|
85
|
-
if bracket_group is None:
|
86
|
-
self.composition.append([])
|
87
|
-
else:
|
88
|
-
pass # no need to think about 1s inside parenthesis
|
89
|
-
return
|
90
|
-
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
|
91
|
-
if not (is_number or is_axis_name):
|
92
|
-
raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
|
93
|
-
if is_number:
|
94
|
-
x = AnonymousAxis(x)
|
95
|
-
self.identifiers.add(x)
|
96
|
-
if is_number:
|
97
|
-
self.has_non_unitary_anonymous_axes = True
|
98
|
-
if bracket_group is None:
|
99
|
-
self.composition.append([x])
|
100
|
-
else:
|
101
|
-
bracket_group.append(x)
|
102
|
-
|
103
|
-
current_identifier = None
|
104
|
-
for char in expression:
|
105
|
-
if char in '() ':
|
106
|
-
if current_identifier is not None:
|
107
|
-
add_axis_name(current_identifier)
|
108
|
-
current_identifier = None
|
109
|
-
if char == '(':
|
110
|
-
if bracket_group is not None:
|
111
|
-
raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
|
112
|
-
bracket_group = []
|
113
|
-
elif char == ')':
|
114
|
-
if bracket_group is None:
|
115
|
-
raise EinopsError('Brackets are not balanced')
|
116
|
-
self.composition.append(bracket_group)
|
117
|
-
bracket_group = None
|
118
|
-
elif str.isalnum(char) or char in ['_', _ellipsis]:
|
119
|
-
if current_identifier is None:
|
120
|
-
current_identifier = char
|
121
|
-
else:
|
122
|
-
current_identifier += char
|
123
|
-
else:
|
124
|
-
raise EinopsError("Unknown character '{}'".format(char))
|
125
|
-
|
126
|
-
if bracket_group is not None:
|
127
|
-
raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
|
128
|
-
if current_identifier is not None:
|
129
|
-
add_axis_name(current_identifier)
|
130
|
-
|
131
|
-
def flat_axes_order(self) -> List:
|
132
|
-
result = []
|
133
|
-
for composed_axis in self.composition:
|
134
|
-
assert isinstance(composed_axis, list), 'does not work with ellipsis'
|
135
|
-
for axis in composed_axis:
|
136
|
-
result.append(axis)
|
137
|
-
return result
|
138
|
-
|
139
|
-
def has_composed_axes(self) -> bool:
|
140
|
-
# this will ignore 1 inside brackets
|
141
|
-
for axes in self.composition:
|
142
|
-
if isinstance(axes, list) and len(axes) > 1:
|
143
|
-
return True
|
144
|
-
return False
|
145
|
-
|
146
|
-
@staticmethod
|
147
|
-
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
|
148
|
-
if not str.isidentifier(name):
|
149
|
-
return False, 'not a valid python identifier'
|
150
|
-
elif name[0] == '_' or name[-1] == '_':
|
151
|
-
if name == '_' and allow_underscore:
|
152
|
-
return True, ''
|
153
|
-
return False, 'axis name should should not start or end with underscore'
|
154
|
-
else:
|
155
|
-
if keyword.iskeyword(name):
|
156
|
-
warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
|
157
|
-
if name in ['axis']:
|
158
|
-
warnings.warn("It is discouraged to use 'axis' as an axis name "
|
159
|
-
"and will raise an error in future", FutureWarning)
|
160
|
-
return True, ''
|
161
|
-
|
162
|
-
@staticmethod
|
163
|
-
def check_axis_name(name: str) -> bool:
|
164
|
-
"""
|
165
|
-
Valid axes names are python identifiers except keywords,
|
166
|
-
and additionally should not start or end with underscore
|
167
|
-
"""
|
168
|
-
is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
|
169
|
-
return is_valid
|
@@ -1,126 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import pytest
|
17
|
-
|
18
|
-
from brainstate.math._einops_parsing import EinopsError, ParsedExpression, AnonymousAxis, _ellipsis
|
19
|
-
|
20
|
-
|
21
|
-
class AnonymousAxisPlaceholder:
|
22
|
-
def __init__(self, value: int):
|
23
|
-
self.value = value
|
24
|
-
assert isinstance(self.value, int)
|
25
|
-
|
26
|
-
def __eq__(self, other):
|
27
|
-
return isinstance(other, AnonymousAxis) and self.value == other.value
|
28
|
-
|
29
|
-
|
30
|
-
def test_anonymous_axes():
|
31
|
-
a, b = AnonymousAxis('2'), AnonymousAxis('2')
|
32
|
-
assert a != b
|
33
|
-
c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3)
|
34
|
-
assert a == c and b == c
|
35
|
-
assert a != d and b != d
|
36
|
-
assert [a, 2, b] == [c, 2, c]
|
37
|
-
|
38
|
-
|
39
|
-
def test_elementary_axis_name():
|
40
|
-
for name in ['a', 'b', 'h', 'dx', 'h1', 'zz', 'i9123', 'somelongname',
|
41
|
-
'Alex', 'camelCase', 'u_n_d_e_r_score', 'unreasonablyLongAxisName']:
|
42
|
-
assert ParsedExpression.check_axis_name(name)
|
43
|
-
|
44
|
-
for name in ['', '2b', '12', '_startWithUnderscore', 'endWithUnderscore_', '_', '...', _ellipsis]:
|
45
|
-
assert not ParsedExpression.check_axis_name(name)
|
46
|
-
|
47
|
-
|
48
|
-
def test_invalid_expressions():
|
49
|
-
# double ellipsis should raise an error
|
50
|
-
ParsedExpression('... a b c d')
|
51
|
-
with pytest.raises(EinopsError):
|
52
|
-
ParsedExpression('... a b c d ...')
|
53
|
-
with pytest.raises(EinopsError):
|
54
|
-
ParsedExpression('... a b c (d ...)')
|
55
|
-
with pytest.raises(EinopsError):
|
56
|
-
ParsedExpression('(... a) b c (d ...)')
|
57
|
-
|
58
|
-
# double/missing/enclosed parenthesis
|
59
|
-
ParsedExpression('(a) b c (d ...)')
|
60
|
-
with pytest.raises(EinopsError):
|
61
|
-
ParsedExpression('(a)) b c (d ...)')
|
62
|
-
with pytest.raises(EinopsError):
|
63
|
-
ParsedExpression('(a b c (d ...)')
|
64
|
-
with pytest.raises(EinopsError):
|
65
|
-
ParsedExpression('(a) (()) b c (d ...)')
|
66
|
-
with pytest.raises(EinopsError):
|
67
|
-
ParsedExpression('(a) ((b c) (d ...))')
|
68
|
-
|
69
|
-
# invalid identifiers
|
70
|
-
ParsedExpression('camelCase under_scored cApiTaLs ß ...')
|
71
|
-
with pytest.raises(EinopsError):
|
72
|
-
ParsedExpression('1a')
|
73
|
-
with pytest.raises(EinopsError):
|
74
|
-
ParsedExpression('_pre')
|
75
|
-
with pytest.raises(EinopsError):
|
76
|
-
ParsedExpression('...pre')
|
77
|
-
with pytest.raises(EinopsError):
|
78
|
-
ParsedExpression('pre...')
|
79
|
-
|
80
|
-
|
81
|
-
def test_parse_expression():
|
82
|
-
parsed = ParsedExpression('a1 b1 c1 d1')
|
83
|
-
assert parsed.identifiers == {'a1', 'b1', 'c1', 'd1'}
|
84
|
-
assert parsed.composition == [['a1'], ['b1'], ['c1'], ['d1']]
|
85
|
-
assert not parsed.has_non_unitary_anonymous_axes
|
86
|
-
assert not parsed.has_ellipsis
|
87
|
-
|
88
|
-
parsed = ParsedExpression('() () () ()')
|
89
|
-
assert parsed.identifiers == set()
|
90
|
-
assert parsed.composition == [[], [], [], []]
|
91
|
-
assert not parsed.has_non_unitary_anonymous_axes
|
92
|
-
assert not parsed.has_ellipsis
|
93
|
-
|
94
|
-
parsed = ParsedExpression('1 1 1 ()')
|
95
|
-
assert parsed.identifiers == set()
|
96
|
-
assert parsed.composition == [[], [], [], []]
|
97
|
-
assert not parsed.has_non_unitary_anonymous_axes
|
98
|
-
assert not parsed.has_ellipsis
|
99
|
-
|
100
|
-
aap = AnonymousAxisPlaceholder
|
101
|
-
|
102
|
-
parsed = ParsedExpression('5 (3 4)')
|
103
|
-
assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5}
|
104
|
-
assert parsed.composition == [[aap(5)], [aap(3), aap(4)]]
|
105
|
-
assert parsed.has_non_unitary_anonymous_axes
|
106
|
-
assert not parsed.has_ellipsis
|
107
|
-
|
108
|
-
parsed = ParsedExpression('5 1 (1 4) 1')
|
109
|
-
assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5}
|
110
|
-
assert parsed.composition == [[aap(5)], [], [aap(4)], []]
|
111
|
-
|
112
|
-
parsed = ParsedExpression('name1 ... a1 12 (name2 14)')
|
113
|
-
assert len(parsed.identifiers) == 6
|
114
|
-
assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2
|
115
|
-
assert parsed.composition == [['name1'], _ellipsis, ['a1'], [aap(12)], ['name2', aap(14)]]
|
116
|
-
assert parsed.has_non_unitary_anonymous_axes
|
117
|
-
assert parsed.has_ellipsis
|
118
|
-
assert not parsed.has_ellipsis_parenthesized
|
119
|
-
|
120
|
-
parsed = ParsedExpression('(name1 ... a1 12) name2 14')
|
121
|
-
assert len(parsed.identifiers) == 6
|
122
|
-
assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2
|
123
|
-
assert parsed.composition == [['name1', _ellipsis, 'a1', aap(12)], ['name2'], [aap(14)]]
|
124
|
-
assert parsed.has_non_unitary_anonymous_axes
|
125
|
-
assert parsed.has_ellipsis
|
126
|
-
assert parsed.has_ellipsis_parenthesized
|
brainstate/math/_einops_test.py
DELETED
@@ -1,346 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import numpy as np
|
17
|
-
import pytest
|
18
|
-
|
19
|
-
import brainstate.math as bcm
|
20
|
-
from brainstate.math._einops import einrearrange, einreduce, einrepeat, _enumerate_directions
|
21
|
-
from brainstate.math._einops_parsing import EinopsError
|
22
|
-
|
23
|
-
REDUCTIONS = ("min", "max", "sum", "mean", "prod")
|
24
|
-
|
25
|
-
identity_patterns = [
|
26
|
-
"...->...",
|
27
|
-
"a b c d e-> a b c d e",
|
28
|
-
"a b c d e ...-> ... a b c d e",
|
29
|
-
"a b c d e ...-> a ... b c d e",
|
30
|
-
"... a b c d e -> ... a b c d e",
|
31
|
-
"a ... e-> a ... e",
|
32
|
-
"a ... -> a ... ",
|
33
|
-
"a ... c d e -> a (...) c d e",
|
34
|
-
]
|
35
|
-
|
36
|
-
equivalent_rearrange_patterns = [
|
37
|
-
("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
|
38
|
-
("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
|
39
|
-
("a b c d e -> a b c d e", "... -> ... "),
|
40
|
-
("a b c d e -> (a b c d e)", "... -> (...)"),
|
41
|
-
("a b c d e -> b (c d e) a", "a b ... -> b (...) a"),
|
42
|
-
("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"),
|
43
|
-
]
|
44
|
-
|
45
|
-
equivalent_reduction_patterns = [
|
46
|
-
("a b c d e -> ", " ... -> "),
|
47
|
-
("a b c d e -> (e a)", "a ... e -> (e a)"),
|
48
|
-
("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "),
|
49
|
-
("a b c d e -> (a b)", " ... c d e -> (...) "),
|
50
|
-
]
|
51
|
-
|
52
|
-
|
53
|
-
def test_collapsed_ellipsis_errors_out():
|
54
|
-
x = np.zeros([1, 1, 1, 1, 1])
|
55
|
-
einrearrange(x, "a b c d ... -> a b c ... d")
|
56
|
-
with pytest.raises(EinopsError):
|
57
|
-
einrearrange(x, "a b c d (...) -> a b c ... d")
|
58
|
-
|
59
|
-
einrearrange(x, "... -> (...)")
|
60
|
-
with pytest.raises(EinopsError):
|
61
|
-
einrearrange(x, "(...) -> (...)")
|
62
|
-
|
63
|
-
|
64
|
-
def test_ellipsis_ops_numpy():
|
65
|
-
x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
|
66
|
-
for pattern in identity_patterns:
|
67
|
-
assert np.array_equal(x, einrearrange(x, pattern)), pattern
|
68
|
-
|
69
|
-
for pattern1, pattern2 in equivalent_rearrange_patterns:
|
70
|
-
assert np.array_equal(einrearrange(x, pattern1), einrearrange(x, pattern2))
|
71
|
-
|
72
|
-
for reduction in ["min", "max", "sum"]:
|
73
|
-
for pattern1, pattern2 in equivalent_reduction_patterns:
|
74
|
-
assert np.array_equal(einreduce(x, pattern1, reduction=reduction),
|
75
|
-
einreduce(x, pattern2, reduction=reduction))
|
76
|
-
|
77
|
-
# now just check coincidence with numpy
|
78
|
-
all_rearrange_patterns = [*identity_patterns]
|
79
|
-
for pattern_pairs in equivalent_rearrange_patterns:
|
80
|
-
all_rearrange_patterns.extend(pattern_pairs)
|
81
|
-
|
82
|
-
|
83
|
-
def test_rearrange_consistency_numpy():
|
84
|
-
shape = [1, 2, 3, 5, 7, 11]
|
85
|
-
x = np.arange(np.prod(shape)).reshape(shape)
|
86
|
-
for pattern in [
|
87
|
-
"a b c d e f -> a b c d e f",
|
88
|
-
"b a c d e f -> a b d e f c",
|
89
|
-
"a b c d e f -> f e d c b a",
|
90
|
-
"a b c d e f -> (f e) d (c b a)",
|
91
|
-
"a b c d e f -> (f e d c b a)",
|
92
|
-
]:
|
93
|
-
result = einrearrange(x, pattern)
|
94
|
-
assert len(np.setdiff1d(x, result)) == 0
|
95
|
-
|
96
|
-
result = einrearrange(x, "a b c d e f -> a (b) (c d e) f")
|
97
|
-
assert np.array_equal(x.flatten(), result.flatten())
|
98
|
-
|
99
|
-
result = einrearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11")
|
100
|
-
assert np.array_equal(x, result)
|
101
|
-
|
102
|
-
result1 = einrearrange(x, "a b c d e f -> f e d c b a")
|
103
|
-
result2 = einrearrange(x, "f e d c b a -> a b c d e f")
|
104
|
-
assert np.array_equal(result1, result2)
|
105
|
-
|
106
|
-
result = einrearrange(einrearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, d=5)
|
107
|
-
assert np.array_equal(x, result)
|
108
|
-
|
109
|
-
sizes = dict(zip("abcdef", shape))
|
110
|
-
temp = einrearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes)
|
111
|
-
result = einrearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes)
|
112
|
-
assert np.array_equal(x, result)
|
113
|
-
|
114
|
-
x2 = np.arange(2 * 3 * 4).reshape([2, 3, 4])
|
115
|
-
result = einrearrange(x2, "a b c -> b c a")
|
116
|
-
assert x2[1, 2, 3] == result[2, 3, 1]
|
117
|
-
assert x2[0, 1, 2] == result[1, 2, 0]
|
118
|
-
|
119
|
-
|
120
|
-
def test_rearrange_permutations_numpy():
|
121
|
-
# tests random permutation of axes against two independent numpy ways
|
122
|
-
for n_axes in range(1, 10):
|
123
|
-
input = np.arange(2 ** n_axes).reshape([2] * n_axes)
|
124
|
-
permutation = np.random.permutation(n_axes)
|
125
|
-
left_expression = " ".join("i" + str(axis) for axis in range(n_axes))
|
126
|
-
right_expression = " ".join("i" + str(axis) for axis in permutation)
|
127
|
-
expression = left_expression + " -> " + right_expression
|
128
|
-
result = einrearrange(input, expression)
|
129
|
-
|
130
|
-
for pick in np.random.randint(0, 2, [10, n_axes]):
|
131
|
-
assert input[tuple(pick)] == result[tuple(pick[permutation])]
|
132
|
-
|
133
|
-
for n_axes in range(1, 10):
|
134
|
-
input = np.arange(2 ** n_axes).reshape([2] * n_axes)
|
135
|
-
permutation = np.random.permutation(n_axes)
|
136
|
-
left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1])
|
137
|
-
right_expression = " ".join("i" + str(axis) for axis in permutation[::-1])
|
138
|
-
expression = left_expression + " -> " + right_expression
|
139
|
-
result = einrearrange(input, expression)
|
140
|
-
assert result.shape == input.shape
|
141
|
-
expected_result = np.zeros_like(input)
|
142
|
-
for original_axis, result_axis in enumerate(permutation):
|
143
|
-
expected_result |= ((input >> original_axis) & 1) << result_axis
|
144
|
-
|
145
|
-
assert np.array_equal(result, expected_result)
|
146
|
-
|
147
|
-
|
148
|
-
def test_reduction_imperatives():
|
149
|
-
for reduction in REDUCTIONS:
|
150
|
-
# slight redundancy for simpler order - numpy version is evaluated multiple times
|
151
|
-
input = np.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6])
|
152
|
-
if reduction in ["mean", "prod"]:
|
153
|
-
input = input / input.astype("float64").mean()
|
154
|
-
test_cases = [
|
155
|
-
["a b c d e -> ", {}, getattr(input, reduction)()],
|
156
|
-
["a ... -> ", {}, getattr(input, reduction)()],
|
157
|
-
["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()],
|
158
|
-
[
|
159
|
-
"a b c d e -> (e c) a",
|
160
|
-
{},
|
161
|
-
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
|
162
|
-
],
|
163
|
-
[
|
164
|
-
"a ... c d e -> (e c) a",
|
165
|
-
{},
|
166
|
-
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
|
167
|
-
],
|
168
|
-
[
|
169
|
-
"a b c d e ... -> (e c) a",
|
170
|
-
{},
|
171
|
-
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
|
172
|
-
],
|
173
|
-
["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
|
174
|
-
["(a a2) ... -> (a2 a) ...", dict(a2=1), input],
|
175
|
-
]
|
176
|
-
for pattern, axes_lengths, expected_result in test_cases:
|
177
|
-
result = einreduce(bcm.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
|
178
|
-
result = bcm.as_numpy(result)
|
179
|
-
print(reduction, pattern, expected_result, result)
|
180
|
-
assert np.allclose(result, expected_result), f"Failed at {pattern}"
|
181
|
-
|
182
|
-
|
183
|
-
def test_enumerating_directions():
|
184
|
-
for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
|
185
|
-
x = np.arange(np.prod(shape)).reshape(shape)
|
186
|
-
axes1 = _enumerate_directions(x)
|
187
|
-
axes2 = _enumerate_directions(bcm.from_numpy(x))
|
188
|
-
assert len(axes1) == len(axes2) == len(shape)
|
189
|
-
for ax1, ax2 in zip(axes1, axes2):
|
190
|
-
ax2 = bcm.as_numpy(ax2)
|
191
|
-
assert ax1.shape == ax2.shape
|
192
|
-
assert np.allclose(ax1, ax2)
|
193
|
-
|
194
|
-
|
195
|
-
def test_concatenations_and_stacking():
|
196
|
-
for n_arrays in [1, 2, 5]:
|
197
|
-
shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
|
198
|
-
for shape in shapes:
|
199
|
-
arrays1 = [np.arange(i, i + np.prod(shape)).reshape(shape) for i in range(n_arrays)]
|
200
|
-
arrays2 = [bcm.from_numpy(array) for array in arrays1]
|
201
|
-
result0 = np.asarray(arrays1)
|
202
|
-
result1 = einrearrange(arrays1, "...->...")
|
203
|
-
result2 = einrearrange(arrays2, "...->...")
|
204
|
-
assert np.array_equal(result0, result1)
|
205
|
-
assert np.array_equal(result1, bcm.as_numpy(result2))
|
206
|
-
|
207
|
-
result1 = einrearrange(arrays1, "b ... -> ... b")
|
208
|
-
result2 = einrearrange(arrays2, "b ... -> ... b")
|
209
|
-
assert np.array_equal(result1, bcm.as_numpy(result2))
|
210
|
-
|
211
|
-
|
212
|
-
def test_gradients_imperatives():
|
213
|
-
# lazy - just checking reductions
|
214
|
-
for reduction in REDUCTIONS:
|
215
|
-
if reduction in ("any", "all"):
|
216
|
-
continue # non-differentiable ops
|
217
|
-
x = np.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32")
|
218
|
-
y0 = bcm.from_numpy(x)
|
219
|
-
if not hasattr(y0, "grad"):
|
220
|
-
continue
|
221
|
-
|
222
|
-
y1 = einreduce(y0, "a b c -> c a", reduction=reduction)
|
223
|
-
y2 = einreduce(y1, "c a -> a c", reduction=reduction)
|
224
|
-
y3 = einreduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2)
|
225
|
-
y4 = einreduce(y3, "... -> ", reduction=reduction)
|
226
|
-
|
227
|
-
y4.backward()
|
228
|
-
grad = bcm.as_numpy(y0.grad)
|
229
|
-
|
230
|
-
|
231
|
-
def test_tiling_imperatives():
|
232
|
-
input = np.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5])
|
233
|
-
test_cases = [
|
234
|
-
(1, 1, 1, 1, 1),
|
235
|
-
(1, 2, 1, 3, 1),
|
236
|
-
(3, 1, 1, 4, 1),
|
237
|
-
]
|
238
|
-
for repeats in test_cases:
|
239
|
-
expected = np.tile(input, repeats)
|
240
|
-
converted = bcm.from_numpy(input)
|
241
|
-
repeated = np.tile(converted, repeats)
|
242
|
-
result = bcm.as_numpy(repeated)
|
243
|
-
assert np.array_equal(result, expected)
|
244
|
-
|
245
|
-
|
246
|
-
repeat_test_cases = [
|
247
|
-
# all assume that input has shape [2, 3, 5]
|
248
|
-
("a b c -> c a b", dict()),
|
249
|
-
("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)),
|
250
|
-
("a b c -> (a copy) b c ", dict(copy=1)),
|
251
|
-
("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)),
|
252
|
-
("a ... -> a ... copy", dict(copy=4)),
|
253
|
-
("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)),
|
254
|
-
("... -> ... ", dict()),
|
255
|
-
(" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)),
|
256
|
-
("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)),
|
257
|
-
]
|
258
|
-
|
259
|
-
|
260
|
-
def check_reversion(x, repeat_pattern, **sizes):
|
261
|
-
"""Checks repeat pattern by running reduction"""
|
262
|
-
left, right = repeat_pattern.split("->")
|
263
|
-
reduce_pattern = right + "->" + left
|
264
|
-
repeated = einrepeat(x, repeat_pattern, **sizes)
|
265
|
-
reduced_min = einreduce(repeated, reduce_pattern, reduction="min", **sizes)
|
266
|
-
reduced_max = einreduce(repeated, reduce_pattern, reduction="max", **sizes)
|
267
|
-
assert np.array_equal(x, reduced_min)
|
268
|
-
assert np.array_equal(x, reduced_max)
|
269
|
-
|
270
|
-
|
271
|
-
def test_repeat_numpy():
|
272
|
-
# check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
|
273
|
-
x = np.arange(2 * 3 * 5).reshape([2, 3, 5])
|
274
|
-
x1 = einrepeat(x, "a b c -> copy a b c ", copy=1)
|
275
|
-
assert np.array_equal(x[None], x1)
|
276
|
-
for pattern, axis_dimensions in repeat_test_cases:
|
277
|
-
check_reversion(x, pattern, **axis_dimensions)
|
278
|
-
|
279
|
-
|
280
|
-
test_cases_repeat_anonymous = [
|
281
|
-
# all assume that input has shape [1, 2, 4, 6]
|
282
|
-
("a b c d -> c a d b", dict()),
|
283
|
-
("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)),
|
284
|
-
("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)),
|
285
|
-
("1 ... -> 3 ... ", dict()),
|
286
|
-
("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)),
|
287
|
-
("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()),
|
288
|
-
]
|
289
|
-
|
290
|
-
|
291
|
-
def test_anonymous_axes():
|
292
|
-
x = np.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
|
293
|
-
for pattern, axis_dimensions in test_cases_repeat_anonymous:
|
294
|
-
check_reversion(x, pattern, **axis_dimensions)
|
295
|
-
|
296
|
-
|
297
|
-
def test_list_inputs():
|
298
|
-
x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
|
299
|
-
|
300
|
-
assert np.array_equal(
|
301
|
-
einrearrange(list(x), "... -> (...)"),
|
302
|
-
einrearrange(x, "... -> (...)"),
|
303
|
-
)
|
304
|
-
assert np.array_equal(
|
305
|
-
einreduce(list(x), "a ... e -> (...)", "min"),
|
306
|
-
einreduce(x, "a ... e -> (...)", "min"),
|
307
|
-
)
|
308
|
-
assert np.array_equal(
|
309
|
-
einrepeat(list(x), "... -> b (...)", b=3),
|
310
|
-
einrepeat(x, "... -> b (...)", b=3),
|
311
|
-
)
|
312
|
-
|
313
|
-
|
314
|
-
def bit_count(x):
|
315
|
-
return sum((x >> i) & 1 for i in range(20))
|
316
|
-
|
317
|
-
|
318
|
-
def test_reduction_imperatives_booleans():
|
319
|
-
"""Checks that any/all reduction works in all frameworks"""
|
320
|
-
x_np = np.asarray([(bit_count(x) % 2) == 0 for x in range(2 ** 6)]).reshape([2] * 6)
|
321
|
-
|
322
|
-
for axis in range(6):
|
323
|
-
expected_result_any = np.any(x_np, axis=axis, keepdims=True)
|
324
|
-
expected_result_all = np.all(x_np, axis=axis, keepdims=True)
|
325
|
-
assert not np.array_equal(expected_result_any, expected_result_all)
|
326
|
-
|
327
|
-
axes = list("abcdef")
|
328
|
-
axes_in = list(axes)
|
329
|
-
axes_out = list(axes)
|
330
|
-
axes_out[axis] = "1"
|
331
|
-
pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out))
|
332
|
-
|
333
|
-
res_any = einreduce(bcm.from_numpy(x_np), pattern, reduction="any")
|
334
|
-
res_all = einreduce(bcm.from_numpy(x_np), pattern, reduction="all")
|
335
|
-
|
336
|
-
assert np.array_equal(expected_result_any, bcm.as_numpy(res_any))
|
337
|
-
assert np.array_equal(expected_result_all, bcm.as_numpy(res_all))
|
338
|
-
|
339
|
-
# expected result: any/all
|
340
|
-
expected_result_any = np.any(x_np, axis=(0, 1), keepdims=True)
|
341
|
-
expected_result_all = np.all(x_np, axis=(0, 1), keepdims=True)
|
342
|
-
pattern = "a b ... -> 1 1 ..."
|
343
|
-
res_any = einreduce(bcm.from_numpy(x_np), pattern, reduction="any")
|
344
|
-
res_all = einreduce(bcm.from_numpy(x_np), pattern, reduction="all")
|
345
|
-
assert np.array_equal(expected_result_any, bcm.as_numpy(res_any))
|
346
|
-
assert np.array_equal(expected_result_all, bcm.as_numpy(res_all))
|