brainstate 0.1.2__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +0 -15
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +30 -14
  9. brainstate/nn/__init__.py +84 -17
  10. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  11. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +19 -3
  12. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +6 -5
  13. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +137 -21
  14. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  15. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  16. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob.py} +96 -25
  17. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  18. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  19. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +2 -2
  20. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  21. brainstate/nn/_module.py +5 -5
  22. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  23. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  24. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  25. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
  26. brainstate/nn/_projection.py +486 -0
  27. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  28. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  29. brainstate/nn/_stp.py +236 -0
  30. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +19 -212
  31. brainstate/nn/_synaptic_projection.py +423 -0
  32. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  33. brainstate/surrogate.py +1 -1
  34. brainstate/typing.py +1 -1
  35. brainstate/util/__init__.py +14 -14
  36. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  37. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  38. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/RECORD +61 -63
  39. brainstate/nn/_dyn_impl/__init__.py +0 -42
  40. brainstate/nn/_dynamics/__init__.py +0 -37
  41. brainstate/nn/_dynamics/_projection_base.py +0 -362
  42. brainstate/nn/_elementwise/__init__.py +0 -22
  43. brainstate/nn/_interaction/__init__.py +0 -41
  44. /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
  45. /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
  46. /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
  47. /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
  48. /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  49. /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
  50. /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
  51. /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
  52. /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
  53. /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
  54. /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
  55. /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
  56. /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
  57. /brainstate/util/{_caller.py → caller.py} +0 -0
  58. /brainstate/util/{_error.py → error.py} +0 -0
  59. /brainstate/util/{_others.py → others.py} +0 -0
  60. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  61. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  62. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  63. /brainstate/util/{_struct.py → struct.py} +0 -0
  64. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  65. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  66. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.2"
20
+ __version__ = "0.1.4"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
@@ -16,7 +16,6 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- import importlib.util
20
19
  from contextlib import contextmanager
21
20
  from functools import partial
22
21
  from typing import Iterable, Hashable, TypeVar, Callable
@@ -31,7 +30,6 @@ __all__ = [
31
30
  'get_aval',
32
31
  'Tracer',
33
32
  'to_concrete_aval',
34
- 'brainevent',
35
33
  'safe_map',
36
34
  'safe_zip',
37
35
  'unzip2',
@@ -45,9 +43,7 @@ T1 = TypeVar("T1")
45
43
  T2 = TypeVar("T2")
46
44
  T3 = TypeVar("T3")
47
45
 
48
-
49
46
  from saiunit._compatible_import import wrap_init
50
- brainevent_installed = importlib.util.find_spec('brainevent') is not None
51
47
 
52
48
  from jax.core import get_aval, Tracer
53
49
 
@@ -150,14 +146,3 @@ def to_concrete_aval(aval):
150
146
  return aval.to_concrete_value()
151
147
  return aval
152
148
 
153
-
154
- if brainevent_installed:
155
- import brainevent
156
- else:
157
-
158
- class BrainEvent:
159
- def __getattr__(self, item):
160
- raise ImportError('brainevent is not installed, please install brainevent first.')
161
-
162
-
163
- brainevent = BrainEvent()
@@ -51,6 +51,7 @@ def _get_jitted_fun(
51
51
  out_shardings,
52
52
  static_argnums,
53
53
  donate_argnums,
54
+ static_argnames,
54
55
  donate_argnames,
55
56
  keep_unused,
56
57
  device,
@@ -59,10 +60,12 @@ def _get_jitted_fun(
59
60
  abstracted_axes,
60
61
  **kwargs
61
62
  ) -> JittedFunction:
62
- static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
63
+ static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
64
+ donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
63
65
  fun = StatefulFunction(
64
66
  fun,
65
67
  static_argnums=static_argnums,
68
+ static_argnames=static_argnames,
66
69
  abstracted_axes=abstracted_axes,
67
70
  cache_type='jit',
68
71
  name='jit'
@@ -70,7 +73,8 @@ def _get_jitted_fun(
70
73
  jit_fun = jax.jit(
71
74
  fun.jaxpr_call,
72
75
  static_argnums=tuple(i + 1 for i in static_argnums),
73
- donate_argnums=donate_argnums,
76
+ static_argnames=static_argnames,
77
+ donate_argnums=tuple(i + 1 for i in donate_argnums),
74
78
  donate_argnames=donate_argnames,
75
79
  keep_unused=keep_unused,
76
80
  device=device,
@@ -179,6 +183,7 @@ def jit(
179
183
  out_shardings=sharding_impls.UNSPECIFIED,
180
184
  static_argnums: int | Sequence[int] | None = None,
181
185
  donate_argnums: int | Sequence[int] | None = None,
186
+ static_argnames: str | Sequence[str] | None = None,
182
187
  donate_argnames: str | Iterable[str] | None = None,
183
188
  keep_unused: bool = False,
184
189
  device: Device | None = None,
@@ -190,9 +195,6 @@ def jit(
190
195
  """
191
196
  Sets up ``fun`` for just-in-time compilation with XLA.
192
197
 
193
- Does not support setting ``static_argnames`` as in ``jax.jit()``.
194
-
195
-
196
198
  Args:
197
199
  fun: Function to be jitted.
198
200
  in_shardings: Pytree of structure matching that of arguments to ``fun``,
@@ -246,6 +248,11 @@ def jit(
246
248
  provided, ``inspect.signature`` is not used, and only actual
247
249
  parameters listed in either ``static_argnums`` or ``static_argnames`` will
248
250
  be treated as static.
251
+ static_argnames: An optional string or collection of strings specifying
252
+ which named arguments are treated as static (compile-time constant).
253
+ Operations that only depend on static arguments will be constant-folded in
254
+ Python (during tracing), and so the corresponding argument values can be
255
+ any Python object.
249
256
  donate_argnums: Specify which positional argument buffers are "donated" to
250
257
  the computation. It is safe to donate argument buffers if you no longer
251
258
  need them once the computation has finished. In some cases XLA can make
@@ -309,6 +316,7 @@ def jit(
309
316
  out_shardings=out_shardings,
310
317
  static_argnums=static_argnums,
311
318
  donate_argnums=donate_argnums,
319
+ static_argnames=static_argnames,
312
320
  donate_argnames=donate_argnames,
313
321
  keep_unused=keep_unused,
314
322
  device=device,
@@ -327,6 +335,7 @@ def jit(
327
335
  out_shardings,
328
336
  static_argnums,
329
337
  donate_argnums,
338
+ static_argnames,
330
339
  donate_argnames,
331
340
  keep_unused,
332
341
  device,
@@ -88,6 +88,12 @@ __all__ = [
88
88
  ]
89
89
 
90
90
 
91
+ def _ensure_str(x: str) -> str:
92
+ if not isinstance(x, str):
93
+ raise TypeError(f"argument is not a string: {x}")
94
+ return x
95
+
96
+
91
97
  def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
92
98
  """Convert x to a tuple of indices."""
93
99
  x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
@@ -97,6 +103,14 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
97
103
  return tuple(safe_map(operator.index, x))
98
104
 
99
105
 
106
+ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
107
+ """Convert x to a tuple of strings."""
108
+ if isinstance(x, str):
109
+ return (x,)
110
+ else:
111
+ return tuple(safe_map(_ensure_str, x))
112
+
113
+
100
114
  def _jax_v04_new_arg_fn(frame, trace, aval):
101
115
  """
102
116
  Transform a new argument to a tracer.
@@ -155,6 +169,9 @@ def _init_state_trace_stack(name) -> StateTraceStack:
155
169
  return state_trace
156
170
 
157
171
 
172
+ default_cache_key = ((), ())
173
+
174
+
158
175
  class StatefulFunction(PrettyObject):
159
176
  """
160
177
  A wrapper class for a function that collects the states that are read and written by the function. The states are
@@ -170,6 +187,7 @@ class StatefulFunction(PrettyObject):
170
187
  arguments and return value should be arrays, scalars, or standard Python
171
188
  containers (tuple/list/dict) thereof.
172
189
  static_argnums: See the :py:func:`jax.jit` docstring.
190
+ static_argnames: See the :py:func:`jax.jit` docstring.
173
191
  axis_env: Optional, a sequence of pairs where the first element is an axis
174
192
  name and the second element is a positive integer representing the size of
175
193
  the mapped axis with that name. This parameter is useful when lowering
@@ -199,6 +217,7 @@ class StatefulFunction(PrettyObject):
199
217
  self,
200
218
  fun: Callable,
201
219
  static_argnums: Union[int, Iterable[int]] = (),
220
+ static_argnames: Union[str, Iterable[str]] = (),
202
221
  axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
203
222
  abstracted_axes: Optional[Any] = None,
204
223
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
@@ -207,11 +226,12 @@ class StatefulFunction(PrettyObject):
207
226
  ):
208
227
  # explicit parameters
209
228
  self.fun = fun
210
- self.static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
229
+ self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
230
+ self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
211
231
  self.axis_env = axis_env
212
232
  self.abstracted_axes = abstracted_axes
213
233
  self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
214
- assert cache_type in [None, 'jit']
234
+ assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
215
235
  self.name = name
216
236
 
217
237
  # implicit parameters
@@ -226,7 +246,7 @@ class StatefulFunction(PrettyObject):
226
246
  return None
227
247
  return k, v
228
248
 
229
- def get_jaxpr(self, cache_key: Hashable = ()) -> ClosedJaxpr:
249
+ def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
230
250
  """
231
251
  Read the JAX Jaxpr representation of the function.
232
252
 
@@ -236,11 +256,13 @@ class StatefulFunction(PrettyObject):
236
256
  Returns:
237
257
  The JAX Jaxpr representation of the function.
238
258
  """
259
+ if cache_key is None:
260
+ cache_key = default_cache_key
239
261
  if cache_key not in self._cached_jaxpr:
240
262
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
241
263
  return self._cached_jaxpr[cache_key]
242
264
 
243
- def get_out_shapes(self, cache_key: Hashable = ()) -> PyTree:
265
+ def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
244
266
  """
245
267
  Read the output shapes of the function.
246
268
 
@@ -250,11 +272,13 @@ class StatefulFunction(PrettyObject):
250
272
  Returns:
251
273
  The output shapes of the function.
252
274
  """
275
+ if cache_key is None:
276
+ cache_key = default_cache_key
253
277
  if cache_key not in self._cached_out_shapes:
254
278
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
255
279
  return self._cached_out_shapes[cache_key]
256
280
 
257
- def get_out_treedef(self, cache_key: Hashable = ()) -> PyTree:
281
+ def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
258
282
  """
259
283
  Read the output tree of the function.
260
284
 
@@ -264,11 +288,13 @@ class StatefulFunction(PrettyObject):
264
288
  Returns:
265
289
  The output tree of the function.
266
290
  """
291
+ if cache_key is None:
292
+ cache_key = default_cache_key
267
293
  if cache_key not in self._cached_jaxpr_out_tree:
268
294
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
269
295
  return self._cached_jaxpr_out_tree[cache_key]
270
296
 
271
- def get_state_trace(self, cache_key: Hashable = ()) -> StateTraceStack:
297
+ def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
272
298
  """
273
299
  Read the state trace of the function.
274
300
 
@@ -278,11 +304,13 @@ class StatefulFunction(PrettyObject):
278
304
  Returns:
279
305
  The state trace of the function.
280
306
  """
307
+ if cache_key is None:
308
+ cache_key = default_cache_key
281
309
  if cache_key not in self._cached_state_trace:
282
310
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
283
311
  return self._cached_state_trace[cache_key]
284
312
 
285
- def get_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
313
+ def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
286
314
  """
287
315
  Read the states that are read and written by the function.
288
316
 
@@ -292,9 +320,11 @@ class StatefulFunction(PrettyObject):
292
320
  Returns:
293
321
  The states that are read and written by the function.
294
322
  """
323
+ if cache_key is None:
324
+ cache_key = default_cache_key
295
325
  return tuple(self.get_state_trace(cache_key).states)
296
326
 
297
- def get_read_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
327
+ def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
298
328
  """
299
329
  Read the states that are read by the function.
300
330
 
@@ -304,9 +334,11 @@ class StatefulFunction(PrettyObject):
304
334
  Returns:
305
335
  The states that are read by the function.
306
336
  """
337
+ if cache_key is None:
338
+ cache_key = default_cache_key
307
339
  return self.get_state_trace(cache_key).get_read_states()
308
340
 
309
- def get_write_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
341
+ def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
310
342
  """
311
343
  Read the states that are written by the function.
312
344
 
@@ -316,6 +348,8 @@ class StatefulFunction(PrettyObject):
316
348
  Returns:
317
349
  The states that are written by the function.
318
350
  """
351
+ if cache_key is None:
352
+ cache_key = default_cache_key
319
353
  return self.get_state_trace(cache_key).get_write_states()
320
354
 
321
355
  def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
@@ -323,10 +357,11 @@ class StatefulFunction(PrettyObject):
323
357
  Get the static arguments from the arguments.
324
358
 
325
359
  Args:
326
- *args: The arguments to the function.
360
+ *args: The arguments to the function.
361
+ **kwargs: The keyword arguments to the function.
327
362
 
328
363
  Returns:
329
- The static arguments.
364
+ The static arguments and keyword arguments as a tuple.
330
365
  """
331
366
  if self.cache_type == 'jit':
332
367
  static_args, dyn_args = [], []
@@ -336,11 +371,18 @@ class StatefulFunction(PrettyObject):
336
371
  else:
337
372
  dyn_args.append(arg)
338
373
  dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
339
- dyn_kwargs = jax.tree.map(shaped_abstractify, jax.tree.leaves(kwargs))
340
- return tuple([tuple(static_args), tuple(dyn_args), tuple(dyn_kwargs)])
374
+ static_kwargs, dyn_kwargs = [], []
375
+ for k, v in kwargs.items():
376
+ if k in self.static_argnames:
377
+ static_kwargs.append((k, v))
378
+ else:
379
+ dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
380
+ return tuple([tuple(static_args), tuple(dyn_args), tuple(static_kwargs), tuple(dyn_kwargs)])
341
381
  elif self.cache_type is None:
342
382
  num_arg = len(args)
343
- return tuple(args[i] for i in self.static_argnums if i < num_arg)
383
+ static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
384
+ static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
385
+ return tuple([static_args, static_kwargs])
344
386
  else:
345
387
  raise ValueError(f"Invalid cache type: {self.cache_type}")
346
388
 
@@ -389,7 +431,7 @@ class StatefulFunction(PrettyObject):
389
431
  self._cached_state_trace.clear()
390
432
 
391
433
  def _wrapped_fun_to_eval(
392
- self, cache_key, *args, return_only_write: bool = False, **kwargs,
434
+ self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
393
435
  ) -> Tuple[Any, Tuple[State, ...]]:
394
436
  """
395
437
  Wrap the function and return the states that are read and written by the function and the output of the function.
@@ -405,7 +447,7 @@ class StatefulFunction(PrettyObject):
405
447
  state_trace = _init_state_trace_stack(self.name)
406
448
  self._cached_state_trace[cache_key] = state_trace
407
449
  with state_trace:
408
- out = self.fun(*args, **kwargs)
450
+ out = self.fun(*args, **dyn_kwargs, **static_kwargs)
409
451
  state_values = (
410
452
  state_trace.get_write_state_values(True)
411
453
  if return_only_write else
@@ -430,8 +472,9 @@ class StatefulFunction(PrettyObject):
430
472
  the structure, shape, dtypes, and named shapes of the output of ``fun``.
431
473
 
432
474
  Args:
433
- *args: The arguments to the function.
434
- **kwargs: The keyword arguments to the function.
475
+ *args: The arguments to the function.
476
+ **kwargs: The keyword arguments to the function.
477
+ return_only_write: If True, only return the states that are written by the function.
435
478
  """
436
479
 
437
480
  # static args
@@ -440,17 +483,24 @@ class StatefulFunction(PrettyObject):
440
483
  if cache_key not in self._cached_state_trace:
441
484
  try:
442
485
  # jaxpr
486
+ static_kwargs, dyn_kwargs = {}, {}
487
+ for k, v in kwargs.items():
488
+ if k in self.static_argnames:
489
+ static_kwargs[k] = v
490
+ else:
491
+ dyn_kwargs[k] = v
443
492
  jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
444
493
  functools.partial(
445
494
  self._wrapped_fun_to_eval,
446
495
  cache_key,
496
+ static_kwargs,
447
497
  return_only_write=return_only_write
448
498
  ),
449
499
  static_argnums=self.static_argnums,
450
500
  axis_env=self.axis_env,
451
501
  return_shape=True,
452
502
  abstracted_axes=self.abstracted_axes
453
- )(*args, **kwargs)
503
+ )(*args, **dyn_kwargs)
454
504
  # returns
455
505
  self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
456
506
  self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
@@ -483,6 +533,7 @@ class StatefulFunction(PrettyObject):
483
533
  assert len(state_vals) == len(states), 'State length mismatch.'
484
534
 
485
535
  # parameters
536
+ kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
486
537
  args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
487
538
  args = jax.tree.flatten((args, kwargs, state_vals))[0]
488
539
 
@@ -519,12 +570,16 @@ class StatefulFunction(PrettyObject):
519
570
  def make_jaxpr(
520
571
  fun: Callable,
521
572
  static_argnums: Union[int, Iterable[int]] = (),
573
+ static_argnames: Union[str, Iterable[str]] = (),
522
574
  axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
523
575
  return_shape: bool = False,
524
576
  abstracted_axes: Optional[Any] = None,
525
577
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
526
- ) -> Callable[..., (Tuple[ClosedJaxpr, Tuple[State, ...]] |
527
- Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])]:
578
+ ) -> Callable[
579
+ ...,
580
+ (Tuple[ClosedJaxpr, Tuple[State, ...]] |
581
+ Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
582
+ ]:
528
583
  """
529
584
  Creates a function that produces its jaxpr given example args.
530
585
 
@@ -533,6 +588,7 @@ def make_jaxpr(
533
588
  arguments and return value should be arrays, scalars, or standard Python
534
589
  containers (tuple/list/dict) thereof.
535
590
  static_argnums: See the :py:func:`jax.jit` docstring.
591
+ static_argnames: See the :py:func:`jax.jit` docstring.
536
592
  axis_env: Optional, a sequence of pairs where the first element is an axis
537
593
  name and the second element is a positive integer representing the size of
538
594
  the mapped axis with that name. This parameter is useful when lowering
@@ -605,11 +661,11 @@ def make_jaxpr(
605
661
  stateful_fun = StatefulFunction(
606
662
  fun,
607
663
  static_argnums=static_argnums,
664
+ static_argnames=static_argnames,
608
665
  axis_env=axis_env,
609
666
  abstracted_axes=abstracted_axes,
610
667
  state_returns=state_returns,
611
668
  name='make_jaxpr'
612
-
613
669
  )
614
670
 
615
671
  @wraps(fun)
@@ -88,7 +88,7 @@ class TestMakeJaxpr(unittest.TestCase):
88
88
  self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
89
89
  f3(jnp.zeros(1))))
90
90
 
91
- def test_compar_jax_make_jaxpr2(self):
91
+ def test_compare_jax_make_jaxpr2(self):
92
92
  st1 = brainstate.State(jnp.ones(10))
93
93
 
94
94
  def fa(x):
@@ -108,7 +108,7 @@ class TestMakeJaxpr(unittest.TestCase):
108
108
  print(jaxpr)
109
109
  print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
110
110
 
111
- def test_compar_jax_make_jaxpr3(self):
111
+ def test_compare_jax_make_jaxpr3(self):
112
112
  def fa(x):
113
113
  return 1.
114
114
 
@@ -121,6 +121,17 @@ class TestMakeJaxpr(unittest.TestCase):
121
121
  print(jaxpr)
122
122
  # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
123
 
124
+ def test_static_argnames(self):
125
+ def func4(a, b): # Arg is a pair
126
+ temp = a + jnp.sin(b) * 3.
127
+ c = brainstate.random.rand_like(a)
128
+ return jnp.sum(temp + c)
129
+
130
+ jaxpr, states = brainstate.compile.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
131
+ print()
132
+ print(jaxpr)
133
+ print(states)
134
+
124
135
 
125
136
  def test_return_states():
126
137
  import jax.numpy
@@ -25,7 +25,7 @@ import numpy as np
25
25
 
26
26
  from brainstate._state import State, TreefyState
27
27
  from brainstate.typing import Key
28
- from brainstate.util._pretty_pytree import PrettyObject
28
+ from brainstate.util.pretty_pytree import PrettyObject
29
29
  from ._graph_operation import register_graph_node_type
30
30
 
31
31
  __all__ = [
@@ -30,10 +30,10 @@ from typing_extensions import TypeGuard, Unpack
30
30
  from brainstate._state import State, TreefyState
31
31
  from brainstate._utils import set_module_as
32
32
  from brainstate.typing import PathParts, Filter, Predicate, Key
33
- from brainstate.util._caller import ApplyCaller, CallableProxy, DelayedAccessor
34
- from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
- from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
- from brainstate.util._struct import FrozenDict
33
+ from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
34
+ from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
+ from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
+ from brainstate.util.struct import FrozenDict
37
37
  from brainstate.util.filter import to_predicate
38
38
 
39
39
  _max_int = np.iinfo(np.int32).max
brainstate/mixin.py CHANGED
@@ -41,6 +41,14 @@ __all__ = [
41
41
  ]
42
42
 
43
43
 
44
+ def hashable(x):
45
+ try:
46
+ hash(x)
47
+ return True
48
+ except TypeError:
49
+ return False
50
+
51
+
44
52
  class Mixin(object):
45
53
  """Base Mixin object.
46
54
 
@@ -67,6 +75,14 @@ class ParamDesc(Mixin):
67
75
 
68
76
 
69
77
  class HashableDict(dict):
78
+ def __init__(self, the_dict: dict):
79
+ out = dict()
80
+ for k, v in the_dict.items():
81
+ if not hashable(v):
82
+ v = str(v) # convert to string if not hashable
83
+ out[k] = v
84
+ super().__init__(out)
85
+
70
86
  def __hash__(self):
71
87
  return hash(tuple(sorted(self.items())))
72
88
 
@@ -146,29 +162,29 @@ class BindCondData(Mixin):
146
162
  self._conductance = None
147
163
 
148
164
 
149
- class UpdateReturn(Mixin):
165
+ def not_implemented(func):
150
166
 
151
- def update_return(self) -> PyTree:
152
- """
153
- The update function return of the model.
167
+ def wrapper(*args, **kwargs):
168
+ raise NotImplementedError(f'{func.__name__} is not implemented.')
154
169
 
155
- It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
170
+ wrapper.not_implemented = True
171
+ return wrapper
156
172
 
157
- """
158
- raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
159
173
 
160
- def update_return_info(self) -> PyTree:
174
+
175
+ class UpdateReturn(Mixin):
176
+ @not_implemented
177
+ def update_return(self) -> PyTree:
161
178
  """
162
- The update return information of the model.
179
+ The update function return of the model.
163
180
 
164
- It should be a pytree, with each element as a ``jax.Array``.
181
+ This function requires no parameters and must return a PyTree.
165
182
 
166
- .. note::
167
- Should not include the batch axis and batch in_size.
168
- These information will be inferred from the ``mode`` attribute.
183
+ It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
184
+ initialize the output delay.
169
185
 
170
186
  """
171
- raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
187
+ raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
172
188
 
173
189
 
174
190
  class _MetaUnionType(type):