brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +6 -3
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -3
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_xla_custom_op.py +7 -3
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
brainstate/event/_csr_test.py
CHANGED
@@ -18,6 +18,9 @@
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import brainunit as u
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
21
24
|
|
22
25
|
import brainstate as bst
|
23
26
|
|
@@ -88,3 +91,187 @@ class TestCSR(unittest.TestCase):
|
|
88
91
|
v @ csr
|
89
92
|
)
|
90
93
|
)
|
94
|
+
|
95
|
+
|
96
|
+
def _get_csr(n_pre, n_post, prob):
|
97
|
+
n_conn = int(n_post * prob)
|
98
|
+
indptr = np.arange(n_pre + 1) * n_conn
|
99
|
+
indices = np.random.randint(0, n_post, (n_pre * n_conn,))
|
100
|
+
return indptr, indices
|
101
|
+
|
102
|
+
|
103
|
+
def vector_csr(x, w, indices, indptr, shape):
|
104
|
+
homo_w = jnp.size(w) == 1
|
105
|
+
post = jnp.zeros((shape[1],))
|
106
|
+
for i_pre in range(x.shape[0]):
|
107
|
+
ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
|
108
|
+
post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
|
109
|
+
return post
|
110
|
+
|
111
|
+
|
112
|
+
def matrix_csr(xs, w, indices, indptr, shape):
|
113
|
+
homo_w = jnp.size(w) == 1
|
114
|
+
post = jnp.zeros((xs.shape[0], shape[1]))
|
115
|
+
for i_pre in range(xs.shape[1]):
|
116
|
+
ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
|
117
|
+
post = post.at[:, ids].add(
|
118
|
+
w * xs[:, i_pre: i_pre + 1]
|
119
|
+
if homo_w else
|
120
|
+
(w[indptr[i_pre]: indptr[i_pre + 1]] * xs[:, i_pre: i_pre + 1])
|
121
|
+
)
|
122
|
+
return post
|
123
|
+
|
124
|
+
|
125
|
+
def csr_vector(x, w, indices, indptr, shape):
|
126
|
+
homo_w = jnp.size(w) == 1
|
127
|
+
out = jnp.zeros([shape[0]])
|
128
|
+
for i in range(shape[0]):
|
129
|
+
ids = indices[indptr[i]: indptr[i + 1]]
|
130
|
+
ws = w if homo_w else w[indptr[i]: indptr[i + 1]]
|
131
|
+
out = out.at[i].set(jnp.sum(x[ids] * ws))
|
132
|
+
return out
|
133
|
+
|
134
|
+
|
135
|
+
def csr_matrix(xs, w, indices, indptr, shape):
|
136
|
+
# CSR @ matrix
|
137
|
+
homo_w = jnp.size(w) == 1
|
138
|
+
out = jnp.zeros([shape[0], xs.shape[1]])
|
139
|
+
for i in range(shape[0]):
|
140
|
+
ids = indices[indptr[i]: indptr[i + 1]]
|
141
|
+
ws = w if homo_w else jnp.expand_dims(w[indptr[i]: indptr[i + 1]], axis=1)
|
142
|
+
out = out.at[i].set(jnp.sum(xs[ids] * ws, axis=0))
|
143
|
+
return out
|
144
|
+
|
145
|
+
|
146
|
+
class TestVectorCSR(unittest.TestCase):
|
147
|
+
def test_vector_csr(self, ):
|
148
|
+
m, n = 20, 40
|
149
|
+
x = bst.random.rand(m) < 0.1
|
150
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
151
|
+
|
152
|
+
for homo_w in [True, False]:
|
153
|
+
print(f'homo_w = {homo_w}')
|
154
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
155
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
156
|
+
y = x @ csr
|
157
|
+
y2 = vector_csr(x, csr.data, indices, indptr, [m, n])
|
158
|
+
self.assertTrue(jnp.allclose(y, y2))
|
159
|
+
|
160
|
+
def test_vector_csr_vmap_vector(self):
|
161
|
+
n_batch, m, n = 10, 20, 40
|
162
|
+
xs = bst.random.rand(n_batch, m) < 0.1
|
163
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
164
|
+
|
165
|
+
for homo_w in [True, False]:
|
166
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
167
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
168
|
+
y = jax.vmap(lambda x: x @ csr)(xs)
|
169
|
+
y2 = jax.vmap(lambda x: vector_csr(x, csr.data, indices, indptr, [m, n]))(xs)
|
170
|
+
self.assertTrue(jnp.allclose(y, y2))
|
171
|
+
|
172
|
+
|
173
|
+
class TestMatrixCSR(unittest.TestCase):
|
174
|
+
def test_matrix_csr(self):
|
175
|
+
k, m, n = 10, 20, 40
|
176
|
+
x = bst.random.rand(k, m) < 0.1
|
177
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
178
|
+
|
179
|
+
for homo_w in [True, False]:
|
180
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
181
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
182
|
+
y = x @ csr
|
183
|
+
y2 = matrix_csr(x, csr.data, indices, indptr, [m, n])
|
184
|
+
self.assertTrue(jnp.allclose(y, y2))
|
185
|
+
|
186
|
+
|
187
|
+
class TestCSRVector(unittest.TestCase):
|
188
|
+
def test_csr_vector(self):
|
189
|
+
m, n = 20, 40
|
190
|
+
v = bst.random.rand(n) < 0.1
|
191
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
192
|
+
|
193
|
+
for homo_w in [True, False]:
|
194
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
195
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
196
|
+
y = csr @ v
|
197
|
+
y2 = csr_vector(v, csr.data, indices, indptr, [m, n])
|
198
|
+
self.assertTrue(jnp.allclose(y, y2))
|
199
|
+
|
200
|
+
|
201
|
+
class TestCSRMatrix(unittest.TestCase):
|
202
|
+
def test_csr_matrix(self):
|
203
|
+
m, n, k = 20, 40, 10
|
204
|
+
matrix = bst.random.rand(n, k) < 0.1
|
205
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
206
|
+
|
207
|
+
for homo_w in [True, False]:
|
208
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
209
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
210
|
+
y = csr @ matrix
|
211
|
+
y2 = csr_matrix(matrix, csr.data, indices, indptr, [m, n])
|
212
|
+
self.assertTrue(jnp.allclose(y, y2))
|
213
|
+
|
214
|
+
# @parameterized.product(
|
215
|
+
# bool_x=[True, False],
|
216
|
+
# homo_w=[True, False]
|
217
|
+
# )
|
218
|
+
# def test_vjp(self, bool_x, homo_w):
|
219
|
+
# n_in = 20
|
220
|
+
# n_out = 30
|
221
|
+
# if bool_x:
|
222
|
+
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
223
|
+
# else:
|
224
|
+
# x = bst.random.rand(n_in)
|
225
|
+
#
|
226
|
+
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
227
|
+
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
|
228
|
+
# w = fn.weight.value
|
229
|
+
#
|
230
|
+
# def f(x, w):
|
231
|
+
# fn.weight.value = w
|
232
|
+
# return fn(x).sum()
|
233
|
+
#
|
234
|
+
# r = jax.grad(f, argnums=(0, 1))(x, w)
|
235
|
+
#
|
236
|
+
# # -------------------
|
237
|
+
# # TRUE gradients
|
238
|
+
#
|
239
|
+
# def f2(x, w):
|
240
|
+
# return true_fn(x, w, indices, indptr, n_out).sum()
|
241
|
+
#
|
242
|
+
# r2 = jax.grad(f2, argnums=(0, 1))(x, w)
|
243
|
+
# self.assertTrue(jnp.allclose(r[0], r2[0]))
|
244
|
+
# self.assertTrue(jnp.allclose(r[1], r2[1]))
|
245
|
+
#
|
246
|
+
# @parameterized.product(
|
247
|
+
# bool_x=[True, False],
|
248
|
+
# homo_w=[True, False]
|
249
|
+
# )
|
250
|
+
# def test_jvp(self, bool_x, homo_w):
|
251
|
+
# n_in = 20
|
252
|
+
# n_out = 30
|
253
|
+
# if bool_x:
|
254
|
+
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
255
|
+
# else:
|
256
|
+
# x = bst.random.rand(n_in)
|
257
|
+
#
|
258
|
+
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
259
|
+
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
|
260
|
+
# 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
|
261
|
+
# w = fn.weight.value
|
262
|
+
#
|
263
|
+
# def f(x, w):
|
264
|
+
# fn.weight.value = w
|
265
|
+
# return fn(x)
|
266
|
+
#
|
267
|
+
# o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
268
|
+
#
|
269
|
+
# # -------------------
|
270
|
+
# # TRUE gradients
|
271
|
+
#
|
272
|
+
# def f2(x, w):
|
273
|
+
# return true_fn(x, w, indices, indptr, n_out)
|
274
|
+
#
|
275
|
+
# o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
276
|
+
# self.assertTrue(jnp.allclose(r1, r2))
|
277
|
+
# self.assertTrue(jnp.allclose(o1, o2))
|
@@ -9,7 +9,6 @@ from typing import Callable, Sequence, Tuple, Protocol
|
|
9
9
|
import jax
|
10
10
|
import numpy as np
|
11
11
|
from jax import tree_util
|
12
|
-
from jax.core import Primitive
|
13
12
|
from jax.interpreters import batching, ad
|
14
13
|
from jax.interpreters import xla, mlir
|
15
14
|
from jaxlib.hlo_helpers import custom_call
|
@@ -19,6 +18,11 @@ if jax.__version_info__ < (0, 4, 35):
|
|
19
18
|
else:
|
20
19
|
import jax.extend as je
|
21
20
|
|
21
|
+
if jax.__version_info__ < (0, 4, 38):
|
22
|
+
from jax.core import Primitive
|
23
|
+
else:
|
24
|
+
from jax.extend.core import Primitive
|
25
|
+
|
22
26
|
numba_installed = importlib.util.find_spec('numba') is not None
|
23
27
|
|
24
28
|
__all__ = [
|
@@ -164,7 +168,7 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
|
|
164
168
|
|
165
169
|
|
166
170
|
def register_numba_mlir_cpu_translation_rule(
|
167
|
-
primitive:
|
171
|
+
primitive: Primitive,
|
168
172
|
cpu_kernel: Callable,
|
169
173
|
debug: bool = False
|
170
174
|
):
|
@@ -205,7 +209,7 @@ class XLACustomOp:
|
|
205
209
|
transpose_translation: Callable = None,
|
206
210
|
):
|
207
211
|
# primitive
|
208
|
-
self.primitive =
|
212
|
+
self.primitive = Primitive(name)
|
209
213
|
self.primitive.multiple_results = True
|
210
214
|
|
211
215
|
# abstract evaluation
|
brainstate/graph/__init__.py
CHANGED
@@ -14,20 +14,16 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from ._graph_context import *
|
18
|
-
from ._graph_context import __all__ as _graph_context__all__
|
19
|
-
from ._graph_convert import *
|
20
|
-
from ._graph_convert import __all__ as _graph_convert__all__
|
21
17
|
from ._graph_node import *
|
22
18
|
from ._graph_node import __all__ as _graph_node__all__
|
23
19
|
from ._graph_operation import *
|
24
20
|
from ._graph_operation import __all__ as _graph_operation__all__
|
25
21
|
|
26
|
-
__all__ = (
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
del (
|
31
|
-
|
32
|
-
|
33
|
-
|
22
|
+
__all__ = (
|
23
|
+
_graph_node__all__ +
|
24
|
+
_graph_operation__all__
|
25
|
+
)
|
26
|
+
del (
|
27
|
+
_graph_node__all__,
|
28
|
+
_graph_operation__all__
|
29
|
+
)
|
brainstate/graph/_graph_node.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
@@ -27,9 +27,7 @@ import numpy as np
|
|
27
27
|
|
28
28
|
from brainstate._state import State, TreefyState
|
29
29
|
from brainstate.typing import Key
|
30
|
-
from brainstate.util._error import TraceContextError
|
31
30
|
from brainstate.util._pretty_repr import PrettyRepr, pretty_repr_avoid_duplicate, PrettyType, PrettyAttr
|
32
|
-
from brainstate.util._tracers import StateJaxTracer
|
33
31
|
from ._graph_operation import register_graph_node_type
|
34
32
|
|
35
33
|
__all__ = [
|
@@ -44,7 +42,6 @@ class GraphNodeMeta(ABCMeta):
|
|
44
42
|
if not TYPE_CHECKING:
|
45
43
|
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
46
44
|
node = cls.__new__(cls, *args, **kwargs)
|
47
|
-
vars(node)['_trace_state'] = StateJaxTracer()
|
48
45
|
node.__init__(*args, **kwargs)
|
49
46
|
return node
|
50
47
|
|
@@ -64,9 +61,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
64
61
|
|
65
62
|
graph_invisible_attrs = ()
|
66
63
|
|
67
|
-
if TYPE_CHECKING:
|
68
|
-
_trace_state: StateJaxTracer
|
69
|
-
|
70
64
|
def __init_subclass__(cls) -> None:
|
71
65
|
super().__init_subclass__()
|
72
66
|
|
@@ -79,21 +73,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
79
73
|
clear=_node_clear,
|
80
74
|
)
|
81
75
|
|
82
|
-
# if not TYPE_CHECKING:
|
83
|
-
# def __setattr__(self, name: str, value: Any) -> None:
|
84
|
-
# self._setattr(name, value)
|
85
|
-
|
86
|
-
# def _setattr(self, name: str, value: Any) -> None:
|
87
|
-
# self.check_valid_context(lambda: f"Cannot mutate '{type(self).__name__}' from different trace level")
|
88
|
-
# object.__setattr__(self, name, value)
|
89
|
-
|
90
|
-
def check_valid_context(self, error_msg: Callable[[], str]) -> None:
|
91
|
-
"""
|
92
|
-
Check if the current context is valid for the object to be mutated.
|
93
|
-
"""
|
94
|
-
if not self._trace_state.is_valid():
|
95
|
-
raise TraceContextError(error_msg())
|
96
|
-
|
97
76
|
def __deepcopy__(self: G, memo=None) -> G:
|
98
77
|
"""
|
99
78
|
Deepcopy the object.
|
@@ -214,7 +193,6 @@ def _node_create_empty(
|
|
214
193
|
) -> G:
|
215
194
|
node_type, = static
|
216
195
|
node = object.__new__(node_type)
|
217
|
-
vars(node).update(_trace_state=StateJaxTracer())
|
218
196
|
return node
|
219
197
|
|
220
198
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
@@ -17,13 +17,10 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import unittest
|
19
19
|
from collections.abc import Callable
|
20
|
-
from functools import partial
|
21
20
|
from threading import Thread
|
22
|
-
from typing import Any
|
23
21
|
|
24
22
|
import jax
|
25
23
|
import jax.numpy as jnp
|
26
|
-
import pytest
|
27
24
|
from absl.testing import absltest, parameterized
|
28
25
|
|
29
26
|
import brainstate as bst
|
@@ -354,125 +351,6 @@ class TestGraphUtils(absltest.TestCase):
|
|
354
351
|
assert m2.tree.a is not m.tree.a
|
355
352
|
assert m2.tree is not m.tree
|
356
353
|
|
357
|
-
@pytest.mark.skip(reason='Not implemented')
|
358
|
-
def test_cached_unflatten(self):
|
359
|
-
class Foo(bst.graph.Node):
|
360
|
-
def __init__(self, ):
|
361
|
-
self.a = bst.nn.Linear(2, 2)
|
362
|
-
self.b = bst.nn.BatchNorm1d([10, 2])
|
363
|
-
|
364
|
-
def f(m: Foo):
|
365
|
-
m.a, m.b = m.b, m.a # type: ignore
|
366
|
-
|
367
|
-
m = Foo()
|
368
|
-
a = m.a
|
369
|
-
b = m.b
|
370
|
-
|
371
|
-
ref_out_idx_out = bst.graph.RefMap()
|
372
|
-
graphdef: bst.graph.GraphDef[Foo]
|
373
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
374
|
-
|
375
|
-
@partial(jax.jit, static_argnums=(0,))
|
376
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
377
|
-
idx_out_ref_in: dict[int, Any] = {}
|
378
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
379
|
-
f(m)
|
380
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
381
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
382
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
383
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
384
|
-
return state, static_out
|
385
|
-
|
386
|
-
static_out: bst.graph.Static
|
387
|
-
state, static_out = f_pure(graphdef, state)
|
388
|
-
idx_out_idx_in: dict[int, int]
|
389
|
-
graphdef, idx_out_idx_in = static_out.value
|
390
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
391
|
-
ref_out_idx_out, idx_out_idx_in
|
392
|
-
)
|
393
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
394
|
-
assert m2 is m
|
395
|
-
assert m2.a is b
|
396
|
-
assert m2.b is a
|
397
|
-
|
398
|
-
@pytest.mark.skip(reason='Not implemented')
|
399
|
-
def test_cached_unflatten_swap_variables(self):
|
400
|
-
class Foo(bst.graph.Node):
|
401
|
-
def __init__(self):
|
402
|
-
self.a = bst.ParamState(1)
|
403
|
-
self.b = bst.ParamState(2)
|
404
|
-
|
405
|
-
def f(m: Foo):
|
406
|
-
m.a, m.b = m.b, m.a
|
407
|
-
|
408
|
-
m = Foo()
|
409
|
-
a = m.a
|
410
|
-
b = m.b
|
411
|
-
|
412
|
-
ref_out_idx_out = bst.graph.RefMap[Any, int]()
|
413
|
-
graphdef: bst.graph.GraphDef[Foo]
|
414
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
415
|
-
|
416
|
-
@partial(jax.jit, static_argnums=(0,))
|
417
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
418
|
-
idx_out_ref_in: dict[int, Any] = {}
|
419
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
420
|
-
f(m)
|
421
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
422
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
423
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
424
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
425
|
-
return state, static_out
|
426
|
-
|
427
|
-
static_out: bst.graph.Static
|
428
|
-
state, static_out = f_pure(graphdef, state)
|
429
|
-
idx_out_idx_in: dict[int, int]
|
430
|
-
graphdef, idx_out_idx_in = static_out.value
|
431
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
432
|
-
ref_out_idx_out, idx_out_idx_in
|
433
|
-
)
|
434
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
435
|
-
assert m2 is m
|
436
|
-
assert m2.a is b
|
437
|
-
assert m2.b is a
|
438
|
-
|
439
|
-
@pytest.mark.skip(reason='Not implemented')
|
440
|
-
def test_cached_unflatten_add_self_reference(self):
|
441
|
-
class Foo(bst.graph.Node):
|
442
|
-
def __init__(self):
|
443
|
-
self.ref = None
|
444
|
-
|
445
|
-
def f(m: Foo):
|
446
|
-
m.ref = m
|
447
|
-
|
448
|
-
m = Foo()
|
449
|
-
|
450
|
-
ref_out_idx_out = bst.graph.RefMap()
|
451
|
-
graphdef: bst.graph.GraphDef[Foo]
|
452
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
453
|
-
|
454
|
-
@partial(jax.jit, static_argnums=(0,))
|
455
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
456
|
-
idx_out_ref_in: dict[int, Any] = {}
|
457
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
458
|
-
f(m)
|
459
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
460
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
461
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
462
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
463
|
-
return state, static_out
|
464
|
-
|
465
|
-
static_out: bst.graph.Static
|
466
|
-
state, static_out = f_pure(graphdef, state)
|
467
|
-
idx_out_idx_in: dict[int, int]
|
468
|
-
graphdef, idx_out_idx_in = static_out.value
|
469
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
470
|
-
ref_out_idx_out, idx_out_idx_in
|
471
|
-
)
|
472
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
473
|
-
assert m2 is m
|
474
|
-
assert m2.ref is m2
|
475
|
-
|
476
354
|
def test_call_jit_update(self):
|
477
355
|
class Counter(bst.graph.Node):
|
478
356
|
def __init__(self):
|
@@ -527,43 +405,6 @@ class TestGraphUtils(absltest.TestCase):
|
|
527
405
|
self.assertEqual(nodes['a'].count.value, 0)
|
528
406
|
self.assertEqual(nodes['b'].count.value, 1)
|
529
407
|
|
530
|
-
def test_to_tree_simple(self):
|
531
|
-
m = bst.nn.Linear(2, 3, )
|
532
|
-
impure_tree = (m, 1, {'b': m})
|
533
|
-
|
534
|
-
pure_tree = bst.graph.graph_to_tree(impure_tree)
|
535
|
-
|
536
|
-
t1 = pure_tree[0]
|
537
|
-
t2 = pure_tree[2]['b']
|
538
|
-
|
539
|
-
self.assertEqual(pure_tree[1], 1)
|
540
|
-
self.assertIsInstance(t1, bst.graph.NodeStates)
|
541
|
-
assert isinstance(t1, bst.graph.NodeStates)
|
542
|
-
self.assertIsInstance(t2, bst.graph.NodeStates)
|
543
|
-
assert isinstance(t2, bst.graph.NodeStates)
|
544
|
-
self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
|
545
|
-
self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
|
546
|
-
self.assertLen(t1.states[0].to_flat(), 1)
|
547
|
-
self.assertLen(t2.states[0].to_flat(), 0)
|
548
|
-
|
549
|
-
impure_tree2 = bst.graph.tree_to_graph(pure_tree)
|
550
|
-
|
551
|
-
m1_out = impure_tree2[0]
|
552
|
-
m2_out = impure_tree2[2]['b']
|
553
|
-
|
554
|
-
self.assertIs(m1_out, m2_out)
|
555
|
-
self.assertEqual(impure_tree2[1], 1)
|
556
|
-
|
557
|
-
def test_to_tree_consistent_prefix(self):
|
558
|
-
m = bst.nn.Linear(2, 3, )
|
559
|
-
impure_tree = (m, 1, {'b': m})
|
560
|
-
prefix = (0, None, 0)
|
561
|
-
pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
562
|
-
|
563
|
-
prefix = (0, None, 1)
|
564
|
-
with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
|
565
|
-
bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
566
|
-
|
567
408
|
|
568
409
|
class SimpleModule(bst.nn.Module):
|
569
410
|
pass
|