brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +148 -43
- brainstate/_module_test.py +95 -21
- brainstate/environ.py +0 -1
- brainstate/functional/__init__.py +2 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +0 -1
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +2 -12
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/RECORD +28 -35
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
brainstate/math/_misc.py
DELETED
@@ -1,298 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
from typing import Optional, Sequence
|
18
|
-
|
19
|
-
import jax
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import numpy as np
|
22
|
-
|
23
|
-
from brainstate import environ
|
24
|
-
from brainstate._utils import set_module_as
|
25
|
-
|
26
|
-
|
27
|
-
__all__ = [
|
28
|
-
'get_dtype',
|
29
|
-
'is_float',
|
30
|
-
'is_int',
|
31
|
-
'exprel',
|
32
|
-
'flatten',
|
33
|
-
'unflatten',
|
34
|
-
'remove_diag',
|
35
|
-
'clip_by_norm',
|
36
|
-
'from_numpy',
|
37
|
-
'as_numpy',
|
38
|
-
'tree_zeros_like',
|
39
|
-
'tree_ones_like',
|
40
|
-
]
|
41
|
-
|
42
|
-
|
43
|
-
@set_module_as('brainstate.math')
|
44
|
-
def get_dtype(a):
|
45
|
-
"""
|
46
|
-
Get the dtype of a.
|
47
|
-
"""
|
48
|
-
if hasattr(a, 'dtype'):
|
49
|
-
return a.dtype
|
50
|
-
else:
|
51
|
-
if isinstance(a, bool):
|
52
|
-
return bool
|
53
|
-
elif isinstance(a, int):
|
54
|
-
return environ.ditype()
|
55
|
-
elif isinstance(a, float):
|
56
|
-
return environ.dftype()
|
57
|
-
elif isinstance(a, complex):
|
58
|
-
return environ.dctype()
|
59
|
-
else:
|
60
|
-
raise ValueError(f'Can not get dtype of {a}.')
|
61
|
-
|
62
|
-
|
63
|
-
@set_module_as('brainstate.math')
|
64
|
-
def is_float(array):
|
65
|
-
"""
|
66
|
-
Check if the array is a floating point array.
|
67
|
-
|
68
|
-
Args:
|
69
|
-
array: The input array.
|
70
|
-
|
71
|
-
Returns:
|
72
|
-
A boolean value indicating if the array is a floating point array.
|
73
|
-
"""
|
74
|
-
return jnp.issubdtype(get_dtype(array), jnp.floating)
|
75
|
-
|
76
|
-
|
77
|
-
@set_module_as('brainstate.math')
|
78
|
-
def is_int(array):
|
79
|
-
"""
|
80
|
-
Check if the array is an integer array.
|
81
|
-
|
82
|
-
Args:
|
83
|
-
array: The input array.
|
84
|
-
|
85
|
-
Returns:
|
86
|
-
A boolean value indicating if the array is an integer array.
|
87
|
-
"""
|
88
|
-
return jnp.issubdtype(get_dtype(array), jnp.integer)
|
89
|
-
|
90
|
-
|
91
|
-
@set_module_as('brainstate.math')
|
92
|
-
def exprel(x):
|
93
|
-
"""
|
94
|
-
Relative error exponential, ``(exp(x) - 1)/x``.
|
95
|
-
|
96
|
-
When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can
|
97
|
-
suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of
|
98
|
-
precision that occurs when ``x`` is near zero.
|
99
|
-
|
100
|
-
Args:
|
101
|
-
x: ndarray. Input array. ``x`` must contain real numbers.
|
102
|
-
|
103
|
-
Returns:
|
104
|
-
``(exp(x) - 1)/x``, computed element-wise.
|
105
|
-
"""
|
106
|
-
|
107
|
-
# following the implementation of exprel from scipy.special
|
108
|
-
x = jnp.asarray(x)
|
109
|
-
dtype = x.dtype
|
110
|
-
|
111
|
-
# Adjust the tolerance based on the dtype of x
|
112
|
-
if dtype == jnp.float64:
|
113
|
-
small_threshold = 1e-16
|
114
|
-
big_threshold = 717
|
115
|
-
elif dtype == jnp.float32:
|
116
|
-
small_threshold = 1e-8
|
117
|
-
big_threshold = 100
|
118
|
-
elif dtype == jnp.float16:
|
119
|
-
small_threshold = 1e-4
|
120
|
-
big_threshold = 10
|
121
|
-
else:
|
122
|
-
small_threshold = 1e-4
|
123
|
-
big_threshold = 10
|
124
|
-
|
125
|
-
small = jnp.abs(x) < small_threshold
|
126
|
-
big = x > big_threshold
|
127
|
-
origin = jnp.expm1(x) / x
|
128
|
-
return jnp.where(small, 1.0, jnp.where(big, jnp.inf, origin))
|
129
|
-
|
130
|
-
|
131
|
-
@set_module_as('brainstate.math')
|
132
|
-
def remove_diag(arr):
|
133
|
-
"""Remove the diagonal of the matrix.
|
134
|
-
|
135
|
-
Parameters
|
136
|
-
----------
|
137
|
-
arr: ArrayType
|
138
|
-
The matrix with the shape of `(M, N)`.
|
139
|
-
|
140
|
-
Returns
|
141
|
-
-------
|
142
|
-
arr: Array
|
143
|
-
The matrix without diagonal which has the shape of `(M, N-1)`.
|
144
|
-
"""
|
145
|
-
if arr.ndim != 2:
|
146
|
-
raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.')
|
147
|
-
eyes = jnp.fill_diagonal(jnp.ones(arr.shape, dtype=bool), False)
|
148
|
-
return jnp.reshape(arr[eyes], (arr.shape[0], arr.shape[1] - 1))
|
149
|
-
|
150
|
-
|
151
|
-
@set_module_as('brainstate.math')
|
152
|
-
def clip_by_norm(t, clip_norm, axis=None):
|
153
|
-
"""
|
154
|
-
Clip the tensor by the norm of the tensor.
|
155
|
-
|
156
|
-
Args:
|
157
|
-
t: The tensor to be clipped.
|
158
|
-
clip_norm: The maximum norm value.
|
159
|
-
axis: The axis to calculate the norm. If None, the norm is calculated over the whole tensor.
|
160
|
-
|
161
|
-
Returns:
|
162
|
-
The clipped tensor.
|
163
|
-
|
164
|
-
"""
|
165
|
-
return jax.tree.map(
|
166
|
-
lambda l: l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm),
|
167
|
-
t
|
168
|
-
)
|
169
|
-
|
170
|
-
|
171
|
-
@set_module_as('brainstate.math')
|
172
|
-
def flatten(
|
173
|
-
input: jax.typing.ArrayLike,
|
174
|
-
start_dim: Optional[int] = None,
|
175
|
-
end_dim: Optional[int] = None
|
176
|
-
) -> jax.Array:
|
177
|
-
"""Flattens input by reshaping it into a one-dimensional tensor.
|
178
|
-
If ``start_dim`` or ``end_dim`` are passed, only dimensions starting
|
179
|
-
with ``start_dim`` and ending with ``end_dim`` are flattened.
|
180
|
-
The order of elements in input is unchanged.
|
181
|
-
|
182
|
-
.. note::
|
183
|
-
Flattening a zero-dimensional tensor will return a one-dimensional view.
|
184
|
-
|
185
|
-
Parameters
|
186
|
-
----------
|
187
|
-
input: Array
|
188
|
-
The input array.
|
189
|
-
start_dim: int
|
190
|
-
the first dim to flatten
|
191
|
-
end_dim: int
|
192
|
-
the last dim to flatten
|
193
|
-
|
194
|
-
Returns
|
195
|
-
-------
|
196
|
-
out: Array
|
197
|
-
"""
|
198
|
-
shape = input.shape
|
199
|
-
ndim = input.ndim
|
200
|
-
if ndim == 0:
|
201
|
-
ndim = 1
|
202
|
-
if start_dim is None:
|
203
|
-
start_dim = 0
|
204
|
-
elif start_dim < 0:
|
205
|
-
start_dim = ndim + start_dim
|
206
|
-
if end_dim is None:
|
207
|
-
end_dim = ndim - 1
|
208
|
-
elif end_dim < 0:
|
209
|
-
end_dim = ndim + end_dim
|
210
|
-
end_dim += 1
|
211
|
-
if start_dim < 0 or start_dim > ndim:
|
212
|
-
raise ValueError(f'start_dim {start_dim} is out of size.')
|
213
|
-
if end_dim < 0 or end_dim > ndim:
|
214
|
-
raise ValueError(f'end_dim {end_dim} is out of size.')
|
215
|
-
new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int),) + shape[end_dim:]
|
216
|
-
return jnp.reshape(input, new_shape)
|
217
|
-
|
218
|
-
|
219
|
-
@set_module_as('brainstate.math')
|
220
|
-
def unflatten(x: jax.typing.ArrayLike, dim: int, sizes: Sequence[int]) -> jax.Array:
|
221
|
-
"""
|
222
|
-
Expands a dimension of the input tensor over multiple dimensions.
|
223
|
-
|
224
|
-
Args:
|
225
|
-
x: input tensor.
|
226
|
-
dim: Dimension to be unflattened, specified as an index into ``x.shape``.
|
227
|
-
sizes: New shape of the unflattened dimension. One of its elements can be -1
|
228
|
-
in which case the corresponding output dimension is inferred.
|
229
|
-
Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``.
|
230
|
-
|
231
|
-
Returns:
|
232
|
-
A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions.
|
233
|
-
The returned tensor has one more dimension than the input tensor.
|
234
|
-
The returned tensor shares the same underlying data with this tensor.
|
235
|
-
"""
|
236
|
-
assert x.ndim > dim, ('The dimension to be unflattened should '
|
237
|
-
'be less than the tensor dimension. '
|
238
|
-
f'Got {dim} and {x.ndim}.')
|
239
|
-
shape = x.shape
|
240
|
-
new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:]
|
241
|
-
return jnp.reshape(x, new_shape)
|
242
|
-
|
243
|
-
|
244
|
-
@set_module_as('brainstate.math')
|
245
|
-
def from_numpy(x):
|
246
|
-
"""
|
247
|
-
Convert the numpy array to jax array.
|
248
|
-
|
249
|
-
Args:
|
250
|
-
x: The numpy array.
|
251
|
-
|
252
|
-
Returns:
|
253
|
-
The jax array.
|
254
|
-
"""
|
255
|
-
return jnp.array(x)
|
256
|
-
|
257
|
-
|
258
|
-
@set_module_as('brainstate.math')
|
259
|
-
def as_numpy(x):
|
260
|
-
"""
|
261
|
-
Convert the array to numpy array.
|
262
|
-
|
263
|
-
Args:
|
264
|
-
x: The array.
|
265
|
-
|
266
|
-
Returns:
|
267
|
-
The numpy array.
|
268
|
-
"""
|
269
|
-
return np.array(x)
|
270
|
-
|
271
|
-
|
272
|
-
@set_module_as('brainstate.math')
|
273
|
-
def tree_zeros_like(tree):
|
274
|
-
"""
|
275
|
-
Create a tree with the same structure as the input tree, but with zeros in each leaf.
|
276
|
-
|
277
|
-
Args:
|
278
|
-
tree: The input tree.
|
279
|
-
|
280
|
-
Returns:
|
281
|
-
The tree with zeros in each leaf.
|
282
|
-
"""
|
283
|
-
return jax.tree_map(jnp.zeros_like, tree)
|
284
|
-
|
285
|
-
|
286
|
-
@set_module_as('brainstate.math')
|
287
|
-
def tree_ones_like(tree):
|
288
|
-
"""
|
289
|
-
Create a tree with the same structure as the input tree, but with ones in each leaf.
|
290
|
-
|
291
|
-
Args:
|
292
|
-
tree: The input tree.
|
293
|
-
|
294
|
-
Returns:
|
295
|
-
The tree with ones in each leaf.
|
296
|
-
|
297
|
-
"""
|
298
|
-
return jax.tree_map(jnp.ones_like, tree)
|
brainstate/math/_misc_test.py
DELETED
@@ -1,58 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import jax.numpy as jnp
|
17
|
-
import numpy as np
|
18
|
-
from scipy.special import exprel
|
19
|
-
|
20
|
-
import brainstate as bc
|
21
|
-
from brainstate import math
|
22
|
-
|
23
|
-
|
24
|
-
def test_exprel():
|
25
|
-
np.printoptions(precision=30)
|
26
|
-
|
27
|
-
with bc.environ.context(precision=64):
|
28
|
-
# Test with float64 input
|
29
|
-
x = jnp.array([0.0, 1.0, 10.0, 100.0, 717.0, 718.0], dtype=jnp.float64)
|
30
|
-
# expected = jnp.array([1.0, 1.718281828459045, 2.2025466e+03, jnp.inf, jnp.inf, jnp.inf])
|
31
|
-
# print(math.exprel(x), exprel(np.asarray(x)))
|
32
|
-
assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)
|
33
|
-
|
34
|
-
with bc.environ.context(precision=32):
|
35
|
-
# Test with float32 input
|
36
|
-
x = jnp.array([0.0, 1.0, 10.0, 100.0], dtype=jnp.float32)
|
37
|
-
# expected = jnp.array([1.0, 1.7182817, 2.2025466e+03, jnp.inf])
|
38
|
-
# print(math.exprel(x), exprel(np.asarray(x)))
|
39
|
-
assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)
|
40
|
-
|
41
|
-
# Test with float16 input
|
42
|
-
x = jnp.array([0.0, 1.0, 10.0], dtype=jnp.float16)
|
43
|
-
# expected = jnp.array([1.0, 1.71875, 2.2025466e+03])
|
44
|
-
# print(math.exprel(x), exprel(np.asarray(x)))
|
45
|
-
assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-3)
|
46
|
-
|
47
|
-
# Test with int input
|
48
|
-
x = jnp.array([0, 1, 10])
|
49
|
-
# expected = jnp.array([1.0, 1.718281828459045, 2.20254658e+03])
|
50
|
-
# print(math.exprel(x), exprel(np.asarray(x)))
|
51
|
-
assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)
|
52
|
-
|
53
|
-
with bc.environ.context(precision=64):
|
54
|
-
# Test with negative input
|
55
|
-
x = jnp.array([-1.0, -10.0, -100.0], dtype=jnp.float64)
|
56
|
-
# expected = jnp.array([0.63212055, 0.09999546, 0.01 ])
|
57
|
-
# print(math.exprel(x), exprel(np.asarray(x)))
|
58
|
-
assert jnp.allclose(math.exprel(x), exprel(np.asarray(x)), rtol=1e-6)
|
File without changes
|
File without changes
|
{brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt
RENAMED
File without changes
|