brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,16 @@ from __future__ import annotations
18
18
  import unittest
19
19
 
20
20
  import jax
21
- import jax.extend as je
22
21
  import jax.numpy as jnp
23
22
  import pytest
24
23
 
25
24
  import brainstate as bst
26
25
 
26
+ if jax.__version_info__ < (0, 4, 38):
27
+ from jax.core import jaxpr_as_fun
28
+ else:
29
+ from jax.extend.core import jaxpr_as_fun
30
+
27
31
 
28
32
  class TestMakeJaxpr(unittest.TestCase):
29
33
  def test_compar_jax_make_jaxpr(self):
@@ -85,7 +89,7 @@ class TestMakeJaxpr(unittest.TestCase):
85
89
  print(jaxpr)
86
90
  jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
87
91
  print(jaxpr)
88
- self.assertTrue(jnp.allclose(je.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
92
+ self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
89
93
  f3(jnp.zeros(1))))
90
94
 
91
95
  def test_compar_jax_make_jaxpr2(self):
@@ -103,10 +107,10 @@ class TestMakeJaxpr(unittest.TestCase):
103
107
  print()
104
108
  print(jaxpr)
105
109
  print(states)
106
- print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
110
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
107
111
  jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
108
112
  print(jaxpr)
109
- print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
113
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
110
114
 
111
115
  def test_compar_jax_make_jaxpr3(self):
112
116
  def fa(x):
@@ -116,10 +120,10 @@ class TestMakeJaxpr(unittest.TestCase):
116
120
  print()
117
121
  print(jaxpr)
118
122
  print(states)
119
- # print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
120
124
  jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
121
125
  print(jaxpr)
122
- # print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
126
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
127
 
124
128
 
125
129
  def test_return_states():
@@ -16,34 +16,59 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import copy
19
- from typing import Optional
19
+ import importlib.util
20
+ from typing import Optional, Callable, Any, Tuple, Dict
20
21
 
21
22
  import jax
22
23
 
23
- try:
24
- from tqdm.auto import tqdm
25
- except (ImportError, ModuleNotFoundError):
26
- tqdm = None
24
+ tqdm_installed = importlib.util.find_spec('tqdm') is not None
27
25
 
28
26
  __all__ = [
29
27
  'ProgressBar',
30
28
  ]
31
29
 
30
+ Index = int
31
+ Carray = Any
32
+ Output = Any
33
+
32
34
 
33
35
  class ProgressBar(object):
36
+ """
37
+ A progress bar for tracking the progress of a jitted for-loop computation.
38
+ """
34
39
  __module__ = "brainstate.compile"
35
40
 
36
- def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
41
+ def __init__(
42
+ self,
43
+ freq: Optional[int] = None,
44
+ count: Optional[int] = None,
45
+ desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
46
+ **kwargs
47
+ ):
48
+ # print rate
37
49
  self.print_freq = freq
38
50
  if isinstance(freq, int):
39
51
  assert freq > 0, "Print rate should be > 0."
52
+
53
+ # print count
40
54
  self.print_count = count
41
55
  if self.print_freq is not None and self.print_count is not None:
42
56
  raise ValueError("Cannot specify both count and freq.")
57
+
58
+ # other parameters
43
59
  for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
44
60
  kwargs.pop(kwarg, None)
45
61
  self.kwargs = kwargs
46
- if tqdm is None:
62
+
63
+ # description
64
+ if desc is not None:
65
+ assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
66
+ assert isinstance(desc[0], str), 'Description should be a string.'
67
+ assert callable(desc[1]), 'Description should be a callable.'
68
+ self.desc = desc
69
+
70
+ # check if tqdm is installed
71
+ if not tqdm_installed:
47
72
  raise ImportError("tqdm is not installed.")
48
73
 
49
74
  def init(self, n: int):
@@ -67,15 +92,22 @@ class ProgressBar(object):
67
92
  raise ValueError("Print rate should be less than the "
68
93
  f"number of steps {n}, got {freq}")
69
94
  remainder = n % freq
70
- desc = kwargs.pop("desc", f"Running for {n:,} iterations")
71
- message = kwargs.pop("message", desc)
72
- return ProgressBarRunner(n, message, freq, remainder, **kwargs)
95
+
96
+ message = f"Running for {n:,} iterations" if self.desc is None else self.desc
97
+ return ProgressBarRunner(n, freq, remainder, message, **kwargs)
73
98
 
74
99
 
75
100
  class ProgressBarRunner(object):
76
101
  __module__ = "brainstate.compile"
77
102
 
78
- def __init__(self, n: int, message, print_freq: int, remainder: int, **kwargs):
103
+ def __init__(
104
+ self,
105
+ n: int,
106
+ print_freq: int,
107
+ remainder: int,
108
+ message: str | Tuple[str, Callable[[Dict], Dict]],
109
+ **kwargs
110
+ ):
79
111
  self.tqdm_bars = {}
80
112
  self.kwargs = kwargs
81
113
  self.n = n
@@ -83,50 +115,46 @@ class ProgressBarRunner(object):
83
115
  self.remainder = remainder
84
116
  self.message = message
85
117
 
86
- def _define_tqdm(self):
118
+ def _define_tqdm(self, x: dict):
119
+ from tqdm.auto import tqdm
87
120
  self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
88
- self.tqdm_bars[0].set_description(self.message, refresh=False)
121
+ if isinstance(self.message, str):
122
+ self.tqdm_bars[0].set_description(self.message, refresh=False)
123
+ else:
124
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
89
125
 
90
- def _update_tqdm(self):
126
+ def _update_tqdm(self, x: dict):
91
127
  self.tqdm_bars[0].update(self.print_freq)
128
+ if not isinstance(self.message, str):
129
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
92
130
 
93
- def _close_tqdm(self):
131
+ def _close_tqdm(self, x: dict):
94
132
  if self.remainder > 0:
95
133
  self.tqdm_bars[0].update(self.remainder)
134
+ if not isinstance(self.message, str):
135
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
96
136
  self.tqdm_bars[0].close()
97
137
 
98
- def _tqdm(self, is_init, is_print, is_final):
99
- if is_init:
100
- self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
101
- self.tqdm_bars[0].set_description(self.message, refresh=False)
102
- if is_print:
103
- self.tqdm_bars[0].update(self.print_freq)
104
- if is_final:
105
- if self.remainder > 0:
106
- self.tqdm_bars[0].update(self.remainder)
107
- self.tqdm_bars[0].close()
108
-
109
- def __call__(self, iter_num, *args, **kwargs):
110
- # jax.debug.callback(
111
- # self._tqdm,
112
- # iter_num == 0,
113
- # (iter_num + 1) % self.print_freq == 0,
114
- # iter_num == self.n - 1
115
- # )
138
+ def __call__(self, iter_num, **kwargs):
139
+ data = dict(i=iter_num, **kwargs)
140
+ data = dict() if isinstance(self.message, str) else self.message[1](data)
141
+ assert isinstance(data, dict), 'Description function should return a dictionary.'
116
142
 
117
143
  _ = jax.lax.cond(
118
144
  iter_num == 0,
119
- lambda: jax.debug.callback(self._define_tqdm, ordered=True),
120
- lambda: None,
145
+ lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
146
+ lambda x: None,
147
+ data
121
148
  )
122
149
  _ = jax.lax.cond(
123
150
  iter_num % self.print_freq == (self.print_freq - 1),
124
- lambda: jax.debug.callback(self._update_tqdm, ordered=True),
125
- lambda: None,
151
+ lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
152
+ lambda x: None,
153
+ data
126
154
  )
127
155
  _ = jax.lax.cond(
128
156
  iter_num == self.n - 1,
129
- lambda: jax.debug.callback(self._close_tqdm, ordered=True),
130
- lambda: None,
157
+ lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
158
+ lambda x: None,
159
+ data
131
160
  )
132
-
@@ -16,13 +16,16 @@ from __future__ import annotations
16
16
 
17
17
  import jax
18
18
  import jax.core
19
- import jax.extend as je
20
19
  import jax.interpreters.batching as batching
21
20
  import jax.interpreters.mlir as mlir
22
21
  import jax.numpy as jnp
23
-
24
22
  from brainstate._utils import set_module_as
25
23
 
24
+ if jax.__version_info__ < (0, 4, 38):
25
+ from jax.core import Primitive
26
+ else:
27
+ from jax.extend.core import Primitive
28
+
26
29
  __all__ = [
27
30
  "unvmap",
28
31
  ]
@@ -44,7 +47,7 @@ def unvmap(x, op: str = 'any'):
44
47
 
45
48
  # unvmap_all
46
49
 
47
- unvmap_all_p = je.core.Primitive("unvmap_all")
50
+ unvmap_all_p = Primitive("unvmap_all")
48
51
 
49
52
 
50
53
  def unvmap_all(x):
@@ -75,7 +78,7 @@ mlir.register_lowering(
75
78
 
76
79
  # unvmap_any
77
80
 
78
- unvmap_any_p = jax.core.Primitive("unvmap_any")
81
+ unvmap_any_p = Primitive("unvmap_any")
79
82
 
80
83
 
81
84
  def unvmap_any(x):
@@ -106,7 +109,7 @@ mlir.register_lowering(
106
109
 
107
110
  # unvmap_max
108
111
 
109
- unvmap_max_p = jax.core.Primitive("unvmap_max")
112
+ unvmap_max_p = Primitive("unvmap_max")
110
113
 
111
114
 
112
115
  def unvmap_max(x):
@@ -153,7 +156,7 @@ def _without_vmap_batch(x, batch_axes):
153
156
  return _without_vmap(x), batching.not_mapped
154
157
 
155
158
 
156
- _no_vmap_prim = jax.core.Primitive('no_vmap')
159
+ _no_vmap_prim = Primitive('no_vmap')
157
160
  _no_vmap_prim.def_impl(_without_vmap_imp)
158
161
  _no_vmap_prim.def_abstract_eval(_without_vmap_abs)
159
162
  batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
@@ -14,20 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from ._graph_context import *
18
- from ._graph_context import __all__ as _graph_context__all__
19
- from ._graph_convert import *
20
- from ._graph_convert import __all__ as _graph_convert__all__
21
- from ._graph_node import *
22
- from ._graph_node import __all__ as _graph_node__all__
23
- from ._graph_operation import *
24
- from ._graph_operation import __all__ as _graph_operation__all__
17
+ from ._graph_node import Node, Dict, List, Sequential
18
+ from ._graph_operation import (
19
+ pop_states, nodes, states, treefy_states, update_states, flatten, unflatten,
20
+ treefy_split, treefy_merge, iter_leaf, iter_node, clone, graphdef,
21
+ call, RefMap, GraphDef, NodeRef, NodeDef
22
+ )
25
23
 
26
- __all__ = (_graph_context__all__ +
27
- _graph_convert__all__ +
28
- _graph_node__all__ +
29
- _graph_operation__all__)
30
- del (_graph_context__all__,
31
- _graph_convert__all__,
32
- _graph_node__all__,
33
- _graph_operation__all__)
24
+ __all__ = [
25
+ 'Node', 'Dict', 'List', 'Sequential',
26
+ 'pop_states', 'nodes', 'states', 'treefy_states', 'update_states', 'flatten', 'unflatten',
27
+ 'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef',
28
+ 'call', 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef',
29
+ ]
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -27,9 +27,7 @@ import numpy as np
27
27
 
28
28
  from brainstate._state import State, TreefyState
29
29
  from brainstate.typing import Key
30
- from brainstate.util._error import TraceContextError
31
30
  from brainstate.util._pretty_repr import PrettyRepr, pretty_repr_avoid_duplicate, PrettyType, PrettyAttr
32
- from brainstate.util._tracers import StateJaxTracer
33
31
  from ._graph_operation import register_graph_node_type
34
32
 
35
33
  __all__ = [
@@ -44,7 +42,6 @@ class GraphNodeMeta(ABCMeta):
44
42
  if not TYPE_CHECKING:
45
43
  def __call__(cls, *args: Any, **kwargs: Any) -> Any:
46
44
  node = cls.__new__(cls, *args, **kwargs)
47
- vars(node)['_trace_state'] = StateJaxTracer()
48
45
  node.__init__(*args, **kwargs)
49
46
  return node
50
47
 
@@ -64,9 +61,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
64
61
 
65
62
  graph_invisible_attrs = ()
66
63
 
67
- if TYPE_CHECKING:
68
- _trace_state: StateJaxTracer
69
-
70
64
  def __init_subclass__(cls) -> None:
71
65
  super().__init_subclass__()
72
66
 
@@ -79,21 +73,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
79
73
  clear=_node_clear,
80
74
  )
81
75
 
82
- # if not TYPE_CHECKING:
83
- # def __setattr__(self, name: str, value: Any) -> None:
84
- # self._setattr(name, value)
85
-
86
- # def _setattr(self, name: str, value: Any) -> None:
87
- # self.check_valid_context(lambda: f"Cannot mutate '{type(self).__name__}' from different trace level")
88
- # object.__setattr__(self, name, value)
89
-
90
- def check_valid_context(self, error_msg: Callable[[], str]) -> None:
91
- """
92
- Check if the current context is valid for the object to be mutated.
93
- """
94
- if not self._trace_state.is_valid():
95
- raise TraceContextError(error_msg())
96
-
97
76
  def __deepcopy__(self: G, memo=None) -> G:
98
77
  """
99
78
  Deepcopy the object.
@@ -214,7 +193,6 @@ def _node_create_empty(
214
193
  ) -> G:
215
194
  node_type, = static
216
195
  node = object.__new__(node_type)
217
- vars(node).update(_trace_state=StateJaxTracer())
218
196
  return node
219
197
 
220
198
 
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -17,13 +17,10 @@ from __future__ import annotations
17
17
 
18
18
  import unittest
19
19
  from collections.abc import Callable
20
- from functools import partial
21
20
  from threading import Thread
22
- from typing import Any
23
21
 
24
22
  import jax
25
23
  import jax.numpy as jnp
26
- import pytest
27
24
  from absl.testing import absltest, parameterized
28
25
 
29
26
  import brainstate as bst
@@ -354,125 +351,6 @@ class TestGraphUtils(absltest.TestCase):
354
351
  assert m2.tree.a is not m.tree.a
355
352
  assert m2.tree is not m.tree
356
353
 
357
- @pytest.mark.skip(reason='Not implemented')
358
- def test_cached_unflatten(self):
359
- class Foo(bst.graph.Node):
360
- def __init__(self, ):
361
- self.a = bst.nn.Linear(2, 2)
362
- self.b = bst.nn.BatchNorm1d([10, 2])
363
-
364
- def f(m: Foo):
365
- m.a, m.b = m.b, m.a # type: ignore
366
-
367
- m = Foo()
368
- a = m.a
369
- b = m.b
370
-
371
- ref_out_idx_out = bst.graph.RefMap()
372
- graphdef: bst.graph.GraphDef[Foo]
373
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
374
-
375
- @partial(jax.jit, static_argnums=(0,))
376
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
377
- idx_out_ref_in: dict[int, Any] = {}
378
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
379
- f(m)
380
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
381
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
382
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
383
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
384
- return state, static_out
385
-
386
- static_out: bst.graph.Static
387
- state, static_out = f_pure(graphdef, state)
388
- idx_out_idx_in: dict[int, int]
389
- graphdef, idx_out_idx_in = static_out.value
390
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
391
- ref_out_idx_out, idx_out_idx_in
392
- )
393
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
394
- assert m2 is m
395
- assert m2.a is b
396
- assert m2.b is a
397
-
398
- @pytest.mark.skip(reason='Not implemented')
399
- def test_cached_unflatten_swap_variables(self):
400
- class Foo(bst.graph.Node):
401
- def __init__(self):
402
- self.a = bst.ParamState(1)
403
- self.b = bst.ParamState(2)
404
-
405
- def f(m: Foo):
406
- m.a, m.b = m.b, m.a
407
-
408
- m = Foo()
409
- a = m.a
410
- b = m.b
411
-
412
- ref_out_idx_out = bst.graph.RefMap[Any, int]()
413
- graphdef: bst.graph.GraphDef[Foo]
414
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
415
-
416
- @partial(jax.jit, static_argnums=(0,))
417
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
418
- idx_out_ref_in: dict[int, Any] = {}
419
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
420
- f(m)
421
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
422
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
423
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
424
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
425
- return state, static_out
426
-
427
- static_out: bst.graph.Static
428
- state, static_out = f_pure(graphdef, state)
429
- idx_out_idx_in: dict[int, int]
430
- graphdef, idx_out_idx_in = static_out.value
431
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
432
- ref_out_idx_out, idx_out_idx_in
433
- )
434
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
435
- assert m2 is m
436
- assert m2.a is b
437
- assert m2.b is a
438
-
439
- @pytest.mark.skip(reason='Not implemented')
440
- def test_cached_unflatten_add_self_reference(self):
441
- class Foo(bst.graph.Node):
442
- def __init__(self):
443
- self.ref = None
444
-
445
- def f(m: Foo):
446
- m.ref = m
447
-
448
- m = Foo()
449
-
450
- ref_out_idx_out = bst.graph.RefMap()
451
- graphdef: bst.graph.GraphDef[Foo]
452
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
453
-
454
- @partial(jax.jit, static_argnums=(0,))
455
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
456
- idx_out_ref_in: dict[int, Any] = {}
457
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
458
- f(m)
459
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
460
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
461
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
462
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
463
- return state, static_out
464
-
465
- static_out: bst.graph.Static
466
- state, static_out = f_pure(graphdef, state)
467
- idx_out_idx_in: dict[int, int]
468
- graphdef, idx_out_idx_in = static_out.value
469
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
470
- ref_out_idx_out, idx_out_idx_in
471
- )
472
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
473
- assert m2 is m
474
- assert m2.ref is m2
475
-
476
354
  def test_call_jit_update(self):
477
355
  class Counter(bst.graph.Node):
478
356
  def __init__(self):
@@ -527,43 +405,6 @@ class TestGraphUtils(absltest.TestCase):
527
405
  self.assertEqual(nodes['a'].count.value, 0)
528
406
  self.assertEqual(nodes['b'].count.value, 1)
529
407
 
530
- def test_to_tree_simple(self):
531
- m = bst.nn.Linear(2, 3, )
532
- impure_tree = (m, 1, {'b': m})
533
-
534
- pure_tree = bst.graph.graph_to_tree(impure_tree)
535
-
536
- t1 = pure_tree[0]
537
- t2 = pure_tree[2]['b']
538
-
539
- self.assertEqual(pure_tree[1], 1)
540
- self.assertIsInstance(t1, bst.graph.NodeStates)
541
- assert isinstance(t1, bst.graph.NodeStates)
542
- self.assertIsInstance(t2, bst.graph.NodeStates)
543
- assert isinstance(t2, bst.graph.NodeStates)
544
- self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
545
- self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
546
- self.assertLen(t1.states[0].to_flat(), 1)
547
- self.assertLen(t2.states[0].to_flat(), 0)
548
-
549
- impure_tree2 = bst.graph.tree_to_graph(pure_tree)
550
-
551
- m1_out = impure_tree2[0]
552
- m2_out = impure_tree2[2]['b']
553
-
554
- self.assertIs(m1_out, m2_out)
555
- self.assertEqual(impure_tree2[1], 1)
556
-
557
- def test_to_tree_consistent_prefix(self):
558
- m = bst.nn.Linear(2, 3, )
559
- impure_tree = (m, 1, {'b': m})
560
- prefix = (0, None, 0)
561
- pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
562
-
563
- prefix = (0, None, 1)
564
- with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
565
- bst.graph.graph_to_tree(impure_tree, prefix=prefix)
566
-
567
408
 
568
409
  class SimpleModule(bst.nn.Module):
569
410
  pass