brainstate 0.1.0.post20241209__py2.py3-none-any.whl → 0.1.0.post20241219__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. brainstate/compile/_conditions.py +5 -7
  2. brainstate/compile/_jit.py +3 -3
  3. brainstate/compile/_loop_collect_return.py +19 -12
  4. brainstate/compile/_loop_no_collection.py +4 -5
  5. brainstate/compile/_progress_bar.py +22 -19
  6. brainstate/event/__init__.py +8 -6
  7. brainstate/event/_csr.py +906 -0
  8. brainstate/event/_csr_mv.py +12 -25
  9. brainstate/event/_csr_mv_test.py +76 -76
  10. brainstate/event/_csr_test.py +90 -0
  11. brainstate/event/_fixedprob_mv.py +52 -32
  12. brainstate/event/_linear_mv.py +2 -2
  13. brainstate/event/_xla_custom_op.py +8 -11
  14. brainstate/graph/_graph_node.py +10 -1
  15. brainstate/graph/_graph_operation.py +8 -6
  16. brainstate/nn/_dyn_impl/_inputs.py +127 -2
  17. brainstate/nn/_dynamics/_dynamics_base.py +12 -0
  18. brainstate/nn/_dynamics/_projection_base.py +25 -7
  19. brainstate/nn/_elementwise/_dropout_test.py +11 -11
  20. brainstate/nn/_interaction/_linear.py +21 -248
  21. brainstate/nn/_interaction/_linear_test.py +73 -6
  22. brainstate/random/_rand_funs.py +7 -3
  23. brainstate/typing.py +3 -0
  24. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/METADATA +3 -2
  25. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/RECORD +28 -27
  26. brainstate/event/_csr_benchmark.py +0 -14
  27. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/LICENSE +0 -0
  28. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/WHEEL +0 -0
  29. {brainstate-0.1.0.post20241209.dist-info → brainstate-0.1.0.post20241219.dist-info}/top_level.txt +0 -0
@@ -58,7 +58,6 @@ class CSRLinear(Module):
58
58
  indices: ArrayLike,
59
59
  weight: Union[Callable, ArrayLike],
60
60
  name: Optional[str] = None,
61
- grad_mode: str = 'vjp'
62
61
  ):
63
62
  super().__init__(name=name)
64
63
 
@@ -68,17 +67,13 @@ class CSRLinear(Module):
68
67
  self.n_pre = self.in_size[-1]
69
68
  self.n_post = self.out_size[-1]
70
69
 
71
- # gradient mode
72
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
73
- self.grad_mode = grad_mode
74
-
75
70
  # CSR data structure
76
- indptr = jnp.asarray(indptr)
77
- indices = jnp.asarray(indices)
78
- assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
79
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
80
- assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
81
71
  with jax.ensure_compile_time_eval():
72
+ indptr = jnp.asarray(indptr)
73
+ indices = jnp.asarray(indices)
74
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
75
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
76
+ assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
82
77
  self.indptr = u.math.asarray(indptr)
83
78
  self.indices = u.math.asarray(indices)
84
79
 
@@ -101,21 +96,13 @@ class CSRLinear(Module):
101
96
  device_kind = jax.devices()[0].platform # spk.device.device_kind
102
97
 
103
98
  # CPU implementation
104
- if device_kind == 'cpu':
105
- return cpu_event_csr(
106
- u.math.asarray(spk),
107
- self.indptr,
108
- self.indices,
109
- u.math.asarray(weight),
110
- n_post=self.n_post, grad_mode=self.grad_mode
111
- )
112
-
113
- # GPU/TPU implementation
114
- elif device_kind in ['gpu', 'tpu']:
115
- raise NotImplementedError()
116
-
117
- else:
118
- raise ValueError(f"Unsupported device: {device_kind}")
99
+ return cpu_event_csr(
100
+ u.math.asarray(spk),
101
+ self.indptr,
102
+ self.indices,
103
+ u.math.asarray(weight),
104
+ n_post=self.n_post,
105
+ )
119
106
 
120
107
 
121
108
  @set_module_as('brainstate.event')
@@ -40,79 +40,79 @@ def true_fn(x, w, indices, indptr, n_out):
40
40
  return post
41
41
 
42
42
 
43
- class TestFixedProbCSR(parameterized.TestCase):
44
- @parameterized.product(
45
- homo_w=[True, False],
46
- )
47
- def test1(self, homo_w):
48
- x = bst.random.rand(20) < 0.1
49
- indptr, indices = _get_csr(20, 40, 0.1)
50
- m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
51
- y = m(x)
52
- y2 = true_fn(x, m.weight.value, indices, indptr, 40)
53
- self.assertTrue(jnp.allclose(y, y2))
54
-
55
- @parameterized.product(
56
- bool_x=[True, False],
57
- homo_w=[True, False]
58
- )
59
- def test_vjp(self, bool_x, homo_w):
60
- n_in = 20
61
- n_out = 30
62
- if bool_x:
63
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
64
- else:
65
- x = bst.random.rand(n_in)
66
-
67
- indptr, indices = _get_csr(n_in, n_out, 0.1)
68
- fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
69
- w = fn.weight.value
70
-
71
- def f(x, w):
72
- fn.weight.value = w
73
- return fn(x).sum()
74
-
75
- r = jax.grad(f, argnums=(0, 1))(x, w)
76
-
77
- # -------------------
78
- # TRUE gradients
79
-
80
- def f2(x, w):
81
- return true_fn(x, w, indices, indptr, n_out).sum()
82
-
83
- r2 = jax.grad(f2, argnums=(0, 1))(x, w)
84
- self.assertTrue(jnp.allclose(r[0], r2[0]))
85
- self.assertTrue(jnp.allclose(r[1], r2[1]))
86
-
87
- @parameterized.product(
88
- bool_x=[True, False],
89
- homo_w=[True, False]
90
- )
91
- def test_jvp(self, bool_x, homo_w):
92
- n_in = 20
93
- n_out = 30
94
- if bool_x:
95
- x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
96
- else:
97
- x = bst.random.rand(n_in)
98
-
99
- indptr, indices = _get_csr(n_in, n_out, 0.1)
100
- fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
101
- 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
102
- w = fn.weight.value
103
-
104
- def f(x, w):
105
- fn.weight.value = w
106
- return fn(x)
107
-
108
- o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
109
-
110
- # -------------------
111
- # TRUE gradients
112
-
113
- def f2(x, w):
114
- return true_fn(x, w, indices, indptr, n_out)
115
-
116
- o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
117
- self.assertTrue(jnp.allclose(r1, r2))
118
- self.assertTrue(jnp.allclose(o1, o2))
43
+ # class TestFixedProbCSR(parameterized.TestCase):
44
+ # @parameterized.product(
45
+ # homo_w=[True, False],
46
+ # )
47
+ # def test1(self, homo_w):
48
+ # x = bst.random.rand(20) < 0.1
49
+ # indptr, indices = _get_csr(20, 40, 0.1)
50
+ # m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
51
+ # y = m(x)
52
+ # y2 = true_fn(x, m.weight.value, indices, indptr, 40)
53
+ # self.assertTrue(jnp.allclose(y, y2))
54
+ #
55
+ # @parameterized.product(
56
+ # bool_x=[True, False],
57
+ # homo_w=[True, False]
58
+ # )
59
+ # def test_vjp(self, bool_x, homo_w):
60
+ # n_in = 20
61
+ # n_out = 30
62
+ # if bool_x:
63
+ # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
64
+ # else:
65
+ # x = bst.random.rand(n_in)
66
+ #
67
+ # indptr, indices = _get_csr(n_in, n_out, 0.1)
68
+ # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
69
+ # w = fn.weight.value
70
+ #
71
+ # def f(x, w):
72
+ # fn.weight.value = w
73
+ # return fn(x).sum()
74
+ #
75
+ # r = jax.grad(f, argnums=(0, 1))(x, w)
76
+ #
77
+ # # -------------------
78
+ # # TRUE gradients
79
+ #
80
+ # def f2(x, w):
81
+ # return true_fn(x, w, indices, indptr, n_out).sum()
82
+ #
83
+ # r2 = jax.grad(f2, argnums=(0, 1))(x, w)
84
+ # self.assertTrue(jnp.allclose(r[0], r2[0]))
85
+ # self.assertTrue(jnp.allclose(r[1], r2[1]))
86
+ #
87
+ # @parameterized.product(
88
+ # bool_x=[True, False],
89
+ # homo_w=[True, False]
90
+ # )
91
+ # def test_jvp(self, bool_x, homo_w):
92
+ # n_in = 20
93
+ # n_out = 30
94
+ # if bool_x:
95
+ # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
96
+ # else:
97
+ # x = bst.random.rand(n_in)
98
+ #
99
+ # indptr, indices = _get_csr(n_in, n_out, 0.1)
100
+ # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
101
+ # 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
102
+ # w = fn.weight.value
103
+ #
104
+ # def f(x, w):
105
+ # fn.weight.value = w
106
+ # return fn(x)
107
+ #
108
+ # o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
109
+ #
110
+ # # -------------------
111
+ # # TRUE gradients
112
+ #
113
+ # def f2(x, w):
114
+ # return true_fn(x, w, indices, indptr, n_out)
115
+ #
116
+ # o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
117
+ # self.assertTrue(jnp.allclose(r1, r2))
118
+ # self.assertTrue(jnp.allclose(o1, o2))
@@ -0,0 +1,90 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # -*- coding: utf-8 -*-
16
+
17
+
18
+ import unittest
19
+
20
+ import brainunit as u
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestCSR(unittest.TestCase):
26
+ def test_event_homo_bool(self):
27
+ for dat in [1., 2., 3.]:
28
+ mask = (bst.random.rand(10, 20) < 0.1).astype(float) * dat
29
+ csr = u.sparse.CSR.fromdense(mask)
30
+ csr = bst.event.CSR((dat, csr.indices, csr.indptr), shape=mask.shape)
31
+
32
+ v = bst.random.rand(20) < 0.5
33
+ self.assertTrue(
34
+ u.math.allclose(
35
+ mask.astype(float) @ v.astype(float),
36
+ csr @ v
37
+ )
38
+ )
39
+
40
+ v = bst.random.rand(10) < 0.5
41
+ self.assertTrue(
42
+ u.math.allclose(
43
+ v.astype(float) @ mask.astype(float),
44
+ v @ csr
45
+ )
46
+ )
47
+
48
+ def test_event_homo_heter(self):
49
+ mat = bst.random.rand(10, 20)
50
+ mask = (bst.random.rand(10, 20) < 0.1) * mat
51
+ csr = u.sparse.CSR.fromdense(mask)
52
+ csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
53
+
54
+ v = bst.random.rand(20) < 0.5
55
+ self.assertTrue(
56
+ u.math.allclose(
57
+ mask.astype(float) @ v.astype(float),
58
+ csr @ v
59
+ )
60
+ )
61
+
62
+ v = bst.random.rand(10) < 0.5
63
+ self.assertTrue(
64
+ u.math.allclose(
65
+ v.astype(float) @ mask.astype(float),
66
+ v @ csr
67
+ )
68
+ )
69
+
70
+ def test_event_heter_float_as_bool(self):
71
+ mat = bst.random.rand(10, 20)
72
+ mask = (mat < 0.1).astype(float) * mat
73
+ csr = u.sparse.CSR.fromdense(mask)
74
+ csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
75
+
76
+ v = (bst.random.rand(20) < 0.5).astype(float)
77
+ self.assertTrue(
78
+ u.math.allclose(
79
+ mask.astype(float) @ v.astype(float),
80
+ csr @ v
81
+ )
82
+ )
83
+
84
+ v = (bst.random.rand(10) < 0.5).astype(float)
85
+ self.assertTrue(
86
+ u.math.allclose(
87
+ v.astype(float) @ mask.astype(float),
88
+ v @ csr
89
+ )
90
+ )
@@ -85,44 +85,52 @@ class FixedProb(Module):
85
85
  self.in_size = in_size
86
86
  self.out_size = out_size
87
87
  self.n_conn = int(self.out_size[-1] * prob)
88
- if self.n_conn < 1:
89
- raise ValueError(f"The number of connections must be at least 1. "
90
- f"Got: int({self.out_size[-1]} * {prob}) = {self.n_conn}")
91
88
  self.float_as_event = float_as_event
92
89
  self.block_size = block_size
93
90
 
94
- # indices of post connected neurons
95
- with jax.ensure_compile_time_eval():
96
- if allow_multi_conn:
97
- rng = np.random.RandomState(seed)
98
- self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
99
- else:
100
- rng = RandomState(seed)
91
+ if self.n_conn > 1:
92
+ # indices of post connected neurons
93
+ with jax.ensure_compile_time_eval():
94
+ if allow_multi_conn:
95
+ rng = np.random.RandomState(seed)
96
+ self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
97
+ else:
98
+ rng = RandomState(seed)
101
99
 
102
- @vmap(rngs=rng)
103
- def rand_indices(key):
104
- rng.set_key(key)
105
- return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
100
+ @vmap(rngs=rng)
101
+ def rand_indices(key):
102
+ rng.set_key(key)
103
+ return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
106
104
 
107
- self.indices = rand_indices(rng.split_key(self.in_size[-1]))
108
- self.indices = u.math.asarray(self.indices)
105
+ self.indices = rand_indices(rng.split_key(self.in_size[-1]))
106
+ self.indices = u.math.asarray(self.indices)
109
107
 
110
108
  # maximum synaptic conductance
111
109
  weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
112
110
  self.weight = ParamState(weight)
113
111
 
114
112
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
115
- return event_fixed_prob(
116
- spk,
117
- self.weight.value,
118
- self.indices,
119
- n_post=self.out_size[-1],
120
- block_size=self.block_size,
121
- float_as_event=self.float_as_event
122
- )
113
+ if self.n_conn > 1:
114
+ return event_fixed_prob(
115
+ spk,
116
+ self.weight.value,
117
+ self.indices,
118
+ n_post=self.out_size[-1],
119
+ block_size=self.block_size,
120
+ float_as_event=self.float_as_event
121
+ )
122
+ else:
123
+ weight = self.weight.value
124
+ unit = u.get_unit(weight)
125
+ r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
126
+ return u.maybe_decimal(u.Quantity(r, unit=unit))
123
127
 
124
128
 
125
- def event_fixed_prob(spk, weight, indices, *, n_post, block_size, float_as_event):
129
+ def event_fixed_prob(
130
+ spk, weight, indices,
131
+ *,
132
+ n_post, block_size, float_as_event
133
+ ):
126
134
  """
127
135
  The FixedProb module implements a fixed probability connection with CSR sparse data structure.
128
136
 
@@ -374,7 +382,11 @@ def gpu_kernel_generator(
374
382
  kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))
375
383
 
376
384
 
377
- def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwargs):
385
+ def jvp_spikes(
386
+ spk_dot, spikes, weights, indices,
387
+ *,
388
+ n_post, block_size, **kwargs
389
+ ):
378
390
  return ellmv_p_call(
379
391
  spk_dot,
380
392
  weights,
@@ -384,7 +396,11 @@ def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwarg
384
396
  )
385
397
 
386
398
 
387
- def jvp_weights(w_dot, spikes, weights, indices, *, float_as_event, block_size, n_post, **kwargs):
399
+ def jvp_weights(
400
+ w_dot, spikes, weights, indices,
401
+ *,
402
+ float_as_event, block_size, n_post, **kwargs
403
+ ):
388
404
  return event_ellmv_p_call(
389
405
  spikes,
390
406
  w_dot,
@@ -457,14 +473,18 @@ def transpose_rule(
457
473
 
458
474
  event_ellmv_p = XLACustomOp(
459
475
  'event_ell_mv',
460
- cpu_kernel_generator=cpu_kernel_generator,
461
- gpu_kernel_generator=gpu_kernel_generator,
476
+ cpu_kernel_or_generator=cpu_kernel_generator,
477
+ gpu_kernel_or_generator=gpu_kernel_generator,
462
478
  )
463
479
  event_ellmv_p.defjvp(jvp_spikes, jvp_weights, None)
464
480
  event_ellmv_p.def_transpose_rule(transpose_rule)
465
481
 
466
482
 
467
- def event_ellmv_p_call(spikes, weights, indices, *, n_post, block_size, float_as_event):
483
+ def event_ellmv_p_call(
484
+ spikes, weights, indices,
485
+ *,
486
+ n_post, block_size, float_as_event
487
+ ):
468
488
  n_conn = indices.shape[1]
469
489
  if block_size is None:
470
490
  if n_conn <= 16:
@@ -673,8 +693,8 @@ def transpose_rule_no_spk(
673
693
 
674
694
  ellmv_p = XLACustomOp(
675
695
  'ell_mv',
676
- cpu_kernel_generator=ell_cpu_kernel_generator,
677
- gpu_kernel_generator=ell_gpu_kernel_generator,
696
+ cpu_kernel_or_generator=ell_cpu_kernel_generator,
697
+ gpu_kernel_or_generator=ell_gpu_kernel_generator,
678
698
  )
679
699
  ellmv_p.defjvp(jvp_spikes, jvp_weights_no_spk, None)
680
700
  ellmv_p.def_transpose_rule(transpose_rule_no_spk)
@@ -334,8 +334,8 @@ def transpose_rule(ct, spikes, weights, *, float_as_event, **kwargs):
334
334
 
335
335
  event_linear_p = XLACustomOp(
336
336
  'event_linear',
337
- cpu_kernel_generator=cpu_kernel_generator,
338
- gpu_kernel_generator=gpu_kernel_generator,
337
+ cpu_kernel_or_generator=cpu_kernel_generator,
338
+ gpu_kernel_or_generator=gpu_kernel_generator,
339
339
  )
340
340
  event_linear_p.defjvp(jvp_spikes, jvp_weights)
341
341
  event_linear_p.def_transpose_rule(transpose_rule)
@@ -180,8 +180,8 @@ class XLACustomOp:
180
180
  """Creating a XLA custom call operator.
181
181
 
182
182
  Args:
183
- cpu_kernel_generator: Callable. The function defines the computation on CPU backend.
184
- gpu_kernel_generator: Callable. The function defines the computation on GPU backend.
183
+ cpu_kernel_or_generator: Callable. The function defines the computation on CPU backend.
184
+ gpu_kernel_or_generator: Callable. The function defines the computation on GPU backend.
185
185
  batching_translation: Callable. The batching translation rule of JAX.
186
186
  jvp_translation: Callable. The JVP translation rule of JAX.
187
187
  transpose_translation: Callable. The transpose translation rule of JAX.
@@ -191,15 +191,12 @@ class XLACustomOp:
191
191
  def __init__(
192
192
  self,
193
193
  name: str,
194
- cpu_kernel_generator: Callable,
195
- gpu_kernel_generator: Callable = None,
194
+ cpu_kernel_or_generator: Callable,
195
+ gpu_kernel_or_generator: Callable = None,
196
196
  batching_translation: Callable = None,
197
197
  jvp_translation: Callable = None,
198
198
  transpose_translation: Callable = None,
199
199
  ):
200
- # set cpu_kernel and gpu_kernel
201
- self.cpu_kernel = cpu_kernel_generator
202
-
203
200
  # primitive
204
201
  self.primitive = jax.core.Primitive(name)
205
202
  self.primitive.multiple_results = True
@@ -209,10 +206,10 @@ class XLACustomOp:
209
206
  self.primitive.def_abstract_eval(self._abstract_eval)
210
207
 
211
208
  # cpu kernel
212
- if cpu_kernel_generator is not None:
213
- self.def_cpu_kernel(cpu_kernel_generator)
214
- if gpu_kernel_generator is not None:
215
- self.def_gpu_kernel(gpu_kernel_generator)
209
+ if cpu_kernel_or_generator is not None:
210
+ self.def_cpu_kernel(cpu_kernel_or_generator)
211
+ if gpu_kernel_or_generator is not None:
212
+ self.def_gpu_kernel(gpu_kernel_or_generator)
216
213
 
217
214
  # batching rule
218
215
  if batching_translation is not None:
@@ -61,6 +61,9 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
61
61
  - Deepcopy the node.
62
62
 
63
63
  """
64
+
65
+ graph_invisible_attrs = ()
66
+
64
67
  if TYPE_CHECKING:
65
68
  _trace_state: StateJaxTracer
66
69
 
@@ -170,7 +173,13 @@ def _to_shape_dtype(value):
170
173
  def _node_flatten(
171
174
  node: Node
172
175
  ) -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[Type]]:
173
- nodes = sorted((key, value) for key, value in vars(node).items() if key != '_trace_state')
176
+ # graph_invisible_attrs = getattr(node, 'graph_invisible_attrs', ())
177
+ # graph_invisible_attrs = tuple(graph_invisible_attrs) + ('_trace_state',)
178
+ graph_invisible_attrs = ('_trace_state',)
179
+ nodes = sorted(
180
+ (key, value) for key, value in vars(node).items()
181
+ if (key not in graph_invisible_attrs)
182
+ )
174
183
  return nodes, (type(node),)
175
184
 
176
185
 
@@ -608,9 +608,9 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
608
608
  if isinstance(value, TreefyState):
609
609
  variable.update_from_ref(value)
610
610
  elif isinstance(value, State):
611
- if value._been_writen:
611
+ if value._been_writen:
612
612
  variable.write_value(value.value)
613
- else:
613
+ else:
614
614
  variable.restore_value(value.value)
615
615
  else:
616
616
  raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
@@ -1600,10 +1600,12 @@ def iter_leaf(
1600
1600
  visited_.add(id(node_))
1601
1601
  node_dict = _get_node_impl(node_).node_dict(node_)
1602
1602
  for key, value in node_dict.items():
1603
- yield from _iter_graph_leaf(value,
1604
- visited_,
1605
- (*path_parts_, key),
1606
- level_ + 1 if _is_graph_node(value) else level_)
1603
+ yield from _iter_graph_leaf(
1604
+ value,
1605
+ visited_,
1606
+ (*path_parts_, key),
1607
+ level_ + 1 if _is_graph_node(value) else level_
1608
+ )
1607
1609
  else:
1608
1610
  if level_ >= allowed_hierarchy[0]:
1609
1611
  yield path_parts_, node_