brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_jit.py +37 -28
  9. brainstate/compile/_loop_collect_return.py +6 -3
  10. brainstate/compile/_loop_no_collection.py +2 -0
  11. brainstate/compile/_make_jaxpr.py +7 -3
  12. brainstate/compile/_progress_bar.py +68 -40
  13. brainstate/compile/_unvmap.py +6 -3
  14. brainstate/event/__init__.py +0 -2
  15. brainstate/event/_csr.py +266 -23
  16. brainstate/event/_csr_test.py +187 -0
  17. brainstate/event/_xla_custom_op.py +7 -3
  18. brainstate/graph/__init__.py +8 -12
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_interaction/_conv.py +4 -2
  24. brainstate/nn/_interaction/_linear.py +84 -10
  25. brainstate/random/_rand_funs.py +9 -2
  26. brainstate/random/_rand_seed.py +12 -2
  27. brainstate/random/_rand_state.py +50 -179
  28. brainstate/surrogate.py +5 -1
  29. brainstate/util/__init__.py +0 -4
  30. brainstate/util/_caller.py +1 -1
  31. brainstate/util/_dict.py +4 -1
  32. brainstate/util/_filter.py +1 -1
  33. brainstate/util/_pretty_repr.py +1 -1
  34. brainstate/util/_struct.py +1 -1
  35. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
  37. brainstate/event/_csr_mv_test.py +0 -118
  38. brainstate/graph/_graph_context.py +0 -443
  39. brainstate/graph/_graph_context_test.py +0 -65
  40. brainstate/graph/_graph_convert.py +0 -246
  41. brainstate/util/_tracers.py +0 -68
  42. brainstate/util/_visualization.py +0 -47
  43. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  44. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  46. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,9 @@
18
18
  import unittest
19
19
 
20
20
  import brainunit as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
21
24
 
22
25
  import brainstate as bst
23
26
 
@@ -88,3 +91,187 @@ class TestCSR(unittest.TestCase):
88
91
  v @ csr
89
92
  )
90
93
  )
94
+
95
+
96
+ def _get_csr(n_pre, n_post, prob):
97
+ n_conn = int(n_post * prob)
98
+ indptr = np.arange(n_pre + 1) * n_conn
99
+ indices = np.random.randint(0, n_post, (n_pre * n_conn,))
100
+ return indptr, indices
101
+
102
+
103
+ def vector_csr(x, w, indices, indptr, shape):
104
+ homo_w = jnp.size(w) == 1
105
+ post = jnp.zeros((shape[1],))
106
+ for i_pre in range(x.shape[0]):
107
+ ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
108
+ post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
109
+ return post
110
+
111
+
112
+ def matrix_csr(xs, w, indices, indptr, shape):
113
+ homo_w = jnp.size(w) == 1
114
+ post = jnp.zeros((xs.shape[0], shape[1]))
115
+ for i_pre in range(xs.shape[1]):
116
+ ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
117
+ post = post.at[:, ids].add(
118
+ w * xs[:, i_pre: i_pre + 1]
119
+ if homo_w else
120
+ (w[indptr[i_pre]: indptr[i_pre + 1]] * xs[:, i_pre: i_pre + 1])
121
+ )
122
+ return post
123
+
124
+
125
+ def csr_vector(x, w, indices, indptr, shape):
126
+ homo_w = jnp.size(w) == 1
127
+ out = jnp.zeros([shape[0]])
128
+ for i in range(shape[0]):
129
+ ids = indices[indptr[i]: indptr[i + 1]]
130
+ ws = w if homo_w else w[indptr[i]: indptr[i + 1]]
131
+ out = out.at[i].set(jnp.sum(x[ids] * ws))
132
+ return out
133
+
134
+
135
+ def csr_matrix(xs, w, indices, indptr, shape):
136
+ # CSR @ matrix
137
+ homo_w = jnp.size(w) == 1
138
+ out = jnp.zeros([shape[0], xs.shape[1]])
139
+ for i in range(shape[0]):
140
+ ids = indices[indptr[i]: indptr[i + 1]]
141
+ ws = w if homo_w else jnp.expand_dims(w[indptr[i]: indptr[i + 1]], axis=1)
142
+ out = out.at[i].set(jnp.sum(xs[ids] * ws, axis=0))
143
+ return out
144
+
145
+
146
+ class TestVectorCSR(unittest.TestCase):
147
+ def test_vector_csr(self, ):
148
+ m, n = 20, 40
149
+ x = bst.random.rand(m) < 0.1
150
+ indptr, indices = _get_csr(m, n, 0.1)
151
+
152
+ for homo_w in [True, False]:
153
+ print(f'homo_w = {homo_w}')
154
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
155
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
156
+ y = x @ csr
157
+ y2 = vector_csr(x, csr.data, indices, indptr, [m, n])
158
+ self.assertTrue(jnp.allclose(y, y2))
159
+
160
+ def test_vector_csr_vmap_vector(self):
161
+ n_batch, m, n = 10, 20, 40
162
+ xs = bst.random.rand(n_batch, m) < 0.1
163
+ indptr, indices = _get_csr(m, n, 0.1)
164
+
165
+ for homo_w in [True, False]:
166
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
167
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
168
+ y = jax.vmap(lambda x: x @ csr)(xs)
169
+ y2 = jax.vmap(lambda x: vector_csr(x, csr.data, indices, indptr, [m, n]))(xs)
170
+ self.assertTrue(jnp.allclose(y, y2))
171
+
172
+
173
+ class TestMatrixCSR(unittest.TestCase):
174
+ def test_matrix_csr(self):
175
+ k, m, n = 10, 20, 40
176
+ x = bst.random.rand(k, m) < 0.1
177
+ indptr, indices = _get_csr(m, n, 0.1)
178
+
179
+ for homo_w in [True, False]:
180
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
181
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
182
+ y = x @ csr
183
+ y2 = matrix_csr(x, csr.data, indices, indptr, [m, n])
184
+ self.assertTrue(jnp.allclose(y, y2))
185
+
186
+
187
+ class TestCSRVector(unittest.TestCase):
188
+ def test_csr_vector(self):
189
+ m, n = 20, 40
190
+ v = bst.random.rand(n) < 0.1
191
+ indptr, indices = _get_csr(m, n, 0.1)
192
+
193
+ for homo_w in [True, False]:
194
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
195
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
196
+ y = csr @ v
197
+ y2 = csr_vector(v, csr.data, indices, indptr, [m, n])
198
+ self.assertTrue(jnp.allclose(y, y2))
199
+
200
+
201
+ class TestCSRMatrix(unittest.TestCase):
202
+ def test_csr_matrix(self):
203
+ m, n, k = 20, 40, 10
204
+ matrix = bst.random.rand(n, k) < 0.1
205
+ indptr, indices = _get_csr(m, n, 0.1)
206
+
207
+ for homo_w in [True, False]:
208
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
209
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
210
+ y = csr @ matrix
211
+ y2 = csr_matrix(matrix, csr.data, indices, indptr, [m, n])
212
+ self.assertTrue(jnp.allclose(y, y2))
213
+
214
+ # @parameterized.product(
215
+ # bool_x=[True, False],
216
+ # homo_w=[True, False]
217
+ # )
218
+ # def test_vjp(self, bool_x, homo_w):
219
+ # n_in = 20
220
+ # n_out = 30
221
+ # if bool_x:
222
+ # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
223
+ # else:
224
+ # x = bst.random.rand(n_in)
225
+ #
226
+ # indptr, indices = _get_csr(n_in, n_out, 0.1)
227
+ # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
228
+ # w = fn.weight.value
229
+ #
230
+ # def f(x, w):
231
+ # fn.weight.value = w
232
+ # return fn(x).sum()
233
+ #
234
+ # r = jax.grad(f, argnums=(0, 1))(x, w)
235
+ #
236
+ # # -------------------
237
+ # # TRUE gradients
238
+ #
239
+ # def f2(x, w):
240
+ # return true_fn(x, w, indices, indptr, n_out).sum()
241
+ #
242
+ # r2 = jax.grad(f2, argnums=(0, 1))(x, w)
243
+ # self.assertTrue(jnp.allclose(r[0], r2[0]))
244
+ # self.assertTrue(jnp.allclose(r[1], r2[1]))
245
+ #
246
+ # @parameterized.product(
247
+ # bool_x=[True, False],
248
+ # homo_w=[True, False]
249
+ # )
250
+ # def test_jvp(self, bool_x, homo_w):
251
+ # n_in = 20
252
+ # n_out = 30
253
+ # if bool_x:
254
+ # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
255
+ # else:
256
+ # x = bst.random.rand(n_in)
257
+ #
258
+ # indptr, indices = _get_csr(n_in, n_out, 0.1)
259
+ # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
260
+ # 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
261
+ # w = fn.weight.value
262
+ #
263
+ # def f(x, w):
264
+ # fn.weight.value = w
265
+ # return fn(x)
266
+ #
267
+ # o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
268
+ #
269
+ # # -------------------
270
+ # # TRUE gradients
271
+ #
272
+ # def f2(x, w):
273
+ # return true_fn(x, w, indices, indptr, n_out)
274
+ #
275
+ # o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
276
+ # self.assertTrue(jnp.allclose(r1, r2))
277
+ # self.assertTrue(jnp.allclose(o1, o2))
@@ -9,7 +9,6 @@ from typing import Callable, Sequence, Tuple, Protocol
9
9
  import jax
10
10
  import numpy as np
11
11
  from jax import tree_util
12
- from jax.core import Primitive
13
12
  from jax.interpreters import batching, ad
14
13
  from jax.interpreters import xla, mlir
15
14
  from jaxlib.hlo_helpers import custom_call
@@ -19,6 +18,11 @@ if jax.__version_info__ < (0, 4, 35):
19
18
  else:
20
19
  import jax.extend as je
21
20
 
21
+ if jax.__version_info__ < (0, 4, 38):
22
+ from jax.core import Primitive
23
+ else:
24
+ from jax.extend.core import Primitive
25
+
22
26
  numba_installed = importlib.util.find_spec('numba') is not None
23
27
 
24
28
  __all__ = [
@@ -164,7 +168,7 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
164
168
 
165
169
 
166
170
  def register_numba_mlir_cpu_translation_rule(
167
- primitive: jax.core.Primitive,
171
+ primitive: Primitive,
168
172
  cpu_kernel: Callable,
169
173
  debug: bool = False
170
174
  ):
@@ -205,7 +209,7 @@ class XLACustomOp:
205
209
  transpose_translation: Callable = None,
206
210
  ):
207
211
  # primitive
208
- self.primitive = jax.core.Primitive(name)
212
+ self.primitive = Primitive(name)
209
213
  self.primitive.multiple_results = True
210
214
 
211
215
  # abstract evaluation
@@ -14,20 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from ._graph_context import *
18
- from ._graph_context import __all__ as _graph_context__all__
19
- from ._graph_convert import *
20
- from ._graph_convert import __all__ as _graph_convert__all__
21
17
  from ._graph_node import *
22
18
  from ._graph_node import __all__ as _graph_node__all__
23
19
  from ._graph_operation import *
24
20
  from ._graph_operation import __all__ as _graph_operation__all__
25
21
 
26
- __all__ = (_graph_context__all__ +
27
- _graph_convert__all__ +
28
- _graph_node__all__ +
29
- _graph_operation__all__)
30
- del (_graph_context__all__,
31
- _graph_convert__all__,
32
- _graph_node__all__,
33
- _graph_operation__all__)
22
+ __all__ = (
23
+ _graph_node__all__ +
24
+ _graph_operation__all__
25
+ )
26
+ del (
27
+ _graph_node__all__,
28
+ _graph_operation__all__
29
+ )
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -27,9 +27,7 @@ import numpy as np
27
27
 
28
28
  from brainstate._state import State, TreefyState
29
29
  from brainstate.typing import Key
30
- from brainstate.util._error import TraceContextError
31
30
  from brainstate.util._pretty_repr import PrettyRepr, pretty_repr_avoid_duplicate, PrettyType, PrettyAttr
32
- from brainstate.util._tracers import StateJaxTracer
33
31
  from ._graph_operation import register_graph_node_type
34
32
 
35
33
  __all__ = [
@@ -44,7 +42,6 @@ class GraphNodeMeta(ABCMeta):
44
42
  if not TYPE_CHECKING:
45
43
  def __call__(cls, *args: Any, **kwargs: Any) -> Any:
46
44
  node = cls.__new__(cls, *args, **kwargs)
47
- vars(node)['_trace_state'] = StateJaxTracer()
48
45
  node.__init__(*args, **kwargs)
49
46
  return node
50
47
 
@@ -64,9 +61,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
64
61
 
65
62
  graph_invisible_attrs = ()
66
63
 
67
- if TYPE_CHECKING:
68
- _trace_state: StateJaxTracer
69
-
70
64
  def __init_subclass__(cls) -> None:
71
65
  super().__init_subclass__()
72
66
 
@@ -79,21 +73,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
79
73
  clear=_node_clear,
80
74
  )
81
75
 
82
- # if not TYPE_CHECKING:
83
- # def __setattr__(self, name: str, value: Any) -> None:
84
- # self._setattr(name, value)
85
-
86
- # def _setattr(self, name: str, value: Any) -> None:
87
- # self.check_valid_context(lambda: f"Cannot mutate '{type(self).__name__}' from different trace level")
88
- # object.__setattr__(self, name, value)
89
-
90
- def check_valid_context(self, error_msg: Callable[[], str]) -> None:
91
- """
92
- Check if the current context is valid for the object to be mutated.
93
- """
94
- if not self._trace_state.is_valid():
95
- raise TraceContextError(error_msg())
96
-
97
76
  def __deepcopy__(self: G, memo=None) -> G:
98
77
  """
99
78
  Deepcopy the object.
@@ -214,7 +193,6 @@ def _node_create_empty(
214
193
  ) -> G:
215
194
  node_type, = static
216
195
  node = object.__new__(node_type)
217
- vars(node).update(_trace_state=StateJaxTracer())
218
196
  return node
219
197
 
220
198
 
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -17,13 +17,10 @@ from __future__ import annotations
17
17
 
18
18
  import unittest
19
19
  from collections.abc import Callable
20
- from functools import partial
21
20
  from threading import Thread
22
- from typing import Any
23
21
 
24
22
  import jax
25
23
  import jax.numpy as jnp
26
- import pytest
27
24
  from absl.testing import absltest, parameterized
28
25
 
29
26
  import brainstate as bst
@@ -354,125 +351,6 @@ class TestGraphUtils(absltest.TestCase):
354
351
  assert m2.tree.a is not m.tree.a
355
352
  assert m2.tree is not m.tree
356
353
 
357
- @pytest.mark.skip(reason='Not implemented')
358
- def test_cached_unflatten(self):
359
- class Foo(bst.graph.Node):
360
- def __init__(self, ):
361
- self.a = bst.nn.Linear(2, 2)
362
- self.b = bst.nn.BatchNorm1d([10, 2])
363
-
364
- def f(m: Foo):
365
- m.a, m.b = m.b, m.a # type: ignore
366
-
367
- m = Foo()
368
- a = m.a
369
- b = m.b
370
-
371
- ref_out_idx_out = bst.graph.RefMap()
372
- graphdef: bst.graph.GraphDef[Foo]
373
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
374
-
375
- @partial(jax.jit, static_argnums=(0,))
376
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
377
- idx_out_ref_in: dict[int, Any] = {}
378
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
379
- f(m)
380
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
381
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
382
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
383
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
384
- return state, static_out
385
-
386
- static_out: bst.graph.Static
387
- state, static_out = f_pure(graphdef, state)
388
- idx_out_idx_in: dict[int, int]
389
- graphdef, idx_out_idx_in = static_out.value
390
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
391
- ref_out_idx_out, idx_out_idx_in
392
- )
393
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
394
- assert m2 is m
395
- assert m2.a is b
396
- assert m2.b is a
397
-
398
- @pytest.mark.skip(reason='Not implemented')
399
- def test_cached_unflatten_swap_variables(self):
400
- class Foo(bst.graph.Node):
401
- def __init__(self):
402
- self.a = bst.ParamState(1)
403
- self.b = bst.ParamState(2)
404
-
405
- def f(m: Foo):
406
- m.a, m.b = m.b, m.a
407
-
408
- m = Foo()
409
- a = m.a
410
- b = m.b
411
-
412
- ref_out_idx_out = bst.graph.RefMap[Any, int]()
413
- graphdef: bst.graph.GraphDef[Foo]
414
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
415
-
416
- @partial(jax.jit, static_argnums=(0,))
417
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
418
- idx_out_ref_in: dict[int, Any] = {}
419
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
420
- f(m)
421
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
422
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
423
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
424
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
425
- return state, static_out
426
-
427
- static_out: bst.graph.Static
428
- state, static_out = f_pure(graphdef, state)
429
- idx_out_idx_in: dict[int, int]
430
- graphdef, idx_out_idx_in = static_out.value
431
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
432
- ref_out_idx_out, idx_out_idx_in
433
- )
434
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
435
- assert m2 is m
436
- assert m2.a is b
437
- assert m2.b is a
438
-
439
- @pytest.mark.skip(reason='Not implemented')
440
- def test_cached_unflatten_add_self_reference(self):
441
- class Foo(bst.graph.Node):
442
- def __init__(self):
443
- self.ref = None
444
-
445
- def f(m: Foo):
446
- m.ref = m
447
-
448
- m = Foo()
449
-
450
- ref_out_idx_out = bst.graph.RefMap()
451
- graphdef: bst.graph.GraphDef[Foo]
452
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
453
-
454
- @partial(jax.jit, static_argnums=(0,))
455
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
456
- idx_out_ref_in: dict[int, Any] = {}
457
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
458
- f(m)
459
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
460
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
461
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
462
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
463
- return state, static_out
464
-
465
- static_out: bst.graph.Static
466
- state, static_out = f_pure(graphdef, state)
467
- idx_out_idx_in: dict[int, int]
468
- graphdef, idx_out_idx_in = static_out.value
469
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
470
- ref_out_idx_out, idx_out_idx_in
471
- )
472
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
473
- assert m2 is m
474
- assert m2.ref is m2
475
-
476
354
  def test_call_jit_update(self):
477
355
  class Counter(bst.graph.Node):
478
356
  def __init__(self):
@@ -527,43 +405,6 @@ class TestGraphUtils(absltest.TestCase):
527
405
  self.assertEqual(nodes['a'].count.value, 0)
528
406
  self.assertEqual(nodes['b'].count.value, 1)
529
407
 
530
- def test_to_tree_simple(self):
531
- m = bst.nn.Linear(2, 3, )
532
- impure_tree = (m, 1, {'b': m})
533
-
534
- pure_tree = bst.graph.graph_to_tree(impure_tree)
535
-
536
- t1 = pure_tree[0]
537
- t2 = pure_tree[2]['b']
538
-
539
- self.assertEqual(pure_tree[1], 1)
540
- self.assertIsInstance(t1, bst.graph.NodeStates)
541
- assert isinstance(t1, bst.graph.NodeStates)
542
- self.assertIsInstance(t2, bst.graph.NodeStates)
543
- assert isinstance(t2, bst.graph.NodeStates)
544
- self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
545
- self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
546
- self.assertLen(t1.states[0].to_flat(), 1)
547
- self.assertLen(t2.states[0].to_flat(), 0)
548
-
549
- impure_tree2 = bst.graph.tree_to_graph(pure_tree)
550
-
551
- m1_out = impure_tree2[0]
552
- m2_out = impure_tree2[2]['b']
553
-
554
- self.assertIs(m1_out, m2_out)
555
- self.assertEqual(impure_tree2[1], 1)
556
-
557
- def test_to_tree_consistent_prefix(self):
558
- m = bst.nn.Linear(2, 3, )
559
- impure_tree = (m, 1, {'b': m})
560
- prefix = (0, None, 0)
561
- pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
562
-
563
- prefix = (0, None, 1)
564
- with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
565
- bst.graph.graph_to_tree(impure_tree, prefix=prefix)
566
-
567
408
 
568
409
  class SimpleModule(bst.nn.Module):
569
410
  pass