brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240612__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_module.py CHANGED
@@ -46,7 +46,6 @@ For handling the delays:
46
46
 
47
47
  """
48
48
 
49
- import inspect
50
49
  import math
51
50
  import numbers
52
51
  from collections import namedtuple
@@ -58,12 +57,12 @@ import jax.numpy as jnp
58
57
  import numpy as np
59
58
 
60
59
  from . import environ
61
- from ._utils import set_module_as
62
60
  from ._state import State, StateDictManager, visible_state_dict
63
- from .util import unique_name, DictManager, get_unique_name, DotDict
61
+ from ._utils import set_module_as
64
62
  from .math import get_dtype
65
63
  from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
66
64
  from .transform._jit_error import jit_error
65
+ from .util import unique_name, DictManager, get_unique_name
67
66
 
68
67
  Shape = Union[int, Sequence[int]]
69
68
  PyTree = Any
@@ -92,7 +91,7 @@ __all__ = [
92
91
  'call_order',
93
92
 
94
93
  # state processing
95
- 'init_states', 'load_states', 'save_states', 'assign_state_values',
94
+ 'init_states', 'reset_states', 'load_states', 'save_states', 'assign_state_values',
96
95
  ]
97
96
 
98
97
 
@@ -271,6 +270,12 @@ class Module(object):
271
270
  """
272
271
  pass
273
272
 
273
+ def reset_state(self, *args, **kwargs):
274
+ """
275
+ State resetting function.
276
+ """
277
+ pass
278
+
274
279
  def save_state(self, **kwargs) -> Dict:
275
280
  """Save states as a dictionary. """
276
281
  return self.states(include_self=True, level=0, method='absolute')
@@ -1115,6 +1120,12 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1115
1120
  fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
1116
1121
  self.history = State(jax.tree.map(fun, self.target_info))
1117
1122
 
1123
+ def reset_state(self, batch_size: int = None, **kwargs):
1124
+ if batch_size is not None:
1125
+ assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
1126
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
1127
+ self.history.value = jax.tree.map(fun, self.target_info)
1128
+
1118
1129
  def register_entry(
1119
1130
  self,
1120
1131
  entry: str,
@@ -1344,7 +1355,7 @@ def call_order(level: int = 0):
1344
1355
  @set_module_as('brainstate')
1345
1356
  def init_states(target: Module, *args, **kwargs) -> Module:
1346
1357
  """
1347
- Reset states of all children nodes in the given target.
1358
+ Initialize states of all children nodes in the given target.
1348
1359
 
1349
1360
  Args:
1350
1361
  target: The target Module.
@@ -1368,6 +1379,33 @@ def init_states(target: Module, *args, **kwargs) -> Module:
1368
1379
  return target
1369
1380
 
1370
1381
 
1382
+ @set_module_as('brainstate')
1383
+ def reset_states(target: Module, *args, **kwargs) -> Module:
1384
+ """
1385
+ Reset states of all children nodes in the given target.
1386
+
1387
+ Args:
1388
+ target: The target Module.
1389
+
1390
+ Returns:
1391
+ The target Module.
1392
+ """
1393
+ nodes_with_order = []
1394
+
1395
+ # reset node whose `init_state` has no `call_order`
1396
+ for node in list(target.nodes().values()):
1397
+ if not hasattr(node.reset_state, 'call_order'):
1398
+ node.reset_state(*args, **kwargs)
1399
+ else:
1400
+ nodes_with_order.append(node)
1401
+
1402
+ # reset the node's states
1403
+ for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
1404
+ node.reset_state(*args, **kwargs)
1405
+
1406
+ return target
1407
+
1408
+
1371
1409
  @set_module_as('brainstate')
1372
1410
  def load_states(target: Module, state_dict: Dict, **kwargs):
1373
1411
  """Copy parameters and buffers from :attr:`state_dict` into
brainstate/_state.py CHANGED
@@ -59,6 +59,23 @@ _global_context_to_check_state_tree = [False]
59
59
  def check_state_value_tree() -> None:
60
60
  """
61
61
  The contex manager to check weather the tree structure of the state value keeps consistently.
62
+
63
+ Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
64
+ the tree structure of the value is not checked to avoid off the repeated evaluation.
65
+ If you want to check the tree structure of the value once the new value is assigned,
66
+ you can use this context manager.
67
+
68
+ Example::
69
+
70
+ ```python
71
+ state = brainstate.ShortTermState(jnp.zeros((2, 3)))
72
+ with check_state_value_tree():
73
+ state.value = jnp.zeros((2, 3))
74
+
75
+ # The following code will raise an error.
76
+ state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
77
+ ```
78
+
62
79
  """
63
80
  try:
64
81
  _global_context_to_check_state_tree.append(True)
brainstate/environ.py CHANGED
@@ -18,7 +18,8 @@ from .util import MemScaling, IdMemScaling
18
18
  __all__ = [
19
19
  'set', 'context', 'get', 'all',
20
20
  'set_host_device_count', 'set_platform',
21
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
21
+ 'get_host_device_count', 'get_platform',
22
+ 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
22
23
  'tolerance',
23
24
  'dftype', 'ditype', 'dutype', 'dctype',
24
25
  ]
@@ -20,6 +20,7 @@ from ._normalization import *
20
20
  from ._normalization import __all__ as __others_all__
21
21
  from ._spikes import *
22
22
  from ._spikes import __all__ as __spikes_all__
23
+ from ._others import *
24
+ from ._others import __all__ as __others_all__
23
25
 
24
- __all__ = __spikes_all__ + __others_all__ + __activations_all__
25
-
26
+ __all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
@@ -27,7 +27,7 @@ import jax.numpy as jnp
27
27
  from jax.scipy.special import logsumexp
28
28
  from jax.typing import ArrayLike
29
29
 
30
- from brainstate import math, random
30
+ from .. import math, random
31
31
 
32
32
  __all__ = [
33
33
  "tanh",
@@ -20,11 +20,14 @@ from typing import Optional
20
20
  import jax
21
21
  import jax.numpy as jnp
22
22
 
23
+ from .._utils import set_module_as
24
+
23
25
  __all__ = [
24
26
  'weight_standardization',
25
27
  ]
26
28
 
27
29
 
30
+ @set_module_as('brainstate.functional')
28
31
  def weight_standardization(
29
32
  w: jax.typing.ArrayLike,
30
33
  eps: float = 1e-4,
@@ -0,0 +1,49 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from functools import partial
19
+ from typing import Any
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+ PyTree = Any
25
+
26
+ __all__ = [
27
+ 'clip_grad_norm',
28
+ ]
29
+
30
+
31
+ def clip_grad_norm(
32
+ grad: PyTree,
33
+ max_norm: float | jax.Array,
34
+ norm_type: int | str | None = None
35
+ ):
36
+ """
37
+ Clips gradient norm of an iterable of parameters.
38
+
39
+ The norm is computed over all gradients together, as if they were
40
+ concatenated into a single vector. Gradients are modified in-place.
41
+
42
+ Args:
43
+ grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
44
+ max_norm (float): max norm of the gradients.
45
+ norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
46
+ """
47
+ norm_fn = partial(jnp.linalg.norm, ord=norm_type)
48
+ norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
49
+ return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
brainstate/nn/__init__.py CHANGED
@@ -21,6 +21,8 @@ from ._dynamics import *
21
21
  from ._dynamics import __all__ as dynamics_all
22
22
  from ._elementwise import *
23
23
  from ._elementwise import __all__ as elementwise_all
24
+ from ._embedding import *
25
+ from ._embedding import __all__ as embed_all
24
26
  from ._misc import *
25
27
  from ._misc import __all__ as _misc_all
26
28
  from ._normalizations import *
@@ -43,6 +45,7 @@ __all__ = (
43
45
  connections_all +
44
46
  dynamics_all +
45
47
  elementwise_all +
48
+ embed_all +
46
49
  normalizations_all +
47
50
  others_all +
48
51
  poolings_all +
@@ -58,6 +61,7 @@ del (
58
61
  connections_all,
59
62
  dynamics_all,
60
63
  elementwise_all,
64
+ embed_all,
61
65
  normalizations_all,
62
66
  others_all,
63
67
  poolings_all,
brainstate/nn/_base.py CHANGED
@@ -55,22 +55,24 @@ class ExplicitInOutSize(Mixin):
55
55
 
56
56
  @property
57
57
  def in_size(self) -> Tuple[int, ...]:
58
- if self._in_size is None:
59
- raise ValueError(f"The input shape is not set in this node: {self} ")
60
58
  return self._in_size
61
59
 
62
60
  @in_size.setter
63
- def in_size(self, in_size: Sequence[int]):
61
+ def in_size(self, in_size: Sequence[int] | int):
62
+ if isinstance(in_size, int):
63
+ in_size = (in_size,)
64
+ assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {type(in_size)}"
64
65
  self._in_size = tuple(in_size)
65
66
 
66
67
  @property
67
68
  def out_size(self) -> Tuple[int, ...]:
68
- if self._out_size is None:
69
- raise ValueError(f"The output shape is not set in this node: {self}")
70
69
  return self._out_size
71
70
 
72
71
  @out_size.setter
73
- def out_size(self, out_size: Sequence[int]):
72
+ def out_size(self, out_size: Sequence[int] | int):
73
+ if isinstance(out_size, int):
74
+ out_size = (out_size,)
75
+ assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}"
74
76
  self._out_size = tuple(out_size)
75
77
 
76
78
 
@@ -152,7 +154,8 @@ class Sequential(Module, UpdateReturn, Container, ExplicitInOutSize):
152
154
  self.children = visible_module_dict(self.format_elements(object, first, *tuple_modules, **dict_modules))
153
155
 
154
156
  # the input and output shape
155
- self.in_size = tuple(first.in_size)
157
+ if first.in_size is not None:
158
+ self.in_size = first.in_size
156
159
  self.out_size = tuple(in_size)
157
160
 
158
161
  def _format_module(self, module, in_size):
@@ -103,6 +103,9 @@ class IF(Neuron):
103
103
  def init_state(self, batch_size: int = None, **kwargs):
104
104
  self.V = ShortTermState(init.param(jnp.zeros, self.varshape, batch_size))
105
105
 
106
+ def reset_state(self, batch_size: int = None, **kwargs):
107
+ self.V.value = init.param(jnp.zeros, self.varshape, batch_size)
108
+
106
109
  def get_spike(self, V=None):
107
110
  V = self.V.value if V is None else V
108
111
  v_scaled = (V - self.V_th) / self.V_th
@@ -160,6 +163,9 @@ class LIF(Neuron):
160
163
  def init_state(self, batch_size: int = None, **kwargs):
161
164
  self.V = ShortTermState(init.param(init.Constant(self.V_reset), self.varshape, batch_size))
162
165
 
166
+ def reset_state(self, batch_size: int = None, **kwargs):
167
+ self.V.value = init.param(init.Constant(self.V_reset), self.varshape, batch_size)
168
+
163
169
  def get_spike(self, V=None):
164
170
  V = self.V.value if V is None else V
165
171
  v_scaled = (V - self.V_th) / self.V_th
@@ -214,6 +220,10 @@ class ALIF(Neuron):
214
220
  self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
215
221
  self.a = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
216
222
 
223
+ def reset_state(self, batch_size: int = None, **kwargs):
224
+ self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
225
+ self.a.value = init.param(init.Constant(0.), self.varshape, batch_size)
226
+
217
227
  def get_spike(self, V=None, a=None):
218
228
  V = self.V.value if V is None else V
219
229
  a = self.a.value if a is None else a
@@ -275,6 +285,9 @@ class Expon(Synapse):
275
285
  def init_state(self, batch_size: int = None, **kwargs):
276
286
  self.g = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
277
287
 
288
+ def reset_state(self, batch_size: int = None, **kwargs):
289
+ self.g.value = init.param(init.Constant(0.), self.varshape, batch_size)
290
+
278
291
  def update(self, x=None):
279
292
  self.g.value = exp_euler_step(self.dg, self.g.value, environ.get('t'))
280
293
  if x is not None:
@@ -325,6 +338,10 @@ class STP(Synapse):
325
338
  self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
326
339
  self.u = ShortTermState(init.param(init.Constant(self.U), self.varshape, batch_size))
327
340
 
341
+ def reset_state(self, batch_size: int = None, **kwargs):
342
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
343
+ self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
344
+
328
345
  def du(self, u, t):
329
346
  return self.U - u / self.tau_f
330
347
 
@@ -390,6 +407,9 @@ class STD(Synapse):
390
407
  def init_state(self, batch_size: int = None, **kwargs):
391
408
  self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
392
409
 
410
+ def reset_state(self, batch_size: int = None, **kwargs):
411
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
412
+
393
413
  def update(self, pre_spike):
394
414
  t = environ.get('t')
395
415
  x = exp_euler_step(self.dx, self.x.value, t)
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Optional, Callable, Union
17
+
18
+ from ._base import DnnLayer
19
+ from .. import init
20
+ from .._state import ParamState
21
+ from ..mixin import Mode, Training
22
+ from ..typing import ArrayLike
23
+
24
+ __all__ = [
25
+ 'Embedding',
26
+ ]
27
+
28
+
29
+ class Embedding(DnnLayer):
30
+ r"""
31
+ A simple lookup table that stores embeddings of a fixed size.
32
+
33
+ Args:
34
+ num_embeddings: Size of embedding dictionary. Must be non-negative.
35
+ embedding_size: Size of each embedding vector. Must be non-negative.
36
+ embed_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
37
+
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_embeddings: int,
43
+ embedding_size: int,
44
+ embed_init: Union[Callable, ArrayLike] = init.LecunUniform(),
45
+ name: Optional[str] = None,
46
+ mode: Optional[Mode] = None,
47
+ ):
48
+ super().__init__(name=name, mode=mode)
49
+ if num_embeddings < 0:
50
+ raise ValueError("num_embeddings must not be negative.")
51
+ if embedding_size < 0:
52
+ raise ValueError("embedding_size must not be negative.")
53
+ self.num_embeddings = num_embeddings
54
+ self.embedding_size = embedding_size
55
+ self.out_size = (embedding_size,)
56
+
57
+ weight = init.param(embed_init, (self.num_embeddings, self.embedding_size))
58
+ if self.mode.has(Training):
59
+ self.weight = ParamState(weight)
60
+ else:
61
+ self.weight = weight
62
+
63
+ def update(self, indices: ArrayLike):
64
+ if self.mode.has(Training):
65
+ return self.weight.value[indices]
66
+ return self.weight[indices]
@@ -90,6 +90,9 @@ class ValinaRNNCell(RNNCell):
90
90
  def init_state(self, batch_size: int = None, **kwargs):
91
91
  self.h = ShortTermState(init.param(self._state_initializer, self.num_out, batch_size))
92
92
 
93
+ def reset_state(self, batch_size: int = None, **kwargs):
94
+ self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
95
+
93
96
  def update(self, x):
94
97
  xh = jnp.concatenate([x, self.h.value], axis=-1)
95
98
  h = self.W(xh)
@@ -147,6 +150,9 @@ class GRUCell(RNNCell):
147
150
  def init_state(self, batch_size: int = None, **kwargs):
148
151
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
149
152
 
153
+ def reset_state(self, batch_size: int = None, **kwargs):
154
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
155
+
150
156
  def update(self, x):
151
157
  old_h = self.h.value
152
158
  xh = jnp.concatenate([x, old_h], axis=-1)
@@ -224,6 +230,9 @@ class MGUCell(RNNCell):
224
230
  def init_state(self, batch_size: int = None, **kwargs):
225
231
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
226
232
 
233
+ def reset_state(self, batch_size: int = None, **kwargs):
234
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
235
+
227
236
  def update(self, x):
228
237
  old_h = self.h.value
229
238
  xh = jnp.concatenate([x, old_h], axis=-1)
@@ -327,6 +336,10 @@ class LSTMCell(RNNCell):
327
336
  self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
328
337
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
329
338
 
339
+ def reset_state(self, batch_size: int = None, **kwargs):
340
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
341
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
342
+
330
343
  def update(self, x):
331
344
  h, c = self.h.value, self.c.value
332
345
  xh = jnp.concat([x, h], axis=-1)
@@ -379,6 +392,10 @@ class URLSTMCell(RNNCell):
379
392
  self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
380
393
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
381
394
 
395
+ def reset_state(self, batch_size: int = None, **kwargs):
396
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
397
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
398
+
382
399
  def update(self, x: ArrayLike) -> ArrayLike:
383
400
  h, c = self.h.value, self.c.value
384
401
  xh = jnp.concat([x, h], axis=-1)
brainstate/nn/_readout.py CHANGED
@@ -66,6 +66,9 @@ class LeakyRateReadout(DnnLayer):
66
66
  def init_state(self, batch_size=None, **kwargs):
67
67
  self.r = ShortTermState(init.param(init.Constant(0.), self.out_size, batch_size))
68
68
 
69
+ def reset_state(self, batch_size=None, **kwargs):
70
+ self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
71
+
69
72
  def update(self, x):
70
73
  r = self.decay * self.r.value + x @ self.weight.value
71
74
  self.r.value = r
@@ -109,6 +112,9 @@ class LeakySpikeReadout(Neuron):
109
112
  def init_state(self, batch_size, **kwargs):
110
113
  self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
111
114
 
115
+ def reset_state(self, batch_size, **kwargs):
116
+ self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
117
+
112
118
  @property
113
119
  def spike(self):
114
120
  return self.get_spike(self.V.value)
@@ -34,3 +34,16 @@ class TestMultiStepLR(unittest.TestCase):
34
34
  self.assertTrue(jnp.allclose(r, 0.001))
35
35
  else:
36
36
  self.assertTrue(jnp.allclose(r, 0.0001))
37
+
38
+ def test2(self):
39
+ lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
40
+ for i in range(40):
41
+ r = lr(i)
42
+ if i < 10:
43
+ self.assertEqual(r, 0.1)
44
+ elif i < 20:
45
+ self.assertTrue(jnp.allclose(r, 0.01))
46
+ elif i < 30:
47
+ self.assertTrue(jnp.allclose(r, 0.001))
48
+ else:
49
+ self.assertTrue(jnp.allclose(r, 0.0001))
@@ -23,8 +23,8 @@ import jax
23
23
  from jax._src import sharding_impls
24
24
  from jax.lib import xla_client as xc
25
25
 
26
- from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
27
26
  from brainstate._utils import set_module_as
27
+ from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
28
28
 
29
29
  __all__ = ['jit']
30
30
 
@@ -33,10 +33,13 @@ class JittedFunction(Callable):
33
33
  """
34
34
  A wrapped version of ``fun``, set up for just-in-time compilation.
35
35
  """
36
- origin_fun: Callable # the original function
36
+ origin_fun: Callable # the original function
37
37
  stateful_fun: StatefulFunction # the stateful function for extracting states
38
38
  jitted_fun: jax.stages.Wrapped # the jitted function
39
- clear_cache: Callable # clear the cache of the jitted function
39
+ clear_cache: Callable # clear the cache of the jitted function
40
+
41
+ def __call__(self, *args, **kwargs):
42
+ pass
40
43
 
41
44
 
42
45
  def _get_jitted_fun(
@@ -85,12 +88,16 @@ def _get_jitted_fun(
85
88
  jit_fun.clear_cache()
86
89
 
87
90
  jitted_fun: JittedFunction
91
+
88
92
  # the original function
89
93
  jitted_fun.origin_fun = fun.fun
94
+
90
95
  # the stateful function for extracting states
91
96
  jitted_fun.stateful_fun = fun
97
+
92
98
  # the jitted function
93
99
  jitted_fun.jitted_fun = jit_fun
100
+
94
101
  # clear cache
95
102
  jitted_fun.clear_cache = clear_cache
96
103
 
@@ -99,18 +106,18 @@ def _get_jitted_fun(
99
106
 
100
107
  @set_module_as('brainstate.transform')
101
108
  def jit(
102
- fun: Callable = None,
103
- in_shardings=sharding_impls.UNSPECIFIED,
104
- out_shardings=sharding_impls.UNSPECIFIED,
105
- static_argnums: int | Sequence[int] | None = None,
106
- donate_argnums: int | Sequence[int] | None = None,
107
- donate_argnames: str | Iterable[str] | None = None,
108
- keep_unused: bool = False,
109
- device: xc.Device | None = None,
110
- backend: str | None = None,
111
- inline: bool = False,
112
- abstracted_axes: Any | None = None,
113
- **kwargs
109
+ fun: Callable = None,
110
+ in_shardings=sharding_impls.UNSPECIFIED,
111
+ out_shardings=sharding_impls.UNSPECIFIED,
112
+ static_argnums: int | Sequence[int] | None = None,
113
+ donate_argnums: int | Sequence[int] | None = None,
114
+ donate_argnames: str | Iterable[str] | None = None,
115
+ keep_unused: bool = False,
116
+ device: xc.Device | None = None,
117
+ backend: str | None = None,
118
+ inline: bool = False,
119
+ abstracted_axes: Any | None = None,
120
+ **kwargs
114
121
  ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
115
122
  """
116
123
  Sets up ``fun`` for just-in-time compilation with XLA.
@@ -228,12 +235,31 @@ def jit(
228
235
 
229
236
  if fun is None:
230
237
  def wrapper(fun_again: Callable) -> JittedFunction:
231
- return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums,
232
- donate_argnums, donate_argnames, keep_unused,
233
- device, backend, inline, abstracted_axes, **kwargs)
238
+ return _get_jitted_fun(fun_again,
239
+ in_shardings,
240
+ out_shardings,
241
+ static_argnums,
242
+ donate_argnums,
243
+ donate_argnames,
244
+ keep_unused,
245
+ device,
246
+ backend,
247
+ inline,
248
+ abstracted_axes,
249
+ **kwargs)
250
+
234
251
  return wrapper
235
252
 
236
253
  else:
237
- return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums,
238
- donate_argnums, donate_argnames, keep_unused,
239
- device, backend, inline, abstracted_axes, **kwargs)
254
+ return _get_jitted_fun(fun,
255
+ in_shardings,
256
+ out_shardings,
257
+ static_argnums,
258
+ donate_argnums,
259
+ donate_argnames,
260
+ keep_unused,
261
+ device,
262
+ backend,
263
+ inline,
264
+ abstracted_axes,
265
+ **kwargs)