brainstate 0.1.3__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +1 -16
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +31 -2
- brainstate/nn/__init__.py +8 -5
- brainstate/nn/_delay.py +13 -1
- brainstate/nn/_dropout.py +5 -4
- brainstate/nn/_dynamics.py +39 -44
- brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
- brainstate/nn/_linear_mv.py +1 -1
- brainstate/nn/_module.py +5 -5
- brainstate/nn/_projection.py +190 -98
- brainstate/nn/_synapse.py +5 -9
- brainstate/nn/_synaptic_projection.py +376 -86
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/RECORD +35 -35
- /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_compatible_import.py
CHANGED
@@ -16,10 +16,9 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
-
import importlib.util
|
20
19
|
from contextlib import contextmanager
|
21
20
|
from functools import partial
|
22
|
-
from typing import Iterable, Hashable, TypeVar, Callable
|
21
|
+
from typing import Iterable, Hashable, TypeVar, Callable
|
23
22
|
|
24
23
|
import jax
|
25
24
|
|
@@ -31,7 +30,6 @@ __all__ = [
|
|
31
30
|
'get_aval',
|
32
31
|
'Tracer',
|
33
32
|
'to_concrete_aval',
|
34
|
-
'brainevent',
|
35
33
|
'safe_map',
|
36
34
|
'safe_zip',
|
37
35
|
'unzip2',
|
@@ -47,8 +45,6 @@ T3 = TypeVar("T3")
|
|
47
45
|
|
48
46
|
from saiunit._compatible_import import wrap_init
|
49
47
|
|
50
|
-
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
51
|
-
|
52
48
|
from jax.core import get_aval, Tracer
|
53
49
|
|
54
50
|
if jax.__version_info__ < (0, 5, 0):
|
@@ -150,14 +146,3 @@ def to_concrete_aval(aval):
|
|
150
146
|
return aval.to_concrete_value()
|
151
147
|
return aval
|
152
148
|
|
153
|
-
|
154
|
-
if not brainevent_installed:
|
155
|
-
if not TYPE_CHECKING:
|
156
|
-
class BrainEvent:
|
157
|
-
def __getattr__(self, item):
|
158
|
-
raise ImportError('brainevent is not installed, please install brainevent first.')
|
159
|
-
|
160
|
-
brainevent = BrainEvent()
|
161
|
-
|
162
|
-
else:
|
163
|
-
import brainevent
|
brainstate/compile/_jit.py
CHANGED
@@ -51,6 +51,7 @@ def _get_jitted_fun(
|
|
51
51
|
out_shardings,
|
52
52
|
static_argnums,
|
53
53
|
donate_argnums,
|
54
|
+
static_argnames,
|
54
55
|
donate_argnames,
|
55
56
|
keep_unused,
|
56
57
|
device,
|
@@ -59,10 +60,12 @@ def _get_jitted_fun(
|
|
59
60
|
abstracted_axes,
|
60
61
|
**kwargs
|
61
62
|
) -> JittedFunction:
|
62
|
-
static_argnums =
|
63
|
+
static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
64
|
+
donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
|
63
65
|
fun = StatefulFunction(
|
64
66
|
fun,
|
65
67
|
static_argnums=static_argnums,
|
68
|
+
static_argnames=static_argnames,
|
66
69
|
abstracted_axes=abstracted_axes,
|
67
70
|
cache_type='jit',
|
68
71
|
name='jit'
|
@@ -70,7 +73,8 @@ def _get_jitted_fun(
|
|
70
73
|
jit_fun = jax.jit(
|
71
74
|
fun.jaxpr_call,
|
72
75
|
static_argnums=tuple(i + 1 for i in static_argnums),
|
73
|
-
|
76
|
+
static_argnames=static_argnames,
|
77
|
+
donate_argnums=tuple(i + 1 for i in donate_argnums),
|
74
78
|
donate_argnames=donate_argnames,
|
75
79
|
keep_unused=keep_unused,
|
76
80
|
device=device,
|
@@ -179,6 +183,7 @@ def jit(
|
|
179
183
|
out_shardings=sharding_impls.UNSPECIFIED,
|
180
184
|
static_argnums: int | Sequence[int] | None = None,
|
181
185
|
donate_argnums: int | Sequence[int] | None = None,
|
186
|
+
static_argnames: str | Sequence[str] | None = None,
|
182
187
|
donate_argnames: str | Iterable[str] | None = None,
|
183
188
|
keep_unused: bool = False,
|
184
189
|
device: Device | None = None,
|
@@ -190,9 +195,6 @@ def jit(
|
|
190
195
|
"""
|
191
196
|
Sets up ``fun`` for just-in-time compilation with XLA.
|
192
197
|
|
193
|
-
Does not support setting ``static_argnames`` as in ``jax.jit()``.
|
194
|
-
|
195
|
-
|
196
198
|
Args:
|
197
199
|
fun: Function to be jitted.
|
198
200
|
in_shardings: Pytree of structure matching that of arguments to ``fun``,
|
@@ -246,6 +248,11 @@ def jit(
|
|
246
248
|
provided, ``inspect.signature`` is not used, and only actual
|
247
249
|
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
248
250
|
be treated as static.
|
251
|
+
static_argnames: An optional string or collection of strings specifying
|
252
|
+
which named arguments are treated as static (compile-time constant).
|
253
|
+
Operations that only depend on static arguments will be constant-folded in
|
254
|
+
Python (during tracing), and so the corresponding argument values can be
|
255
|
+
any Python object.
|
249
256
|
donate_argnums: Specify which positional argument buffers are "donated" to
|
250
257
|
the computation. It is safe to donate argument buffers if you no longer
|
251
258
|
need them once the computation has finished. In some cases XLA can make
|
@@ -309,6 +316,7 @@ def jit(
|
|
309
316
|
out_shardings=out_shardings,
|
310
317
|
static_argnums=static_argnums,
|
311
318
|
donate_argnums=donate_argnums,
|
319
|
+
static_argnames=static_argnames,
|
312
320
|
donate_argnames=donate_argnames,
|
313
321
|
keep_unused=keep_unused,
|
314
322
|
device=device,
|
@@ -327,6 +335,7 @@ def jit(
|
|
327
335
|
out_shardings,
|
328
336
|
static_argnums,
|
329
337
|
donate_argnums,
|
338
|
+
static_argnames,
|
330
339
|
donate_argnames,
|
331
340
|
keep_unused,
|
332
341
|
device,
|
@@ -88,6 +88,12 @@ __all__ = [
|
|
88
88
|
]
|
89
89
|
|
90
90
|
|
91
|
+
def _ensure_str(x: str) -> str:
|
92
|
+
if not isinstance(x, str):
|
93
|
+
raise TypeError(f"argument is not a string: {x}")
|
94
|
+
return x
|
95
|
+
|
96
|
+
|
91
97
|
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
92
98
|
"""Convert x to a tuple of indices."""
|
93
99
|
x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
@@ -97,6 +103,14 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
|
97
103
|
return tuple(safe_map(operator.index, x))
|
98
104
|
|
99
105
|
|
106
|
+
def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
107
|
+
"""Convert x to a tuple of strings."""
|
108
|
+
if isinstance(x, str):
|
109
|
+
return (x,)
|
110
|
+
else:
|
111
|
+
return tuple(safe_map(_ensure_str, x))
|
112
|
+
|
113
|
+
|
100
114
|
def _jax_v04_new_arg_fn(frame, trace, aval):
|
101
115
|
"""
|
102
116
|
Transform a new argument to a tracer.
|
@@ -155,6 +169,9 @@ def _init_state_trace_stack(name) -> StateTraceStack:
|
|
155
169
|
return state_trace
|
156
170
|
|
157
171
|
|
172
|
+
default_cache_key = ((), ())
|
173
|
+
|
174
|
+
|
158
175
|
class StatefulFunction(PrettyObject):
|
159
176
|
"""
|
160
177
|
A wrapper class for a function that collects the states that are read and written by the function. The states are
|
@@ -170,6 +187,7 @@ class StatefulFunction(PrettyObject):
|
|
170
187
|
arguments and return value should be arrays, scalars, or standard Python
|
171
188
|
containers (tuple/list/dict) thereof.
|
172
189
|
static_argnums: See the :py:func:`jax.jit` docstring.
|
190
|
+
static_argnames: See the :py:func:`jax.jit` docstring.
|
173
191
|
axis_env: Optional, a sequence of pairs where the first element is an axis
|
174
192
|
name and the second element is a positive integer representing the size of
|
175
193
|
the mapped axis with that name. This parameter is useful when lowering
|
@@ -199,6 +217,7 @@ class StatefulFunction(PrettyObject):
|
|
199
217
|
self,
|
200
218
|
fun: Callable,
|
201
219
|
static_argnums: Union[int, Iterable[int]] = (),
|
220
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
202
221
|
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
203
222
|
abstracted_axes: Optional[Any] = None,
|
204
223
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
|
@@ -207,11 +226,12 @@ class StatefulFunction(PrettyObject):
|
|
207
226
|
):
|
208
227
|
# explicit parameters
|
209
228
|
self.fun = fun
|
210
|
-
self.static_argnums =
|
229
|
+
self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
230
|
+
self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
|
211
231
|
self.axis_env = axis_env
|
212
232
|
self.abstracted_axes = abstracted_axes
|
213
233
|
self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
|
214
|
-
assert cache_type in [None, 'jit']
|
234
|
+
assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
|
215
235
|
self.name = name
|
216
236
|
|
217
237
|
# implicit parameters
|
@@ -226,7 +246,7 @@ class StatefulFunction(PrettyObject):
|
|
226
246
|
return None
|
227
247
|
return k, v
|
228
248
|
|
229
|
-
def get_jaxpr(self, cache_key: Hashable =
|
249
|
+
def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
|
230
250
|
"""
|
231
251
|
Read the JAX Jaxpr representation of the function.
|
232
252
|
|
@@ -236,11 +256,13 @@ class StatefulFunction(PrettyObject):
|
|
236
256
|
Returns:
|
237
257
|
The JAX Jaxpr representation of the function.
|
238
258
|
"""
|
259
|
+
if cache_key is None:
|
260
|
+
cache_key = default_cache_key
|
239
261
|
if cache_key not in self._cached_jaxpr:
|
240
262
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
241
263
|
return self._cached_jaxpr[cache_key]
|
242
264
|
|
243
|
-
def get_out_shapes(self, cache_key: Hashable =
|
265
|
+
def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
|
244
266
|
"""
|
245
267
|
Read the output shapes of the function.
|
246
268
|
|
@@ -250,11 +272,13 @@ class StatefulFunction(PrettyObject):
|
|
250
272
|
Returns:
|
251
273
|
The output shapes of the function.
|
252
274
|
"""
|
275
|
+
if cache_key is None:
|
276
|
+
cache_key = default_cache_key
|
253
277
|
if cache_key not in self._cached_out_shapes:
|
254
278
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
255
279
|
return self._cached_out_shapes[cache_key]
|
256
280
|
|
257
|
-
def get_out_treedef(self, cache_key: Hashable =
|
281
|
+
def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
|
258
282
|
"""
|
259
283
|
Read the output tree of the function.
|
260
284
|
|
@@ -264,11 +288,13 @@ class StatefulFunction(PrettyObject):
|
|
264
288
|
Returns:
|
265
289
|
The output tree of the function.
|
266
290
|
"""
|
291
|
+
if cache_key is None:
|
292
|
+
cache_key = default_cache_key
|
267
293
|
if cache_key not in self._cached_jaxpr_out_tree:
|
268
294
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
269
295
|
return self._cached_jaxpr_out_tree[cache_key]
|
270
296
|
|
271
|
-
def get_state_trace(self, cache_key: Hashable =
|
297
|
+
def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
|
272
298
|
"""
|
273
299
|
Read the state trace of the function.
|
274
300
|
|
@@ -278,11 +304,13 @@ class StatefulFunction(PrettyObject):
|
|
278
304
|
Returns:
|
279
305
|
The state trace of the function.
|
280
306
|
"""
|
307
|
+
if cache_key is None:
|
308
|
+
cache_key = default_cache_key
|
281
309
|
if cache_key not in self._cached_state_trace:
|
282
310
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
283
311
|
return self._cached_state_trace[cache_key]
|
284
312
|
|
285
|
-
def get_states(self, cache_key: Hashable =
|
313
|
+
def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
286
314
|
"""
|
287
315
|
Read the states that are read and written by the function.
|
288
316
|
|
@@ -292,9 +320,11 @@ class StatefulFunction(PrettyObject):
|
|
292
320
|
Returns:
|
293
321
|
The states that are read and written by the function.
|
294
322
|
"""
|
323
|
+
if cache_key is None:
|
324
|
+
cache_key = default_cache_key
|
295
325
|
return tuple(self.get_state_trace(cache_key).states)
|
296
326
|
|
297
|
-
def get_read_states(self, cache_key: Hashable =
|
327
|
+
def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
298
328
|
"""
|
299
329
|
Read the states that are read by the function.
|
300
330
|
|
@@ -304,9 +334,11 @@ class StatefulFunction(PrettyObject):
|
|
304
334
|
Returns:
|
305
335
|
The states that are read by the function.
|
306
336
|
"""
|
337
|
+
if cache_key is None:
|
338
|
+
cache_key = default_cache_key
|
307
339
|
return self.get_state_trace(cache_key).get_read_states()
|
308
340
|
|
309
|
-
def get_write_states(self, cache_key: Hashable =
|
341
|
+
def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
310
342
|
"""
|
311
343
|
Read the states that are written by the function.
|
312
344
|
|
@@ -316,6 +348,8 @@ class StatefulFunction(PrettyObject):
|
|
316
348
|
Returns:
|
317
349
|
The states that are written by the function.
|
318
350
|
"""
|
351
|
+
if cache_key is None:
|
352
|
+
cache_key = default_cache_key
|
319
353
|
return self.get_state_trace(cache_key).get_write_states()
|
320
354
|
|
321
355
|
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
@@ -323,10 +357,11 @@ class StatefulFunction(PrettyObject):
|
|
323
357
|
Get the static arguments from the arguments.
|
324
358
|
|
325
359
|
Args:
|
326
|
-
|
360
|
+
*args: The arguments to the function.
|
361
|
+
**kwargs: The keyword arguments to the function.
|
327
362
|
|
328
363
|
Returns:
|
329
|
-
The static arguments.
|
364
|
+
The static arguments and keyword arguments as a tuple.
|
330
365
|
"""
|
331
366
|
if self.cache_type == 'jit':
|
332
367
|
static_args, dyn_args = [], []
|
@@ -336,11 +371,18 @@ class StatefulFunction(PrettyObject):
|
|
336
371
|
else:
|
337
372
|
dyn_args.append(arg)
|
338
373
|
dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
|
339
|
-
dyn_kwargs =
|
340
|
-
|
374
|
+
static_kwargs, dyn_kwargs = [], []
|
375
|
+
for k, v in kwargs.items():
|
376
|
+
if k in self.static_argnames:
|
377
|
+
static_kwargs.append((k, v))
|
378
|
+
else:
|
379
|
+
dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
|
380
|
+
return tuple([tuple(static_args), tuple(dyn_args), tuple(static_kwargs), tuple(dyn_kwargs)])
|
341
381
|
elif self.cache_type is None:
|
342
382
|
num_arg = len(args)
|
343
|
-
|
383
|
+
static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
|
384
|
+
static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
|
385
|
+
return tuple([static_args, static_kwargs])
|
344
386
|
else:
|
345
387
|
raise ValueError(f"Invalid cache type: {self.cache_type}")
|
346
388
|
|
@@ -389,7 +431,7 @@ class StatefulFunction(PrettyObject):
|
|
389
431
|
self._cached_state_trace.clear()
|
390
432
|
|
391
433
|
def _wrapped_fun_to_eval(
|
392
|
-
self, cache_key, *args, return_only_write: bool = False, **
|
434
|
+
self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
|
393
435
|
) -> Tuple[Any, Tuple[State, ...]]:
|
394
436
|
"""
|
395
437
|
Wrap the function and return the states that are read and written by the function and the output of the function.
|
@@ -405,7 +447,7 @@ class StatefulFunction(PrettyObject):
|
|
405
447
|
state_trace = _init_state_trace_stack(self.name)
|
406
448
|
self._cached_state_trace[cache_key] = state_trace
|
407
449
|
with state_trace:
|
408
|
-
out = self.fun(*args, **
|
450
|
+
out = self.fun(*args, **dyn_kwargs, **static_kwargs)
|
409
451
|
state_values = (
|
410
452
|
state_trace.get_write_state_values(True)
|
411
453
|
if return_only_write else
|
@@ -430,8 +472,9 @@ class StatefulFunction(PrettyObject):
|
|
430
472
|
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
431
473
|
|
432
474
|
Args:
|
433
|
-
|
434
|
-
|
475
|
+
*args: The arguments to the function.
|
476
|
+
**kwargs: The keyword arguments to the function.
|
477
|
+
return_only_write: If True, only return the states that are written by the function.
|
435
478
|
"""
|
436
479
|
|
437
480
|
# static args
|
@@ -440,17 +483,24 @@ class StatefulFunction(PrettyObject):
|
|
440
483
|
if cache_key not in self._cached_state_trace:
|
441
484
|
try:
|
442
485
|
# jaxpr
|
486
|
+
static_kwargs, dyn_kwargs = {}, {}
|
487
|
+
for k, v in kwargs.items():
|
488
|
+
if k in self.static_argnames:
|
489
|
+
static_kwargs[k] = v
|
490
|
+
else:
|
491
|
+
dyn_kwargs[k] = v
|
443
492
|
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
444
493
|
functools.partial(
|
445
494
|
self._wrapped_fun_to_eval,
|
446
495
|
cache_key,
|
496
|
+
static_kwargs,
|
447
497
|
return_only_write=return_only_write
|
448
498
|
),
|
449
499
|
static_argnums=self.static_argnums,
|
450
500
|
axis_env=self.axis_env,
|
451
501
|
return_shape=True,
|
452
502
|
abstracted_axes=self.abstracted_axes
|
453
|
-
)(*args, **
|
503
|
+
)(*args, **dyn_kwargs)
|
454
504
|
# returns
|
455
505
|
self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
|
456
506
|
self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
|
@@ -483,6 +533,7 @@ class StatefulFunction(PrettyObject):
|
|
483
533
|
assert len(state_vals) == len(states), 'State length mismatch.'
|
484
534
|
|
485
535
|
# parameters
|
536
|
+
kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
|
486
537
|
args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
|
487
538
|
args = jax.tree.flatten((args, kwargs, state_vals))[0]
|
488
539
|
|
@@ -519,12 +570,16 @@ class StatefulFunction(PrettyObject):
|
|
519
570
|
def make_jaxpr(
|
520
571
|
fun: Callable,
|
521
572
|
static_argnums: Union[int, Iterable[int]] = (),
|
573
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
522
574
|
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
523
575
|
return_shape: bool = False,
|
524
576
|
abstracted_axes: Optional[Any] = None,
|
525
577
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
526
|
-
) -> Callable[
|
527
|
-
|
578
|
+
) -> Callable[
|
579
|
+
...,
|
580
|
+
(Tuple[ClosedJaxpr, Tuple[State, ...]] |
|
581
|
+
Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
|
582
|
+
]:
|
528
583
|
"""
|
529
584
|
Creates a function that produces its jaxpr given example args.
|
530
585
|
|
@@ -533,6 +588,7 @@ def make_jaxpr(
|
|
533
588
|
arguments and return value should be arrays, scalars, or standard Python
|
534
589
|
containers (tuple/list/dict) thereof.
|
535
590
|
static_argnums: See the :py:func:`jax.jit` docstring.
|
591
|
+
static_argnames: See the :py:func:`jax.jit` docstring.
|
536
592
|
axis_env: Optional, a sequence of pairs where the first element is an axis
|
537
593
|
name and the second element is a positive integer representing the size of
|
538
594
|
the mapped axis with that name. This parameter is useful when lowering
|
@@ -605,11 +661,11 @@ def make_jaxpr(
|
|
605
661
|
stateful_fun = StatefulFunction(
|
606
662
|
fun,
|
607
663
|
static_argnums=static_argnums,
|
664
|
+
static_argnames=static_argnames,
|
608
665
|
axis_env=axis_env,
|
609
666
|
abstracted_axes=abstracted_axes,
|
610
667
|
state_returns=state_returns,
|
611
668
|
name='make_jaxpr'
|
612
|
-
|
613
669
|
)
|
614
670
|
|
615
671
|
@wraps(fun)
|
@@ -88,7 +88,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
88
88
|
self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
89
89
|
f3(jnp.zeros(1))))
|
90
90
|
|
91
|
-
def
|
91
|
+
def test_compare_jax_make_jaxpr2(self):
|
92
92
|
st1 = brainstate.State(jnp.ones(10))
|
93
93
|
|
94
94
|
def fa(x):
|
@@ -108,7 +108,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
108
108
|
print(jaxpr)
|
109
109
|
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
110
110
|
|
111
|
-
def
|
111
|
+
def test_compare_jax_make_jaxpr3(self):
|
112
112
|
def fa(x):
|
113
113
|
return 1.
|
114
114
|
|
@@ -121,6 +121,17 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
121
121
|
print(jaxpr)
|
122
122
|
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
123
123
|
|
124
|
+
def test_static_argnames(self):
|
125
|
+
def func4(a, b): # Arg is a pair
|
126
|
+
temp = a + jnp.sin(b) * 3.
|
127
|
+
c = brainstate.random.rand_like(a)
|
128
|
+
return jnp.sum(temp + c)
|
129
|
+
|
130
|
+
jaxpr, states = brainstate.compile.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
|
131
|
+
print()
|
132
|
+
print(jaxpr)
|
133
|
+
print(states)
|
134
|
+
|
124
135
|
|
125
136
|
def test_return_states():
|
126
137
|
import jax.numpy
|
brainstate/graph/_graph_node.py
CHANGED
@@ -25,7 +25,7 @@ import numpy as np
|
|
25
25
|
|
26
26
|
from brainstate._state import State, TreefyState
|
27
27
|
from brainstate.typing import Key
|
28
|
-
from brainstate.util.
|
28
|
+
from brainstate.util.pretty_pytree import PrettyObject
|
29
29
|
from ._graph_operation import register_graph_node_type
|
30
30
|
|
31
31
|
__all__ = [
|
@@ -30,10 +30,10 @@ from typing_extensions import TypeGuard, Unpack
|
|
30
30
|
from brainstate._state import State, TreefyState
|
31
31
|
from brainstate._utils import set_module_as
|
32
32
|
from brainstate.typing import PathParts, Filter, Predicate, Key
|
33
|
-
from brainstate.util.
|
34
|
-
from brainstate.util.
|
35
|
-
from brainstate.util.
|
36
|
-
from brainstate.util.
|
33
|
+
from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
|
34
|
+
from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
|
35
|
+
from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
|
36
|
+
from brainstate.util.struct import FrozenDict
|
37
37
|
from brainstate.util.filter import to_predicate
|
38
38
|
|
39
39
|
_max_int = np.iinfo(np.int32).max
|
brainstate/mixin.py
CHANGED
@@ -41,6 +41,14 @@ __all__ = [
|
|
41
41
|
]
|
42
42
|
|
43
43
|
|
44
|
+
def hashable(x):
|
45
|
+
try:
|
46
|
+
hash(x)
|
47
|
+
return True
|
48
|
+
except TypeError:
|
49
|
+
return False
|
50
|
+
|
51
|
+
|
44
52
|
class Mixin(object):
|
45
53
|
"""Base Mixin object.
|
46
54
|
|
@@ -67,6 +75,14 @@ class ParamDesc(Mixin):
|
|
67
75
|
|
68
76
|
|
69
77
|
class HashableDict(dict):
|
78
|
+
def __init__(self, the_dict: dict):
|
79
|
+
out = dict()
|
80
|
+
for k, v in the_dict.items():
|
81
|
+
if not hashable(v):
|
82
|
+
v = str(v) # convert to string if not hashable
|
83
|
+
out[k] = v
|
84
|
+
super().__init__(out)
|
85
|
+
|
70
86
|
def __hash__(self):
|
71
87
|
return hash(tuple(sorted(self.items())))
|
72
88
|
|
@@ -132,7 +148,6 @@ class AlignPost(Mixin):
|
|
132
148
|
raise NotImplementedError
|
133
149
|
|
134
150
|
|
135
|
-
|
136
151
|
class BindCondData(Mixin):
|
137
152
|
"""Bind temporary conductance data.
|
138
153
|
|
@@ -147,12 +162,26 @@ class BindCondData(Mixin):
|
|
147
162
|
self._conductance = None
|
148
163
|
|
149
164
|
|
165
|
+
def not_implemented(func):
|
166
|
+
|
167
|
+
def wrapper(*args, **kwargs):
|
168
|
+
raise NotImplementedError(f'{func.__name__} is not implemented.')
|
169
|
+
|
170
|
+
wrapper.not_implemented = True
|
171
|
+
return wrapper
|
172
|
+
|
173
|
+
|
174
|
+
|
150
175
|
class UpdateReturn(Mixin):
|
176
|
+
@not_implemented
|
151
177
|
def update_return(self) -> PyTree:
|
152
178
|
"""
|
153
179
|
The update function return of the model.
|
154
180
|
|
155
|
-
|
181
|
+
This function requires no parameters and must return a PyTree.
|
182
|
+
|
183
|
+
It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
|
184
|
+
initialize the output delay.
|
156
185
|
|
157
186
|
"""
|
158
187
|
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
brainstate/nn/__init__.py
CHANGED
@@ -33,12 +33,14 @@ from ._embedding import *
|
|
33
33
|
from ._embedding import __all__ as embed_all
|
34
34
|
from ._exp_euler import *
|
35
35
|
from ._exp_euler import __all__ as exp_euler_all
|
36
|
-
from .
|
36
|
+
from ._fixedprob import *
|
37
|
+
from._fixedprob import __all__ as fixedprob_all
|
37
38
|
from ._inputs import *
|
38
39
|
from ._inputs import __all__ as inputs_all
|
39
40
|
from ._linear import *
|
40
41
|
from ._linear import __all__ as linear_all
|
41
|
-
from ._linear_mv import
|
42
|
+
from ._linear_mv import *
|
43
|
+
from ._linear_mv import __all__ as linear_mv_all
|
42
44
|
from ._ltp import *
|
43
45
|
from ._ltp import __all__ as ltp_all
|
44
46
|
from ._module import *
|
@@ -69,9 +71,6 @@ from ._utils import __all__ as utils_all
|
|
69
71
|
__all__ = (
|
70
72
|
[
|
71
73
|
'metrics',
|
72
|
-
'EventLinear',
|
73
|
-
'EventFixedProb',
|
74
|
-
'EventFixedNumConn',
|
75
74
|
]
|
76
75
|
+ collective_ops_all
|
77
76
|
+ common_all
|
@@ -87,6 +86,8 @@ __all__ = (
|
|
87
86
|
+ linear_all
|
88
87
|
+ normalizations_all
|
89
88
|
+ poolings_all
|
89
|
+
+ fixedprob_all
|
90
|
+
+ linear_mv_all
|
90
91
|
+ embed_all
|
91
92
|
+ dropout_all
|
92
93
|
+ elementwise_all
|
@@ -115,6 +116,8 @@ del (
|
|
115
116
|
normalizations_all,
|
116
117
|
poolings_all,
|
117
118
|
embed_all,
|
119
|
+
fixedprob_all,
|
120
|
+
linear_mv_all,
|
118
121
|
dropout_all,
|
119
122
|
elementwise_all,
|
120
123
|
dyn_neuron_all,
|
brainstate/nn/_delay.py
CHANGED
@@ -330,7 +330,14 @@ class Delay(Module):
|
|
330
330
|
indices = (delay_idx,) + indices
|
331
331
|
|
332
332
|
# the delay data
|
333
|
-
|
333
|
+
if self._unit is None:
|
334
|
+
return jax.tree.map(lambda a: a[indices], self.history.value)
|
335
|
+
else:
|
336
|
+
return jax.tree.map(
|
337
|
+
lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
|
338
|
+
self.history.value,
|
339
|
+
self._unit
|
340
|
+
)
|
334
341
|
|
335
342
|
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
336
343
|
"""
|
@@ -393,6 +400,9 @@ class Delay(Module):
|
|
393
400
|
"""
|
394
401
|
assert self.history is not None, 'The delay history is not initialized.'
|
395
402
|
|
403
|
+
if self.take_aware_unit and self._unit is None:
|
404
|
+
self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
|
405
|
+
|
396
406
|
# update the delay data at the rotation index
|
397
407
|
if self.delay_method == _DELAY_ROTATE:
|
398
408
|
i = environ.get(environ.I)
|
@@ -419,6 +429,8 @@ class Delay(Module):
|
|
419
429
|
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
420
430
|
|
421
431
|
|
432
|
+
|
433
|
+
|
422
434
|
class StateWithDelay(Delay):
|
423
435
|
"""
|
424
436
|
A ``State`` type that defines the state in a differential equation.
|
brainstate/nn/_dropout.py
CHANGED
@@ -409,7 +409,8 @@ class DropoutFixed(ElementWiseBlock):
|
|
409
409
|
self.out_size = in_size
|
410
410
|
|
411
411
|
def init_state(self, batch_size=None, **kwargs):
|
412
|
-
|
412
|
+
if self.prob < 1.:
|
413
|
+
self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
|
413
414
|
|
414
415
|
def update(self, x):
|
415
416
|
dtype = u.math.get_dtype(x)
|
@@ -418,8 +419,8 @@ class DropoutFixed(ElementWiseBlock):
|
|
418
419
|
if self.mask.value.shape != x.shape:
|
419
420
|
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
|
420
421
|
f"Please call `init_state()` method first.")
|
421
|
-
return
|
422
|
-
|
423
|
-
|
422
|
+
return u.math.where(self.mask.value,
|
423
|
+
u.math.asarray(x / self.prob, dtype=dtype),
|
424
|
+
u.math.asarray(0., dtype=dtype) * u.get_unit(x))
|
424
425
|
else:
|
425
426
|
return x
|