brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241125__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd.py +121 -120
  5. brainstate/augment/_autograd_test.py +97 -0
  6. brainstate/event/__init__.py +10 -8
  7. brainstate/event/_csr_benchmark.py +14 -0
  8. brainstate/event/{_csr.py → _csr_mv.py} +26 -18
  9. brainstate/event/_csr_mv_benchmark.py +14 -0
  10. brainstate/event/_fixedprob_mv.py +708 -0
  11. brainstate/event/_fixedprob_mv_benchmark.py +128 -0
  12. brainstate/event/{_fixed_probability_test.py → _fixedprob_mv_test.py} +13 -10
  13. brainstate/event/_linear_mv.py +359 -0
  14. brainstate/event/_linear_mv_benckmark.py +82 -0
  15. brainstate/event/{_linear_test.py → _linear_mv_test.py} +9 -4
  16. brainstate/event/_xla_custom_op.py +309 -0
  17. brainstate/event/_xla_custom_op_test.py +55 -0
  18. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  19. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  20. brainstate/nn/_dynamics/_projection_base.py +1 -1
  21. brainstate/nn/_exp_euler.py +1 -1
  22. brainstate/nn/_interaction/__init__.py +13 -4
  23. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  24. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  25. brainstate/nn/_interaction/_linear.py +582 -0
  26. brainstate/nn/_interaction/_linear_test.py +42 -0
  27. brainstate/optim/_lr_scheduler.py +1 -1
  28. brainstate/optim/_optax_optimizer.py +19 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/METADATA +2 -2
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/RECORD +34 -24
  31. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/top_level.txt +1 -0
  32. brainstate/event/_fixed_probability.py +0 -271
  33. brainstate/event/_linear.py +0 -219
  34. /brainstate/event/{_csr_test.py → _csr_mv_test.py} +0 -0
  35. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/LICENSE +0 -0
  36. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/WHEEL +0 -0
@@ -21,12 +21,11 @@ import jax
21
21
  import jax.numpy as jnp
22
22
  import numpy as np
23
23
 
24
- from brainstate._state import ParamState, State
24
+ from brainstate._state import ParamState
25
25
  from brainstate._utils import set_module_as
26
26
  from brainstate.init import param
27
27
  from brainstate.nn._module import Module
28
- from brainstate.typing import ArrayLike
29
- from ._misc import IntScalar
28
+ from brainstate.typing import ArrayLike, Size
30
29
 
31
30
  __all__ = [
32
31
  'CSRLinear',
@@ -39,12 +38,12 @@ class CSRLinear(Module):
39
38
 
40
39
  Parameters
41
40
  ----------
42
- n_pre : int
43
- Number of pre-synaptic neurons.
44
- n_post : int
45
- Number of post-synaptic neurons.
41
+ in_size : Size
42
+ Number of pre-synaptic neurons, i.e., input size.
43
+ out_size : Size
44
+ Number of post-synaptic neurons, i.e., output size.
46
45
  weight : float or callable or jax.Array or brainunit.Quantity
47
- Maximum synaptic conductance.
46
+ Maximum synaptic conductance or a function that returns the maximum synaptic conductance.
48
47
  name : str, optional
49
48
  Name of the module.
50
49
  """
@@ -53,8 +52,8 @@ class CSRLinear(Module):
53
52
 
54
53
  def __init__(
55
54
  self,
56
- n_pre: IntScalar,
57
- n_post: IntScalar,
55
+ in_size: Size,
56
+ out_size: Size,
58
57
  indptr: ArrayLike,
59
58
  indices: ArrayLike,
60
59
  weight: Union[Callable, ArrayLike],
@@ -63,10 +62,11 @@ class CSRLinear(Module):
63
62
  ):
64
63
  super().__init__(name=name)
65
64
 
66
- self.in_size = n_pre
67
- self.out_size = n_post
68
- self.n_pre = n_pre
69
- self.n_post = n_post
65
+ # network size
66
+ self.in_size = in_size
67
+ self.out_size = out_size
68
+ self.n_pre = self.in_size[-1]
69
+ self.n_post = self.out_size[-1]
70
70
 
71
71
  # gradient mode
72
72
  assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
@@ -77,9 +77,10 @@ class CSRLinear(Module):
77
77
  indices = jnp.asarray(indices)
78
78
  assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
79
79
  assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
80
- assert indptr.size == n_pre + 1, f"indptr must have size {n_pre + 1}. Got: {indptr.size}"
81
- self.indptr = u.math.asarray(indptr)
82
- self.indices = u.math.asarray(indices)
80
+ assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
81
+ with jax.ensure_compile_time_eval():
82
+ self.indptr = u.math.asarray(indptr)
83
+ self.indices = u.math.asarray(indices)
83
84
 
84
85
  # maximum synaptic conductance
85
86
  weight = param(weight, (len(indices),), allow_none=False)
@@ -88,7 +89,9 @@ class CSRLinear(Module):
88
89
  self.weight = ParamState(weight)
89
90
 
90
91
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
91
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
92
+ weight = self.weight.value
93
+
94
+ # return zero if no pre-synaptic neurons
92
95
  if len(self.indices) == 0:
93
96
  r = u.math.zeros(spk.shape[:-1] + (self.n_post,),
94
97
  dtype=weight.dtype,
@@ -96,6 +99,8 @@ class CSRLinear(Module):
96
99
  return u.maybe_decimal(r)
97
100
 
98
101
  device_kind = jax.devices()[0].platform # spk.device.device_kind
102
+
103
+ # CPU implementation
99
104
  if device_kind == 'cpu':
100
105
  return cpu_event_csr(
101
106
  u.math.asarray(spk),
@@ -104,8 +109,11 @@ class CSRLinear(Module):
104
109
  u.math.asarray(weight),
105
110
  n_post=self.n_post, grad_mode=self.grad_mode
106
111
  )
112
+
113
+ # GPU/TPU implementation
107
114
  elif device_kind in ['gpu', 'tpu']:
108
115
  raise NotImplementedError()
116
+
109
117
  else:
110
118
  raise ValueError(f"Unsupported device: {device_kind}")
111
119
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================