brainstate 0.1.0.post20250413__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +73 -0
- brainstate/_state.py +5 -4
- brainstate/_state_test.py +2 -1
- brainstate/augment/_autograd_test.py +3 -2
- brainstate/augment/_eval_shape.py +2 -1
- brainstate/augment/_mapping.py +0 -1
- brainstate/augment/_mapping_test.py +1 -0
- brainstate/compile/_ad_checkpoint.py +2 -1
- brainstate/compile/_conditions.py +4 -2
- brainstate/compile/_conditions_test.py +2 -1
- brainstate/compile/_error_if.py +2 -1
- brainstate/compile/_error_if_test.py +2 -1
- brainstate/compile/_jit.py +3 -2
- brainstate/compile/_jit_test.py +2 -1
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_loop_collect_return_test.py +2 -1
- brainstate/compile/_loop_no_collection.py +1 -1
- brainstate/compile/_make_jaxpr.py +10 -13
- brainstate/compile/_make_jaxpr_test.py +3 -6
- brainstate/compile/_progress_bar.py +2 -1
- brainstate/compile/_unvmap.py +1 -5
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +2 -1
- brainstate/functional/_activations.py +2 -1
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +2 -1
- brainstate/functional/_others.py +2 -1
- brainstate/graph/_graph_operation.py +3 -2
- brainstate/graph/_graph_operation_test.py +4 -3
- brainstate/init/_base.py +2 -1
- brainstate/init/_generic.py +2 -1
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +1 -0
- brainstate/nn/_collective_ops_test.py +0 -4
- brainstate/nn/_common.py +0 -1
- brainstate/nn/_dyn_impl/__init__.py +0 -4
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
- brainstate/nn/_dyn_impl/_inputs.py +236 -29
- brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
- brainstate/nn/_dyn_impl/_readout.py +91 -8
- brainstate/nn/_dyn_impl/_readout_test.py +2 -1
- brainstate/nn/_dynamics/_dynamics_base.py +676 -96
- brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
- brainstate/nn/_dynamics/_projection_base.py +29 -30
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +2 -1
- brainstate/nn/_elementwise/_dropout.py +3 -2
- brainstate/nn/_elementwise/_dropout_test.py +2 -1
- brainstate/nn/_elementwise/_elementwise.py +2 -1
- brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
- brainstate/nn/_event/_fixedprob_mv.py +169 -0
- brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
- brainstate/nn/_event/_linear_mv.py +85 -0
- brainstate/nn/_event/_linear_mv_test.py +121 -0
- brainstate/nn/_exp_euler.py +2 -1
- brainstate/nn/_exp_euler_test.py +2 -1
- brainstate/nn/_interaction/_conv.py +2 -1
- brainstate/nn/_interaction/_linear.py +2 -1
- brainstate/nn/_interaction/_linear_test.py +2 -1
- brainstate/nn/_interaction/_normalizations.py +2 -1
- brainstate/nn/_interaction/_poolings.py +4 -3
- brainstate/nn/_module_test.py +2 -1
- brainstate/nn/metrics.py +4 -3
- brainstate/optim/_lr_scheduler.py +2 -1
- brainstate/optim/_lr_scheduler_test.py +2 -1
- brainstate/optim/_optax_optimizer_test.py +2 -1
- brainstate/optim/_sgd_optimizer.py +3 -2
- brainstate/random/_rand_funs.py +2 -1
- brainstate/random/_rand_funs_test.py +3 -2
- brainstate/random/_rand_seed.py +3 -2
- brainstate/random/_rand_seed_test.py +2 -1
- brainstate/random/_rand_state.py +4 -3
- brainstate/surrogate.py +1 -5
- brainstate/typing.py +4 -4
- brainstate/util/_caller.py +2 -1
- brainstate/util/_others.py +4 -4
- brainstate/util/_pretty_pytree.py +1 -1
- brainstate/util/_pretty_pytree_test.py +2 -1
- brainstate/util/_pretty_table.py +43 -43
- brainstate/util/_struct.py +2 -1
- brainstate/util/filter.py +0 -1
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
- brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250413.dist-info/RECORD +0 -128
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
brainstate/nn/_exp_euler.py
CHANGED
@@ -16,9 +16,10 @@
|
|
16
16
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
from typing import Callable
|
20
|
+
|
19
21
|
import brainunit as u
|
20
22
|
import jax.numpy as jnp
|
21
|
-
from typing import Callable
|
22
23
|
|
23
24
|
from brainstate import environ, random
|
24
25
|
from brainstate.augment import vector_grad
|
brainstate/nn/_exp_euler_test.py
CHANGED
@@ -18,9 +18,10 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import collections.abc
|
21
|
+
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
22
|
+
|
21
23
|
import jax
|
22
24
|
import jax.numpy as jnp
|
23
|
-
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
24
25
|
|
25
26
|
from brainstate import init, functional
|
26
27
|
from brainstate._state import ParamState
|
@@ -17,9 +17,10 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from typing import Callable, Union, Optional
|
21
|
+
|
20
22
|
import brainunit as u
|
21
23
|
import jax.numpy as jnp
|
22
|
-
from typing import Callable, Union, Optional
|
23
24
|
|
24
25
|
from brainstate import init, functional
|
25
26
|
from brainstate._state import ParamState
|
@@ -17,9 +17,10 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from typing import Callable, Union, Sequence, Optional, Any
|
21
|
+
|
20
22
|
import jax
|
21
23
|
import jax.numpy as jnp
|
22
|
-
from typing import Callable, Union, Sequence, Optional, Any
|
23
24
|
|
24
25
|
from brainstate import environ, init
|
25
26
|
from brainstate._state import ParamState, BatchState
|
@@ -17,13 +17,14 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import brainunit as u
|
21
20
|
import functools
|
21
|
+
from typing import Sequence, Optional
|
22
|
+
from typing import Union, Tuple, Callable, List
|
23
|
+
|
24
|
+
import brainunit as u
|
22
25
|
import jax
|
23
26
|
import jax.numpy as jnp
|
24
27
|
import numpy as np
|
25
|
-
from typing import Sequence, Optional
|
26
|
-
from typing import Union, Tuple, Callable, List
|
27
28
|
|
28
29
|
from brainstate import environ
|
29
30
|
from brainstate.nn._module import Module
|
brainstate/nn/_module_test.py
CHANGED
brainstate/nn/metrics.py
CHANGED
@@ -16,13 +16,14 @@
|
|
16
16
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import jax
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import numpy as np
|
22
19
|
import typing as tp
|
23
20
|
from dataclasses import dataclass
|
24
21
|
from functools import partial
|
25
22
|
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
import numpy as np
|
26
|
+
|
26
27
|
from brainstate._state import State
|
27
28
|
|
28
29
|
__all__ = [
|
@@ -16,10 +16,11 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
from typing import Sequence, Union
|
20
|
+
|
19
21
|
import jax
|
20
22
|
import jax.numpy as jnp
|
21
23
|
import numpy as np
|
22
|
-
from typing import Sequence, Union
|
23
24
|
|
24
25
|
from brainstate import environ
|
25
26
|
from brainstate._state import State, LongTermState
|
@@ -16,11 +16,12 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import brainunit as u
|
20
19
|
import functools
|
20
|
+
from typing import Union, Dict, Optional, Tuple, Any, TypeVar
|
21
|
+
|
22
|
+
import brainunit as u
|
21
23
|
import jax
|
22
24
|
import jax.numpy as jnp
|
23
|
-
from typing import Union, Dict, Optional, Tuple, Any, TypeVar
|
24
25
|
|
25
26
|
from brainstate import environ
|
26
27
|
from brainstate._state import State, LongTermState, StateDictManager
|
brainstate/random/_rand_funs.py
CHANGED
@@ -17,9 +17,10 @@
|
|
17
17
|
# -*- coding: utf-8 -*-
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import numpy as np
|
21
20
|
from typing import Optional
|
22
21
|
|
22
|
+
import numpy as np
|
23
|
+
|
23
24
|
from brainstate.typing import DTypeLike, Size, SeedOrKey
|
24
25
|
from ._rand_state import RandomState, DEFAULT
|
25
26
|
|
@@ -15,12 +15,13 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import platform
|
19
|
+
import unittest
|
20
|
+
|
18
21
|
import jax.numpy as jnp
|
19
22
|
import jax.random as jr
|
20
23
|
import numpy as np
|
21
|
-
import platform
|
22
24
|
import pytest
|
23
|
-
import unittest
|
24
25
|
|
25
26
|
import brainstate as bst
|
26
27
|
|
brainstate/random/_rand_seed.py
CHANGED
@@ -14,11 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
import jax
|
18
|
-
import numpy as np
|
19
17
|
from contextlib import contextmanager
|
20
18
|
from typing import Optional
|
21
19
|
|
20
|
+
import jax
|
21
|
+
import numpy as np
|
22
|
+
|
22
23
|
from brainstate.typing import SeedOrKey
|
23
24
|
from ._rand_state import RandomState, DEFAULT, use_prng_key
|
24
25
|
|
brainstate/random/_rand_state.py
CHANGED
@@ -16,16 +16,17 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
from functools import partial
|
20
|
+
from operator import index
|
21
|
+
from typing import Optional
|
22
|
+
|
19
23
|
import brainunit as u
|
20
24
|
import jax
|
21
25
|
import jax.numpy as jnp
|
22
26
|
import jax.random as jr
|
23
27
|
import numpy as np
|
24
|
-
from functools import partial
|
25
28
|
from jax import jit, vmap
|
26
29
|
from jax import lax, core, dtypes
|
27
|
-
from operator import index
|
28
|
-
from typing import Optional
|
29
30
|
|
30
31
|
from brainstate import environ
|
31
32
|
from brainstate._state import State
|
brainstate/surrogate.py
CHANGED
@@ -21,13 +21,9 @@ import jax.numpy as jnp
|
|
21
21
|
import jax.scipy as sci
|
22
22
|
from jax.interpreters import batching, ad, mlir
|
23
23
|
|
24
|
+
from brainstate._compatible_import import Primitive
|
24
25
|
from brainstate.util._pretty_pytree import PrettyObject
|
25
26
|
|
26
|
-
if jax.__version_info__ < (0, 4, 38):
|
27
|
-
from jax.core import Primitive
|
28
|
-
else:
|
29
|
-
from jax.extend.core import Primitive
|
30
|
-
|
31
27
|
__all__ = [
|
32
28
|
'Surrogate',
|
33
29
|
'Sigmoid',
|
brainstate/typing.py
CHANGED
@@ -16,18 +16,18 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import builtins
|
19
|
-
|
20
|
-
import brainunit as u
|
21
19
|
import functools as ft
|
22
20
|
import importlib
|
23
21
|
import inspect
|
24
|
-
import jax
|
25
|
-
import numpy as np
|
26
22
|
from typing import (
|
27
23
|
Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
|
28
24
|
runtime_checkable, TYPE_CHECKING, Generic, Sequence
|
29
25
|
)
|
30
26
|
|
27
|
+
import brainunit as u
|
28
|
+
import jax
|
29
|
+
import numpy as np
|
30
|
+
|
31
31
|
tp = importlib.import_module("typing")
|
32
32
|
|
33
33
|
__all__ = [
|
brainstate/util/_caller.py
CHANGED
brainstate/util/_others.py
CHANGED
@@ -15,17 +15,17 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import gc
|
19
|
-
|
20
18
|
import copy
|
21
19
|
import functools
|
22
|
-
import
|
20
|
+
import gc
|
23
21
|
import threading
|
24
22
|
import types
|
25
23
|
from collections.abc import Iterable
|
26
|
-
from jax.lib import xla_bridge
|
27
24
|
from typing import Any, Callable, Tuple, Union, Dict
|
28
25
|
|
26
|
+
import jax
|
27
|
+
from jax.lib import xla_bridge
|
28
|
+
|
29
29
|
from brainstate._utils import set_module_as
|
30
30
|
|
31
31
|
__all__ = [
|
@@ -18,9 +18,9 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from collections import abc
|
21
|
+
from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
|
21
22
|
|
22
23
|
import jax
|
23
|
-
from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
|
24
24
|
|
25
25
|
from brainstate.typing import Filter, PathParts
|
26
26
|
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
|
brainstate/util/_pretty_table.py
CHANGED
@@ -422,14 +422,14 @@ class PrettyTable:
|
|
422
422
|
def _column_specific_args(self):
|
423
423
|
# Column specific arguments, use property.setters
|
424
424
|
for attr in (
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
425
|
+
"align",
|
426
|
+
"valign",
|
427
|
+
"max_width",
|
428
|
+
"min_width",
|
429
|
+
"int_format",
|
430
|
+
"float_format",
|
431
|
+
"custom_format",
|
432
|
+
"none_format",
|
433
433
|
):
|
434
434
|
setattr(
|
435
435
|
self, attr, (self._kwargs[attr] or {}) if attr in self._kwargs else {}
|
@@ -517,15 +517,15 @@ class PrettyTable:
|
|
517
517
|
elif option == "none_format":
|
518
518
|
self._validate_none_format(val)
|
519
519
|
elif option in (
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
520
|
+
"start",
|
521
|
+
"end",
|
522
|
+
"max_width",
|
523
|
+
"min_width",
|
524
|
+
"min_table_width",
|
525
|
+
"max_table_width",
|
526
|
+
"padding_width",
|
527
|
+
"left_padding_width",
|
528
|
+
"right_padding_width",
|
529
529
|
):
|
530
530
|
self._validate_nonnegative_int(option, val)
|
531
531
|
elif option == "sortby":
|
@@ -539,16 +539,16 @@ class PrettyTable:
|
|
539
539
|
elif option == "fields":
|
540
540
|
self._validate_all_field_names(option, val)
|
541
541
|
elif option in (
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
542
|
+
"header",
|
543
|
+
"border",
|
544
|
+
"preserve_internal_border",
|
545
|
+
"reversesort",
|
546
|
+
"xhtml",
|
547
|
+
"format",
|
548
|
+
"print_empty",
|
549
|
+
"oldsortslice",
|
550
|
+
"escape_header",
|
551
|
+
"escape_data",
|
552
552
|
):
|
553
553
|
self._validate_true_or_false(option, val)
|
554
554
|
elif option == "header_style":
|
@@ -561,18 +561,18 @@ class PrettyTable:
|
|
561
561
|
for k, formatter in val.items():
|
562
562
|
self._validate_function(f"{option}.{k}", formatter)
|
563
563
|
elif option in (
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
564
|
+
"vertical_char",
|
565
|
+
"horizontal_char",
|
566
|
+
"horizontal_align_char",
|
567
|
+
"junction_char",
|
568
|
+
"top_junction_char",
|
569
|
+
"bottom_junction_char",
|
570
|
+
"right_junction_char",
|
571
|
+
"left_junction_char",
|
572
|
+
"top_right_junction_char",
|
573
|
+
"top_left_junction_char",
|
574
|
+
"bottom_right_junction_char",
|
575
|
+
"bottom_left_junction_char",
|
576
576
|
):
|
577
577
|
self._validate_single_char(option, val)
|
578
578
|
elif option == "attributes":
|
@@ -1998,8 +1998,8 @@ class PrettyTable:
|
|
1998
1998
|
if options["header"]:
|
1999
1999
|
lines.append(self._stringify_header(options))
|
2000
2000
|
elif options["border"] and options["hrules"] in (
|
2001
|
-
|
2002
|
-
|
2001
|
+
HRuleStyle.ALL,
|
2002
|
+
HRuleStyle.FRAME,
|
2003
2003
|
):
|
2004
2004
|
lines.append(self._stringify_hrule(options, where="top_"))
|
2005
2005
|
if title and options["vrules"] in (VRuleStyle.ALL, VRuleStyle.FRAME):
|
@@ -2115,8 +2115,8 @@ class PrettyTable:
|
|
2115
2115
|
if options["hrules"] in (HRuleStyle.ALL, HRuleStyle.FRAME):
|
2116
2116
|
bits.append(self._stringify_hrule(options, "top_"))
|
2117
2117
|
if options["title"] and options["vrules"] in (
|
2118
|
-
|
2119
|
-
|
2118
|
+
VRuleStyle.ALL,
|
2119
|
+
VRuleStyle.FRAME,
|
2120
2120
|
):
|
2121
2121
|
left_j_len = len(self.left_junction_char)
|
2122
2122
|
right_j_len = len(self.right_junction_char)
|
brainstate/util/_struct.py
CHANGED
@@ -21,10 +21,11 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import collections
|
23
23
|
import dataclasses
|
24
|
-
import jax
|
25
24
|
from collections.abc import Hashable, Mapping
|
26
25
|
from types import MappingProxyType
|
27
26
|
from typing import Any, TypeVar
|
27
|
+
|
28
|
+
import jax
|
28
29
|
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
|
29
30
|
|
30
31
|
__all__ = [
|
brainstate/util/filter.py
CHANGED
{brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250422
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -15,7 +15,6 @@ Classifier: Development Status :: 4 - Beta
|
|
15
15
|
Classifier: Intended Audience :: Developers
|
16
16
|
Classifier: Intended Audience :: Science/Research
|
17
17
|
Classifier: Programming Language :: Python :: 3
|
18
|
-
Classifier: Programming Language :: Python :: 3.9
|
19
18
|
Classifier: Programming Language :: Python :: 3.10
|
20
19
|
Classifier: Programming Language :: Python :: 3.11
|
21
20
|
Classifier: Programming Language :: Python :: 3.12
|
@@ -33,6 +32,7 @@ Requires-Dist: jax
|
|
33
32
|
Requires-Dist: jaxlib
|
34
33
|
Requires-Dist: numpy
|
35
34
|
Requires-Dist: brainunit (>=0.0.4)
|
35
|
+
Requires-Dist: brainevent
|
36
36
|
Provides-Extra: cpu
|
37
37
|
Requires-Dist: jaxlib ; extra == 'cpu'
|
38
38
|
Provides-Extra: cuda12
|
@@ -56,7 +56,7 @@ Requires-Dist: jaxlib[tpu] ; extra == 'tpu'
|
|
56
56
|
<p align="center">
|
57
57
|
<a href="https://pypi.org/project/brainstate/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/brainstate"></a>
|
58
58
|
<a href="https://github.com/chaobrain/brainstate/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-Apache%202.0-blue.svg"></a>
|
59
|
-
<a href='https://brainstate.readthedocs.io
|
59
|
+
<a href='https://brainstate.readthedocs.io/?badge=latest'>
|
60
60
|
<img src='https://readthedocs.org/projects/brainstate/badge/?version=latest' alt='Documentation Status' />
|
61
61
|
</a>
|
62
62
|
<a href="https://badge.fury.io/py/brainstate"><img alt="PyPI version" src="https://badge.fury.io/py/brainstate.svg"></a>
|