brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +7 -3
- brainstate/_state.py +177 -177
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_make_jaxpr.py +4 -6
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_regular_inits.py +0 -2
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_module.py +0 -1
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/METADATA +11 -5
- brainstate-0.1.1.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_compatible_import.py
CHANGED
@@ -35,8 +35,8 @@ __all__ = [
|
|
35
35
|
'safe_map',
|
36
36
|
'safe_zip',
|
37
37
|
'unzip2',
|
38
|
-
'unzip3',
|
39
38
|
'wraps',
|
39
|
+
'Device',
|
40
40
|
]
|
41
41
|
|
42
42
|
T = TypeVar("T")
|
@@ -48,6 +48,11 @@ brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
|
48
48
|
|
49
49
|
from jax.core import get_aval, Tracer
|
50
50
|
|
51
|
+
if jax.__version_info__ < (0, 5, 0):
|
52
|
+
from jax.lib.xla_client import Device
|
53
|
+
else:
|
54
|
+
from jax import Device
|
55
|
+
|
51
56
|
if jax.__version_info__ < (0, 4, 38):
|
52
57
|
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
53
58
|
else:
|
@@ -84,8 +89,7 @@ else:
|
|
84
89
|
return list(zip(*args))
|
85
90
|
|
86
91
|
|
87
|
-
def unzip2(xys: Iterable[tuple[T1, T2]]
|
88
|
-
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
92
|
+
def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
89
93
|
"""Unzip sequence of length-2 tuples into two tuples."""
|
90
94
|
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
91
95
|
# is too permissive about inputs, and does not guarantee a length-2 output.
|
brainstate/_state.py
CHANGED
@@ -133,183 +133,6 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
|
|
133
133
|
TRACE_CONTEXT.tree_check.pop()
|
134
134
|
|
135
135
|
|
136
|
-
class StateCatcher(PrettyObject):
|
137
|
-
"""
|
138
|
-
The catcher to catch and manage new states.
|
139
|
-
|
140
|
-
This class provides functionality to collect and tag new State objects.
|
141
|
-
It ensures that each state is only added once and assigns a tag to each state.
|
142
|
-
|
143
|
-
Attributes:
|
144
|
-
state_tag (str): A string identifier used to tag the caught states.
|
145
|
-
state_ids (set): A set of state IDs to ensure uniqueness.
|
146
|
-
states (list): A list to store the caught State objects.
|
147
|
-
"""
|
148
|
-
|
149
|
-
def __init__(
|
150
|
-
self,
|
151
|
-
state_tag: str,
|
152
|
-
state_to_exclude: Filter = Nothing()
|
153
|
-
):
|
154
|
-
"""
|
155
|
-
Initialize a new Catcher instance.
|
156
|
-
|
157
|
-
Args:
|
158
|
-
state_tag (str): The tag to be assigned to caught states.
|
159
|
-
state_to_exclude (Filter, optional): A filter to exclude states from being caught.
|
160
|
-
"""
|
161
|
-
if state_to_exclude is None:
|
162
|
-
state_to_exclude = Nothing()
|
163
|
-
self.state_to_exclude = state_to_exclude
|
164
|
-
self.state_tag = state_tag
|
165
|
-
self.state_ids = set()
|
166
|
-
self.states = []
|
167
|
-
|
168
|
-
def get_state_values(self) -> List[PyTree]:
|
169
|
-
"""
|
170
|
-
Get the values of the caught states.
|
171
|
-
|
172
|
-
Returns:
|
173
|
-
list: A list of values of the caught states.
|
174
|
-
"""
|
175
|
-
return [state.value for state in self.states]
|
176
|
-
|
177
|
-
def get_states(self) -> List[State]:
|
178
|
-
"""
|
179
|
-
Get the caught states.
|
180
|
-
|
181
|
-
Returns:
|
182
|
-
list: A list of the caught states.
|
183
|
-
"""
|
184
|
-
return self.states
|
185
|
-
|
186
|
-
def append(self, state: State):
|
187
|
-
"""
|
188
|
-
Add a new state to the catcher if it hasn't been added before.
|
189
|
-
|
190
|
-
This method adds the state to the internal list, records its ID,
|
191
|
-
and assigns the catcher's tag to the state.
|
192
|
-
|
193
|
-
Args:
|
194
|
-
state (State): The State object to be added.
|
195
|
-
"""
|
196
|
-
if self.state_to_exclude((), state):
|
197
|
-
return
|
198
|
-
if id(state) not in self.state_ids:
|
199
|
-
self.state_ids.add(id(state))
|
200
|
-
self.states.append(state)
|
201
|
-
state.tag = self.state_tag
|
202
|
-
|
203
|
-
def __iter__(self):
|
204
|
-
"""
|
205
|
-
Allow iteration over the caught states.
|
206
|
-
|
207
|
-
Returns:
|
208
|
-
iterator: An iterator over the list of caught states.
|
209
|
-
"""
|
210
|
-
return iter(self.states)
|
211
|
-
|
212
|
-
def __len__(self):
|
213
|
-
"""
|
214
|
-
Return the number of caught states.
|
215
|
-
|
216
|
-
Returns:
|
217
|
-
int: The number of caught states.
|
218
|
-
"""
|
219
|
-
return len(self.states)
|
220
|
-
|
221
|
-
def __getitem__(self, index):
|
222
|
-
"""
|
223
|
-
Get a state by index.
|
224
|
-
|
225
|
-
Args:
|
226
|
-
index (int): The index of the state to retrieve.
|
227
|
-
|
228
|
-
Returns:
|
229
|
-
State: The state at the specified index.
|
230
|
-
"""
|
231
|
-
return self.states[index]
|
232
|
-
|
233
|
-
def clear(self):
|
234
|
-
"""
|
235
|
-
Clear all caught states.
|
236
|
-
"""
|
237
|
-
self.state_ids.clear()
|
238
|
-
self.states.clear()
|
239
|
-
|
240
|
-
def get_by_tag(self, tag: str):
|
241
|
-
"""
|
242
|
-
Get all states with a specific tag.
|
243
|
-
|
244
|
-
Args:
|
245
|
-
tag (str): The tag to filter by.
|
246
|
-
|
247
|
-
Returns:
|
248
|
-
list: A list of states with the specified tag.
|
249
|
-
"""
|
250
|
-
return [state for state in self.states if state.tag == tag]
|
251
|
-
|
252
|
-
def remove(self, state: State):
|
253
|
-
"""
|
254
|
-
Remove a specific state from the catcher.
|
255
|
-
|
256
|
-
Args:
|
257
|
-
state (State): The state to remove.
|
258
|
-
"""
|
259
|
-
if id(state) in self.state_ids:
|
260
|
-
self.state_ids.remove(id(state))
|
261
|
-
self.states.remove(state)
|
262
|
-
|
263
|
-
def __contains__(self, state: State):
|
264
|
-
"""
|
265
|
-
Check if a state is in the catcher.
|
266
|
-
|
267
|
-
Args:
|
268
|
-
state (State): The state to check for.
|
269
|
-
|
270
|
-
Returns:
|
271
|
-
bool: True if the state is in the catcher, False otherwise.
|
272
|
-
"""
|
273
|
-
return id(state) in self.state_ids
|
274
|
-
|
275
|
-
|
276
|
-
@contextlib.contextmanager
|
277
|
-
def catch_new_states(
|
278
|
-
state_tag: str = None,
|
279
|
-
state_to_exclude: Filter = Nothing()
|
280
|
-
) -> Generator[StateCatcher, None, None]:
|
281
|
-
"""
|
282
|
-
A context manager that catches and tracks new states created within its scope.
|
283
|
-
|
284
|
-
This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
|
285
|
-
new_state_catcher list. It allows for tracking and managing new states created
|
286
|
-
within the context.
|
287
|
-
|
288
|
-
Args:
|
289
|
-
state_tag (str, optional): A string tag to associate with the caught states.
|
290
|
-
Defaults to None.
|
291
|
-
state_to_exclude (Filter, optional): A filter object to specify which states
|
292
|
-
should be excluded from catching. Defaults to Nothing(), which excludes no states.
|
293
|
-
|
294
|
-
Yields:
|
295
|
-
Catcher: A Catcher object that can be used to access and manage the
|
296
|
-
newly created states within the context.
|
297
|
-
|
298
|
-
Example::
|
299
|
-
|
300
|
-
with catch_new_states("my_tag") as catcher:
|
301
|
-
# Create new states here
|
302
|
-
# They will be caught and tagged with "my_tag"
|
303
|
-
# Access caught states through catcher object
|
304
|
-
"""
|
305
|
-
try:
|
306
|
-
catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
|
307
|
-
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
308
|
-
yield catcher
|
309
|
-
finally:
|
310
|
-
TRACE_CONTEXT.new_state_catcher.pop()
|
311
|
-
|
312
|
-
|
313
136
|
def maybe_state(val: Any) -> Any:
|
314
137
|
"""
|
315
138
|
Extracts the value from a State object if given, otherwise returns the input value.
|
@@ -1660,3 +1483,180 @@ jax.tree_util.register_pytree_with_keys(
|
|
1660
1483
|
_state_ref_unflatten, # type: ignore
|
1661
1484
|
flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
|
1662
1485
|
)
|
1486
|
+
|
1487
|
+
|
1488
|
+
class StateCatcher(PrettyObject):
|
1489
|
+
"""
|
1490
|
+
The catcher to catch and manage new states.
|
1491
|
+
|
1492
|
+
This class provides functionality to collect and tag new State objects.
|
1493
|
+
It ensures that each state is only added once and assigns a tag to each state.
|
1494
|
+
|
1495
|
+
Attributes:
|
1496
|
+
state_tag (str): A string identifier used to tag the caught states.
|
1497
|
+
state_ids (set): A set of state IDs to ensure uniqueness.
|
1498
|
+
states (list): A list to store the caught State objects.
|
1499
|
+
"""
|
1500
|
+
|
1501
|
+
def __init__(
|
1502
|
+
self,
|
1503
|
+
state_tag: str,
|
1504
|
+
state_to_exclude: Filter = Nothing()
|
1505
|
+
):
|
1506
|
+
"""
|
1507
|
+
Initialize a new Catcher instance.
|
1508
|
+
|
1509
|
+
Args:
|
1510
|
+
state_tag (str): The tag to be assigned to caught states.
|
1511
|
+
state_to_exclude (Filter, optional): A filter to exclude states from being caught.
|
1512
|
+
"""
|
1513
|
+
if state_to_exclude is None:
|
1514
|
+
state_to_exclude = Nothing()
|
1515
|
+
self.state_to_exclude = state_to_exclude
|
1516
|
+
self.state_tag = state_tag
|
1517
|
+
self.state_ids = set()
|
1518
|
+
self.states = []
|
1519
|
+
|
1520
|
+
def get_state_values(self) -> List[PyTree]:
|
1521
|
+
"""
|
1522
|
+
Get the values of the caught states.
|
1523
|
+
|
1524
|
+
Returns:
|
1525
|
+
list: A list of values of the caught states.
|
1526
|
+
"""
|
1527
|
+
return [state.value for state in self.states]
|
1528
|
+
|
1529
|
+
def get_states(self) -> List[State]:
|
1530
|
+
"""
|
1531
|
+
Get the caught states.
|
1532
|
+
|
1533
|
+
Returns:
|
1534
|
+
list: A list of the caught states.
|
1535
|
+
"""
|
1536
|
+
return self.states
|
1537
|
+
|
1538
|
+
def append(self, state: State):
|
1539
|
+
"""
|
1540
|
+
Add a new state to the catcher if it hasn't been added before.
|
1541
|
+
|
1542
|
+
This method adds the state to the internal list, records its ID,
|
1543
|
+
and assigns the catcher's tag to the state.
|
1544
|
+
|
1545
|
+
Args:
|
1546
|
+
state (State): The State object to be added.
|
1547
|
+
"""
|
1548
|
+
if self.state_to_exclude((), state):
|
1549
|
+
return
|
1550
|
+
if id(state) not in self.state_ids:
|
1551
|
+
self.state_ids.add(id(state))
|
1552
|
+
self.states.append(state)
|
1553
|
+
state.tag = self.state_tag
|
1554
|
+
|
1555
|
+
def __iter__(self):
|
1556
|
+
"""
|
1557
|
+
Allow iteration over the caught states.
|
1558
|
+
|
1559
|
+
Returns:
|
1560
|
+
iterator: An iterator over the list of caught states.
|
1561
|
+
"""
|
1562
|
+
return iter(self.states)
|
1563
|
+
|
1564
|
+
def __len__(self):
|
1565
|
+
"""
|
1566
|
+
Return the number of caught states.
|
1567
|
+
|
1568
|
+
Returns:
|
1569
|
+
int: The number of caught states.
|
1570
|
+
"""
|
1571
|
+
return len(self.states)
|
1572
|
+
|
1573
|
+
def __getitem__(self, index):
|
1574
|
+
"""
|
1575
|
+
Get a state by index.
|
1576
|
+
|
1577
|
+
Args:
|
1578
|
+
index (int): The index of the state to retrieve.
|
1579
|
+
|
1580
|
+
Returns:
|
1581
|
+
State: The state at the specified index.
|
1582
|
+
"""
|
1583
|
+
return self.states[index]
|
1584
|
+
|
1585
|
+
def clear(self):
|
1586
|
+
"""
|
1587
|
+
Clear all caught states.
|
1588
|
+
"""
|
1589
|
+
self.state_ids.clear()
|
1590
|
+
self.states.clear()
|
1591
|
+
|
1592
|
+
def get_by_tag(self, tag: str):
|
1593
|
+
"""
|
1594
|
+
Get all states with a specific tag.
|
1595
|
+
|
1596
|
+
Args:
|
1597
|
+
tag (str): The tag to filter by.
|
1598
|
+
|
1599
|
+
Returns:
|
1600
|
+
list: A list of states with the specified tag.
|
1601
|
+
"""
|
1602
|
+
return [state for state in self.states if state.tag == tag]
|
1603
|
+
|
1604
|
+
def remove(self, state: State):
|
1605
|
+
"""
|
1606
|
+
Remove a specific state from the catcher.
|
1607
|
+
|
1608
|
+
Args:
|
1609
|
+
state (State): The state to remove.
|
1610
|
+
"""
|
1611
|
+
if id(state) in self.state_ids:
|
1612
|
+
self.state_ids.remove(id(state))
|
1613
|
+
self.states.remove(state)
|
1614
|
+
|
1615
|
+
def __contains__(self, state: State):
|
1616
|
+
"""
|
1617
|
+
Check if a state is in the catcher.
|
1618
|
+
|
1619
|
+
Args:
|
1620
|
+
state (State): The state to check for.
|
1621
|
+
|
1622
|
+
Returns:
|
1623
|
+
bool: True if the state is in the catcher, False otherwise.
|
1624
|
+
"""
|
1625
|
+
return id(state) in self.state_ids
|
1626
|
+
|
1627
|
+
|
1628
|
+
@contextlib.contextmanager
|
1629
|
+
def catch_new_states(
|
1630
|
+
state_tag: str = None,
|
1631
|
+
state_to_exclude: Filter = Nothing()
|
1632
|
+
) -> Generator[StateCatcher, None, None]:
|
1633
|
+
"""
|
1634
|
+
A context manager that catches and tracks new states created within its scope.
|
1635
|
+
|
1636
|
+
This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
|
1637
|
+
new_state_catcher list. It allows for tracking and managing new states created
|
1638
|
+
within the context.
|
1639
|
+
|
1640
|
+
Args:
|
1641
|
+
state_tag (str, optional): A string tag to associate with the caught states.
|
1642
|
+
Defaults to None.
|
1643
|
+
state_to_exclude (Filter, optional): A filter object to specify which states
|
1644
|
+
should be excluded from catching. Defaults to Nothing(), which excludes no states.
|
1645
|
+
|
1646
|
+
Yields:
|
1647
|
+
Catcher: A Catcher object that can be used to access and manage the
|
1648
|
+
newly created states within the context.
|
1649
|
+
|
1650
|
+
Example::
|
1651
|
+
|
1652
|
+
with catch_new_states("my_tag") as catcher:
|
1653
|
+
# Create new states here
|
1654
|
+
# They will be caught and tagged with "my_tag"
|
1655
|
+
# Access caught states through catcher object
|
1656
|
+
"""
|
1657
|
+
try:
|
1658
|
+
catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
|
1659
|
+
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
1660
|
+
yield catcher
|
1661
|
+
finally:
|
1662
|
+
TRACE_CONTEXT.new_state_catcher.pop()
|
brainstate/_utils.py
CHANGED
brainstate/augment/_autograd.py
CHANGED
@@ -27,8 +27,6 @@ The wrapped gradient transformations here are made possible by using the followi
|
|
27
27
|
|
28
28
|
"""
|
29
29
|
|
30
|
-
from __future__ import annotations
|
31
|
-
|
32
30
|
from functools import wraps, partial
|
33
31
|
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
34
32
|
|
brainstate/augment/_mapping.py
CHANGED
@@ -13,8 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import functools
|
19
17
|
from typing import (
|
20
18
|
Any,
|
@@ -33,6 +31,7 @@ from typing import (
|
|
33
31
|
import jax
|
34
32
|
from jax.interpreters.batching import BatchTracer
|
35
33
|
|
34
|
+
from brainstate._compatible_import import Device
|
36
35
|
from brainstate._state import State, catch_new_states
|
37
36
|
from brainstate.compile import scan, StatefulFunction
|
38
37
|
from brainstate.random import RandomState, DEFAULT
|
@@ -700,7 +699,7 @@ def pmap(
|
|
700
699
|
in_axes: Any = 0,
|
701
700
|
out_axes: Any = 0,
|
702
701
|
static_broadcasted_argnums: int | Iterable[int] = (),
|
703
|
-
devices: Optional[Sequence[
|
702
|
+
devices: Optional[Sequence[Device]] = None, # noqa: F811
|
704
703
|
backend: Optional[str] = None,
|
705
704
|
axis_size: Optional[int] = None,
|
706
705
|
donate_argnums: int | Iterable[int] = (),
|
brainstate/augment/_random.py
CHANGED
brainstate/compile/_error_if.py
CHANGED
brainstate/compile/_jit.py
CHANGED
@@ -13,16 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import functools
|
19
17
|
from collections.abc import Iterable, Sequence
|
20
18
|
from typing import (Any, Callable, Union)
|
21
19
|
|
22
20
|
import jax
|
23
21
|
from jax._src import sharding_impls
|
24
|
-
from jax.lib import xla_client as xc
|
25
22
|
|
23
|
+
from brainstate._compatible_import import Device
|
26
24
|
from brainstate._utils import set_module_as
|
27
25
|
from brainstate.typing import Missing
|
28
26
|
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
@@ -94,8 +92,8 @@ def _get_jitted_fun(
|
|
94
92
|
read_state_vals = state_trace.get_read_state_values(True)
|
95
93
|
|
96
94
|
# call the jitted function
|
97
|
-
# print('Running ...')
|
98
95
|
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
96
|
+
|
99
97
|
# write the state values back to the states
|
100
98
|
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
101
99
|
return outs
|
@@ -106,8 +104,11 @@ def _get_jitted_fun(
|
|
106
104
|
"""
|
107
105
|
# clear the cache of the stateful function
|
108
106
|
fun.clear_cache()
|
109
|
-
|
110
|
-
|
107
|
+
try:
|
108
|
+
# clear the cache of the jitted function
|
109
|
+
jit_fun.clear_cache()
|
110
|
+
except AttributeError:
|
111
|
+
pass
|
111
112
|
|
112
113
|
def eval_shape():
|
113
114
|
raise NotImplementedError
|
@@ -165,7 +166,7 @@ def _get_jitted_fun(
|
|
165
166
|
# compile the jitted function
|
166
167
|
jitted_fun.compile = compile
|
167
168
|
|
168
|
-
# trace the jitted
|
169
|
+
# trace the jitted function
|
169
170
|
jitted_fun.trace = trace
|
170
171
|
|
171
172
|
return jitted_fun
|
@@ -180,7 +181,7 @@ def jit(
|
|
180
181
|
donate_argnums: int | Sequence[int] | None = None,
|
181
182
|
donate_argnames: str | Iterable[str] | None = None,
|
182
183
|
keep_unused: bool = False,
|
183
|
-
device:
|
184
|
+
device: Device | None = None,
|
184
185
|
backend: str | None = None,
|
185
186
|
inline: bool = False,
|
186
187
|
abstracted_axes: Any | None = None,
|
@@ -13,8 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import math
|
19
17
|
from functools import wraps
|
20
18
|
from typing import Callable, Optional, TypeVar, Tuple, Any
|
@@ -51,8 +51,6 @@ function.
|
|
51
51
|
|
52
52
|
"""
|
53
53
|
|
54
|
-
from __future__ import annotations
|
55
|
-
|
56
54
|
import functools
|
57
55
|
import inspect
|
58
56
|
import operator
|
@@ -64,7 +62,7 @@ import jax
|
|
64
62
|
from jax._src import source_info_util
|
65
63
|
from jax._src.linear_util import annotate
|
66
64
|
from jax._src.traceback_util import api_boundary
|
67
|
-
from jax.api_util import shaped_abstractify
|
65
|
+
from jax.api_util import shaped_abstractify, debug_info
|
68
66
|
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
69
67
|
from jax.interpreters import partial_eval as pe
|
70
68
|
|
@@ -745,7 +743,7 @@ def _make_jaxpr(
|
|
745
743
|
@wraps(fun)
|
746
744
|
@api_boundary
|
747
745
|
def make_jaxpr_f(*args, **kwargs):
|
748
|
-
f = wrap_init(fun)
|
746
|
+
f = wrap_init(fun, debug_info=debug_info('make_jaxpr', fun, args, kwargs))
|
749
747
|
if static_argnums:
|
750
748
|
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
751
749
|
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
@@ -754,12 +752,12 @@ def _make_jaxpr(
|
|
754
752
|
f, out_tree = _flatten_fun(f, in_tree)
|
755
753
|
f = annotate(f, in_type)
|
756
754
|
if jax.__version_info__ < (0, 5, 0):
|
757
|
-
|
755
|
+
debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
758
756
|
with ExitStack() as stack:
|
759
757
|
if axis_env is not None:
|
760
758
|
stack.enter_context(extend_axis_env_nd(axis_env))
|
761
759
|
if jax.__version_info__ < (0, 5, 0):
|
762
|
-
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=
|
760
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
|
763
761
|
else:
|
764
762
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
765
763
|
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
brainstate/compile/_unvmap.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
15
|
|
17
16
|
import jax
|
18
17
|
import jax.core
|
brainstate/compile/_util.py
CHANGED
brainstate/environ.py
CHANGED