brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,52 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
|
21
|
-
import brainstate
|
22
|
-
|
23
|
-
|
24
|
-
class TestJitError(unittest.TestCase):
|
25
|
-
def test1(self):
|
26
|
-
with self.assertRaises(Exception):
|
27
|
-
brainstate.compile.jit_error_if(True, 'error')
|
28
|
-
|
29
|
-
def err_f(x):
|
30
|
-
raise ValueError(f'error: {x}')
|
31
|
-
|
32
|
-
brainstate.compile.jit_error_if(False, err_f, 1.)
|
33
|
-
with self.assertRaises(Exception):
|
34
|
-
brainstate.compile.jit_error_if(True, err_f, 1.)
|
35
|
-
|
36
|
-
def test_vmap(self):
|
37
|
-
def f(x):
|
38
|
-
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
39
|
-
|
40
|
-
jax.vmap(f)(jnp.array([False, False, False]))
|
41
|
-
with self.assertRaises(Exception):
|
42
|
-
jax.vmap(f)(jnp.array([True, False, False]))
|
43
|
-
|
44
|
-
def test_vmap_vmap(self):
|
45
|
-
def f(x):
|
46
|
-
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
47
|
-
|
48
|
-
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
49
|
-
[False, False, False]]))
|
50
|
-
with self.assertRaises(Exception):
|
51
|
-
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
52
|
-
[True, False, False]]))
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
|
23
|
+
|
24
|
+
class TestJitError(unittest.TestCase):
|
25
|
+
def test1(self):
|
26
|
+
with self.assertRaises(Exception):
|
27
|
+
brainstate.compile.jit_error_if(True, 'error')
|
28
|
+
|
29
|
+
def err_f(x):
|
30
|
+
raise ValueError(f'error: {x}')
|
31
|
+
|
32
|
+
brainstate.compile.jit_error_if(False, err_f, 1.)
|
33
|
+
with self.assertRaises(Exception):
|
34
|
+
brainstate.compile.jit_error_if(True, err_f, 1.)
|
35
|
+
|
36
|
+
def test_vmap(self):
|
37
|
+
def f(x):
|
38
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
39
|
+
|
40
|
+
jax.vmap(f)(jnp.array([False, False, False]))
|
41
|
+
with self.assertRaises(Exception):
|
42
|
+
jax.vmap(f)(jnp.array([True, False, False]))
|
43
|
+
|
44
|
+
def test_vmap_vmap(self):
|
45
|
+
def f(x):
|
46
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
47
|
+
|
48
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
49
|
+
[False, False, False]]))
|
50
|
+
with self.assertRaises(Exception):
|
51
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
52
|
+
[True, False, False]]))
|
@@ -1,145 +1,145 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from typing import Any, TypeVar, Callable, Sequence, Union
|
18
|
-
|
19
|
-
import jax
|
20
|
-
|
21
|
-
from brainstate import random
|
22
|
-
from brainstate._utils import set_module_as
|
23
|
-
from brainstate.graph import Node, flatten, unflatten
|
24
|
-
from ._random import restore_rngs
|
25
|
-
|
26
|
-
__all__ = [
|
27
|
-
'abstract_init',
|
28
|
-
]
|
29
|
-
|
30
|
-
A = TypeVar('A')
|
31
|
-
|
32
|
-
|
33
|
-
@set_module_as('brainstate.transform')
|
34
|
-
def abstract_init(
|
35
|
-
fn: Callable[..., A],
|
36
|
-
*args: Any,
|
37
|
-
rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
|
38
|
-
**kwargs: Any,
|
39
|
-
) -> A:
|
40
|
-
"""
|
41
|
-
Compute the shape/dtype of ``fn`` without any FLOPs.
|
42
|
-
|
43
|
-
This function evaluates the shape and dtype of the output of a function without
|
44
|
-
actually executing the computational operations. It's particularly useful for
|
45
|
-
initializing neural network models to understand their structure and parameter
|
46
|
-
shapes without performing expensive computations.
|
47
|
-
|
48
|
-
Parameters
|
49
|
-
----------
|
50
|
-
fn : callable
|
51
|
-
The function whose output shape should be evaluated.
|
52
|
-
*args
|
53
|
-
Positional argument tuple of arrays, scalars, or (nested) standard
|
54
|
-
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
55
|
-
those types. Since only the ``shape`` and ``dtype`` attributes are
|
56
|
-
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
57
|
-
that duck-types as ndarrays (note however that duck-typed objects cannot
|
58
|
-
be namedtuples because those are treated as standard Python containers).
|
59
|
-
rngs : RandomState or sequence of RandomState, default random.DEFAULT
|
60
|
-
A :class:`RandomState` or a sequence of :class:`RandomState` objects
|
61
|
-
representing the random number generators to use. If not provided, the
|
62
|
-
default random number generator will be used.
|
63
|
-
**kwargs
|
64
|
-
Keyword argument dict of arrays, scalars, or (nested) standard
|
65
|
-
Python containers (pytrees) of those types. As in ``args``, array values
|
66
|
-
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
67
|
-
|
68
|
-
Returns
|
69
|
-
-------
|
70
|
-
A
|
71
|
-
A nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves,
|
72
|
-
representing the structure and shape/dtype information of the function output.
|
73
|
-
|
74
|
-
Examples
|
75
|
-
--------
|
76
|
-
Basic usage with neural network initialization:
|
77
|
-
|
78
|
-
.. code-block:: python
|
79
|
-
|
80
|
-
>>> import brainstate
|
81
|
-
>>> import jax.numpy as jnp
|
82
|
-
>>>
|
83
|
-
>>> class MLP:
|
84
|
-
... def __init__(self, n_in, n_mid, n_out):
|
85
|
-
... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
86
|
-
... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
87
|
-
>>>
|
88
|
-
>>> # Get shape information without actual computation
|
89
|
-
>>> model_shape = brainstate.transform.abstract_init(lambda: MLP(1, 2, 3))
|
90
|
-
|
91
|
-
With function arguments:
|
92
|
-
|
93
|
-
.. code-block:: python
|
94
|
-
|
95
|
-
>>> def create_model(input_size, hidden_size, output_size):
|
96
|
-
... return brainstate.nn.Sequential([
|
97
|
-
... brainstate.nn.Linear(input_size, hidden_size),
|
98
|
-
... brainstate.nn.ReLU(),
|
99
|
-
... brainstate.nn.Linear(hidden_size, output_size)
|
100
|
-
... ])
|
101
|
-
>>>
|
102
|
-
>>> # Abstract initialization with arguments
|
103
|
-
>>> model_shape = brainstate.transform.abstract_init(
|
104
|
-
... create_model, 784, 256, 10
|
105
|
-
... )
|
106
|
-
|
107
|
-
Using custom random number generators:
|
108
|
-
|
109
|
-
.. code-block:: python
|
110
|
-
|
111
|
-
>>> import brainstate.random as random
|
112
|
-
>>>
|
113
|
-
>>> # Create custom RNG
|
114
|
-
>>> rng = random.RandomState(42)
|
115
|
-
>>>
|
116
|
-
>>> def init_with_custom_weights():
|
117
|
-
... return brainstate.nn.Linear(10, 5)
|
118
|
-
>>>
|
119
|
-
>>> model_shape = brainstate.transform.abstract_init(
|
120
|
-
... init_with_custom_weights, rngs=rng
|
121
|
-
... )
|
122
|
-
|
123
|
-
Evaluating function with array inputs:
|
124
|
-
|
125
|
-
.. code-block:: python
|
126
|
-
|
127
|
-
>>> def model_forward(x):
|
128
|
-
... layer = brainstate.nn.Linear(x.shape[-1], 128)
|
129
|
-
... return layer(x)
|
130
|
-
>>>
|
131
|
-
>>> # Use ShapeDtypeStruct to represent input without actual data
|
132
|
-
>>> input_shape = jax.ShapeDtypeStruct((32, 784), jnp.float32)
|
133
|
-
>>> output_shape = brainstate.transform.abstract_init(model_forward, input_shape)
|
134
|
-
"""
|
135
|
-
|
136
|
-
@functools.wraps(fn)
|
137
|
-
@restore_rngs(rngs=rngs)
|
138
|
-
def _eval_shape_fn(*args_, **kwargs_):
|
139
|
-
out = fn(*args_, **kwargs_)
|
140
|
-
assert isinstance(out, Node), 'The output of the function must be Node'
|
141
|
-
graph_def, treefy_states = flatten(out)
|
142
|
-
return graph_def, treefy_states
|
143
|
-
|
144
|
-
graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
|
145
|
-
return unflatten(graph_def_, treefy_states_)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from typing import Any, TypeVar, Callable, Sequence, Union
|
18
|
+
|
19
|
+
import jax
|
20
|
+
|
21
|
+
from brainstate import random
|
22
|
+
from brainstate._utils import set_module_as
|
23
|
+
from brainstate.graph import Node, flatten, unflatten
|
24
|
+
from ._random import restore_rngs
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'abstract_init',
|
28
|
+
]
|
29
|
+
|
30
|
+
A = TypeVar('A')
|
31
|
+
|
32
|
+
|
33
|
+
@set_module_as('brainstate.transform')
|
34
|
+
def abstract_init(
|
35
|
+
fn: Callable[..., A],
|
36
|
+
*args: Any,
|
37
|
+
rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
|
38
|
+
**kwargs: Any,
|
39
|
+
) -> A:
|
40
|
+
"""
|
41
|
+
Compute the shape/dtype of ``fn`` without any FLOPs.
|
42
|
+
|
43
|
+
This function evaluates the shape and dtype of the output of a function without
|
44
|
+
actually executing the computational operations. It's particularly useful for
|
45
|
+
initializing neural network models to understand their structure and parameter
|
46
|
+
shapes without performing expensive computations.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
fn : callable
|
51
|
+
The function whose output shape should be evaluated.
|
52
|
+
*args
|
53
|
+
Positional argument tuple of arrays, scalars, or (nested) standard
|
54
|
+
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
55
|
+
those types. Since only the ``shape`` and ``dtype`` attributes are
|
56
|
+
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
57
|
+
that duck-types as ndarrays (note however that duck-typed objects cannot
|
58
|
+
be namedtuples because those are treated as standard Python containers).
|
59
|
+
rngs : RandomState or sequence of RandomState, default random.DEFAULT
|
60
|
+
A :class:`RandomState` or a sequence of :class:`RandomState` objects
|
61
|
+
representing the random number generators to use. If not provided, the
|
62
|
+
default random number generator will be used.
|
63
|
+
**kwargs
|
64
|
+
Keyword argument dict of arrays, scalars, or (nested) standard
|
65
|
+
Python containers (pytrees) of those types. As in ``args``, array values
|
66
|
+
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
67
|
+
|
68
|
+
Returns
|
69
|
+
-------
|
70
|
+
A
|
71
|
+
A nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves,
|
72
|
+
representing the structure and shape/dtype information of the function output.
|
73
|
+
|
74
|
+
Examples
|
75
|
+
--------
|
76
|
+
Basic usage with neural network initialization:
|
77
|
+
|
78
|
+
.. code-block:: python
|
79
|
+
|
80
|
+
>>> import brainstate
|
81
|
+
>>> import jax.numpy as jnp
|
82
|
+
>>>
|
83
|
+
>>> class MLP:
|
84
|
+
... def __init__(self, n_in, n_mid, n_out):
|
85
|
+
... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
86
|
+
... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
87
|
+
>>>
|
88
|
+
>>> # Get shape information without actual computation
|
89
|
+
>>> model_shape = brainstate.transform.abstract_init(lambda: MLP(1, 2, 3))
|
90
|
+
|
91
|
+
With function arguments:
|
92
|
+
|
93
|
+
.. code-block:: python
|
94
|
+
|
95
|
+
>>> def create_model(input_size, hidden_size, output_size):
|
96
|
+
... return brainstate.nn.Sequential([
|
97
|
+
... brainstate.nn.Linear(input_size, hidden_size),
|
98
|
+
... brainstate.nn.ReLU(),
|
99
|
+
... brainstate.nn.Linear(hidden_size, output_size)
|
100
|
+
... ])
|
101
|
+
>>>
|
102
|
+
>>> # Abstract initialization with arguments
|
103
|
+
>>> model_shape = brainstate.transform.abstract_init(
|
104
|
+
... create_model, 784, 256, 10
|
105
|
+
... )
|
106
|
+
|
107
|
+
Using custom random number generators:
|
108
|
+
|
109
|
+
.. code-block:: python
|
110
|
+
|
111
|
+
>>> import brainstate.random as random
|
112
|
+
>>>
|
113
|
+
>>> # Create custom RNG
|
114
|
+
>>> rng = random.RandomState(42)
|
115
|
+
>>>
|
116
|
+
>>> def init_with_custom_weights():
|
117
|
+
... return brainstate.nn.Linear(10, 5)
|
118
|
+
>>>
|
119
|
+
>>> model_shape = brainstate.transform.abstract_init(
|
120
|
+
... init_with_custom_weights, rngs=rng
|
121
|
+
... )
|
122
|
+
|
123
|
+
Evaluating function with array inputs:
|
124
|
+
|
125
|
+
.. code-block:: python
|
126
|
+
|
127
|
+
>>> def model_forward(x):
|
128
|
+
... layer = brainstate.nn.Linear(x.shape[-1], 128)
|
129
|
+
... return layer(x)
|
130
|
+
>>>
|
131
|
+
>>> # Use ShapeDtypeStruct to represent input without actual data
|
132
|
+
>>> input_shape = jax.ShapeDtypeStruct((32, 784), jnp.float32)
|
133
|
+
>>> output_shape = brainstate.transform.abstract_init(model_forward, input_shape)
|
134
|
+
"""
|
135
|
+
|
136
|
+
@functools.wraps(fn)
|
137
|
+
@restore_rngs(rngs=rngs)
|
138
|
+
def _eval_shape_fn(*args_, **kwargs_):
|
139
|
+
out = fn(*args_, **kwargs_)
|
140
|
+
assert isinstance(out, Node), 'The output of the function must be Node'
|
141
|
+
graph_def, treefy_states = flatten(out)
|
142
|
+
return graph_def, treefy_states
|
143
|
+
|
144
|
+
graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
|
145
|
+
return unflatten(graph_def_, treefy_states_)
|
@@ -1,38 +1,38 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import unittest
|
18
|
-
|
19
|
-
import brainstate
|
20
|
-
|
21
|
-
|
22
|
-
class TestEvalShape(unittest.TestCase):
|
23
|
-
def test1(self):
|
24
|
-
class MLP(brainstate.nn.Module):
|
25
|
-
def __init__(self, n_in, n_mid, n_out):
|
26
|
-
super().__init__()
|
27
|
-
self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
28
|
-
self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
29
|
-
|
30
|
-
def __call__(self, x):
|
31
|
-
x = self.dense1(x)
|
32
|
-
x = brainstate.functional.relu(x)
|
33
|
-
x = self.dense2(x)
|
34
|
-
return x
|
35
|
-
|
36
|
-
r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
|
37
|
-
print(r)
|
38
|
-
print(brainstate.random.DEFAULT)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import brainstate
|
20
|
+
|
21
|
+
|
22
|
+
class TestEvalShape(unittest.TestCase):
|
23
|
+
def test1(self):
|
24
|
+
class MLP(brainstate.nn.Module):
|
25
|
+
def __init__(self, n_in, n_mid, n_out):
|
26
|
+
super().__init__()
|
27
|
+
self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
28
|
+
self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
29
|
+
|
30
|
+
def __call__(self, x):
|
31
|
+
x = self.dense1(x)
|
32
|
+
x = brainstate.functional.relu(x)
|
33
|
+
x = self.dense2(x)
|
34
|
+
return x
|
35
|
+
|
36
|
+
r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
|
37
|
+
print(r)
|
38
|
+
print(brainstate.random.DEFAULT)
|