brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240623__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +147 -42
  3. brainstate/_module_test.py +95 -21
  4. brainstate/environ.py +0 -1
  5. brainstate/functional/__init__.py +2 -2
  6. brainstate/functional/_activations.py +7 -26
  7. brainstate/functional/_spikes.py +0 -1
  8. brainstate/mixin.py +2 -2
  9. brainstate/nn/_elementwise.py +5 -4
  10. brainstate/nn/_misc.py +4 -3
  11. brainstate/nn/_others.py +3 -2
  12. brainstate/nn/_poolings.py +21 -20
  13. brainstate/nn/_poolings_test.py +4 -4
  14. brainstate/optim/__init__.py +0 -1
  15. brainstate/optim/_sgd_optimizer.py +18 -17
  16. brainstate/transform/__init__.py +2 -3
  17. brainstate/transform/_autograd.py +1 -1
  18. brainstate/transform/_autograd_test.py +0 -2
  19. brainstate/transform/_jit_test.py +0 -3
  20. brainstate/transform/_make_jaxpr.py +0 -1
  21. brainstate/transform/_make_jaxpr_test.py +0 -2
  22. brainstate/transform/_progress_bar.py +1 -3
  23. brainstate/util.py +0 -1
  24. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/METADATA +2 -12
  25. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/RECORD +28 -35
  26. brainstate/math/__init__.py +0 -21
  27. brainstate/math/_einops.py +0 -787
  28. brainstate/math/_einops_parsing.py +0 -169
  29. brainstate/math/_einops_parsing_test.py +0 -126
  30. brainstate/math/_einops_test.py +0 -346
  31. brainstate/math/_misc.py +0 -298
  32. brainstate/math/_misc_test.py +0 -58
  33. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/LICENSE +0 -0
  34. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/WHEEL +0 -0
  35. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/top_level.txt +0 -0
brainstate/math/_misc.py DELETED
@@ -1,298 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
- from typing import Optional, Sequence
18
-
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
-
23
- from brainstate import environ
24
- from brainstate._utils import set_module_as
25
-
26
-
27
- __all__ = [
28
- 'get_dtype',
29
- 'is_float',
30
- 'is_int',
31
- 'exprel',
32
- 'flatten',
33
- 'unflatten',
34
- 'remove_diag',
35
- 'clip_by_norm',
36
- 'from_numpy',
37
- 'as_numpy',
38
- 'tree_zeros_like',
39
- 'tree_ones_like',
40
- ]
41
-
42
-
43
- @set_module_as('brainstate.math')
44
- def get_dtype(a):
45
- """
46
- Get the dtype of a.
47
- """
48
- if hasattr(a, 'dtype'):
49
- return a.dtype
50
- else:
51
- if isinstance(a, bool):
52
- return bool
53
- elif isinstance(a, int):
54
- return environ.ditype()
55
- elif isinstance(a, float):
56
- return environ.dftype()
57
- elif isinstance(a, complex):
58
- return environ.dctype()
59
- else:
60
- raise ValueError(f'Can not get dtype of {a}.')
61
-
62
-
63
- @set_module_as('brainstate.math')
64
- def is_float(array):
65
- """
66
- Check if the array is a floating point array.
67
-
68
- Args:
69
- array: The input array.
70
-
71
- Returns:
72
- A boolean value indicating if the array is a floating point array.
73
- """
74
- return jnp.issubdtype(get_dtype(array), jnp.floating)
75
-
76
-
77
- @set_module_as('brainstate.math')
78
- def is_int(array):
79
- """
80
- Check if the array is an integer array.
81
-
82
- Args:
83
- array: The input array.
84
-
85
- Returns:
86
- A boolean value indicating if the array is an integer array.
87
- """
88
- return jnp.issubdtype(get_dtype(array), jnp.integer)
89
-
90
-
91
- @set_module_as('brainstate.math')
92
- def exprel(x):
93
- """
94
- Relative error exponential, ``(exp(x) - 1)/x``.
95
-
96
- When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can
97
- suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of
98
- precision that occurs when ``x`` is near zero.
99
-
100
- Args:
101
- x: ndarray. Input array. ``x`` must contain real numbers.
102
-
103
- Returns:
104
- ``(exp(x) - 1)/x``, computed element-wise.
105
- """
106
-
107
- # following the implementation of exprel from scipy.special
108
- x = jnp.asarray(x)
109
- dtype = x.dtype
110
-
111
- # Adjust the tolerance based on the dtype of x
112
- if dtype == jnp.float64:
113
- small_threshold = 1e-16
114
- big_threshold = 717
115
- elif dtype == jnp.float32:
116
- small_threshold = 1e-8
117
- big_threshold = 100
118
- elif dtype == jnp.float16:
119
- small_threshold = 1e-4
120
- big_threshold = 10
121
- else:
122
- small_threshold = 1e-4
123
- big_threshold = 10
124
-
125
- small = jnp.abs(x) < small_threshold
126
- big = x > big_threshold
127
- origin = jnp.expm1(x) / x
128
- return jnp.where(small, 1.0, jnp.where(big, jnp.inf, origin))
129
-
130
-
131
- @set_module_as('brainstate.math')
132
- def remove_diag(arr):
133
- """Remove the diagonal of the matrix.
134
-
135
- Parameters
136
- ----------
137
- arr: ArrayType
138
- The matrix with the shape of `(M, N)`.
139
-
140
- Returns
141
- -------
142
- arr: Array
143
- The matrix without diagonal which has the shape of `(M, N-1)`.
144
- """
145
- if arr.ndim != 2:
146
- raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.')
147
- eyes = jnp.fill_diagonal(jnp.ones(arr.shape, dtype=bool), False)
148
- return jnp.reshape(arr[eyes], (arr.shape[0], arr.shape[1] - 1))
149
-
150
-
151
- @set_module_as('brainstate.math')
152
- def clip_by_norm(t, clip_norm, axis=None):
153
- """
154
- Clip the tensor by the norm of the tensor.
155
-
156
- Args:
157
- t: The tensor to be clipped.
158
- clip_norm: The maximum norm value.
159
- axis: The axis to calculate the norm. If None, the norm is calculated over the whole tensor.
160
-
161
- Returns:
162
- The clipped tensor.
163
-
164
- """
165
- return jax.tree.map(
166
- lambda l: l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm),
167
- t
168
- )
169
-
170
-
171
- @set_module_as('brainstate.math')
172
- def flatten(
173
- input: jax.typing.ArrayLike,
174
- start_dim: Optional[int] = None,
175
- end_dim: Optional[int] = None
176
- ) -> jax.Array:
177
- """Flattens input by reshaping it into a one-dimensional tensor.
178
- If ``start_dim`` or ``end_dim`` are passed, only dimensions starting
179
- with ``start_dim`` and ending with ``end_dim`` are flattened.
180
- The order of elements in input is unchanged.
181
-
182
- .. note::
183
- Flattening a zero-dimensional tensor will return a one-dimensional view.
184
-
185
- Parameters
186
- ----------
187
- input: Array
188
- The input array.
189
- start_dim: int
190
- the first dim to flatten
191
- end_dim: int
192
- the last dim to flatten
193
-
194
- Returns
195
- -------
196
- out: Array
197
- """
198
- shape = input.shape
199
- ndim = input.ndim
200
- if ndim == 0:
201
- ndim = 1
202
- if start_dim is None:
203
- start_dim = 0
204
- elif start_dim < 0:
205
- start_dim = ndim + start_dim
206
- if end_dim is None:
207
- end_dim = ndim - 1
208
- elif end_dim < 0:
209
- end_dim = ndim + end_dim
210
- end_dim += 1
211
- if start_dim < 0 or start_dim > ndim:
212
- raise ValueError(f'start_dim {start_dim} is out of size.')
213
- if end_dim < 0 or end_dim > ndim:
214
- raise ValueError(f'end_dim {end_dim} is out of size.')
215
- new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int),) + shape[end_dim:]
216
- return jnp.reshape(input, new_shape)
217
-
218
-
219
- @set_module_as('brainstate.math')
220
- def unflatten(x: jax.typing.ArrayLike, dim: int, sizes: Sequence[int]) -> jax.Array:
221
- """
222
- Expands a dimension of the input tensor over multiple dimensions.
223
-
224
- Args:
225
- x: input tensor.
226
- dim: Dimension to be unflattened, specified as an index into ``x.shape``.
227
- sizes: New shape of the unflattened dimension. One of its elements can be -1
228
- in which case the corresponding output dimension is inferred.
229
- Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``.
230
-
231
- Returns:
232
- A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions.
233
- The returned tensor has one more dimension than the input tensor.
234
- The returned tensor shares the same underlying data with this tensor.
235
- """
236
- assert x.ndim > dim, ('The dimension to be unflattened should '
237
- 'be less than the tensor dimension. '
238
- f'Got {dim} and {x.ndim}.')
239
- shape = x.shape
240
- new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:]
241
- return jnp.reshape(x, new_shape)
242
-
243
-
244
- @set_module_as('brainstate.math')
245
- def from_numpy(x):
246
- """
247
- Convert the numpy array to jax array.
248
-
249
- Args:
250
- x: The numpy array.
251
-
252
- Returns:
253
- The jax array.
254
- """
255
- return jnp.array(x)
256
-
257
-
258
- @set_module_as('brainstate.math')
259
- def as_numpy(x):
260
- """
261
- Convert the array to numpy array.
262
-
263
- Args:
264
- x: The array.
265
-
266
- Returns:
267
- The numpy array.
268
- """
269
- return np.array(x)
270
-
271
-
272
- @set_module_as('brainstate.math')
273
- def tree_zeros_like(tree):
274
- """
275
- Create a tree with the same structure as the input tree, but with zeros in each leaf.
276
-
277
- Args:
278
- tree: The input tree.
279
-
280
- Returns:
281
- The tree with zeros in each leaf.
282
- """
283
- return jax.tree_map(jnp.zeros_like, tree)
284
-
285
-
286
- @set_module_as('brainstate.math')
287
- def tree_ones_like(tree):
288
- """
289
- Create a tree with the same structure as the input tree, but with ones in each leaf.
290
-
291
- Args:
292
- tree: The input tree.
293
-
294
- Returns:
295
- The tree with ones in each leaf.
296
-
297
- """
298
- return jax.tree_map(jnp.ones_like, tree)
@@ -1,58 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import jax.numpy as jnp
17
- import numpy as np
18
- from scipy.special import exprel
19
-
20
- import brainstate as bc
21
- from brainstate import math
22
-
23
-
24
- def test_exprel():
25
- np.printoptions(precision=30)
26
-
27
- with bc.environ.context(precision=64):
28
- # Test with float64 input
29
- x = jnp.array([0.0, 1.0, 10.0, 100.0, 717.0, 718.0], dtype=jnp.float64)
30
- # expected = jnp.array([1.0, 1.718281828459045, 2.2025466e+03, jnp.inf, jnp.inf, jnp.inf])
31
- # print(math.exprel(x), exprel(np.asarray(x)))
32
- assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)
33
-
34
- with bc.environ.context(precision=32):
35
- # Test with float32 input
36
- x = jnp.array([0.0, 1.0, 10.0, 100.0], dtype=jnp.float32)
37
- # expected = jnp.array([1.0, 1.7182817, 2.2025466e+03, jnp.inf])
38
- # print(math.exprel(x), exprel(np.asarray(x)))
39
- assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)
40
-
41
- # Test with float16 input
42
- x = jnp.array([0.0, 1.0, 10.0], dtype=jnp.float16)
43
- # expected = jnp.array([1.0, 1.71875, 2.2025466e+03])
44
- # print(math.exprel(x), exprel(np.asarray(x)))
45
- assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-3)
46
-
47
- # Test with int input
48
- x = jnp.array([0, 1, 10])
49
- # expected = jnp.array([1.0, 1.718281828459045, 2.20254658e+03])
50
- # print(math.exprel(x), exprel(np.asarray(x)))
51
- assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)
52
-
53
- with bc.environ.context(precision=64):
54
- # Test with negative input
55
- x = jnp.array([-1.0, -10.0, -100.0], dtype=jnp.float64)
56
- # expected = jnp.array([0.63212055, 0.09999546, 0.01 ])
57
- # print(math.exprel(x), exprel(np.asarray(x)))
58
- assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)