brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -33,6 +33,7 @@ from typing import (
|
|
33
33
|
Generator,
|
34
34
|
)
|
35
35
|
|
36
|
+
import brainunit as u
|
36
37
|
import jax
|
37
38
|
import numpy as np
|
38
39
|
from jax.api_util import shaped_abstractify
|
@@ -47,6 +48,8 @@ __all__ = [
|
|
47
48
|
'ShortTermState',
|
48
49
|
'LongTermState',
|
49
50
|
'HiddenState',
|
51
|
+
'HiddenGroupState',
|
52
|
+
'HiddenTreeState',
|
50
53
|
'ParamState',
|
51
54
|
'BatchState',
|
52
55
|
'TreefyState',
|
@@ -213,7 +216,7 @@ class State(Generic[A], PrettyObject):
|
|
213
216
|
tag (Optional[str]): An optional tag for categorizing or grouping states.
|
214
217
|
|
215
218
|
Args:
|
216
|
-
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
219
|
+
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
217
220
|
The initial value for the state. Can be a PyTree of array-like objects
|
218
221
|
or a StateMetadata object.
|
219
222
|
name (Optional[str]): An optional name for the state.
|
@@ -233,13 +236,13 @@ class State(Generic[A], PrettyObject):
|
|
233
236
|
[0. 0. 0.]]
|
234
237
|
|
235
238
|
Note:
|
236
|
-
- Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
|
239
|
+
- Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
|
237
240
|
RandomState) are typically used for specific purposes in a program.
|
238
|
-
- The class integrates with BrainState's tracing system to track state
|
241
|
+
- The class integrates with BrainState's tracing system to track state
|
239
242
|
creation and modifications.
|
240
243
|
|
241
244
|
The typical examples of :py:class:`~.State` subclass are:
|
242
|
-
|
245
|
+
|
243
246
|
- :py:class:`ShortTermState`: The short-term state, which is used to store the short-term data in the program.
|
244
247
|
- :py:class:`LongTermState`: The long-term state, which is used to store the long-term data in the program.
|
245
248
|
- :py:class:`ParamState`: The parameter state, which is used to store the parameters in the program.
|
@@ -271,7 +274,7 @@ class State(Generic[A], PrettyObject):
|
|
271
274
|
handling various input types and metadata.
|
272
275
|
|
273
276
|
Args:
|
274
|
-
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
277
|
+
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
275
278
|
The initial value for the hidden state. Can be a PyTree of array-like objects
|
276
279
|
or a StateMetadata object containing both value and metadata.
|
277
280
|
name (Optional[str], optional): A name for the hidden state. Defaults to None.
|
@@ -738,7 +741,7 @@ class BatchState(LongTermState):
|
|
738
741
|
|
739
742
|
class HiddenState(ShortTermState):
|
740
743
|
"""
|
741
|
-
|
744
|
+
Represents hidden state variables in neurons or synapses.
|
742
745
|
|
743
746
|
This class extends :class:`ShortTermState` and is specifically designed to represent
|
744
747
|
and manage hidden states within dynamic models, such as recurrent neural networks.
|
@@ -751,6 +754,12 @@ class HiddenState(ShortTermState):
|
|
751
754
|
or networks. The latter is used to store the trainable parameters in the model,
|
752
755
|
such as synaptic weights.
|
753
756
|
|
757
|
+
Note:
|
758
|
+
From version 0.2.2, :class:`HiddenState` only supports value of numpy.ndarray,
|
759
|
+
jax.Array or brainunit.Quantity. Moreover, it is equivalent to :class:`brainscale.ETraceState`.
|
760
|
+
Dynamics models defined with :class:`HiddenState` can be seamlessly integrated with
|
761
|
+
BrainScale online learning.
|
762
|
+
|
754
763
|
Example:
|
755
764
|
>>> lstm_hidden = HiddenState(np.zeros(128), name="lstm_hidden_state")
|
756
765
|
>>> gru_hidden = HiddenState(np.zeros(64), name="gru_hidden_state")
|
@@ -758,6 +767,502 @@ class HiddenState(ShortTermState):
|
|
758
767
|
|
759
768
|
__module__ = 'brainstate'
|
760
769
|
|
770
|
+
value: ArrayLike
|
771
|
+
|
772
|
+
def __init__(self, value: ArrayLike, name: Optional[str] = None):
|
773
|
+
self._check_value(value)
|
774
|
+
super().__init__(value, name=name)
|
775
|
+
|
776
|
+
@property
|
777
|
+
def varshape(self) -> Tuple[int, ...]:
|
778
|
+
"""
|
779
|
+
Get the shape of the hidden state variable.
|
780
|
+
|
781
|
+
This property returns the shape of the hidden state variable stored in the instance.
|
782
|
+
It provides the dimensions of the array representing the hidden state.
|
783
|
+
|
784
|
+
Returns:
|
785
|
+
Tuple[int, ...]: A tuple representing the shape of the hidden state variable.
|
786
|
+
"""
|
787
|
+
return self.value.shape
|
788
|
+
|
789
|
+
@property
|
790
|
+
def num_state(self) -> int:
|
791
|
+
"""
|
792
|
+
Get the number of hidden states.
|
793
|
+
|
794
|
+
This property returns the number of hidden states represented by the instance.
|
795
|
+
For the `ETraceState` class, this is always 1, as it represents a single hidden state.
|
796
|
+
|
797
|
+
Returns:
|
798
|
+
int: The number of hidden states, which is 1 for this class.
|
799
|
+
"""
|
800
|
+
return 1
|
801
|
+
|
802
|
+
def _check_value(self, value: ArrayLike):
|
803
|
+
if not isinstance(value, (np.ndarray, jax.Array, u.Quantity)):
|
804
|
+
raise TypeError(
|
805
|
+
f'Currently, {HiddenState.__name__} only supports '
|
806
|
+
f'numpy.ndarray, jax.Array or brainunit.Quantity. '
|
807
|
+
f'But we got {type(value)}.'
|
808
|
+
)
|
809
|
+
|
810
|
+
|
811
|
+
class HiddenGroupState(HiddenState):
|
812
|
+
"""
|
813
|
+
A group of multiple hidden states for eligibility trace-based learning.
|
814
|
+
|
815
|
+
This class is used to define multiple hidden states within a single instance
|
816
|
+
of :py:class:`ETraceState`. Normally, you should define multiple instances
|
817
|
+
of :py:class:`ETraceState` to represent multiple hidden states. But
|
818
|
+
:py:class:`HiddenGroupState` let your define multiple hidden states within
|
819
|
+
a single instance.
|
820
|
+
|
821
|
+
The following is the way to initialize the hidden states.
|
822
|
+
|
823
|
+
.. code-block:: python
|
824
|
+
|
825
|
+
import brainunit as u
|
826
|
+
value = np.random.randn(10, 10, 5) * u.mV
|
827
|
+
state = HiddenGroupState(value)
|
828
|
+
|
829
|
+
Then, you can retrieve the hidden state value with the following method.
|
830
|
+
|
831
|
+
.. code-block:: python
|
832
|
+
|
833
|
+
state.get_value(0) # get the first hidden state
|
834
|
+
# or
|
835
|
+
state.get_value('0') # get the hidden state with the name '0'
|
836
|
+
|
837
|
+
You can write the hidden state value with the following method.
|
838
|
+
|
839
|
+
.. code-block:: python
|
840
|
+
|
841
|
+
state.set_value({0: np.random.randn(10, 10) * u.mV}) # set the first hidden state
|
842
|
+
# or
|
843
|
+
state.set_value({'0': np.random.randn(10, 10) * u.mV}) # set the hidden state with the name '0'
|
844
|
+
# or
|
845
|
+
state.value = np.random.randn(10, 10, 5) * u.mV # set all hidden state value
|
846
|
+
|
847
|
+
Args:
|
848
|
+
value: The values of the hidden states. It can be a sequence of hidden states,
|
849
|
+
or a single hidden state with the last dimension as the number of hidden states,
|
850
|
+
or a dictionary of hidden states.
|
851
|
+
"""
|
852
|
+
|
853
|
+
__module__ = 'brainstate'
|
854
|
+
value: ArrayLike
|
855
|
+
name2index: Dict[str, int]
|
856
|
+
|
857
|
+
def __init__(self, value: ArrayLike):
|
858
|
+
value, name2index = self._check_value(value)
|
859
|
+
self.name2index = name2index
|
860
|
+
ShortTermState.__init__(self, value)
|
861
|
+
|
862
|
+
@property
|
863
|
+
def varshape(self) -> Tuple[int, ...]:
|
864
|
+
"""
|
865
|
+
Get the shape of each hidden state variable.
|
866
|
+
|
867
|
+
This property returns the shape of the hidden state variables, excluding
|
868
|
+
the last dimension which represents the number of hidden states.
|
869
|
+
|
870
|
+
Returns:
|
871
|
+
Tuple[int, ...]: A tuple representing the shape of each hidden state variable.
|
872
|
+
"""
|
873
|
+
return self.value.shape[:-1]
|
874
|
+
|
875
|
+
@property
|
876
|
+
def num_state(self) -> int:
|
877
|
+
"""
|
878
|
+
Get the number of hidden states.
|
879
|
+
|
880
|
+
This property returns the number of hidden states represented by the last dimension
|
881
|
+
of the value array.
|
882
|
+
|
883
|
+
Returns:
|
884
|
+
int: The number of hidden states.
|
885
|
+
"""
|
886
|
+
return self.value.shape[-1]
|
887
|
+
|
888
|
+
def _check_value(self, value) -> Tuple[ArrayLike, Dict[str, int]]:
|
889
|
+
"""
|
890
|
+
Validates the input value for hidden states and returns a tuple containing
|
891
|
+
the processed value and a dictionary mapping state names to indices.
|
892
|
+
|
893
|
+
This function ensures that the input value is of a supported type and has
|
894
|
+
the required dimensionality for hidden states. It also constructs a mapping
|
895
|
+
from string representations of indices to their integer counterparts.
|
896
|
+
|
897
|
+
Parameters
|
898
|
+
----------
|
899
|
+
value (ArrayLike): The input value representing hidden states.
|
900
|
+
It must be an instance of numpy.ndarray, jax.Array, or brainunit.Quantity
|
901
|
+
with at least two dimensions.
|
902
|
+
|
903
|
+
Returns
|
904
|
+
-------
|
905
|
+
Tuple[ArrayLike, Dict[str, int]]: A tuple containing:
|
906
|
+
- The validated and possibly modified input value.
|
907
|
+
- A dictionary mapping string representations of indices to integer indices.
|
908
|
+
|
909
|
+
Raises
|
910
|
+
------
|
911
|
+
TypeError: If the input value is not of a supported type.
|
912
|
+
ValueError: If the input value does not have the required number of dimensions.
|
913
|
+
"""
|
914
|
+
if not isinstance(value, (np.ndarray, jax.Array, u.Quantity)):
|
915
|
+
raise TypeError(
|
916
|
+
f'Currently, {self.__class__.__name__} only supports '
|
917
|
+
f'numpy.ndarray, jax.Array or brainunit.Quantity. '
|
918
|
+
f'But we got {type(value)}.'
|
919
|
+
)
|
920
|
+
if value.ndim < 2:
|
921
|
+
raise ValueError(
|
922
|
+
f'Currently, {self.__class__.__name__} only supports '
|
923
|
+
f'hidden states with more than 2 dimensions, where the last '
|
924
|
+
f'dimension is the number of state size and the other dimensions '
|
925
|
+
f'are the hidden shape. '
|
926
|
+
f'But we got {value.ndim} dimensions.'
|
927
|
+
)
|
928
|
+
name2index = {str(i): i for i in range(value.shape[-1])}
|
929
|
+
return value, name2index
|
930
|
+
|
931
|
+
def get_value(self, item: int | str) -> ArrayLike:
|
932
|
+
"""
|
933
|
+
Get the value of the hidden state with the item.
|
934
|
+
|
935
|
+
Args:
|
936
|
+
item: int or str. The index of the hidden state.
|
937
|
+
- If int, the index of the hidden state.
|
938
|
+
- If str, the name of the hidden state.
|
939
|
+
Returns:
|
940
|
+
The value of the hidden state.
|
941
|
+
"""
|
942
|
+
if isinstance(item, int):
|
943
|
+
assert item < self.value.shape[-1], (f'Index {item} out of range. '
|
944
|
+
f'The maximum index is {self.value.shape[-1] - 1}.')
|
945
|
+
return self.value[..., item]
|
946
|
+
elif isinstance(item, str):
|
947
|
+
assert item in self.name2index, (f'Hidden state name {item} not found. '
|
948
|
+
f'Please check the hidden state names.')
|
949
|
+
index = self.name2index[item]
|
950
|
+
return self.value[..., index]
|
951
|
+
else:
|
952
|
+
raise TypeError(
|
953
|
+
f'Currently, {self.__class__.__name__} only supports '
|
954
|
+
f'int or str for getting the hidden state. '
|
955
|
+
f'But we got {type(item)}.'
|
956
|
+
)
|
957
|
+
|
958
|
+
def set_value(
|
959
|
+
self,
|
960
|
+
val: Dict[int | str, ArrayLike] | Sequence[ArrayLike]
|
961
|
+
) -> None:
|
962
|
+
"""
|
963
|
+
Set the value of the hidden state with the specified item.
|
964
|
+
|
965
|
+
This method updates the hidden state values based on the provided dictionary or sequence.
|
966
|
+
The values are set according to the indices or names specified in the input.
|
967
|
+
|
968
|
+
Parameters
|
969
|
+
----------
|
970
|
+
val (Dict[int | str, ArrayLike] | Sequence[ArrayLike]):
|
971
|
+
A dictionary or sequence containing the new values for the hidden states.
|
972
|
+
- If a dictionary, keys can be integers (indices) or strings (names) of the hidden states.
|
973
|
+
- If a sequence, it is converted to a dictionary with indices as keys.
|
974
|
+
|
975
|
+
Returns
|
976
|
+
-------
|
977
|
+
None: This method does not return any value. It updates the hidden state values in place.
|
978
|
+
"""
|
979
|
+
if isinstance(val, (tuple, list)):
|
980
|
+
val = {i: v for i, v in enumerate(val)}
|
981
|
+
assert isinstance(val, dict), (
|
982
|
+
f'Currently, {self.__class__.__name__}.set_value() only supports '
|
983
|
+
f'dictionary of hidden states. But we got {type(val)}.'
|
984
|
+
)
|
985
|
+
indices = []
|
986
|
+
values = []
|
987
|
+
for k, v in val.items():
|
988
|
+
if isinstance(k, str):
|
989
|
+
k = self.name2index[k]
|
990
|
+
assert isinstance(k, int), (
|
991
|
+
f'Key {k} should be int or str. '
|
992
|
+
f'But we got {type(k)}.'
|
993
|
+
)
|
994
|
+
assert v.shape == self.varshape, (
|
995
|
+
f'The shape of the hidden state should be {self.varshape}. '
|
996
|
+
f'But we got {v.shape}.'
|
997
|
+
)
|
998
|
+
indices.append(k)
|
999
|
+
values.append(v)
|
1000
|
+
values = u.math.stack(values, axis=-1)
|
1001
|
+
self.value = self.value.at[..., indices].set(values)
|
1002
|
+
|
1003
|
+
|
1004
|
+
class HiddenTreeState(HiddenGroupState):
|
1005
|
+
"""
|
1006
|
+
A pytree of multiple hidden states for eligibility trace-based learning.
|
1007
|
+
|
1008
|
+
.. note::
|
1009
|
+
|
1010
|
+
The value in this state class behaves likes a dictionary/sequence of hidden states.
|
1011
|
+
However, the state is actually stored as a single dimensionless array.
|
1012
|
+
|
1013
|
+
There are two ways to define the hidden states.
|
1014
|
+
|
1015
|
+
1. The first is to define a sequence of hidden states.
|
1016
|
+
|
1017
|
+
.. code-block:: python
|
1018
|
+
|
1019
|
+
import brainunit as u
|
1020
|
+
value = [np.random.randn(10, 10) * u.mV,
|
1021
|
+
np.random.randn(10, 10) * u.mA,
|
1022
|
+
np.random.randn(10, 10) * u.mS]
|
1023
|
+
state = HiddenTreeState(value)
|
1024
|
+
|
1025
|
+
Then, you can retrieve the hidden state value with the following method.
|
1026
|
+
|
1027
|
+
.. code-block:: python
|
1028
|
+
|
1029
|
+
state.get_value(0) # get the first hidden state
|
1030
|
+
# or
|
1031
|
+
state.get_value('0') # get the hidden state with the name '0'
|
1032
|
+
|
1033
|
+
You can write the hidden state value with the following method.
|
1034
|
+
|
1035
|
+
.. code-block:: python
|
1036
|
+
|
1037
|
+
state.set_value({0: np.random.randn(10, 10) * u.mV}) # set the first hidden state
|
1038
|
+
# or
|
1039
|
+
state.set_value({'1': np.random.randn(10, 10) * u.mA}) # set the hidden state with the name '1'
|
1040
|
+
# or
|
1041
|
+
state.set_value([np.random.randn(10, 10) * u.mV,
|
1042
|
+
np.random.randn(10, 10) * u.mA,
|
1043
|
+
np.random.randn(10, 10) * u.mS]) # set all hidden state value
|
1044
|
+
# or
|
1045
|
+
state.set_value({
|
1046
|
+
0: np.random.randn(10, 10) * u.mV,
|
1047
|
+
1: np.random.randn(10, 10) * u.mA,
|
1048
|
+
2: np.random.randn(10, 10) * u.mS
|
1049
|
+
}) # set all hidden state value
|
1050
|
+
|
1051
|
+
2. The second is to define a dictionary of hidden states.
|
1052
|
+
|
1053
|
+
.. code-block:: python
|
1054
|
+
|
1055
|
+
import brainunit as u
|
1056
|
+
value = {'v': np.random.randn(10, 10) * u.mV,
|
1057
|
+
'i': np.random.randn(10, 10) * u.mA,
|
1058
|
+
'g': np.random.randn(10, 10) * u.mS}
|
1059
|
+
state = HiddenTreeState(value)
|
1060
|
+
|
1061
|
+
Then, you can retrieve the hidden state value with the following method.
|
1062
|
+
|
1063
|
+
.. code-block:: python
|
1064
|
+
|
1065
|
+
state.get_value('v') # get the hidden state with the name 'v'
|
1066
|
+
# or
|
1067
|
+
state.get_value('i') # get the hidden state with the name 'i'
|
1068
|
+
|
1069
|
+
You can write the hidden state value with the following method.
|
1070
|
+
|
1071
|
+
.. code-block:: python
|
1072
|
+
|
1073
|
+
state.set_value({'v': np.random.randn(10, 10) * u.mV}) # set the hidden state with the name 'v'
|
1074
|
+
# or
|
1075
|
+
state.set_value({'i': np.random.randn(10, 10) * u.mA}) # set the hidden state with the name 'i'
|
1076
|
+
# or
|
1077
|
+
state.set_value([np.random.randn(10, 10) * u.mV,
|
1078
|
+
np.random.randn(10, 10) * u.mA,
|
1079
|
+
np.random.randn(10, 10) * u.mS]) # set all hidden state value
|
1080
|
+
# or
|
1081
|
+
state.set_value({
|
1082
|
+
'v': np.random.randn(10, 10) * u.mV,
|
1083
|
+
'g': np.random.randn(10, 10) * u.mA,
|
1084
|
+
'i': np.random.randn(10, 10) * u.mS
|
1085
|
+
}) # set all hidden state value
|
1086
|
+
|
1087
|
+
.. note::
|
1088
|
+
|
1089
|
+
Avoid using ``HiddenTreeState.value`` to get the state value, or
|
1090
|
+
``HiddenTreeState.value =`` to assign the state value.
|
1091
|
+
|
1092
|
+
Instead, use ``HiddenTreeState.get_value()`` and ``HiddenTreeState.set_value()``.
|
1093
|
+
This is because ``.value`` loss hidden state units and other information,
|
1094
|
+
and it is only dimensionless data.
|
1095
|
+
|
1096
|
+
This design aims to ensure that any etrace hidden state has only one array.
|
1097
|
+
|
1098
|
+
|
1099
|
+
Args:
|
1100
|
+
value: The values of the hidden states.
|
1101
|
+
"""
|
1102
|
+
|
1103
|
+
__module__ = 'brainstate'
|
1104
|
+
value: ArrayLike
|
1105
|
+
|
1106
|
+
def __init__(
|
1107
|
+
self,
|
1108
|
+
value: Dict[str, ArrayLike] | Sequence[ArrayLike],
|
1109
|
+
):
|
1110
|
+
value, name2unit, name2index = self._check_value(value)
|
1111
|
+
self.name2unit: Dict[str, u.Unit] = name2unit
|
1112
|
+
self.name2index: Dict[str, int] = name2index
|
1113
|
+
self.index2unit: Dict[int, u.Unit] = {i: v for i, v in enumerate(name2unit.values())}
|
1114
|
+
self.index2name: Dict[int, str] = {v: k for k, v in name2index.items()}
|
1115
|
+
ShortTermState.__init__(self, value)
|
1116
|
+
|
1117
|
+
@property
|
1118
|
+
def varshape(self) -> Tuple[int, ...]:
|
1119
|
+
"""
|
1120
|
+
The shape of each hidden state variable.
|
1121
|
+
"""
|
1122
|
+
return self.value.shape[:-1]
|
1123
|
+
|
1124
|
+
@property
|
1125
|
+
def num_state(self) -> int:
|
1126
|
+
"""
|
1127
|
+
The number of hidden states.
|
1128
|
+
"""
|
1129
|
+
assert self.value.shape[-1] == len(self.name2index), (
|
1130
|
+
f'The number of hidden states '
|
1131
|
+
f'is not equal to the number of hidden state names.'
|
1132
|
+
)
|
1133
|
+
return self.value.shape[-1]
|
1134
|
+
|
1135
|
+
def _check_value(
|
1136
|
+
self,
|
1137
|
+
value: dict | Sequence
|
1138
|
+
) -> Tuple[ArrayLike, Dict[str, u.Unit], Dict[str, int]]:
|
1139
|
+
"""
|
1140
|
+
Validates and processes the input value to ensure it conforms to the expected format
|
1141
|
+
and structure for hidden states.
|
1142
|
+
|
1143
|
+
This function checks if the input value is a dictionary or sequence of hidden states,
|
1144
|
+
verifies that all hidden states have the same shape, and extracts units and indices
|
1145
|
+
for each hidden state.
|
1146
|
+
|
1147
|
+
Args:
|
1148
|
+
value (dict | Sequence): A dictionary or sequence representing hidden states.
|
1149
|
+
- If a sequence, it is converted to a dictionary with string indices as keys.
|
1150
|
+
- Each hidden state should be a numpy.ndarray, jax.Array, or brainunit.Quantity.
|
1151
|
+
|
1152
|
+
Returns:
|
1153
|
+
Tuple[ArrayLike, Dict[str, u.Unit], Dict[str, int]]:
|
1154
|
+
- A stacked array of hidden state magnitudes.
|
1155
|
+
- A dictionary mapping hidden state names to their units.
|
1156
|
+
- A dictionary mapping hidden state names to their indices.
|
1157
|
+
|
1158
|
+
Raises:
|
1159
|
+
TypeError: If any hidden state is not a numpy.ndarray, jax.Array, or brainunit.Quantity.
|
1160
|
+
ValueError: If hidden states do not have the same shape.
|
1161
|
+
"""
|
1162
|
+
if isinstance(value, (tuple, list)):
|
1163
|
+
value = {str(i): v for i, v in enumerate(value)}
|
1164
|
+
assert isinstance(value, dict), (
|
1165
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1166
|
+
f'dictionary/sequence of hidden states. But we got {type(value)}.'
|
1167
|
+
)
|
1168
|
+
shapes = []
|
1169
|
+
for k, v in value.items():
|
1170
|
+
if not isinstance(v, (np.ndarray, jax.Array, u.Quantity)):
|
1171
|
+
raise TypeError(
|
1172
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1173
|
+
f'numpy.ndarray, jax.Array or brainunit.Quantity. '
|
1174
|
+
f'But we got {type(v)} for key {k}.'
|
1175
|
+
)
|
1176
|
+
shapes.append(v.shape)
|
1177
|
+
if len(set(shapes)) > 1:
|
1178
|
+
info = {k: v.shape for k, v in value.items()}
|
1179
|
+
raise ValueError(
|
1180
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1181
|
+
f'hidden states with the same shape. '
|
1182
|
+
f'But we got {info}.'
|
1183
|
+
)
|
1184
|
+
name2unit = {k: u.get_unit(v) for k, v in value.items()}
|
1185
|
+
name2index = {k: i for i, k in enumerate(value.keys())}
|
1186
|
+
value = u.math.stack([u.get_magnitude(v) for v in value.values()], axis=-1)
|
1187
|
+
return value, name2unit, name2index
|
1188
|
+
|
1189
|
+
def get_value(self, item: str | int) -> ArrayLike:
|
1190
|
+
"""
|
1191
|
+
Get the value of the hidden state with the key.
|
1192
|
+
|
1193
|
+
Args:
|
1194
|
+
item: The key of the hidden state.
|
1195
|
+
- If int, the index of the hidden state.
|
1196
|
+
- If str, the name of the hidden state.
|
1197
|
+
"""
|
1198
|
+
if isinstance(item, int):
|
1199
|
+
assert item < self.value.shape[-1], (f'Index {item} out of range. '
|
1200
|
+
f'The maximum index is {self.value.shape[-1] - 1}.')
|
1201
|
+
val = self.value[..., item]
|
1202
|
+
elif isinstance(item, str):
|
1203
|
+
assert item in self.name2index, (f'Hidden state name {item} not found. '
|
1204
|
+
f'Please check the hidden state names.')
|
1205
|
+
item = self.name2index[item]
|
1206
|
+
val = self.value[..., item]
|
1207
|
+
else:
|
1208
|
+
raise TypeError(
|
1209
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1210
|
+
f'int or str for getting the hidden state. '
|
1211
|
+
f'But we got {type(item)}.'
|
1212
|
+
)
|
1213
|
+
if self.index2unit[item].dim.is_dimensionless:
|
1214
|
+
return val
|
1215
|
+
else:
|
1216
|
+
return val * self.index2unit[item]
|
1217
|
+
|
1218
|
+
def set_value(
|
1219
|
+
self,
|
1220
|
+
val: Dict[int | str, ArrayLike] | Sequence[ArrayLike]
|
1221
|
+
) -> None:
|
1222
|
+
"""
|
1223
|
+
Set the value of the hidden state with the specified item.
|
1224
|
+
|
1225
|
+
This method updates the hidden state values based on the provided dictionary or sequence.
|
1226
|
+
The values are set according to the indices or names specified in the input.
|
1227
|
+
|
1228
|
+
Parameters
|
1229
|
+
----------
|
1230
|
+
val (Dict[int | str, ArrayLike] | Sequence[ArrayLike]):
|
1231
|
+
A dictionary or sequence containing the new values for the hidden states.
|
1232
|
+
- If a dictionary, keys can be integers (indices) or strings (names) of the hidden states.
|
1233
|
+
- If a sequence, it is converted to a dictionary with indices as keys.
|
1234
|
+
|
1235
|
+
Returns
|
1236
|
+
-------
|
1237
|
+
None: This method does not return any value. It updates the hidden state values in place.
|
1238
|
+
"""
|
1239
|
+
if isinstance(val, (tuple, list)):
|
1240
|
+
val = {i: v for i, v in enumerate(val)}
|
1241
|
+
assert isinstance(val, dict), (f'Currently, {self.__class__.__name__}.set_value() only supports '
|
1242
|
+
f'dictionary of hidden states. But we got {type(val)}.')
|
1243
|
+
indices = []
|
1244
|
+
values = []
|
1245
|
+
for index, v in val.items():
|
1246
|
+
if isinstance(index, str):
|
1247
|
+
index = self.name2index[index]
|
1248
|
+
assert isinstance(index, int), (f'Key {index} should be int or str. '
|
1249
|
+
f'But we got {type(index)}.')
|
1250
|
+
assert v.shape == self.varshape, (f'The shape of the hidden state should be {self.varshape}. '
|
1251
|
+
f'But we got {v.shape}.')
|
1252
|
+
indices.append(index)
|
1253
|
+
values.append(u.Quantity(v).to(self.index2unit[index]).mantissa)
|
1254
|
+
if len(indices) == 0:
|
1255
|
+
raise ValueError(
|
1256
|
+
f'No hidden state is set. Please check the hidden state names or indices.'
|
1257
|
+
)
|
1258
|
+
if len(indices) == 1:
|
1259
|
+
indices = indices[0]
|
1260
|
+
values = values[0]
|
1261
|
+
else:
|
1262
|
+
indices = np.asarray(indices)
|
1263
|
+
values = u.math.stack(values, axis=-1)
|
1264
|
+
self.value = self.value.at[..., indices].set(values)
|
1265
|
+
|
761
1266
|
|
762
1267
|
class ParamState(LongTermState):
|
763
1268
|
"""
|
@@ -967,17 +1472,17 @@ class StateTraceStack(Generic[A]):
|
|
967
1472
|
def new_arg(self, state: State) -> None:
|
968
1473
|
"""
|
969
1474
|
Apply a transformation to the value of a given state using a predefined function.
|
970
|
-
|
1475
|
+
|
971
1476
|
This method is used internally to transform the value of a state during tracing.
|
972
1477
|
If a transformation function (``_jax_trace_new_arg``) is defined, it applies this
|
973
1478
|
function to each element of the state's value using JAX's tree mapping.
|
974
|
-
|
1479
|
+
|
975
1480
|
Args:
|
976
1481
|
state (State): The State object whose value needs to be transformed.
|
977
|
-
|
1482
|
+
|
978
1483
|
Returns:
|
979
1484
|
None: This function modifies the state in-place and doesn't return anything.
|
980
|
-
|
1485
|
+
|
981
1486
|
Note:
|
982
1487
|
This method is intended for internal use and relies on the presence of
|
983
1488
|
a ``_jax_trace_new_arg`` function, which should be set separately.
|
@@ -997,18 +1502,18 @@ class StateTraceStack(Generic[A]):
|
|
997
1502
|
def read_its_value(self, state: State) -> None:
|
998
1503
|
"""
|
999
1504
|
Record that a state's value has been read during tracing.
|
1000
|
-
|
1505
|
+
|
1001
1506
|
This method marks the given state as having been read in the current
|
1002
1507
|
tracing context. If the state hasn't been encountered before, it adds
|
1003
1508
|
it to the internal tracking structures and applies any necessary
|
1004
1509
|
transformations via the new_arg method.
|
1005
|
-
|
1510
|
+
|
1006
1511
|
Args:
|
1007
1512
|
state (State): The State object whose value is being read.
|
1008
|
-
|
1513
|
+
|
1009
1514
|
Returns:
|
1010
1515
|
None
|
1011
|
-
|
1516
|
+
|
1012
1517
|
Note:
|
1013
1518
|
This method updates the internal tracking of state accesses.
|
1014
1519
|
It doesn't actually read or return the state's value.
|
@@ -1052,11 +1557,11 @@ class StateTraceStack(Generic[A]):
|
|
1052
1557
|
) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
|
1053
1558
|
"""
|
1054
1559
|
Retrieve the values of all states in the StateTraceStack.
|
1055
|
-
|
1560
|
+
|
1056
1561
|
This method returns the values of all states, optionally separating them
|
1057
1562
|
into written and read states, and optionally replacing values with None
|
1058
1563
|
for states that weren't accessed in a particular way.
|
1059
|
-
|
1564
|
+
|
1060
1565
|
Args:
|
1061
1566
|
separate (bool, optional): If True, separate the values into written
|
1062
1567
|
and read states. If False, return all values in a single sequence.
|
@@ -1064,7 +1569,7 @@ class StateTraceStack(Generic[A]):
|
|
1064
1569
|
replace (bool, optional): If True and separate is True, replace values
|
1065
1570
|
with None for states that weren't written/read. If False, only
|
1066
1571
|
include values for states that were written/read. Defaults to False.
|
1067
|
-
|
1572
|
+
|
1068
1573
|
Returns:
|
1069
1574
|
Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
|
1070
1575
|
If separate is False:
|
@@ -1075,7 +1580,7 @@ class StateTraceStack(Generic[A]):
|
|
1075
1580
|
- The second sequence contains values of read states.
|
1076
1581
|
If replace is True, these sequences will have None for
|
1077
1582
|
states that weren't written/read respectively.
|
1078
|
-
|
1583
|
+
|
1079
1584
|
"""
|
1080
1585
|
if separate:
|
1081
1586
|
if replace:
|
@@ -1121,19 +1626,19 @@ class StateTraceStack(Generic[A]):
|
|
1121
1626
|
def merge(self, *traces) -> 'StateTraceStack':
|
1122
1627
|
"""
|
1123
1628
|
Merge other state traces into the current ``StateTraceStack``.
|
1124
|
-
|
1629
|
+
|
1125
1630
|
This method combines the states, their write status, and original values from
|
1126
1631
|
other ``StateTraceStack`` instances into the current one. If a state from another
|
1127
1632
|
trace is not present in the current trace, it is added. If a state is already
|
1128
1633
|
present, its write status is updated if necessary.
|
1129
|
-
|
1634
|
+
|
1130
1635
|
Args:
|
1131
1636
|
*traces: Variable number of ``StateTraceStack`` instances to be merged into
|
1132
1637
|
the current instance.
|
1133
|
-
|
1638
|
+
|
1134
1639
|
Returns:
|
1135
1640
|
StateTraceStack: The current ``StateTraceStack`` instance with merged traces.
|
1136
|
-
|
1641
|
+
|
1137
1642
|
Note:
|
1138
1643
|
This method modifies the current ``StateTraceStack`` in-place and also returns it.
|
1139
1644
|
"""
|
@@ -1152,16 +1657,16 @@ class StateTraceStack(Generic[A]):
|
|
1152
1657
|
def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
|
1153
1658
|
"""
|
1154
1659
|
Retrieve the states that were read during the function execution.
|
1155
|
-
|
1660
|
+
|
1156
1661
|
This method returns the states that were accessed (read from) during
|
1157
1662
|
the traced function's execution. It can optionally replace written
|
1158
1663
|
states with None.
|
1159
|
-
|
1664
|
+
|
1160
1665
|
Args:
|
1161
1666
|
replace_writen (bool, optional): If True, replace written states with None
|
1162
1667
|
in the returned tuple. If False, exclude written states entirely from
|
1163
1668
|
the result. Defaults to False.
|
1164
|
-
|
1669
|
+
|
1165
1670
|
Returns:
|
1166
1671
|
Tuple[State, ...]: A tuple containing the read states.
|
1167
1672
|
If replace_writen is True, the tuple will have the same length as the
|
@@ -1177,15 +1682,15 @@ class StateTraceStack(Generic[A]):
|
|
1177
1682
|
def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
|
1178
1683
|
"""
|
1179
1684
|
Retrieve the values of states that were read during the function execution.
|
1180
|
-
|
1685
|
+
|
1181
1686
|
This method returns the values of states that were accessed (read from) during
|
1182
1687
|
the traced function's execution. It can optionally replace written states with None.
|
1183
|
-
|
1688
|
+
|
1184
1689
|
Args:
|
1185
1690
|
replace_writen (bool, optional): If True, replace the values of written
|
1186
1691
|
states with None in the returned tuple. If False, exclude written
|
1187
1692
|
states entirely from the result. Defaults to False.
|
1188
|
-
|
1693
|
+
|
1189
1694
|
Returns:
|
1190
1695
|
Tuple[PyTree, ...]: A tuple containing the values of read states.
|
1191
1696
|
If replace_writen is True, the tuple will have the same length as the
|
@@ -1204,16 +1709,16 @@ class StateTraceStack(Generic[A]):
|
|
1204
1709
|
def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
|
1205
1710
|
"""
|
1206
1711
|
Retrieve the states that were written during the function execution.
|
1207
|
-
|
1712
|
+
|
1208
1713
|
This method returns the states that were modified (written to) during
|
1209
1714
|
the traced function's execution. It can optionally replace unwritten (read-only)
|
1210
1715
|
states with None.
|
1211
|
-
|
1716
|
+
|
1212
1717
|
Args:
|
1213
1718
|
replace_read (bool, optional): If True, replace read-only states with None
|
1214
1719
|
in the returned tuple. If False, exclude read-only states entirely from
|
1215
1720
|
the result. Defaults to False.
|
1216
|
-
|
1721
|
+
|
1217
1722
|
Returns:
|
1218
1723
|
Tuple[State, ...]: A tuple containing the written states.
|
1219
1724
|
If replace_read is True, the tuple will have the same length as the
|
@@ -1229,23 +1734,23 @@ class StateTraceStack(Generic[A]):
|
|
1229
1734
|
def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
|
1230
1735
|
"""
|
1231
1736
|
Retrieve the values of states that were written during the function execution.
|
1232
|
-
|
1737
|
+
|
1233
1738
|
This method returns the values of states that were modified (written to) during
|
1234
1739
|
the traced function's execution. It can optionally replace unwritten (read-only)
|
1235
1740
|
states with None.
|
1236
|
-
|
1741
|
+
|
1237
1742
|
Args:
|
1238
1743
|
replace_read (bool, optional): If True, replace the values of read-only
|
1239
1744
|
states with None in the returned tuple. If False, exclude read-only
|
1240
1745
|
states entirely from the result. Defaults to False.
|
1241
|
-
|
1746
|
+
|
1242
1747
|
Returns:
|
1243
1748
|
Tuple[PyTree, ...]: A tuple containing the values of written states.
|
1244
1749
|
If replace_read is True, the tuple will have the same length as the
|
1245
1750
|
total number of states, with None for read-only states.
|
1246
1751
|
If replace_read is False, the tuple will only contain values of
|
1247
1752
|
written states.
|
1248
|
-
|
1753
|
+
|
1249
1754
|
"""
|
1250
1755
|
if replace_read:
|
1251
1756
|
return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
|