brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/mixin_test.py CHANGED
@@ -19,57 +19,55 @@ import brainstate as bc
19
19
 
20
20
 
21
21
  class TestMixin(unittest.TestCase):
22
- def test_mixin(self):
23
- self.assertTrue(bc.mixin.Mixin)
24
- self.assertTrue(bc.mixin.DelayedInit)
25
- self.assertTrue(bc.mixin.DelayedInitializer)
26
- self.assertTrue(bc.mixin.JointTypes)
27
- self.assertTrue(bc.mixin.OneOfTypes)
28
- self.assertTrue(bc.mixin.Mode)
29
- self.assertTrue(bc.mixin.Batching)
30
- self.assertTrue(bc.mixin.Training)
31
-
32
-
22
+ def test_mixin(self):
23
+ self.assertTrue(bc.mixin.Mixin)
24
+ self.assertTrue(bc.mixin.ParamDesc)
25
+ self.assertTrue(bc.mixin.ParamDescriber)
26
+ self.assertTrue(bc.mixin.JointTypes)
27
+ self.assertTrue(bc.mixin.OneOfTypes)
28
+ self.assertTrue(bc.mixin.Mode)
29
+ self.assertTrue(bc.mixin.Batching)
30
+ self.assertTrue(bc.mixin.Training)
33
31
 
34
32
 
35
33
  class TestMode(unittest.TestCase):
36
- def test_JointMode(self):
37
- a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
38
- self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching, bc.mixin.Training]))
39
- self.assertTrue(a.has(bc.mixin.Batching))
40
- self.assertTrue(a.has(bc.mixin.Training))
41
- b = bc.mixin.JointMode(bc.mixin.Batching())
42
- self.assertTrue(b.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
43
- self.assertTrue(b.is_a(bc.mixin.Batching))
44
- self.assertTrue(b.has(bc.mixin.Batching))
34
+ def test_JointMode(self):
35
+ a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
36
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching, bc.mixin.Training]))
37
+ self.assertTrue(a.has(bc.mixin.Batching))
38
+ self.assertTrue(a.has(bc.mixin.Training))
39
+ b = bc.mixin.JointMode(bc.mixin.Batching())
40
+ self.assertTrue(b.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
41
+ self.assertTrue(b.is_a(bc.mixin.Batching))
42
+ self.assertTrue(b.has(bc.mixin.Batching))
45
43
 
46
- def test_Training(self):
47
- a = bc.mixin.Training()
48
- self.assertTrue(a.is_a(bc.mixin.Training))
49
- self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Training]))
50
- self.assertTrue(a.has(bc.mixin.Training))
51
- self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Training]))
52
- self.assertFalse(a.is_a(bc.mixin.Batching))
53
- self.assertFalse(a.has(bc.mixin.Batching))
44
+ def test_Training(self):
45
+ a = bc.mixin.Training()
46
+ self.assertTrue(a.is_a(bc.mixin.Training))
47
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Training]))
48
+ self.assertTrue(a.has(bc.mixin.Training))
49
+ self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Training]))
50
+ self.assertFalse(a.is_a(bc.mixin.Batching))
51
+ self.assertFalse(a.has(bc.mixin.Batching))
54
52
 
55
- def test_Batching(self):
56
- a = bc.mixin.Batching()
57
- self.assertTrue(a.is_a(bc.mixin.Batching))
58
- self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
59
- self.assertTrue(a.has(bc.mixin.Batching))
60
- self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Batching]))
53
+ def test_Batching(self):
54
+ a = bc.mixin.Batching()
55
+ self.assertTrue(a.is_a(bc.mixin.Batching))
56
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
57
+ self.assertTrue(a.has(bc.mixin.Batching))
58
+ self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Batching]))
61
59
 
62
- self.assertFalse(a.is_a(bc.mixin.Training))
63
- self.assertFalse(a.has(bc.mixin.Training))
60
+ self.assertFalse(a.is_a(bc.mixin.Training))
61
+ self.assertFalse(a.has(bc.mixin.Training))
64
62
 
65
- def test_Mode(self):
66
- a = bc.mixin.Mode()
67
- self.assertTrue(a.is_a(bc.mixin.Mode))
68
- self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Mode]))
69
- self.assertTrue(a.has(bc.mixin.Mode))
70
- self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Mode]))
63
+ def test_Mode(self):
64
+ a = bc.mixin.Mode()
65
+ self.assertTrue(a.is_a(bc.mixin.Mode))
66
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Mode]))
67
+ self.assertTrue(a.has(bc.mixin.Mode))
68
+ self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Mode]))
71
69
 
72
- self.assertFalse(a.is_a(bc.mixin.Training))
73
- self.assertFalse(a.has(bc.mixin.Training))
74
- self.assertFalse(a.is_a(bc.mixin.Batching))
75
- self.assertFalse(a.has(bc.mixin.Batching))
70
+ self.assertFalse(a.is_a(bc.mixin.Training))
71
+ self.assertFalse(a.has(bc.mixin.Training))
72
+ self.assertFalse(a.is_a(bc.mixin.Batching))
73
+ self.assertFalse(a.has(bc.mixin.Batching))
brainstate/nn/__init__.py CHANGED
@@ -13,65 +13,40 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+
16
17
  from . import metrics
17
- from ._base import *
18
- from ._base import __all__ as base_all
19
- from ._connections import *
20
- from ._connections import __all__ as connections_all
18
+ from ._collective_ops import *
19
+ from ._collective_ops import __all__ as collective_ops_all
20
+ from ._dyn_impl import *
21
+ from ._dyn_impl import __all__ as dyn_impl_all
21
22
  from ._dynamics import *
22
23
  from ._dynamics import __all__ as dynamics_all
23
24
  from ._elementwise import *
24
25
  from ._elementwise import __all__ as elementwise_all
25
- from ._embedding import *
26
- from ._embedding import __all__ as embed_all
27
- from ._misc import *
28
- from ._misc import __all__ as _misc_all
29
- from ._normalizations import *
30
- from ._normalizations import __all__ as normalizations_all
31
- from ._others import *
32
- from ._others import __all__ as others_all
33
- from ._poolings import *
34
- from ._poolings import __all__ as poolings_all
35
- from ._projection import *
36
- from ._projection import __all__ as _projection_all
37
- from ._rate_rnns import *
38
- from ._rate_rnns import __all__ as rate_rnns
39
- from ._readout import *
40
- from ._readout import __all__ as readout_all
41
- from ._synouts import *
42
- from ._synouts import __all__ as synouts_all
43
- from .event import *
44
- from .event import __all__ as event_all
26
+ from ._exp_euler import *
27
+ from ._exp_euler import __all__ as exp_euler_all
28
+ from ._interaction import *
29
+ from ._interaction import __all__ as interaction_all
30
+ from ._module import *
31
+ from ._module import __all__ as module_all
45
32
 
46
33
  __all__ = (
47
- base_all +
48
- connections_all +
49
- dynamics_all +
50
- elementwise_all +
51
- embed_all +
52
- normalizations_all +
53
- others_all +
54
- poolings_all +
55
- rate_rnns +
56
- readout_all +
57
- synouts_all +
58
- _projection_all +
59
- _misc_all +
60
- event_all
34
+ ['metrics']
35
+ + collective_ops_all
36
+ + dyn_impl_all
37
+ + dynamics_all
38
+ + elementwise_all
39
+ + module_all
40
+ + exp_euler_all
41
+ + interaction_all
61
42
  )
62
43
 
63
44
  del (
64
- base_all,
65
- connections_all,
66
- dynamics_all,
67
- elementwise_all,
68
- embed_all,
69
- normalizations_all,
70
- others_all,
71
- poolings_all,
72
- readout_all,
73
- synouts_all,
74
- _projection_all,
75
- _misc_all,
76
- event_all
45
+ collective_ops_all,
46
+ dyn_impl_all,
47
+ dynamics_all,
48
+ elementwise_all,
49
+ module_all,
50
+ exp_euler_all,
51
+ interaction_all,
77
52
  )
@@ -0,0 +1,199 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections import namedtuple
19
+ from typing import Dict, Callable, TypeVar
20
+
21
+ import jax
22
+
23
+ from brainstate._utils import set_module_as
24
+ from brainstate.graph import nodes
25
+ from ._module import Module
26
+
27
+ # the maximum order
28
+ MAX_ORDER = 10
29
+
30
+ # State Load Results
31
+ StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
32
+
33
+ T = TypeVar('T', bound=Module)
34
+
35
+ __all__ = [
36
+ 'MAX_ORDER', 'call_order', 'init_all_states', 'reset_all_states',
37
+ 'load_all_states', 'save_all_states', 'assign_state_values',
38
+ ]
39
+
40
+
41
+ @set_module_as('brainstate.nn')
42
+ def call_order(level: int = 0, check_order_boundary: bool = True):
43
+ """The decorator for indicating the resetting level.
44
+
45
+ The function takes an optional integer argument level with a default value of 0.
46
+
47
+ The lower the level, the earlier the function is called.
48
+
49
+ >>> import brainstate as bst
50
+ >>> bst.nn.call_order(0)
51
+ >>> bst.nn.call_order(-1)
52
+ >>> bst.nn.call_order(-2)
53
+
54
+ Parameters
55
+ ----------
56
+ level: int
57
+ The call order level.
58
+ check_order_boundary: bool
59
+ Whether check the boundary of function call order. If True,
60
+ the order that not in [0, 10) will raise a ValueError.
61
+
62
+ Returns
63
+ -------
64
+ The function to warp.
65
+ """
66
+ if check_order_boundary and (level < 0 or level >= MAX_ORDER):
67
+ raise ValueError(f'"call_order" must be an integer in [0, {MAX_ORDER}). but we got {level}.')
68
+
69
+ def wrap(fun: Callable):
70
+ fun.call_order = level
71
+ return fun
72
+
73
+ return wrap
74
+
75
+
76
+ @set_module_as('brainstate.nn')
77
+ def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
78
+ """
79
+ Collectively initialize states of all children nodes in the given target.
80
+
81
+ Args:
82
+ target: The target Module.
83
+
84
+ Returns:
85
+ The target Module.
86
+ """
87
+ nodes_with_order = []
88
+
89
+ nodes_ = nodes(target).filter(Module)
90
+ if exclude is not None:
91
+ nodes_ = nodes_ - nodes_.filter(exclude)
92
+
93
+ # reset node whose `init_state` has no `call_order`
94
+ for node in list(nodes_.values()):
95
+ if hasattr(node.init_state, 'call_order'):
96
+ nodes_with_order.append(node)
97
+ else:
98
+ node.init_state(*args, **kwargs)
99
+
100
+ # reset the node's states
101
+ for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
102
+ node.init_state(*args, **kwargs)
103
+
104
+ return target
105
+
106
+
107
+ @set_module_as('brainstate.nn')
108
+ def reset_all_states(target: Module, *args, **kwargs) -> Module:
109
+ """
110
+ Collectively reset states of all children nodes in the given target.
111
+
112
+ Args:
113
+ target: The target Module.
114
+
115
+ Returns:
116
+ The target Module.
117
+ """
118
+ nodes_with_order = []
119
+
120
+ # reset node whose `init_state` has no `call_order`
121
+ for path, node in nodes(target).filter(Module).items():
122
+ if hasattr(node.reset_state, 'call_order'):
123
+ nodes_with_order.append(node)
124
+ else:
125
+ node.reset_state(*args, **kwargs)
126
+
127
+ # reset the node's states
128
+ for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
129
+ node.reset_state(*args, **kwargs)
130
+
131
+ return target
132
+
133
+
134
+ @set_module_as('brainstate.nn')
135
+ def load_all_states(target: Module, state_dict: Dict, **kwargs):
136
+ """
137
+ Copy parameters and buffers from :attr:`state_dict` into
138
+ this module and its descendants.
139
+
140
+ Args:
141
+ target: Module. The dynamical system to load its states.
142
+ state_dict: dict. A dict containing parameters and persistent buffers.
143
+
144
+ Returns
145
+ -------
146
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
147
+
148
+ * **missing_keys** is a list of str containing the missing keys
149
+ * **unexpected_keys** is a list of str containing the unexpected keys
150
+ """
151
+ missing_keys = []
152
+ unexpected_keys = []
153
+ for path, node in nodes(target).items():
154
+ r = node.load_state(state_dict[path], **kwargs)
155
+ if r is not None:
156
+ missing, unexpected = r
157
+ missing_keys.extend([f'{path}.{key}' for key in missing])
158
+ unexpected_keys.extend([f'{path}.{key}' for key in unexpected])
159
+ return StateLoadResult(missing_keys, unexpected_keys)
160
+
161
+
162
+ @set_module_as('brainstate.nn')
163
+ def save_all_states(target: Module, **kwargs) -> Dict:
164
+ """
165
+ Save all states in the ``target`` as a dictionary for later disk serialization.
166
+
167
+ Args:
168
+ target: Module. The node to save its states.
169
+
170
+ Returns:
171
+ Dict. The state dict for serialization.
172
+ """
173
+ return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
174
+
175
+
176
+ @set_module_as('brainstate.nn')
177
+ def assign_state_values(target: Module, *state_by_abs_path: Dict):
178
+ """
179
+ Assign state values according to the given state dictionary.
180
+
181
+ Parameters
182
+ ----------
183
+ target: Module
184
+ The target module.
185
+ state_by_abs_path: dict
186
+ The state dictionary which is accessed by the "absolute" accessing method.
187
+
188
+ """
189
+ all_states = dict()
190
+ for state in state_by_abs_path:
191
+ all_states.update(state)
192
+ variables = target.states()
193
+ keys1 = set(all_states.keys())
194
+ keys2 = set(variables.keys())
195
+ for key in keys2.intersection(keys1):
196
+ variables[key].value = jax.numpy.asarray(all_states[key])
197
+ unexpected_keys = list(keys1 - keys2)
198
+ missing_keys = list(keys2 - keys1)
199
+ return unexpected_keys, missing_keys
@@ -0,0 +1,46 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from ._dynamics_neuron import *
18
+ from ._dynamics_neuron import __all__ as dyn_neuron_all
19
+ from ._dynamics_synapse import *
20
+ from ._dynamics_synapse import __all__ as dyn_synapse_all
21
+ from ._inputs import *
22
+ from ._inputs import __all__ as inputs_all
23
+ from ._projection_alignpost import *
24
+ from ._projection_alignpost import __all__ as alignpost_all
25
+ from ._rate_rnns import *
26
+ from ._rate_rnns import __all__ as rate_rnns
27
+ from ._readout import *
28
+ from ._readout import __all__ as readout_all
29
+
30
+ __all__ = (
31
+ dyn_neuron_all
32
+ + dyn_synapse_all
33
+ + inputs_all
34
+ + alignpost_all
35
+ + rate_rnns
36
+ + readout_all
37
+ )
38
+
39
+ del (
40
+ dyn_neuron_all,
41
+ dyn_synapse_all,
42
+ inputs_all,
43
+ readout_all,
44
+ alignpost_all,
45
+ rate_rnns,
46
+ )