brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/transform/_random.py
DELETED
@@ -1,171 +0,0 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from typing import Callable, Sequence, Union
|
18
|
-
|
19
|
-
from brainstate._utils import set_module_as
|
20
|
-
from brainstate.random import DEFAULT, RandomState
|
21
|
-
from brainstate.typing import Missing
|
22
|
-
from brainstate.util import PrettyObject
|
23
|
-
|
24
|
-
__all__ = [
|
25
|
-
'restore_rngs'
|
26
|
-
]
|
27
|
-
|
28
|
-
|
29
|
-
class RngRestore(PrettyObject):
|
30
|
-
"""
|
31
|
-
Manage backing up and restoring multiple random states.
|
32
|
-
|
33
|
-
Parameters
|
34
|
-
----------
|
35
|
-
rngs : Sequence[RandomState]
|
36
|
-
Sequence of :class:`~brainstate.random.RandomState` instances whose
|
37
|
-
states should be captured and restored.
|
38
|
-
|
39
|
-
Attributes
|
40
|
-
----------
|
41
|
-
rngs : Sequence[RandomState]
|
42
|
-
Managed random-state instances.
|
43
|
-
rng_keys : list
|
44
|
-
Cached keys captured by :meth:`backup` until :meth:`restore` runs.
|
45
|
-
|
46
|
-
Examples
|
47
|
-
--------
|
48
|
-
.. code-block:: python
|
49
|
-
|
50
|
-
>>> import brainstate
|
51
|
-
>>>
|
52
|
-
>>> rng = brainstate.random.RandomState(0)
|
53
|
-
>>> restorer = brainstate.transform.RngRestore([rng])
|
54
|
-
>>> restorer.backup()
|
55
|
-
>>> _ = rng.random()
|
56
|
-
>>> restorer.restore()
|
57
|
-
"""
|
58
|
-
__module__ = 'brainstate.transform'
|
59
|
-
|
60
|
-
def __init__(self, rngs: Sequence[RandomState]):
|
61
|
-
"""
|
62
|
-
Initialize a restorer for the provided random states.
|
63
|
-
|
64
|
-
Parameters
|
65
|
-
----------
|
66
|
-
rngs : Sequence[RandomState]
|
67
|
-
Random states that will be backed up and restored.
|
68
|
-
"""
|
69
|
-
self.rngs: Sequence[RandomState] = rngs
|
70
|
-
self.rng_keys = []
|
71
|
-
|
72
|
-
def backup(self):
|
73
|
-
"""
|
74
|
-
Cache the current key for each managed random state.
|
75
|
-
|
76
|
-
Notes
|
77
|
-
-----
|
78
|
-
The cached keys persist until :meth:`restore` is called, after which the
|
79
|
-
internal cache is cleared.
|
80
|
-
"""
|
81
|
-
self.rng_keys = [rng.value for rng in self.rngs]
|
82
|
-
|
83
|
-
def restore(self):
|
84
|
-
"""
|
85
|
-
Restore each random state to the cached key.
|
86
|
-
|
87
|
-
Raises
|
88
|
-
------
|
89
|
-
ValueError
|
90
|
-
Raised when the number of stored keys does not match ``rngs``.
|
91
|
-
"""
|
92
|
-
if len(self.rng_keys) != len(self.rngs):
|
93
|
-
raise ValueError('The number of random keys does not match the number of random states.')
|
94
|
-
for rng, key in zip(self.rngs, self.rng_keys):
|
95
|
-
rng.restore_value(key)
|
96
|
-
self.rng_keys.clear()
|
97
|
-
|
98
|
-
|
99
|
-
def _rng_backup(
|
100
|
-
fn: Callable,
|
101
|
-
rngs: Union[RandomState, Sequence[RandomState]]
|
102
|
-
) -> Callable:
|
103
|
-
rng_restorer = RngRestore(rngs)
|
104
|
-
|
105
|
-
@functools.wraps(fn)
|
106
|
-
def wrapper(*args, **kwargs):
|
107
|
-
# backup the random state
|
108
|
-
rng_restorer.backup()
|
109
|
-
# call the function
|
110
|
-
out = fn(*args, **kwargs)
|
111
|
-
# restore the random state
|
112
|
-
rng_restorer.restore()
|
113
|
-
return out
|
114
|
-
|
115
|
-
return wrapper
|
116
|
-
|
117
|
-
|
118
|
-
@set_module_as('brainstate.transform')
|
119
|
-
def restore_rngs(
|
120
|
-
fn: Callable = Missing(),
|
121
|
-
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
122
|
-
) -> Callable:
|
123
|
-
"""
|
124
|
-
Decorate a function so specified random states are restored after execution.
|
125
|
-
|
126
|
-
Parameters
|
127
|
-
----------
|
128
|
-
fn : Callable, optional
|
129
|
-
Function to wrap. When omitted, :func:`restore_rngs` returns a decorator
|
130
|
-
preconfigured with ``rngs``.
|
131
|
-
rngs : Union[RandomState, Sequence[RandomState]], optional
|
132
|
-
Random states whose keys should be backed up before running ``fn`` and
|
133
|
-
restored afterwards. Defaults to :data:`brainstate.random.DEFAULT`.
|
134
|
-
|
135
|
-
Returns
|
136
|
-
-------
|
137
|
-
Callable
|
138
|
-
Wrapped callable that restores the random state or a partially applied
|
139
|
-
decorator depending on how :func:`restore_rngs` is used.
|
140
|
-
|
141
|
-
Raises
|
142
|
-
------
|
143
|
-
AssertionError
|
144
|
-
If ``rngs`` is neither a :class:`~brainstate.random.RandomState` instance nor
|
145
|
-
a sequence of such instances.
|
146
|
-
|
147
|
-
Examples
|
148
|
-
--------
|
149
|
-
.. code-block:: python
|
150
|
-
|
151
|
-
>>> import brainstate
|
152
|
-
>>>
|
153
|
-
>>> rng = brainstate.random.RandomState(0)
|
154
|
-
>>>
|
155
|
-
>>> @brainstate.transform.restore_rngs(rngs=rng)
|
156
|
-
... def sample_pair():
|
157
|
-
... first = rng.random()
|
158
|
-
... second = rng.random()
|
159
|
-
... return first, second
|
160
|
-
>>>
|
161
|
-
>>> assert sample_pair()[0] == sample_pair()[0]
|
162
|
-
"""
|
163
|
-
if isinstance(fn, Missing):
|
164
|
-
return functools.partial(restore_rngs, rngs=rngs)
|
165
|
-
|
166
|
-
if isinstance(rngs, RandomState):
|
167
|
-
rngs = [rngs]
|
168
|
-
assert isinstance(rngs, Sequence), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
169
|
-
for rng in rngs:
|
170
|
-
assert isinstance(rng, RandomState), 'rngs must be a RandomState or a sequence of RandomState instances.'
|
171
|
-
return _rng_backup(fn, rngs=rngs)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|