brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240612__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +43 -5
- brainstate/_state.py +17 -0
- brainstate/environ.py +2 -1
- brainstate/functional/__init__.py +3 -2
- brainstate/functional/_activations.py +1 -1
- brainstate/functional/_normalization.py +3 -0
- brainstate/functional/_others.py +49 -0
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_base.py +10 -7
- brainstate/nn/_dynamics.py +20 -0
- brainstate/nn/_embedding.py +66 -0
- brainstate/nn/_rate_rnns.py +17 -0
- brainstate/nn/_readout.py +6 -0
- brainstate/optim/_lr_scheduler_test.py +13 -0
- brainstate/transform/_jit.py +47 -21
- brainstate/transform/_make_jaxpr.py +165 -3
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/METADATA +8 -6
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/RECORD +21 -29
- brainstate/nn/functional/__init__.py +0 -25
- brainstate/nn/functional/_activations.py +0 -754
- brainstate/nn/functional/_normalization.py +0 -69
- brainstate/nn/functional/_spikes.py +0 -90
- brainstate/nn/init/__init__.py +0 -26
- brainstate/nn/init/_base.py +0 -36
- brainstate/nn/init/_generic.py +0 -175
- brainstate/nn/init/_random_inits.py +0 -489
- brainstate/nn/init/_regular_inits.py +0 -109
- brainstate/nn/surrogate.py +0 -1740
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/top_level.txt +0 -0
@@ -54,20 +54,27 @@ function.
|
|
54
54
|
from __future__ import annotations
|
55
55
|
|
56
56
|
import functools
|
57
|
+
import inspect
|
57
58
|
import operator
|
58
59
|
from collections.abc import Hashable, Iterable, Sequence
|
60
|
+
from contextlib import ExitStack
|
59
61
|
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
60
62
|
|
61
63
|
import jax
|
62
64
|
from jax._src import source_info_util
|
65
|
+
from jax._src.linear_util import annotate
|
66
|
+
from jax._src.traceback_util import api_boundary
|
67
|
+
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
63
68
|
from jax.interpreters import partial_eval as pe
|
64
|
-
from jax.util import wraps
|
65
69
|
from jax.interpreters.xla import abstractify
|
70
|
+
from jax.util import wraps
|
66
71
|
|
67
72
|
from brainstate._state import State, StateTrace
|
68
73
|
from brainstate._utils import set_module_as
|
69
74
|
|
70
75
|
PyTree = Any
|
76
|
+
AxisName = Hashable
|
77
|
+
|
71
78
|
|
72
79
|
__all__ = [
|
73
80
|
"StatefulFunction",
|
@@ -393,7 +400,8 @@ class StatefulFunction(object):
|
|
393
400
|
if cache_key not in self._state_trace:
|
394
401
|
try:
|
395
402
|
# jaxpr
|
396
|
-
jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
|
403
|
+
# jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
|
404
|
+
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
397
405
|
functools.partial(self._wrapped_fun_to_eval, cache_key),
|
398
406
|
static_argnums=self.static_argnums,
|
399
407
|
axis_env=self.axis_env,
|
@@ -474,7 +482,8 @@ def make_jaxpr(
|
|
474
482
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
475
483
|
) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
|
476
484
|
Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
|
477
|
-
"""
|
485
|
+
"""
|
486
|
+
Creates a function that produces its jaxpr given example args.
|
478
487
|
|
479
488
|
Args:
|
480
489
|
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
@@ -571,3 +580,156 @@ def make_jaxpr(
|
|
571
580
|
if hasattr(fun, "__name__"):
|
572
581
|
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
573
582
|
return make_jaxpr_f
|
583
|
+
|
584
|
+
|
585
|
+
def _check_callable(fun):
|
586
|
+
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
587
|
+
# is that we can't take weak references to them, which the C++ JIT requires.
|
588
|
+
if isinstance(fun, staticmethod):
|
589
|
+
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
590
|
+
if not callable(fun):
|
591
|
+
raise TypeError(f"Expected a callable value, got {fun}")
|
592
|
+
if inspect.isgeneratorfunction(fun):
|
593
|
+
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
594
|
+
|
595
|
+
|
596
|
+
def _broadcast_prefix(
|
597
|
+
prefix_tree: Any,
|
598
|
+
full_tree: Any,
|
599
|
+
is_leaf: Callable[[Any], bool] | None = None
|
600
|
+
) -> list[Any]:
|
601
|
+
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
|
602
|
+
# ValueError; use prefix_errors to find disagreements and raise more precise
|
603
|
+
# error messages.
|
604
|
+
result = []
|
605
|
+
num_leaves = lambda t: jax.tree.structure(t).num_leaves
|
606
|
+
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
|
607
|
+
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
608
|
+
return result
|
609
|
+
|
610
|
+
|
611
|
+
def _flat_axes_specs(
|
612
|
+
abstracted_axes, *args, **kwargs
|
613
|
+
) -> list[pe.AbstractedAxesSpec]:
|
614
|
+
if kwargs:
|
615
|
+
raise NotImplementedError
|
616
|
+
|
617
|
+
def ax_leaf(l):
|
618
|
+
return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
|
619
|
+
isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
|
620
|
+
|
621
|
+
return _broadcast_prefix(abstracted_axes, args, ax_leaf)
|
622
|
+
|
623
|
+
|
624
|
+
@transformation_with_aux
|
625
|
+
def _flatten_fun(in_tree, *args_flat):
|
626
|
+
py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
|
627
|
+
ans = yield py_args, py_kwargs
|
628
|
+
yield jax.tree.flatten(ans)
|
629
|
+
|
630
|
+
|
631
|
+
def _make_jaxpr(
|
632
|
+
fun: Callable,
|
633
|
+
static_argnums: int | Iterable[int] = (),
|
634
|
+
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
635
|
+
return_shape: bool = False,
|
636
|
+
abstracted_axes: Any | None = None,
|
637
|
+
) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
|
638
|
+
"""Creates a function that produces its jaxpr given example args.
|
639
|
+
|
640
|
+
Args:
|
641
|
+
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
642
|
+
arguments and return value should be arrays, scalars, or standard Python
|
643
|
+
containers (tuple/list/dict) thereof.
|
644
|
+
static_argnums: See the :py:func:`jax.jit` docstring.
|
645
|
+
axis_env: Optional, a sequence of pairs where the first element is an axis
|
646
|
+
name and the second element is a positive integer representing the size of
|
647
|
+
the mapped axis with that name. This parameter is useful when lowering
|
648
|
+
functions that involve parallel communication collectives, and it
|
649
|
+
specifies the axis name/size environment that would be set up by
|
650
|
+
applications of :py:func:`jax.pmap`.
|
651
|
+
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
652
|
+
wrapped function returns a pair where the first element is the
|
653
|
+
``ClosedJaxpr`` representation of ``fun`` and the second element is a
|
654
|
+
pytree with the same structure as the output of ``fun`` and where the
|
655
|
+
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
|
656
|
+
attributes representing the corresponding types of the output leaves.
|
657
|
+
|
658
|
+
Returns:
|
659
|
+
A wrapped version of ``fun`` that when applied to example arguments returns
|
660
|
+
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
661
|
+
argument ``return_shape`` is ``True``, then the returned function instead
|
662
|
+
returns a pair where the first element is the ``ClosedJaxpr``
|
663
|
+
representation of ``fun`` and the second element is a pytree representing
|
664
|
+
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
665
|
+
|
666
|
+
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
667
|
+
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
668
|
+
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
669
|
+
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
670
|
+
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
671
|
+
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
672
|
+
|
673
|
+
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
674
|
+
instead give a few examples.
|
675
|
+
|
676
|
+
>>> import jax
|
677
|
+
>>>
|
678
|
+
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
679
|
+
>>> print(f(3.0))
|
680
|
+
-0.83602
|
681
|
+
>>> _make_jaxpr(f)(3.0)
|
682
|
+
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
683
|
+
>>> _make_jaxpr(jax.grad(f))(3.0)
|
684
|
+
{ lambda ; a:f32[]. let
|
685
|
+
b:f32[] = cos a
|
686
|
+
c:f32[] = sin a
|
687
|
+
_:f32[] = sin b
|
688
|
+
d:f32[] = cos b
|
689
|
+
e:f32[] = mul 1.0 d
|
690
|
+
f:f32[] = neg e
|
691
|
+
g:f32[] = mul f c
|
692
|
+
in (g,) }
|
693
|
+
"""
|
694
|
+
_check_callable(fun)
|
695
|
+
static_argnums = _ensure_index_tuple(static_argnums)
|
696
|
+
|
697
|
+
def _abstractify(args, kwargs):
|
698
|
+
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
699
|
+
if abstracted_axes is None:
|
700
|
+
return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
701
|
+
else:
|
702
|
+
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
703
|
+
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
704
|
+
in_avals, keep_inputs = jax.util.unzip2(in_type)
|
705
|
+
return in_avals, in_tree, keep_inputs
|
706
|
+
|
707
|
+
@wraps(fun)
|
708
|
+
@api_boundary
|
709
|
+
def make_jaxpr_f(*args, **kwargs):
|
710
|
+
f = wrap_init(fun)
|
711
|
+
if static_argnums:
|
712
|
+
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
713
|
+
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
714
|
+
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
|
715
|
+
in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
|
716
|
+
f, out_tree = _flatten_fun(f, in_tree)
|
717
|
+
f = annotate(f, in_type)
|
718
|
+
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
719
|
+
with ExitStack() as stack:
|
720
|
+
for axis_name, size in axis_env or []:
|
721
|
+
stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
|
722
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
723
|
+
closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
|
724
|
+
if return_shape:
|
725
|
+
out_avals, _ = jax.util.unzip2(out_type)
|
726
|
+
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
727
|
+
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
|
728
|
+
return closed_jaxpr
|
729
|
+
|
730
|
+
make_jaxpr_f.__module__ = "brainstate.transform"
|
731
|
+
if hasattr(fun, "__qualname__"):
|
732
|
+
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
733
|
+
if hasattr(fun, "__name__"):
|
734
|
+
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
735
|
+
return make_jaxpr_f
|
@@ -1,8 +1,8 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.1
|
3
|
+
Version: 0.0.1.post20240612
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
|
-
Home-page: https://github.com/brainpy/
|
5
|
+
Home-page: https://github.com/brainpy/brainstate
|
6
6
|
Author: BrainPy Team
|
7
7
|
Author-email: BrainPy Team <chao.brain@qq.com>
|
8
8
|
License: Apache-2.0 license
|
@@ -90,12 +90,14 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
|
|
90
90
|
|
91
91
|
## See also the BDP ecosystem
|
92
92
|
|
93
|
-
- [``
|
93
|
+
- [``brainstate``](https://github.com/brainpy/brainstate): A ``State``-based transformation system for brain dynamics programming.
|
94
94
|
|
95
|
-
- [``
|
95
|
+
- [``brainunit``](https://github.com/brainpy/brainunit): The unit system for brain dynamics programming.
|
96
96
|
|
97
|
-
- [``
|
97
|
+
- [``braintaichi``](https://github.com/brainpy/braintaichi): Leveraging Taichi Lang to customize brain dynamics operators.
|
98
98
|
|
99
|
-
- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning for biological
|
99
|
+
- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning framework for biological neural networks.
|
100
|
+
|
101
|
+
- [``braintools``](https://github.com/brainpy/braintools): The toolbox for the brain dynamics simulation, training and analysis.
|
100
102
|
|
101
103
|
|
@@ -1,10 +1,10 @@
|
|
1
1
|
brainstate/__init__.py,sha256=3R0I9oLpIS6Z9iwmx5ODzSlZGX3MbYWvCWsFHaCiaG4,1436
|
2
|
-
brainstate/_module.py,sha256=
|
2
|
+
brainstate/_module.py,sha256=R3pBeNvqR_mEfquZU60uWj7JmopOCUciF2BgcJyA0aw,48151
|
3
3
|
brainstate/_module_test.py,sha256=4tqtp2-j5mSoUmCITY0mVZEcXzxXCWJ_02Jdt1fxYJg,4502
|
4
|
-
brainstate/_state.py,sha256=
|
4
|
+
brainstate/_state.py,sha256=RWnLjMeaidxWXNAA0X-8mxj4i61j3T8w5KhugACUYhI,11422
|
5
5
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
6
6
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
7
|
-
brainstate/environ.py,sha256=
|
7
|
+
brainstate/environ.py,sha256=RMDUACuixwk2ZTHf0UGLhcd5DCraW-l1j9T3wc2wcFc,10242
|
8
8
|
brainstate/mixin.py,sha256=V75vjMTzYcCMlPo5wekgRZZ9o6-xN8kJocQgEliu5yI,10738
|
9
9
|
brainstate/mixin_test.py,sha256=qDYqhHbHw3aBFW8aHQdPhov29013Eo9TJDF7RW2dapE,2919
|
10
10
|
brainstate/random.py,sha256=Mi5i0kAsR8C-VoI8LMuIbPPr6YFzq6NBxhJ5K0w2qW4,186392
|
@@ -12,9 +12,10 @@ brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,178
|
|
12
12
|
brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
13
13
|
brainstate/typing.py,sha256=Ooweu7c17nYP686fyIeKNomChodSxx_OEpu8QRoB9cY,2180
|
14
14
|
brainstate/util.py,sha256=FrBN_OZAPlWxfNK8c9Z1d-bbIa8qwMrcOsSJZJS8xOE,19878
|
15
|
-
brainstate/functional/__init__.py,sha256=
|
16
|
-
brainstate/functional/_activations.py,sha256=
|
17
|
-
brainstate/functional/_normalization.py,sha256=
|
15
|
+
brainstate/functional/__init__.py,sha256=Z-43coOHFAsQK0u5amlr4l0fNNPc7dVcuKXfNY4Gj_s,1107
|
16
|
+
brainstate/functional/_activations.py,sha256=IfZ6Zy8SAwyxo166E3NmCZMUHnG_rBFAUaLTyxG5FgA,18490
|
17
|
+
brainstate/functional/_normalization.py,sha256=IxE580waloZylZVXcpUUK4bWQdlE6oSPfafaKYfDkbg,2169
|
18
|
+
brainstate/functional/_others.py,sha256=ifB-l82y7ZB632yLUJOEcpkRY-yOoiJ0mtDOxNilp4M,1711
|
18
19
|
brainstate/functional/_spikes.py,sha256=uAln_Q87pr1codLxeDck3PUA9jpk7S5LifNps1kdyrU,2576
|
19
20
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
20
21
|
brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
|
@@ -28,52 +29,43 @@ brainstate/math/_einops_parsing_test.py,sha256=JPn73yld300481J6E9cL7jHWn63Vr21VV
|
|
28
29
|
brainstate/math/_einops_test.py,sha256=xj-DDTL0EsW1Obm64KCnT7eqELWjjj04Ozdwk0839Tw,13289
|
29
30
|
brainstate/math/_misc.py,sha256=jDtREP4ojxHyj6lXcLcYLGVsLA0HFZcrs8cdlnA7aK8,7863
|
30
31
|
brainstate/math/_misc_test.py,sha256=V41YV-RiEbukKQlzq54174cpSalOhMjaHOoVH8o82eI,2443
|
31
|
-
brainstate/nn/__init__.py,sha256=
|
32
|
-
brainstate/nn/_base.py,sha256=
|
32
|
+
brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
|
33
|
+
brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
|
33
34
|
brainstate/nn/_connections.py,sha256=GSOW2IbpJRHdPyF4nFJ2RPgO8y6SVHT1Gn-pbri9pMk,22970
|
34
|
-
brainstate/nn/_dynamics.py,sha256=
|
35
|
+
brainstate/nn/_dynamics.py,sha256=OeYYXv1dqjUDcCsRhZo1XS7SP2li1vlH9uhME_PE9v0,13205
|
35
36
|
brainstate/nn/_elementwise.py,sha256=T1oCu47t11Ki7LPaL-hHk4W8bKP_Q3HLJcGngcmGK0Y,43552
|
37
|
+
brainstate/nn/_embedding.py,sha256=WbgrIaM_14abN8zBDr0xipBOsFc8dXP2m7Z_aRLAfmU,2249
|
36
38
|
brainstate/nn/_misc.py,sha256=Z7gdJraJ18gVNHyNOk_KmE67M3OM4z3QT4RN6al5JMc,3766
|
37
39
|
brainstate/nn/_normalizations.py,sha256=9yVDORAEpqEkL9MYSPU4m7C4q8Qj5UNsPh9sKmIt5gQ,14329
|
38
40
|
brainstate/nn/_others.py,sha256=8PYmCiUNzru4kmm58HY0RzCs-32dnwNFDZdTTPixaqo,4492
|
39
41
|
brainstate/nn/_poolings.py,sha256=cNZ1PyMIaViP-_AUkEbpy3ZfHo--ib1hAhL0bEAmXIQ,45688
|
40
42
|
brainstate/nn/_poolings_test.py,sha256=iE0NgvOIWVgwmcvP4wazhGG4RJQdU2eeagdJ1sDXIBQ,7260
|
41
|
-
brainstate/nn/_rate_rnns.py,sha256=
|
42
|
-
brainstate/nn/_readout.py,sha256=
|
43
|
+
brainstate/nn/_rate_rnns.py,sha256=Cebhy57UWzfwrCfq0v2qLDegmb__mXL5ht750y4aTro,14457
|
44
|
+
brainstate/nn/_readout.py,sha256=jsQwhVnrJICKw4wFq-Du2AORPb_XXz_tZ4cURcckU-E,4240
|
43
45
|
brainstate/nn/_synouts.py,sha256=gi3EyKlzt4UoyghwvNIr03r7YabZyl1idbq9aYG8zYM,4379
|
44
|
-
brainstate/nn/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
45
46
|
brainstate/nn/_projection/__init__.py,sha256=L6svNHTb8BDh2rdX2eYmcx_NdscSdKykkQbzpdCSkTA,1207
|
46
47
|
brainstate/nn/_projection/_align_post.py,sha256=dAdNsuNf7jo8Qsh3uHXLonv9iDi3J9AnWzmFaG3b3bo,20655
|
47
48
|
brainstate/nn/_projection/_align_pre.py,sha256=R2U6_RQ_o8y6PWXpozeWE2cx_oQ7WMhhrBR9hZtEBTs,24597
|
48
49
|
brainstate/nn/_projection/_delta.py,sha256=KT8ySo3n_Q_7swzOH-ISDf0x9rjMkiv99H-vqeQZDR8,7122
|
49
50
|
brainstate/nn/_projection/_utils.py,sha256=UcmELOqsINgqJr7eC5BSNNteyZ--1lyGjhUTJfxyMmA,813
|
50
51
|
brainstate/nn/_projection/_vanilla.py,sha256=_bh_DLtF0o33SBtj6IGL8CTanFEtJwfjBrgxBEAmIlg,3397
|
51
|
-
brainstate/nn/functional/__init__.py,sha256=6aMwiOl2UGEd6b5SBpGaB0s2QG1iF8rRFoOfvScJ4Dg,1020
|
52
|
-
brainstate/nn/functional/_activations.py,sha256=8mgmctDoK_oO8D3JV-YBlLMrsg5Dry0OPcyBz_2Ur0I,18498
|
53
|
-
brainstate/nn/functional/_normalization.py,sha256=aLe_CsVl8gZCMTcTVmrI3mEtaLmBWqJyIHQRDgslSHw,2090
|
54
|
-
brainstate/nn/functional/_spikes.py,sha256=uAln_Q87pr1codLxeDck3PUA9jpk7S5LifNps1kdyrU,2576
|
55
|
-
brainstate/nn/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
56
|
-
brainstate/nn/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
|
57
|
-
brainstate/nn/init/_generic.py,sha256=OJFS7DHYmZV0JogdsgjnUseUfvTUrAUYiXZynCQqmG4,5163
|
58
|
-
brainstate/nn/init/_random_inits.py,sha256=STbX-mrHwNuICXkw7EldtJLdUUsWOAcGkEzx2ycV-Yc,15321
|
59
|
-
brainstate/nn/init/_regular_inits.py,sha256=n-vF-51FM1UcUh-8h5lUk5Jhjrn04KPcGXgGhUGFAAk,3065
|
60
52
|
brainstate/optim/__init__.py,sha256=1xH5_peSWKuZ4tOU295r9EKAv0a-cBMABx6XV3faDJI,919
|
61
53
|
brainstate/optim/_lr_scheduler.py,sha256=emKnA52UVqOfUcX7LJqwP-FVDVlGGzTQi2djYmbCWUo,15627
|
62
|
-
brainstate/optim/_lr_scheduler_test.py,sha256=
|
54
|
+
brainstate/optim/_lr_scheduler_test.py,sha256=OwF8Iz-PorEbO0gO--A7IIgQEytqEfYWbPucAgzqL90,1598
|
63
55
|
brainstate/optim/_sgd_optimizer.py,sha256=7-jMfP_Hol0XGEA6_4wVqygpLTqI1646F6eeLtwtNFY,45760
|
64
56
|
brainstate/transform/__init__.py,sha256=9S9TLp1sF6nWRmW6jFtu6_dLmOc43V88Ruh073Z8I50,1460
|
65
57
|
brainstate/transform/_autograd.py,sha256=sFGJ6oAhlSr54Hb1c1aNc5Q2St7eIr_X77lupc31YXg,23964
|
66
58
|
brainstate/transform/_autograd_test.py,sha256=epQ2z97fAp_dQ_CwWGZD7sgw-p9o9fGfSeOUAJiiDY0,38658
|
67
59
|
brainstate/transform/_control.py,sha256=NWceTIuLlj2uGTdNcqBAXgnaLuChOGgAtIXtFn5vdLU,26837
|
68
60
|
brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
|
69
|
-
brainstate/transform/_jit.py,sha256=
|
61
|
+
brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
|
70
62
|
brainstate/transform/_jit_error.py,sha256=lO_e5AdhkjozHjM10q0b57OaXbeZ9gQkVmZMN6VQVCw,4450
|
71
63
|
brainstate/transform/_jit_test.py,sha256=lVXvScfXExhXwFi8jnvEY6stNVulZHCzriamajFqzrY,2891
|
72
|
-
brainstate/transform/_make_jaxpr.py,sha256=
|
64
|
+
brainstate/transform/_make_jaxpr.py,sha256=MTeBpPO1thu5yDytWoJijySHV7-nWmUoBMC0RCbdzcY,29972
|
73
65
|
brainstate/transform/_make_jaxpr_test.py,sha256=4nEwZv_ebgUZgV86vOJFO_qC69mw2F3rogViF2SC1Qs,3823
|
74
66
|
brainstate/transform/_progress_bar.py,sha256=myrAkBcUfuVGFLVwFzeSe5vdg1z49ARKqTlccG92maA,3536
|
75
|
-
brainstate-0.0.1.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
76
|
-
brainstate-0.0.1.dist-info/METADATA,sha256=
|
77
|
-
brainstate-0.0.1.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
78
|
-
brainstate-0.0.1.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
79
|
-
brainstate-0.0.1.dist-info/RECORD,,
|
67
|
+
brainstate-0.0.1.post20240612.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
68
|
+
brainstate-0.0.1.post20240612.dist-info/METADATA,sha256=VRHXnO0TBRcoo_M4iFHsywjCJbhonpStzYSmRXkR_wM,4254
|
69
|
+
brainstate-0.0.1.post20240612.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
70
|
+
brainstate-0.0.1.post20240612.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
71
|
+
brainstate-0.0.1.post20240612.dist-info/RECORD,,
|
@@ -1,25 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from ._activations import *
|
18
|
-
from ._activations import __all__ as __activations_all__
|
19
|
-
from ._normalization import *
|
20
|
-
from ._normalization import __all__ as __others_all__
|
21
|
-
from ._spikes import *
|
22
|
-
from ._spikes import __all__ as __spikes_all__
|
23
|
-
|
24
|
-
__all__ = __spikes_all__ + __others_all__ + __activations_all__
|
25
|
-
|