brainstate 0.1.2__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -10
  3. brainstate/mixin.py +1 -14
  4. brainstate/nn/__init__.py +81 -17
  5. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  6. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
  7. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
  8. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
  9. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  10. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  11. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
  12. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  13. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  14. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
  15. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  16. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  17. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  18. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  19. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
  20. brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
  21. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  22. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  23. brainstate/nn/_stp.py +236 -0
  24. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
  25. brainstate/nn/_synaptic_projection.py +133 -0
  26. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  27. {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
  28. {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/RECORD +44 -46
  29. brainstate/nn/_dyn_impl/__init__.py +0 -42
  30. brainstate/nn/_dynamics/__init__.py +0 -37
  31. brainstate/nn/_elementwise/__init__.py +0 -22
  32. brainstate/nn/_interaction/__init__.py +0 -41
  33. /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
  34. /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
  35. /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
  36. /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
  37. /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -0
  38. /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
  39. /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
  40. /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
  41. /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
  42. /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
  43. /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
  44. /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
  45. /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
  46. {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
  47. {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
  48. {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.2"
20
+ __version__ = "0.1.3"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
@@ -19,7 +19,7 @@
19
19
  import importlib.util
20
20
  from contextlib import contextmanager
21
21
  from functools import partial
22
- from typing import Iterable, Hashable, TypeVar, Callable
22
+ from typing import Iterable, Hashable, TypeVar, Callable, TYPE_CHECKING
23
23
 
24
24
  import jax
25
25
 
@@ -45,8 +45,8 @@ T1 = TypeVar("T1")
45
45
  T2 = TypeVar("T2")
46
46
  T3 = TypeVar("T3")
47
47
 
48
-
49
48
  from saiunit._compatible_import import wrap_init
49
+
50
50
  brainevent_installed = importlib.util.find_spec('brainevent') is not None
51
51
 
52
52
  from jax.core import get_aval, Tracer
@@ -151,13 +151,13 @@ def to_concrete_aval(aval):
151
151
  return aval
152
152
 
153
153
 
154
- if brainevent_installed:
155
- import brainevent
156
- else:
154
+ if not brainevent_installed:
155
+ if not TYPE_CHECKING:
156
+ class BrainEvent:
157
+ def __getattr__(self, item):
158
+ raise ImportError('brainevent is not installed, please install brainevent first.')
157
159
 
158
- class BrainEvent:
159
- def __getattr__(self, item):
160
- raise ImportError('brainevent is not installed, please install brainevent first.')
160
+ brainevent = BrainEvent()
161
161
 
162
-
163
- brainevent = BrainEvent()
162
+ else:
163
+ import brainevent
brainstate/mixin.py CHANGED
@@ -132,6 +132,7 @@ class AlignPost(Mixin):
132
132
  raise NotImplementedError
133
133
 
134
134
 
135
+
135
136
  class BindCondData(Mixin):
136
137
  """Bind temporary conductance data.
137
138
 
@@ -147,7 +148,6 @@ class BindCondData(Mixin):
147
148
 
148
149
 
149
150
  class UpdateReturn(Mixin):
150
-
151
151
  def update_return(self) -> PyTree:
152
152
  """
153
153
  The update function return of the model.
@@ -157,19 +157,6 @@ class UpdateReturn(Mixin):
157
157
  """
158
158
  raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
159
159
 
160
- def update_return_info(self) -> PyTree:
161
- """
162
- The update return information of the model.
163
-
164
- It should be a pytree, with each element as a ``jax.Array``.
165
-
166
- .. note::
167
- Should not include the batch axis and batch in_size.
168
- These information will be inferred from the ``mode`` attribute.
169
-
170
- """
171
- raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
172
-
173
160
 
174
161
  class _MetaUnionType(type):
175
162
  def __new__(cls, name, bases, dct):
brainstate/nn/__init__.py CHANGED
@@ -19,46 +19,110 @@ from ._collective_ops import *
19
19
  from ._collective_ops import __all__ as collective_ops_all
20
20
  from ._common import *
21
21
  from ._common import __all__ as common_all
22
- from ._dyn_impl import *
23
- from ._dyn_impl import __all__ as dyn_impl_all
22
+ from ._conv import *
23
+ from ._conv import __all__ as conv_all
24
+ from ._delay import *
25
+ from ._delay import __all__ as state_delay_all
26
+ from ._dropout import *
27
+ from ._dropout import __all__ as dropout_all
24
28
  from ._dynamics import *
25
- from ._dynamics import __all__ as dynamics_all
29
+ from ._dynamics import __all__ as dyn_all
26
30
  from ._elementwise import *
27
31
  from ._elementwise import __all__ as elementwise_all
28
- from ._event import *
29
- from ._event import __all__ as _event_all
32
+ from ._embedding import *
33
+ from ._embedding import __all__ as embed_all
30
34
  from ._exp_euler import *
31
35
  from ._exp_euler import __all__ as exp_euler_all
32
- from ._interaction import *
33
- from ._interaction import __all__ as interaction_all
36
+ from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
37
+ from ._inputs import *
38
+ from ._inputs import __all__ as inputs_all
39
+ from ._linear import *
40
+ from ._linear import __all__ as linear_all
41
+ from ._linear_mv import EventLinear
42
+ from ._ltp import *
43
+ from ._ltp import __all__ as ltp_all
34
44
  from ._module import *
35
45
  from ._module import __all__ as module_all
46
+ from ._neuron import *
47
+ from ._neuron import __all__ as dyn_neuron_all
48
+ from ._normalizations import *
49
+ from ._normalizations import __all__ as normalizations_all
50
+ from ._poolings import *
51
+ from ._poolings import __all__ as poolings_all
52
+ from ._projection import *
53
+ from ._projection import __all__ as projection_all
54
+ from ._rate_rnns import *
55
+ from ._rate_rnns import __all__ as rate_rnns
56
+ from ._readout import *
57
+ from ._readout import __all__ as readout_all
58
+ from ._stp import *
59
+ from ._stp import __all__ as stp_all
60
+ from ._synapse import *
61
+ from ._synapse import __all__ as dyn_synapse_all
62
+ from ._synaptic_projection import *
63
+ from ._synaptic_projection import __all__ as _syn_proj_all
64
+ from ._synouts import *
65
+ from ._synouts import __all__ as synouts_all
36
66
  from ._utils import *
37
67
  from ._utils import __all__ as utils_all
38
68
 
39
69
  __all__ = (
40
- ['metrics']
70
+ [
71
+ 'metrics',
72
+ 'EventLinear',
73
+ 'EventFixedProb',
74
+ 'EventFixedNumConn',
75
+ ]
41
76
  + collective_ops_all
42
77
  + common_all
43
- + dyn_impl_all
44
- + dynamics_all
45
78
  + elementwise_all
46
79
  + module_all
47
80
  + exp_euler_all
48
- + interaction_all
49
81
  + utils_all
50
- + _event_all
82
+ + dyn_all
83
+ + projection_all
84
+ + state_delay_all
85
+ + synouts_all
86
+ + conv_all
87
+ + linear_all
88
+ + normalizations_all
89
+ + poolings_all
90
+ + embed_all
91
+ + dropout_all
92
+ + elementwise_all
93
+ + dyn_neuron_all
94
+ + dyn_synapse_all
95
+ + inputs_all
96
+ + rate_rnns
97
+ + readout_all
98
+ + stp_all
99
+ + ltp_all
100
+ + _syn_proj_all
51
101
  )
52
102
 
53
103
  del (
54
104
  collective_ops_all,
55
105
  common_all,
56
- dyn_impl_all,
57
- dynamics_all,
58
- elementwise_all,
59
106
  module_all,
60
107
  exp_euler_all,
61
- interaction_all,
62
108
  utils_all,
63
- _event_all,
109
+ dyn_all,
110
+ projection_all,
111
+ state_delay_all,
112
+ synouts_all,
113
+ conv_all,
114
+ linear_all,
115
+ normalizations_all,
116
+ poolings_all,
117
+ embed_all,
118
+ dropout_all,
119
+ elementwise_all,
120
+ dyn_neuron_all,
121
+ dyn_synapse_all,
122
+ inputs_all,
123
+ readout_all,
124
+ rate_rnns,
125
+ stp_all,
126
+ ltp_all,
127
+ _syn_proj_all,
64
128
  )
@@ -23,8 +23,8 @@ import jax.numpy as jnp
23
23
 
24
24
  from brainstate import init, functional
25
25
  from brainstate._state import ParamState
26
- from brainstate.nn._module import Module
27
26
  from brainstate.typing import ArrayLike
27
+ from ._module import Module
28
28
 
29
29
  T = TypeVar('T')
30
30
 
@@ -27,9 +27,9 @@ from brainstate import environ
27
27
  from brainstate._state import ShortTermState, State
28
28
  from brainstate.compile import jit_error_if
29
29
  from brainstate.graph import Node
30
- from brainstate.nn._collective_ops import call_order
31
- from brainstate.nn._module import Module
32
30
  from brainstate.typing import ArrayLike, PyTree
31
+ from ._collective_ops import call_order
32
+ from ._module import Module
33
33
 
34
34
  __all__ = [
35
35
  'Delay', 'DelayAccess', 'StateWithDelay',
@@ -135,6 +135,7 @@ class Delay(Module):
135
135
  entries: Optional[Dict] = None, # delay access entry
136
136
  delay_method: Optional[str] = _DELAY_ROTATE, # delay method
137
137
  interp_method: str = _INTERP_LINEAR, # interpolation method
138
+ take_aware_unit: bool = False
138
139
  ):
139
140
  # target information
140
141
  self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
@@ -170,6 +171,9 @@ class Delay(Module):
170
171
  for entry, delay_time in entries.items():
171
172
  self.register_entry(entry, delay_time)
172
173
 
174
+ self.take_aware_unit = take_aware_unit
175
+ self._unit = None
176
+
173
177
  @property
174
178
  def history(self):
175
179
  return self._history
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import random, environ, init
24
24
  from brainstate._state import ShortTermState
25
- from brainstate.nn._module import ElementWiseBlock
26
25
  from brainstate.typing import Size
26
+ from ._module import ElementWiseBlock
27
27
 
28
28
  __all__ = [
29
29
  'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
@@ -41,13 +41,14 @@ import numpy as np
41
41
  from brainstate import environ
42
42
  from brainstate._state import State
43
43
  from brainstate.graph import Node
44
- from brainstate.mixin import ParamDescriber
45
- from brainstate.nn._module import Module
44
+ from brainstate.mixin import ParamDescriber, UpdateReturn
46
45
  from brainstate.typing import Size, ArrayLike
47
- from ._state_delay import StateWithDelay, Delay
46
+ from ._delay import StateWithDelay, Delay
47
+ from ._module import Module
48
48
 
49
49
  __all__ = [
50
- 'DynamicsGroup', 'Projection', 'Dynamics', 'Prefetch',
50
+ 'DynamicsGroup', 'Projection', 'Dynamics',
51
+ 'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
51
52
  ]
52
53
 
53
54
  T = TypeVar('T')
@@ -99,7 +100,7 @@ class Projection(Module):
99
100
  raise ValueError('Do not implement the update() function.')
100
101
 
101
102
 
102
- class Dynamics(Module):
103
+ class Dynamics(Module, UpdateReturn):
103
104
  """
104
105
  Base class for implementing neural dynamics models in BrainState.
105
106
 
@@ -821,6 +822,41 @@ class Dynamics(Module):
821
822
  else:
822
823
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
823
824
 
825
+ def prefetch_delay(self, state: str, delay: Optional[ArrayLike] = None) -> 'PrefetchDelayAt':
826
+ """
827
+ Create a reference to a delayed state or variable in the module.
828
+
829
+ This method simplifies the process of accessing a delayed version of a state or variable
830
+ within the module. It first creates a prefetch reference to the specified state,
831
+ then specifies the delay time for accessing this state.
832
+
833
+ Args:
834
+ state (str): The name of the state or variable to reference.
835
+ delay (Optional[ArrayLike]): The amount of time to delay the variable access,
836
+ typically in time units (e.g., milliseconds). Defaults to None.
837
+
838
+ Returns:
839
+ PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
840
+ """
841
+ return self.prefetch(state).delay.at(delay)
842
+
843
+ def output_delay(self, delay: Optional[ArrayLike] = None) -> 'OutputDelayAt':
844
+ """
845
+ Create a reference to the delayed output of the module.
846
+
847
+ This method simplifies the process of accessing a delayed version of the module's output.
848
+ It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
849
+ at the specified delay time.
850
+
851
+ Args:
852
+ delay (Optional[ArrayLike]): The amount of time to delay the output access,
853
+ typically in time units (e.g., milliseconds). Defaults to None.
854
+
855
+ Returns:
856
+ OutputDelayAt: An object that provides access to the module's output at the specified delay time.
857
+ """
858
+ return OutputDelayAt(self, delay)
859
+
824
860
 
825
861
  class Prefetch(Node):
826
862
  """
@@ -885,6 +921,7 @@ class Prefetch(Node):
885
921
  An object that provides access to delayed versions of the prefetched item.
886
922
  """
887
923
  return PrefetchDelay(self.module, self.item)
924
+ # return PrefetchDelayAt(self.module, self.item, time)
888
925
 
889
926
  def __call__(self, *args, **kwargs):
890
927
  """
@@ -1007,7 +1044,7 @@ class PrefetchDelayAt(Node):
1007
1044
  self,
1008
1045
  module: Dynamics,
1009
1046
  item: str,
1010
- time: ArrayLike
1047
+ time: ArrayLike = None,
1011
1048
  ):
1012
1049
  """
1013
1050
  Initialize a PrefetchDelayAt object.
@@ -1026,14 +1063,16 @@ class PrefetchDelayAt(Node):
1026
1063
  self.module = module
1027
1064
  self.item = item
1028
1065
  self.time = time
1029
- self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1030
1066
 
1031
- # register the delay
1032
- key = _get_delay_key(item)
1033
- if not module._has_after_update(key):
1034
- module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
1035
- self.state_delay: StateWithDelay = module._get_after_update(key)
1036
- self.state_delay.register_delay(time)
1067
+ if time is not None:
1068
+ self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1069
+
1070
+ # register the delay
1071
+ key = _get_prefetch_delay_key(item)
1072
+ if not module._has_after_update(key):
1073
+ module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
1074
+ self.state_delay: StateWithDelay = module._get_after_update(key)
1075
+ self.state_delay.register_delay(time)
1037
1076
 
1038
1077
  def __call__(self, *args, **kwargs):
1039
1078
  """
@@ -1044,12 +1083,94 @@ class PrefetchDelayAt(Node):
1044
1083
  Any
1045
1084
  The value of the state or variable at the specified delay time.
1046
1085
  """
1047
- # return self.state_delay.retrieve_at_time(self.time)
1048
- return self.state_delay.retrieve_at_step(self.step)
1086
+ if self.time is None:
1087
+ return _get_prefetch_item(self).value
1088
+ else:
1089
+ return self.state_delay.retrieve_at_step(self.step)
1090
+
1091
+
1092
+ class OutputDelayAt(Node):
1093
+ """
1094
+ Provides access to a specific delayed state or variable value at the specific time.
1095
+
1096
+ This class represents the final step in the prefetch delay chain, providing
1097
+ actual access to state values at a specific delay time. It converts the
1098
+ specified time delay into steps and registers the delay with the appropriate
1099
+ StateWithDelay handler.
1100
+
1101
+ Parameters
1102
+ ----------
1103
+ module : Dynamics
1104
+ The dynamics module that contains the referenced state or variable.
1105
+ item : str
1106
+ The name of the state or variable to access with delay.
1107
+ time : ArrayLike
1108
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1109
+
1110
+ Examples
1111
+ --------
1112
+ >>> import brainstate
1113
+ >>> import brainunit as u
1114
+ >>> neuron = brainstate.nn.LIF(10)
1115
+ >>> # Create a reference to voltage delayed by 5ms
1116
+ >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1117
+ >>> # Get the delayed value
1118
+ >>> v_value = delayed_v()
1119
+ """
1120
+
1121
+ def __init__(
1122
+ self,
1123
+ module: Dynamics,
1124
+ time: Optional[ArrayLike] = None,
1125
+ ):
1126
+ """
1127
+ Initialize a PrefetchDelayAt object.
1128
+
1129
+ Parameters
1130
+ ----------
1131
+ module : AlignPre, Module
1132
+ The dynamics module that contains the referenced state or variable.
1133
+ time : ArrayLike
1134
+ The amount of time to delay access by, typically in time units.
1135
+ """
1136
+ super().__init__()
1137
+ assert isinstance(module, UpdateReturn), 'The module should implement the `update_return` method.'
1138
+ assert isinstance(module, Module), 'The module should be an instance of Module.'
1139
+ self.module = module
1140
+ self.time = time
1141
+ if time is not None:
1142
+ self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1143
+
1144
+ # register the delay
1145
+ key = _get_output_delay_key()
1146
+ if not module._has_after_update(key):
1147
+ # TODO: unit processing
1148
+ delay = Delay(module.update_return(), time)
1149
+ module._add_after_update(key, receive_update_output(delay))
1150
+ self.out_delay: Delay = module._get_after_update(key)
1151
+ self.out_delay.register_delay(time)
1152
+
1153
+ def __call__(self, *args, **kwargs):
1154
+ """
1155
+ Retrieve the value of the state at the specified delay time.
1156
+
1157
+ Returns
1158
+ -------
1159
+ Any
1160
+ The value of the state or variable at the specified delay time.
1161
+ """
1162
+ if self.time is None:
1163
+ return self.module.update_return()
1164
+ else:
1165
+ return self.out_delay.retrieve_at_step(self.step)
1166
+
1167
+
1168
+ def _get_prefetch_delay_key(item) -> str:
1169
+ return f'{item}-prefetch-delay'
1049
1170
 
1050
1171
 
1051
- def _get_delay_key(item) -> str:
1052
- return f'{item}-delay'
1172
+ def _get_output_delay_key() -> str:
1173
+ return f'output-delay'
1053
1174
 
1054
1175
 
1055
1176
  def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
@@ -1064,7 +1185,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
1064
1185
  f'The target module should be an instance '
1065
1186
  f'of Dynamics. But got {target.module}.'
1066
1187
  )
1067
- delay = target.module._get_after_update(_get_delay_key(target.item))
1188
+ delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1068
1189
  if not isinstance(delay, StateWithDelay):
1069
1190
  raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1070
1191
  f'its delay. But got {delay}.')
@@ -23,8 +23,8 @@ import jax.typing
23
23
 
24
24
  from brainstate import random, functional as F
25
25
  from brainstate._state import ParamState
26
- from brainstate.nn._module import ElementWiseBlock
27
26
  from brainstate.typing import ArrayLike
27
+ from ._module import ElementWiseBlock
28
28
 
29
29
  __all__ = [
30
30
  # activation functions
@@ -17,8 +17,8 @@ from typing import Optional, Callable, Union
17
17
 
18
18
  from brainstate import init
19
19
  from brainstate._state import ParamState
20
- from brainstate.nn._module import Module
21
20
  from brainstate.typing import ArrayLike
21
+ from ._module import Module
22
22
 
23
23
  __all__ = [
24
24
  'Embedding',
@@ -25,8 +25,8 @@ from brainstate import random, augment, environ, init
25
25
  from brainstate._compatible_import import brainevent
26
26
  from brainstate._state import ParamState
27
27
  from brainstate.compile import for_loop
28
- from brainstate.nn._module import Module
29
28
  from brainstate.typing import Size, ArrayLike
29
+ from ._module import Module
30
30
 
31
31
  __all__ = [
32
32
  'EventFixedNumConn',
@@ -20,12 +20,11 @@ import jax
20
20
  import numpy as np
21
21
 
22
22
  from brainstate import environ, init, random
23
- from brainstate._state import ShortTermState
24
- from brainstate._state import State, maybe_state
23
+ from brainstate._state import ShortTermState, State, maybe_state
25
24
  from brainstate.compile import while_loop
26
- from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
27
- from brainstate.nn._module import Module
28
25
  from brainstate.typing import ArrayLike, Size, DTypeLike
26
+ from ._dynamics import Dynamics, Prefetch
27
+ from ._module import Module
29
28
 
30
29
  __all__ = [
31
30
  'SpikeTime',
@@ -134,7 +133,7 @@ class PoissonSpike(Dynamics):
134
133
  self.freqs = init.param(freqs, self.varshape, allow_none=False)
135
134
 
136
135
  def update(self):
137
- spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
136
+ spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
138
137
  spikes = u.math.asarray(spikes, dtype=self.spk_type)
139
138
  return spikes
140
139
 
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import init, functional
24
24
  from brainstate._state import ParamState
25
- from brainstate.nn._module import Module
26
25
  from brainstate.typing import ArrayLike, Size
26
+ from ._module import Module
27
27
 
28
28
  __all__ = [
29
29
  'Linear',
@@ -350,10 +350,7 @@ class OneToOne(Module):
350
350
  self.weight = param_type(param)
351
351
 
352
352
  def update(self, pre_val):
353
- pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
354
- w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
355
- post_val = pre_val * w_val
356
- post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
353
+ post_val = pre_val * self.weight.value['weight']
357
354
  if 'bias' in self.weight.value:
358
355
  post_val = post_val + self.weight.value['bias']
359
356
  return post_val
@@ -21,8 +21,8 @@ import jax
21
21
  from brainstate import init
22
22
  from brainstate._compatible_import import brainevent
23
23
  from brainstate._state import ParamState
24
- from brainstate.nn._module import Module
25
24
  from brainstate.typing import Size, ArrayLike
25
+ from ._module import Module
26
26
 
27
27
  __all__ = [
28
28
  'EventLinear',
@@ -16,11 +16,13 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
20
- from ._linear_mv import EventLinear
19
+ from ._synapse import Synapse
21
20
 
22
21
  __all__ = [
23
- 'EventLinear',
24
- 'EventFixedProb',
25
- 'EventFixedNumConn',
22
+ 'LongTermPlasticity',
26
23
  ]
24
+
25
+
26
+ class LongTermPlasticity(Synapse):
27
+ pass
28
+
@@ -22,9 +22,9 @@ import jax
22
22
 
23
23
  from brainstate import init, surrogate, environ
24
24
  from brainstate._state import HiddenState, ShortTermState
25
- from brainstate.nn._dynamics._dynamics_base import Dynamics
26
- from brainstate.nn._exp_euler import exp_euler_step
27
25
  from brainstate.typing import ArrayLike, Size
26
+ from ._dynamics import Dynamics
27
+ from ._exp_euler import exp_euler_step
28
28
 
29
29
  __all__ = [
30
30
  'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import environ, init
24
24
  from brainstate._state import ParamState, BatchState
25
- from brainstate.nn._module import Module
26
25
  from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
26
+ from ._module import Module
27
27
 
28
28
  __all__ = [
29
29
  'BatchNorm0d',
@@ -25,8 +25,8 @@ import jax.numpy as jnp
25
25
  import numpy as np
26
26
 
27
27
  from brainstate import environ
28
- from brainstate.nn._module import Module
29
28
  from brainstate.typing import Size
29
+ from ._module import Module
30
30
 
31
31
  __all__ = [
32
32
  'Flatten', 'Unflatten',
@@ -103,7 +103,7 @@ class TestPool(parameterized.TestCase):
103
103
  for target_size in [10, 9, 8, 7, 6]
104
104
  )
105
105
  def test_adaptive_pool1d(self, target_size):
106
- from brainstate.nn._interaction._poolings import _adaptive_pool1d
106
+ from brainstate.nn._poolings import _adaptive_pool1d
107
107
 
108
108
  arr = brainstate.random.rand(100)
109
109
  op = jax.numpy.mean