brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -2
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/__init__.py +10 -20
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/__init__.py +18 -37
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +6 -3
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +15 -4
- brainstate/compile/_make_jaxpr_test.py +10 -6
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +9 -6
- brainstate/graph/__init__.py +12 -16
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_elementwise/_dropout_test.py +1 -1
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
- brainstate/event/__init__.py +0 -29
- brainstate/event/_csr.py +0 -906
- brainstate/event/_csr_mv.py +0 -303
- brainstate/event/_csr_mv_benchmark.py +0 -14
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/event/_csr_test.py +0 -90
- brainstate/event/_fixedprob_mv.py +0 -730
- brainstate/event/_fixedprob_mv_benchmark.py +0 -128
- brainstate/event/_fixedprob_mv_test.py +0 -132
- brainstate/event/_linear_mv.py +0 -359
- brainstate/event/_linear_mv_benckmark.py +0 -82
- brainstate/event/_linear_mv_test.py +0 -117
- brainstate/event/_misc.py +0 -34
- brainstate/event/_xla_custom_op.py +0 -313
- brainstate/event/_xla_custom_op_test.py +0 -55
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,16 @@ from __future__ import annotations
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import jax
|
21
|
-
import jax.extend as je
|
22
21
|
import jax.numpy as jnp
|
23
22
|
import pytest
|
24
23
|
|
25
24
|
import brainstate as bst
|
26
25
|
|
26
|
+
if jax.__version_info__ < (0, 4, 38):
|
27
|
+
from jax.core import jaxpr_as_fun
|
28
|
+
else:
|
29
|
+
from jax.extend.core import jaxpr_as_fun
|
30
|
+
|
27
31
|
|
28
32
|
class TestMakeJaxpr(unittest.TestCase):
|
29
33
|
def test_compar_jax_make_jaxpr(self):
|
@@ -85,7 +89,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
85
89
|
print(jaxpr)
|
86
90
|
jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
|
87
91
|
print(jaxpr)
|
88
|
-
self.assertTrue(jnp.allclose(
|
92
|
+
self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
89
93
|
f3(jnp.zeros(1))))
|
90
94
|
|
91
95
|
def test_compar_jax_make_jaxpr2(self):
|
@@ -103,10 +107,10 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
103
107
|
print()
|
104
108
|
print(jaxpr)
|
105
109
|
print(states)
|
106
|
-
print(
|
110
|
+
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
|
107
111
|
jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
|
108
112
|
print(jaxpr)
|
109
|
-
print(
|
113
|
+
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
110
114
|
|
111
115
|
def test_compar_jax_make_jaxpr3(self):
|
112
116
|
def fa(x):
|
@@ -116,10 +120,10 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
116
120
|
print()
|
117
121
|
print(jaxpr)
|
118
122
|
print(states)
|
119
|
-
# print(
|
123
|
+
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
120
124
|
jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
|
121
125
|
print(jaxpr)
|
122
|
-
# print(
|
126
|
+
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
123
127
|
|
124
128
|
|
125
129
|
def test_return_states():
|
@@ -16,34 +16,59 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import copy
|
19
|
-
|
19
|
+
import importlib.util
|
20
|
+
from typing import Optional, Callable, Any, Tuple, Dict
|
20
21
|
|
21
22
|
import jax
|
22
23
|
|
23
|
-
|
24
|
-
from tqdm.auto import tqdm
|
25
|
-
except (ImportError, ModuleNotFoundError):
|
26
|
-
tqdm = None
|
24
|
+
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
27
25
|
|
28
26
|
__all__ = [
|
29
27
|
'ProgressBar',
|
30
28
|
]
|
31
29
|
|
30
|
+
Index = int
|
31
|
+
Carray = Any
|
32
|
+
Output = Any
|
33
|
+
|
32
34
|
|
33
35
|
class ProgressBar(object):
|
36
|
+
"""
|
37
|
+
A progress bar for tracking the progress of a jitted for-loop computation.
|
38
|
+
"""
|
34
39
|
__module__ = "brainstate.compile"
|
35
40
|
|
36
|
-
def __init__(
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
freq: Optional[int] = None,
|
44
|
+
count: Optional[int] = None,
|
45
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
|
46
|
+
**kwargs
|
47
|
+
):
|
48
|
+
# print rate
|
37
49
|
self.print_freq = freq
|
38
50
|
if isinstance(freq, int):
|
39
51
|
assert freq > 0, "Print rate should be > 0."
|
52
|
+
|
53
|
+
# print count
|
40
54
|
self.print_count = count
|
41
55
|
if self.print_freq is not None and self.print_count is not None:
|
42
56
|
raise ValueError("Cannot specify both count and freq.")
|
57
|
+
|
58
|
+
# other parameters
|
43
59
|
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
44
60
|
kwargs.pop(kwarg, None)
|
45
61
|
self.kwargs = kwargs
|
46
|
-
|
62
|
+
|
63
|
+
# description
|
64
|
+
if desc is not None:
|
65
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
66
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
67
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
68
|
+
self.desc = desc
|
69
|
+
|
70
|
+
# check if tqdm is installed
|
71
|
+
if not tqdm_installed:
|
47
72
|
raise ImportError("tqdm is not installed.")
|
48
73
|
|
49
74
|
def init(self, n: int):
|
@@ -67,15 +92,22 @@ class ProgressBar(object):
|
|
67
92
|
raise ValueError("Print rate should be less than the "
|
68
93
|
f"number of steps {n}, got {freq}")
|
69
94
|
remainder = n % freq
|
70
|
-
|
71
|
-
message =
|
72
|
-
return ProgressBarRunner(n,
|
95
|
+
|
96
|
+
message = f"Running for {n:,} iterations" if self.desc is None else self.desc
|
97
|
+
return ProgressBarRunner(n, freq, remainder, message, **kwargs)
|
73
98
|
|
74
99
|
|
75
100
|
class ProgressBarRunner(object):
|
76
101
|
__module__ = "brainstate.compile"
|
77
102
|
|
78
|
-
def __init__(
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
n: int,
|
106
|
+
print_freq: int,
|
107
|
+
remainder: int,
|
108
|
+
message: str | Tuple[str, Callable[[Dict], Dict]],
|
109
|
+
**kwargs
|
110
|
+
):
|
79
111
|
self.tqdm_bars = {}
|
80
112
|
self.kwargs = kwargs
|
81
113
|
self.n = n
|
@@ -83,50 +115,46 @@ class ProgressBarRunner(object):
|
|
83
115
|
self.remainder = remainder
|
84
116
|
self.message = message
|
85
117
|
|
86
|
-
def _define_tqdm(self):
|
118
|
+
def _define_tqdm(self, x: dict):
|
119
|
+
from tqdm.auto import tqdm
|
87
120
|
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
88
|
-
|
121
|
+
if isinstance(self.message, str):
|
122
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
123
|
+
else:
|
124
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
89
125
|
|
90
|
-
def _update_tqdm(self):
|
126
|
+
def _update_tqdm(self, x: dict):
|
91
127
|
self.tqdm_bars[0].update(self.print_freq)
|
128
|
+
if not isinstance(self.message, str):
|
129
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
92
130
|
|
93
|
-
def _close_tqdm(self):
|
131
|
+
def _close_tqdm(self, x: dict):
|
94
132
|
if self.remainder > 0:
|
95
133
|
self.tqdm_bars[0].update(self.remainder)
|
134
|
+
if not isinstance(self.message, str):
|
135
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
96
136
|
self.tqdm_bars[0].close()
|
97
137
|
|
98
|
-
def
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
if is_print:
|
103
|
-
self.tqdm_bars[0].update(self.print_freq)
|
104
|
-
if is_final:
|
105
|
-
if self.remainder > 0:
|
106
|
-
self.tqdm_bars[0].update(self.remainder)
|
107
|
-
self.tqdm_bars[0].close()
|
108
|
-
|
109
|
-
def __call__(self, iter_num, *args, **kwargs):
|
110
|
-
# jax.debug.callback(
|
111
|
-
# self._tqdm,
|
112
|
-
# iter_num == 0,
|
113
|
-
# (iter_num + 1) % self.print_freq == 0,
|
114
|
-
# iter_num == self.n - 1
|
115
|
-
# )
|
138
|
+
def __call__(self, iter_num, **kwargs):
|
139
|
+
data = dict(i=iter_num, **kwargs)
|
140
|
+
data = dict() if isinstance(self.message, str) else self.message[1](data)
|
141
|
+
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
116
142
|
|
117
143
|
_ = jax.lax.cond(
|
118
144
|
iter_num == 0,
|
119
|
-
lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
120
|
-
lambda: None,
|
145
|
+
lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
|
146
|
+
lambda x: None,
|
147
|
+
data
|
121
148
|
)
|
122
149
|
_ = jax.lax.cond(
|
123
150
|
iter_num % self.print_freq == (self.print_freq - 1),
|
124
|
-
lambda: jax.debug.callback(self._update_tqdm, ordered=True),
|
125
|
-
lambda: None,
|
151
|
+
lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
|
152
|
+
lambda x: None,
|
153
|
+
data
|
126
154
|
)
|
127
155
|
_ = jax.lax.cond(
|
128
156
|
iter_num == self.n - 1,
|
129
|
-
lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
130
|
-
lambda: None,
|
157
|
+
lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
|
158
|
+
lambda x: None,
|
159
|
+
data
|
131
160
|
)
|
132
|
-
|
brainstate/compile/_unvmap.py
CHANGED
@@ -16,13 +16,16 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import jax
|
18
18
|
import jax.core
|
19
|
-
import jax.extend as je
|
20
19
|
import jax.interpreters.batching as batching
|
21
20
|
import jax.interpreters.mlir as mlir
|
22
21
|
import jax.numpy as jnp
|
23
|
-
|
24
22
|
from brainstate._utils import set_module_as
|
25
23
|
|
24
|
+
if jax.__version_info__ < (0, 4, 38):
|
25
|
+
from jax.core import Primitive
|
26
|
+
else:
|
27
|
+
from jax.extend.core import Primitive
|
28
|
+
|
26
29
|
__all__ = [
|
27
30
|
"unvmap",
|
28
31
|
]
|
@@ -44,7 +47,7 @@ def unvmap(x, op: str = 'any'):
|
|
44
47
|
|
45
48
|
# unvmap_all
|
46
49
|
|
47
|
-
unvmap_all_p =
|
50
|
+
unvmap_all_p = Primitive("unvmap_all")
|
48
51
|
|
49
52
|
|
50
53
|
def unvmap_all(x):
|
@@ -75,7 +78,7 @@ mlir.register_lowering(
|
|
75
78
|
|
76
79
|
# unvmap_any
|
77
80
|
|
78
|
-
unvmap_any_p =
|
81
|
+
unvmap_any_p = Primitive("unvmap_any")
|
79
82
|
|
80
83
|
|
81
84
|
def unvmap_any(x):
|
@@ -106,7 +109,7 @@ mlir.register_lowering(
|
|
106
109
|
|
107
110
|
# unvmap_max
|
108
111
|
|
109
|
-
unvmap_max_p =
|
112
|
+
unvmap_max_p = Primitive("unvmap_max")
|
110
113
|
|
111
114
|
|
112
115
|
def unvmap_max(x):
|
@@ -153,7 +156,7 @@ def _without_vmap_batch(x, batch_axes):
|
|
153
156
|
return _without_vmap(x), batching.not_mapped
|
154
157
|
|
155
158
|
|
156
|
-
_no_vmap_prim =
|
159
|
+
_no_vmap_prim = Primitive('no_vmap')
|
157
160
|
_no_vmap_prim.def_impl(_without_vmap_imp)
|
158
161
|
_no_vmap_prim.def_abstract_eval(_without_vmap_abs)
|
159
162
|
batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
|
brainstate/graph/__init__.py
CHANGED
@@ -14,20 +14,16 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from .
|
18
|
-
from .
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
from ._graph_operation import *
|
24
|
-
from ._graph_operation import __all__ as _graph_operation__all__
|
17
|
+
from ._graph_node import Node, Dict, List, Sequential
|
18
|
+
from ._graph_operation import (
|
19
|
+
pop_states, nodes, states, treefy_states, update_states, flatten, unflatten,
|
20
|
+
treefy_split, treefy_merge, iter_leaf, iter_node, clone, graphdef,
|
21
|
+
call, RefMap, GraphDef, NodeRef, NodeDef
|
22
|
+
)
|
25
23
|
|
26
|
-
__all__ =
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
_graph_node__all__,
|
33
|
-
_graph_operation__all__)
|
24
|
+
__all__ = [
|
25
|
+
'Node', 'Dict', 'List', 'Sequential',
|
26
|
+
'pop_states', 'nodes', 'states', 'treefy_states', 'update_states', 'flatten', 'unflatten',
|
27
|
+
'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef',
|
28
|
+
'call', 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef',
|
29
|
+
]
|
brainstate/graph/_graph_node.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
@@ -27,9 +27,7 @@ import numpy as np
|
|
27
27
|
|
28
28
|
from brainstate._state import State, TreefyState
|
29
29
|
from brainstate.typing import Key
|
30
|
-
from brainstate.util._error import TraceContextError
|
31
30
|
from brainstate.util._pretty_repr import PrettyRepr, pretty_repr_avoid_duplicate, PrettyType, PrettyAttr
|
32
|
-
from brainstate.util._tracers import StateJaxTracer
|
33
31
|
from ._graph_operation import register_graph_node_type
|
34
32
|
|
35
33
|
__all__ = [
|
@@ -44,7 +42,6 @@ class GraphNodeMeta(ABCMeta):
|
|
44
42
|
if not TYPE_CHECKING:
|
45
43
|
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
46
44
|
node = cls.__new__(cls, *args, **kwargs)
|
47
|
-
vars(node)['_trace_state'] = StateJaxTracer()
|
48
45
|
node.__init__(*args, **kwargs)
|
49
46
|
return node
|
50
47
|
|
@@ -64,9 +61,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
64
61
|
|
65
62
|
graph_invisible_attrs = ()
|
66
63
|
|
67
|
-
if TYPE_CHECKING:
|
68
|
-
_trace_state: StateJaxTracer
|
69
|
-
|
70
64
|
def __init_subclass__(cls) -> None:
|
71
65
|
super().__init_subclass__()
|
72
66
|
|
@@ -79,21 +73,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
79
73
|
clear=_node_clear,
|
80
74
|
)
|
81
75
|
|
82
|
-
# if not TYPE_CHECKING:
|
83
|
-
# def __setattr__(self, name: str, value: Any) -> None:
|
84
|
-
# self._setattr(name, value)
|
85
|
-
|
86
|
-
# def _setattr(self, name: str, value: Any) -> None:
|
87
|
-
# self.check_valid_context(lambda: f"Cannot mutate '{type(self).__name__}' from different trace level")
|
88
|
-
# object.__setattr__(self, name, value)
|
89
|
-
|
90
|
-
def check_valid_context(self, error_msg: Callable[[], str]) -> None:
|
91
|
-
"""
|
92
|
-
Check if the current context is valid for the object to be mutated.
|
93
|
-
"""
|
94
|
-
if not self._trace_state.is_valid():
|
95
|
-
raise TraceContextError(error_msg())
|
96
|
-
|
97
76
|
def __deepcopy__(self: G, memo=None) -> G:
|
98
77
|
"""
|
99
78
|
Deepcopy the object.
|
@@ -214,7 +193,6 @@ def _node_create_empty(
|
|
214
193
|
) -> G:
|
215
194
|
node_type, = static
|
216
195
|
node = object.__new__(node_type)
|
217
|
-
vars(node).update(_trace_state=StateJaxTracer())
|
218
196
|
return node
|
219
197
|
|
220
198
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
@@ -17,13 +17,10 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import unittest
|
19
19
|
from collections.abc import Callable
|
20
|
-
from functools import partial
|
21
20
|
from threading import Thread
|
22
|
-
from typing import Any
|
23
21
|
|
24
22
|
import jax
|
25
23
|
import jax.numpy as jnp
|
26
|
-
import pytest
|
27
24
|
from absl.testing import absltest, parameterized
|
28
25
|
|
29
26
|
import brainstate as bst
|
@@ -354,125 +351,6 @@ class TestGraphUtils(absltest.TestCase):
|
|
354
351
|
assert m2.tree.a is not m.tree.a
|
355
352
|
assert m2.tree is not m.tree
|
356
353
|
|
357
|
-
@pytest.mark.skip(reason='Not implemented')
|
358
|
-
def test_cached_unflatten(self):
|
359
|
-
class Foo(bst.graph.Node):
|
360
|
-
def __init__(self, ):
|
361
|
-
self.a = bst.nn.Linear(2, 2)
|
362
|
-
self.b = bst.nn.BatchNorm1d([10, 2])
|
363
|
-
|
364
|
-
def f(m: Foo):
|
365
|
-
m.a, m.b = m.b, m.a # type: ignore
|
366
|
-
|
367
|
-
m = Foo()
|
368
|
-
a = m.a
|
369
|
-
b = m.b
|
370
|
-
|
371
|
-
ref_out_idx_out = bst.graph.RefMap()
|
372
|
-
graphdef: bst.graph.GraphDef[Foo]
|
373
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
374
|
-
|
375
|
-
@partial(jax.jit, static_argnums=(0,))
|
376
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
377
|
-
idx_out_ref_in: dict[int, Any] = {}
|
378
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
379
|
-
f(m)
|
380
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
381
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
382
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
383
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
384
|
-
return state, static_out
|
385
|
-
|
386
|
-
static_out: bst.graph.Static
|
387
|
-
state, static_out = f_pure(graphdef, state)
|
388
|
-
idx_out_idx_in: dict[int, int]
|
389
|
-
graphdef, idx_out_idx_in = static_out.value
|
390
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
391
|
-
ref_out_idx_out, idx_out_idx_in
|
392
|
-
)
|
393
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
394
|
-
assert m2 is m
|
395
|
-
assert m2.a is b
|
396
|
-
assert m2.b is a
|
397
|
-
|
398
|
-
@pytest.mark.skip(reason='Not implemented')
|
399
|
-
def test_cached_unflatten_swap_variables(self):
|
400
|
-
class Foo(bst.graph.Node):
|
401
|
-
def __init__(self):
|
402
|
-
self.a = bst.ParamState(1)
|
403
|
-
self.b = bst.ParamState(2)
|
404
|
-
|
405
|
-
def f(m: Foo):
|
406
|
-
m.a, m.b = m.b, m.a
|
407
|
-
|
408
|
-
m = Foo()
|
409
|
-
a = m.a
|
410
|
-
b = m.b
|
411
|
-
|
412
|
-
ref_out_idx_out = bst.graph.RefMap[Any, int]()
|
413
|
-
graphdef: bst.graph.GraphDef[Foo]
|
414
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
415
|
-
|
416
|
-
@partial(jax.jit, static_argnums=(0,))
|
417
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
418
|
-
idx_out_ref_in: dict[int, Any] = {}
|
419
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
420
|
-
f(m)
|
421
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
422
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
423
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
424
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
425
|
-
return state, static_out
|
426
|
-
|
427
|
-
static_out: bst.graph.Static
|
428
|
-
state, static_out = f_pure(graphdef, state)
|
429
|
-
idx_out_idx_in: dict[int, int]
|
430
|
-
graphdef, idx_out_idx_in = static_out.value
|
431
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
432
|
-
ref_out_idx_out, idx_out_idx_in
|
433
|
-
)
|
434
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
435
|
-
assert m2 is m
|
436
|
-
assert m2.a is b
|
437
|
-
assert m2.b is a
|
438
|
-
|
439
|
-
@pytest.mark.skip(reason='Not implemented')
|
440
|
-
def test_cached_unflatten_add_self_reference(self):
|
441
|
-
class Foo(bst.graph.Node):
|
442
|
-
def __init__(self):
|
443
|
-
self.ref = None
|
444
|
-
|
445
|
-
def f(m: Foo):
|
446
|
-
m.ref = m
|
447
|
-
|
448
|
-
m = Foo()
|
449
|
-
|
450
|
-
ref_out_idx_out = bst.graph.RefMap()
|
451
|
-
graphdef: bst.graph.GraphDef[Foo]
|
452
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
453
|
-
|
454
|
-
@partial(jax.jit, static_argnums=(0,))
|
455
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
456
|
-
idx_out_ref_in: dict[int, Any] = {}
|
457
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
458
|
-
f(m)
|
459
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
460
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
461
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
462
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
463
|
-
return state, static_out
|
464
|
-
|
465
|
-
static_out: bst.graph.Static
|
466
|
-
state, static_out = f_pure(graphdef, state)
|
467
|
-
idx_out_idx_in: dict[int, int]
|
468
|
-
graphdef, idx_out_idx_in = static_out.value
|
469
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
470
|
-
ref_out_idx_out, idx_out_idx_in
|
471
|
-
)
|
472
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
473
|
-
assert m2 is m
|
474
|
-
assert m2.ref is m2
|
475
|
-
|
476
354
|
def test_call_jit_update(self):
|
477
355
|
class Counter(bst.graph.Node):
|
478
356
|
def __init__(self):
|
@@ -527,43 +405,6 @@ class TestGraphUtils(absltest.TestCase):
|
|
527
405
|
self.assertEqual(nodes['a'].count.value, 0)
|
528
406
|
self.assertEqual(nodes['b'].count.value, 1)
|
529
407
|
|
530
|
-
def test_to_tree_simple(self):
|
531
|
-
m = bst.nn.Linear(2, 3, )
|
532
|
-
impure_tree = (m, 1, {'b': m})
|
533
|
-
|
534
|
-
pure_tree = bst.graph.graph_to_tree(impure_tree)
|
535
|
-
|
536
|
-
t1 = pure_tree[0]
|
537
|
-
t2 = pure_tree[2]['b']
|
538
|
-
|
539
|
-
self.assertEqual(pure_tree[1], 1)
|
540
|
-
self.assertIsInstance(t1, bst.graph.NodeStates)
|
541
|
-
assert isinstance(t1, bst.graph.NodeStates)
|
542
|
-
self.assertIsInstance(t2, bst.graph.NodeStates)
|
543
|
-
assert isinstance(t2, bst.graph.NodeStates)
|
544
|
-
self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
|
545
|
-
self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
|
546
|
-
self.assertLen(t1.states[0].to_flat(), 1)
|
547
|
-
self.assertLen(t2.states[0].to_flat(), 0)
|
548
|
-
|
549
|
-
impure_tree2 = bst.graph.tree_to_graph(pure_tree)
|
550
|
-
|
551
|
-
m1_out = impure_tree2[0]
|
552
|
-
m2_out = impure_tree2[2]['b']
|
553
|
-
|
554
|
-
self.assertIs(m1_out, m2_out)
|
555
|
-
self.assertEqual(impure_tree2[1], 1)
|
556
|
-
|
557
|
-
def test_to_tree_consistent_prefix(self):
|
558
|
-
m = bst.nn.Linear(2, 3, )
|
559
|
-
impure_tree = (m, 1, {'b': m})
|
560
|
-
prefix = (0, None, 0)
|
561
|
-
pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
562
|
-
|
563
|
-
prefix = (0, None, 1)
|
564
|
-
with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
|
565
|
-
bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
566
|
-
|
567
408
|
|
568
409
|
class SimpleModule(bst.nn.Module):
|
569
410
|
pass
|