brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/augment/_autograd.py +9 -6
  4. brainstate/event/__init__.py +4 -2
  5. brainstate/event/_csr.py +26 -18
  6. brainstate/event/_csr_benchmark.py +14 -0
  7. brainstate/event/_fixed_probability.py +589 -152
  8. brainstate/event/_fixed_probability_benchmark.py +128 -0
  9. brainstate/event/_fixed_probability_test.py +13 -10
  10. brainstate/event/_linear.py +267 -127
  11. brainstate/event/_linear_benckmark.py +82 -0
  12. brainstate/event/_linear_test.py +8 -3
  13. brainstate/event/_xla_custom_op.py +312 -0
  14. brainstate/event/_xla_custom_op_test.py +55 -0
  15. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  16. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  17. brainstate/nn/_dynamics/_projection_base.py +1 -1
  18. brainstate/nn/_exp_euler.py +1 -1
  19. brainstate/nn/_interaction/__init__.py +13 -4
  20. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  21. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  22. brainstate/nn/_interaction/_linear.py +582 -0
  23. brainstate/nn/_interaction/_linear_test.py +42 -0
  24. brainstate/optim/_lr_scheduler.py +1 -1
  25. brainstate/optim/_optax_optimizer.py +18 -0
  26. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
  27. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
  28. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -18,10 +18,8 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import collections.abc
21
- import numbers
22
21
  from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
23
22
 
24
- import brainunit as u
25
23
  import jax
26
24
  import jax.numpy as jnp
27
25
 
@@ -33,10 +31,8 @@ from brainstate.typing import ArrayLike
33
31
  T = TypeVar('T')
34
32
 
35
33
  __all__ = [
36
- 'Linear', 'ScaledWSLinear', 'SignedWLinear', 'CSRLinear',
37
34
  'Conv1d', 'Conv2d', 'Conv3d',
38
35
  'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
39
- 'AllToAll',
40
36
  ]
41
37
 
42
38
 
@@ -79,168 +75,6 @@ def replicate(
79
75
  f"sequence of length {num_replicate}.")
80
76
 
81
77
 
82
- class Linear(Module):
83
- """
84
- Linear layer.
85
- """
86
- __module__ = 'brainstate.nn'
87
-
88
- def __init__(
89
- self,
90
- in_size: Union[int, Sequence[int]],
91
- out_size: Union[int, Sequence[int]],
92
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
93
- b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
94
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
95
- name: Optional[str] = None,
96
- ):
97
- super().__init__(name=name)
98
-
99
- # input and output shape
100
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
101
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
102
-
103
- # w_mask
104
- self.w_mask = init.param(w_mask, self.in_size + self.out_size)
105
-
106
- # weights
107
- params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
108
- if b_init is not None:
109
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
110
-
111
- # weight + op
112
- self.weight = ParamState(params)
113
-
114
- def update(self, x):
115
- params = self.weight.value
116
- weight = params['weight']
117
- if self.w_mask is not None:
118
- weight = weight * self.w_mask
119
- y = u.math.dot(x, weight)
120
- if 'bias' in params:
121
- y = y + params['bias']
122
- return y
123
-
124
-
125
- class SignedWLinear(Module):
126
- """
127
- Linear layer with signed weights.
128
- """
129
- __module__ = 'brainstate.nn'
130
-
131
- def __init__(
132
- self,
133
- in_size: Union[int, Sequence[int]],
134
- out_size: Union[int, Sequence[int]],
135
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
136
- w_sign: Optional[ArrayLike] = None,
137
- name: Optional[str] = None,
138
-
139
- ):
140
- super().__init__(name=name)
141
-
142
- # input and output shape
143
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
144
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
145
-
146
- # w_mask
147
- self.w_sign = w_sign
148
-
149
- # weights
150
- weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
151
- self.weight = ParamState(weight)
152
-
153
- def _operation(self, x, w):
154
- if self.w_sign is None:
155
- return jnp.matmul(x, jnp.abs(w))
156
- else:
157
- return jnp.matmul(x, jnp.abs(w) * self.w_sign)
158
-
159
- def update(self, x):
160
- return self._operation(x, self.weight.value)
161
-
162
-
163
- class ScaledWSLinear(Module):
164
- """
165
- Linear Layer with Weight Standardization.
166
-
167
- Applies weight standardization to the weights of the linear layer.
168
-
169
- Parameters
170
- ----------
171
- in_size: int, sequence of int
172
- The input size.
173
- out_size: int, sequence of int
174
- The output size.
175
- w_init: Callable, ArrayLike
176
- The initializer for the weights.
177
- b_init: Callable, ArrayLike
178
- The initializer for the bias.
179
- w_mask: ArrayLike, Callable
180
- The optional mask of the weights.
181
- ws_gain: bool
182
- Whether to use gain for the weights. The default is True.
183
- eps: float
184
- The epsilon value for the weight standardization.
185
- name: str
186
- The name of the object.
187
-
188
- """
189
- __module__ = 'brainstate.nn'
190
-
191
- def __init__(
192
- self,
193
- in_size: Union[int, Sequence[int]],
194
- out_size: Union[int, Sequence[int]],
195
- w_init: Callable = init.KaimingNormal(),
196
- b_init: Callable = init.ZeroInit(),
197
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
198
- ws_gain: bool = True,
199
- eps: float = 1e-4,
200
- name: str = None,
201
- ):
202
- super().__init__(name=name)
203
-
204
- # input and output shape
205
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
206
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
207
-
208
- # w_mask
209
- self.w_mask = init.param(w_mask, (self.in_size[0], 1))
210
-
211
- # parameters
212
- self.eps = eps
213
-
214
- # weights
215
- params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
216
- if b_init is not None:
217
- params['bias'] = init.param(b_init, self.out_size, allow_none=False)
218
- # gain
219
- if ws_gain:
220
- s = params['weight'].shape
221
- params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
222
-
223
- # weight operation
224
- self.weight = ParamState(params)
225
-
226
- def update(self, x):
227
- return self._operation(x, self.weight.value)
228
-
229
- def _operation(self, x, params):
230
- w = params['weight']
231
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
232
- if self.w_mask is not None:
233
- w = w * self.w_mask
234
- y = jnp.dot(x, w)
235
- if 'bias' in params:
236
- y = y + params['bias']
237
- return y
238
-
239
-
240
- class CSRLinear(Module):
241
- __module__ = 'brainstate.nn'
242
-
243
-
244
78
  class _BaseConv(Module):
245
79
  # the number of spatial dimensions
246
80
  num_spatial_dims: int
@@ -663,64 +497,3 @@ _ws_conv_doc = '''
663
497
  ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
664
498
  ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
665
499
  ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc
666
-
667
-
668
- class AllToAll(Module):
669
- """Synaptic matrix multiplication with All2All connections.
670
-
671
- Args:
672
- in_size: int. The number of neurons in the presynaptic neuron group.
673
- out_size: int. The number of neurons in the postsynaptic neuron group.
674
- w_init: The synaptic weights.
675
- include_self: bool. Whether connect the neuron with at the same position.
676
- name: str. The object name.
677
- """
678
-
679
- def __init__(
680
- self,
681
- in_size: Union[int, Sequence[int]],
682
- out_size: Union[int, Sequence[int]],
683
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
684
- include_self: bool = True,
685
- name: Optional[str] = None,
686
- ):
687
- super().__init__(name=name)
688
-
689
- # input and output shape
690
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
691
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
692
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
693
- 'and "out_size" must be the same.')
694
-
695
- # weights
696
- self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
697
-
698
- # others
699
- self.include_self = include_self
700
-
701
- def update(self, pre_val):
702
- if u.math.ndim(self.weight.value) == 0: # weight is a scalar
703
- if pre_val.ndim == 1:
704
- post_val = u.math.sum(pre_val)
705
- else:
706
- post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
707
- if not self.include_self:
708
- if self.in_size == self.out_size:
709
- post_val = post_val - pre_val
710
- elif self.in_size[-1] > self.out_size[-1]:
711
- val = pre_val[..., :self.out_size[-1]]
712
- post_val = post_val - val
713
- else:
714
- size = list(self.out_size)
715
- size[-1] = self.out_size[-1] - self.in_size[-1]
716
- val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
717
- post_val = post_val - val
718
- post_val = self.weight.value * post_val
719
-
720
- else: # weight is a matrix
721
- assert u.math.ndim(self.weight.value) == 2, '"weight" must be a 2D matrix.'
722
- if not self.include_self:
723
- post_val = pre_val @ u.math.fill_diagonal(self.weight.value, 0.)
724
- else:
725
- post_val = pre_val @ self.weight.value
726
- return post_val
@@ -235,20 +235,5 @@ class TestConvTranspose3d(parameterized.TestCase):
235
235
  y = conv_transpose_module(x)
236
236
  print(y.shape)
237
237
 
238
-
239
- class TestDense(parameterized.TestCase):
240
- @parameterized.product(
241
- size=[(10,),
242
- (20, 10),
243
- (5, 8, 10)],
244
- num_out=[20, ]
245
- )
246
- def test_Dense1(self, size, num_out):
247
- f = bst.nn.Linear(10, num_out)
248
- x = bst.random.random(size)
249
- y = f(x)
250
- self.assertTrue(y.shape == size[:-1] + (num_out,))
251
-
252
-
253
238
  if __name__ == '__main__':
254
239
  absltest.main()