brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
"""
|
20
|
+
All the basic dynamics class for the ``brainstate``.
|
21
|
+
|
22
|
+
For handling dynamical systems:
|
23
|
+
|
24
|
+
- ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
|
25
|
+
then ``Dynamics``, finally others.
|
26
|
+
- ``Projection``: The class for the synaptic projection.
|
27
|
+
- ``Dynamics``: The class for the dynamical system.
|
28
|
+
|
29
|
+
For handling the delays:
|
30
|
+
|
31
|
+
- ``Delay``: The class for all delays.
|
32
|
+
- ``DelayAccess``: The class for the delay access.
|
33
|
+
|
34
|
+
"""
|
35
|
+
from __future__ import annotations
|
36
|
+
|
37
|
+
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
38
|
+
|
39
|
+
import brainunit as u
|
40
|
+
import numpy as np
|
41
|
+
|
42
|
+
from brainstate import environ
|
43
|
+
from brainstate._state import State
|
44
|
+
from brainstate.graph import Node
|
45
|
+
from brainstate.mixin import ParamDescriber
|
46
|
+
from brainstate.nn._module import Module
|
47
|
+
from brainstate.typing import Size, ArrayLike
|
48
|
+
from ._state_delay import StateWithDelay, Delay
|
49
|
+
|
50
|
+
__all__ = [
|
51
|
+
'DynamicsGroup', 'Projection', 'Dynamics', 'Prefetch',
|
52
|
+
]
|
53
|
+
|
54
|
+
T = TypeVar('T')
|
55
|
+
_max_order = 10
|
56
|
+
|
57
|
+
|
58
|
+
class Projection(Module):
|
59
|
+
"""
|
60
|
+
Base class to model synaptic projections.
|
61
|
+
"""
|
62
|
+
|
63
|
+
__module__ = 'brainstate.nn'
|
64
|
+
|
65
|
+
def update(self, *args, **kwargs):
|
66
|
+
sub_nodes = tuple(self.nodes(allowed_hierarchy=(1, 1)).values())
|
67
|
+
if len(sub_nodes):
|
68
|
+
for node in sub_nodes:
|
69
|
+
node(*args, **kwargs)
|
70
|
+
else:
|
71
|
+
raise ValueError('Do not implement the update() function.')
|
72
|
+
|
73
|
+
|
74
|
+
class Dynamics(Module):
|
75
|
+
"""
|
76
|
+
Base class to model dynamics.
|
77
|
+
|
78
|
+
.. note::
|
79
|
+
In general, every instance of :py:class:`~.Module` implemented in
|
80
|
+
BrainPy only defines the evolving function at each time step :math:`t`.
|
81
|
+
|
82
|
+
If users want to define the logic of running models across multiple steps,
|
83
|
+
we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
|
84
|
+
:py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
|
85
|
+
|
86
|
+
To be compatible with previous APIs, :py:class:`~.Module` inherits
|
87
|
+
from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
|
88
|
+
:py:class:`~.DelayRegister` will be removed in the future, including:
|
89
|
+
|
90
|
+
- ``.register_delay()``
|
91
|
+
- ``.get_delay_data()``
|
92
|
+
- ``.update_local_delays()``
|
93
|
+
- ``.reset_local_delays()``
|
94
|
+
|
95
|
+
There are several essential attributes:
|
96
|
+
|
97
|
+
- ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
|
98
|
+
neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
|
99
|
+
a 3-dimensional neuron group.
|
100
|
+
- ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
|
101
|
+
`num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
in_size: The neuron group geometry.
|
105
|
+
name: The name of the dynamic system.
|
106
|
+
"""
|
107
|
+
|
108
|
+
__module__ = 'brainstate.nn'
|
109
|
+
|
110
|
+
# before updates
|
111
|
+
_before_updates: Optional[Dict[Hashable, Callable]]
|
112
|
+
|
113
|
+
# after updates
|
114
|
+
_after_updates: Optional[Dict[Hashable, Callable]]
|
115
|
+
|
116
|
+
# current inputs
|
117
|
+
_current_inputs: Optional[Dict[str, ArrayLike | Callable]]
|
118
|
+
|
119
|
+
# delta inputs
|
120
|
+
_delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
|
121
|
+
|
122
|
+
def __init__(
|
123
|
+
self,
|
124
|
+
in_size: Size,
|
125
|
+
name: Optional[str] = None,
|
126
|
+
):
|
127
|
+
# initialize
|
128
|
+
super().__init__(name=name)
|
129
|
+
|
130
|
+
# geometry size of neuron population
|
131
|
+
if isinstance(in_size, (list, tuple)):
|
132
|
+
if len(in_size) <= 0:
|
133
|
+
raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
|
134
|
+
if not isinstance(in_size[0], (int, np.integer)):
|
135
|
+
raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
|
136
|
+
in_size = tuple(in_size)
|
137
|
+
elif isinstance(in_size, (int, np.integer)):
|
138
|
+
in_size = (in_size,)
|
139
|
+
else:
|
140
|
+
raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
|
141
|
+
self.in_size = in_size
|
142
|
+
|
143
|
+
# current inputs
|
144
|
+
self._current_inputs = None
|
145
|
+
|
146
|
+
# delta inputs
|
147
|
+
self._delta_inputs = None
|
148
|
+
|
149
|
+
# before updates
|
150
|
+
self._before_updates = None
|
151
|
+
|
152
|
+
# after updates
|
153
|
+
self._after_updates = None
|
154
|
+
|
155
|
+
# in-/out- size of neuron population
|
156
|
+
self.out_size = self.in_size
|
157
|
+
|
158
|
+
@property
|
159
|
+
def varshape(self):
|
160
|
+
"""The shape of variables in the neuron group."""
|
161
|
+
return self.in_size
|
162
|
+
|
163
|
+
@property
|
164
|
+
def current_inputs(self):
|
165
|
+
"""
|
166
|
+
The current inputs of the model. It should be a dictionary of the input data.
|
167
|
+
"""
|
168
|
+
return self._current_inputs
|
169
|
+
|
170
|
+
@property
|
171
|
+
def delta_inputs(self):
|
172
|
+
"""
|
173
|
+
The delta inputs of the model. It should be a dictionary of the input data.
|
174
|
+
"""
|
175
|
+
return self._delta_inputs
|
176
|
+
|
177
|
+
def add_current_input(
|
178
|
+
self,
|
179
|
+
key: str,
|
180
|
+
inp: Union[Callable, ArrayLike],
|
181
|
+
label: Optional[str] = None
|
182
|
+
):
|
183
|
+
"""
|
184
|
+
Add a current input function.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
key: str. The dict key.
|
188
|
+
inp: Callable, ArrayLike. The currents or the function to generate currents.
|
189
|
+
label: str. The input label.
|
190
|
+
"""
|
191
|
+
key = _input_label_repr(key, label)
|
192
|
+
if self._current_inputs is None:
|
193
|
+
self._current_inputs = dict()
|
194
|
+
if key in self._current_inputs:
|
195
|
+
if id(self._current_inputs[key]) != id(inp):
|
196
|
+
raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
|
197
|
+
self._current_inputs[key] = inp
|
198
|
+
|
199
|
+
def add_delta_input(
|
200
|
+
self,
|
201
|
+
key: str,
|
202
|
+
inp: Union[Callable, ArrayLike],
|
203
|
+
label: Optional[str] = None
|
204
|
+
):
|
205
|
+
"""
|
206
|
+
Add a delta input function.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
key: str. The dict key.
|
210
|
+
inp: Callable, ArrayLike. The currents or the function to generate currents.
|
211
|
+
label: str. The input label.
|
212
|
+
"""
|
213
|
+
key = _input_label_repr(key, label)
|
214
|
+
if self._delta_inputs is None:
|
215
|
+
self._delta_inputs = dict()
|
216
|
+
if key in self._delta_inputs:
|
217
|
+
if id(self._delta_inputs[key]) != id(inp):
|
218
|
+
raise ValueError(f'Key "{key}" has been defined and used.')
|
219
|
+
self._delta_inputs[key] = inp
|
220
|
+
|
221
|
+
def get_input(self, key: str):
|
222
|
+
"""Get the input function.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
key: str. The key.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
The input function which generates currents.
|
229
|
+
"""
|
230
|
+
if self._current_inputs is not None and key in self._current_inputs:
|
231
|
+
return self._current_inputs[key]
|
232
|
+
elif self._delta_inputs is not None and key in self._delta_inputs:
|
233
|
+
return self._delta_inputs[key]
|
234
|
+
else:
|
235
|
+
raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
|
236
|
+
|
237
|
+
def sum_current_inputs(
|
238
|
+
self,
|
239
|
+
init: Any,
|
240
|
+
*args,
|
241
|
+
label: Optional[str] = None,
|
242
|
+
**kwargs
|
243
|
+
):
|
244
|
+
"""
|
245
|
+
Summarize all current inputs by the defined input functions ``.current_inputs``.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
init: The initial input data.
|
249
|
+
*args: The arguments for input functions.
|
250
|
+
**kwargs: The arguments for input functions.
|
251
|
+
label: str. The input label.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
The total currents.
|
255
|
+
"""
|
256
|
+
if self._current_inputs is None:
|
257
|
+
return init
|
258
|
+
if label is None:
|
259
|
+
# no label
|
260
|
+
for key in tuple(self._current_inputs.keys()):
|
261
|
+
out = self._current_inputs[key]
|
262
|
+
init = init + (out(*args, **kwargs) if callable(out) else out)
|
263
|
+
if not callable(out):
|
264
|
+
self._current_inputs.pop(key)
|
265
|
+
else:
|
266
|
+
# has label
|
267
|
+
label_repr = _input_label_start(label)
|
268
|
+
for key in tuple(self._current_inputs.keys()):
|
269
|
+
if key.startswith(label_repr):
|
270
|
+
out = self._current_inputs[key]
|
271
|
+
init = init + (out(*args, **kwargs) if callable(out) else out)
|
272
|
+
if not callable(out):
|
273
|
+
self._current_inputs.pop(key)
|
274
|
+
return init
|
275
|
+
|
276
|
+
def sum_delta_inputs(
|
277
|
+
self,
|
278
|
+
init: Any,
|
279
|
+
*args,
|
280
|
+
label: Optional[str] = None,
|
281
|
+
**kwargs
|
282
|
+
):
|
283
|
+
"""
|
284
|
+
Summarize all delta inputs by the defined input functions ``.delta_inputs``.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
init: The initial input data.
|
288
|
+
*args: The arguments for input functions.
|
289
|
+
**kwargs: The arguments for input functions.
|
290
|
+
label: str. The input label.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
The total currents.
|
294
|
+
"""
|
295
|
+
if self._delta_inputs is None:
|
296
|
+
return init
|
297
|
+
if label is None:
|
298
|
+
# no label
|
299
|
+
for key in tuple(self._delta_inputs.keys()):
|
300
|
+
out = self._delta_inputs[key]
|
301
|
+
init = init + (out(*args, **kwargs) if callable(out) else out)
|
302
|
+
if not callable(out):
|
303
|
+
self._delta_inputs.pop(key)
|
304
|
+
else:
|
305
|
+
# has label
|
306
|
+
label_repr = _input_label_start(label)
|
307
|
+
for key in tuple(self._delta_inputs.keys()):
|
308
|
+
if key.startswith(label_repr):
|
309
|
+
out = self._delta_inputs[key]
|
310
|
+
init = init + (out(*args, **kwargs) if callable(out) else out)
|
311
|
+
if not callable(out):
|
312
|
+
self._delta_inputs.pop(key)
|
313
|
+
return init
|
314
|
+
|
315
|
+
@property
|
316
|
+
def before_updates(self):
|
317
|
+
"""
|
318
|
+
The before updates of the model. It should be a dictionary of the updating functions.
|
319
|
+
"""
|
320
|
+
return self._before_updates
|
321
|
+
|
322
|
+
@property
|
323
|
+
def after_updates(self):
|
324
|
+
"""
|
325
|
+
The after updates of the model. It should be a dictionary of the updating functions.
|
326
|
+
"""
|
327
|
+
return self._after_updates
|
328
|
+
|
329
|
+
def _add_before_update(self, key: Any, fun: Callable):
|
330
|
+
"""
|
331
|
+
Add the before update into this node.
|
332
|
+
"""
|
333
|
+
if self._before_updates is None:
|
334
|
+
self._before_updates = dict()
|
335
|
+
if key in self.before_updates:
|
336
|
+
raise KeyError(f'{key} has been registered in before_updates of {self}')
|
337
|
+
self.before_updates[key] = fun
|
338
|
+
|
339
|
+
def _add_after_update(self, key: Any, fun: Callable):
|
340
|
+
"""Add the after update into this node"""
|
341
|
+
if self._after_updates is None:
|
342
|
+
self._after_updates = dict()
|
343
|
+
if key in self.after_updates:
|
344
|
+
raise KeyError(f'{key} has been registered in after_updates of {self}')
|
345
|
+
self.after_updates[key] = fun
|
346
|
+
|
347
|
+
def _get_before_update(self, key: Any):
|
348
|
+
"""Get the before update of this node by the given ``key``."""
|
349
|
+
if self._before_updates is None:
|
350
|
+
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
351
|
+
if key not in self.before_updates:
|
352
|
+
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
353
|
+
return self.before_updates.get(key)
|
354
|
+
|
355
|
+
def _get_after_update(self, key: Any):
|
356
|
+
"""Get the after update of this node by the given ``key``."""
|
357
|
+
if self._after_updates is None:
|
358
|
+
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
359
|
+
if key not in self.after_updates:
|
360
|
+
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
361
|
+
return self.after_updates.get(key)
|
362
|
+
|
363
|
+
def _has_before_update(self, key: Any):
|
364
|
+
"""Whether this node has the before update of the given ``key``."""
|
365
|
+
if self._before_updates is None:
|
366
|
+
return False
|
367
|
+
return key in self.before_updates
|
368
|
+
|
369
|
+
def _has_after_update(self, key: Any):
|
370
|
+
"""Whether this node has the after update of the given ``key``."""
|
371
|
+
if self._after_updates is None:
|
372
|
+
return False
|
373
|
+
return key in self.after_updates
|
374
|
+
|
375
|
+
def __call__(self, *args, **kwargs):
|
376
|
+
"""
|
377
|
+
The shortcut to call ``update`` methods.
|
378
|
+
"""
|
379
|
+
|
380
|
+
# ``before_updates``
|
381
|
+
if self.before_updates is not None:
|
382
|
+
for model in self.before_updates.values():
|
383
|
+
if hasattr(model, '_receive_update_input'):
|
384
|
+
model(*args, **kwargs)
|
385
|
+
else:
|
386
|
+
model()
|
387
|
+
|
388
|
+
# update the model self
|
389
|
+
ret = self.update(*args, **kwargs)
|
390
|
+
|
391
|
+
# ``after_updates``
|
392
|
+
if self.after_updates is not None:
|
393
|
+
for model in self.after_updates.values():
|
394
|
+
if hasattr(model, '_not_receive_update_output'):
|
395
|
+
model()
|
396
|
+
else:
|
397
|
+
model(ret)
|
398
|
+
return ret
|
399
|
+
|
400
|
+
def prefetch(self, item: str) -> 'Prefetch':
|
401
|
+
return Prefetch(self, item)
|
402
|
+
|
403
|
+
def align_pre(
|
404
|
+
self, dyn: Union[ParamDescriber[T], T]
|
405
|
+
) -> T:
|
406
|
+
"""
|
407
|
+
Align the dynamics before the interaction.
|
408
|
+
"""
|
409
|
+
if isinstance(dyn, Dynamics):
|
410
|
+
self._add_after_update(dyn.name, dyn)
|
411
|
+
return dyn
|
412
|
+
elif isinstance(dyn, ParamDescriber):
|
413
|
+
if not isinstance(dyn.cls, Dynamics):
|
414
|
+
raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
|
415
|
+
if not self._has_after_update(dyn.identifier):
|
416
|
+
self._add_after_update(dyn.identifier, dyn())
|
417
|
+
return self._get_after_update(dyn.identifier)
|
418
|
+
else:
|
419
|
+
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
420
|
+
|
421
|
+
def __leaf_fn__(self, name, value):
|
422
|
+
if name in ['_in_size', '_out_size', '_name', '_mode',
|
423
|
+
'_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
|
424
|
+
return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
425
|
+
return name, value
|
426
|
+
|
427
|
+
|
428
|
+
class Prefetch(Node):
|
429
|
+
"""
|
430
|
+
Prefetch a variable of the given module.
|
431
|
+
"""
|
432
|
+
|
433
|
+
def __init__(self, module: Module, item: str):
|
434
|
+
super().__init__()
|
435
|
+
self.module = module
|
436
|
+
self.item = item
|
437
|
+
|
438
|
+
@property
|
439
|
+
def delay(self):
|
440
|
+
return PrefetchDelay(self.module, self.item)
|
441
|
+
|
442
|
+
def __call__(self, *args, **kwargs):
|
443
|
+
item = _get_prefetch_item(self)
|
444
|
+
return item.value if isinstance(item, State) else item
|
445
|
+
|
446
|
+
|
447
|
+
class PrefetchDelay(Node):
|
448
|
+
def __init__(self, module: Dynamics, item: str):
|
449
|
+
self.module = module
|
450
|
+
self.item = item
|
451
|
+
|
452
|
+
def at(self, time: ArrayLike):
|
453
|
+
return PrefetchDelayAt(self.module, self.item, time)
|
454
|
+
|
455
|
+
|
456
|
+
class PrefetchDelayAt(Node):
|
457
|
+
"""
|
458
|
+
Prefetch the delay of a variable in the given module at a specific time.
|
459
|
+
|
460
|
+
Args:
|
461
|
+
module: The module that has the item with the name specified by ``item`` argument.
|
462
|
+
item: The item that has the delay.
|
463
|
+
time: The time to retrieve the delay.
|
464
|
+
"""
|
465
|
+
|
466
|
+
def __init__(self, module: Dynamics, item: str, time: ArrayLike):
|
467
|
+
super().__init__()
|
468
|
+
assert isinstance(module, Dynamics), ''
|
469
|
+
self.module = module
|
470
|
+
self.item = item
|
471
|
+
self.time = time
|
472
|
+
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
473
|
+
|
474
|
+
# register the delay
|
475
|
+
key = _get_delay_key(item)
|
476
|
+
if not module._has_after_update(key):
|
477
|
+
module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
|
478
|
+
self.state_delay: StateWithDelay = module._get_after_update(key)
|
479
|
+
self.state_delay.register_delay(time)
|
480
|
+
|
481
|
+
def __call__(self, *args, **kwargs):
|
482
|
+
# return self.state_delay.retrieve_at_time(self.time)
|
483
|
+
return self.state_delay.retrieve_at_step(self.step)
|
484
|
+
|
485
|
+
|
486
|
+
def _get_delay_key(item) -> str:
|
487
|
+
return f'{item}-delay'
|
488
|
+
|
489
|
+
|
490
|
+
def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
|
491
|
+
item = getattr(target.module, target.item, None)
|
492
|
+
if item is None:
|
493
|
+
raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
|
494
|
+
return item
|
495
|
+
|
496
|
+
|
497
|
+
def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
|
498
|
+
assert isinstance(target.module, Dynamics), (f'The target module should be an instance '
|
499
|
+
f'of Dynamics. But got {target.module}.')
|
500
|
+
delay = target.module._get_after_update(_get_delay_key(target.item))
|
501
|
+
if not isinstance(delay, StateWithDelay):
|
502
|
+
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
503
|
+
f'its delay. But got {delay}.')
|
504
|
+
return delay
|
505
|
+
|
506
|
+
|
507
|
+
def maybe_init_prefetch(target, *args, **kwargs):
|
508
|
+
if isinstance(target, Prefetch):
|
509
|
+
_get_prefetch_item(target)
|
510
|
+
|
511
|
+
elif isinstance(target, PrefetchDelay):
|
512
|
+
_get_prefetch_item_delay(target)
|
513
|
+
|
514
|
+
elif isinstance(target, PrefetchDelayAt):
|
515
|
+
delay = _get_prefetch_item_delay(target)
|
516
|
+
delay.register_delay(target.time)
|
517
|
+
|
518
|
+
|
519
|
+
class DynamicsGroup(Module):
|
520
|
+
"""
|
521
|
+
A group of :py:class:`~.Module` in which the updating order does not matter.
|
522
|
+
|
523
|
+
Args:
|
524
|
+
children_as_tuple: The children objects.
|
525
|
+
children_as_dict: The children objects.
|
526
|
+
"""
|
527
|
+
|
528
|
+
__module__ = 'brainstate.nn'
|
529
|
+
|
530
|
+
if not TYPE_CHECKING:
|
531
|
+
def __init__(self, *children_as_tuple, **children_as_dict):
|
532
|
+
super().__init__()
|
533
|
+
self.layers_tuple = tuple(children_as_tuple)
|
534
|
+
self.layers_dict = dict(children_as_dict)
|
535
|
+
|
536
|
+
def update(self, *args, **kwargs):
|
537
|
+
"""
|
538
|
+
Update function of a network.
|
539
|
+
|
540
|
+
In this update function, the update functions in children systems are iteratively called.
|
541
|
+
"""
|
542
|
+
projs, dyns, others = self.nodes(allowed_hierarchy=(1, 1)).split(Projection, Dynamics)
|
543
|
+
|
544
|
+
# update nodes of projections
|
545
|
+
for node in projs.values():
|
546
|
+
node()
|
547
|
+
|
548
|
+
# update nodes of dynamics
|
549
|
+
for node in dyns.values():
|
550
|
+
node()
|
551
|
+
|
552
|
+
# update nodes with other types, including delays, ...
|
553
|
+
for node in others.values():
|
554
|
+
node()
|
555
|
+
|
556
|
+
|
557
|
+
def receive_update_output(cls: object):
|
558
|
+
"""
|
559
|
+
The decorator to mark the object (as the after updates) to receive the output of the update function.
|
560
|
+
|
561
|
+
That is, the `aft_update` will receive the return of the update function::
|
562
|
+
|
563
|
+
ret = model.update(*args, **kwargs)
|
564
|
+
for fun in model.aft_updates:
|
565
|
+
fun(ret)
|
566
|
+
|
567
|
+
"""
|
568
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
569
|
+
if hasattr(cls, '_not_receive_update_output'):
|
570
|
+
delattr(cls, '_not_receive_update_output')
|
571
|
+
return cls
|
572
|
+
|
573
|
+
|
574
|
+
def not_receive_update_output(cls: T) -> T:
|
575
|
+
"""
|
576
|
+
The decorator to mark the object (as the after updates) to not receive the output of the update function.
|
577
|
+
|
578
|
+
That is, the `aft_update` will not receive the return of the update function::
|
579
|
+
|
580
|
+
ret = model.update(*args, **kwargs)
|
581
|
+
for fun in model.aft_updates:
|
582
|
+
fun()
|
583
|
+
|
584
|
+
"""
|
585
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
586
|
+
cls._not_receive_update_output = True
|
587
|
+
return cls
|
588
|
+
|
589
|
+
|
590
|
+
def receive_update_input(cls: object):
|
591
|
+
"""
|
592
|
+
The decorator to mark the object (as the before updates) to receive the input of the update function.
|
593
|
+
|
594
|
+
That is, the `bef_update` will receive the input of the update function::
|
595
|
+
|
596
|
+
|
597
|
+
for fun in model.bef_updates:
|
598
|
+
fun(*args, **kwargs)
|
599
|
+
model.update(*args, **kwargs)
|
600
|
+
|
601
|
+
"""
|
602
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
603
|
+
cls._receive_update_input = True
|
604
|
+
return cls
|
605
|
+
|
606
|
+
|
607
|
+
def not_receive_update_input(cls: object):
|
608
|
+
"""
|
609
|
+
The decorator to mark the object (as the before updates) to not receive the input of the update function.
|
610
|
+
|
611
|
+
That is, the `bef_update` will not receive the input of the update function::
|
612
|
+
|
613
|
+
for fun in model.bef_updates:
|
614
|
+
fun()
|
615
|
+
model.update()
|
616
|
+
|
617
|
+
"""
|
618
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
619
|
+
if hasattr(cls, '_receive_update_input'):
|
620
|
+
delattr(cls, '_receive_update_input')
|
621
|
+
return cls
|
622
|
+
|
623
|
+
|
624
|
+
def _input_label_start(label: str):
|
625
|
+
# unify the input label repr.
|
626
|
+
return f'{label} // '
|
627
|
+
|
628
|
+
|
629
|
+
def _input_label_repr(name: str, label: Optional[str] = None):
|
630
|
+
# unify the input label repr.
|
631
|
+
return name if label is None else (_input_label_start(label) + str(name))
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import unittest
|
21
|
+
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
import brainstate as bst
|
25
|
+
|
26
|
+
|
27
|
+
class TestModuleGroup(unittest.TestCase):
|
28
|
+
def test_initialization(self):
|
29
|
+
group = bst.nn.DynamicsGroup()
|
30
|
+
self.assertIsInstance(group, bst.nn.DynamicsGroup)
|
31
|
+
|
32
|
+
|
33
|
+
class TestProjection(unittest.TestCase):
|
34
|
+
def test_initialization(self):
|
35
|
+
proj = bst.nn.Projection()
|
36
|
+
self.assertIsInstance(proj, bst.nn.Projection)
|
37
|
+
|
38
|
+
def test_update_not_implemented(self):
|
39
|
+
proj = bst.nn.Projection()
|
40
|
+
with self.assertRaises(ValueError):
|
41
|
+
proj.update()
|
42
|
+
|
43
|
+
|
44
|
+
class TestDynamics(unittest.TestCase):
|
45
|
+
def test_initialization(self):
|
46
|
+
dyn = bst.nn.Dynamics(in_size=10)
|
47
|
+
self.assertIsInstance(dyn, bst.nn.Dynamics)
|
48
|
+
self.assertEqual(dyn.in_size, (10,))
|
49
|
+
self.assertEqual(dyn.out_size, (10,))
|
50
|
+
|
51
|
+
def test_size_validation(self):
|
52
|
+
with self.assertRaises(ValueError):
|
53
|
+
bst.nn.Dynamics(in_size=[])
|
54
|
+
with self.assertRaises(ValueError):
|
55
|
+
bst.nn.Dynamics(in_size="invalid")
|
56
|
+
|
57
|
+
def test_input_handling(self):
|
58
|
+
dyn = bst.nn.Dynamics(in_size=10)
|
59
|
+
dyn.add_current_input("test_current", lambda: np.random.rand(10))
|
60
|
+
dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
|
61
|
+
|
62
|
+
self.assertIn("test_current", dyn.current_inputs)
|
63
|
+
self.assertIn("test_delta", dyn.delta_inputs)
|
64
|
+
|
65
|
+
def test_duplicate_input_key(self):
|
66
|
+
dyn = bst.nn.Dynamics(in_size=10)
|
67
|
+
dyn.add_current_input("test", lambda: np.random.rand(10))
|
68
|
+
with self.assertRaises(ValueError):
|
69
|
+
dyn.add_current_input("test", lambda: np.random.rand(10))
|
70
|
+
|
71
|
+
def test_varshape(self):
|
72
|
+
dyn = bst.nn.Dynamics(in_size=(2, 3))
|
73
|
+
self.assertEqual(dyn.varshape, (2, 3))
|
74
|
+
dyn = bst.nn.Dynamics(in_size=(2, 3))
|
75
|
+
self.assertEqual(dyn.varshape, (2, 3))
|
76
|
+
|
77
|
+
|
78
|
+
if __name__ == '__main__':
|
79
|
+
unittest.main()
|