brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +7 -3
  3. brainstate/_state.py +177 -177
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_eval_shape.py +0 -2
  7. brainstate/augment/_mapping.py +2 -3
  8. brainstate/augment/_random.py +0 -2
  9. brainstate/compile/_ad_checkpoint.py +0 -2
  10. brainstate/compile/_conditions.py +0 -2
  11. brainstate/compile/_error_if.py +0 -2
  12. brainstate/compile/_jit.py +9 -8
  13. brainstate/compile/_loop_collect_return.py +0 -2
  14. brainstate/compile/_loop_no_collection.py +0 -2
  15. brainstate/compile/_make_jaxpr.py +4 -6
  16. brainstate/compile/_progress_bar.py +0 -1
  17. brainstate/compile/_unvmap.py +0 -1
  18. brainstate/compile/_util.py +0 -2
  19. brainstate/environ.py +0 -2
  20. brainstate/functional/_activations.py +0 -2
  21. brainstate/functional/_normalization.py +0 -2
  22. brainstate/functional/_others.py +0 -2
  23. brainstate/functional/_spikes.py +0 -1
  24. brainstate/graph/_graph_node.py +1 -3
  25. brainstate/graph/_graph_operation.py +4 -2
  26. brainstate/init/_base.py +0 -2
  27. brainstate/init/_generic.py +0 -1
  28. brainstate/init/_random_inits.py +0 -1
  29. brainstate/init/_regular_inits.py +0 -2
  30. brainstate/mixin.py +0 -2
  31. brainstate/nn/_collective_ops.py +0 -3
  32. brainstate/nn/_common.py +0 -2
  33. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  34. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  35. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  36. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  37. brainstate/nn/_dyn_impl/_readout.py +0 -1
  38. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  39. brainstate/nn/_dynamics/_projection_base.py +0 -1
  40. brainstate/nn/_dynamics/_state_delay.py +0 -2
  41. brainstate/nn/_dynamics/_synouts.py +0 -2
  42. brainstate/nn/_elementwise/_dropout.py +0 -2
  43. brainstate/nn/_elementwise/_elementwise.py +0 -2
  44. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  45. brainstate/nn/_event/_linear_mv.py +0 -2
  46. brainstate/nn/_exp_euler.py +0 -2
  47. brainstate/nn/_interaction/_conv.py +0 -2
  48. brainstate/nn/_interaction/_embedding.py +0 -1
  49. brainstate/nn/_interaction/_linear.py +0 -2
  50. brainstate/nn/_interaction/_normalizations.py +0 -2
  51. brainstate/nn/_interaction/_poolings.py +0 -2
  52. brainstate/nn/_module.py +0 -1
  53. brainstate/nn/metrics.py +0 -2
  54. brainstate/optim/_base.py +0 -2
  55. brainstate/optim/_lr_scheduler.py +0 -1
  56. brainstate/optim/_optax_optimizer.py +0 -2
  57. brainstate/optim/_sgd_optimizer.py +0 -1
  58. brainstate/random/_rand_funs.py +0 -1
  59. brainstate/random/_rand_seed.py +0 -1
  60. brainstate/random/_rand_state.py +0 -1
  61. brainstate/surrogate.py +0 -1
  62. brainstate/typing.py +0 -2
  63. brainstate/util/_caller.py +4 -6
  64. brainstate/util/_others.py +0 -2
  65. brainstate/util/_pretty_pytree.py +201 -150
  66. brainstate/util/_pretty_repr.py +0 -2
  67. brainstate/util/_pretty_table.py +57 -3
  68. brainstate/util/_scaling.py +0 -2
  69. brainstate/util/_struct.py +0 -2
  70. brainstate/util/filter.py +0 -2
  71. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/METADATA +11 -5
  72. brainstate-0.1.1.dist-info/RECORD +133 -0
  73. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  74. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/LICENSE +0 -0
  75. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/WHEEL +0 -0
  76. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.1.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.0"
20
+ __version__ = "0.1.1"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
@@ -35,8 +35,8 @@ __all__ = [
35
35
  'safe_map',
36
36
  'safe_zip',
37
37
  'unzip2',
38
- 'unzip3',
39
38
  'wraps',
39
+ 'Device',
40
40
  ]
41
41
 
42
42
  T = TypeVar("T")
@@ -48,6 +48,11 @@ brainevent_installed = importlib.util.find_spec('brainevent') is not None
48
48
 
49
49
  from jax.core import get_aval, Tracer
50
50
 
51
+ if jax.__version_info__ < (0, 5, 0):
52
+ from jax.lib.xla_client import Device
53
+ else:
54
+ from jax import Device
55
+
51
56
  if jax.__version_info__ < (0, 4, 38):
52
57
  from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
53
58
  else:
@@ -84,8 +89,7 @@ else:
84
89
  return list(zip(*args))
85
90
 
86
91
 
87
- def unzip2(xys: Iterable[tuple[T1, T2]]
88
- ) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
92
+ def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
89
93
  """Unzip sequence of length-2 tuples into two tuples."""
90
94
  # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
91
95
  # is too permissive about inputs, and does not guarantee a length-2 output.
brainstate/_state.py CHANGED
@@ -133,183 +133,6 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
133
133
  TRACE_CONTEXT.tree_check.pop()
134
134
 
135
135
 
136
- class StateCatcher(PrettyObject):
137
- """
138
- The catcher to catch and manage new states.
139
-
140
- This class provides functionality to collect and tag new State objects.
141
- It ensures that each state is only added once and assigns a tag to each state.
142
-
143
- Attributes:
144
- state_tag (str): A string identifier used to tag the caught states.
145
- state_ids (set): A set of state IDs to ensure uniqueness.
146
- states (list): A list to store the caught State objects.
147
- """
148
-
149
- def __init__(
150
- self,
151
- state_tag: str,
152
- state_to_exclude: Filter = Nothing()
153
- ):
154
- """
155
- Initialize a new Catcher instance.
156
-
157
- Args:
158
- state_tag (str): The tag to be assigned to caught states.
159
- state_to_exclude (Filter, optional): A filter to exclude states from being caught.
160
- """
161
- if state_to_exclude is None:
162
- state_to_exclude = Nothing()
163
- self.state_to_exclude = state_to_exclude
164
- self.state_tag = state_tag
165
- self.state_ids = set()
166
- self.states = []
167
-
168
- def get_state_values(self) -> List[PyTree]:
169
- """
170
- Get the values of the caught states.
171
-
172
- Returns:
173
- list: A list of values of the caught states.
174
- """
175
- return [state.value for state in self.states]
176
-
177
- def get_states(self) -> List[State]:
178
- """
179
- Get the caught states.
180
-
181
- Returns:
182
- list: A list of the caught states.
183
- """
184
- return self.states
185
-
186
- def append(self, state: State):
187
- """
188
- Add a new state to the catcher if it hasn't been added before.
189
-
190
- This method adds the state to the internal list, records its ID,
191
- and assigns the catcher's tag to the state.
192
-
193
- Args:
194
- state (State): The State object to be added.
195
- """
196
- if self.state_to_exclude((), state):
197
- return
198
- if id(state) not in self.state_ids:
199
- self.state_ids.add(id(state))
200
- self.states.append(state)
201
- state.tag = self.state_tag
202
-
203
- def __iter__(self):
204
- """
205
- Allow iteration over the caught states.
206
-
207
- Returns:
208
- iterator: An iterator over the list of caught states.
209
- """
210
- return iter(self.states)
211
-
212
- def __len__(self):
213
- """
214
- Return the number of caught states.
215
-
216
- Returns:
217
- int: The number of caught states.
218
- """
219
- return len(self.states)
220
-
221
- def __getitem__(self, index):
222
- """
223
- Get a state by index.
224
-
225
- Args:
226
- index (int): The index of the state to retrieve.
227
-
228
- Returns:
229
- State: The state at the specified index.
230
- """
231
- return self.states[index]
232
-
233
- def clear(self):
234
- """
235
- Clear all caught states.
236
- """
237
- self.state_ids.clear()
238
- self.states.clear()
239
-
240
- def get_by_tag(self, tag: str):
241
- """
242
- Get all states with a specific tag.
243
-
244
- Args:
245
- tag (str): The tag to filter by.
246
-
247
- Returns:
248
- list: A list of states with the specified tag.
249
- """
250
- return [state for state in self.states if state.tag == tag]
251
-
252
- def remove(self, state: State):
253
- """
254
- Remove a specific state from the catcher.
255
-
256
- Args:
257
- state (State): The state to remove.
258
- """
259
- if id(state) in self.state_ids:
260
- self.state_ids.remove(id(state))
261
- self.states.remove(state)
262
-
263
- def __contains__(self, state: State):
264
- """
265
- Check if a state is in the catcher.
266
-
267
- Args:
268
- state (State): The state to check for.
269
-
270
- Returns:
271
- bool: True if the state is in the catcher, False otherwise.
272
- """
273
- return id(state) in self.state_ids
274
-
275
-
276
- @contextlib.contextmanager
277
- def catch_new_states(
278
- state_tag: str = None,
279
- state_to_exclude: Filter = Nothing()
280
- ) -> Generator[StateCatcher, None, None]:
281
- """
282
- A context manager that catches and tracks new states created within its scope.
283
-
284
- This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
285
- new_state_catcher list. It allows for tracking and managing new states created
286
- within the context.
287
-
288
- Args:
289
- state_tag (str, optional): A string tag to associate with the caught states.
290
- Defaults to None.
291
- state_to_exclude (Filter, optional): A filter object to specify which states
292
- should be excluded from catching. Defaults to Nothing(), which excludes no states.
293
-
294
- Yields:
295
- Catcher: A Catcher object that can be used to access and manage the
296
- newly created states within the context.
297
-
298
- Example::
299
-
300
- with catch_new_states("my_tag") as catcher:
301
- # Create new states here
302
- # They will be caught and tagged with "my_tag"
303
- # Access caught states through catcher object
304
- """
305
- try:
306
- catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
307
- TRACE_CONTEXT.new_state_catcher.append(catcher)
308
- yield catcher
309
- finally:
310
- TRACE_CONTEXT.new_state_catcher.pop()
311
-
312
-
313
136
  def maybe_state(val: Any) -> Any:
314
137
  """
315
138
  Extracts the value from a State object if given, otherwise returns the input value.
@@ -1660,3 +1483,180 @@ jax.tree_util.register_pytree_with_keys(
1660
1483
  _state_ref_unflatten, # type: ignore
1661
1484
  flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
1662
1485
  )
1486
+
1487
+
1488
+ class StateCatcher(PrettyObject):
1489
+ """
1490
+ The catcher to catch and manage new states.
1491
+
1492
+ This class provides functionality to collect and tag new State objects.
1493
+ It ensures that each state is only added once and assigns a tag to each state.
1494
+
1495
+ Attributes:
1496
+ state_tag (str): A string identifier used to tag the caught states.
1497
+ state_ids (set): A set of state IDs to ensure uniqueness.
1498
+ states (list): A list to store the caught State objects.
1499
+ """
1500
+
1501
+ def __init__(
1502
+ self,
1503
+ state_tag: str,
1504
+ state_to_exclude: Filter = Nothing()
1505
+ ):
1506
+ """
1507
+ Initialize a new Catcher instance.
1508
+
1509
+ Args:
1510
+ state_tag (str): The tag to be assigned to caught states.
1511
+ state_to_exclude (Filter, optional): A filter to exclude states from being caught.
1512
+ """
1513
+ if state_to_exclude is None:
1514
+ state_to_exclude = Nothing()
1515
+ self.state_to_exclude = state_to_exclude
1516
+ self.state_tag = state_tag
1517
+ self.state_ids = set()
1518
+ self.states = []
1519
+
1520
+ def get_state_values(self) -> List[PyTree]:
1521
+ """
1522
+ Get the values of the caught states.
1523
+
1524
+ Returns:
1525
+ list: A list of values of the caught states.
1526
+ """
1527
+ return [state.value for state in self.states]
1528
+
1529
+ def get_states(self) -> List[State]:
1530
+ """
1531
+ Get the caught states.
1532
+
1533
+ Returns:
1534
+ list: A list of the caught states.
1535
+ """
1536
+ return self.states
1537
+
1538
+ def append(self, state: State):
1539
+ """
1540
+ Add a new state to the catcher if it hasn't been added before.
1541
+
1542
+ This method adds the state to the internal list, records its ID,
1543
+ and assigns the catcher's tag to the state.
1544
+
1545
+ Args:
1546
+ state (State): The State object to be added.
1547
+ """
1548
+ if self.state_to_exclude((), state):
1549
+ return
1550
+ if id(state) not in self.state_ids:
1551
+ self.state_ids.add(id(state))
1552
+ self.states.append(state)
1553
+ state.tag = self.state_tag
1554
+
1555
+ def __iter__(self):
1556
+ """
1557
+ Allow iteration over the caught states.
1558
+
1559
+ Returns:
1560
+ iterator: An iterator over the list of caught states.
1561
+ """
1562
+ return iter(self.states)
1563
+
1564
+ def __len__(self):
1565
+ """
1566
+ Return the number of caught states.
1567
+
1568
+ Returns:
1569
+ int: The number of caught states.
1570
+ """
1571
+ return len(self.states)
1572
+
1573
+ def __getitem__(self, index):
1574
+ """
1575
+ Get a state by index.
1576
+
1577
+ Args:
1578
+ index (int): The index of the state to retrieve.
1579
+
1580
+ Returns:
1581
+ State: The state at the specified index.
1582
+ """
1583
+ return self.states[index]
1584
+
1585
+ def clear(self):
1586
+ """
1587
+ Clear all caught states.
1588
+ """
1589
+ self.state_ids.clear()
1590
+ self.states.clear()
1591
+
1592
+ def get_by_tag(self, tag: str):
1593
+ """
1594
+ Get all states with a specific tag.
1595
+
1596
+ Args:
1597
+ tag (str): The tag to filter by.
1598
+
1599
+ Returns:
1600
+ list: A list of states with the specified tag.
1601
+ """
1602
+ return [state for state in self.states if state.tag == tag]
1603
+
1604
+ def remove(self, state: State):
1605
+ """
1606
+ Remove a specific state from the catcher.
1607
+
1608
+ Args:
1609
+ state (State): The state to remove.
1610
+ """
1611
+ if id(state) in self.state_ids:
1612
+ self.state_ids.remove(id(state))
1613
+ self.states.remove(state)
1614
+
1615
+ def __contains__(self, state: State):
1616
+ """
1617
+ Check if a state is in the catcher.
1618
+
1619
+ Args:
1620
+ state (State): The state to check for.
1621
+
1622
+ Returns:
1623
+ bool: True if the state is in the catcher, False otherwise.
1624
+ """
1625
+ return id(state) in self.state_ids
1626
+
1627
+
1628
+ @contextlib.contextmanager
1629
+ def catch_new_states(
1630
+ state_tag: str = None,
1631
+ state_to_exclude: Filter = Nothing()
1632
+ ) -> Generator[StateCatcher, None, None]:
1633
+ """
1634
+ A context manager that catches and tracks new states created within its scope.
1635
+
1636
+ This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
1637
+ new_state_catcher list. It allows for tracking and managing new states created
1638
+ within the context.
1639
+
1640
+ Args:
1641
+ state_tag (str, optional): A string tag to associate with the caught states.
1642
+ Defaults to None.
1643
+ state_to_exclude (Filter, optional): A filter object to specify which states
1644
+ should be excluded from catching. Defaults to Nothing(), which excludes no states.
1645
+
1646
+ Yields:
1647
+ Catcher: A Catcher object that can be used to access and manage the
1648
+ newly created states within the context.
1649
+
1650
+ Example::
1651
+
1652
+ with catch_new_states("my_tag") as catcher:
1653
+ # Create new states here
1654
+ # They will be caught and tagged with "my_tag"
1655
+ # Access caught states through catcher object
1656
+ """
1657
+ try:
1658
+ catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
1659
+ TRACE_CONTEXT.new_state_catcher.append(catcher)
1660
+ yield catcher
1661
+ finally:
1662
+ TRACE_CONTEXT.new_state_catcher.pop()
brainstate/_utils.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import warnings
19
18
 
@@ -27,8 +27,6 @@ The wrapped gradient transformations here are made possible by using the followi
27
27
 
28
28
  """
29
29
 
30
- from __future__ import annotations
31
-
32
30
  from functools import wraps, partial
33
31
  from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
34
32
 
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import functools
19
17
  from typing import Any, TypeVar, Callable, Sequence, Union
20
18
 
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import functools
19
17
  from typing import (
20
18
  Any,
@@ -33,6 +31,7 @@ from typing import (
33
31
  import jax
34
32
  from jax.interpreters.batching import BatchTracer
35
33
 
34
+ from brainstate._compatible_import import Device
36
35
  from brainstate._state import State, catch_new_states
37
36
  from brainstate.compile import scan, StatefulFunction
38
37
  from brainstate.random import RandomState, DEFAULT
@@ -700,7 +699,7 @@ def pmap(
700
699
  in_axes: Any = 0,
701
700
  out_axes: Any = 0,
702
701
  static_broadcasted_argnums: int | Iterable[int] = (),
703
- devices: Optional[Sequence[jax.Device]] = None, # noqa: F811
702
+ devices: Optional[Sequence[Device]] = None, # noqa: F811
704
703
  backend: Optional[str] = None,
705
704
  axis_size: Optional[int] = None,
706
705
  donate_argnums: int | Iterable[int] = (),
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import functools
19
17
  from typing import Callable, Sequence, Union
20
18
 
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import functools
19
17
  from typing import Callable, Tuple, Union
20
18
 
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from collections.abc import Callable, Sequence
19
17
 
20
18
  import jax
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import functools
19
17
  from functools import partial
20
18
  from typing import Callable, Union
@@ -13,16 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import functools
19
17
  from collections.abc import Iterable, Sequence
20
18
  from typing import (Any, Callable, Union)
21
19
 
22
20
  import jax
23
21
  from jax._src import sharding_impls
24
- from jax.lib import xla_client as xc
25
22
 
23
+ from brainstate._compatible_import import Device
26
24
  from brainstate._utils import set_module_as
27
25
  from brainstate.typing import Missing
28
26
  from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
@@ -94,8 +92,8 @@ def _get_jitted_fun(
94
92
  read_state_vals = state_trace.get_read_state_values(True)
95
93
 
96
94
  # call the jitted function
97
- # print('Running ...')
98
95
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
96
+
99
97
  # write the state values back to the states
100
98
  write_back_state_values(state_trace, read_state_vals, write_state_vals)
101
99
  return outs
@@ -106,8 +104,11 @@ def _get_jitted_fun(
106
104
  """
107
105
  # clear the cache of the stateful function
108
106
  fun.clear_cache()
109
- # clear the cache of the jitted function
110
- jit_fun.clear_cache()
107
+ try:
108
+ # clear the cache of the jitted function
109
+ jit_fun.clear_cache()
110
+ except AttributeError:
111
+ pass
111
112
 
112
113
  def eval_shape():
113
114
  raise NotImplementedError
@@ -165,7 +166,7 @@ def _get_jitted_fun(
165
166
  # compile the jitted function
166
167
  jitted_fun.compile = compile
167
168
 
168
- # trace the jitted
169
+ # trace the jitted function
169
170
  jitted_fun.trace = trace
170
171
 
171
172
  return jitted_fun
@@ -180,7 +181,7 @@ def jit(
180
181
  donate_argnums: int | Sequence[int] | None = None,
181
182
  donate_argnames: str | Iterable[str] | None = None,
182
183
  keep_unused: bool = False,
183
- device: xc.Device | None = None,
184
+ device: Device | None = None,
184
185
  backend: str | None = None,
185
186
  inline: bool = False,
186
187
  abstracted_axes: Any | None = None,
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import math
19
17
  from functools import wraps
20
18
  from typing import Callable, Optional, TypeVar, Tuple, Any
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import math
19
17
  from typing import Any, Callable, TypeVar
20
18
 
@@ -51,8 +51,6 @@ function.
51
51
 
52
52
  """
53
53
 
54
- from __future__ import annotations
55
-
56
54
  import functools
57
55
  import inspect
58
56
  import operator
@@ -64,7 +62,7 @@ import jax
64
62
  from jax._src import source_info_util
65
63
  from jax._src.linear_util import annotate
66
64
  from jax._src.traceback_util import api_boundary
67
- from jax.api_util import shaped_abstractify
65
+ from jax.api_util import shaped_abstractify, debug_info
68
66
  from jax.extend.linear_util import transformation_with_aux, wrap_init
69
67
  from jax.interpreters import partial_eval as pe
70
68
 
@@ -745,7 +743,7 @@ def _make_jaxpr(
745
743
  @wraps(fun)
746
744
  @api_boundary
747
745
  def make_jaxpr_f(*args, **kwargs):
748
- f = wrap_init(fun)
746
+ f = wrap_init(fun, debug_info=debug_info('make_jaxpr', fun, args, kwargs))
749
747
  if static_argnums:
750
748
  dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
751
749
  f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
@@ -754,12 +752,12 @@ def _make_jaxpr(
754
752
  f, out_tree = _flatten_fun(f, in_tree)
755
753
  f = annotate(f, in_type)
756
754
  if jax.__version_info__ < (0, 5, 0):
757
- debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
755
+ debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
758
756
  with ExitStack() as stack:
759
757
  if axis_env is not None:
760
758
  stack.enter_context(extend_axis_env_nd(axis_env))
761
759
  if jax.__version_info__ < (0, 5, 0):
762
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
760
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
763
761
  else:
764
762
  jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
765
763
  closed_jaxpr = ClosedJaxpr(jaxpr, consts)
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import copy
19
18
  import importlib.util
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from __future__ import annotations
16
15
 
17
16
  import jax
18
17
  import jax.core
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from functools import wraps
19
17
  from typing import Sequence, Tuple
20
18
 
brainstate/environ.py CHANGED
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  import contextlib
21
19
  import dataclasses
22
20
  import functools
@@ -18,8 +18,6 @@
18
18
  Shared neural network activations and other functions.
19
19
  """
20
20
 
21
- from __future__ import annotations
22
-
23
21
  from typing import Any, Union, Sequence
24
22
 
25
23
  import brainunit as u
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from typing import Optional, Union
19
17
 
20
18
  import brainunit as u
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from functools import partial
19
17
 
20
18
  import jax