brainstate 0.1.0.post20241210__py2.py3-none-any.whl → 0.1.0.post20241220__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_jit.py +20 -14
- brainstate/compile/_loop_collect_return.py +14 -6
- brainstate/compile/_progress_bar.py +5 -3
- brainstate/event/__init__.py +8 -6
- brainstate/event/_csr.py +906 -0
- brainstate/event/_csr_mv.py +12 -25
- brainstate/event/_csr_mv_test.py +76 -76
- brainstate/event/_csr_test.py +90 -0
- brainstate/event/_fixedprob_mv.py +52 -32
- brainstate/event/_linear_mv.py +2 -2
- brainstate/event/_xla_custom_op.py +8 -11
- brainstate/graph/_graph_node.py +10 -1
- brainstate/graph/_graph_operation.py +8 -6
- brainstate/nn/_dyn_impl/_inputs.py +127 -2
- brainstate/nn/_dynamics/_dynamics_base.py +12 -0
- brainstate/nn/_dynamics/_projection_base.py +25 -7
- brainstate/nn/_elementwise/_dropout_test.py +11 -11
- brainstate/nn/_interaction/_linear.py +21 -248
- brainstate/nn/_interaction/_linear_test.py +73 -6
- brainstate/random/_rand_funs.py +7 -3
- brainstate/typing.py +3 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/METADATA +3 -2
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/RECORD +26 -25
- brainstate/event/_csr_benchmark.py +0 -14
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/top_level.txt +0 -0
brainstate/event/_csr_mv.py
CHANGED
@@ -58,7 +58,6 @@ class CSRLinear(Module):
|
|
58
58
|
indices: ArrayLike,
|
59
59
|
weight: Union[Callable, ArrayLike],
|
60
60
|
name: Optional[str] = None,
|
61
|
-
grad_mode: str = 'vjp'
|
62
61
|
):
|
63
62
|
super().__init__(name=name)
|
64
63
|
|
@@ -68,17 +67,13 @@ class CSRLinear(Module):
|
|
68
67
|
self.n_pre = self.in_size[-1]
|
69
68
|
self.n_post = self.out_size[-1]
|
70
69
|
|
71
|
-
# gradient mode
|
72
|
-
assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
|
73
|
-
self.grad_mode = grad_mode
|
74
|
-
|
75
70
|
# CSR data structure
|
76
|
-
indptr = jnp.asarray(indptr)
|
77
|
-
indices = jnp.asarray(indices)
|
78
|
-
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
|
79
|
-
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
|
80
|
-
assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
|
81
71
|
with jax.ensure_compile_time_eval():
|
72
|
+
indptr = jnp.asarray(indptr)
|
73
|
+
indices = jnp.asarray(indices)
|
74
|
+
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
|
75
|
+
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
|
76
|
+
assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
|
82
77
|
self.indptr = u.math.asarray(indptr)
|
83
78
|
self.indices = u.math.asarray(indices)
|
84
79
|
|
@@ -101,21 +96,13 @@ class CSRLinear(Module):
|
|
101
96
|
device_kind = jax.devices()[0].platform # spk.device.device_kind
|
102
97
|
|
103
98
|
# CPU implementation
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
)
|
112
|
-
|
113
|
-
# GPU/TPU implementation
|
114
|
-
elif device_kind in ['gpu', 'tpu']:
|
115
|
-
raise NotImplementedError()
|
116
|
-
|
117
|
-
else:
|
118
|
-
raise ValueError(f"Unsupported device: {device_kind}")
|
99
|
+
return cpu_event_csr(
|
100
|
+
u.math.asarray(spk),
|
101
|
+
self.indptr,
|
102
|
+
self.indices,
|
103
|
+
u.math.asarray(weight),
|
104
|
+
n_post=self.n_post,
|
105
|
+
)
|
119
106
|
|
120
107
|
|
121
108
|
@set_module_as('brainstate.event')
|
brainstate/event/_csr_mv_test.py
CHANGED
@@ -40,79 +40,79 @@ def true_fn(x, w, indices, indptr, n_out):
|
|
40
40
|
return post
|
41
41
|
|
42
42
|
|
43
|
-
class TestFixedProbCSR(parameterized.TestCase):
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
43
|
+
# class TestFixedProbCSR(parameterized.TestCase):
|
44
|
+
# @parameterized.product(
|
45
|
+
# homo_w=[True, False],
|
46
|
+
# )
|
47
|
+
# def test1(self, homo_w):
|
48
|
+
# x = bst.random.rand(20) < 0.1
|
49
|
+
# indptr, indices = _get_csr(20, 40, 0.1)
|
50
|
+
# m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
|
51
|
+
# y = m(x)
|
52
|
+
# y2 = true_fn(x, m.weight.value, indices, indptr, 40)
|
53
|
+
# self.assertTrue(jnp.allclose(y, y2))
|
54
|
+
#
|
55
|
+
# @parameterized.product(
|
56
|
+
# bool_x=[True, False],
|
57
|
+
# homo_w=[True, False]
|
58
|
+
# )
|
59
|
+
# def test_vjp(self, bool_x, homo_w):
|
60
|
+
# n_in = 20
|
61
|
+
# n_out = 30
|
62
|
+
# if bool_x:
|
63
|
+
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
64
|
+
# else:
|
65
|
+
# x = bst.random.rand(n_in)
|
66
|
+
#
|
67
|
+
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
68
|
+
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
|
69
|
+
# w = fn.weight.value
|
70
|
+
#
|
71
|
+
# def f(x, w):
|
72
|
+
# fn.weight.value = w
|
73
|
+
# return fn(x).sum()
|
74
|
+
#
|
75
|
+
# r = jax.grad(f, argnums=(0, 1))(x, w)
|
76
|
+
#
|
77
|
+
# # -------------------
|
78
|
+
# # TRUE gradients
|
79
|
+
#
|
80
|
+
# def f2(x, w):
|
81
|
+
# return true_fn(x, w, indices, indptr, n_out).sum()
|
82
|
+
#
|
83
|
+
# r2 = jax.grad(f2, argnums=(0, 1))(x, w)
|
84
|
+
# self.assertTrue(jnp.allclose(r[0], r2[0]))
|
85
|
+
# self.assertTrue(jnp.allclose(r[1], r2[1]))
|
86
|
+
#
|
87
|
+
# @parameterized.product(
|
88
|
+
# bool_x=[True, False],
|
89
|
+
# homo_w=[True, False]
|
90
|
+
# )
|
91
|
+
# def test_jvp(self, bool_x, homo_w):
|
92
|
+
# n_in = 20
|
93
|
+
# n_out = 30
|
94
|
+
# if bool_x:
|
95
|
+
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
96
|
+
# else:
|
97
|
+
# x = bst.random.rand(n_in)
|
98
|
+
#
|
99
|
+
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
100
|
+
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
|
101
|
+
# 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
|
102
|
+
# w = fn.weight.value
|
103
|
+
#
|
104
|
+
# def f(x, w):
|
105
|
+
# fn.weight.value = w
|
106
|
+
# return fn(x)
|
107
|
+
#
|
108
|
+
# o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
109
|
+
#
|
110
|
+
# # -------------------
|
111
|
+
# # TRUE gradients
|
112
|
+
#
|
113
|
+
# def f2(x, w):
|
114
|
+
# return true_fn(x, w, indices, indptr, n_out)
|
115
|
+
#
|
116
|
+
# o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
117
|
+
# self.assertTrue(jnp.allclose(r1, r2))
|
118
|
+
# self.assertTrue(jnp.allclose(o1, o2))
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# -*- coding: utf-8 -*-
|
16
|
+
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class TestCSR(unittest.TestCase):
|
26
|
+
def test_event_homo_bool(self):
|
27
|
+
for dat in [1., 2., 3.]:
|
28
|
+
mask = (bst.random.rand(10, 20) < 0.1).astype(float) * dat
|
29
|
+
csr = u.sparse.CSR.fromdense(mask)
|
30
|
+
csr = bst.event.CSR((dat, csr.indices, csr.indptr), shape=mask.shape)
|
31
|
+
|
32
|
+
v = bst.random.rand(20) < 0.5
|
33
|
+
self.assertTrue(
|
34
|
+
u.math.allclose(
|
35
|
+
mask.astype(float) @ v.astype(float),
|
36
|
+
csr @ v
|
37
|
+
)
|
38
|
+
)
|
39
|
+
|
40
|
+
v = bst.random.rand(10) < 0.5
|
41
|
+
self.assertTrue(
|
42
|
+
u.math.allclose(
|
43
|
+
v.astype(float) @ mask.astype(float),
|
44
|
+
v @ csr
|
45
|
+
)
|
46
|
+
)
|
47
|
+
|
48
|
+
def test_event_homo_heter(self):
|
49
|
+
mat = bst.random.rand(10, 20)
|
50
|
+
mask = (bst.random.rand(10, 20) < 0.1) * mat
|
51
|
+
csr = u.sparse.CSR.fromdense(mask)
|
52
|
+
csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
|
53
|
+
|
54
|
+
v = bst.random.rand(20) < 0.5
|
55
|
+
self.assertTrue(
|
56
|
+
u.math.allclose(
|
57
|
+
mask.astype(float) @ v.astype(float),
|
58
|
+
csr @ v
|
59
|
+
)
|
60
|
+
)
|
61
|
+
|
62
|
+
v = bst.random.rand(10) < 0.5
|
63
|
+
self.assertTrue(
|
64
|
+
u.math.allclose(
|
65
|
+
v.astype(float) @ mask.astype(float),
|
66
|
+
v @ csr
|
67
|
+
)
|
68
|
+
)
|
69
|
+
|
70
|
+
def test_event_heter_float_as_bool(self):
|
71
|
+
mat = bst.random.rand(10, 20)
|
72
|
+
mask = (mat < 0.1).astype(float) * mat
|
73
|
+
csr = u.sparse.CSR.fromdense(mask)
|
74
|
+
csr = bst.event.CSR((csr.data, csr.indices, csr.indptr), shape=mask.shape)
|
75
|
+
|
76
|
+
v = (bst.random.rand(20) < 0.5).astype(float)
|
77
|
+
self.assertTrue(
|
78
|
+
u.math.allclose(
|
79
|
+
mask.astype(float) @ v.astype(float),
|
80
|
+
csr @ v
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
v = (bst.random.rand(10) < 0.5).astype(float)
|
85
|
+
self.assertTrue(
|
86
|
+
u.math.allclose(
|
87
|
+
v.astype(float) @ mask.astype(float),
|
88
|
+
v @ csr
|
89
|
+
)
|
90
|
+
)
|
@@ -85,44 +85,52 @@ class FixedProb(Module):
|
|
85
85
|
self.in_size = in_size
|
86
86
|
self.out_size = out_size
|
87
87
|
self.n_conn = int(self.out_size[-1] * prob)
|
88
|
-
if self.n_conn < 1:
|
89
|
-
raise ValueError(f"The number of connections must be at least 1. "
|
90
|
-
f"Got: int({self.out_size[-1]} * {prob}) = {self.n_conn}")
|
91
88
|
self.float_as_event = float_as_event
|
92
89
|
self.block_size = block_size
|
93
90
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
91
|
+
if self.n_conn > 1:
|
92
|
+
# indices of post connected neurons
|
93
|
+
with jax.ensure_compile_time_eval():
|
94
|
+
if allow_multi_conn:
|
95
|
+
rng = np.random.RandomState(seed)
|
96
|
+
self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
|
97
|
+
else:
|
98
|
+
rng = RandomState(seed)
|
101
99
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
100
|
+
@vmap(rngs=rng)
|
101
|
+
def rand_indices(key):
|
102
|
+
rng.set_key(key)
|
103
|
+
return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
|
106
104
|
|
107
|
-
|
108
|
-
|
105
|
+
self.indices = rand_indices(rng.split_key(self.in_size[-1]))
|
106
|
+
self.indices = u.math.asarray(self.indices)
|
109
107
|
|
110
108
|
# maximum synaptic conductance
|
111
109
|
weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
|
112
110
|
self.weight = ParamState(weight)
|
113
111
|
|
114
112
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
113
|
+
if self.n_conn > 1:
|
114
|
+
return event_fixed_prob(
|
115
|
+
spk,
|
116
|
+
self.weight.value,
|
117
|
+
self.indices,
|
118
|
+
n_post=self.out_size[-1],
|
119
|
+
block_size=self.block_size,
|
120
|
+
float_as_event=self.float_as_event
|
121
|
+
)
|
122
|
+
else:
|
123
|
+
weight = self.weight.value
|
124
|
+
unit = u.get_unit(weight)
|
125
|
+
r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
126
|
+
return u.maybe_decimal(u.Quantity(r, unit=unit))
|
123
127
|
|
124
128
|
|
125
|
-
def event_fixed_prob(
|
129
|
+
def event_fixed_prob(
|
130
|
+
spk, weight, indices,
|
131
|
+
*,
|
132
|
+
n_post, block_size, float_as_event
|
133
|
+
):
|
126
134
|
"""
|
127
135
|
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
|
128
136
|
|
@@ -374,7 +382,11 @@ def gpu_kernel_generator(
|
|
374
382
|
kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))
|
375
383
|
|
376
384
|
|
377
|
-
def jvp_spikes(
|
385
|
+
def jvp_spikes(
|
386
|
+
spk_dot, spikes, weights, indices,
|
387
|
+
*,
|
388
|
+
n_post, block_size, **kwargs
|
389
|
+
):
|
378
390
|
return ellmv_p_call(
|
379
391
|
spk_dot,
|
380
392
|
weights,
|
@@ -384,7 +396,11 @@ def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwarg
|
|
384
396
|
)
|
385
397
|
|
386
398
|
|
387
|
-
def jvp_weights(
|
399
|
+
def jvp_weights(
|
400
|
+
w_dot, spikes, weights, indices,
|
401
|
+
*,
|
402
|
+
float_as_event, block_size, n_post, **kwargs
|
403
|
+
):
|
388
404
|
return event_ellmv_p_call(
|
389
405
|
spikes,
|
390
406
|
w_dot,
|
@@ -457,14 +473,18 @@ def transpose_rule(
|
|
457
473
|
|
458
474
|
event_ellmv_p = XLACustomOp(
|
459
475
|
'event_ell_mv',
|
460
|
-
|
461
|
-
|
476
|
+
cpu_kernel_or_generator=cpu_kernel_generator,
|
477
|
+
gpu_kernel_or_generator=gpu_kernel_generator,
|
462
478
|
)
|
463
479
|
event_ellmv_p.defjvp(jvp_spikes, jvp_weights, None)
|
464
480
|
event_ellmv_p.def_transpose_rule(transpose_rule)
|
465
481
|
|
466
482
|
|
467
|
-
def event_ellmv_p_call(
|
483
|
+
def event_ellmv_p_call(
|
484
|
+
spikes, weights, indices,
|
485
|
+
*,
|
486
|
+
n_post, block_size, float_as_event
|
487
|
+
):
|
468
488
|
n_conn = indices.shape[1]
|
469
489
|
if block_size is None:
|
470
490
|
if n_conn <= 16:
|
@@ -673,8 +693,8 @@ def transpose_rule_no_spk(
|
|
673
693
|
|
674
694
|
ellmv_p = XLACustomOp(
|
675
695
|
'ell_mv',
|
676
|
-
|
677
|
-
|
696
|
+
cpu_kernel_or_generator=ell_cpu_kernel_generator,
|
697
|
+
gpu_kernel_or_generator=ell_gpu_kernel_generator,
|
678
698
|
)
|
679
699
|
ellmv_p.defjvp(jvp_spikes, jvp_weights_no_spk, None)
|
680
700
|
ellmv_p.def_transpose_rule(transpose_rule_no_spk)
|
brainstate/event/_linear_mv.py
CHANGED
@@ -334,8 +334,8 @@ def transpose_rule(ct, spikes, weights, *, float_as_event, **kwargs):
|
|
334
334
|
|
335
335
|
event_linear_p = XLACustomOp(
|
336
336
|
'event_linear',
|
337
|
-
|
338
|
-
|
337
|
+
cpu_kernel_or_generator=cpu_kernel_generator,
|
338
|
+
gpu_kernel_or_generator=gpu_kernel_generator,
|
339
339
|
)
|
340
340
|
event_linear_p.defjvp(jvp_spikes, jvp_weights)
|
341
341
|
event_linear_p.def_transpose_rule(transpose_rule)
|
@@ -180,8 +180,8 @@ class XLACustomOp:
|
|
180
180
|
"""Creating a XLA custom call operator.
|
181
181
|
|
182
182
|
Args:
|
183
|
-
|
184
|
-
|
183
|
+
cpu_kernel_or_generator: Callable. The function defines the computation on CPU backend.
|
184
|
+
gpu_kernel_or_generator: Callable. The function defines the computation on GPU backend.
|
185
185
|
batching_translation: Callable. The batching translation rule of JAX.
|
186
186
|
jvp_translation: Callable. The JVP translation rule of JAX.
|
187
187
|
transpose_translation: Callable. The transpose translation rule of JAX.
|
@@ -191,15 +191,12 @@ class XLACustomOp:
|
|
191
191
|
def __init__(
|
192
192
|
self,
|
193
193
|
name: str,
|
194
|
-
|
195
|
-
|
194
|
+
cpu_kernel_or_generator: Callable,
|
195
|
+
gpu_kernel_or_generator: Callable = None,
|
196
196
|
batching_translation: Callable = None,
|
197
197
|
jvp_translation: Callable = None,
|
198
198
|
transpose_translation: Callable = None,
|
199
199
|
):
|
200
|
-
# set cpu_kernel and gpu_kernel
|
201
|
-
self.cpu_kernel = cpu_kernel_generator
|
202
|
-
|
203
200
|
# primitive
|
204
201
|
self.primitive = jax.core.Primitive(name)
|
205
202
|
self.primitive.multiple_results = True
|
@@ -209,10 +206,10 @@ class XLACustomOp:
|
|
209
206
|
self.primitive.def_abstract_eval(self._abstract_eval)
|
210
207
|
|
211
208
|
# cpu kernel
|
212
|
-
if
|
213
|
-
self.def_cpu_kernel(
|
214
|
-
if
|
215
|
-
self.def_gpu_kernel(
|
209
|
+
if cpu_kernel_or_generator is not None:
|
210
|
+
self.def_cpu_kernel(cpu_kernel_or_generator)
|
211
|
+
if gpu_kernel_or_generator is not None:
|
212
|
+
self.def_gpu_kernel(gpu_kernel_or_generator)
|
216
213
|
|
217
214
|
# batching rule
|
218
215
|
if batching_translation is not None:
|
brainstate/graph/_graph_node.py
CHANGED
@@ -61,6 +61,9 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
61
61
|
- Deepcopy the node.
|
62
62
|
|
63
63
|
"""
|
64
|
+
|
65
|
+
graph_invisible_attrs = ()
|
66
|
+
|
64
67
|
if TYPE_CHECKING:
|
65
68
|
_trace_state: StateJaxTracer
|
66
69
|
|
@@ -170,7 +173,13 @@ def _to_shape_dtype(value):
|
|
170
173
|
def _node_flatten(
|
171
174
|
node: Node
|
172
175
|
) -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[Type]]:
|
173
|
-
|
176
|
+
# graph_invisible_attrs = getattr(node, 'graph_invisible_attrs', ())
|
177
|
+
# graph_invisible_attrs = tuple(graph_invisible_attrs) + ('_trace_state',)
|
178
|
+
graph_invisible_attrs = ('_trace_state',)
|
179
|
+
nodes = sorted(
|
180
|
+
(key, value) for key, value in vars(node).items()
|
181
|
+
if (key not in graph_invisible_attrs)
|
182
|
+
)
|
174
183
|
return nodes, (type(node),)
|
175
184
|
|
176
185
|
|
@@ -608,9 +608,9 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
608
608
|
if isinstance(value, TreefyState):
|
609
609
|
variable.update_from_ref(value)
|
610
610
|
elif isinstance(value, State):
|
611
|
-
|
611
|
+
if value._been_writen:
|
612
612
|
variable.write_value(value.value)
|
613
|
-
|
613
|
+
else:
|
614
614
|
variable.restore_value(value.value)
|
615
615
|
else:
|
616
616
|
raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
|
@@ -1600,10 +1600,12 @@ def iter_leaf(
|
|
1600
1600
|
visited_.add(id(node_))
|
1601
1601
|
node_dict = _get_node_impl(node_).node_dict(node_)
|
1602
1602
|
for key, value in node_dict.items():
|
1603
|
-
yield from _iter_graph_leaf(
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1603
|
+
yield from _iter_graph_leaf(
|
1604
|
+
value,
|
1605
|
+
visited_,
|
1606
|
+
(*path_parts_, key),
|
1607
|
+
level_ + 1 if _is_graph_node(value) else level_
|
1608
|
+
)
|
1607
1609
|
else:
|
1608
1610
|
if level_ >= allowed_hierarchy[0]:
|
1609
1611
|
yield path_parts_, node_
|