brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +148 -43
  3. brainstate/_module_test.py +95 -21
  4. brainstate/environ.py +0 -1
  5. brainstate/functional/__init__.py +2 -2
  6. brainstate/functional/_activations.py +7 -26
  7. brainstate/functional/_spikes.py +0 -1
  8. brainstate/mixin.py +2 -2
  9. brainstate/nn/_elementwise.py +5 -4
  10. brainstate/nn/_misc.py +4 -3
  11. brainstate/nn/_others.py +3 -2
  12. brainstate/nn/_poolings.py +21 -20
  13. brainstate/nn/_poolings_test.py +4 -4
  14. brainstate/optim/__init__.py +0 -1
  15. brainstate/optim/_sgd_optimizer.py +18 -17
  16. brainstate/transform/__init__.py +2 -3
  17. brainstate/transform/_autograd.py +1 -1
  18. brainstate/transform/_autograd_test.py +0 -2
  19. brainstate/transform/_jit_test.py +0 -3
  20. brainstate/transform/_make_jaxpr.py +0 -1
  21. brainstate/transform/_make_jaxpr_test.py +0 -2
  22. brainstate/transform/_progress_bar.py +1 -3
  23. brainstate/util.py +0 -1
  24. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +2 -12
  25. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/RECORD +28 -35
  26. brainstate/math/__init__.py +0 -21
  27. brainstate/math/_einops.py +0 -787
  28. brainstate/math/_einops_parsing.py +0 -169
  29. brainstate/math/_einops_parsing_test.py +0 -126
  30. brainstate/math/_einops_test.py +0 -346
  31. brainstate/math/_misc.py +0 -298
  32. brainstate/math/_misc_test.py +0 -58
  33. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
  34. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
  35. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
@@ -1,169 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
- import keyword
18
- import warnings
19
- from typing import List, Optional, Set, Tuple, Union
20
-
21
- _ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
22
-
23
-
24
- class EinopsError(Exception):
25
- pass
26
-
27
-
28
- class AnonymousAxis(object):
29
- """Important thing: all instances of this class are not equal to each other """
30
-
31
- def __init__(self, value: str):
32
- self.value = int(value)
33
- if self.value <= 1:
34
- if self.value == 1:
35
- raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
36
- else:
37
- raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
38
-
39
- def __repr__(self):
40
- return "{}-axis".format(str(self.value))
41
-
42
-
43
- class ParsedExpression:
44
- """
45
- non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
46
- and keeps some information important for downstream
47
- """
48
-
49
- def __init__(self, expression: str, *, allow_underscore: bool = False,
50
- allow_duplicates: bool = False):
51
- self.has_ellipsis: bool = False
52
- self.has_ellipsis_parenthesized: Optional[bool] = None
53
- self.identifiers: Set[str] = set()
54
- # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
55
- self.has_non_unitary_anonymous_axes: bool = False
56
- # composition keeps structure of composite axes, see how different corner cases are handled in tests
57
- self.composition: List[Union[List[str], str]] = []
58
- if '.' in expression:
59
- if '...' not in expression:
60
- raise EinopsError('Expression may contain dots only inside ellipsis (...)')
61
- if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
62
- raise EinopsError(
63
- 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
64
- expression = expression.replace('...', _ellipsis)
65
- self.has_ellipsis = True
66
-
67
- bracket_group: Optional[List[str]] = None
68
-
69
- def add_axis_name(x):
70
- if x in self.identifiers:
71
- if not (allow_underscore and x == "_") and not allow_duplicates:
72
- raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
73
- if x == _ellipsis:
74
- self.identifiers.add(_ellipsis)
75
- if bracket_group is None:
76
- self.composition.append(_ellipsis)
77
- self.has_ellipsis_parenthesized = False
78
- else:
79
- bracket_group.append(_ellipsis)
80
- self.has_ellipsis_parenthesized = True
81
- else:
82
- is_number = str.isdecimal(x)
83
- if is_number and int(x) == 1:
84
- # handling the case of anonymous axis of length 1
85
- if bracket_group is None:
86
- self.composition.append([])
87
- else:
88
- pass # no need to think about 1s inside parenthesis
89
- return
90
- is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
91
- if not (is_number or is_axis_name):
92
- raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
93
- if is_number:
94
- x = AnonymousAxis(x)
95
- self.identifiers.add(x)
96
- if is_number:
97
- self.has_non_unitary_anonymous_axes = True
98
- if bracket_group is None:
99
- self.composition.append([x])
100
- else:
101
- bracket_group.append(x)
102
-
103
- current_identifier = None
104
- for char in expression:
105
- if char in '() ':
106
- if current_identifier is not None:
107
- add_axis_name(current_identifier)
108
- current_identifier = None
109
- if char == '(':
110
- if bracket_group is not None:
111
- raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
112
- bracket_group = []
113
- elif char == ')':
114
- if bracket_group is None:
115
- raise EinopsError('Brackets are not balanced')
116
- self.composition.append(bracket_group)
117
- bracket_group = None
118
- elif str.isalnum(char) or char in ['_', _ellipsis]:
119
- if current_identifier is None:
120
- current_identifier = char
121
- else:
122
- current_identifier += char
123
- else:
124
- raise EinopsError("Unknown character '{}'".format(char))
125
-
126
- if bracket_group is not None:
127
- raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
128
- if current_identifier is not None:
129
- add_axis_name(current_identifier)
130
-
131
- def flat_axes_order(self) -> List:
132
- result = []
133
- for composed_axis in self.composition:
134
- assert isinstance(composed_axis, list), 'does not work with ellipsis'
135
- for axis in composed_axis:
136
- result.append(axis)
137
- return result
138
-
139
- def has_composed_axes(self) -> bool:
140
- # this will ignore 1 inside brackets
141
- for axes in self.composition:
142
- if isinstance(axes, list) and len(axes) > 1:
143
- return True
144
- return False
145
-
146
- @staticmethod
147
- def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
148
- if not str.isidentifier(name):
149
- return False, 'not a valid python identifier'
150
- elif name[0] == '_' or name[-1] == '_':
151
- if name == '_' and allow_underscore:
152
- return True, ''
153
- return False, 'axis name should should not start or end with underscore'
154
- else:
155
- if keyword.iskeyword(name):
156
- warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
157
- if name in ['axis']:
158
- warnings.warn("It is discouraged to use 'axis' as an axis name "
159
- "and will raise an error in future", FutureWarning)
160
- return True, ''
161
-
162
- @staticmethod
163
- def check_axis_name(name: str) -> bool:
164
- """
165
- Valid axes names are python identifiers except keywords,
166
- and additionally should not start or end with underscore
167
- """
168
- is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
169
- return is_valid
@@ -1,126 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import pytest
17
-
18
- from brainstate.math._einops_parsing import EinopsError, ParsedExpression, AnonymousAxis, _ellipsis
19
-
20
-
21
- class AnonymousAxisPlaceholder:
22
- def __init__(self, value: int):
23
- self.value = value
24
- assert isinstance(self.value, int)
25
-
26
- def __eq__(self, other):
27
- return isinstance(other, AnonymousAxis) and self.value == other.value
28
-
29
-
30
- def test_anonymous_axes():
31
- a, b = AnonymousAxis('2'), AnonymousAxis('2')
32
- assert a != b
33
- c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3)
34
- assert a == c and b == c
35
- assert a != d and b != d
36
- assert [a, 2, b] == [c, 2, c]
37
-
38
-
39
- def test_elementary_axis_name():
40
- for name in ['a', 'b', 'h', 'dx', 'h1', 'zz', 'i9123', 'somelongname',
41
- 'Alex', 'camelCase', 'u_n_d_e_r_score', 'unreasonablyLongAxisName']:
42
- assert ParsedExpression.check_axis_name(name)
43
-
44
- for name in ['', '2b', '12', '_startWithUnderscore', 'endWithUnderscore_', '_', '...', _ellipsis]:
45
- assert not ParsedExpression.check_axis_name(name)
46
-
47
-
48
- def test_invalid_expressions():
49
- # double ellipsis should raise an error
50
- ParsedExpression('... a b c d')
51
- with pytest.raises(EinopsError):
52
- ParsedExpression('... a b c d ...')
53
- with pytest.raises(EinopsError):
54
- ParsedExpression('... a b c (d ...)')
55
- with pytest.raises(EinopsError):
56
- ParsedExpression('(... a) b c (d ...)')
57
-
58
- # double/missing/enclosed parenthesis
59
- ParsedExpression('(a) b c (d ...)')
60
- with pytest.raises(EinopsError):
61
- ParsedExpression('(a)) b c (d ...)')
62
- with pytest.raises(EinopsError):
63
- ParsedExpression('(a b c (d ...)')
64
- with pytest.raises(EinopsError):
65
- ParsedExpression('(a) (()) b c (d ...)')
66
- with pytest.raises(EinopsError):
67
- ParsedExpression('(a) ((b c) (d ...))')
68
-
69
- # invalid identifiers
70
- ParsedExpression('camelCase under_scored cApiTaLs ß ...')
71
- with pytest.raises(EinopsError):
72
- ParsedExpression('1a')
73
- with pytest.raises(EinopsError):
74
- ParsedExpression('_pre')
75
- with pytest.raises(EinopsError):
76
- ParsedExpression('...pre')
77
- with pytest.raises(EinopsError):
78
- ParsedExpression('pre...')
79
-
80
-
81
- def test_parse_expression():
82
- parsed = ParsedExpression('a1 b1 c1 d1')
83
- assert parsed.identifiers == {'a1', 'b1', 'c1', 'd1'}
84
- assert parsed.composition == [['a1'], ['b1'], ['c1'], ['d1']]
85
- assert not parsed.has_non_unitary_anonymous_axes
86
- assert not parsed.has_ellipsis
87
-
88
- parsed = ParsedExpression('() () () ()')
89
- assert parsed.identifiers == set()
90
- assert parsed.composition == [[], [], [], []]
91
- assert not parsed.has_non_unitary_anonymous_axes
92
- assert not parsed.has_ellipsis
93
-
94
- parsed = ParsedExpression('1 1 1 ()')
95
- assert parsed.identifiers == set()
96
- assert parsed.composition == [[], [], [], []]
97
- assert not parsed.has_non_unitary_anonymous_axes
98
- assert not parsed.has_ellipsis
99
-
100
- aap = AnonymousAxisPlaceholder
101
-
102
- parsed = ParsedExpression('5 (3 4)')
103
- assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5}
104
- assert parsed.composition == [[aap(5)], [aap(3), aap(4)]]
105
- assert parsed.has_non_unitary_anonymous_axes
106
- assert not parsed.has_ellipsis
107
-
108
- parsed = ParsedExpression('5 1 (1 4) 1')
109
- assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5}
110
- assert parsed.composition == [[aap(5)], [], [aap(4)], []]
111
-
112
- parsed = ParsedExpression('name1 ... a1 12 (name2 14)')
113
- assert len(parsed.identifiers) == 6
114
- assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2
115
- assert parsed.composition == [['name1'], _ellipsis, ['a1'], [aap(12)], ['name2', aap(14)]]
116
- assert parsed.has_non_unitary_anonymous_axes
117
- assert parsed.has_ellipsis
118
- assert not parsed.has_ellipsis_parenthesized
119
-
120
- parsed = ParsedExpression('(name1 ... a1 12) name2 14')
121
- assert len(parsed.identifiers) == 6
122
- assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2
123
- assert parsed.composition == [['name1', _ellipsis, 'a1', aap(12)], ['name2'], [aap(14)]]
124
- assert parsed.has_non_unitary_anonymous_axes
125
- assert parsed.has_ellipsis
126
- assert parsed.has_ellipsis_parenthesized
@@ -1,346 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import numpy as np
17
- import pytest
18
-
19
- import brainstate.math as bcm
20
- from brainstate.math._einops import einrearrange, einreduce, einrepeat, _enumerate_directions
21
- from brainstate.math._einops_parsing import EinopsError
22
-
23
- REDUCTIONS = ("min", "max", "sum", "mean", "prod")
24
-
25
- identity_patterns = [
26
- "...->...",
27
- "a b c d e-> a b c d e",
28
- "a b c d e ...-> ... a b c d e",
29
- "a b c d e ...-> a ... b c d e",
30
- "... a b c d e -> ... a b c d e",
31
- "a ... e-> a ... e",
32
- "a ... -> a ... ",
33
- "a ... c d e -> a (...) c d e",
34
- ]
35
-
36
- equivalent_rearrange_patterns = [
37
- ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
38
- ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
39
- ("a b c d e -> a b c d e", "... -> ... "),
40
- ("a b c d e -> (a b c d e)", "... -> (...)"),
41
- ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"),
42
- ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"),
43
- ]
44
-
45
- equivalent_reduction_patterns = [
46
- ("a b c d e -> ", " ... -> "),
47
- ("a b c d e -> (e a)", "a ... e -> (e a)"),
48
- ("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "),
49
- ("a b c d e -> (a b)", " ... c d e -> (...) "),
50
- ]
51
-
52
-
53
- def test_collapsed_ellipsis_errors_out():
54
- x = np.zeros([1, 1, 1, 1, 1])
55
- einrearrange(x, "a b c d ... -> a b c ... d")
56
- with pytest.raises(EinopsError):
57
- einrearrange(x, "a b c d (...) -> a b c ... d")
58
-
59
- einrearrange(x, "... -> (...)")
60
- with pytest.raises(EinopsError):
61
- einrearrange(x, "(...) -> (...)")
62
-
63
-
64
- def test_ellipsis_ops_numpy():
65
- x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
66
- for pattern in identity_patterns:
67
- assert np.array_equal(x, einrearrange(x, pattern)), pattern
68
-
69
- for pattern1, pattern2 in equivalent_rearrange_patterns:
70
- assert np.array_equal(einrearrange(x, pattern1), einrearrange(x, pattern2))
71
-
72
- for reduction in ["min", "max", "sum"]:
73
- for pattern1, pattern2 in equivalent_reduction_patterns:
74
- assert np.array_equal(einreduce(x, pattern1, reduction=reduction),
75
- einreduce(x, pattern2, reduction=reduction))
76
-
77
- # now just check coincidence with numpy
78
- all_rearrange_patterns = [*identity_patterns]
79
- for pattern_pairs in equivalent_rearrange_patterns:
80
- all_rearrange_patterns.extend(pattern_pairs)
81
-
82
-
83
- def test_rearrange_consistency_numpy():
84
- shape = [1, 2, 3, 5, 7, 11]
85
- x = np.arange(np.prod(shape)).reshape(shape)
86
- for pattern in [
87
- "a b c d e f -> a b c d e f",
88
- "b a c d e f -> a b d e f c",
89
- "a b c d e f -> f e d c b a",
90
- "a b c d e f -> (f e) d (c b a)",
91
- "a b c d e f -> (f e d c b a)",
92
- ]:
93
- result = einrearrange(x, pattern)
94
- assert len(np.setdiff1d(x, result)) == 0
95
-
96
- result = einrearrange(x, "a b c d e f -> a (b) (c d e) f")
97
- assert np.array_equal(x.flatten(), result.flatten())
98
-
99
- result = einrearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11")
100
- assert np.array_equal(x, result)
101
-
102
- result1 = einrearrange(x, "a b c d e f -> f e d c b a")
103
- result2 = einrearrange(x, "f e d c b a -> a b c d e f")
104
- assert np.array_equal(result1, result2)
105
-
106
- result = einrearrange(einrearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, d=5)
107
- assert np.array_equal(x, result)
108
-
109
- sizes = dict(zip("abcdef", shape))
110
- temp = einrearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes)
111
- result = einrearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes)
112
- assert np.array_equal(x, result)
113
-
114
- x2 = np.arange(2 * 3 * 4).reshape([2, 3, 4])
115
- result = einrearrange(x2, "a b c -> b c a")
116
- assert x2[1, 2, 3] == result[2, 3, 1]
117
- assert x2[0, 1, 2] == result[1, 2, 0]
118
-
119
-
120
- def test_rearrange_permutations_numpy():
121
- # tests random permutation of axes against two independent numpy ways
122
- for n_axes in range(1, 10):
123
- input = np.arange(2 ** n_axes).reshape([2] * n_axes)
124
- permutation = np.random.permutation(n_axes)
125
- left_expression = " ".join("i" + str(axis) for axis in range(n_axes))
126
- right_expression = " ".join("i" + str(axis) for axis in permutation)
127
- expression = left_expression + " -> " + right_expression
128
- result = einrearrange(input, expression)
129
-
130
- for pick in np.random.randint(0, 2, [10, n_axes]):
131
- assert input[tuple(pick)] == result[tuple(pick[permutation])]
132
-
133
- for n_axes in range(1, 10):
134
- input = np.arange(2 ** n_axes).reshape([2] * n_axes)
135
- permutation = np.random.permutation(n_axes)
136
- left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1])
137
- right_expression = " ".join("i" + str(axis) for axis in permutation[::-1])
138
- expression = left_expression + " -> " + right_expression
139
- result = einrearrange(input, expression)
140
- assert result.shape == input.shape
141
- expected_result = np.zeros_like(input)
142
- for original_axis, result_axis in enumerate(permutation):
143
- expected_result |= ((input >> original_axis) & 1) << result_axis
144
-
145
- assert np.array_equal(result, expected_result)
146
-
147
-
148
- def test_reduction_imperatives():
149
- for reduction in REDUCTIONS:
150
- # slight redundancy for simpler order - numpy version is evaluated multiple times
151
- input = np.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6])
152
- if reduction in ["mean", "prod"]:
153
- input = input / input.astype("float64").mean()
154
- test_cases = [
155
- ["a b c d e -> ", {}, getattr(input, reduction)()],
156
- ["a ... -> ", {}, getattr(input, reduction)()],
157
- ["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()],
158
- [
159
- "a b c d e -> (e c) a",
160
- {},
161
- getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
162
- ],
163
- [
164
- "a ... c d e -> (e c) a",
165
- {},
166
- getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
167
- ],
168
- [
169
- "a b c d e ... -> (e c) a",
170
- {},
171
- getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
172
- ],
173
- ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
174
- ["(a a2) ... -> (a2 a) ...", dict(a2=1), input],
175
- ]
176
- for pattern, axes_lengths, expected_result in test_cases:
177
- result = einreduce(bcm.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
178
- result = bcm.as_numpy(result)
179
- print(reduction, pattern, expected_result, result)
180
- assert np.allclose(result, expected_result), f"Failed at {pattern}"
181
-
182
-
183
- def test_enumerating_directions():
184
- for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
185
- x = np.arange(np.prod(shape)).reshape(shape)
186
- axes1 = _enumerate_directions(x)
187
- axes2 = _enumerate_directions(bcm.from_numpy(x))
188
- assert len(axes1) == len(axes2) == len(shape)
189
- for ax1, ax2 in zip(axes1, axes2):
190
- ax2 = bcm.as_numpy(ax2)
191
- assert ax1.shape == ax2.shape
192
- assert np.allclose(ax1, ax2)
193
-
194
-
195
- def test_concatenations_and_stacking():
196
- for n_arrays in [1, 2, 5]:
197
- shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
198
- for shape in shapes:
199
- arrays1 = [np.arange(i, i + np.prod(shape)).reshape(shape) for i in range(n_arrays)]
200
- arrays2 = [bcm.from_numpy(array) for array in arrays1]
201
- result0 = np.asarray(arrays1)
202
- result1 = einrearrange(arrays1, "...->...")
203
- result2 = einrearrange(arrays2, "...->...")
204
- assert np.array_equal(result0, result1)
205
- assert np.array_equal(result1, bcm.as_numpy(result2))
206
-
207
- result1 = einrearrange(arrays1, "b ... -> ... b")
208
- result2 = einrearrange(arrays2, "b ... -> ... b")
209
- assert np.array_equal(result1, bcm.as_numpy(result2))
210
-
211
-
212
- def test_gradients_imperatives():
213
- # lazy - just checking reductions
214
- for reduction in REDUCTIONS:
215
- if reduction in ("any", "all"):
216
- continue # non-differentiable ops
217
- x = np.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32")
218
- y0 = bcm.from_numpy(x)
219
- if not hasattr(y0, "grad"):
220
- continue
221
-
222
- y1 = einreduce(y0, "a b c -> c a", reduction=reduction)
223
- y2 = einreduce(y1, "c a -> a c", reduction=reduction)
224
- y3 = einreduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2)
225
- y4 = einreduce(y3, "... -> ", reduction=reduction)
226
-
227
- y4.backward()
228
- grad = bcm.as_numpy(y0.grad)
229
-
230
-
231
- def test_tiling_imperatives():
232
- input = np.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5])
233
- test_cases = [
234
- (1, 1, 1, 1, 1),
235
- (1, 2, 1, 3, 1),
236
- (3, 1, 1, 4, 1),
237
- ]
238
- for repeats in test_cases:
239
- expected = np.tile(input, repeats)
240
- converted = bcm.from_numpy(input)
241
- repeated = np.tile(converted, repeats)
242
- result = bcm.as_numpy(repeated)
243
- assert np.array_equal(result, expected)
244
-
245
-
246
- repeat_test_cases = [
247
- # all assume that input has shape [2, 3, 5]
248
- ("a b c -> c a b", dict()),
249
- ("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)),
250
- ("a b c -> (a copy) b c ", dict(copy=1)),
251
- ("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)),
252
- ("a ... -> a ... copy", dict(copy=4)),
253
- ("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)),
254
- ("... -> ... ", dict()),
255
- (" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)),
256
- ("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)),
257
- ]
258
-
259
-
260
- def check_reversion(x, repeat_pattern, **sizes):
261
- """Checks repeat pattern by running reduction"""
262
- left, right = repeat_pattern.split("->")
263
- reduce_pattern = right + "->" + left
264
- repeated = einrepeat(x, repeat_pattern, **sizes)
265
- reduced_min = einreduce(repeated, reduce_pattern, reduction="min", **sizes)
266
- reduced_max = einreduce(repeated, reduce_pattern, reduction="max", **sizes)
267
- assert np.array_equal(x, reduced_min)
268
- assert np.array_equal(x, reduced_max)
269
-
270
-
271
- def test_repeat_numpy():
272
- # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
273
- x = np.arange(2 * 3 * 5).reshape([2, 3, 5])
274
- x1 = einrepeat(x, "a b c -> copy a b c ", copy=1)
275
- assert np.array_equal(x[None], x1)
276
- for pattern, axis_dimensions in repeat_test_cases:
277
- check_reversion(x, pattern, **axis_dimensions)
278
-
279
-
280
- test_cases_repeat_anonymous = [
281
- # all assume that input has shape [1, 2, 4, 6]
282
- ("a b c d -> c a d b", dict()),
283
- ("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)),
284
- ("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)),
285
- ("1 ... -> 3 ... ", dict()),
286
- ("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)),
287
- ("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()),
288
- ]
289
-
290
-
291
- def test_anonymous_axes():
292
- x = np.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
293
- for pattern, axis_dimensions in test_cases_repeat_anonymous:
294
- check_reversion(x, pattern, **axis_dimensions)
295
-
296
-
297
- def test_list_inputs():
298
- x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
299
-
300
- assert np.array_equal(
301
- einrearrange(list(x), "... -> (...)"),
302
- einrearrange(x, "... -> (...)"),
303
- )
304
- assert np.array_equal(
305
- einreduce(list(x), "a ... e -> (...)", "min"),
306
- einreduce(x, "a ... e -> (...)", "min"),
307
- )
308
- assert np.array_equal(
309
- einrepeat(list(x), "... -> b (...)", b=3),
310
- einrepeat(x, "... -> b (...)", b=3),
311
- )
312
-
313
-
314
- def bit_count(x):
315
- return sum((x >> i) & 1 for i in range(20))
316
-
317
-
318
- def test_reduction_imperatives_booleans():
319
- """Checks that any/all reduction works in all frameworks"""
320
- x_np = np.asarray([(bit_count(x) % 2) == 0 for x in range(2 ** 6)]).reshape([2] * 6)
321
-
322
- for axis in range(6):
323
- expected_result_any = np.any(x_np, axis=axis, keepdims=True)
324
- expected_result_all = np.all(x_np, axis=axis, keepdims=True)
325
- assert not np.array_equal(expected_result_any, expected_result_all)
326
-
327
- axes = list("abcdef")
328
- axes_in = list(axes)
329
- axes_out = list(axes)
330
- axes_out[axis] = "1"
331
- pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out))
332
-
333
- res_any = einreduce(bcm.from_numpy(x_np), pattern, reduction="any")
334
- res_all = einreduce(bcm.from_numpy(x_np), pattern, reduction="all")
335
-
336
- assert np.array_equal(expected_result_any, bcm.as_numpy(res_any))
337
- assert np.array_equal(expected_result_all, bcm.as_numpy(res_all))
338
-
339
- # expected result: any/all
340
- expected_result_any = np.any(x_np, axis=(0, 1), keepdims=True)
341
- expected_result_all = np.all(x_np, axis=(0, 1), keepdims=True)
342
- pattern = "a b ... -> 1 1 ..."
343
- res_any = einreduce(bcm.from_numpy(x_np), pattern, reduction="any")
344
- res_all = einreduce(bcm.from_numpy(x_np), pattern, reduction="all")
345
- assert np.array_equal(expected_result_any, bcm.as_numpy(res_any))
346
- assert np.array_equal(expected_result_all, bcm.as_numpy(res_all))