brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_compatible_import.py
CHANGED
@@ -35,8 +35,9 @@ __all__ = [
|
|
35
35
|
'safe_map',
|
36
36
|
'safe_zip',
|
37
37
|
'unzip2',
|
38
|
-
'unzip3',
|
39
38
|
'wraps',
|
39
|
+
'Device',
|
40
|
+
'wrap_init',
|
40
41
|
]
|
41
42
|
|
42
43
|
T = TypeVar("T")
|
@@ -44,10 +45,17 @@ T1 = TypeVar("T1")
|
|
44
45
|
T2 = TypeVar("T2")
|
45
46
|
T3 = TypeVar("T3")
|
46
47
|
|
48
|
+
|
49
|
+
from saiunit._compatible_import import wrap_init
|
47
50
|
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
48
51
|
|
49
52
|
from jax.core import get_aval, Tracer
|
50
53
|
|
54
|
+
if jax.__version_info__ < (0, 5, 0):
|
55
|
+
from jax.lib.xla_client import Device
|
56
|
+
else:
|
57
|
+
from jax import Device
|
58
|
+
|
51
59
|
if jax.__version_info__ < (0, 4, 38):
|
52
60
|
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
53
61
|
else:
|
@@ -84,8 +92,7 @@ else:
|
|
84
92
|
return list(zip(*args))
|
85
93
|
|
86
94
|
|
87
|
-
def unzip2(xys: Iterable[tuple[T1, T2]]
|
88
|
-
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
95
|
+
def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
89
96
|
"""Unzip sequence of length-2 tuples into two tuples."""
|
90
97
|
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
91
98
|
# is too permissive about inputs, and does not guarantee a length-2 output.
|
brainstate/_state.py
CHANGED
@@ -133,183 +133,6 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
|
|
133
133
|
TRACE_CONTEXT.tree_check.pop()
|
134
134
|
|
135
135
|
|
136
|
-
class StateCatcher(PrettyObject):
|
137
|
-
"""
|
138
|
-
The catcher to catch and manage new states.
|
139
|
-
|
140
|
-
This class provides functionality to collect and tag new State objects.
|
141
|
-
It ensures that each state is only added once and assigns a tag to each state.
|
142
|
-
|
143
|
-
Attributes:
|
144
|
-
state_tag (str): A string identifier used to tag the caught states.
|
145
|
-
state_ids (set): A set of state IDs to ensure uniqueness.
|
146
|
-
states (list): A list to store the caught State objects.
|
147
|
-
"""
|
148
|
-
|
149
|
-
def __init__(
|
150
|
-
self,
|
151
|
-
state_tag: str,
|
152
|
-
state_to_exclude: Filter = Nothing()
|
153
|
-
):
|
154
|
-
"""
|
155
|
-
Initialize a new Catcher instance.
|
156
|
-
|
157
|
-
Args:
|
158
|
-
state_tag (str): The tag to be assigned to caught states.
|
159
|
-
state_to_exclude (Filter, optional): A filter to exclude states from being caught.
|
160
|
-
"""
|
161
|
-
if state_to_exclude is None:
|
162
|
-
state_to_exclude = Nothing()
|
163
|
-
self.state_to_exclude = state_to_exclude
|
164
|
-
self.state_tag = state_tag
|
165
|
-
self.state_ids = set()
|
166
|
-
self.states = []
|
167
|
-
|
168
|
-
def get_state_values(self) -> List[PyTree]:
|
169
|
-
"""
|
170
|
-
Get the values of the caught states.
|
171
|
-
|
172
|
-
Returns:
|
173
|
-
list: A list of values of the caught states.
|
174
|
-
"""
|
175
|
-
return [state.value for state in self.states]
|
176
|
-
|
177
|
-
def get_states(self) -> List[State]:
|
178
|
-
"""
|
179
|
-
Get the caught states.
|
180
|
-
|
181
|
-
Returns:
|
182
|
-
list: A list of the caught states.
|
183
|
-
"""
|
184
|
-
return self.states
|
185
|
-
|
186
|
-
def append(self, state: State):
|
187
|
-
"""
|
188
|
-
Add a new state to the catcher if it hasn't been added before.
|
189
|
-
|
190
|
-
This method adds the state to the internal list, records its ID,
|
191
|
-
and assigns the catcher's tag to the state.
|
192
|
-
|
193
|
-
Args:
|
194
|
-
state (State): The State object to be added.
|
195
|
-
"""
|
196
|
-
if self.state_to_exclude((), state):
|
197
|
-
return
|
198
|
-
if id(state) not in self.state_ids:
|
199
|
-
self.state_ids.add(id(state))
|
200
|
-
self.states.append(state)
|
201
|
-
state.tag = self.state_tag
|
202
|
-
|
203
|
-
def __iter__(self):
|
204
|
-
"""
|
205
|
-
Allow iteration over the caught states.
|
206
|
-
|
207
|
-
Returns:
|
208
|
-
iterator: An iterator over the list of caught states.
|
209
|
-
"""
|
210
|
-
return iter(self.states)
|
211
|
-
|
212
|
-
def __len__(self):
|
213
|
-
"""
|
214
|
-
Return the number of caught states.
|
215
|
-
|
216
|
-
Returns:
|
217
|
-
int: The number of caught states.
|
218
|
-
"""
|
219
|
-
return len(self.states)
|
220
|
-
|
221
|
-
def __getitem__(self, index):
|
222
|
-
"""
|
223
|
-
Get a state by index.
|
224
|
-
|
225
|
-
Args:
|
226
|
-
index (int): The index of the state to retrieve.
|
227
|
-
|
228
|
-
Returns:
|
229
|
-
State: The state at the specified index.
|
230
|
-
"""
|
231
|
-
return self.states[index]
|
232
|
-
|
233
|
-
def clear(self):
|
234
|
-
"""
|
235
|
-
Clear all caught states.
|
236
|
-
"""
|
237
|
-
self.state_ids.clear()
|
238
|
-
self.states.clear()
|
239
|
-
|
240
|
-
def get_by_tag(self, tag: str):
|
241
|
-
"""
|
242
|
-
Get all states with a specific tag.
|
243
|
-
|
244
|
-
Args:
|
245
|
-
tag (str): The tag to filter by.
|
246
|
-
|
247
|
-
Returns:
|
248
|
-
list: A list of states with the specified tag.
|
249
|
-
"""
|
250
|
-
return [state for state in self.states if state.tag == tag]
|
251
|
-
|
252
|
-
def remove(self, state: State):
|
253
|
-
"""
|
254
|
-
Remove a specific state from the catcher.
|
255
|
-
|
256
|
-
Args:
|
257
|
-
state (State): The state to remove.
|
258
|
-
"""
|
259
|
-
if id(state) in self.state_ids:
|
260
|
-
self.state_ids.remove(id(state))
|
261
|
-
self.states.remove(state)
|
262
|
-
|
263
|
-
def __contains__(self, state: State):
|
264
|
-
"""
|
265
|
-
Check if a state is in the catcher.
|
266
|
-
|
267
|
-
Args:
|
268
|
-
state (State): The state to check for.
|
269
|
-
|
270
|
-
Returns:
|
271
|
-
bool: True if the state is in the catcher, False otherwise.
|
272
|
-
"""
|
273
|
-
return id(state) in self.state_ids
|
274
|
-
|
275
|
-
|
276
|
-
@contextlib.contextmanager
|
277
|
-
def catch_new_states(
|
278
|
-
state_tag: str = None,
|
279
|
-
state_to_exclude: Filter = Nothing()
|
280
|
-
) -> Generator[StateCatcher, None, None]:
|
281
|
-
"""
|
282
|
-
A context manager that catches and tracks new states created within its scope.
|
283
|
-
|
284
|
-
This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
|
285
|
-
new_state_catcher list. It allows for tracking and managing new states created
|
286
|
-
within the context.
|
287
|
-
|
288
|
-
Args:
|
289
|
-
state_tag (str, optional): A string tag to associate with the caught states.
|
290
|
-
Defaults to None.
|
291
|
-
state_to_exclude (Filter, optional): A filter object to specify which states
|
292
|
-
should be excluded from catching. Defaults to Nothing(), which excludes no states.
|
293
|
-
|
294
|
-
Yields:
|
295
|
-
Catcher: A Catcher object that can be used to access and manage the
|
296
|
-
newly created states within the context.
|
297
|
-
|
298
|
-
Example::
|
299
|
-
|
300
|
-
with catch_new_states("my_tag") as catcher:
|
301
|
-
# Create new states here
|
302
|
-
# They will be caught and tagged with "my_tag"
|
303
|
-
# Access caught states through catcher object
|
304
|
-
"""
|
305
|
-
try:
|
306
|
-
catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
|
307
|
-
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
308
|
-
yield catcher
|
309
|
-
finally:
|
310
|
-
TRACE_CONTEXT.new_state_catcher.pop()
|
311
|
-
|
312
|
-
|
313
136
|
def maybe_state(val: Any) -> Any:
|
314
137
|
"""
|
315
138
|
Extracts the value from a State object if given, otherwise returns the input value.
|
@@ -1226,7 +1049,7 @@ class StateTraceStack(Generic[A]):
|
|
1226
1049
|
"""
|
1227
1050
|
if self._jax_trace_new_arg is not None:
|
1228
1051
|
# internal use
|
1229
|
-
state._value = jax.tree.map(
|
1052
|
+
state._value = jax.tree.map(self._jax_trace_new_arg, state._value)
|
1230
1053
|
|
1231
1054
|
def __enter__(self) -> 'StateTraceStack':
|
1232
1055
|
TRACE_CONTEXT.state_stack.append(self)
|
@@ -1660,3 +1483,180 @@ jax.tree_util.register_pytree_with_keys(
|
|
1660
1483
|
_state_ref_unflatten, # type: ignore
|
1661
1484
|
flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
|
1662
1485
|
)
|
1486
|
+
|
1487
|
+
|
1488
|
+
class StateCatcher(PrettyObject):
|
1489
|
+
"""
|
1490
|
+
The catcher to catch and manage new states.
|
1491
|
+
|
1492
|
+
This class provides functionality to collect and tag new State objects.
|
1493
|
+
It ensures that each state is only added once and assigns a tag to each state.
|
1494
|
+
|
1495
|
+
Attributes:
|
1496
|
+
state_tag (str): A string identifier used to tag the caught states.
|
1497
|
+
state_ids (set): A set of state IDs to ensure uniqueness.
|
1498
|
+
states (list): A list to store the caught State objects.
|
1499
|
+
"""
|
1500
|
+
|
1501
|
+
def __init__(
|
1502
|
+
self,
|
1503
|
+
state_tag: str,
|
1504
|
+
state_to_exclude: Filter = Nothing()
|
1505
|
+
):
|
1506
|
+
"""
|
1507
|
+
Initialize a new Catcher instance.
|
1508
|
+
|
1509
|
+
Args:
|
1510
|
+
state_tag (str): The tag to be assigned to caught states.
|
1511
|
+
state_to_exclude (Filter, optional): A filter to exclude states from being caught.
|
1512
|
+
"""
|
1513
|
+
if state_to_exclude is None:
|
1514
|
+
state_to_exclude = Nothing()
|
1515
|
+
self.state_to_exclude = state_to_exclude
|
1516
|
+
self.state_tag = state_tag
|
1517
|
+
self.state_ids = set()
|
1518
|
+
self.states = []
|
1519
|
+
|
1520
|
+
def get_state_values(self) -> List[PyTree]:
|
1521
|
+
"""
|
1522
|
+
Get the values of the caught states.
|
1523
|
+
|
1524
|
+
Returns:
|
1525
|
+
list: A list of values of the caught states.
|
1526
|
+
"""
|
1527
|
+
return [state.value for state in self.states]
|
1528
|
+
|
1529
|
+
def get_states(self) -> List[State]:
|
1530
|
+
"""
|
1531
|
+
Get the caught states.
|
1532
|
+
|
1533
|
+
Returns:
|
1534
|
+
list: A list of the caught states.
|
1535
|
+
"""
|
1536
|
+
return self.states
|
1537
|
+
|
1538
|
+
def append(self, state: State):
|
1539
|
+
"""
|
1540
|
+
Add a new state to the catcher if it hasn't been added before.
|
1541
|
+
|
1542
|
+
This method adds the state to the internal list, records its ID,
|
1543
|
+
and assigns the catcher's tag to the state.
|
1544
|
+
|
1545
|
+
Args:
|
1546
|
+
state (State): The State object to be added.
|
1547
|
+
"""
|
1548
|
+
if self.state_to_exclude((), state):
|
1549
|
+
return
|
1550
|
+
if id(state) not in self.state_ids:
|
1551
|
+
self.state_ids.add(id(state))
|
1552
|
+
self.states.append(state)
|
1553
|
+
state.tag = self.state_tag
|
1554
|
+
|
1555
|
+
def __iter__(self):
|
1556
|
+
"""
|
1557
|
+
Allow iteration over the caught states.
|
1558
|
+
|
1559
|
+
Returns:
|
1560
|
+
iterator: An iterator over the list of caught states.
|
1561
|
+
"""
|
1562
|
+
return iter(self.states)
|
1563
|
+
|
1564
|
+
def __len__(self):
|
1565
|
+
"""
|
1566
|
+
Return the number of caught states.
|
1567
|
+
|
1568
|
+
Returns:
|
1569
|
+
int: The number of caught states.
|
1570
|
+
"""
|
1571
|
+
return len(self.states)
|
1572
|
+
|
1573
|
+
def __getitem__(self, index):
|
1574
|
+
"""
|
1575
|
+
Get a state by index.
|
1576
|
+
|
1577
|
+
Args:
|
1578
|
+
index (int): The index of the state to retrieve.
|
1579
|
+
|
1580
|
+
Returns:
|
1581
|
+
State: The state at the specified index.
|
1582
|
+
"""
|
1583
|
+
return self.states[index]
|
1584
|
+
|
1585
|
+
def clear(self):
|
1586
|
+
"""
|
1587
|
+
Clear all caught states.
|
1588
|
+
"""
|
1589
|
+
self.state_ids.clear()
|
1590
|
+
self.states.clear()
|
1591
|
+
|
1592
|
+
def get_by_tag(self, tag: str):
|
1593
|
+
"""
|
1594
|
+
Get all states with a specific tag.
|
1595
|
+
|
1596
|
+
Args:
|
1597
|
+
tag (str): The tag to filter by.
|
1598
|
+
|
1599
|
+
Returns:
|
1600
|
+
list: A list of states with the specified tag.
|
1601
|
+
"""
|
1602
|
+
return [state for state in self.states if state.tag == tag]
|
1603
|
+
|
1604
|
+
def remove(self, state: State):
|
1605
|
+
"""
|
1606
|
+
Remove a specific state from the catcher.
|
1607
|
+
|
1608
|
+
Args:
|
1609
|
+
state (State): The state to remove.
|
1610
|
+
"""
|
1611
|
+
if id(state) in self.state_ids:
|
1612
|
+
self.state_ids.remove(id(state))
|
1613
|
+
self.states.remove(state)
|
1614
|
+
|
1615
|
+
def __contains__(self, state: State):
|
1616
|
+
"""
|
1617
|
+
Check if a state is in the catcher.
|
1618
|
+
|
1619
|
+
Args:
|
1620
|
+
state (State): The state to check for.
|
1621
|
+
|
1622
|
+
Returns:
|
1623
|
+
bool: True if the state is in the catcher, False otherwise.
|
1624
|
+
"""
|
1625
|
+
return id(state) in self.state_ids
|
1626
|
+
|
1627
|
+
|
1628
|
+
@contextlib.contextmanager
|
1629
|
+
def catch_new_states(
|
1630
|
+
state_tag: str = None,
|
1631
|
+
state_to_exclude: Filter = Nothing()
|
1632
|
+
) -> Generator[StateCatcher, None, None]:
|
1633
|
+
"""
|
1634
|
+
A context manager that catches and tracks new states created within its scope.
|
1635
|
+
|
1636
|
+
This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
|
1637
|
+
new_state_catcher list. It allows for tracking and managing new states created
|
1638
|
+
within the context.
|
1639
|
+
|
1640
|
+
Args:
|
1641
|
+
state_tag (str, optional): A string tag to associate with the caught states.
|
1642
|
+
Defaults to None.
|
1643
|
+
state_to_exclude (Filter, optional): A filter object to specify which states
|
1644
|
+
should be excluded from catching. Defaults to Nothing(), which excludes no states.
|
1645
|
+
|
1646
|
+
Yields:
|
1647
|
+
Catcher: A Catcher object that can be used to access and manage the
|
1648
|
+
newly created states within the context.
|
1649
|
+
|
1650
|
+
Example::
|
1651
|
+
|
1652
|
+
with catch_new_states("my_tag") as catcher:
|
1653
|
+
# Create new states here
|
1654
|
+
# They will be caught and tagged with "my_tag"
|
1655
|
+
# Access caught states through catcher object
|
1656
|
+
"""
|
1657
|
+
try:
|
1658
|
+
catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
|
1659
|
+
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
1660
|
+
yield catcher
|
1661
|
+
finally:
|
1662
|
+
TRACE_CONTEXT.new_state_catcher.pop()
|
brainstate/_utils.py
CHANGED
brainstate/augment/_autograd.py
CHANGED
@@ -27,8 +27,6 @@ The wrapped gradient transformations here are made possible by using the followi
|
|
27
27
|
|
28
28
|
"""
|
29
29
|
|
30
|
-
from __future__ import annotations
|
31
|
-
|
32
30
|
from functools import wraps, partial
|
33
31
|
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
34
32
|
|