brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -1,443 +0,0 @@
|
|
1
|
-
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
-
# The credit should go to the Flax authors.
|
3
|
-
#
|
4
|
-
# Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
|
5
|
-
#
|
6
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
-
# you may not use this file except in compliance with the License.
|
8
|
-
# You may obtain a copy of the License at
|
9
|
-
#
|
10
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
-
#
|
12
|
-
# Unless required by applicable law or agreed to in writing, software
|
13
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
-
# See the License for the specific language governing permissions and
|
16
|
-
# limitations under the License.
|
17
|
-
# ==============================================================================
|
18
|
-
|
19
|
-
from __future__ import annotations
|
20
|
-
|
21
|
-
import jax
|
22
|
-
import contextlib
|
23
|
-
import dataclasses
|
24
|
-
import functools
|
25
|
-
import threading
|
26
|
-
from typing import (Any, Tuple, List, overload, Callable, TypeVar, Mapping)
|
27
|
-
|
28
|
-
from typing_extensions import Unpack
|
29
|
-
|
30
|
-
from brainstate.typing import Filter
|
31
|
-
from brainstate.util import NestedDict, FrozenDict
|
32
|
-
from ._graph_operation import (flatten,
|
33
|
-
unflatten,
|
34
|
-
_split_state,
|
35
|
-
GraphDef,
|
36
|
-
RefMap,
|
37
|
-
NodeDef)
|
38
|
-
|
39
|
-
__all__ = [
|
40
|
-
'split_context',
|
41
|
-
'merge_context',
|
42
|
-
'update_context',
|
43
|
-
]
|
44
|
-
|
45
|
-
Index = int
|
46
|
-
A = TypeVar('A')
|
47
|
-
B = TypeVar('B')
|
48
|
-
C = TypeVar('C')
|
49
|
-
F = TypeVar('F', bound=Callable)
|
50
|
-
|
51
|
-
|
52
|
-
@dataclasses.dataclass
|
53
|
-
class GraphContext(threading.local):
|
54
|
-
"""
|
55
|
-
A context manager for handling complex state updates.
|
56
|
-
"""
|
57
|
-
update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field(default_factory=dict)
|
58
|
-
ref_index_stack: List[SplitContext] = dataclasses.field(default_factory=list)
|
59
|
-
index_ref_stack: List[MergeContext] = dataclasses.field(default_factory=list)
|
60
|
-
|
61
|
-
|
62
|
-
GRAPH_CONTEXT = GraphContext()
|
63
|
-
|
64
|
-
|
65
|
-
@dataclasses.dataclass
|
66
|
-
class SplitContext:
|
67
|
-
"""
|
68
|
-
A context manager for handling graph splitting.
|
69
|
-
"""
|
70
|
-
ctxtag: str | None
|
71
|
-
ref_index: RefMap[Any, Index]
|
72
|
-
|
73
|
-
def treefy_split(
|
74
|
-
self,
|
75
|
-
node: A,
|
76
|
-
*filters: Filter
|
77
|
-
) -> Tuple[GraphDef[A], Unpack[Tuple[NestedDict, ...]]]:
|
78
|
-
ctx = current_update_context(self.ctxtag) if self.ctxtag is not None else None
|
79
|
-
graphdef, statetree = flatten(node, self.ref_index)
|
80
|
-
state_mappings = _split_state(statetree, filters)
|
81
|
-
if ctx is not None:
|
82
|
-
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
|
83
|
-
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
|
84
|
-
graphdef = dataclasses.replace(
|
85
|
-
graphdef,
|
86
|
-
index_mapping=FrozenDict(index_to_index)
|
87
|
-
)
|
88
|
-
return graphdef, *state_mappings
|
89
|
-
|
90
|
-
|
91
|
-
@contextlib.contextmanager
|
92
|
-
def split_context(ctxtag: str | None = None):
|
93
|
-
"""
|
94
|
-
A context manager for handling graph splitting.
|
95
|
-
"""
|
96
|
-
index_ref: RefMap[Any, Index] = RefMap()
|
97
|
-
flatten_ctx = SplitContext(ctxtag, index_ref)
|
98
|
-
GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx)
|
99
|
-
|
100
|
-
try:
|
101
|
-
yield flatten_ctx
|
102
|
-
finally:
|
103
|
-
GRAPH_CONTEXT.ref_index_stack.pop()
|
104
|
-
if ctxtag is not None:
|
105
|
-
ctx = current_update_context(ctxtag)
|
106
|
-
ctx.flatten_end(index_ref)
|
107
|
-
del flatten_ctx.ref_index
|
108
|
-
del flatten_ctx.ctxtag
|
109
|
-
|
110
|
-
|
111
|
-
@dataclasses.dataclass
|
112
|
-
class MergeContext:
|
113
|
-
"""
|
114
|
-
A context manager for handling graph merging.
|
115
|
-
"""
|
116
|
-
ctxtag: str | None
|
117
|
-
index_ref: dict[Index, Any]
|
118
|
-
|
119
|
-
def treefy_merge(
|
120
|
-
self,
|
121
|
-
graphdef: GraphDef[A],
|
122
|
-
state_mapping: NestedDict,
|
123
|
-
/,
|
124
|
-
*state_mappings: NestedDict
|
125
|
-
) -> A:
|
126
|
-
ctx = (
|
127
|
-
current_update_context(self.ctxtag)
|
128
|
-
if self.ctxtag is not None
|
129
|
-
else None
|
130
|
-
)
|
131
|
-
if (
|
132
|
-
ctx is not None
|
133
|
-
and isinstance(graphdef, NodeDef)
|
134
|
-
and graphdef.index_mapping is not None
|
135
|
-
):
|
136
|
-
# outer merge (4), create index_ref_cache
|
137
|
-
assert ctx.ref_index is not None
|
138
|
-
index_ref_cache = compose_mapping_reversed(
|
139
|
-
ctx.ref_index, graphdef.index_mapping
|
140
|
-
)
|
141
|
-
else:
|
142
|
-
# inner merge (2)
|
143
|
-
index_ref_cache = None
|
144
|
-
|
145
|
-
state_mapping = NestedDict.merge(state_mapping, *state_mappings)
|
146
|
-
node = unflatten(
|
147
|
-
graphdef,
|
148
|
-
state_mapping,
|
149
|
-
index_ref=self.index_ref,
|
150
|
-
index_ref_cache=index_ref_cache,
|
151
|
-
)
|
152
|
-
return node
|
153
|
-
|
154
|
-
|
155
|
-
@contextlib.contextmanager
|
156
|
-
def merge_context(ctxtag: str | None = None):
|
157
|
-
"""
|
158
|
-
A context manager for handling graph merging.
|
159
|
-
"""
|
160
|
-
index_ref: dict[Index, Any] = {}
|
161
|
-
|
162
|
-
unflatten_ctx = MergeContext(ctxtag, index_ref)
|
163
|
-
GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx)
|
164
|
-
|
165
|
-
try:
|
166
|
-
yield unflatten_ctx
|
167
|
-
finally:
|
168
|
-
GRAPH_CONTEXT.index_ref_stack.pop()
|
169
|
-
if ctxtag is not None:
|
170
|
-
ctx = current_update_context(ctxtag)
|
171
|
-
ctx.unflatten_end(index_ref)
|
172
|
-
del unflatten_ctx.index_ref
|
173
|
-
del unflatten_ctx.ctxtag
|
174
|
-
|
175
|
-
|
176
|
-
@dataclasses.dataclass
|
177
|
-
class UpdateContext:
|
178
|
-
"""
|
179
|
-
A context manager for handling complex state updates.
|
180
|
-
"""
|
181
|
-
|
182
|
-
tag: str
|
183
|
-
ref_index: RefMap[Any, Index] | None
|
184
|
-
index_ref: dict[Index, Any] | None
|
185
|
-
|
186
|
-
# define hash and eq to make this an opaque object
|
187
|
-
def __hash__(self):
|
188
|
-
return 0
|
189
|
-
|
190
|
-
def __eq__(self, other):
|
191
|
-
return isinstance(other, UpdateContext)
|
192
|
-
|
193
|
-
def flatten_end(self, ref_index: RefMap[Any, Index]):
|
194
|
-
if self.ref_index is None:
|
195
|
-
# outer split (1), store the references
|
196
|
-
self.ref_index = ref_index
|
197
|
-
else:
|
198
|
-
# inner split (3), clear index_ref
|
199
|
-
self.index_ref = None
|
200
|
-
|
201
|
-
def unflatten_end(self, index_ref: dict[Index, Any]):
|
202
|
-
self.index_ref = index_ref
|
203
|
-
|
204
|
-
@overload
|
205
|
-
def split(
|
206
|
-
self, graph_node: A, /
|
207
|
-
) -> tuple[GraphDef[A], NestedDict]:
|
208
|
-
...
|
209
|
-
|
210
|
-
@overload
|
211
|
-
def split(
|
212
|
-
self, graph_node: A, first: Filter, /
|
213
|
-
) -> tuple[GraphDef[A], NestedDict]:
|
214
|
-
...
|
215
|
-
|
216
|
-
@overload
|
217
|
-
def split(
|
218
|
-
self,
|
219
|
-
graph_node: A,
|
220
|
-
first: Filter,
|
221
|
-
second: Filter,
|
222
|
-
/,
|
223
|
-
*filters: Filter,
|
224
|
-
) -> tuple[GraphDef[A], NestedDict, Unpack[tuple[NestedDict, ...]]]:
|
225
|
-
...
|
226
|
-
|
227
|
-
def split(
|
228
|
-
self, node: A, *filters: Filter
|
229
|
-
) -> tuple[GraphDef[A], NestedDict, Unpack[tuple[NestedDict, ...]]]:
|
230
|
-
"""
|
231
|
-
Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
|
232
|
-
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
|
233
|
-
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
|
234
|
-
to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
|
235
|
-
seamlessly between stateful and stateless representations of the graph.
|
236
|
-
|
237
|
-
Arguments:
|
238
|
-
node: graph node to split.
|
239
|
-
*filters: some optional filters to group the state into mutually exclusive substates.
|
240
|
-
Returns:
|
241
|
-
:class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no
|
242
|
-
filters are passed, a single :class:`State` is returned.
|
243
|
-
"""
|
244
|
-
ref_index: RefMap[Any, Index] = RefMap()
|
245
|
-
graphdef, state = flatten(node, ref_index)
|
246
|
-
states = _split_state(state, filters)
|
247
|
-
|
248
|
-
if (self.index_ref is not None) and isinstance(graphdef, NodeDef):
|
249
|
-
index_to_index = compose_mapping(self.index_ref, ref_index)
|
250
|
-
graphdef = dataclasses.replace(
|
251
|
-
graphdef,
|
252
|
-
index_mapping=FrozenDict(index_to_index)
|
253
|
-
)
|
254
|
-
|
255
|
-
self.flatten_end(ref_index)
|
256
|
-
|
257
|
-
return graphdef, *states
|
258
|
-
|
259
|
-
def merge(
|
260
|
-
self,
|
261
|
-
graphdef: GraphDef[A],
|
262
|
-
state: NestedDict,
|
263
|
-
*states: NestedDict,
|
264
|
-
) -> A:
|
265
|
-
"""merge"""
|
266
|
-
if not isinstance(graphdef, NodeDef):
|
267
|
-
raise ValueError(f'Expected a NodeDef instance, but got {type(graphdef)}.' )
|
268
|
-
if self.ref_index is None:
|
269
|
-
raise ValueError('Cannot merge without ref_index.')
|
270
|
-
|
271
|
-
if graphdef.index_mapping is not None:
|
272
|
-
# outer merge (4), create index_ref_cache
|
273
|
-
assert self.ref_index is not None
|
274
|
-
index_ref_cache = compose_mapping_reversed(
|
275
|
-
self.ref_index,
|
276
|
-
graphdef.index_mapping
|
277
|
-
)
|
278
|
-
else:
|
279
|
-
# inner merge (2)
|
280
|
-
index_ref_cache = None
|
281
|
-
|
282
|
-
state = NestedDict.merge(state, *states)
|
283
|
-
index_ref: dict[Index, Any] = {}
|
284
|
-
node = unflatten(
|
285
|
-
graphdef,
|
286
|
-
state,
|
287
|
-
index_ref=index_ref,
|
288
|
-
index_ref_cache=index_ref_cache
|
289
|
-
)
|
290
|
-
|
291
|
-
self.unflatten_end(index_ref)
|
292
|
-
|
293
|
-
return node
|
294
|
-
|
295
|
-
|
296
|
-
jax.tree_util.register_static(UpdateContext)
|
297
|
-
|
298
|
-
|
299
|
-
@dataclasses.dataclass
|
300
|
-
class UpdateContextManager:
|
301
|
-
tag: str
|
302
|
-
|
303
|
-
def __enter__(self):
|
304
|
-
ctx = UpdateContext(self.tag, None, None)
|
305
|
-
if self.tag not in GRAPH_CONTEXT.update_context_stacks:
|
306
|
-
GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx]
|
307
|
-
else:
|
308
|
-
GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx)
|
309
|
-
return ctx
|
310
|
-
|
311
|
-
def __exit__(self, *args):
|
312
|
-
if self.tag not in GRAPH_CONTEXT.update_context_stacks:
|
313
|
-
raise RuntimeError(f'No update context found for tag {self.tag!r}, this is a bug.')
|
314
|
-
stack = GRAPH_CONTEXT.update_context_stacks[self.tag]
|
315
|
-
|
316
|
-
ctx = stack.pop()
|
317
|
-
# clear references
|
318
|
-
del ctx.ref_index
|
319
|
-
del ctx.index_ref
|
320
|
-
|
321
|
-
if not stack:
|
322
|
-
del GRAPH_CONTEXT.update_context_stacks[self.tag]
|
323
|
-
|
324
|
-
def __call__(self, f: F) -> F:
|
325
|
-
@functools.wraps(f)
|
326
|
-
def update_context_manager_wrapper(*args, **kwargs):
|
327
|
-
with self:
|
328
|
-
return f(*args, **kwargs)
|
329
|
-
|
330
|
-
return update_context_manager_wrapper # type: ignore
|
331
|
-
|
332
|
-
|
333
|
-
def update_context(tag: str):
|
334
|
-
"""Creates an :class:`UpdateContext` context manager which can be used to handle
|
335
|
-
more complex state updates beyond what ``nnx.update`` can handle, including
|
336
|
-
updates to static properties and graph structure.
|
337
|
-
|
338
|
-
UpdateContext exposes a ``split`` and ``merge`` API with the same
|
339
|
-
signature as ``nnx.split`` / ``nnx.merge`` but performs some bookkeeping
|
340
|
-
to have the necessary information in order to perfectly update the input
|
341
|
-
objects based on the changes made inside the transform. The UpdateContext
|
342
|
-
must call split and merge a total of 4 times, the first
|
343
|
-
and last calls happen outside the transform and the second and third calls
|
344
|
-
happen inside the transform as shown in the diagram below::
|
345
|
-
|
346
|
-
|
347
|
-
idxmap
|
348
|
-
(2) merge ─────────────────────────────► split (3)
|
349
|
-
▲ │
|
350
|
-
│ inside │
|
351
|
-
│. . . . . . . . . . . . . . . . . . │ index_mapping
|
352
|
-
│ outside │
|
353
|
-
│ ▼
|
354
|
-
(1) split──────────────────────────────► merge (4)
|
355
|
-
refmap
|
356
|
-
|
357
|
-
|
358
|
-
The first call to split ``(1)`` creates a ``refmap`` which keeps track of the
|
359
|
-
outer references, and the first call to merge ``(2)`` creates an ``idxmap`` which
|
360
|
-
keeps track of the inner references. The second call to split ``(3)`` combines
|
361
|
-
the refmap and idxmap to produce the ``index_mapping`` which indicates
|
362
|
-
how the outer references map to the inner references. Finally, the last call to
|
363
|
-
merge ``(4)`` uses the index_mapping and the refmap to reconstruct the
|
364
|
-
output of the transform while reusing/updating the inner references. To avoid
|
365
|
-
memory leaks, the idxmap is cleared after ``(3)`` and the refmap is
|
366
|
-
cleared after ``(4)``, and both are cleared after the context manager exits.
|
367
|
-
|
368
|
-
Here is a simple example showing the use of ``update_context``::
|
369
|
-
|
370
|
-
>>> import brainstate as bst
|
371
|
-
>>> import jax
|
372
|
-
...
|
373
|
-
>>> m1 = bst.graph.Dict({})
|
374
|
-
>>> with bst.graph.update_context('example') as ctx:
|
375
|
-
... graphdef, state = ctx.split(m1)
|
376
|
-
... @jax.jit
|
377
|
-
... def f(graphdef, state):
|
378
|
-
... m2 = ctx.merge(graphdef, state)
|
379
|
-
... m2.a = 1
|
380
|
-
... m2.ref = m2 # create a reference cycle
|
381
|
-
... return ctx.split(m2)
|
382
|
-
... graphdef_out, state_out = f(graphdef, state)
|
383
|
-
... m3 = ctx.merge(graphdef_out, state_out)
|
384
|
-
...
|
385
|
-
>>> assert m1 is m3
|
386
|
-
>>> assert m1.a == 1
|
387
|
-
>>> assert m1.ref is m1
|
388
|
-
|
389
|
-
Note that ``update_context`` takes in a ``tag`` argument which is used
|
390
|
-
primarily as a safety mechanism reduce the risk of accidentally using the
|
391
|
-
wrong UpdateContext when using :func:`current_update_context` to access the
|
392
|
-
current active context. current_update_context can be used as a way of
|
393
|
-
accessing the current active context without having to pass it as a capture::
|
394
|
-
|
395
|
-
>>> m1 = bst.graph.Dict({})
|
396
|
-
>>> @jax.jit
|
397
|
-
... def f(graphdef, state):
|
398
|
-
... ctx = bst.graph.current_update_context('example')
|
399
|
-
... m2 = ctx.merge(graphdef, state)
|
400
|
-
... m2.a = 1 # insert static attribute
|
401
|
-
... m2.ref = m2 # create a reference cycle
|
402
|
-
... return ctx.split(m2)
|
403
|
-
...
|
404
|
-
>>> @bst.graph.update_context('example')
|
405
|
-
... def g(m1):
|
406
|
-
... ctx = bst.graph.current_update_context('example')
|
407
|
-
... graphdef, state = ctx.split(m1)
|
408
|
-
... graphdef_out, state_out = f(graphdef, state)
|
409
|
-
... return ctx.merge(graphdef_out, state_out)
|
410
|
-
...
|
411
|
-
>>> m3 = g(m1)
|
412
|
-
>>> assert m1 is m3
|
413
|
-
>>> assert m1.a == 1
|
414
|
-
>>> assert m1.ref is m1
|
415
|
-
|
416
|
-
As shown in the code above, ``update_context`` can also be used as a
|
417
|
-
decorator that creates/activates an UpdateContext context for the
|
418
|
-
duration of the function. The context can be accessed using
|
419
|
-
:func:`current_update_context`.
|
420
|
-
|
421
|
-
Args:
|
422
|
-
tag: A string tag to identify the context.
|
423
|
-
"""
|
424
|
-
return UpdateContextManager(tag)
|
425
|
-
|
426
|
-
|
427
|
-
def current_update_context(tag: str) -> UpdateContext:
|
428
|
-
"""Returns the current active :class:`UpdateContext` for the given tag."""
|
429
|
-
if tag not in GRAPH_CONTEXT.update_context_stacks:
|
430
|
-
raise ValueError(f'No update context found for tag {tag!r}.')
|
431
|
-
return GRAPH_CONTEXT.update_context_stacks[tag][-1]
|
432
|
-
|
433
|
-
|
434
|
-
def compose_mapping(
|
435
|
-
map_ab: Mapping[A, B], map_bc: Mapping[B, C], /
|
436
|
-
) -> dict[A, C]:
|
437
|
-
return {a: map_bc[b] for a, b in map_ab.items() if b in map_bc}
|
438
|
-
|
439
|
-
|
440
|
-
def compose_mapping_reversed(
|
441
|
-
map_ab: Mapping[A, B], map_bc: Mapping[B, C], /
|
442
|
-
) -> dict[C, A]:
|
443
|
-
return {map_bc[b]: a for a, b in map_ab.items() if b in map_bc}
|
@@ -1,65 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
from absl.testing import absltest
|
19
|
-
|
20
|
-
import brainstate as bst
|
21
|
-
|
22
|
-
|
23
|
-
class TestGraphUtils(absltest.TestCase):
|
24
|
-
def test_split_merge_context(self):
|
25
|
-
m = bst.nn.Linear(2, 3, )
|
26
|
-
with bst.graph.split_context() as ctx:
|
27
|
-
graphdef1, state1 = ctx.treefy_split(m)
|
28
|
-
graphdef2, state2 = ctx.treefy_split(m)
|
29
|
-
pass
|
30
|
-
|
31
|
-
self.assertFalse(hasattr(ctx, 'ref_index'))
|
32
|
-
self.assertIsInstance(graphdef1, bst.graph.NodeDef)
|
33
|
-
self.assertIsInstance(graphdef2, bst.graph.NodeRef)
|
34
|
-
self.assertLen(state1.to_flat(), 1)
|
35
|
-
self.assertLen(state2.to_flat(), 0)
|
36
|
-
|
37
|
-
with bst.graph.merge_context() as ctx:
|
38
|
-
m1 = ctx.treefy_merge(graphdef1, state1)
|
39
|
-
m2 = ctx.treefy_merge(graphdef2, state2)
|
40
|
-
|
41
|
-
self.assertIs(m1, m2)
|
42
|
-
self.assertFalse(hasattr(ctx, 'index_ref'))
|
43
|
-
|
44
|
-
def test_split_merge_context_nested(self):
|
45
|
-
m2 = bst.nn.Linear(2, 3, )
|
46
|
-
m1 = bst.nn.Sequential(m2)
|
47
|
-
with bst.graph.split_context() as ctx:
|
48
|
-
graphdef1, state1 = ctx.treefy_split(m1)
|
49
|
-
graphdef2, state2 = ctx.treefy_split(m2)
|
50
|
-
|
51
|
-
self.assertIsInstance(graphdef1, bst.graph.NodeDef)
|
52
|
-
self.assertIsInstance(graphdef2, bst.graph.NodeRef)
|
53
|
-
self.assertLen(state1.to_flat(), 1)
|
54
|
-
self.assertLen(state2.to_flat(), 0)
|
55
|
-
|
56
|
-
with bst.graph.merge_context() as ctx:
|
57
|
-
m1 = ctx.treefy_merge(graphdef1, state1)
|
58
|
-
m2 = ctx.treefy_merge(graphdef2, state2)
|
59
|
-
|
60
|
-
self.assertIs(m2, m1.layers[0])
|
61
|
-
self.assertFalse(hasattr(ctx, 'index_ref'))
|
62
|
-
|
63
|
-
|
64
|
-
if __name__ == '__main__':
|
65
|
-
absltest.main()
|