brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/_state.py CHANGED
@@ -33,6 +33,7 @@ from typing import (
33
33
  Generator,
34
34
  )
35
35
 
36
+ import brainunit as u
36
37
  import jax
37
38
  import numpy as np
38
39
  from jax.api_util import shaped_abstractify
@@ -47,6 +48,8 @@ __all__ = [
47
48
  'ShortTermState',
48
49
  'LongTermState',
49
50
  'HiddenState',
51
+ 'HiddenGroupState',
52
+ 'HiddenTreeState',
50
53
  'ParamState',
51
54
  'BatchState',
52
55
  'TreefyState',
@@ -213,7 +216,7 @@ class State(Generic[A], PrettyObject):
213
216
  tag (Optional[str]): An optional tag for categorizing or grouping states.
214
217
 
215
218
  Args:
216
- value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
219
+ value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
217
220
  The initial value for the state. Can be a PyTree of array-like objects
218
221
  or a StateMetadata object.
219
222
  name (Optional[str]): An optional name for the state.
@@ -233,13 +236,13 @@ class State(Generic[A], PrettyObject):
233
236
  [0. 0. 0.]]
234
237
 
235
238
  Note:
236
- - Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
239
+ - Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
237
240
  RandomState) are typically used for specific purposes in a program.
238
- - The class integrates with BrainState's tracing system to track state
241
+ - The class integrates with BrainState's tracing system to track state
239
242
  creation and modifications.
240
243
 
241
244
  The typical examples of :py:class:`~.State` subclass are:
242
-
245
+
243
246
  - :py:class:`ShortTermState`: The short-term state, which is used to store the short-term data in the program.
244
247
  - :py:class:`LongTermState`: The long-term state, which is used to store the long-term data in the program.
245
248
  - :py:class:`ParamState`: The parameter state, which is used to store the parameters in the program.
@@ -271,7 +274,7 @@ class State(Generic[A], PrettyObject):
271
274
  handling various input types and metadata.
272
275
 
273
276
  Args:
274
- value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
277
+ value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
275
278
  The initial value for the hidden state. Can be a PyTree of array-like objects
276
279
  or a StateMetadata object containing both value and metadata.
277
280
  name (Optional[str], optional): A name for the hidden state. Defaults to None.
@@ -738,7 +741,7 @@ class BatchState(LongTermState):
738
741
 
739
742
  class HiddenState(ShortTermState):
740
743
  """
741
- Represents hidden state variables in neurons or synapses.
744
+ Represents hidden state variables in neurons or synapses.
742
745
 
743
746
  This class extends :class:`ShortTermState` and is specifically designed to represent
744
747
  and manage hidden states within dynamic models, such as recurrent neural networks.
@@ -751,6 +754,12 @@ class HiddenState(ShortTermState):
751
754
  or networks. The latter is used to store the trainable parameters in the model,
752
755
  such as synaptic weights.
753
756
 
757
+ Note:
758
+ From version 0.2.2, :class:`HiddenState` only supports value of numpy.ndarray,
759
+ jax.Array or brainunit.Quantity. Moreover, it is equivalent to :class:`brainscale.ETraceState`.
760
+ Dynamics models defined with :class:`HiddenState` can be seamlessly integrated with
761
+ BrainScale online learning.
762
+
754
763
  Example:
755
764
  >>> lstm_hidden = HiddenState(np.zeros(128), name="lstm_hidden_state")
756
765
  >>> gru_hidden = HiddenState(np.zeros(64), name="gru_hidden_state")
@@ -758,6 +767,502 @@ class HiddenState(ShortTermState):
758
767
 
759
768
  __module__ = 'brainstate'
760
769
 
770
+ value: ArrayLike
771
+
772
+ def __init__(self, value: ArrayLike, name: Optional[str] = None):
773
+ self._check_value(value)
774
+ super().__init__(value, name=name)
775
+
776
+ @property
777
+ def varshape(self) -> Tuple[int, ...]:
778
+ """
779
+ Get the shape of the hidden state variable.
780
+
781
+ This property returns the shape of the hidden state variable stored in the instance.
782
+ It provides the dimensions of the array representing the hidden state.
783
+
784
+ Returns:
785
+ Tuple[int, ...]: A tuple representing the shape of the hidden state variable.
786
+ """
787
+ return self.value.shape
788
+
789
+ @property
790
+ def num_state(self) -> int:
791
+ """
792
+ Get the number of hidden states.
793
+
794
+ This property returns the number of hidden states represented by the instance.
795
+ For the `ETraceState` class, this is always 1, as it represents a single hidden state.
796
+
797
+ Returns:
798
+ int: The number of hidden states, which is 1 for this class.
799
+ """
800
+ return 1
801
+
802
+ def _check_value(self, value: ArrayLike):
803
+ if not isinstance(value, (np.ndarray, jax.Array, u.Quantity)):
804
+ raise TypeError(
805
+ f'Currently, {HiddenState.__name__} only supports '
806
+ f'numpy.ndarray, jax.Array or brainunit.Quantity. '
807
+ f'But we got {type(value)}.'
808
+ )
809
+
810
+
811
+ class HiddenGroupState(HiddenState):
812
+ """
813
+ A group of multiple hidden states for eligibility trace-based learning.
814
+
815
+ This class is used to define multiple hidden states within a single instance
816
+ of :py:class:`ETraceState`. Normally, you should define multiple instances
817
+ of :py:class:`ETraceState` to represent multiple hidden states. But
818
+ :py:class:`HiddenGroupState` let your define multiple hidden states within
819
+ a single instance.
820
+
821
+ The following is the way to initialize the hidden states.
822
+
823
+ .. code-block:: python
824
+
825
+ import brainunit as u
826
+ value = np.random.randn(10, 10, 5) * u.mV
827
+ state = HiddenGroupState(value)
828
+
829
+ Then, you can retrieve the hidden state value with the following method.
830
+
831
+ .. code-block:: python
832
+
833
+ state.get_value(0) # get the first hidden state
834
+ # or
835
+ state.get_value('0') # get the hidden state with the name '0'
836
+
837
+ You can write the hidden state value with the following method.
838
+
839
+ .. code-block:: python
840
+
841
+ state.set_value({0: np.random.randn(10, 10) * u.mV}) # set the first hidden state
842
+ # or
843
+ state.set_value({'0': np.random.randn(10, 10) * u.mV}) # set the hidden state with the name '0'
844
+ # or
845
+ state.value = np.random.randn(10, 10, 5) * u.mV # set all hidden state value
846
+
847
+ Args:
848
+ value: The values of the hidden states. It can be a sequence of hidden states,
849
+ or a single hidden state with the last dimension as the number of hidden states,
850
+ or a dictionary of hidden states.
851
+ """
852
+
853
+ __module__ = 'brainstate'
854
+ value: ArrayLike
855
+ name2index: Dict[str, int]
856
+
857
+ def __init__(self, value: ArrayLike):
858
+ value, name2index = self._check_value(value)
859
+ self.name2index = name2index
860
+ ShortTermState.__init__(self, value)
861
+
862
+ @property
863
+ def varshape(self) -> Tuple[int, ...]:
864
+ """
865
+ Get the shape of each hidden state variable.
866
+
867
+ This property returns the shape of the hidden state variables, excluding
868
+ the last dimension which represents the number of hidden states.
869
+
870
+ Returns:
871
+ Tuple[int, ...]: A tuple representing the shape of each hidden state variable.
872
+ """
873
+ return self.value.shape[:-1]
874
+
875
+ @property
876
+ def num_state(self) -> int:
877
+ """
878
+ Get the number of hidden states.
879
+
880
+ This property returns the number of hidden states represented by the last dimension
881
+ of the value array.
882
+
883
+ Returns:
884
+ int: The number of hidden states.
885
+ """
886
+ return self.value.shape[-1]
887
+
888
+ def _check_value(self, value) -> Tuple[ArrayLike, Dict[str, int]]:
889
+ """
890
+ Validates the input value for hidden states and returns a tuple containing
891
+ the processed value and a dictionary mapping state names to indices.
892
+
893
+ This function ensures that the input value is of a supported type and has
894
+ the required dimensionality for hidden states. It also constructs a mapping
895
+ from string representations of indices to their integer counterparts.
896
+
897
+ Parameters
898
+ ----------
899
+ value (ArrayLike): The input value representing hidden states.
900
+ It must be an instance of numpy.ndarray, jax.Array, or brainunit.Quantity
901
+ with at least two dimensions.
902
+
903
+ Returns
904
+ -------
905
+ Tuple[ArrayLike, Dict[str, int]]: A tuple containing:
906
+ - The validated and possibly modified input value.
907
+ - A dictionary mapping string representations of indices to integer indices.
908
+
909
+ Raises
910
+ ------
911
+ TypeError: If the input value is not of a supported type.
912
+ ValueError: If the input value does not have the required number of dimensions.
913
+ """
914
+ if not isinstance(value, (np.ndarray, jax.Array, u.Quantity)):
915
+ raise TypeError(
916
+ f'Currently, {self.__class__.__name__} only supports '
917
+ f'numpy.ndarray, jax.Array or brainunit.Quantity. '
918
+ f'But we got {type(value)}.'
919
+ )
920
+ if value.ndim < 2:
921
+ raise ValueError(
922
+ f'Currently, {self.__class__.__name__} only supports '
923
+ f'hidden states with more than 2 dimensions, where the last '
924
+ f'dimension is the number of state size and the other dimensions '
925
+ f'are the hidden shape. '
926
+ f'But we got {value.ndim} dimensions.'
927
+ )
928
+ name2index = {str(i): i for i in range(value.shape[-1])}
929
+ return value, name2index
930
+
931
+ def get_value(self, item: int | str) -> ArrayLike:
932
+ """
933
+ Get the value of the hidden state with the item.
934
+
935
+ Args:
936
+ item: int or str. The index of the hidden state.
937
+ - If int, the index of the hidden state.
938
+ - If str, the name of the hidden state.
939
+ Returns:
940
+ The value of the hidden state.
941
+ """
942
+ if isinstance(item, int):
943
+ assert item < self.value.shape[-1], (f'Index {item} out of range. '
944
+ f'The maximum index is {self.value.shape[-1] - 1}.')
945
+ return self.value[..., item]
946
+ elif isinstance(item, str):
947
+ assert item in self.name2index, (f'Hidden state name {item} not found. '
948
+ f'Please check the hidden state names.')
949
+ index = self.name2index[item]
950
+ return self.value[..., index]
951
+ else:
952
+ raise TypeError(
953
+ f'Currently, {self.__class__.__name__} only supports '
954
+ f'int or str for getting the hidden state. '
955
+ f'But we got {type(item)}.'
956
+ )
957
+
958
+ def set_value(
959
+ self,
960
+ val: Dict[int | str, ArrayLike] | Sequence[ArrayLike]
961
+ ) -> None:
962
+ """
963
+ Set the value of the hidden state with the specified item.
964
+
965
+ This method updates the hidden state values based on the provided dictionary or sequence.
966
+ The values are set according to the indices or names specified in the input.
967
+
968
+ Parameters
969
+ ----------
970
+ val (Dict[int | str, ArrayLike] | Sequence[ArrayLike]):
971
+ A dictionary or sequence containing the new values for the hidden states.
972
+ - If a dictionary, keys can be integers (indices) or strings (names) of the hidden states.
973
+ - If a sequence, it is converted to a dictionary with indices as keys.
974
+
975
+ Returns
976
+ -------
977
+ None: This method does not return any value. It updates the hidden state values in place.
978
+ """
979
+ if isinstance(val, (tuple, list)):
980
+ val = {i: v for i, v in enumerate(val)}
981
+ assert isinstance(val, dict), (
982
+ f'Currently, {self.__class__.__name__}.set_value() only supports '
983
+ f'dictionary of hidden states. But we got {type(val)}.'
984
+ )
985
+ indices = []
986
+ values = []
987
+ for k, v in val.items():
988
+ if isinstance(k, str):
989
+ k = self.name2index[k]
990
+ assert isinstance(k, int), (
991
+ f'Key {k} should be int or str. '
992
+ f'But we got {type(k)}.'
993
+ )
994
+ assert v.shape == self.varshape, (
995
+ f'The shape of the hidden state should be {self.varshape}. '
996
+ f'But we got {v.shape}.'
997
+ )
998
+ indices.append(k)
999
+ values.append(v)
1000
+ values = u.math.stack(values, axis=-1)
1001
+ self.value = self.value.at[..., indices].set(values)
1002
+
1003
+
1004
+ class HiddenTreeState(HiddenGroupState):
1005
+ """
1006
+ A pytree of multiple hidden states for eligibility trace-based learning.
1007
+
1008
+ .. note::
1009
+
1010
+ The value in this state class behaves likes a dictionary/sequence of hidden states.
1011
+ However, the state is actually stored as a single dimensionless array.
1012
+
1013
+ There are two ways to define the hidden states.
1014
+
1015
+ 1. The first is to define a sequence of hidden states.
1016
+
1017
+ .. code-block:: python
1018
+
1019
+ import brainunit as u
1020
+ value = [np.random.randn(10, 10) * u.mV,
1021
+ np.random.randn(10, 10) * u.mA,
1022
+ np.random.randn(10, 10) * u.mS]
1023
+ state = HiddenTreeState(value)
1024
+
1025
+ Then, you can retrieve the hidden state value with the following method.
1026
+
1027
+ .. code-block:: python
1028
+
1029
+ state.get_value(0) # get the first hidden state
1030
+ # or
1031
+ state.get_value('0') # get the hidden state with the name '0'
1032
+
1033
+ You can write the hidden state value with the following method.
1034
+
1035
+ .. code-block:: python
1036
+
1037
+ state.set_value({0: np.random.randn(10, 10) * u.mV}) # set the first hidden state
1038
+ # or
1039
+ state.set_value({'1': np.random.randn(10, 10) * u.mA}) # set the hidden state with the name '1'
1040
+ # or
1041
+ state.set_value([np.random.randn(10, 10) * u.mV,
1042
+ np.random.randn(10, 10) * u.mA,
1043
+ np.random.randn(10, 10) * u.mS]) # set all hidden state value
1044
+ # or
1045
+ state.set_value({
1046
+ 0: np.random.randn(10, 10) * u.mV,
1047
+ 1: np.random.randn(10, 10) * u.mA,
1048
+ 2: np.random.randn(10, 10) * u.mS
1049
+ }) # set all hidden state value
1050
+
1051
+ 2. The second is to define a dictionary of hidden states.
1052
+
1053
+ .. code-block:: python
1054
+
1055
+ import brainunit as u
1056
+ value = {'v': np.random.randn(10, 10) * u.mV,
1057
+ 'i': np.random.randn(10, 10) * u.mA,
1058
+ 'g': np.random.randn(10, 10) * u.mS}
1059
+ state = HiddenTreeState(value)
1060
+
1061
+ Then, you can retrieve the hidden state value with the following method.
1062
+
1063
+ .. code-block:: python
1064
+
1065
+ state.get_value('v') # get the hidden state with the name 'v'
1066
+ # or
1067
+ state.get_value('i') # get the hidden state with the name 'i'
1068
+
1069
+ You can write the hidden state value with the following method.
1070
+
1071
+ .. code-block:: python
1072
+
1073
+ state.set_value({'v': np.random.randn(10, 10) * u.mV}) # set the hidden state with the name 'v'
1074
+ # or
1075
+ state.set_value({'i': np.random.randn(10, 10) * u.mA}) # set the hidden state with the name 'i'
1076
+ # or
1077
+ state.set_value([np.random.randn(10, 10) * u.mV,
1078
+ np.random.randn(10, 10) * u.mA,
1079
+ np.random.randn(10, 10) * u.mS]) # set all hidden state value
1080
+ # or
1081
+ state.set_value({
1082
+ 'v': np.random.randn(10, 10) * u.mV,
1083
+ 'g': np.random.randn(10, 10) * u.mA,
1084
+ 'i': np.random.randn(10, 10) * u.mS
1085
+ }) # set all hidden state value
1086
+
1087
+ .. note::
1088
+
1089
+ Avoid using ``HiddenTreeState.value`` to get the state value, or
1090
+ ``HiddenTreeState.value =`` to assign the state value.
1091
+
1092
+ Instead, use ``HiddenTreeState.get_value()`` and ``HiddenTreeState.set_value()``.
1093
+ This is because ``.value`` loss hidden state units and other information,
1094
+ and it is only dimensionless data.
1095
+
1096
+ This design aims to ensure that any etrace hidden state has only one array.
1097
+
1098
+
1099
+ Args:
1100
+ value: The values of the hidden states.
1101
+ """
1102
+
1103
+ __module__ = 'brainstate'
1104
+ value: ArrayLike
1105
+
1106
+ def __init__(
1107
+ self,
1108
+ value: Dict[str, ArrayLike] | Sequence[ArrayLike],
1109
+ ):
1110
+ value, name2unit, name2index = self._check_value(value)
1111
+ self.name2unit: Dict[str, u.Unit] = name2unit
1112
+ self.name2index: Dict[str, int] = name2index
1113
+ self.index2unit: Dict[int, u.Unit] = {i: v for i, v in enumerate(name2unit.values())}
1114
+ self.index2name: Dict[int, str] = {v: k for k, v in name2index.items()}
1115
+ ShortTermState.__init__(self, value)
1116
+
1117
+ @property
1118
+ def varshape(self) -> Tuple[int, ...]:
1119
+ """
1120
+ The shape of each hidden state variable.
1121
+ """
1122
+ return self.value.shape[:-1]
1123
+
1124
+ @property
1125
+ def num_state(self) -> int:
1126
+ """
1127
+ The number of hidden states.
1128
+ """
1129
+ assert self.value.shape[-1] == len(self.name2index), (
1130
+ f'The number of hidden states '
1131
+ f'is not equal to the number of hidden state names.'
1132
+ )
1133
+ return self.value.shape[-1]
1134
+
1135
+ def _check_value(
1136
+ self,
1137
+ value: dict | Sequence
1138
+ ) -> Tuple[ArrayLike, Dict[str, u.Unit], Dict[str, int]]:
1139
+ """
1140
+ Validates and processes the input value to ensure it conforms to the expected format
1141
+ and structure for hidden states.
1142
+
1143
+ This function checks if the input value is a dictionary or sequence of hidden states,
1144
+ verifies that all hidden states have the same shape, and extracts units and indices
1145
+ for each hidden state.
1146
+
1147
+ Args:
1148
+ value (dict | Sequence): A dictionary or sequence representing hidden states.
1149
+ - If a sequence, it is converted to a dictionary with string indices as keys.
1150
+ - Each hidden state should be a numpy.ndarray, jax.Array, or brainunit.Quantity.
1151
+
1152
+ Returns:
1153
+ Tuple[ArrayLike, Dict[str, u.Unit], Dict[str, int]]:
1154
+ - A stacked array of hidden state magnitudes.
1155
+ - A dictionary mapping hidden state names to their units.
1156
+ - A dictionary mapping hidden state names to their indices.
1157
+
1158
+ Raises:
1159
+ TypeError: If any hidden state is not a numpy.ndarray, jax.Array, or brainunit.Quantity.
1160
+ ValueError: If hidden states do not have the same shape.
1161
+ """
1162
+ if isinstance(value, (tuple, list)):
1163
+ value = {str(i): v for i, v in enumerate(value)}
1164
+ assert isinstance(value, dict), (
1165
+ f'Currently, {self.__class__.__name__} only supports '
1166
+ f'dictionary/sequence of hidden states. But we got {type(value)}.'
1167
+ )
1168
+ shapes = []
1169
+ for k, v in value.items():
1170
+ if not isinstance(v, (np.ndarray, jax.Array, u.Quantity)):
1171
+ raise TypeError(
1172
+ f'Currently, {self.__class__.__name__} only supports '
1173
+ f'numpy.ndarray, jax.Array or brainunit.Quantity. '
1174
+ f'But we got {type(v)} for key {k}.'
1175
+ )
1176
+ shapes.append(v.shape)
1177
+ if len(set(shapes)) > 1:
1178
+ info = {k: v.shape for k, v in value.items()}
1179
+ raise ValueError(
1180
+ f'Currently, {self.__class__.__name__} only supports '
1181
+ f'hidden states with the same shape. '
1182
+ f'But we got {info}.'
1183
+ )
1184
+ name2unit = {k: u.get_unit(v) for k, v in value.items()}
1185
+ name2index = {k: i for i, k in enumerate(value.keys())}
1186
+ value = u.math.stack([u.get_magnitude(v) for v in value.values()], axis=-1)
1187
+ return value, name2unit, name2index
1188
+
1189
+ def get_value(self, item: str | int) -> ArrayLike:
1190
+ """
1191
+ Get the value of the hidden state with the key.
1192
+
1193
+ Args:
1194
+ item: The key of the hidden state.
1195
+ - If int, the index of the hidden state.
1196
+ - If str, the name of the hidden state.
1197
+ """
1198
+ if isinstance(item, int):
1199
+ assert item < self.value.shape[-1], (f'Index {item} out of range. '
1200
+ f'The maximum index is {self.value.shape[-1] - 1}.')
1201
+ val = self.value[..., item]
1202
+ elif isinstance(item, str):
1203
+ assert item in self.name2index, (f'Hidden state name {item} not found. '
1204
+ f'Please check the hidden state names.')
1205
+ item = self.name2index[item]
1206
+ val = self.value[..., item]
1207
+ else:
1208
+ raise TypeError(
1209
+ f'Currently, {self.__class__.__name__} only supports '
1210
+ f'int or str for getting the hidden state. '
1211
+ f'But we got {type(item)}.'
1212
+ )
1213
+ if self.index2unit[item].dim.is_dimensionless:
1214
+ return val
1215
+ else:
1216
+ return val * self.index2unit[item]
1217
+
1218
+ def set_value(
1219
+ self,
1220
+ val: Dict[int | str, ArrayLike] | Sequence[ArrayLike]
1221
+ ) -> None:
1222
+ """
1223
+ Set the value of the hidden state with the specified item.
1224
+
1225
+ This method updates the hidden state values based on the provided dictionary or sequence.
1226
+ The values are set according to the indices or names specified in the input.
1227
+
1228
+ Parameters
1229
+ ----------
1230
+ val (Dict[int | str, ArrayLike] | Sequence[ArrayLike]):
1231
+ A dictionary or sequence containing the new values for the hidden states.
1232
+ - If a dictionary, keys can be integers (indices) or strings (names) of the hidden states.
1233
+ - If a sequence, it is converted to a dictionary with indices as keys.
1234
+
1235
+ Returns
1236
+ -------
1237
+ None: This method does not return any value. It updates the hidden state values in place.
1238
+ """
1239
+ if isinstance(val, (tuple, list)):
1240
+ val = {i: v for i, v in enumerate(val)}
1241
+ assert isinstance(val, dict), (f'Currently, {self.__class__.__name__}.set_value() only supports '
1242
+ f'dictionary of hidden states. But we got {type(val)}.')
1243
+ indices = []
1244
+ values = []
1245
+ for index, v in val.items():
1246
+ if isinstance(index, str):
1247
+ index = self.name2index[index]
1248
+ assert isinstance(index, int), (f'Key {index} should be int or str. '
1249
+ f'But we got {type(index)}.')
1250
+ assert v.shape == self.varshape, (f'The shape of the hidden state should be {self.varshape}. '
1251
+ f'But we got {v.shape}.')
1252
+ indices.append(index)
1253
+ values.append(u.Quantity(v).to(self.index2unit[index]).mantissa)
1254
+ if len(indices) == 0:
1255
+ raise ValueError(
1256
+ f'No hidden state is set. Please check the hidden state names or indices.'
1257
+ )
1258
+ if len(indices) == 1:
1259
+ indices = indices[0]
1260
+ values = values[0]
1261
+ else:
1262
+ indices = np.asarray(indices)
1263
+ values = u.math.stack(values, axis=-1)
1264
+ self.value = self.value.at[..., indices].set(values)
1265
+
761
1266
 
762
1267
  class ParamState(LongTermState):
763
1268
  """
@@ -967,17 +1472,17 @@ class StateTraceStack(Generic[A]):
967
1472
  def new_arg(self, state: State) -> None:
968
1473
  """
969
1474
  Apply a transformation to the value of a given state using a predefined function.
970
-
1475
+
971
1476
  This method is used internally to transform the value of a state during tracing.
972
1477
  If a transformation function (``_jax_trace_new_arg``) is defined, it applies this
973
1478
  function to each element of the state's value using JAX's tree mapping.
974
-
1479
+
975
1480
  Args:
976
1481
  state (State): The State object whose value needs to be transformed.
977
-
1482
+
978
1483
  Returns:
979
1484
  None: This function modifies the state in-place and doesn't return anything.
980
-
1485
+
981
1486
  Note:
982
1487
  This method is intended for internal use and relies on the presence of
983
1488
  a ``_jax_trace_new_arg`` function, which should be set separately.
@@ -997,18 +1502,18 @@ class StateTraceStack(Generic[A]):
997
1502
  def read_its_value(self, state: State) -> None:
998
1503
  """
999
1504
  Record that a state's value has been read during tracing.
1000
-
1505
+
1001
1506
  This method marks the given state as having been read in the current
1002
1507
  tracing context. If the state hasn't been encountered before, it adds
1003
1508
  it to the internal tracking structures and applies any necessary
1004
1509
  transformations via the new_arg method.
1005
-
1510
+
1006
1511
  Args:
1007
1512
  state (State): The State object whose value is being read.
1008
-
1513
+
1009
1514
  Returns:
1010
1515
  None
1011
-
1516
+
1012
1517
  Note:
1013
1518
  This method updates the internal tracking of state accesses.
1014
1519
  It doesn't actually read or return the state's value.
@@ -1052,11 +1557,11 @@ class StateTraceStack(Generic[A]):
1052
1557
  ) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1053
1558
  """
1054
1559
  Retrieve the values of all states in the StateTraceStack.
1055
-
1560
+
1056
1561
  This method returns the values of all states, optionally separating them
1057
1562
  into written and read states, and optionally replacing values with None
1058
1563
  for states that weren't accessed in a particular way.
1059
-
1564
+
1060
1565
  Args:
1061
1566
  separate (bool, optional): If True, separate the values into written
1062
1567
  and read states. If False, return all values in a single sequence.
@@ -1064,7 +1569,7 @@ class StateTraceStack(Generic[A]):
1064
1569
  replace (bool, optional): If True and separate is True, replace values
1065
1570
  with None for states that weren't written/read. If False, only
1066
1571
  include values for states that were written/read. Defaults to False.
1067
-
1572
+
1068
1573
  Returns:
1069
1574
  Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
1070
1575
  If separate is False:
@@ -1075,7 +1580,7 @@ class StateTraceStack(Generic[A]):
1075
1580
  - The second sequence contains values of read states.
1076
1581
  If replace is True, these sequences will have None for
1077
1582
  states that weren't written/read respectively.
1078
-
1583
+
1079
1584
  """
1080
1585
  if separate:
1081
1586
  if replace:
@@ -1121,19 +1626,19 @@ class StateTraceStack(Generic[A]):
1121
1626
  def merge(self, *traces) -> 'StateTraceStack':
1122
1627
  """
1123
1628
  Merge other state traces into the current ``StateTraceStack``.
1124
-
1629
+
1125
1630
  This method combines the states, their write status, and original values from
1126
1631
  other ``StateTraceStack`` instances into the current one. If a state from another
1127
1632
  trace is not present in the current trace, it is added. If a state is already
1128
1633
  present, its write status is updated if necessary.
1129
-
1634
+
1130
1635
  Args:
1131
1636
  *traces: Variable number of ``StateTraceStack`` instances to be merged into
1132
1637
  the current instance.
1133
-
1638
+
1134
1639
  Returns:
1135
1640
  StateTraceStack: The current ``StateTraceStack`` instance with merged traces.
1136
-
1641
+
1137
1642
  Note:
1138
1643
  This method modifies the current ``StateTraceStack`` in-place and also returns it.
1139
1644
  """
@@ -1152,16 +1657,16 @@ class StateTraceStack(Generic[A]):
1152
1657
  def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
1153
1658
  """
1154
1659
  Retrieve the states that were read during the function execution.
1155
-
1660
+
1156
1661
  This method returns the states that were accessed (read from) during
1157
1662
  the traced function's execution. It can optionally replace written
1158
1663
  states with None.
1159
-
1664
+
1160
1665
  Args:
1161
1666
  replace_writen (bool, optional): If True, replace written states with None
1162
1667
  in the returned tuple. If False, exclude written states entirely from
1163
1668
  the result. Defaults to False.
1164
-
1669
+
1165
1670
  Returns:
1166
1671
  Tuple[State, ...]: A tuple containing the read states.
1167
1672
  If replace_writen is True, the tuple will have the same length as the
@@ -1177,15 +1682,15 @@ class StateTraceStack(Generic[A]):
1177
1682
  def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
1178
1683
  """
1179
1684
  Retrieve the values of states that were read during the function execution.
1180
-
1685
+
1181
1686
  This method returns the values of states that were accessed (read from) during
1182
1687
  the traced function's execution. It can optionally replace written states with None.
1183
-
1688
+
1184
1689
  Args:
1185
1690
  replace_writen (bool, optional): If True, replace the values of written
1186
1691
  states with None in the returned tuple. If False, exclude written
1187
1692
  states entirely from the result. Defaults to False.
1188
-
1693
+
1189
1694
  Returns:
1190
1695
  Tuple[PyTree, ...]: A tuple containing the values of read states.
1191
1696
  If replace_writen is True, the tuple will have the same length as the
@@ -1204,16 +1709,16 @@ class StateTraceStack(Generic[A]):
1204
1709
  def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
1205
1710
  """
1206
1711
  Retrieve the states that were written during the function execution.
1207
-
1712
+
1208
1713
  This method returns the states that were modified (written to) during
1209
1714
  the traced function's execution. It can optionally replace unwritten (read-only)
1210
1715
  states with None.
1211
-
1716
+
1212
1717
  Args:
1213
1718
  replace_read (bool, optional): If True, replace read-only states with None
1214
1719
  in the returned tuple. If False, exclude read-only states entirely from
1215
1720
  the result. Defaults to False.
1216
-
1721
+
1217
1722
  Returns:
1218
1723
  Tuple[State, ...]: A tuple containing the written states.
1219
1724
  If replace_read is True, the tuple will have the same length as the
@@ -1229,23 +1734,23 @@ class StateTraceStack(Generic[A]):
1229
1734
  def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
1230
1735
  """
1231
1736
  Retrieve the values of states that were written during the function execution.
1232
-
1737
+
1233
1738
  This method returns the values of states that were modified (written to) during
1234
1739
  the traced function's execution. It can optionally replace unwritten (read-only)
1235
1740
  states with None.
1236
-
1741
+
1237
1742
  Args:
1238
1743
  replace_read (bool, optional): If True, replace the values of read-only
1239
1744
  states with None in the returned tuple. If False, exclude read-only
1240
1745
  states entirely from the result. Defaults to False.
1241
-
1746
+
1242
1747
  Returns:
1243
1748
  Tuple[PyTree, ...]: A tuple containing the values of written states.
1244
1749
  If replace_read is True, the tuple will have the same length as the
1245
1750
  total number of states, with None for read-only states.
1246
1751
  If replace_read is False, the tuple will only contain values of
1247
1752
  written states.
1248
-
1753
+
1249
1754
  """
1250
1755
  if replace_read:
1251
1756
  return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])