brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from absl.testing import absltest
|
19
|
+
from absl.testing import parameterized
|
20
|
+
|
21
|
+
import brainstate as bst
|
22
|
+
|
23
|
+
|
24
|
+
class Test_Activation(parameterized.TestCase):
|
25
|
+
|
26
|
+
def test_Threshold(self):
|
27
|
+
threshold_layer = bst.nn.Threshold(5, 20)
|
28
|
+
input = bst.random.randn(2)
|
29
|
+
output = threshold_layer(input)
|
30
|
+
|
31
|
+
def test_ReLU(self):
|
32
|
+
ReLU_layer = bst.nn.ReLU()
|
33
|
+
input = bst.random.randn(2)
|
34
|
+
output = ReLU_layer(input)
|
35
|
+
|
36
|
+
def test_RReLU(self):
|
37
|
+
RReLU_layer = bst.nn.RReLU(lower=0, upper=1)
|
38
|
+
input = bst.random.randn(2)
|
39
|
+
output = RReLU_layer(input)
|
40
|
+
|
41
|
+
def test_Hardtanh(self):
|
42
|
+
Hardtanh_layer = bst.nn.Hardtanh(min_val=0, max_val=1, )
|
43
|
+
input = bst.random.randn(2)
|
44
|
+
output = Hardtanh_layer(input)
|
45
|
+
|
46
|
+
def test_ReLU6(self):
|
47
|
+
ReLU6_layer = bst.nn.ReLU6()
|
48
|
+
input = bst.random.randn(2)
|
49
|
+
output = ReLU6_layer(input)
|
50
|
+
|
51
|
+
def test_Sigmoid(self):
|
52
|
+
Sigmoid_layer = bst.nn.Sigmoid()
|
53
|
+
input = bst.random.randn(2)
|
54
|
+
output = Sigmoid_layer(input)
|
55
|
+
|
56
|
+
def test_Hardsigmoid(self):
|
57
|
+
Hardsigmoid_layer = bst.nn.Hardsigmoid()
|
58
|
+
input = bst.random.randn(2)
|
59
|
+
output = Hardsigmoid_layer(input)
|
60
|
+
|
61
|
+
def test_Tanh(self):
|
62
|
+
Tanh_layer = bst.nn.Tanh()
|
63
|
+
input = bst.random.randn(2)
|
64
|
+
output = Tanh_layer(input)
|
65
|
+
|
66
|
+
def test_SiLU(self):
|
67
|
+
SiLU_layer = bst.nn.SiLU()
|
68
|
+
input = bst.random.randn(2)
|
69
|
+
output = SiLU_layer(input)
|
70
|
+
|
71
|
+
def test_Mish(self):
|
72
|
+
Mish_layer = bst.nn.Mish()
|
73
|
+
input = bst.random.randn(2)
|
74
|
+
output = Mish_layer(input)
|
75
|
+
|
76
|
+
def test_Hardswish(self):
|
77
|
+
Hardswish_layer = bst.nn.Hardswish()
|
78
|
+
input = bst.random.randn(2)
|
79
|
+
output = Hardswish_layer(input)
|
80
|
+
|
81
|
+
def test_ELU(self):
|
82
|
+
ELU_layer = bst.nn.ELU(alpha=0.5, )
|
83
|
+
input = bst.random.randn(2)
|
84
|
+
output = ELU_layer(input)
|
85
|
+
|
86
|
+
def test_CELU(self):
|
87
|
+
CELU_layer = bst.nn.CELU(alpha=0.5, )
|
88
|
+
input = bst.random.randn(2)
|
89
|
+
output = CELU_layer(input)
|
90
|
+
|
91
|
+
def test_SELU(self):
|
92
|
+
SELU_layer = bst.nn.SELU()
|
93
|
+
input = bst.random.randn(2)
|
94
|
+
output = SELU_layer(input)
|
95
|
+
|
96
|
+
def test_GLU(self):
|
97
|
+
GLU_layer = bst.nn.GLU()
|
98
|
+
input = bst.random.randn(4, 2)
|
99
|
+
output = GLU_layer(input)
|
100
|
+
|
101
|
+
@parameterized.product(
|
102
|
+
approximate=['tanh', 'none']
|
103
|
+
)
|
104
|
+
def test_GELU(self, approximate):
|
105
|
+
GELU_layer = bst.nn.GELU()
|
106
|
+
input = bst.random.randn(2)
|
107
|
+
output = GELU_layer(input)
|
108
|
+
|
109
|
+
def test_Hardshrink(self):
|
110
|
+
Hardshrink_layer = bst.nn.Hardshrink(lambd=1)
|
111
|
+
input = bst.random.randn(2)
|
112
|
+
output = Hardshrink_layer(input)
|
113
|
+
|
114
|
+
def test_LeakyReLU(self):
|
115
|
+
LeakyReLU_layer = bst.nn.LeakyReLU()
|
116
|
+
input = bst.random.randn(2)
|
117
|
+
output = LeakyReLU_layer(input)
|
118
|
+
|
119
|
+
def test_LogSigmoid(self):
|
120
|
+
LogSigmoid_layer = bst.nn.LogSigmoid()
|
121
|
+
input = bst.random.randn(2)
|
122
|
+
output = LogSigmoid_layer(input)
|
123
|
+
|
124
|
+
def test_Softplus(self):
|
125
|
+
Softplus_layer = bst.nn.Softplus()
|
126
|
+
input = bst.random.randn(2)
|
127
|
+
output = Softplus_layer(input)
|
128
|
+
|
129
|
+
def test_Softshrink(self):
|
130
|
+
Softshrink_layer = bst.nn.Softshrink(lambd=1)
|
131
|
+
input = bst.random.randn(2)
|
132
|
+
output = Softshrink_layer(input)
|
133
|
+
|
134
|
+
def test_PReLU(self):
|
135
|
+
PReLU_layer = bst.nn.PReLU(num_parameters=2, init=0.5)
|
136
|
+
input = bst.random.randn(2)
|
137
|
+
output = PReLU_layer(input)
|
138
|
+
|
139
|
+
def test_Softsign(self):
|
140
|
+
Softsign_layer = bst.nn.Softsign()
|
141
|
+
input = bst.random.randn(2)
|
142
|
+
output = Softsign_layer(input)
|
143
|
+
|
144
|
+
def test_Tanhshrink(self):
|
145
|
+
Tanhshrink_layer = bst.nn.Tanhshrink()
|
146
|
+
input = bst.random.randn(2)
|
147
|
+
output = Tanhshrink_layer(input)
|
148
|
+
|
149
|
+
def test_Softmin(self):
|
150
|
+
Softmin_layer = bst.nn.Softmin(dim=2)
|
151
|
+
input = bst.random.randn(2, 3, 4)
|
152
|
+
output = Softmin_layer(input)
|
153
|
+
|
154
|
+
def test_Softmax(self):
|
155
|
+
Softmax_layer = bst.nn.Softmax(dim=2)
|
156
|
+
input = bst.random.randn(2, 3, 4)
|
157
|
+
output = Softmax_layer(input)
|
158
|
+
|
159
|
+
def test_Softmax2d(self):
|
160
|
+
Softmax2d_layer = bst.nn.Softmax2d()
|
161
|
+
input = bst.random.randn(2, 3, 12, 13)
|
162
|
+
output = Softmax2d_layer(input)
|
163
|
+
|
164
|
+
def test_LogSoftmax(self):
|
165
|
+
LogSoftmax_layer = bst.nn.LogSoftmax(dim=2)
|
166
|
+
input = bst.random.randn(2, 3, 4)
|
167
|
+
output = LogSoftmax_layer(input)
|
168
|
+
|
169
|
+
|
170
|
+
if __name__ == '__main__':
|
171
|
+
absltest.main()
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import Callable
|
20
|
+
|
21
|
+
import brainunit as u
|
22
|
+
import jax.numpy as jnp
|
23
|
+
|
24
|
+
from brainstate import environ, random
|
25
|
+
from brainstate.augment import vector_grad
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'exp_euler_step',
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
def exp_euler_step(
|
33
|
+
fn: Callable, *args, **kwargs
|
34
|
+
):
|
35
|
+
"""
|
36
|
+
One-step Exponential Euler method for solving ODEs.
|
37
|
+
|
38
|
+
Examples
|
39
|
+
--------
|
40
|
+
|
41
|
+
>>> def fun(x, t):
|
42
|
+
... return -x
|
43
|
+
>>> x = 1.0
|
44
|
+
>>> exp_euler_step(fun, x, None)
|
45
|
+
|
46
|
+
If the variable ( $x$ ) has units of ( $[X]$ ), then the drift term ( $\text{drift_fn}(x)$ ) should
|
47
|
+
have units of ( $[X]/[T]$ ), where ( $[T]$ ) is the unit of time.
|
48
|
+
|
49
|
+
If the variable ( x ) has units of ( [X] ), then the diffusion term ( \text{diffusion_fn}(x) )
|
50
|
+
should have units of ( [X]/\sqrt{[T]} ).
|
51
|
+
|
52
|
+
Args:
|
53
|
+
fun: Callable. The function to be solved.
|
54
|
+
diffusion: Callable. The diffusion function.
|
55
|
+
*args: The input arguments.
|
56
|
+
drift: Callable. The drift function.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
The one-step solution of the ODE.
|
60
|
+
"""
|
61
|
+
assert callable(fn), 'The input function should be callable.'
|
62
|
+
assert len(args) > 0, 'The input arguments should not be empty.'
|
63
|
+
if callable(args[0]):
|
64
|
+
diffusion = args[0]
|
65
|
+
args = args[1:]
|
66
|
+
else:
|
67
|
+
diffusion = None
|
68
|
+
assert len(args) > 0, 'The input arguments should not be empty.'
|
69
|
+
if u.math.get_dtype(args[0]) not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
|
70
|
+
raise ValueError(
|
71
|
+
f'The input data type should be float64, float32, float16, or bfloat16 '
|
72
|
+
f'when using Exponential Euler method. But we got {args[0].dtype}.'
|
73
|
+
)
|
74
|
+
dt = environ.get('dt')
|
75
|
+
linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
|
76
|
+
linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
|
77
|
+
phi = u.math.exprel(dt * linear)
|
78
|
+
x_next = args[0] + dt * phi * derivative
|
79
|
+
|
80
|
+
if diffusion is not None:
|
81
|
+
# unit checking
|
82
|
+
diffusion = diffusion(*args, **kwargs)
|
83
|
+
time_unit = u.get_unit(dt)
|
84
|
+
drift_unit = u.get_unit(derivative)
|
85
|
+
diffusion_unit = u.get_unit(diffusion)
|
86
|
+
# if drift_unit.is_unitless:
|
87
|
+
# assert diffusion_unit.is_unitless, 'The diffusion term should be unitless when the drift term is unitless.'
|
88
|
+
# else:
|
89
|
+
# u.fail_for_dimension_mismatch(
|
90
|
+
# drift_unit, diffusion_unit * time_unit ** 0.5,
|
91
|
+
# "Drift unit is {drift}, diffusion unit is {diffusion}, ",
|
92
|
+
# drift=drift_unit, diffusion=diffusion_unit * time_unit ** 0.5
|
93
|
+
# )
|
94
|
+
|
95
|
+
# diffusion
|
96
|
+
x_next += diffusion * u.math.sqrt(dt) * random.randn_like(args[0])
|
97
|
+
return x_next
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class TestExpEuler(unittest.TestCase):
|
26
|
+
def test1(self):
|
27
|
+
def fun(x, tau):
|
28
|
+
return -x / tau
|
29
|
+
|
30
|
+
with bst.environ.context(dt=0.1):
|
31
|
+
with self.assertRaises(AssertionError):
|
32
|
+
r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
33
|
+
|
34
|
+
with bst.environ.context(dt=1. * u.ms):
|
35
|
+
r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
36
|
+
print(r)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ._connections import *
|
17
|
+
from ._connections import __all__ as connections_all
|
18
|
+
from ._embedding import *
|
19
|
+
from ._embedding import __all__ as embed_all
|
20
|
+
from ._normalizations import *
|
21
|
+
from ._normalizations import __all__ as normalizations_all
|
22
|
+
from ._poolings import *
|
23
|
+
from ._poolings import __all__ as poolings_all
|
24
|
+
|
25
|
+
__all__ = (
|
26
|
+
connections_all +
|
27
|
+
normalizations_all +
|
28
|
+
poolings_all +
|
29
|
+
embed_all
|
30
|
+
)
|
31
|
+
|
32
|
+
del connections_all, normalizations_all, poolings_all, embed_all
|