brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -1,443 +0,0 @@
1
- # The file is adapted from the Flax library (https://github.com/google/flax).
2
- # The credit should go to the Flax authors.
3
- #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- # ==============================================================================
18
-
19
- from __future__ import annotations
20
-
21
- import jax
22
- import contextlib
23
- import dataclasses
24
- import functools
25
- import threading
26
- from typing import (Any, Tuple, List, overload, Callable, TypeVar, Mapping)
27
-
28
- from typing_extensions import Unpack
29
-
30
- from brainstate.typing import Filter
31
- from brainstate.util import NestedDict, FrozenDict
32
- from ._graph_operation import (flatten,
33
- unflatten,
34
- _split_state,
35
- GraphDef,
36
- RefMap,
37
- NodeDef)
38
-
39
- __all__ = [
40
- 'split_context',
41
- 'merge_context',
42
- 'update_context',
43
- ]
44
-
45
- Index = int
46
- A = TypeVar('A')
47
- B = TypeVar('B')
48
- C = TypeVar('C')
49
- F = TypeVar('F', bound=Callable)
50
-
51
-
52
- @dataclasses.dataclass
53
- class GraphContext(threading.local):
54
- """
55
- A context manager for handling complex state updates.
56
- """
57
- update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field(default_factory=dict)
58
- ref_index_stack: List[SplitContext] = dataclasses.field(default_factory=list)
59
- index_ref_stack: List[MergeContext] = dataclasses.field(default_factory=list)
60
-
61
-
62
- GRAPH_CONTEXT = GraphContext()
63
-
64
-
65
- @dataclasses.dataclass
66
- class SplitContext:
67
- """
68
- A context manager for handling graph splitting.
69
- """
70
- ctxtag: str | None
71
- ref_index: RefMap[Any, Index]
72
-
73
- def treefy_split(
74
- self,
75
- node: A,
76
- *filters: Filter
77
- ) -> Tuple[GraphDef[A], Unpack[Tuple[NestedDict, ...]]]:
78
- ctx = current_update_context(self.ctxtag) if self.ctxtag is not None else None
79
- graphdef, statetree = flatten(node, self.ref_index)
80
- state_mappings = _split_state(statetree, filters)
81
- if ctx is not None:
82
- if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
83
- index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
84
- graphdef = dataclasses.replace(
85
- graphdef,
86
- index_mapping=FrozenDict(index_to_index)
87
- )
88
- return graphdef, *state_mappings
89
-
90
-
91
- @contextlib.contextmanager
92
- def split_context(ctxtag: str | None = None):
93
- """
94
- A context manager for handling graph splitting.
95
- """
96
- index_ref: RefMap[Any, Index] = RefMap()
97
- flatten_ctx = SplitContext(ctxtag, index_ref)
98
- GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx)
99
-
100
- try:
101
- yield flatten_ctx
102
- finally:
103
- GRAPH_CONTEXT.ref_index_stack.pop()
104
- if ctxtag is not None:
105
- ctx = current_update_context(ctxtag)
106
- ctx.flatten_end(index_ref)
107
- del flatten_ctx.ref_index
108
- del flatten_ctx.ctxtag
109
-
110
-
111
- @dataclasses.dataclass
112
- class MergeContext:
113
- """
114
- A context manager for handling graph merging.
115
- """
116
- ctxtag: str | None
117
- index_ref: dict[Index, Any]
118
-
119
- def treefy_merge(
120
- self,
121
- graphdef: GraphDef[A],
122
- state_mapping: NestedDict,
123
- /,
124
- *state_mappings: NestedDict
125
- ) -> A:
126
- ctx = (
127
- current_update_context(self.ctxtag)
128
- if self.ctxtag is not None
129
- else None
130
- )
131
- if (
132
- ctx is not None
133
- and isinstance(graphdef, NodeDef)
134
- and graphdef.index_mapping is not None
135
- ):
136
- # outer merge (4), create index_ref_cache
137
- assert ctx.ref_index is not None
138
- index_ref_cache = compose_mapping_reversed(
139
- ctx.ref_index, graphdef.index_mapping
140
- )
141
- else:
142
- # inner merge (2)
143
- index_ref_cache = None
144
-
145
- state_mapping = NestedDict.merge(state_mapping, *state_mappings)
146
- node = unflatten(
147
- graphdef,
148
- state_mapping,
149
- index_ref=self.index_ref,
150
- index_ref_cache=index_ref_cache,
151
- )
152
- return node
153
-
154
-
155
- @contextlib.contextmanager
156
- def merge_context(ctxtag: str | None = None):
157
- """
158
- A context manager for handling graph merging.
159
- """
160
- index_ref: dict[Index, Any] = {}
161
-
162
- unflatten_ctx = MergeContext(ctxtag, index_ref)
163
- GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx)
164
-
165
- try:
166
- yield unflatten_ctx
167
- finally:
168
- GRAPH_CONTEXT.index_ref_stack.pop()
169
- if ctxtag is not None:
170
- ctx = current_update_context(ctxtag)
171
- ctx.unflatten_end(index_ref)
172
- del unflatten_ctx.index_ref
173
- del unflatten_ctx.ctxtag
174
-
175
-
176
- @dataclasses.dataclass
177
- class UpdateContext:
178
- """
179
- A context manager for handling complex state updates.
180
- """
181
-
182
- tag: str
183
- ref_index: RefMap[Any, Index] | None
184
- index_ref: dict[Index, Any] | None
185
-
186
- # define hash and eq to make this an opaque object
187
- def __hash__(self):
188
- return 0
189
-
190
- def __eq__(self, other):
191
- return isinstance(other, UpdateContext)
192
-
193
- def flatten_end(self, ref_index: RefMap[Any, Index]):
194
- if self.ref_index is None:
195
- # outer split (1), store the references
196
- self.ref_index = ref_index
197
- else:
198
- # inner split (3), clear index_ref
199
- self.index_ref = None
200
-
201
- def unflatten_end(self, index_ref: dict[Index, Any]):
202
- self.index_ref = index_ref
203
-
204
- @overload
205
- def split(
206
- self, graph_node: A, /
207
- ) -> tuple[GraphDef[A], NestedDict]:
208
- ...
209
-
210
- @overload
211
- def split(
212
- self, graph_node: A, first: Filter, /
213
- ) -> tuple[GraphDef[A], NestedDict]:
214
- ...
215
-
216
- @overload
217
- def split(
218
- self,
219
- graph_node: A,
220
- first: Filter,
221
- second: Filter,
222
- /,
223
- *filters: Filter,
224
- ) -> tuple[GraphDef[A], NestedDict, Unpack[tuple[NestedDict, ...]]]:
225
- ...
226
-
227
- def split(
228
- self, node: A, *filters: Filter
229
- ) -> tuple[GraphDef[A], NestedDict, Unpack[tuple[NestedDict, ...]]]:
230
- """
231
- Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
232
- a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
233
- contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
234
- to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
235
- seamlessly between stateful and stateless representations of the graph.
236
-
237
- Arguments:
238
- node: graph node to split.
239
- *filters: some optional filters to group the state into mutually exclusive substates.
240
- Returns:
241
- :class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no
242
- filters are passed, a single :class:`State` is returned.
243
- """
244
- ref_index: RefMap[Any, Index] = RefMap()
245
- graphdef, state = flatten(node, ref_index)
246
- states = _split_state(state, filters)
247
-
248
- if (self.index_ref is not None) and isinstance(graphdef, NodeDef):
249
- index_to_index = compose_mapping(self.index_ref, ref_index)
250
- graphdef = dataclasses.replace(
251
- graphdef,
252
- index_mapping=FrozenDict(index_to_index)
253
- )
254
-
255
- self.flatten_end(ref_index)
256
-
257
- return graphdef, *states
258
-
259
- def merge(
260
- self,
261
- graphdef: GraphDef[A],
262
- state: NestedDict,
263
- *states: NestedDict,
264
- ) -> A:
265
- """merge"""
266
- if not isinstance(graphdef, NodeDef):
267
- raise ValueError(f'Expected a NodeDef instance, but got {type(graphdef)}.' )
268
- if self.ref_index is None:
269
- raise ValueError('Cannot merge without ref_index.')
270
-
271
- if graphdef.index_mapping is not None:
272
- # outer merge (4), create index_ref_cache
273
- assert self.ref_index is not None
274
- index_ref_cache = compose_mapping_reversed(
275
- self.ref_index,
276
- graphdef.index_mapping
277
- )
278
- else:
279
- # inner merge (2)
280
- index_ref_cache = None
281
-
282
- state = NestedDict.merge(state, *states)
283
- index_ref: dict[Index, Any] = {}
284
- node = unflatten(
285
- graphdef,
286
- state,
287
- index_ref=index_ref,
288
- index_ref_cache=index_ref_cache
289
- )
290
-
291
- self.unflatten_end(index_ref)
292
-
293
- return node
294
-
295
-
296
- jax.tree_util.register_static(UpdateContext)
297
-
298
-
299
- @dataclasses.dataclass
300
- class UpdateContextManager:
301
- tag: str
302
-
303
- def __enter__(self):
304
- ctx = UpdateContext(self.tag, None, None)
305
- if self.tag not in GRAPH_CONTEXT.update_context_stacks:
306
- GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx]
307
- else:
308
- GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx)
309
- return ctx
310
-
311
- def __exit__(self, *args):
312
- if self.tag not in GRAPH_CONTEXT.update_context_stacks:
313
- raise RuntimeError(f'No update context found for tag {self.tag!r}, this is a bug.')
314
- stack = GRAPH_CONTEXT.update_context_stacks[self.tag]
315
-
316
- ctx = stack.pop()
317
- # clear references
318
- del ctx.ref_index
319
- del ctx.index_ref
320
-
321
- if not stack:
322
- del GRAPH_CONTEXT.update_context_stacks[self.tag]
323
-
324
- def __call__(self, f: F) -> F:
325
- @functools.wraps(f)
326
- def update_context_manager_wrapper(*args, **kwargs):
327
- with self:
328
- return f(*args, **kwargs)
329
-
330
- return update_context_manager_wrapper # type: ignore
331
-
332
-
333
- def update_context(tag: str):
334
- """Creates an :class:`UpdateContext` context manager which can be used to handle
335
- more complex state updates beyond what ``nnx.update`` can handle, including
336
- updates to static properties and graph structure.
337
-
338
- UpdateContext exposes a ``split`` and ``merge`` API with the same
339
- signature as ``nnx.split`` / ``nnx.merge`` but performs some bookkeeping
340
- to have the necessary information in order to perfectly update the input
341
- objects based on the changes made inside the transform. The UpdateContext
342
- must call split and merge a total of 4 times, the first
343
- and last calls happen outside the transform and the second and third calls
344
- happen inside the transform as shown in the diagram below::
345
-
346
-
347
- idxmap
348
- (2) merge ─────────────────────────────► split (3)
349
- ▲ │
350
- │ inside │
351
- │. . . . . . . . . . . . . . . . . . │ index_mapping
352
- │ outside │
353
- │ ▼
354
- (1) split──────────────────────────────► merge (4)
355
- refmap
356
-
357
-
358
- The first call to split ``(1)`` creates a ``refmap`` which keeps track of the
359
- outer references, and the first call to merge ``(2)`` creates an ``idxmap`` which
360
- keeps track of the inner references. The second call to split ``(3)`` combines
361
- the refmap and idxmap to produce the ``index_mapping`` which indicates
362
- how the outer references map to the inner references. Finally, the last call to
363
- merge ``(4)`` uses the index_mapping and the refmap to reconstruct the
364
- output of the transform while reusing/updating the inner references. To avoid
365
- memory leaks, the idxmap is cleared after ``(3)`` and the refmap is
366
- cleared after ``(4)``, and both are cleared after the context manager exits.
367
-
368
- Here is a simple example showing the use of ``update_context``::
369
-
370
- >>> import brainstate as bst
371
- >>> import jax
372
- ...
373
- >>> m1 = bst.graph.Dict({})
374
- >>> with bst.graph.update_context('example') as ctx:
375
- ... graphdef, state = ctx.split(m1)
376
- ... @jax.jit
377
- ... def f(graphdef, state):
378
- ... m2 = ctx.merge(graphdef, state)
379
- ... m2.a = 1
380
- ... m2.ref = m2 # create a reference cycle
381
- ... return ctx.split(m2)
382
- ... graphdef_out, state_out = f(graphdef, state)
383
- ... m3 = ctx.merge(graphdef_out, state_out)
384
- ...
385
- >>> assert m1 is m3
386
- >>> assert m1.a == 1
387
- >>> assert m1.ref is m1
388
-
389
- Note that ``update_context`` takes in a ``tag`` argument which is used
390
- primarily as a safety mechanism reduce the risk of accidentally using the
391
- wrong UpdateContext when using :func:`current_update_context` to access the
392
- current active context. current_update_context can be used as a way of
393
- accessing the current active context without having to pass it as a capture::
394
-
395
- >>> m1 = bst.graph.Dict({})
396
- >>> @jax.jit
397
- ... def f(graphdef, state):
398
- ... ctx = bst.graph.current_update_context('example')
399
- ... m2 = ctx.merge(graphdef, state)
400
- ... m2.a = 1 # insert static attribute
401
- ... m2.ref = m2 # create a reference cycle
402
- ... return ctx.split(m2)
403
- ...
404
- >>> @bst.graph.update_context('example')
405
- ... def g(m1):
406
- ... ctx = bst.graph.current_update_context('example')
407
- ... graphdef, state = ctx.split(m1)
408
- ... graphdef_out, state_out = f(graphdef, state)
409
- ... return ctx.merge(graphdef_out, state_out)
410
- ...
411
- >>> m3 = g(m1)
412
- >>> assert m1 is m3
413
- >>> assert m1.a == 1
414
- >>> assert m1.ref is m1
415
-
416
- As shown in the code above, ``update_context`` can also be used as a
417
- decorator that creates/activates an UpdateContext context for the
418
- duration of the function. The context can be accessed using
419
- :func:`current_update_context`.
420
-
421
- Args:
422
- tag: A string tag to identify the context.
423
- """
424
- return UpdateContextManager(tag)
425
-
426
-
427
- def current_update_context(tag: str) -> UpdateContext:
428
- """Returns the current active :class:`UpdateContext` for the given tag."""
429
- if tag not in GRAPH_CONTEXT.update_context_stacks:
430
- raise ValueError(f'No update context found for tag {tag!r}.')
431
- return GRAPH_CONTEXT.update_context_stacks[tag][-1]
432
-
433
-
434
- def compose_mapping(
435
- map_ab: Mapping[A, B], map_bc: Mapping[B, C], /
436
- ) -> dict[A, C]:
437
- return {a: map_bc[b] for a, b in map_ab.items() if b in map_bc}
438
-
439
-
440
- def compose_mapping_reversed(
441
- map_ab: Mapping[A, B], map_bc: Mapping[B, C], /
442
- ) -> dict[C, A]:
443
- return {map_bc[b]: a for a, b in map_ab.items() if b in map_bc}
@@ -1,65 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- from absl.testing import absltest
19
-
20
- import brainstate as bst
21
-
22
-
23
- class TestGraphUtils(absltest.TestCase):
24
- def test_split_merge_context(self):
25
- m = bst.nn.Linear(2, 3, )
26
- with bst.graph.split_context() as ctx:
27
- graphdef1, state1 = ctx.treefy_split(m)
28
- graphdef2, state2 = ctx.treefy_split(m)
29
- pass
30
-
31
- self.assertFalse(hasattr(ctx, 'ref_index'))
32
- self.assertIsInstance(graphdef1, bst.graph.NodeDef)
33
- self.assertIsInstance(graphdef2, bst.graph.NodeRef)
34
- self.assertLen(state1.to_flat(), 1)
35
- self.assertLen(state2.to_flat(), 0)
36
-
37
- with bst.graph.merge_context() as ctx:
38
- m1 = ctx.treefy_merge(graphdef1, state1)
39
- m2 = ctx.treefy_merge(graphdef2, state2)
40
-
41
- self.assertIs(m1, m2)
42
- self.assertFalse(hasattr(ctx, 'index_ref'))
43
-
44
- def test_split_merge_context_nested(self):
45
- m2 = bst.nn.Linear(2, 3, )
46
- m1 = bst.nn.Sequential(m2)
47
- with bst.graph.split_context() as ctx:
48
- graphdef1, state1 = ctx.treefy_split(m1)
49
- graphdef2, state2 = ctx.treefy_split(m2)
50
-
51
- self.assertIsInstance(graphdef1, bst.graph.NodeDef)
52
- self.assertIsInstance(graphdef2, bst.graph.NodeRef)
53
- self.assertLen(state1.to_flat(), 1)
54
- self.assertLen(state2.to_flat(), 0)
55
-
56
- with bst.graph.merge_context() as ctx:
57
- m1 = ctx.treefy_merge(graphdef1, state1)
58
- m2 = ctx.treefy_merge(graphdef2, state2)
59
-
60
- self.assertIs(m2, m1.layers[0])
61
- self.assertFalse(hasattr(ctx, 'index_ref'))
62
-
63
-
64
- if __name__ == '__main__':
65
- absltest.main()