brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,52 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
|
21
|
-
import brainstate
|
22
|
-
|
23
|
-
|
24
|
-
class TestJitError(unittest.TestCase):
|
25
|
-
def test1(self):
|
26
|
-
with self.assertRaises(Exception):
|
27
|
-
brainstate.compile.jit_error_if(True, 'error')
|
28
|
-
|
29
|
-
def err_f(x):
|
30
|
-
raise ValueError(f'error: {x}')
|
31
|
-
|
32
|
-
brainstate.compile.jit_error_if(False, err_f, 1.)
|
33
|
-
with self.assertRaises(Exception):
|
34
|
-
brainstate.compile.jit_error_if(True, err_f, 1.)
|
35
|
-
|
36
|
-
def test_vmap(self):
|
37
|
-
def f(x):
|
38
|
-
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
39
|
-
|
40
|
-
jax.vmap(f)(jnp.array([False, False, False]))
|
41
|
-
with self.assertRaises(Exception):
|
42
|
-
jax.vmap(f)(jnp.array([True, False, False]))
|
43
|
-
|
44
|
-
def test_vmap_vmap(self):
|
45
|
-
def f(x):
|
46
|
-
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
47
|
-
|
48
|
-
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
49
|
-
[False, False, False]]))
|
50
|
-
with self.assertRaises(Exception):
|
51
|
-
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
52
|
-
[True, False, False]]))
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
|
23
|
+
|
24
|
+
class TestJitError(unittest.TestCase):
|
25
|
+
def test1(self):
|
26
|
+
with self.assertRaises(Exception):
|
27
|
+
brainstate.compile.jit_error_if(True, 'error')
|
28
|
+
|
29
|
+
def err_f(x):
|
30
|
+
raise ValueError(f'error: {x}')
|
31
|
+
|
32
|
+
brainstate.compile.jit_error_if(False, err_f, 1.)
|
33
|
+
with self.assertRaises(Exception):
|
34
|
+
brainstate.compile.jit_error_if(True, err_f, 1.)
|
35
|
+
|
36
|
+
def test_vmap(self):
|
37
|
+
def f(x):
|
38
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
39
|
+
|
40
|
+
jax.vmap(f)(jnp.array([False, False, False]))
|
41
|
+
with self.assertRaises(Exception):
|
42
|
+
jax.vmap(f)(jnp.array([True, False, False]))
|
43
|
+
|
44
|
+
def test_vmap_vmap(self):
|
45
|
+
def f(x):
|
46
|
+
brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
|
47
|
+
|
48
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
49
|
+
[False, False, False]]))
|
50
|
+
with self.assertRaises(Exception):
|
51
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
52
|
+
[True, False, False]]))
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Callable, Dict, Hashable, Literal, Sequence, Any
|
17
|
+
|
18
|
+
from brainstate._state import State
|
19
|
+
from brainstate.util.filter import Filter, to_predicate
|
20
|
+
from ._make_jaxpr import StatefulFunction
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
'StateFinder'
|
24
|
+
]
|
25
|
+
|
26
|
+
|
27
|
+
class StateFinder:
|
28
|
+
"""
|
29
|
+
Discover :class:`~brainstate.State` instances touched by a callable.
|
30
|
+
|
31
|
+
``StateFinder`` wraps a function in :class:`~brainstate.transform.StatefulFunction`
|
32
|
+
and exposes the collection of states the function reads or writes. The finder
|
33
|
+
can filter states by predicates, request only read or write states, and return
|
34
|
+
the result in several convenient formats.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
fn : callable
|
39
|
+
Function whose state usage should be inspected.
|
40
|
+
filter : Filter, optional
|
41
|
+
Predicate (see :mod:`brainstate.util.filter`) used to select states.
|
42
|
+
usage : {'all', 'read', 'write', 'both'}, default 'all'
|
43
|
+
Portion of the state trace to return. ``'both'`` returns a mapping with
|
44
|
+
separate read and write entries.
|
45
|
+
return_type : {'dict', 'list', 'tuple'}, default 'dict'
|
46
|
+
Controls the container type returned for the selected states. When
|
47
|
+
``usage='both'``, the same container type is used for the ``'read'`` and
|
48
|
+
``'write'`` entries.
|
49
|
+
key_fn : callable, optional
|
50
|
+
Callable ``key_fn(index, state)`` that produces dictionary keys when
|
51
|
+
``return_type='dict'``. Defaults to using the positional index so existing
|
52
|
+
code continues to receive integer keys.
|
53
|
+
|
54
|
+
Examples
|
55
|
+
--------
|
56
|
+
.. code-block:: python
|
57
|
+
|
58
|
+
>>> import brainstate
|
59
|
+
>>> import jax.numpy as jnp
|
60
|
+
>>>
|
61
|
+
>>> param = brainstate.ParamState(jnp.ones(()), name='weights')
|
62
|
+
>>> bias = brainstate.ParamState(jnp.zeros(()), name='bias')
|
63
|
+
>>>
|
64
|
+
>>> def forward(x):
|
65
|
+
... _ = bias.value # read-only
|
66
|
+
... param.value = param.value * x # read + write
|
67
|
+
... return param.value + bias.value
|
68
|
+
>>>
|
69
|
+
>>> finder = brainstate.transform.StateFinder(
|
70
|
+
... forward,
|
71
|
+
... filter=brainstate.ParamState,
|
72
|
+
... usage='both',
|
73
|
+
... key_fn=lambda i, st: st.name or i,
|
74
|
+
... )
|
75
|
+
>>> finder(2.0)['write'] # doctest: +ELLIPSIS
|
76
|
+
{'weights': ParamState(...}
|
77
|
+
|
78
|
+
Notes
|
79
|
+
-----
|
80
|
+
The underlying :class:`StatefulFunction` is cached, so subsequent calls with
|
81
|
+
compatible arguments will reuse the compiled trace.
|
82
|
+
"""
|
83
|
+
|
84
|
+
_VALID_USAGE: tuple[str, ...] = ('all', 'read', 'write', 'both')
|
85
|
+
_VALID_RETURN_TYPE: tuple[str, ...] = ('dict', 'list', 'tuple')
|
86
|
+
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
fn: Callable,
|
90
|
+
filter: Filter = None,
|
91
|
+
*,
|
92
|
+
usage: Literal['all', 'read', 'write', 'both'] = 'all',
|
93
|
+
return_type: Literal['dict', 'list', 'tuple'] = 'dict',
|
94
|
+
key_fn: Callable[[int, State], Hashable] | None = None,
|
95
|
+
) -> None:
|
96
|
+
if usage not in self._VALID_USAGE:
|
97
|
+
raise ValueError(f"Invalid usage '{usage}'. Expected one of {self._VALID_USAGE}.")
|
98
|
+
if return_type not in self._VALID_RETURN_TYPE:
|
99
|
+
raise ValueError(
|
100
|
+
f"Invalid return_type '{return_type}'. Expected one of {self._VALID_RETURN_TYPE}."
|
101
|
+
)
|
102
|
+
|
103
|
+
self.fn = fn
|
104
|
+
self._usage = usage
|
105
|
+
self._return_type = return_type
|
106
|
+
self._key_fn = key_fn if key_fn is not None else self._default_key_fn
|
107
|
+
self._filter = to_predicate(filter) if filter is not None else None
|
108
|
+
self.stateful_fn = StatefulFunction(self.fn)
|
109
|
+
|
110
|
+
def __call__(self, *args, **kwargs):
|
111
|
+
"""
|
112
|
+
Invoke :meth:`find` to retrieve states touched by ``fn``.
|
113
|
+
"""
|
114
|
+
return self.find(*args, **kwargs)
|
115
|
+
|
116
|
+
def find(self, *args, **kwargs):
|
117
|
+
"""
|
118
|
+
Execute the wrapped function symbolically and return the selected states.
|
119
|
+
|
120
|
+
Parameters
|
121
|
+
----------
|
122
|
+
*args, **kwargs
|
123
|
+
Arguments forwarded to ``fn`` to determine the state trace.
|
124
|
+
|
125
|
+
Returns
|
126
|
+
-------
|
127
|
+
Any
|
128
|
+
Container holding the requested states as configured by ``usage`` and
|
129
|
+
``return_type``.
|
130
|
+
"""
|
131
|
+
if self._usage == 'both':
|
132
|
+
read_states = self._collect_states('read', *args, **kwargs)
|
133
|
+
write_states = self._collect_states('write', *args, **kwargs)
|
134
|
+
return {
|
135
|
+
'read': self._format_states(read_states),
|
136
|
+
'write': self._format_states(write_states),
|
137
|
+
}
|
138
|
+
|
139
|
+
states = self._collect_states(self._usage, *args, **kwargs)
|
140
|
+
return self._format_states(states)
|
141
|
+
|
142
|
+
def _collect_states(self, usage: str, *args, **kwargs) -> Sequence[State]:
|
143
|
+
usage_map = {
|
144
|
+
'all': self.stateful_fn.get_states,
|
145
|
+
'read': self.stateful_fn.get_read_states,
|
146
|
+
'write': self.stateful_fn.get_write_states,
|
147
|
+
}
|
148
|
+
collector = usage_map.get(usage)
|
149
|
+
if collector is None:
|
150
|
+
raise ValueError(f"Unsupported usage '{usage}'.")
|
151
|
+
states = list(collector(*args, **kwargs))
|
152
|
+
if self._filter is not None:
|
153
|
+
states = [st for st in states if self._filter(tuple(), st)]
|
154
|
+
return states
|
155
|
+
|
156
|
+
def _format_states(self, states: Sequence[State]):
|
157
|
+
if self._return_type == 'list':
|
158
|
+
return list(states)
|
159
|
+
if self._return_type == 'tuple':
|
160
|
+
return tuple(states)
|
161
|
+
return self._states_to_dict(states)
|
162
|
+
|
163
|
+
def _states_to_dict(self, states: Sequence[State]) -> Dict[Hashable, State]:
|
164
|
+
result: Dict[Hashable, State] = {}
|
165
|
+
used_keys: set[Hashable] = set()
|
166
|
+
for idx, state in enumerate(states):
|
167
|
+
key = self._key_fn(idx, state)
|
168
|
+
key = self._ensure_hashable(key)
|
169
|
+
key = self._ensure_unique_key(key, idx, state, used_keys)
|
170
|
+
result[key] = state
|
171
|
+
used_keys.add(key)
|
172
|
+
return result
|
173
|
+
|
174
|
+
@staticmethod
|
175
|
+
def _default_key_fn(idx: int, state: State) -> Hashable:
|
176
|
+
return idx
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def _ensure_hashable(key: Any) -> Hashable:
|
180
|
+
if key is None:
|
181
|
+
return None
|
182
|
+
try:
|
183
|
+
hash(key)
|
184
|
+
except TypeError:
|
185
|
+
return str(key)
|
186
|
+
return key
|
187
|
+
|
188
|
+
@staticmethod
|
189
|
+
def _ensure_unique_key(key: Hashable, idx: int, state: State, used: set[Hashable]) -> Hashable:
|
190
|
+
if key is None or key in used:
|
191
|
+
base_name = getattr(state, 'name', None)
|
192
|
+
base = base_name if base_name not in (None, '') else f"state_{idx}"
|
193
|
+
candidate = base
|
194
|
+
suffix = 1
|
195
|
+
while candidate in used:
|
196
|
+
candidate = f"{base}_{suffix}"
|
197
|
+
suffix += 1
|
198
|
+
return candidate
|
199
|
+
return key
|
200
|
+
|
@@ -0,0 +1,84 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
import brainstate as bst
|
6
|
+
from brainstate.transform import StateFinder
|
7
|
+
|
8
|
+
|
9
|
+
class TestStateFinder(unittest.TestCase):
|
10
|
+
def test_default_dictionary_output(self):
|
11
|
+
read_state = bst.State(jnp.array(0.0), name='read_state')
|
12
|
+
param_state = bst.ParamState(jnp.array(1.0), name='param_state')
|
13
|
+
|
14
|
+
def fn(scale):
|
15
|
+
_ = read_state.value
|
16
|
+
param_state.value = param_state.value * scale
|
17
|
+
return param_state.value + _
|
18
|
+
|
19
|
+
finder = StateFinder(fn)
|
20
|
+
result = finder(2.0)
|
21
|
+
self.assertEqual(len(result), 2)
|
22
|
+
self.assertEqual(set(result.values()), {read_state, param_state})
|
23
|
+
|
24
|
+
def test_filter_and_usage_read(self):
|
25
|
+
buffer_state = bst.State(jnp.array(1.0), name='buffer')
|
26
|
+
param_state = bst.ParamState(jnp.array(3.0), name='param')
|
27
|
+
|
28
|
+
def fn(offset):
|
29
|
+
_ = buffer_state.value
|
30
|
+
param_state.value = param_state.value + offset
|
31
|
+
return param_state.value
|
32
|
+
|
33
|
+
read_finder = StateFinder(fn, usage='read', return_type='list')
|
34
|
+
reads = read_finder(1.0)
|
35
|
+
self.assertEqual(reads, [buffer_state])
|
36
|
+
|
37
|
+
param_finder = StateFinder(fn, filter=bst.ParamState, usage='all')
|
38
|
+
param_states = param_finder(1.0)
|
39
|
+
self.assertEqual(list(param_states.values()), [param_state])
|
40
|
+
|
41
|
+
def test_usage_write_with_custom_key(self):
|
42
|
+
param_state = bst.ParamState(jnp.array(5.0), name='param')
|
43
|
+
|
44
|
+
def fn(scale):
|
45
|
+
param_state.value = param_state.value * scale
|
46
|
+
return param_state.value
|
47
|
+
|
48
|
+
finder = StateFinder(fn, usage='write', return_type='dict', key_fn=lambda idx, st: f"w_{idx}")
|
49
|
+
write_states = finder(2.0)
|
50
|
+
self.assertIn('w_0', write_states)
|
51
|
+
self.assertIs(write_states['w_0'], param_state)
|
52
|
+
|
53
|
+
def test_usage_both_returns_separated_collections(self):
|
54
|
+
read_state = bst.State(jnp.array(2.0), name='read')
|
55
|
+
write_state = bst.ParamState(jnp.array(4.0), name='write')
|
56
|
+
|
57
|
+
def fn(delta):
|
58
|
+
_ = read_state.value
|
59
|
+
write_state.value = write_state.value + delta
|
60
|
+
return write_state.value
|
61
|
+
|
62
|
+
finder = StateFinder(fn, usage='both', return_type='tuple')
|
63
|
+
result = finder(1.5)
|
64
|
+
self.assertEqual(set(result.keys()), {'read', 'write'})
|
65
|
+
self.assertEqual(result['read'], (read_state,))
|
66
|
+
self.assertEqual(result['write'], (write_state,))
|
67
|
+
|
68
|
+
def test_duplicate_names_are_disambiguated(self):
|
69
|
+
first = bst.State(jnp.array(0.0), name='dup')
|
70
|
+
second = bst.State(jnp.array(1.0), name='dup')
|
71
|
+
|
72
|
+
def fn():
|
73
|
+
_ = first.value
|
74
|
+
_ = second.value
|
75
|
+
return None
|
76
|
+
|
77
|
+
finder = StateFinder(fn)
|
78
|
+
states = finder()
|
79
|
+
self.assertEqual(len(states), 2)
|
80
|
+
self.assertEqual(set(states.values()), {first, second})
|
81
|
+
|
82
|
+
|
83
|
+
if __name__ == "__main__":
|
84
|
+
unittest.main()
|