brainstate 0.1.3__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +1 -16
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +31 -2
  9. brainstate/nn/__init__.py +8 -5
  10. brainstate/nn/_delay.py +13 -1
  11. brainstate/nn/_dropout.py +5 -4
  12. brainstate/nn/_dynamics.py +39 -44
  13. brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
  14. brainstate/nn/_linear_mv.py +1 -1
  15. brainstate/nn/_module.py +5 -5
  16. brainstate/nn/_projection.py +190 -98
  17. brainstate/nn/_synapse.py +5 -9
  18. brainstate/nn/_synaptic_projection.py +376 -86
  19. brainstate/surrogate.py +1 -1
  20. brainstate/typing.py +1 -1
  21. brainstate/util/__init__.py +14 -14
  22. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  23. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  24. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/RECORD +35 -35
  25. /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  26. /brainstate/util/{_caller.py → caller.py} +0 -0
  27. /brainstate/util/{_error.py → error.py} +0 -0
  28. /brainstate/util/{_others.py → others.py} +0 -0
  29. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  30. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  31. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  32. /brainstate/util/{_struct.py → struct.py} +0 -0
  33. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  34. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  35. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.3"
20
+ __version__ = "0.1.4"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
@@ -16,10 +16,9 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- import importlib.util
20
19
  from contextlib import contextmanager
21
20
  from functools import partial
22
- from typing import Iterable, Hashable, TypeVar, Callable, TYPE_CHECKING
21
+ from typing import Iterable, Hashable, TypeVar, Callable
23
22
 
24
23
  import jax
25
24
 
@@ -31,7 +30,6 @@ __all__ = [
31
30
  'get_aval',
32
31
  'Tracer',
33
32
  'to_concrete_aval',
34
- 'brainevent',
35
33
  'safe_map',
36
34
  'safe_zip',
37
35
  'unzip2',
@@ -47,8 +45,6 @@ T3 = TypeVar("T3")
47
45
 
48
46
  from saiunit._compatible_import import wrap_init
49
47
 
50
- brainevent_installed = importlib.util.find_spec('brainevent') is not None
51
-
52
48
  from jax.core import get_aval, Tracer
53
49
 
54
50
  if jax.__version_info__ < (0, 5, 0):
@@ -150,14 +146,3 @@ def to_concrete_aval(aval):
150
146
  return aval.to_concrete_value()
151
147
  return aval
152
148
 
153
-
154
- if not brainevent_installed:
155
- if not TYPE_CHECKING:
156
- class BrainEvent:
157
- def __getattr__(self, item):
158
- raise ImportError('brainevent is not installed, please install brainevent first.')
159
-
160
- brainevent = BrainEvent()
161
-
162
- else:
163
- import brainevent
@@ -51,6 +51,7 @@ def _get_jitted_fun(
51
51
  out_shardings,
52
52
  static_argnums,
53
53
  donate_argnums,
54
+ static_argnames,
54
55
  donate_argnames,
55
56
  keep_unused,
56
57
  device,
@@ -59,10 +60,12 @@ def _get_jitted_fun(
59
60
  abstracted_axes,
60
61
  **kwargs
61
62
  ) -> JittedFunction:
62
- static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
63
+ static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
64
+ donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
63
65
  fun = StatefulFunction(
64
66
  fun,
65
67
  static_argnums=static_argnums,
68
+ static_argnames=static_argnames,
66
69
  abstracted_axes=abstracted_axes,
67
70
  cache_type='jit',
68
71
  name='jit'
@@ -70,7 +73,8 @@ def _get_jitted_fun(
70
73
  jit_fun = jax.jit(
71
74
  fun.jaxpr_call,
72
75
  static_argnums=tuple(i + 1 for i in static_argnums),
73
- donate_argnums=donate_argnums,
76
+ static_argnames=static_argnames,
77
+ donate_argnums=tuple(i + 1 for i in donate_argnums),
74
78
  donate_argnames=donate_argnames,
75
79
  keep_unused=keep_unused,
76
80
  device=device,
@@ -179,6 +183,7 @@ def jit(
179
183
  out_shardings=sharding_impls.UNSPECIFIED,
180
184
  static_argnums: int | Sequence[int] | None = None,
181
185
  donate_argnums: int | Sequence[int] | None = None,
186
+ static_argnames: str | Sequence[str] | None = None,
182
187
  donate_argnames: str | Iterable[str] | None = None,
183
188
  keep_unused: bool = False,
184
189
  device: Device | None = None,
@@ -190,9 +195,6 @@ def jit(
190
195
  """
191
196
  Sets up ``fun`` for just-in-time compilation with XLA.
192
197
 
193
- Does not support setting ``static_argnames`` as in ``jax.jit()``.
194
-
195
-
196
198
  Args:
197
199
  fun: Function to be jitted.
198
200
  in_shardings: Pytree of structure matching that of arguments to ``fun``,
@@ -246,6 +248,11 @@ def jit(
246
248
  provided, ``inspect.signature`` is not used, and only actual
247
249
  parameters listed in either ``static_argnums`` or ``static_argnames`` will
248
250
  be treated as static.
251
+ static_argnames: An optional string or collection of strings specifying
252
+ which named arguments are treated as static (compile-time constant).
253
+ Operations that only depend on static arguments will be constant-folded in
254
+ Python (during tracing), and so the corresponding argument values can be
255
+ any Python object.
249
256
  donate_argnums: Specify which positional argument buffers are "donated" to
250
257
  the computation. It is safe to donate argument buffers if you no longer
251
258
  need them once the computation has finished. In some cases XLA can make
@@ -309,6 +316,7 @@ def jit(
309
316
  out_shardings=out_shardings,
310
317
  static_argnums=static_argnums,
311
318
  donate_argnums=donate_argnums,
319
+ static_argnames=static_argnames,
312
320
  donate_argnames=donate_argnames,
313
321
  keep_unused=keep_unused,
314
322
  device=device,
@@ -327,6 +335,7 @@ def jit(
327
335
  out_shardings,
328
336
  static_argnums,
329
337
  donate_argnums,
338
+ static_argnames,
330
339
  donate_argnames,
331
340
  keep_unused,
332
341
  device,
@@ -88,6 +88,12 @@ __all__ = [
88
88
  ]
89
89
 
90
90
 
91
+ def _ensure_str(x: str) -> str:
92
+ if not isinstance(x, str):
93
+ raise TypeError(f"argument is not a string: {x}")
94
+ return x
95
+
96
+
91
97
  def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
92
98
  """Convert x to a tuple of indices."""
93
99
  x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
@@ -97,6 +103,14 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
97
103
  return tuple(safe_map(operator.index, x))
98
104
 
99
105
 
106
+ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
107
+ """Convert x to a tuple of strings."""
108
+ if isinstance(x, str):
109
+ return (x,)
110
+ else:
111
+ return tuple(safe_map(_ensure_str, x))
112
+
113
+
100
114
  def _jax_v04_new_arg_fn(frame, trace, aval):
101
115
  """
102
116
  Transform a new argument to a tracer.
@@ -155,6 +169,9 @@ def _init_state_trace_stack(name) -> StateTraceStack:
155
169
  return state_trace
156
170
 
157
171
 
172
+ default_cache_key = ((), ())
173
+
174
+
158
175
  class StatefulFunction(PrettyObject):
159
176
  """
160
177
  A wrapper class for a function that collects the states that are read and written by the function. The states are
@@ -170,6 +187,7 @@ class StatefulFunction(PrettyObject):
170
187
  arguments and return value should be arrays, scalars, or standard Python
171
188
  containers (tuple/list/dict) thereof.
172
189
  static_argnums: See the :py:func:`jax.jit` docstring.
190
+ static_argnames: See the :py:func:`jax.jit` docstring.
173
191
  axis_env: Optional, a sequence of pairs where the first element is an axis
174
192
  name and the second element is a positive integer representing the size of
175
193
  the mapped axis with that name. This parameter is useful when lowering
@@ -199,6 +217,7 @@ class StatefulFunction(PrettyObject):
199
217
  self,
200
218
  fun: Callable,
201
219
  static_argnums: Union[int, Iterable[int]] = (),
220
+ static_argnames: Union[str, Iterable[str]] = (),
202
221
  axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
203
222
  abstracted_axes: Optional[Any] = None,
204
223
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
@@ -207,11 +226,12 @@ class StatefulFunction(PrettyObject):
207
226
  ):
208
227
  # explicit parameters
209
228
  self.fun = fun
210
- self.static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
229
+ self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
230
+ self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
211
231
  self.axis_env = axis_env
212
232
  self.abstracted_axes = abstracted_axes
213
233
  self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
214
- assert cache_type in [None, 'jit']
234
+ assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
215
235
  self.name = name
216
236
 
217
237
  # implicit parameters
@@ -226,7 +246,7 @@ class StatefulFunction(PrettyObject):
226
246
  return None
227
247
  return k, v
228
248
 
229
- def get_jaxpr(self, cache_key: Hashable = ()) -> ClosedJaxpr:
249
+ def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
230
250
  """
231
251
  Read the JAX Jaxpr representation of the function.
232
252
 
@@ -236,11 +256,13 @@ class StatefulFunction(PrettyObject):
236
256
  Returns:
237
257
  The JAX Jaxpr representation of the function.
238
258
  """
259
+ if cache_key is None:
260
+ cache_key = default_cache_key
239
261
  if cache_key not in self._cached_jaxpr:
240
262
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
241
263
  return self._cached_jaxpr[cache_key]
242
264
 
243
- def get_out_shapes(self, cache_key: Hashable = ()) -> PyTree:
265
+ def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
244
266
  """
245
267
  Read the output shapes of the function.
246
268
 
@@ -250,11 +272,13 @@ class StatefulFunction(PrettyObject):
250
272
  Returns:
251
273
  The output shapes of the function.
252
274
  """
275
+ if cache_key is None:
276
+ cache_key = default_cache_key
253
277
  if cache_key not in self._cached_out_shapes:
254
278
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
255
279
  return self._cached_out_shapes[cache_key]
256
280
 
257
- def get_out_treedef(self, cache_key: Hashable = ()) -> PyTree:
281
+ def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
258
282
  """
259
283
  Read the output tree of the function.
260
284
 
@@ -264,11 +288,13 @@ class StatefulFunction(PrettyObject):
264
288
  Returns:
265
289
  The output tree of the function.
266
290
  """
291
+ if cache_key is None:
292
+ cache_key = default_cache_key
267
293
  if cache_key not in self._cached_jaxpr_out_tree:
268
294
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
269
295
  return self._cached_jaxpr_out_tree[cache_key]
270
296
 
271
- def get_state_trace(self, cache_key: Hashable = ()) -> StateTraceStack:
297
+ def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
272
298
  """
273
299
  Read the state trace of the function.
274
300
 
@@ -278,11 +304,13 @@ class StatefulFunction(PrettyObject):
278
304
  Returns:
279
305
  The state trace of the function.
280
306
  """
307
+ if cache_key is None:
308
+ cache_key = default_cache_key
281
309
  if cache_key not in self._cached_state_trace:
282
310
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
283
311
  return self._cached_state_trace[cache_key]
284
312
 
285
- def get_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
313
+ def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
286
314
  """
287
315
  Read the states that are read and written by the function.
288
316
 
@@ -292,9 +320,11 @@ class StatefulFunction(PrettyObject):
292
320
  Returns:
293
321
  The states that are read and written by the function.
294
322
  """
323
+ if cache_key is None:
324
+ cache_key = default_cache_key
295
325
  return tuple(self.get_state_trace(cache_key).states)
296
326
 
297
- def get_read_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
327
+ def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
298
328
  """
299
329
  Read the states that are read by the function.
300
330
 
@@ -304,9 +334,11 @@ class StatefulFunction(PrettyObject):
304
334
  Returns:
305
335
  The states that are read by the function.
306
336
  """
337
+ if cache_key is None:
338
+ cache_key = default_cache_key
307
339
  return self.get_state_trace(cache_key).get_read_states()
308
340
 
309
- def get_write_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
341
+ def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
310
342
  """
311
343
  Read the states that are written by the function.
312
344
 
@@ -316,6 +348,8 @@ class StatefulFunction(PrettyObject):
316
348
  Returns:
317
349
  The states that are written by the function.
318
350
  """
351
+ if cache_key is None:
352
+ cache_key = default_cache_key
319
353
  return self.get_state_trace(cache_key).get_write_states()
320
354
 
321
355
  def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
@@ -323,10 +357,11 @@ class StatefulFunction(PrettyObject):
323
357
  Get the static arguments from the arguments.
324
358
 
325
359
  Args:
326
- *args: The arguments to the function.
360
+ *args: The arguments to the function.
361
+ **kwargs: The keyword arguments to the function.
327
362
 
328
363
  Returns:
329
- The static arguments.
364
+ The static arguments and keyword arguments as a tuple.
330
365
  """
331
366
  if self.cache_type == 'jit':
332
367
  static_args, dyn_args = [], []
@@ -336,11 +371,18 @@ class StatefulFunction(PrettyObject):
336
371
  else:
337
372
  dyn_args.append(arg)
338
373
  dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
339
- dyn_kwargs = jax.tree.map(shaped_abstractify, jax.tree.leaves(kwargs))
340
- return tuple([tuple(static_args), tuple(dyn_args), tuple(dyn_kwargs)])
374
+ static_kwargs, dyn_kwargs = [], []
375
+ for k, v in kwargs.items():
376
+ if k in self.static_argnames:
377
+ static_kwargs.append((k, v))
378
+ else:
379
+ dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
380
+ return tuple([tuple(static_args), tuple(dyn_args), tuple(static_kwargs), tuple(dyn_kwargs)])
341
381
  elif self.cache_type is None:
342
382
  num_arg = len(args)
343
- return tuple(args[i] for i in self.static_argnums if i < num_arg)
383
+ static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
384
+ static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
385
+ return tuple([static_args, static_kwargs])
344
386
  else:
345
387
  raise ValueError(f"Invalid cache type: {self.cache_type}")
346
388
 
@@ -389,7 +431,7 @@ class StatefulFunction(PrettyObject):
389
431
  self._cached_state_trace.clear()
390
432
 
391
433
  def _wrapped_fun_to_eval(
392
- self, cache_key, *args, return_only_write: bool = False, **kwargs,
434
+ self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
393
435
  ) -> Tuple[Any, Tuple[State, ...]]:
394
436
  """
395
437
  Wrap the function and return the states that are read and written by the function and the output of the function.
@@ -405,7 +447,7 @@ class StatefulFunction(PrettyObject):
405
447
  state_trace = _init_state_trace_stack(self.name)
406
448
  self._cached_state_trace[cache_key] = state_trace
407
449
  with state_trace:
408
- out = self.fun(*args, **kwargs)
450
+ out = self.fun(*args, **dyn_kwargs, **static_kwargs)
409
451
  state_values = (
410
452
  state_trace.get_write_state_values(True)
411
453
  if return_only_write else
@@ -430,8 +472,9 @@ class StatefulFunction(PrettyObject):
430
472
  the structure, shape, dtypes, and named shapes of the output of ``fun``.
431
473
 
432
474
  Args:
433
- *args: The arguments to the function.
434
- **kwargs: The keyword arguments to the function.
475
+ *args: The arguments to the function.
476
+ **kwargs: The keyword arguments to the function.
477
+ return_only_write: If True, only return the states that are written by the function.
435
478
  """
436
479
 
437
480
  # static args
@@ -440,17 +483,24 @@ class StatefulFunction(PrettyObject):
440
483
  if cache_key not in self._cached_state_trace:
441
484
  try:
442
485
  # jaxpr
486
+ static_kwargs, dyn_kwargs = {}, {}
487
+ for k, v in kwargs.items():
488
+ if k in self.static_argnames:
489
+ static_kwargs[k] = v
490
+ else:
491
+ dyn_kwargs[k] = v
443
492
  jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
444
493
  functools.partial(
445
494
  self._wrapped_fun_to_eval,
446
495
  cache_key,
496
+ static_kwargs,
447
497
  return_only_write=return_only_write
448
498
  ),
449
499
  static_argnums=self.static_argnums,
450
500
  axis_env=self.axis_env,
451
501
  return_shape=True,
452
502
  abstracted_axes=self.abstracted_axes
453
- )(*args, **kwargs)
503
+ )(*args, **dyn_kwargs)
454
504
  # returns
455
505
  self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
456
506
  self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
@@ -483,6 +533,7 @@ class StatefulFunction(PrettyObject):
483
533
  assert len(state_vals) == len(states), 'State length mismatch.'
484
534
 
485
535
  # parameters
536
+ kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
486
537
  args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
487
538
  args = jax.tree.flatten((args, kwargs, state_vals))[0]
488
539
 
@@ -519,12 +570,16 @@ class StatefulFunction(PrettyObject):
519
570
  def make_jaxpr(
520
571
  fun: Callable,
521
572
  static_argnums: Union[int, Iterable[int]] = (),
573
+ static_argnames: Union[str, Iterable[str]] = (),
522
574
  axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
523
575
  return_shape: bool = False,
524
576
  abstracted_axes: Optional[Any] = None,
525
577
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
526
- ) -> Callable[..., (Tuple[ClosedJaxpr, Tuple[State, ...]] |
527
- Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])]:
578
+ ) -> Callable[
579
+ ...,
580
+ (Tuple[ClosedJaxpr, Tuple[State, ...]] |
581
+ Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
582
+ ]:
528
583
  """
529
584
  Creates a function that produces its jaxpr given example args.
530
585
 
@@ -533,6 +588,7 @@ def make_jaxpr(
533
588
  arguments and return value should be arrays, scalars, or standard Python
534
589
  containers (tuple/list/dict) thereof.
535
590
  static_argnums: See the :py:func:`jax.jit` docstring.
591
+ static_argnames: See the :py:func:`jax.jit` docstring.
536
592
  axis_env: Optional, a sequence of pairs where the first element is an axis
537
593
  name and the second element is a positive integer representing the size of
538
594
  the mapped axis with that name. This parameter is useful when lowering
@@ -605,11 +661,11 @@ def make_jaxpr(
605
661
  stateful_fun = StatefulFunction(
606
662
  fun,
607
663
  static_argnums=static_argnums,
664
+ static_argnames=static_argnames,
608
665
  axis_env=axis_env,
609
666
  abstracted_axes=abstracted_axes,
610
667
  state_returns=state_returns,
611
668
  name='make_jaxpr'
612
-
613
669
  )
614
670
 
615
671
  @wraps(fun)
@@ -88,7 +88,7 @@ class TestMakeJaxpr(unittest.TestCase):
88
88
  self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
89
89
  f3(jnp.zeros(1))))
90
90
 
91
- def test_compar_jax_make_jaxpr2(self):
91
+ def test_compare_jax_make_jaxpr2(self):
92
92
  st1 = brainstate.State(jnp.ones(10))
93
93
 
94
94
  def fa(x):
@@ -108,7 +108,7 @@ class TestMakeJaxpr(unittest.TestCase):
108
108
  print(jaxpr)
109
109
  print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
110
110
 
111
- def test_compar_jax_make_jaxpr3(self):
111
+ def test_compare_jax_make_jaxpr3(self):
112
112
  def fa(x):
113
113
  return 1.
114
114
 
@@ -121,6 +121,17 @@ class TestMakeJaxpr(unittest.TestCase):
121
121
  print(jaxpr)
122
122
  # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
123
 
124
+ def test_static_argnames(self):
125
+ def func4(a, b): # Arg is a pair
126
+ temp = a + jnp.sin(b) * 3.
127
+ c = brainstate.random.rand_like(a)
128
+ return jnp.sum(temp + c)
129
+
130
+ jaxpr, states = brainstate.compile.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
131
+ print()
132
+ print(jaxpr)
133
+ print(states)
134
+
124
135
 
125
136
  def test_return_states():
126
137
  import jax.numpy
@@ -25,7 +25,7 @@ import numpy as np
25
25
 
26
26
  from brainstate._state import State, TreefyState
27
27
  from brainstate.typing import Key
28
- from brainstate.util._pretty_pytree import PrettyObject
28
+ from brainstate.util.pretty_pytree import PrettyObject
29
29
  from ._graph_operation import register_graph_node_type
30
30
 
31
31
  __all__ = [
@@ -30,10 +30,10 @@ from typing_extensions import TypeGuard, Unpack
30
30
  from brainstate._state import State, TreefyState
31
31
  from brainstate._utils import set_module_as
32
32
  from brainstate.typing import PathParts, Filter, Predicate, Key
33
- from brainstate.util._caller import ApplyCaller, CallableProxy, DelayedAccessor
34
- from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
- from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
- from brainstate.util._struct import FrozenDict
33
+ from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
34
+ from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
+ from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
+ from brainstate.util.struct import FrozenDict
37
37
  from brainstate.util.filter import to_predicate
38
38
 
39
39
  _max_int = np.iinfo(np.int32).max
brainstate/mixin.py CHANGED
@@ -41,6 +41,14 @@ __all__ = [
41
41
  ]
42
42
 
43
43
 
44
+ def hashable(x):
45
+ try:
46
+ hash(x)
47
+ return True
48
+ except TypeError:
49
+ return False
50
+
51
+
44
52
  class Mixin(object):
45
53
  """Base Mixin object.
46
54
 
@@ -67,6 +75,14 @@ class ParamDesc(Mixin):
67
75
 
68
76
 
69
77
  class HashableDict(dict):
78
+ def __init__(self, the_dict: dict):
79
+ out = dict()
80
+ for k, v in the_dict.items():
81
+ if not hashable(v):
82
+ v = str(v) # convert to string if not hashable
83
+ out[k] = v
84
+ super().__init__(out)
85
+
70
86
  def __hash__(self):
71
87
  return hash(tuple(sorted(self.items())))
72
88
 
@@ -132,7 +148,6 @@ class AlignPost(Mixin):
132
148
  raise NotImplementedError
133
149
 
134
150
 
135
-
136
151
  class BindCondData(Mixin):
137
152
  """Bind temporary conductance data.
138
153
 
@@ -147,12 +162,26 @@ class BindCondData(Mixin):
147
162
  self._conductance = None
148
163
 
149
164
 
165
+ def not_implemented(func):
166
+
167
+ def wrapper(*args, **kwargs):
168
+ raise NotImplementedError(f'{func.__name__} is not implemented.')
169
+
170
+ wrapper.not_implemented = True
171
+ return wrapper
172
+
173
+
174
+
150
175
  class UpdateReturn(Mixin):
176
+ @not_implemented
151
177
  def update_return(self) -> PyTree:
152
178
  """
153
179
  The update function return of the model.
154
180
 
155
- It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
181
+ This function requires no parameters and must return a PyTree.
182
+
183
+ It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
184
+ initialize the output delay.
156
185
 
157
186
  """
158
187
  raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
brainstate/nn/__init__.py CHANGED
@@ -33,12 +33,14 @@ from ._embedding import *
33
33
  from ._embedding import __all__ as embed_all
34
34
  from ._exp_euler import *
35
35
  from ._exp_euler import __all__ as exp_euler_all
36
- from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
36
+ from ._fixedprob import *
37
+ from._fixedprob import __all__ as fixedprob_all
37
38
  from ._inputs import *
38
39
  from ._inputs import __all__ as inputs_all
39
40
  from ._linear import *
40
41
  from ._linear import __all__ as linear_all
41
- from ._linear_mv import EventLinear
42
+ from ._linear_mv import *
43
+ from ._linear_mv import __all__ as linear_mv_all
42
44
  from ._ltp import *
43
45
  from ._ltp import __all__ as ltp_all
44
46
  from ._module import *
@@ -69,9 +71,6 @@ from ._utils import __all__ as utils_all
69
71
  __all__ = (
70
72
  [
71
73
  'metrics',
72
- 'EventLinear',
73
- 'EventFixedProb',
74
- 'EventFixedNumConn',
75
74
  ]
76
75
  + collective_ops_all
77
76
  + common_all
@@ -87,6 +86,8 @@ __all__ = (
87
86
  + linear_all
88
87
  + normalizations_all
89
88
  + poolings_all
89
+ + fixedprob_all
90
+ + linear_mv_all
90
91
  + embed_all
91
92
  + dropout_all
92
93
  + elementwise_all
@@ -115,6 +116,8 @@ del (
115
116
  normalizations_all,
116
117
  poolings_all,
117
118
  embed_all,
119
+ fixedprob_all,
120
+ linear_mv_all,
118
121
  dropout_all,
119
122
  elementwise_all,
120
123
  dyn_neuron_all,
brainstate/nn/_delay.py CHANGED
@@ -330,7 +330,14 @@ class Delay(Module):
330
330
  indices = (delay_idx,) + indices
331
331
 
332
332
  # the delay data
333
- return jax.tree.map(lambda a: a[indices], self.history.value)
333
+ if self._unit is None:
334
+ return jax.tree.map(lambda a: a[indices], self.history.value)
335
+ else:
336
+ return jax.tree.map(
337
+ lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
338
+ self.history.value,
339
+ self._unit
340
+ )
334
341
 
335
342
  def retrieve_at_time(self, delay_time, *indices) -> PyTree:
336
343
  """
@@ -393,6 +400,9 @@ class Delay(Module):
393
400
  """
394
401
  assert self.history is not None, 'The delay history is not initialized.'
395
402
 
403
+ if self.take_aware_unit and self._unit is None:
404
+ self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
405
+
396
406
  # update the delay data at the rotation index
397
407
  if self.delay_method == _DELAY_ROTATE:
398
408
  i = environ.get(environ.I)
@@ -419,6 +429,8 @@ class Delay(Module):
419
429
  raise ValueError(f'Unknown updating method "{self.delay_method}"')
420
430
 
421
431
 
432
+
433
+
422
434
  class StateWithDelay(Delay):
423
435
  """
424
436
  A ``State`` type that defines the state in a differential equation.
brainstate/nn/_dropout.py CHANGED
@@ -409,7 +409,8 @@ class DropoutFixed(ElementWiseBlock):
409
409
  self.out_size = in_size
410
410
 
411
411
  def init_state(self, batch_size=None, **kwargs):
412
- self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
412
+ if self.prob < 1.:
413
+ self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
413
414
 
414
415
  def update(self, x):
415
416
  dtype = u.math.get_dtype(x)
@@ -418,8 +419,8 @@ class DropoutFixed(ElementWiseBlock):
418
419
  if self.mask.value.shape != x.shape:
419
420
  raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
420
421
  f"Please call `init_state()` method first.")
421
- return jnp.where(self.mask.value,
422
- jnp.asarray(x / self.prob, dtype=dtype),
423
- jnp.asarray(0., dtype=dtype))
422
+ return u.math.where(self.mask.value,
423
+ u.math.asarray(x / self.prob, dtype=dtype),
424
+ u.math.asarray(0., dtype=dtype) * u.get_unit(x))
424
425
  else:
425
426
  return x