brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,724 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+ from collections.abc import Callable
20
+ from functools import partial
21
+ from threading import Thread
22
+ from typing import Any
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import pytest
27
+ from absl.testing import absltest, parameterized
28
+
29
+ import brainstate as bst
30
+
31
+
32
+ class TestIter(unittest.TestCase):
33
+ def test1(self):
34
+ class Model(bst.nn.Module):
35
+ def __init__(self):
36
+ super().__init__()
37
+ self.a = bst.nn.Linear(1, 2)
38
+ self.b = bst.nn.Linear(2, 3)
39
+ self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
40
+ self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
41
+ self.b.a = bst.nn.LIF(2)
42
+
43
+ for path, node in bst.graph.iter_leaf(Model()):
44
+ print(path, node)
45
+ for path, node in bst.graph.iter_node(Model()):
46
+ print(path, node)
47
+ for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
48
+ print(path, node)
49
+ for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
50
+ print(path, node)
51
+
52
+ def test_iter_leaf_v1(self):
53
+ class Linear(bst.nn.Module):
54
+ def __init__(self, din, dout):
55
+ super().__init__()
56
+ self.weight = bst.ParamState(bst.random.randn(din, dout))
57
+ self.bias = bst.ParamState(bst.random.randn(dout))
58
+ self.a = 1
59
+
60
+ module = Linear(3, 4)
61
+ graph = [module, module]
62
+
63
+ num = 0
64
+ for path, value in bst.graph.iter_leaf(graph):
65
+ print(path, type(value).__name__)
66
+ num += 1
67
+
68
+ assert num == 3
69
+
70
+ def test_iter_node_v1(self):
71
+ class Model(bst.nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+ self.a = bst.nn.Linear(1, 2)
75
+ self.b = bst.nn.Linear(2, 3)
76
+ self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
77
+ self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
78
+ self.b.a = bst.nn.LIF(2)
79
+
80
+ model = Model()
81
+
82
+ num = 0
83
+ for path, node in bst.graph.iter_node([model, model]):
84
+ print(path, node.__class__.__name__)
85
+ num += 1
86
+ assert num == 8
87
+
88
+
89
+ class List(bst.nn.Module):
90
+ def __init__(self, items):
91
+ super().__init__()
92
+ self.items = list(items)
93
+
94
+ def __getitem__(self, idx):
95
+ return self.items[idx]
96
+
97
+ def __setitem__(self, idx, value):
98
+ self.items[idx] = value
99
+
100
+
101
+ class Dict(bst.nn.Module):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__()
104
+ self.items = dict(*args, **kwargs)
105
+
106
+ def __getitem__(self, key):
107
+ return self.items[key]
108
+
109
+ def __setitem__(self, key, value):
110
+ self.items[key] = value
111
+
112
+
113
+ class StatefulLinear(bst.nn.Module):
114
+ def __init__(self, din, dout):
115
+ super().__init__()
116
+ self.w = bst.ParamState(bst.random.rand(din, dout))
117
+ self.b = bst.ParamState(jnp.zeros((dout,)))
118
+ self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
119
+
120
+ def increment(self):
121
+ self.count.value += 1
122
+
123
+ def __call__(self, x):
124
+ self.count.value += 1
125
+ return x @ self.w.value + self.b.value
126
+
127
+
128
+ class TestGraphUtils(absltest.TestCase):
129
+ def test_flatten_treey_state(self):
130
+ a = {'a': 1, 'b': bst.ParamState(2)}
131
+ g = [a, 3, a, bst.ParamState(4)]
132
+
133
+ refmap = bst.graph.RefMap()
134
+ graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=True)
135
+
136
+ states[0]['b'].value = 2
137
+ states[3].value = 4
138
+
139
+ assert isinstance(states[0]['b'], bst.TreefyState)
140
+ assert isinstance(states[3], bst.TreefyState)
141
+ assert isinstance(states, bst.util.NestedDict)
142
+ assert len(refmap) == 2
143
+ assert a['b'] in refmap
144
+ assert g[3] in refmap
145
+
146
+ def test_flatten(self):
147
+ a = {'a': 1, 'b': bst.ParamState(2)}
148
+ g = [a, 3, a, bst.ParamState(4)]
149
+
150
+ refmap = bst.graph.RefMap()
151
+ graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=False)
152
+
153
+ states[0]['b'].value = 2
154
+ states[3].value = 4
155
+
156
+ assert isinstance(states[0]['b'], bst.State)
157
+ assert isinstance(states[3], bst.State)
158
+ assert len(refmap) == 2
159
+ assert a['b'] in refmap
160
+ assert g[3] in refmap
161
+
162
+ def test_unflatten_treey_state(self):
163
+ a = bst.graph.Dict(a=1, b=bst.ParamState(2))
164
+ g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
165
+
166
+ graphdef, references = bst.graph.flatten(g1, treefy_state=True)
167
+ g = bst.graph.unflatten(graphdef, references)
168
+
169
+ print(graphdef)
170
+ print(references)
171
+ assert g[0] is g[2]
172
+ assert g1[3] is not g[3]
173
+ assert g1[0]['b'] is not g[0]['b']
174
+
175
+ def test_unflatten(self):
176
+ a = bst.graph.Dict(a=1, b=bst.ParamState(2))
177
+ g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
178
+
179
+ graphdef, references = bst.graph.flatten(g1, treefy_state=False)
180
+ g = bst.graph.unflatten(graphdef, references)
181
+
182
+ print(graphdef)
183
+ print(references)
184
+ assert g[0] is g[2]
185
+ assert g1[3] is g[3]
186
+ assert g1[0]['b'] is g[0]['b']
187
+
188
+ def test_unflatten_pytree(self):
189
+ a = {'a': 1, 'b': bst.ParamState(2)}
190
+ g = [a, 3, a, bst.ParamState(4)]
191
+
192
+ graphdef, references = bst.graph.treefy_split(g)
193
+ g = bst.graph.treefy_merge(graphdef, references)
194
+
195
+ assert g[0] is not g[2]
196
+
197
+ def test_unflatten_empty(self):
198
+ a = Dict({'a': 1, 'b': bst.ParamState(2)})
199
+ g = List([a, 3, a, bst.ParamState(4)])
200
+
201
+ graphdef, references = bst.graph.treefy_split(g)
202
+
203
+ with self.assertRaisesRegex(ValueError, 'Expected key'):
204
+ bst.graph.unflatten(graphdef, bst.util.NestedDict({}))
205
+
206
+ def test_module_list(self):
207
+ ls = [
208
+ bst.nn.Linear(2, 2),
209
+ bst.nn.BatchNorm1d([10, 2]),
210
+ ]
211
+ graphdef, statetree = bst.graph.treefy_split(ls)
212
+
213
+ assert statetree[0]['weight'].value['weight'].shape == (2, 2)
214
+ assert statetree[0]['weight'].value['bias'].shape == (2,)
215
+ assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
216
+ assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
217
+ assert statetree[1]['running_mean'].value.shape == (1, 2,)
218
+ assert statetree[1]['running_var'].value.shape == (1, 2)
219
+
220
+ def test_shared_variables(self):
221
+ v = bst.ParamState(1)
222
+ g = [v, v]
223
+
224
+ graphdef, statetree = bst.graph.treefy_split(g)
225
+ assert len(statetree.to_flat()) == 1
226
+
227
+ g2 = bst.graph.treefy_merge(graphdef, statetree)
228
+ assert g2[0] is g2[1]
229
+
230
+ def test_tied_weights(self):
231
+ class Foo(bst.nn.Module):
232
+ def __init__(self) -> None:
233
+ super().__init__()
234
+ self.bar = bst.nn.Linear(2, 2)
235
+ self.baz = bst.nn.Linear(2, 2)
236
+
237
+ # tie the weights
238
+ self.baz.weight = self.bar.weight
239
+
240
+ node = Foo()
241
+ graphdef, state = bst.graph.treefy_split(node)
242
+
243
+ assert len(state.to_flat()) == 1
244
+
245
+ node2 = bst.graph.treefy_merge(graphdef, state)
246
+
247
+ assert node2.bar.weight is node2.baz.weight
248
+
249
+ def test_tied_weights_example(self):
250
+ class LinearTranspose(bst.nn.Module):
251
+ def __init__(self, dout: int, din: int, ) -> None:
252
+ super().__init__()
253
+ self.kernel = bst.ParamState(bst.init.LecunNormal()((dout, din)))
254
+
255
+ def __call__(self, x):
256
+ return x @ self.kernel.value.T
257
+
258
+ class Encoder(bst.nn.Module):
259
+ def __init__(self, ) -> None:
260
+ super().__init__()
261
+ self.embed = bst.nn.Embedding(10, 2)
262
+ self.linear_out = LinearTranspose(10, 2)
263
+
264
+ # tie the weights
265
+ self.linear_out.kernel = self.embed.weight
266
+
267
+ def __call__(self, x):
268
+ x = self.embed(x)
269
+ return self.linear_out(x)
270
+
271
+ model = Encoder()
272
+ graphdef, state = bst.graph.treefy_split(model)
273
+
274
+ assert len(state.to_flat()) == 1
275
+
276
+ x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
277
+ y = model(x)
278
+
279
+ assert y.shape == (2, 10)
280
+
281
+ def test_state_variables_not_shared_with_graph(self):
282
+ class Foo(bst.graph.Node):
283
+ def __init__(self):
284
+ self.a = bst.ParamState(1)
285
+
286
+ m = Foo()
287
+ graphdef, statetree = bst.graph.treefy_split(m)
288
+
289
+ assert isinstance(m.a, bst.ParamState)
290
+ assert issubclass(statetree.a.type, bst.ParamState)
291
+ assert m.a is not statetree.a
292
+ assert m.a.value == statetree.a.value
293
+
294
+ m2 = bst.graph.treefy_merge(graphdef, statetree)
295
+
296
+ assert isinstance(m2.a, bst.ParamState)
297
+ assert issubclass(statetree.a.type, bst.ParamState)
298
+ assert m2.a is not statetree.a
299
+ assert m2.a.value == statetree.a.value
300
+
301
+ def test_shared_state_variables_not_shared_with_graph(self):
302
+ class Foo(bst.graph.Node):
303
+ def __init__(self):
304
+ p = bst.ParamState(1)
305
+ self.a = p
306
+ self.b = p
307
+
308
+ m = Foo()
309
+ graphdef, state = bst.graph.treefy_split(m)
310
+
311
+ assert isinstance(m.a, bst.ParamState)
312
+ assert isinstance(m.b, bst.ParamState)
313
+ assert issubclass(state.a.type, bst.ParamState)
314
+ assert 'b' not in state
315
+ assert m.a is not state.a
316
+ assert m.b is not state.a
317
+ assert m.a.value == state.a.value
318
+ assert m.b.value == state.a.value
319
+
320
+ m2 = bst.graph.treefy_merge(graphdef, state)
321
+
322
+ assert isinstance(m2.a, bst.ParamState)
323
+ assert isinstance(m2.b, bst.ParamState)
324
+ assert issubclass(state.a.type, bst.ParamState)
325
+ assert m2.a is not state.a
326
+ assert m2.b is not state.a
327
+ assert m2.a.value == state.a.value
328
+ assert m2.b.value == state.a.value
329
+ assert m2.a is m2.b
330
+
331
+ def test_pytree_node(self):
332
+ @bst.util.dataclass
333
+ class Tree:
334
+ a: bst.ParamState
335
+ b: str = bst.util.field(pytree_node=False)
336
+
337
+ class Foo(bst.graph.Node):
338
+ def __init__(self):
339
+ self.tree = Tree(bst.ParamState(1), 'a')
340
+
341
+ m = Foo()
342
+
343
+ graphdef, state = bst.graph.treefy_split(m)
344
+
345
+ assert 'tree' in state
346
+ assert 'a' in state.tree
347
+ assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
348
+
349
+ m2 = bst.graph.treefy_merge(graphdef, state)
350
+
351
+ assert isinstance(m2.tree, Tree)
352
+ assert m2.tree.a.value == 1
353
+ assert m2.tree.b == 'a'
354
+ assert m2.tree.a is not m.tree.a
355
+ assert m2.tree is not m.tree
356
+
357
+ @pytest.mark.skip(reason='Not implemented')
358
+ def test_cached_unflatten(self):
359
+ class Foo(bst.graph.Node):
360
+ def __init__(self, ):
361
+ self.a = bst.nn.Linear(2, 2)
362
+ self.b = bst.nn.BatchNorm1d([10, 2])
363
+
364
+ def f(m: Foo):
365
+ m.a, m.b = m.b, m.a # type: ignore
366
+
367
+ m = Foo()
368
+ a = m.a
369
+ b = m.b
370
+
371
+ ref_out_idx_out = bst.graph.RefMap()
372
+ graphdef: bst.graph.GraphDef[Foo]
373
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
374
+
375
+ @partial(jax.jit, static_argnums=(0,))
376
+ def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
377
+ idx_out_ref_in: dict[int, Any] = {}
378
+ m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
379
+ f(m)
380
+ ref_in_idx_in = bst.graph.RefMap[Any, int]()
381
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
382
+ idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
383
+ static_out = bst.graph.Static((graphdef, idx_out_idx_in))
384
+ return state, static_out
385
+
386
+ static_out: bst.graph.Static
387
+ state, static_out = f_pure(graphdef, state)
388
+ idx_out_idx_in: dict[int, int]
389
+ graphdef, idx_out_idx_in = static_out.value
390
+ idx_in_ref_out = bst.graph.compose_mapping_reversed(
391
+ ref_out_idx_out, idx_out_idx_in
392
+ )
393
+ m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
394
+ assert m2 is m
395
+ assert m2.a is b
396
+ assert m2.b is a
397
+
398
+ @pytest.mark.skip(reason='Not implemented')
399
+ def test_cached_unflatten_swap_variables(self):
400
+ class Foo(bst.graph.Node):
401
+ def __init__(self):
402
+ self.a = bst.ParamState(1)
403
+ self.b = bst.ParamState(2)
404
+
405
+ def f(m: Foo):
406
+ m.a, m.b = m.b, m.a
407
+
408
+ m = Foo()
409
+ a = m.a
410
+ b = m.b
411
+
412
+ ref_out_idx_out = bst.graph.RefMap[Any, int]()
413
+ graphdef: bst.graph.GraphDef[Foo]
414
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
415
+
416
+ @partial(jax.jit, static_argnums=(0,))
417
+ def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
418
+ idx_out_ref_in: dict[int, Any] = {}
419
+ m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
420
+ f(m)
421
+ ref_in_idx_in = bst.graph.RefMap[Any, int]()
422
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
423
+ idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
424
+ static_out = bst.graph.Static((graphdef, idx_out_idx_in))
425
+ return state, static_out
426
+
427
+ static_out: bst.graph.Static
428
+ state, static_out = f_pure(graphdef, state)
429
+ idx_out_idx_in: dict[int, int]
430
+ graphdef, idx_out_idx_in = static_out.value
431
+ idx_in_ref_out = bst.graph.compose_mapping_reversed(
432
+ ref_out_idx_out, idx_out_idx_in
433
+ )
434
+ m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
435
+ assert m2 is m
436
+ assert m2.a is b
437
+ assert m2.b is a
438
+
439
+ @pytest.mark.skip(reason='Not implemented')
440
+ def test_cached_unflatten_add_self_reference(self):
441
+ class Foo(bst.graph.Node):
442
+ def __init__(self):
443
+ self.ref = None
444
+
445
+ def f(m: Foo):
446
+ m.ref = m
447
+
448
+ m = Foo()
449
+
450
+ ref_out_idx_out = bst.graph.RefMap()
451
+ graphdef: bst.graph.GraphDef[Foo]
452
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
453
+
454
+ @partial(jax.jit, static_argnums=(0,))
455
+ def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
456
+ idx_out_ref_in: dict[int, Any] = {}
457
+ m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
458
+ f(m)
459
+ ref_in_idx_in = bst.graph.RefMap[Any, int]()
460
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
461
+ idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
462
+ static_out = bst.graph.Static((graphdef, idx_out_idx_in))
463
+ return state, static_out
464
+
465
+ static_out: bst.graph.Static
466
+ state, static_out = f_pure(graphdef, state)
467
+ idx_out_idx_in: dict[int, int]
468
+ graphdef, idx_out_idx_in = static_out.value
469
+ idx_in_ref_out = bst.graph.compose_mapping_reversed(
470
+ ref_out_idx_out, idx_out_idx_in
471
+ )
472
+ m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
473
+ assert m2 is m
474
+ assert m2.ref is m2
475
+
476
+ def test_call_jit_update(self):
477
+ class Counter(bst.graph.Node):
478
+ def __init__(self):
479
+ self.count = bst.ParamState(jnp.zeros(()))
480
+
481
+ def inc(self):
482
+ self.count.value += 1
483
+ return 1
484
+
485
+ graph_state = bst.graph.treefy_split(Counter())
486
+
487
+ @jax.jit
488
+ def update(graph_state):
489
+ out, graph_state = bst.graph.call(graph_state).inc()
490
+ self.assertEqual(out, 1)
491
+ return graph_state
492
+
493
+ graph_state = update(graph_state)
494
+ graph_state = update(graph_state)
495
+
496
+ counter = bst.graph.treefy_merge(*graph_state)
497
+
498
+ self.assertEqual(counter.count.value, 2)
499
+
500
+ def test_stateful_linear(self):
501
+ linear = StatefulLinear(3, 2)
502
+ linear_state = bst.graph.treefy_split(linear)
503
+
504
+ @jax.jit
505
+ def forward(x, pure_linear):
506
+ y, pure_linear = bst.graph.call(pure_linear)(x)
507
+ return y, pure_linear
508
+
509
+ x = jnp.ones((1, 3))
510
+ y, linear_state = forward(x, linear_state)
511
+ y, linear_state = forward(x, linear_state)
512
+
513
+ self.assertEqual(linear.count.value, 0)
514
+ new_linear = bst.graph.treefy_merge(*linear_state)
515
+ self.assertEqual(new_linear.count.value, 2)
516
+
517
+ def test_getitem(self):
518
+ nodes = dict(
519
+ a=StatefulLinear(3, 2),
520
+ b=StatefulLinear(2, 1),
521
+ )
522
+ node_state = bst.graph.treefy_split(nodes)
523
+ _, node_state = bst.graph.call(node_state)['b'].increment()
524
+
525
+ nodes = bst.graph.treefy_merge(*node_state)
526
+
527
+ self.assertEqual(nodes['a'].count.value, 0)
528
+ self.assertEqual(nodes['b'].count.value, 1)
529
+
530
+ def test_to_tree_simple(self):
531
+ m = bst.nn.Linear(2, 3, )
532
+ impure_tree = (m, 1, {'b': m})
533
+
534
+ pure_tree = bst.graph.graph_to_tree(impure_tree)
535
+
536
+ t1 = pure_tree[0]
537
+ t2 = pure_tree[2]['b']
538
+
539
+ self.assertEqual(pure_tree[1], 1)
540
+ self.assertIsInstance(t1, bst.graph.NodeStates)
541
+ assert isinstance(t1, bst.graph.NodeStates)
542
+ self.assertIsInstance(t2, bst.graph.NodeStates)
543
+ assert isinstance(t2, bst.graph.NodeStates)
544
+ self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
545
+ self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
546
+ self.assertLen(t1.states[0].to_flat(), 1)
547
+ self.assertLen(t2.states[0].to_flat(), 0)
548
+
549
+ impure_tree2 = bst.graph.tree_to_graph(pure_tree)
550
+
551
+ m1_out = impure_tree2[0]
552
+ m2_out = impure_tree2[2]['b']
553
+
554
+ self.assertIs(m1_out, m2_out)
555
+ self.assertEqual(impure_tree2[1], 1)
556
+
557
+ def test_to_tree_consistent_prefix(self):
558
+ m = bst.nn.Linear(2, 3, )
559
+ impure_tree = (m, 1, {'b': m})
560
+ prefix = (0, None, 0)
561
+ pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
562
+
563
+ prefix = (0, None, 1)
564
+ with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
565
+ bst.graph.graph_to_tree(impure_tree, prefix=prefix)
566
+
567
+
568
+ class SimpleModule(bst.nn.Module):
569
+ pass
570
+
571
+
572
+ class SimplePyTreeModule(bst.nn.Module):
573
+ pass
574
+
575
+
576
+ class TestThreading(parameterized.TestCase):
577
+
578
+ @parameterized.parameters(
579
+ (SimpleModule,),
580
+ (SimplePyTreeModule,),
581
+ )
582
+ def test_threading(self, module_fn: Callable[[], bst.nn.Module]):
583
+ x = module_fn()
584
+
585
+ class MyThread(Thread):
586
+
587
+ def run(self) -> None:
588
+ bst.graph.treefy_split(x)
589
+
590
+ thread = MyThread()
591
+ thread.start()
592
+ thread.join()
593
+
594
+
595
+ class TestGraphOperation(unittest.TestCase):
596
+ def test1(self):
597
+ class MyNode(bst.graph.Node):
598
+ def __init__(self):
599
+ self.a = bst.nn.Linear(2, 3)
600
+ self.b = bst.nn.Linear(3, 2)
601
+ self.c = [bst.nn.Linear(1, 2), bst.nn.Linear(1, 3)]
602
+ self.d = {'x': bst.nn.Linear(1, 3), 'y': bst.nn.Linear(1, 4)}
603
+
604
+ graphdef, statetree = bst.graph.flatten(MyNode())
605
+ # print(graphdef)
606
+ print(statetree)
607
+ # print(bst.graph.unflatten(graphdef, statetree))
608
+
609
+ def test_split(self):
610
+ class Foo(bst.graph.Node):
611
+ def __init__(self):
612
+ self.a = bst.nn.Linear(2, 2)
613
+ self.b = bst.nn.BatchNorm1d([10, 2])
614
+
615
+ node = Foo()
616
+ graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
617
+
618
+ print(params)
619
+ print(jax.tree.map(jnp.shape, params))
620
+
621
+ print(jax.tree.map(jnp.shape, others))
622
+
623
+ def test_merge(self):
624
+ class Foo(bst.graph.Node):
625
+ def __init__(self):
626
+ self.a = bst.nn.Linear(2, 2)
627
+ self.b = bst.nn.BatchNorm1d([10, 2])
628
+
629
+ node = Foo()
630
+ graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
631
+
632
+ new_node = bst.graph.treefy_merge(graphdef, params, others)
633
+
634
+ assert isinstance(new_node, Foo)
635
+ assert isinstance(new_node.b, bst.nn.BatchNorm1d)
636
+ assert isinstance(new_node.a, bst.nn.Linear)
637
+
638
+ def test_update_states(self):
639
+ x = jnp.ones((1, 2))
640
+ y = jnp.ones((1, 3))
641
+ model = bst.nn.Linear(2, 3)
642
+
643
+ def loss_fn(x, y):
644
+ return jnp.mean((y - model(x)) ** 2)
645
+
646
+ def sgd(ps, gs):
647
+ updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
648
+ ps.value = updates
649
+
650
+ prev_loss = loss_fn(x, y)
651
+ weights = model.states()
652
+ grads = bst.augment.grad(loss_fn, weights)(x, y)
653
+ for key, val in grads.items():
654
+ sgd(weights[key], val)
655
+ assert loss_fn(x, y) < prev_loss
656
+
657
+ def test_pop_states(self):
658
+ class Model(bst.nn.Module):
659
+ def __init__(self):
660
+ super().__init__()
661
+ self.a = bst.nn.Linear(2, 3)
662
+ self.b = bst.nn.LIF([10, 2])
663
+
664
+ model = Model()
665
+ with bst.catch_new_states('new'):
666
+ bst.nn.init_all_states(model)
667
+ # print(model.states())
668
+ self.assertTrue(len(model.states()) == 2)
669
+ model_states = bst.graph.pop_states(model, 'new')
670
+ print(model_states)
671
+ self.assertTrue(len(model.states()) == 1)
672
+ assert not hasattr(model.b, 'V')
673
+ # print(model.states())
674
+
675
+ def test_treefy_split(self):
676
+ class MLP(bst.graph.Node):
677
+ def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
678
+ self.input = bst.nn.Linear(din, dmid)
679
+ self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
680
+ self.output = bst.nn.Linear(dmid, dout)
681
+
682
+ def __call__(self, x):
683
+ x = bst.functional.relu(self.input(x))
684
+ for layer in self.layers:
685
+ x = bst.functional.relu(layer(x))
686
+ return self.output(x)
687
+
688
+ model = MLP(2, 1, 3)
689
+ graph_def, treefy_states = bst.graph.treefy_split(model)
690
+
691
+ print(graph_def)
692
+ print(treefy_states)
693
+
694
+ # states = bst.graph.states(model)
695
+ # print(states)
696
+ # nest_states = states.to_nest()
697
+ # print(nest_states)
698
+
699
+ def test_states(self):
700
+ class MLP(bst.graph.Node):
701
+ def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
702
+ self.input = bst.nn.Linear(din, dmid)
703
+ self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
704
+ self.output = bst.nn.LIF(dout)
705
+
706
+ def __call__(self, x):
707
+ x = bst.functional.relu(self.input(x))
708
+ for layer in self.layers:
709
+ x = bst.functional.relu(layer(x))
710
+ return self.output(x)
711
+
712
+ model = bst.nn.init_all_states(MLP(2, 1, 3))
713
+ states = bst.graph.states(model)
714
+ print(states)
715
+ nest_states = states.to_nest()
716
+ print(nest_states)
717
+
718
+ params, others = bst.graph.states(model, bst.ParamState, bst.ShortTermState)
719
+ print(params)
720
+ print(others)
721
+
722
+
723
+ if __name__ == '__main__':
724
+ absltest.main()