brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,724 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+ from collections.abc import Callable
20
+ from functools import partial
21
+ from threading import Thread
22
+ from typing import Any
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import pytest
27
+ from absl.testing import absltest, parameterized
28
+
29
+ import brainstate as bst
30
+
31
+
32
+ class TestIter(unittest.TestCase):
33
+ def test1(self):
34
+ class Model(bst.nn.Module):
35
+ def __init__(self):
36
+ super().__init__()
37
+ self.a = bst.nn.Linear(1, 2)
38
+ self.b = bst.nn.Linear(2, 3)
39
+ self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
40
+ self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
41
+ self.b.a = bst.nn.LIF(2)
42
+
43
+ for path, node in bst.graph.iter_leaf(Model()):
44
+ print(path, node)
45
+ for path, node in bst.graph.iter_node(Model()):
46
+ print(path, node)
47
+ for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
48
+ print(path, node)
49
+ for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
50
+ print(path, node)
51
+
52
+ def test_iter_leaf_v1(self):
53
+ class Linear(bst.nn.Module):
54
+ def __init__(self, din, dout):
55
+ super().__init__()
56
+ self.weight = bst.ParamState(bst.random.randn(din, dout))
57
+ self.bias = bst.ParamState(bst.random.randn(dout))
58
+ self.a = 1
59
+
60
+ module = Linear(3, 4)
61
+ graph = [module, module]
62
+
63
+ num = 0
64
+ for path, value in bst.graph.iter_leaf(graph):
65
+ print(path, type(value).__name__)
66
+ num += 1
67
+
68
+ assert num == 3
69
+
70
+ def test_iter_node_v1(self):
71
+ class Model(bst.nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+ self.a = bst.nn.Linear(1, 2)
75
+ self.b = bst.nn.Linear(2, 3)
76
+ self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
77
+ self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
78
+ self.b.a = bst.nn.LIF(2)
79
+
80
+ model = Model()
81
+
82
+ num = 0
83
+ for path, node in bst.graph.iter_node([model, model]):
84
+ print(path, node.__class__.__name__)
85
+ num += 1
86
+ assert num == 8
87
+
88
+
89
+ class List(bst.nn.Module):
90
+ def __init__(self, items):
91
+ super().__init__()
92
+ self.items = list(items)
93
+
94
+ def __getitem__(self, idx):
95
+ return self.items[idx]
96
+
97
+ def __setitem__(self, idx, value):
98
+ self.items[idx] = value
99
+
100
+
101
+ class Dict(bst.nn.Module):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__()
104
+ self.items = dict(*args, **kwargs)
105
+
106
+ def __getitem__(self, key):
107
+ return self.items[key]
108
+
109
+ def __setitem__(self, key, value):
110
+ self.items[key] = value
111
+
112
+
113
+ class StatefulLinear(bst.nn.Module):
114
+ def __init__(self, din, dout):
115
+ super().__init__()
116
+ self.w = bst.ParamState(bst.random.rand(din, dout))
117
+ self.b = bst.ParamState(jnp.zeros((dout,)))
118
+ self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
119
+
120
+ def increment(self):
121
+ self.count.value += 1
122
+
123
+ def __call__(self, x):
124
+ self.count.value += 1
125
+ return x @ self.w.value + self.b.value
126
+
127
+
128
+ class TestGraphUtils(absltest.TestCase):
129
+ def test_flatten_treey_state(self):
130
+ a = {'a': 1, 'b': bst.ParamState(2)}
131
+ g = [a, 3, a, bst.ParamState(4)]
132
+
133
+ refmap = bst.graph.RefMap()
134
+ graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=True)
135
+
136
+ states[0]['b'].value = 2
137
+ states[3].value = 4
138
+
139
+ assert isinstance(states[0]['b'], bst.TreefyState)
140
+ assert isinstance(states[3], bst.TreefyState)
141
+ assert isinstance(states, bst.util.NestedDict)
142
+ assert len(refmap) == 2
143
+ assert a['b'] in refmap
144
+ assert g[3] in refmap
145
+
146
+ def test_flatten(self):
147
+ a = {'a': 1, 'b': bst.ParamState(2)}
148
+ g = [a, 3, a, bst.ParamState(4)]
149
+
150
+ refmap = bst.graph.RefMap()
151
+ graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=False)
152
+
153
+ states[0]['b'].value = 2
154
+ states[3].value = 4
155
+
156
+ assert isinstance(states[0]['b'], bst.State)
157
+ assert isinstance(states[3], bst.State)
158
+ assert len(refmap) == 2
159
+ assert a['b'] in refmap
160
+ assert g[3] in refmap
161
+
162
+ def test_unflatten_treey_state(self):
163
+ a = bst.graph.Dict(a=1, b=bst.ParamState(2))
164
+ g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
165
+
166
+ graphdef, references = bst.graph.flatten(g1, treefy_state=True)
167
+ g = bst.graph.unflatten(graphdef, references)
168
+
169
+ print(graphdef)
170
+ print(references)
171
+ assert g[0] is g[2]
172
+ assert g1[3] is not g[3]
173
+ assert g1[0]['b'] is not g[0]['b']
174
+
175
+ def test_unflatten(self):
176
+ a = bst.graph.Dict(a=1, b=bst.ParamState(2))
177
+ g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
178
+
179
+ graphdef, references = bst.graph.flatten(g1, treefy_state=False)
180
+ g = bst.graph.unflatten(graphdef, references)
181
+
182
+ print(graphdef)
183
+ print(references)
184
+ assert g[0] is g[2]
185
+ assert g1[3] is g[3]
186
+ assert g1[0]['b'] is g[0]['b']
187
+
188
+ def test_unflatten_pytree(self):
189
+ a = {'a': 1, 'b': bst.ParamState(2)}
190
+ g = [a, 3, a, bst.ParamState(4)]
191
+
192
+ graphdef, references = bst.graph.treefy_split(g)
193
+ g = bst.graph.treefy_merge(graphdef, references)
194
+
195
+ assert g[0] is not g[2]
196
+
197
+ def test_unflatten_empty(self):
198
+ a = Dict({'a': 1, 'b': bst.ParamState(2)})
199
+ g = List([a, 3, a, bst.ParamState(4)])
200
+
201
+ graphdef, references = bst.graph.treefy_split(g)
202
+
203
+ with self.assertRaisesRegex(ValueError, 'Expected key'):
204
+ bst.graph.unflatten(graphdef, bst.util.NestedDict({}))
205
+
206
+ def test_module_list(self):
207
+ ls = [
208
+ bst.nn.Linear(2, 2),
209
+ bst.nn.BatchNorm1d([10, 2]),
210
+ ]
211
+ graphdef, statetree = bst.graph.treefy_split(ls)
212
+
213
+ assert statetree[0]['weight'].value['weight'].shape == (2, 2)
214
+ assert statetree[0]['weight'].value['bias'].shape == (2,)
215
+ assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
216
+ assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
217
+ assert statetree[1]['running_mean'].value.shape == (1, 2,)
218
+ assert statetree[1]['running_var'].value.shape == (1, 2)
219
+
220
+ def test_shared_variables(self):
221
+ v = bst.ParamState(1)
222
+ g = [v, v]
223
+
224
+ graphdef, statetree = bst.graph.treefy_split(g)
225
+ assert len(statetree.to_flat()) == 1
226
+
227
+ g2 = bst.graph.treefy_merge(graphdef, statetree)
228
+ assert g2[0] is g2[1]
229
+
230
+ def test_tied_weights(self):
231
+ class Foo(bst.nn.Module):
232
+ def __init__(self) -> None:
233
+ super().__init__()
234
+ self.bar = bst.nn.Linear(2, 2)
235
+ self.baz = bst.nn.Linear(2, 2)
236
+
237
+ # tie the weights
238
+ self.baz.weight = self.bar.weight
239
+
240
+ node = Foo()
241
+ graphdef, state = bst.graph.treefy_split(node)
242
+
243
+ assert len(state.to_flat()) == 1
244
+
245
+ node2 = bst.graph.treefy_merge(graphdef, state)
246
+
247
+ assert node2.bar.weight is node2.baz.weight
248
+
249
+ def test_tied_weights_example(self):
250
+ class LinearTranspose(bst.nn.Module):
251
+ def __init__(self, dout: int, din: int, ) -> None:
252
+ super().__init__()
253
+ self.kernel = bst.ParamState(bst.init.LecunNormal()((dout, din)))
254
+
255
+ def __call__(self, x):
256
+ return x @ self.kernel.value.T
257
+
258
+ class Encoder(bst.nn.Module):
259
+ def __init__(self, ) -> None:
260
+ super().__init__()
261
+ self.embed = bst.nn.Embedding(10, 2)
262
+ self.linear_out = LinearTranspose(10, 2)
263
+
264
+ # tie the weights
265
+ self.linear_out.kernel = self.embed.weight
266
+
267
+ def __call__(self, x):
268
+ x = self.embed(x)
269
+ return self.linear_out(x)
270
+
271
+ model = Encoder()
272
+ graphdef, state = bst.graph.treefy_split(model)
273
+
274
+ assert len(state.to_flat()) == 1
275
+
276
+ x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
277
+ y = model(x)
278
+
279
+ assert y.shape == (2, 10)
280
+
281
+ def test_state_variables_not_shared_with_graph(self):
282
+ class Foo(bst.graph.Node):
283
+ def __init__(self):
284
+ self.a = bst.ParamState(1)
285
+
286
+ m = Foo()
287
+ graphdef, statetree = bst.graph.treefy_split(m)
288
+
289
+ assert isinstance(m.a, bst.ParamState)
290
+ assert issubclass(statetree.a.type, bst.ParamState)
291
+ assert m.a is not statetree.a
292
+ assert m.a.value == statetree.a.value
293
+
294
+ m2 = bst.graph.treefy_merge(graphdef, statetree)
295
+
296
+ assert isinstance(m2.a, bst.ParamState)
297
+ assert issubclass(statetree.a.type, bst.ParamState)
298
+ assert m2.a is not statetree.a
299
+ assert m2.a.value == statetree.a.value
300
+
301
+ def test_shared_state_variables_not_shared_with_graph(self):
302
+ class Foo(bst.graph.Node):
303
+ def __init__(self):
304
+ p = bst.ParamState(1)
305
+ self.a = p
306
+ self.b = p
307
+
308
+ m = Foo()
309
+ graphdef, state = bst.graph.treefy_split(m)
310
+
311
+ assert isinstance(m.a, bst.ParamState)
312
+ assert isinstance(m.b, bst.ParamState)
313
+ assert issubclass(state.a.type, bst.ParamState)
314
+ assert 'b' not in state
315
+ assert m.a is not state.a
316
+ assert m.b is not state.a
317
+ assert m.a.value == state.a.value
318
+ assert m.b.value == state.a.value
319
+
320
+ m2 = bst.graph.treefy_merge(graphdef, state)
321
+
322
+ assert isinstance(m2.a, bst.ParamState)
323
+ assert isinstance(m2.b, bst.ParamState)
324
+ assert issubclass(state.a.type, bst.ParamState)
325
+ assert m2.a is not state.a
326
+ assert m2.b is not state.a
327
+ assert m2.a.value == state.a.value
328
+ assert m2.b.value == state.a.value
329
+ assert m2.a is m2.b
330
+
331
+ def test_pytree_node(self):
332
+ @bst.util.dataclass
333
+ class Tree:
334
+ a: bst.ParamState
335
+ b: str = bst.util.field(pytree_node=False)
336
+
337
+ class Foo(bst.graph.Node):
338
+ def __init__(self):
339
+ self.tree = Tree(bst.ParamState(1), 'a')
340
+
341
+ m = Foo()
342
+
343
+ graphdef, state = bst.graph.treefy_split(m)
344
+
345
+ assert 'tree' in state
346
+ assert 'a' in state.tree
347
+ assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
348
+
349
+ m2 = bst.graph.treefy_merge(graphdef, state)
350
+
351
+ assert isinstance(m2.tree, Tree)
352
+ assert m2.tree.a.value == 1
353
+ assert m2.tree.b == 'a'
354
+ assert m2.tree.a is not m.tree.a
355
+ assert m2.tree is not m.tree
356
+
357
+ @pytest.mark.skip(reason='Not implemented')
358
+ def test_cached_unflatten(self):
359
+ class Foo(bst.graph.Node):
360
+ def __init__(self, ):
361
+ self.a = bst.nn.Linear(2, 2)
362
+ self.b = bst.nn.BatchNorm1d([10, 2])
363
+
364
+ def f(m: Foo):
365
+ m.a, m.b = m.b, m.a # type: ignore
366
+
367
+ m = Foo()
368
+ a = m.a
369
+ b = m.b
370
+
371
+ ref_out_idx_out = bst.graph.RefMap()
372
+ graphdef: bst.graph.GraphDef[Foo]
373
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
374
+
375
+ @partial(jax.jit, static_argnums=(0,))
376
+ def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
377
+ idx_out_ref_in: dict[int, Any] = {}
378
+ m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
379
+ f(m)
380
+ ref_in_idx_in = bst.graph.RefMap[Any, int]()
381
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
382
+ idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
383
+ static_out = bst.graph.Static((graphdef, idx_out_idx_in))
384
+ return state, static_out
385
+
386
+ static_out: bst.graph.Static
387
+ state, static_out = f_pure(graphdef, state)
388
+ idx_out_idx_in: dict[int, int]
389
+ graphdef, idx_out_idx_in = static_out.value
390
+ idx_in_ref_out = bst.graph.compose_mapping_reversed(
391
+ ref_out_idx_out, idx_out_idx_in
392
+ )
393
+ m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
394
+ assert m2 is m
395
+ assert m2.a is b
396
+ assert m2.b is a
397
+
398
+ @pytest.mark.skip(reason='Not implemented')
399
+ def test_cached_unflatten_swap_variables(self):
400
+ class Foo(bst.graph.Node):
401
+ def __init__(self):
402
+ self.a = bst.ParamState(1)
403
+ self.b = bst.ParamState(2)
404
+
405
+ def f(m: Foo):
406
+ m.a, m.b = m.b, m.a
407
+
408
+ m = Foo()
409
+ a = m.a
410
+ b = m.b
411
+
412
+ ref_out_idx_out = bst.graph.RefMap[Any, int]()
413
+ graphdef: bst.graph.GraphDef[Foo]
414
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
415
+
416
+ @partial(jax.jit, static_argnums=(0,))
417
+ def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
418
+ idx_out_ref_in: dict[int, Any] = {}
419
+ m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
420
+ f(m)
421
+ ref_in_idx_in = bst.graph.RefMap[Any, int]()
422
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
423
+ idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
424
+ static_out = bst.graph.Static((graphdef, idx_out_idx_in))
425
+ return state, static_out
426
+
427
+ static_out: bst.graph.Static
428
+ state, static_out = f_pure(graphdef, state)
429
+ idx_out_idx_in: dict[int, int]
430
+ graphdef, idx_out_idx_in = static_out.value
431
+ idx_in_ref_out = bst.graph.compose_mapping_reversed(
432
+ ref_out_idx_out, idx_out_idx_in
433
+ )
434
+ m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
435
+ assert m2 is m
436
+ assert m2.a is b
437
+ assert m2.b is a
438
+
439
+ @pytest.mark.skip(reason='Not implemented')
440
+ def test_cached_unflatten_add_self_reference(self):
441
+ class Foo(bst.graph.Node):
442
+ def __init__(self):
443
+ self.ref = None
444
+
445
+ def f(m: Foo):
446
+ m.ref = m
447
+
448
+ m = Foo()
449
+
450
+ ref_out_idx_out = bst.graph.RefMap()
451
+ graphdef: bst.graph.GraphDef[Foo]
452
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
453
+
454
+ @partial(jax.jit, static_argnums=(0,))
455
+ def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
456
+ idx_out_ref_in: dict[int, Any] = {}
457
+ m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
458
+ f(m)
459
+ ref_in_idx_in = bst.graph.RefMap[Any, int]()
460
+ graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
461
+ idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
462
+ static_out = bst.graph.Static((graphdef, idx_out_idx_in))
463
+ return state, static_out
464
+
465
+ static_out: bst.graph.Static
466
+ state, static_out = f_pure(graphdef, state)
467
+ idx_out_idx_in: dict[int, int]
468
+ graphdef, idx_out_idx_in = static_out.value
469
+ idx_in_ref_out = bst.graph.compose_mapping_reversed(
470
+ ref_out_idx_out, idx_out_idx_in
471
+ )
472
+ m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
473
+ assert m2 is m
474
+ assert m2.ref is m2
475
+
476
+ def test_call_jit_update(self):
477
+ class Counter(bst.graph.Node):
478
+ def __init__(self):
479
+ self.count = bst.ParamState(jnp.zeros(()))
480
+
481
+ def inc(self):
482
+ self.count.value += 1
483
+ return 1
484
+
485
+ graph_state = bst.graph.treefy_split(Counter())
486
+
487
+ @jax.jit
488
+ def update(graph_state):
489
+ out, graph_state = bst.graph.call(graph_state).inc()
490
+ self.assertEqual(out, 1)
491
+ return graph_state
492
+
493
+ graph_state = update(graph_state)
494
+ graph_state = update(graph_state)
495
+
496
+ counter = bst.graph.treefy_merge(*graph_state)
497
+
498
+ self.assertEqual(counter.count.value, 2)
499
+
500
+ def test_stateful_linear(self):
501
+ linear = StatefulLinear(3, 2)
502
+ linear_state = bst.graph.treefy_split(linear)
503
+
504
+ @jax.jit
505
+ def forward(x, pure_linear):
506
+ y, pure_linear = bst.graph.call(pure_linear)(x)
507
+ return y, pure_linear
508
+
509
+ x = jnp.ones((1, 3))
510
+ y, linear_state = forward(x, linear_state)
511
+ y, linear_state = forward(x, linear_state)
512
+
513
+ self.assertEqual(linear.count.value, 0)
514
+ new_linear = bst.graph.treefy_merge(*linear_state)
515
+ self.assertEqual(new_linear.count.value, 2)
516
+
517
+ def test_getitem(self):
518
+ nodes = dict(
519
+ a=StatefulLinear(3, 2),
520
+ b=StatefulLinear(2, 1),
521
+ )
522
+ node_state = bst.graph.treefy_split(nodes)
523
+ _, node_state = bst.graph.call(node_state)['b'].increment()
524
+
525
+ nodes = bst.graph.treefy_merge(*node_state)
526
+
527
+ self.assertEqual(nodes['a'].count.value, 0)
528
+ self.assertEqual(nodes['b'].count.value, 1)
529
+
530
+ def test_to_tree_simple(self):
531
+ m = bst.nn.Linear(2, 3, )
532
+ impure_tree = (m, 1, {'b': m})
533
+
534
+ pure_tree = bst.graph.graph_to_tree(impure_tree)
535
+
536
+ t1 = pure_tree[0]
537
+ t2 = pure_tree[2]['b']
538
+
539
+ self.assertEqual(pure_tree[1], 1)
540
+ self.assertIsInstance(t1, bst.graph.NodeStates)
541
+ assert isinstance(t1, bst.graph.NodeStates)
542
+ self.assertIsInstance(t2, bst.graph.NodeStates)
543
+ assert isinstance(t2, bst.graph.NodeStates)
544
+ self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
545
+ self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
546
+ self.assertLen(t1.states[0].to_flat(), 1)
547
+ self.assertLen(t2.states[0].to_flat(), 0)
548
+
549
+ impure_tree2 = bst.graph.tree_to_graph(pure_tree)
550
+
551
+ m1_out = impure_tree2[0]
552
+ m2_out = impure_tree2[2]['b']
553
+
554
+ self.assertIs(m1_out, m2_out)
555
+ self.assertEqual(impure_tree2[1], 1)
556
+
557
+ def test_to_tree_consistent_prefix(self):
558
+ m = bst.nn.Linear(2, 3, )
559
+ impure_tree = (m, 1, {'b': m})
560
+ prefix = (0, None, 0)
561
+ pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
562
+
563
+ prefix = (0, None, 1)
564
+ with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
565
+ bst.graph.graph_to_tree(impure_tree, prefix=prefix)
566
+
567
+
568
+ class SimpleModule(bst.nn.Module):
569
+ pass
570
+
571
+
572
+ class SimplePyTreeModule(bst.nn.Module):
573
+ pass
574
+
575
+
576
+ class TestThreading(parameterized.TestCase):
577
+
578
+ @parameterized.parameters(
579
+ (SimpleModule,),
580
+ (SimplePyTreeModule,),
581
+ )
582
+ def test_threading(self, module_fn: Callable[[], bst.nn.Module]):
583
+ x = module_fn()
584
+
585
+ class MyThread(Thread):
586
+
587
+ def run(self) -> None:
588
+ bst.graph.treefy_split(x)
589
+
590
+ thread = MyThread()
591
+ thread.start()
592
+ thread.join()
593
+
594
+
595
+ class TestGraphOperation(unittest.TestCase):
596
+ def test1(self):
597
+ class MyNode(bst.graph.Node):
598
+ def __init__(self):
599
+ self.a = bst.nn.Linear(2, 3)
600
+ self.b = bst.nn.Linear(3, 2)
601
+ self.c = [bst.nn.Linear(1, 2), bst.nn.Linear(1, 3)]
602
+ self.d = {'x': bst.nn.Linear(1, 3), 'y': bst.nn.Linear(1, 4)}
603
+
604
+ graphdef, statetree = bst.graph.flatten(MyNode())
605
+ # print(graphdef)
606
+ print(statetree)
607
+ # print(bst.graph.unflatten(graphdef, statetree))
608
+
609
+ def test_split(self):
610
+ class Foo(bst.graph.Node):
611
+ def __init__(self):
612
+ self.a = bst.nn.Linear(2, 2)
613
+ self.b = bst.nn.BatchNorm1d([10, 2])
614
+
615
+ node = Foo()
616
+ graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
617
+
618
+ print(params)
619
+ print(jax.tree.map(jnp.shape, params))
620
+
621
+ print(jax.tree.map(jnp.shape, others))
622
+
623
+ def test_merge(self):
624
+ class Foo(bst.graph.Node):
625
+ def __init__(self):
626
+ self.a = bst.nn.Linear(2, 2)
627
+ self.b = bst.nn.BatchNorm1d([10, 2])
628
+
629
+ node = Foo()
630
+ graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
631
+
632
+ new_node = bst.graph.treefy_merge(graphdef, params, others)
633
+
634
+ assert isinstance(new_node, Foo)
635
+ assert isinstance(new_node.b, bst.nn.BatchNorm1d)
636
+ assert isinstance(new_node.a, bst.nn.Linear)
637
+
638
+ def test_update_states(self):
639
+ x = jnp.ones((1, 2))
640
+ y = jnp.ones((1, 3))
641
+ model = bst.nn.Linear(2, 3)
642
+
643
+ def loss_fn(x, y):
644
+ return jnp.mean((y - model(x)) ** 2)
645
+
646
+ def sgd(ps, gs):
647
+ updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
648
+ ps.value = updates
649
+
650
+ prev_loss = loss_fn(x, y)
651
+ weights = model.states()
652
+ grads = bst.augment.grad(loss_fn, weights)(x, y)
653
+ for key, val in grads.items():
654
+ sgd(weights[key], val)
655
+ assert loss_fn(x, y) < prev_loss
656
+
657
+ def test_pop_states(self):
658
+ class Model(bst.nn.Module):
659
+ def __init__(self):
660
+ super().__init__()
661
+ self.a = bst.nn.Linear(2, 3)
662
+ self.b = bst.nn.LIF([10, 2])
663
+
664
+ model = Model()
665
+ with bst.catch_new_states('new'):
666
+ bst.nn.init_all_states(model)
667
+ # print(model.states())
668
+ self.assertTrue(len(model.states()) == 2)
669
+ model_states = bst.graph.pop_states(model, 'new')
670
+ print(model_states)
671
+ self.assertTrue(len(model.states()) == 1)
672
+ assert not hasattr(model.b, 'V')
673
+ # print(model.states())
674
+
675
+ def test_treefy_split(self):
676
+ class MLP(bst.graph.Node):
677
+ def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
678
+ self.input = bst.nn.Linear(din, dmid)
679
+ self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
680
+ self.output = bst.nn.Linear(dmid, dout)
681
+
682
+ def __call__(self, x):
683
+ x = bst.functional.relu(self.input(x))
684
+ for layer in self.layers:
685
+ x = bst.functional.relu(layer(x))
686
+ return self.output(x)
687
+
688
+ model = MLP(2, 1, 3)
689
+ graph_def, treefy_states = bst.graph.treefy_split(model)
690
+
691
+ print(graph_def)
692
+ print(treefy_states)
693
+
694
+ # states = bst.graph.states(model)
695
+ # print(states)
696
+ # nest_states = states.to_nest()
697
+ # print(nest_states)
698
+
699
+ def test_states(self):
700
+ class MLP(bst.graph.Node):
701
+ def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
702
+ self.input = bst.nn.Linear(din, dmid)
703
+ self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
704
+ self.output = bst.nn.LIF(dout)
705
+
706
+ def __call__(self, x):
707
+ x = bst.functional.relu(self.input(x))
708
+ for layer in self.layers:
709
+ x = bst.functional.relu(layer(x))
710
+ return self.output(x)
711
+
712
+ model = bst.nn.init_all_states(MLP(2, 1, 3))
713
+ states = bst.graph.states(model)
714
+ print(states)
715
+ nest_states = states.to_nest()
716
+ print(nest_states)
717
+
718
+ params, others = bst.graph.states(model, bst.ParamState, bst.ShortTermState)
719
+ print(params)
720
+ print(others)
721
+
722
+
723
+ if __name__ == '__main__':
724
+ absltest.main()