brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,194 +1,104 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
1
|
import unittest
|
18
2
|
|
19
3
|
import jax
|
20
4
|
import jax.numpy as jnp
|
21
|
-
from jax import vmap
|
22
|
-
from jax.lax import psum, pmean, pmax
|
23
|
-
|
24
|
-
import brainstate
|
25
|
-
import brainstate.transform
|
26
|
-
from brainstate._error import BatchAxisError
|
27
5
|
|
6
|
+
import brainstate as bst
|
7
|
+
from brainstate.transform import StatefulMapping, vmap, vmap_new_states, pmap, map as bst_map
|
8
|
+
from brainstate.util import filter as state_filter
|
28
9
|
|
29
10
|
|
30
11
|
class TestMap(unittest.TestCase):
|
31
|
-
def
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
'max': max_val,
|
63
|
-
'original': x
|
64
|
-
}
|
65
|
-
|
66
|
-
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
67
|
-
print("Input batch data:", batch_data)
|
68
|
-
|
69
|
-
# vmap with axis name 'batch'
|
70
|
-
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
71
|
-
result_jax = vectorized_stats_jax(batch_data)
|
72
|
-
|
73
|
-
# vmap with axis name 'batch'
|
74
|
-
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
75
|
-
result = vectorized_stats(batch_data)
|
76
|
-
|
77
|
-
# vmap with axis name 'batch'
|
78
|
-
vectorized_stats_v2 = brainstate.transform.jit(
|
79
|
-
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
12
|
+
def test_map_matches_vectorized(self):
|
13
|
+
xs = jnp.arange(6.0).reshape(6, 1)
|
14
|
+
|
15
|
+
def fn(x):
|
16
|
+
return x + 1.0
|
17
|
+
|
18
|
+
expected = jax.vmap(fn)(xs)
|
19
|
+
result = bst_map(fn, xs)
|
20
|
+
self.assertTrue(jnp.allclose(result, expected))
|
21
|
+
|
22
|
+
def test_map_multiple_inputs_and_batch_size(self):
|
23
|
+
xs = jnp.arange(5.0)
|
24
|
+
ys = jnp.ones_like(xs) * 2.0
|
25
|
+
|
26
|
+
def fn(a, b):
|
27
|
+
return a * a + b
|
28
|
+
|
29
|
+
expected = jax.vmap(fn)(xs, ys)
|
30
|
+
result = bst_map(fn, xs, ys, batch_size=2)
|
31
|
+
self.assertTrue(jnp.allclose(result, expected))
|
32
|
+
|
33
|
+
|
34
|
+
class TestVmapIntegration(unittest.TestCase):
|
35
|
+
def test_decorator_batched_stateful_function(self):
|
36
|
+
counter = bst.ShortTermState(jnp.zeros(3))
|
37
|
+
|
38
|
+
@vmap(
|
39
|
+
in_axes=0,
|
40
|
+
out_axes=0,
|
41
|
+
state_in_axes={0: state_filter.OfType(bst.ShortTermState)},
|
42
|
+
state_out_axes={0: state_filter.OfType(bst.ShortTermState)},
|
80
43
|
)
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
return
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
print("Parameters:", params)
|
141
|
-
print("Batch X:", batch_x)
|
142
|
-
print("Batch Y:", batch_y)
|
143
|
-
|
144
|
-
# Compute individual gradients first
|
145
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
146
|
-
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
147
|
-
print("Individual gradients:", individual_grads)
|
148
|
-
|
149
|
-
# Now compute averaged gradients using axis names
|
150
|
-
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
151
|
-
print("Averaged gradients:", averaged_grads)
|
152
|
-
|
153
|
-
return individual_grads, averaged_grads
|
154
|
-
|
155
|
-
def _gradient_averaging_simulation_jax(self):
|
156
|
-
def loss_function(params, x, y):
|
157
|
-
"""Simple quadratic loss"""
|
158
|
-
pred = params * x
|
159
|
-
return (pred - y) ** 2
|
160
|
-
|
161
|
-
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
162
|
-
"""Compute gradients and average them across the batch"""
|
163
|
-
# Compute per-sample gradients
|
164
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
165
|
-
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
166
|
-
|
167
|
-
# Average gradients across batch using named axis
|
168
|
-
def average_grads(grads):
|
169
|
-
return pmean(grads, axis_name='batch')
|
170
|
-
|
171
|
-
# Apply averaging with named axis
|
172
|
-
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
173
|
-
return averaged_grads
|
174
|
-
|
175
|
-
# Example data
|
176
|
-
params = 2.0
|
177
|
-
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
178
|
-
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
179
|
-
|
180
|
-
print("Parameters:", params)
|
181
|
-
print("Batch X:", batch_x)
|
182
|
-
print("Batch Y:", batch_y)
|
183
|
-
|
184
|
-
# Compute individual gradients first
|
185
|
-
grad_fn = jax.grad(loss_function, argnums=0)
|
186
|
-
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
187
|
-
print("Individual gradients:", individual_grads)
|
188
|
-
|
189
|
-
# Now compute averaged gradients using axis names
|
190
|
-
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
191
|
-
print("Averaged gradients:", averaged_grads)
|
192
|
-
|
193
|
-
return individual_grads, averaged_grads
|
44
|
+
def accumulate(x):
|
45
|
+
counter.value = counter.value + x
|
46
|
+
return counter.value
|
47
|
+
|
48
|
+
xs = jnp.asarray([1.0, 2.0, 3.0])
|
49
|
+
result = accumulate(xs)
|
50
|
+
self.assertTrue(jnp.allclose(result, xs))
|
51
|
+
self.assertTrue(jnp.allclose(counter.value, xs))
|
52
|
+
|
53
|
+
def test_vmap_partial_returns_stateful_mapping(self):
|
54
|
+
builder = vmap(in_axes=0, out_axes=0)
|
55
|
+
|
56
|
+
def fn(x):
|
57
|
+
return x * 2.0
|
58
|
+
|
59
|
+
mapped = builder(fn)
|
60
|
+
self.assertIsInstance(mapped, StatefulMapping)
|
61
|
+
xs = jnp.arange(3.0)
|
62
|
+
self.assertTrue(jnp.allclose(mapped(xs), xs * 2.0))
|
63
|
+
|
64
|
+
|
65
|
+
class TestVmapNewStates(unittest.TestCase):
|
66
|
+
def test_new_states_are_vectorized(self):
|
67
|
+
@vmap_new_states(in_axes=0, out_axes=0)
|
68
|
+
def build(x):
|
69
|
+
scratch = bst.ShortTermState(jnp.array(0.0), tag='scratch')
|
70
|
+
scratch.value = scratch.value + x
|
71
|
+
return scratch.value
|
72
|
+
|
73
|
+
xs = jnp.arange(4.0)
|
74
|
+
result_first = build(xs)
|
75
|
+
result_second = build(xs)
|
76
|
+
self.assertTrue(jnp.allclose(result_first, xs))
|
77
|
+
self.assertTrue(jnp.allclose(result_second, xs))
|
78
|
+
|
79
|
+
|
80
|
+
class TestPmapIntegration(unittest.TestCase):
|
81
|
+
@unittest.skipIf(jax.local_device_count() < 2, "Requires at least 2 devices")
|
82
|
+
def test_pmap_stateful_execution(self):
|
83
|
+
param = bst.ParamState(jnp.ones((4,)))
|
84
|
+
|
85
|
+
@pmap(
|
86
|
+
in_axes=0,
|
87
|
+
out_axes=0,
|
88
|
+
axis_name='devices',
|
89
|
+
state_in_axes={0: state_filter.OfType(bst.ParamState)},
|
90
|
+
state_out_axes={0: state_filter.OfType(bst.ParamState)},
|
91
|
+
)
|
92
|
+
def update(delta):
|
93
|
+
param.value = param.value + delta
|
94
|
+
return param.value
|
95
|
+
|
96
|
+
device_count = jax.local_device_count()
|
97
|
+
deltas = jnp.arange(device_count * 4.0, dtype=param.value.dtype).reshape(device_count, 4)
|
98
|
+
updated = update(deltas)
|
99
|
+
self.assertEqual(updated.shape, (device_count, 4))
|
100
|
+
self.assertTrue(jnp.all(updated >= 1.0))
|
101
|
+
|
194
102
|
|
103
|
+
if __name__ == "__main__":
|
104
|
+
unittest.main()
|
@@ -1,55 +1,55 @@
|
|
1
|
-
brainstate/__init__.py,sha256=
|
1
|
+
brainstate/__init__.py,sha256=XZZctdLIf_6-IJIbeTyoEHR6r3pyKF2n6MfikCffASA,5915
|
2
2
|
brainstate/_compatible_import.py,sha256=7thV_2F0FD5AF2DETjBfmtNb_2ZQzki8NxFgC62frg0,11037
|
3
3
|
brainstate/_compatible_import_test.py,sha256=6ka26Sa_Kk6F-Ar1HR6UaKJTHquXcUCWglgXBUOovcg,22762
|
4
4
|
brainstate/_deprecation.py,sha256=gSh36_TWLBgQAo0gNfOzscV9ssa26k3te9y25BG6O2w,8381
|
5
|
-
brainstate/_deprecation_test.py,sha256=
|
5
|
+
brainstate/_deprecation_test.py,sha256=vnmcZ7l_FmQRgsMYUA9wMiWL1ltPc01plUTK-pWKCzk,87550
|
6
6
|
brainstate/_error.py,sha256=6A5ILy17ZMMZIjS8LkajTZBDRnwv_Qait5x__h2Levo,1522
|
7
|
-
brainstate/_state.py,sha256=
|
8
|
-
brainstate/_state_test.py,sha256=
|
7
|
+
brainstate/_state.py,sha256=G5SCjxi42Lb1EoqaIZcyy0fmjNonwrQLLb4ujKt_99E,77867
|
8
|
+
brainstate/_state_test.py,sha256=ptul4-BHagOIhnAvnWmb2BFyXnIfWpK0rQClLxhzLBY,39895
|
9
9
|
brainstate/_utils.py,sha256=cmUyO9ds1etrrpV4ucp1G8mDqE15g4ZtbivblH_cD9o,1613
|
10
10
|
brainstate/environ.py,sha256=BmQsvo1aZaMpckHXlJ45dZh9DUdnHHP9Q9JdLNCa9wA,42169
|
11
11
|
brainstate/environ_test.py,sha256=RdVmeP7irbk3_qNjwWoy-DSdpkTRpxqAABtNGZcbB2w,42418
|
12
|
-
brainstate/mixin.py,sha256=
|
12
|
+
brainstate/mixin.py,sha256=QoZPhBs4hGXTzTfXaF0M8XhsbcbOtT4Ylm54D7pCOSA,45124
|
13
13
|
brainstate/mixin_test.py,sha256=6WmqJf34kT3Z5WaiCNDo3OV3ci0DIHCQp474zECDxEU,34718
|
14
14
|
brainstate/typing.py,sha256=pYiNI-9oHpH7HfjRKYxugK03KGiamCwweagMyO0rsi4,26301
|
15
15
|
brainstate/typing_test.py,sha256=2mmMW0uAzIo3_VXpT5Boq79BohxYkzBlHexBysFUGII,26240
|
16
16
|
brainstate/graph/__init__.py,sha256=kGVtHAnkiWR5MqDYQU0G3AobWnioGeDqjILA--RyDz8,846
|
17
17
|
brainstate/graph/_node.py,sha256=_XH8xx6_glsCK4KCsQnarACK8meyhCdfh3nWfUDko0k,6407
|
18
18
|
brainstate/graph/_node_test.py,sha256=sD2DS0AhDKOU5ZQm0cYz0llnJ6D60ftNfDpztk4i8cM,18687
|
19
|
-
brainstate/graph/_operation.py,sha256=
|
19
|
+
brainstate/graph/_operation.py,sha256=5EN_qL08M_BIV-L3YkVP3OwlsDuNKmVV5NaCLvrl7kw,53971
|
20
20
|
brainstate/graph/_operation_test.py,sha256=IVyrJh4io3sDgtrTEIAItGzNs2XEf7rO1rvI1r_KiII,39119
|
21
|
-
brainstate/nn/__init__.py,sha256=
|
21
|
+
brainstate/nn/__init__.py,sha256=JcaVVt01LNbAwxHymNmBUFAZwchtucyHq_sT70rvYc4,5245
|
22
22
|
brainstate/nn/_activations.py,sha256=6jHR67obYR1lpo-imVXmfd3m_NDyU0XZb8t-pVYDvUU,26917
|
23
23
|
brainstate/nn/_activations_test.py,sha256=Ikr8RYBaIpApVKUhY-XAWr6llEG7vYWS8YuqDHyTtBY,13438
|
24
|
-
brainstate/nn/_collective_ops.py,sha256=
|
24
|
+
brainstate/nn/_collective_ops.py,sha256=AHda68XjoNUyLOwxRhwQLnXr4zLBqnB8UDLAkfx2eAg,21282
|
25
25
|
brainstate/nn/_collective_ops_test.py,sha256=8mKQkfTjfwuO7DA1i_Yr4QD1yg2ZJgVpX7zhHFo9CuQ,25600
|
26
26
|
brainstate/nn/_common.py,sha256=UyJMJoVF9KfrToOX5Dbv-2s3CD49SsroLRfL17DLl4Q,7184
|
27
|
-
brainstate/nn/_common_test.py,sha256=
|
27
|
+
brainstate/nn/_common_test.py,sha256=xztvEPYCurvY9LiaZOtHjWcFrwJaYknA8_RLy_XHI1Y,4895
|
28
28
|
brainstate/nn/_conv.py,sha256=3cGToc5UoGN5jp4BUHlrK6md8O_0IcWMCBXFHBQg7nE,82106
|
29
29
|
brainstate/nn/_conv_test.py,sha256=65FlrteUxLQb8ckUUJFhaPZDsQZTlGVAp0HUGUJtt1M,30173
|
30
|
-
brainstate/nn/_delay.py,sha256=
|
30
|
+
brainstate/nn/_delay.py,sha256=qssPPKgWeS-vcxKJG6ewn2bQ1KvdHHBrbI6bIJytSfI,22332
|
31
31
|
brainstate/nn/_delay_test.py,sha256=FzBb8vXfse8HEcEid83hpa6aag6oj90mtcHYsDs0DOE,10376
|
32
32
|
brainstate/nn/_dropout.py,sha256=UotjW0PQO4gypfhtSqzkR4UVGkY0kNBncEB_poGL7Sc,22555
|
33
|
-
brainstate/nn/_dropout_test.py,sha256=
|
34
|
-
brainstate/nn/_dynamics.py,sha256=
|
35
|
-
brainstate/nn/_dynamics_test.py,sha256=
|
33
|
+
brainstate/nn/_dropout_test.py,sha256=87ETtMZ7Aeaa0les7Z5ICO9KVKMewm_WWd-ITCFDo_w,21589
|
34
|
+
brainstate/nn/_dynamics.py,sha256=H83WUFWv6ZLBEn01U30qInjgfh8WAgZ5jwgXWA9TEkg,29161
|
35
|
+
brainstate/nn/_dynamics_test.py,sha256=XwTBCXXKN8hAaLFtQjmWEOf-ry0MJZmfgx9dAju-rXA,1744
|
36
36
|
brainstate/nn/_elementwise.py,sha256=4kKzrbKn5luwnpY8n7IeaMOtBVVie8oPHEc64hSn8-w,34858
|
37
37
|
brainstate/nn/_elementwise_test.py,sha256=sbWlUyTB8oiu3PRHObTvmUaob99EjIGYE6k6bEpg6K8,27296
|
38
|
-
brainstate/nn/_embedding.py,sha256=
|
38
|
+
brainstate/nn/_embedding.py,sha256=MXkGja6uCB2OsbwqArEQ8ez_GgwZ2s3DINaUnFyeZWk,14991
|
39
39
|
brainstate/nn/_embedding_test.py,sha256=Gc0y6gHMEagaDrBJoAYQZMDTd47TYQNVrenWUwLWK_w,6242
|
40
40
|
brainstate/nn/_event_fixedprob.py,sha256=ZEnIyjDksxtUWWG5GXLcF-RHR1S33DZTw-rN2lHKs0g,9395
|
41
41
|
brainstate/nn/_event_fixedprob_test.py,sha256=rvTKxEzKwvctQc8-AxXjJ4p4D-if1va2O8KqQYW5nxY,3836
|
42
42
|
brainstate/nn/_event_linear.py,sha256=d0J54Sf9zBl926BzXoy4Oc1p96h9veU3f8YhWrZLRPk,2554
|
43
43
|
brainstate/nn/_event_linear_test.py,sha256=qzcGplDIwxTnZOs4JzD5GX_oNBYtcYFNaf3GpWo8pZY,3765
|
44
|
-
brainstate/nn/_exp_euler.py,sha256=
|
44
|
+
brainstate/nn/_exp_euler.py,sha256=deKqWu5RvN_Jvj_QHg5j1xkKhlysBFhi3Qvo_Z-QNDo,8698
|
45
45
|
brainstate/nn/_exp_euler_test.py,sha256=21qomGOo96YLmlEQok2hCByAcRpmqREhSr3kxmhKOm8,13014
|
46
|
-
brainstate/nn/_linear.py,sha256=
|
46
|
+
brainstate/nn/_linear.py,sha256=LLWBE6eBXjpXUvx5mjdwOjFnolNRGBxNb_ZbhsYOu4M,24201
|
47
47
|
brainstate/nn/_linear_test.py,sha256=5fHx4v4_54dH3Bsyapl2cobvVtMEu3DFR-jDKLbkJFw,17876
|
48
48
|
brainstate/nn/_metrics.py,sha256=TgALwv6i9La4Dm1WAkWDWxvxr9rkd7CJLGFy2sOGQbQ,36481
|
49
49
|
brainstate/nn/_metrics_test.py,sha256=XZiRndchRgEH0X8zsHzg0fsHNMxj43mnd883QChfSik,24104
|
50
|
-
brainstate/nn/_module.py,sha256=
|
50
|
+
brainstate/nn/_module.py,sha256=W9iXXmLxaX_QyNxukTxajHJfah7s9GqJXZRHQhZqCfU,13076
|
51
51
|
brainstate/nn/_module_test.py,sha256=znjB7FU5evJENQ1Pqw7ZlOGC5faQe4-4VpjW60H8UWI,1414
|
52
|
-
brainstate/nn/_normalizations.py,sha256=
|
52
|
+
brainstate/nn/_normalizations.py,sha256=3ivJdQUhhWeMhDpzWhXhRQnu8dRR4zDzM1cOlPvPo2s,50374
|
53
53
|
brainstate/nn/_normalizations_test.py,sha256=y5n7aTaUHRkyAAjq4Oj8dfButJC3ehm4KdUB7226Bow,23350
|
54
54
|
brainstate/nn/_paddings.py,sha256=3u3dbRFtPSlIsLMBYZmHcYDQ6HFl0u_d-yP7ZoYcCrA,32415
|
55
55
|
brainstate/nn/_paddings_test.py,sha256=uY9CRexf9sM5V6AzTfzm8214y4IoGqEbahy92yaKsWM,27409
|
@@ -61,36 +61,36 @@ brainstate/nn/_utils.py,sha256=VK-Se53e1q-Ip4AtMOZ3SUzYw8u2UllLJRLRtEFRCRE,7403
|
|
61
61
|
brainstate/nn/_utils_test.py,sha256=uim2SkfNHrBZzNDvN0WOK8qeZC1kaeOd-UQDvrn_M24,14266
|
62
62
|
brainstate/nn/init.py,sha256=7iLHrL-ZHpU-g5d0PlusaqmtkWO7X_KNr2eLx90oHrc,25656
|
63
63
|
brainstate/nn/init_test.py,sha256=bfby6kovvhbc7CCEaohtQawUQj3w3GvaMSDAdTiP2ps,6200
|
64
|
-
brainstate/random/__init__.py,sha256=
|
65
|
-
brainstate/random/
|
66
|
-
brainstate/random/
|
67
|
-
brainstate/random/
|
68
|
-
brainstate/random/
|
69
|
-
brainstate/random/
|
70
|
-
brainstate/random/
|
71
|
-
brainstate/
|
64
|
+
brainstate/random/__init__.py,sha256=yeWQ3RUcFXtXcDkFhaOYa_nwQ_M7hlH_W69YEcGF3Oo,8351
|
65
|
+
brainstate/random/_fun.py,sha256=fW2bc2i15sQNodmKT66ZCfXJSIZ8Ygxmn_Xf80Zjygw,135334
|
66
|
+
brainstate/random/_fun_test.py,sha256=EcHOcCOsmZPbm1n5TUyiTDxS13Pn3MJbq5IG-twbnPI,22866
|
67
|
+
brainstate/random/_impl.py,sha256=A79IK8YNZuN8RGLftASKHcZI0Cdsep7mGYNtCfKl2fQ,21830
|
68
|
+
brainstate/random/_seed.py,sha256=mLHqOu-lJQjsXo4nODACY78SSKgYayHx3n5IBZ9L6J0,24922
|
69
|
+
brainstate/random/_seed_test.py,sha256=Y2VCAkUzciDaCfYZWPe_Ewmi3MylK-WzfPA7TzorV8Q,1491
|
70
|
+
brainstate/random/_state.py,sha256=JPK6-jqFwrzsrb3lrxZ4GahP0TtXycUfWrJSdjHIEg8,42946
|
71
|
+
brainstate/random/_state_test.py,sha256=OfW0WxTpJZm_kT_7bZjJJ8ZtLM_X8NBF90h1sDSZxmw,19221
|
72
|
+
brainstate/transform/__init__.py,sha256=P7MAmt4pJYcpxLO30gFoT4BO_AHHM3VNOoG2j3OkNRk,2126
|
72
73
|
brainstate/transform/_ad_checkpoint.py,sha256=4dcNCEQVV_CPMSkE32URERDMpQHbyfdGeLT_Nvhyd4o,6912
|
73
74
|
brainstate/transform/_ad_checkpoint_test.py,sha256=fPXBjDxsLHbL2mhIU3x_F5BpitkXLpgIsRCRgm2Us6w,1697
|
74
75
|
brainstate/transform/_autograd.py,sha256=4zGSYa9TMn6bqzPJNLfU9UZGZRyYxmwMXTKZWO4w3QQ,39991
|
75
76
|
brainstate/transform/_autograd_test.py,sha256=saWG1_k3cRXpsyQDzQkOLGvsF7IIxG9aGnjrf5B3HNk,44112
|
76
|
-
brainstate/transform/_conditions.py,sha256=
|
77
|
+
brainstate/transform/_conditions.py,sha256=nLc_m0bLlybEPTi42feSlz2zTDIHB-BknSqPTIL_I0w,11376
|
77
78
|
brainstate/transform/_conditions_test.py,sha256=MEuqRq6IFmyORRDi0qWvNo4pWKFyc8aNrW1v9Saqxj0,8493
|
78
79
|
brainstate/transform/_error_if.py,sha256=e9tp3wT5p4bEyjn_Za_SrPNOG3OIoPBMIrvG2CsZzvw,2680
|
79
80
|
brainstate/transform/_error_if_test.py,sha256=yn-qcZ6lZUWciIif4fJOpdpKzJFAAdfgzZm6FfPeq7U,1848
|
80
|
-
brainstate/transform/
|
81
|
-
brainstate/transform/
|
81
|
+
brainstate/transform/_find_state.py,sha256=nUrVp_DUP78E2H8UHyhH8RL03kwcKjepnkDTqbISFmI,7405
|
82
|
+
brainstate/transform/_find_state_test.py,sha256=KahI4HSj6MTlC4ccCPyLiyERy3MryZd_iczGwFXb2bM,2913
|
82
83
|
brainstate/transform/_jit.py,sha256=qYsL3Z9nAAW0UyQe_AyvBEuJvqAan6iw3lN49o1oC0A,15421
|
83
84
|
brainstate/transform/_jit_test.py,sha256=ecw54dGQYdJq2J94itPrXBQrSDNvCY_htUD7z7y4HUM,4013
|
84
85
|
brainstate/transform/_loop_collect_return.py,sha256=HhjC2gq6qzliw4ofP16VxdtR5hW-NmDZdeHxuiLdYGk,25899
|
85
86
|
brainstate/transform/_loop_collect_return_test.py,sha256=BVK-b3CuDtTXciRaA_8t4751N4taQOnIPNzAelSts-k,1753
|
86
87
|
brainstate/transform/_loop_no_collection.py,sha256=ArPpNemMh4jJsq_vUWPxuagCnxTlONN--3P_-44qYq8,10156
|
87
88
|
brainstate/transform/_loop_no_collection_test.py,sha256=3bRo9_Oaypbw3asEevrgTK0WksxDAIZKJgeaWpt7nl8,1371
|
88
|
-
brainstate/transform/_make_jaxpr.py,sha256=
|
89
|
-
brainstate/transform/_make_jaxpr_test.py,sha256=
|
90
|
-
brainstate/transform/_mapping.py,sha256=
|
91
|
-
brainstate/transform/_mapping_test.py,sha256=
|
89
|
+
brainstate/transform/_make_jaxpr.py,sha256=_pSw1oumQ504Emvcgy7eXBTxOw1c1pbnvbcdveKMG8s,80444
|
90
|
+
brainstate/transform/_make_jaxpr_test.py,sha256=EfY-ZKR-u-aLfjOTPY6QiwEzZU-QpIUsB5oCCZi5Z-I,53014
|
91
|
+
brainstate/transform/_mapping.py,sha256=wBTvmXSg8TW-ZjIhKU5nQLLb6BIkLTCaJWwESZ72Hks,22021
|
92
|
+
brainstate/transform/_mapping_test.py,sha256=BkoL9peJGiJvkp7aLErrB6GnDasSviMbBkBqh1yD8LM,3255
|
92
93
|
brainstate/transform/_progress_bar.py,sha256=kZ-mI5hbUQXhqKFVyo0qeKG_LvrR9ZIar7WkXyOeET0,8961
|
93
|
-
brainstate/transform/_random.py,sha256=ZTH5Smx5SvFy6El7qk-ihoYE9WII6TJXcsC9Vm7VgjA,5259
|
94
94
|
brainstate/transform/_unvmap.py,sha256=cW6fjs5Iy1YBB6Nx2mxlM4IzV8U99bEX5QjT8rBRDho,6319
|
95
95
|
brainstate/transform/_util.py,sha256=IYqJj7oyAYzm_m3d9WEsUQRKDdVLQaAWpwM5O8PD4YQ,11304
|
96
96
|
brainstate/util/__init__.py,sha256=anHdG5BIsMqcBQy7gt8lErKInZv1wf2NOLJGUqltyAQ,1154
|
@@ -104,8 +104,8 @@ brainstate/util/filter.py,sha256=wY_XUF3OhrXSV1bZkTcVhlEPba4HP1l9N5aRW2zgxqQ,274
|
|
104
104
|
brainstate/util/filter_test.py,sha256=ZfrEeOc1yMHYzrcSR3p4jbZGj7c_tXC8VcPq7H13q8E,31653
|
105
105
|
brainstate/util/struct.py,sha256=LYPLGDGfPuw14hhx5k4rb8msSH3yZPdAu_0CvjxPWwE,24505
|
106
106
|
brainstate/util/struct_test.py,sha256=q_fWsUH1ON35DKjUUAMq6VtYglqTDyBvd6WMVGD89EI,16526
|
107
|
-
brainstate-0.2.
|
108
|
-
brainstate-0.2.
|
109
|
-
brainstate-0.2.
|
110
|
-
brainstate-0.2.
|
111
|
-
brainstate-0.2.
|
107
|
+
brainstate-0.2.2.dist-info/licenses/LICENSE,sha256=RJ40fox7u2in2H8wvIS5DsPGlNHaA7JI024thFUlaZE,11348
|
108
|
+
brainstate-0.2.2.dist-info/METADATA,sha256=tbhbVxomU2orW-T3WaEjVtYAyZSn4T4uJxAXzPmU9JY,4421
|
109
|
+
brainstate-0.2.2.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
110
|
+
brainstate-0.2.2.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
111
|
+
brainstate-0.2.2.dist-info/RECORD,,
|
@@ -1,145 +0,0 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from typing import Any, TypeVar, Callable, Sequence, Union
|
18
|
-
|
19
|
-
import jax
|
20
|
-
|
21
|
-
from brainstate import random
|
22
|
-
from brainstate._utils import set_module_as
|
23
|
-
from brainstate.graph import Node, flatten, unflatten
|
24
|
-
from ._random import restore_rngs
|
25
|
-
|
26
|
-
__all__ = [
|
27
|
-
'abstract_init',
|
28
|
-
]
|
29
|
-
|
30
|
-
A = TypeVar('A')
|
31
|
-
|
32
|
-
|
33
|
-
@set_module_as('brainstate.transform')
|
34
|
-
def abstract_init(
|
35
|
-
fn: Callable[..., A],
|
36
|
-
*args: Any,
|
37
|
-
rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
|
38
|
-
**kwargs: Any,
|
39
|
-
) -> A:
|
40
|
-
"""
|
41
|
-
Compute the shape/dtype of ``fn`` without any FLOPs.
|
42
|
-
|
43
|
-
This function evaluates the shape and dtype of the output of a function without
|
44
|
-
actually executing the computational operations. It's particularly useful for
|
45
|
-
initializing neural network models to understand their structure and parameter
|
46
|
-
shapes without performing expensive computations.
|
47
|
-
|
48
|
-
Parameters
|
49
|
-
----------
|
50
|
-
fn : callable
|
51
|
-
The function whose output shape should be evaluated.
|
52
|
-
*args
|
53
|
-
Positional argument tuple of arrays, scalars, or (nested) standard
|
54
|
-
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
55
|
-
those types. Since only the ``shape`` and ``dtype`` attributes are
|
56
|
-
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
57
|
-
that duck-types as ndarrays (note however that duck-typed objects cannot
|
58
|
-
be namedtuples because those are treated as standard Python containers).
|
59
|
-
rngs : RandomState or sequence of RandomState, default random.DEFAULT
|
60
|
-
A :class:`RandomState` or a sequence of :class:`RandomState` objects
|
61
|
-
representing the random number generators to use. If not provided, the
|
62
|
-
default random number generator will be used.
|
63
|
-
**kwargs
|
64
|
-
Keyword argument dict of arrays, scalars, or (nested) standard
|
65
|
-
Python containers (pytrees) of those types. As in ``args``, array values
|
66
|
-
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
67
|
-
|
68
|
-
Returns
|
69
|
-
-------
|
70
|
-
A
|
71
|
-
A nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves,
|
72
|
-
representing the structure and shape/dtype information of the function output.
|
73
|
-
|
74
|
-
Examples
|
75
|
-
--------
|
76
|
-
Basic usage with neural network initialization:
|
77
|
-
|
78
|
-
.. code-block:: python
|
79
|
-
|
80
|
-
>>> import brainstate
|
81
|
-
>>> import jax.numpy as jnp
|
82
|
-
>>>
|
83
|
-
>>> class MLP:
|
84
|
-
... def __init__(self, n_in, n_mid, n_out):
|
85
|
-
... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
86
|
-
... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
87
|
-
>>>
|
88
|
-
>>> # Get shape information without actual computation
|
89
|
-
>>> model_shape = brainstate.transform.abstract_init(lambda: MLP(1, 2, 3))
|
90
|
-
|
91
|
-
With function arguments:
|
92
|
-
|
93
|
-
.. code-block:: python
|
94
|
-
|
95
|
-
>>> def create_model(input_size, hidden_size, output_size):
|
96
|
-
... return brainstate.nn.Sequential([
|
97
|
-
... brainstate.nn.Linear(input_size, hidden_size),
|
98
|
-
... brainstate.nn.ReLU(),
|
99
|
-
... brainstate.nn.Linear(hidden_size, output_size)
|
100
|
-
... ])
|
101
|
-
>>>
|
102
|
-
>>> # Abstract initialization with arguments
|
103
|
-
>>> model_shape = brainstate.transform.abstract_init(
|
104
|
-
... create_model, 784, 256, 10
|
105
|
-
... )
|
106
|
-
|
107
|
-
Using custom random number generators:
|
108
|
-
|
109
|
-
.. code-block:: python
|
110
|
-
|
111
|
-
>>> import brainstate.random as random
|
112
|
-
>>>
|
113
|
-
>>> # Create custom RNG
|
114
|
-
>>> rng = random.RandomState(42)
|
115
|
-
>>>
|
116
|
-
>>> def init_with_custom_weights():
|
117
|
-
... return brainstate.nn.Linear(10, 5)
|
118
|
-
>>>
|
119
|
-
>>> model_shape = brainstate.transform.abstract_init(
|
120
|
-
... init_with_custom_weights, rngs=rng
|
121
|
-
... )
|
122
|
-
|
123
|
-
Evaluating function with array inputs:
|
124
|
-
|
125
|
-
.. code-block:: python
|
126
|
-
|
127
|
-
>>> def model_forward(x):
|
128
|
-
... layer = brainstate.nn.Linear(x.shape[-1], 128)
|
129
|
-
... return layer(x)
|
130
|
-
>>>
|
131
|
-
>>> # Use ShapeDtypeStruct to represent input without actual data
|
132
|
-
>>> input_shape = jax.ShapeDtypeStruct((32, 784), jnp.float32)
|
133
|
-
>>> output_shape = brainstate.transform.abstract_init(model_forward, input_shape)
|
134
|
-
"""
|
135
|
-
|
136
|
-
@functools.wraps(fn)
|
137
|
-
@restore_rngs(rngs=rngs)
|
138
|
-
def _eval_shape_fn(*args_, **kwargs_):
|
139
|
-
out = fn(*args_, **kwargs_)
|
140
|
-
assert isinstance(out, Node), 'The output of the function must be Node'
|
141
|
-
graph_def, treefy_states = flatten(out)
|
142
|
-
return graph_def, treefy_states
|
143
|
-
|
144
|
-
graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
|
145
|
-
return unflatten(graph_def_, treefy_states_)
|
@@ -1,38 +0,0 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import unittest
|
18
|
-
|
19
|
-
import brainstate
|
20
|
-
|
21
|
-
|
22
|
-
class TestEvalShape(unittest.TestCase):
|
23
|
-
def test1(self):
|
24
|
-
class MLP(brainstate.nn.Module):
|
25
|
-
def __init__(self, n_in, n_mid, n_out):
|
26
|
-
super().__init__()
|
27
|
-
self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
28
|
-
self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
29
|
-
|
30
|
-
def __call__(self, x):
|
31
|
-
x = self.dense1(x)
|
32
|
-
x = brainstate.functional.relu(x)
|
33
|
-
x = self.dense2(x)
|
34
|
-
return x
|
35
|
-
|
36
|
-
r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
|
37
|
-
print(r)
|
38
|
-
print(brainstate.random.DEFAULT)
|