brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,159 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax
20
- from absl.testing import absltest
21
-
22
- import brainstate
23
-
24
-
25
- class TestNestedMapping(absltest.TestCase):
26
- def test_create_state(self):
27
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
28
-
29
- assert state['a'].value == 1
30
- assert state['b']['c'].value == 2
31
-
32
- def test_get_attr(self):
33
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
34
-
35
- assert state.a.value == 1
36
- assert state.b['c'].value == 2
37
-
38
- def test_set_attr(self):
39
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
40
-
41
- state.a.value = 3
42
- state.b['c'].value = 4
43
-
44
- assert state['a'].value == 3
45
- assert state['b']['c'].value == 4
46
-
47
- def test_set_attr_variables(self):
48
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
49
-
50
- state.a.value = 3
51
- state.b['c'].value = 4
52
-
53
- assert isinstance(state.a, brainstate.ParamState)
54
- assert state.a.value == 3
55
- assert isinstance(state.b['c'], brainstate.ParamState)
56
- assert state.b['c'].value == 4
57
-
58
- def test_add_nested_attr(self):
59
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
60
- state.b['d'] = brainstate.ParamState(5)
61
-
62
- assert state['b']['d'].value == 5
63
-
64
- def test_delete_nested_attr(self):
65
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
66
- del state['b']['c']
67
-
68
- assert 'c' not in state['b']
69
-
70
- def test_integer_access(self):
71
- class Foo(brainstate.nn.Module):
72
- def __init__(self):
73
- super().__init__()
74
- self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
75
-
76
- module = Foo()
77
- state_refs = brainstate.graph.treefy_states(module)
78
-
79
- assert module.layers[0].weight.value['weight'].shape == (1, 2)
80
- assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
81
- assert module.layers[1].weight.value['weight'].shape == (2, 3)
82
- assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
83
-
84
- def test_pure_dict(self):
85
- module = brainstate.nn.Linear(4, 5)
86
- state_map = brainstate.graph.treefy_states(module)
87
- pure_dict = state_map.to_pure_dict()
88
- assert isinstance(pure_dict, dict)
89
- assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
90
- assert isinstance(pure_dict['weight'].value['bias'], jax.Array)
91
-
92
-
93
- class TestSplit(unittest.TestCase):
94
- def test_split(self):
95
- class Model(brainstate.nn.Module):
96
- def __init__(self):
97
- super().__init__()
98
- self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
99
- self.linear = brainstate.nn.Linear([10, 3], [10, 4])
100
-
101
- def __call__(self, x):
102
- return self.linear(self.batchnorm(x))
103
-
104
- with brainstate.environ.context(fit=True):
105
- model = Model()
106
- x = brainstate.random.randn(1, 10, 3)
107
- y = model(x)
108
- self.assertEqual(y.shape, (1, 10, 4))
109
-
110
- state_map = brainstate.graph.treefy_states(model)
111
-
112
- with self.assertRaises(ValueError):
113
- params, others = state_map.split(brainstate.ParamState)
114
-
115
- params, others = state_map.split(brainstate.ParamState, ...)
116
- print()
117
- print(params)
118
- print(others)
119
-
120
- self.assertTrue(len(params.to_flat()) == 2)
121
- self.assertTrue(len(others.to_flat()) == 2)
122
-
123
-
124
- class TestStateMap2(unittest.TestCase):
125
- def test1(self):
126
- class Model(brainstate.nn.Module):
127
- def __init__(self):
128
- super().__init__()
129
- self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
130
- self.linear = brainstate.nn.Linear([10, 3], [10, 4])
131
-
132
- def __call__(self, x):
133
- return self.linear(self.batchnorm(x))
134
-
135
- with brainstate.environ.context(fit=True):
136
- model = Model()
137
- state_map = brainstate.graph.treefy_states(model).to_flat()
138
- state_map = brainstate.util.NestedDict(state_map)
139
-
140
-
141
- class TestFlattedMapping(unittest.TestCase):
142
- def test1(self):
143
- class Model(brainstate.nn.Module):
144
- def __init__(self):
145
- super().__init__()
146
- self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
147
- self.linear = brainstate.nn.Linear([10, 3], [10, 4])
148
-
149
- def __call__(self, x):
150
- return self.linear(self.batchnorm(x))
151
-
152
- model = Model()
153
- # print(model.states())
154
- # print(brainstate.graph.states(model))
155
- self.assertTrue(model.states() == brainstate.graph.states(model))
156
-
157
- print(model.nodes())
158
- # print(brainstate.graph.nodes(model))
159
- self.assertTrue(model.nodes() == brainstate.graph.nodes(model))