brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,161 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- import unittest
20
-
21
- import brainunit as u
22
- import jax
23
- import jax.numpy as jnp
24
-
25
- import brainstate
26
- from brainstate.nn import IF, LIF, ALIF
27
-
28
-
29
- class TestNeuron(unittest.TestCase):
30
- def setUp(self):
31
- self.in_size = 10
32
- self.batch_size = 5
33
- self.time_steps = 100
34
-
35
- def test_neuron_base_class(self):
36
- with self.assertRaises(NotImplementedError):
37
- brainstate.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
38
-
39
- def generate_input(self):
40
- return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
41
-
42
- def test_if_neuron(self):
43
- with brainstate.environ.context(dt=0.1 * u.ms):
44
- neuron = IF(self.in_size)
45
- inputs = self.generate_input()
46
-
47
- # Test initialization
48
- self.assertEqual(neuron.in_size, (self.in_size,))
49
- self.assertEqual(neuron.out_size, (self.in_size,))
50
-
51
- # Test forward pass
52
- state = neuron.init_state(self.batch_size)
53
-
54
- for t in range(self.time_steps):
55
- out = neuron(inputs[t])
56
- self.assertEqual(out.shape, (self.batch_size, self.in_size))
57
-
58
- # Test spike generation
59
- v = jnp.linspace(-1, 1, 100) * u.mV
60
- spikes = neuron.get_spike(v)
61
- self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
62
-
63
- def test_lif_neuron(self):
64
- with brainstate.environ.context(dt=0.1 * u.ms):
65
- tau = 20.0 * u.ms
66
- neuron = LIF(self.in_size, tau=tau)
67
- inputs = self.generate_input()
68
-
69
- # Test initialization
70
- self.assertEqual(neuron.in_size, (self.in_size,))
71
- self.assertEqual(neuron.out_size, (self.in_size,))
72
- self.assertEqual(neuron.tau, tau)
73
-
74
- # Test forward pass
75
- state = neuron.init_state(self.batch_size)
76
- call = brainstate.compile.jit(neuron)
77
-
78
- for t in range(self.time_steps):
79
- out = call(inputs[t])
80
- self.assertEqual(out.shape, (self.batch_size, self.in_size))
81
-
82
- def test_alif_neuron(self):
83
- tau = 20.0 * u.ms
84
- tau_ada = 100.0 * u.ms
85
- neuron = ALIF(self.in_size, tau=tau, tau_a=tau_ada)
86
- inputs = self.generate_input()
87
-
88
- # Test initialization
89
- self.assertEqual(neuron.in_size, (self.in_size,))
90
- self.assertEqual(neuron.out_size, (self.in_size,))
91
- self.assertEqual(neuron.tau, tau)
92
- self.assertEqual(neuron.tau_a, tau_ada)
93
-
94
- # Test forward pass
95
- neuron.init_state(self.batch_size)
96
- call = brainstate.compile.jit(neuron)
97
- with brainstate.environ.context(dt=0.1 * u.ms):
98
- for t in range(self.time_steps):
99
- out = call(inputs[t])
100
- self.assertEqual(out.shape, (self.batch_size, self.in_size))
101
-
102
- def test_spike_function(self):
103
- for NeuronClass in [IF, LIF, ALIF]:
104
- neuron = NeuronClass(self.in_size)
105
- neuron.init_state()
106
- v = jnp.linspace(-1, 1, self.in_size) * u.mV
107
- spikes = neuron.get_spike(v)
108
- self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
109
-
110
- def test_soft_reset(self):
111
- for NeuronClass in [IF, LIF, ALIF]:
112
- neuron = NeuronClass(self.in_size, spk_reset='soft')
113
- inputs = self.generate_input()
114
- state = neuron.init_state(self.batch_size)
115
- call = brainstate.compile.jit(neuron)
116
- with brainstate.environ.context(dt=0.1 * u.ms):
117
- for t in range(self.time_steps):
118
- out = call(inputs[t])
119
- self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th))
120
-
121
- def test_hard_reset(self):
122
- for NeuronClass in [IF, LIF, ALIF]:
123
- neuron = NeuronClass(self.in_size, spk_reset='hard')
124
- inputs = self.generate_input()
125
- state = neuron.init_state(self.batch_size)
126
- call = brainstate.compile.jit(neuron)
127
- with brainstate.environ.context(dt=0.1 * u.ms):
128
- for t in range(self.time_steps):
129
- out = call(inputs[t])
130
- self.assertTrue(jnp.all((neuron.V.value < neuron.V_th) | (neuron.V.value == 0. * u.mV)))
131
-
132
- def test_detach_spike(self):
133
- for NeuronClass in [IF, LIF, ALIF]:
134
- neuron = NeuronClass(self.in_size)
135
- inputs = self.generate_input()
136
- state = neuron.init_state(self.batch_size)
137
- call = brainstate.compile.jit(neuron)
138
- with brainstate.environ.context(dt=0.1 * u.ms):
139
- for t in range(self.time_steps):
140
- out = call(inputs[t])
141
- self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type)
142
-
143
- def test_keep_size(self):
144
- in_size = (2, 3)
145
- for NeuronClass in [IF, LIF, ALIF]:
146
- neuron = NeuronClass(in_size)
147
- self.assertEqual(neuron.in_size, in_size)
148
- self.assertEqual(neuron.out_size, in_size)
149
-
150
- inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
151
- state = neuron.init_state(self.batch_size)
152
- call = brainstate.compile.jit(neuron)
153
- with brainstate.environ.context(dt=0.1 * u.ms):
154
- for t in range(self.time_steps):
155
- out = call(inputs[t])
156
- self.assertEqual(out.shape, (self.batch_size, *in_size))
157
-
158
-
159
- if __name__ == '__main__':
160
- with brainstate.environ.context(dt=0.1):
161
- unittest.main()
brainstate/nn/_others.py DELETED
@@ -1,46 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from functools import partial
17
-
18
- import jax
19
- import jax.numpy as jnp
20
-
21
- from brainstate.typing import PyTree
22
-
23
- __all__ = [
24
- 'clip_grad_norm',
25
- ]
26
-
27
-
28
- def clip_grad_norm(
29
- grad: PyTree,
30
- max_norm: float | jax.Array,
31
- norm_type: int | str | None = None
32
- ):
33
- """
34
- Clips gradient norm of an iterable of parameters.
35
-
36
- The norm is computed over all gradients together, as if they were
37
- concatenated into a single vector. Gradients are modified in-place.
38
-
39
- Args:
40
- grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
41
- max_norm (float): max norm of the gradients.
42
- norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
43
- """
44
- norm_fn = partial(jnp.linalg.norm, ord=norm_type)
45
- norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
46
- return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
@@ -1,486 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Callable, Union
17
- from typing import Optional
18
-
19
- import brainevent
20
- import brainunit as u
21
-
22
- from brainstate._state import State
23
- from brainstate.mixin import BindCondData, JointTypes
24
- from brainstate.mixin import ParamDescriber, AlignPost
25
- from brainstate.util.others import get_unique_name
26
- from ._collective_ops import call_order
27
- from ._dynamics import Dynamics, Projection, maybe_init_prefetch, Prefetch, PrefetchDelayAt
28
- from ._module import Module
29
- from ._stp import ShortTermPlasticity
30
- from ._synapse import Synapse
31
- from ._synouts import SynOut
32
-
33
- __all__ = [
34
- 'AlignPostProj',
35
- 'DeltaProj',
36
- 'CurrentProj',
37
-
38
- 'align_pre_projection',
39
- 'align_post_projection',
40
- ]
41
-
42
-
43
- def _check_modules(*modules):
44
- # checking modules
45
- for module in modules:
46
- if not callable(module) and not isinstance(module, State):
47
- raise TypeError(
48
- f'The module should be a callable function or a brainstate.State, but got {module}.'
49
- )
50
- return tuple(modules)
51
-
52
-
53
- def call_module(module, *args, **kwargs):
54
- if callable(module):
55
- return module(*args, **kwargs)
56
- elif isinstance(module, State):
57
- return module.value
58
- else:
59
- raise TypeError(
60
- f'The module should be a callable function or a brainstate.State, but got {module}.'
61
- )
62
-
63
-
64
- def is_instance(x, cls) -> bool:
65
- return isinstance(x, cls)
66
-
67
-
68
- def get_post_repr(label, syn, out):
69
- if label is None:
70
- return f'{syn.identifier} // {out.identifier}'
71
- else:
72
- return f'{label}{syn.identifier} // {out.identifier}'
73
-
74
-
75
- def align_post_add_bef_update(
76
- syn_desc: ParamDescriber[AlignPost],
77
- out_desc: ParamDescriber[BindCondData],
78
- post: Dynamics,
79
- proj_name: str,
80
- label: str,
81
- ):
82
- # synapse and output initialization
83
- _post_repr = get_post_repr(label, syn_desc, out_desc)
84
- if not post._has_before_update(_post_repr):
85
- syn_cls = syn_desc()
86
- out_cls = out_desc()
87
-
88
- # synapse and output initialization
89
- post.add_current_input(proj_name, out_cls, label=label)
90
- post._add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
91
- syn = post._get_before_update(_post_repr).syn
92
- out = post._get_before_update(_post_repr).out
93
- return syn, out
94
-
95
-
96
- class _AlignPost(Module):
97
- def __init__(
98
- self,
99
- syn: Dynamics,
100
- out: BindCondData
101
- ):
102
- super().__init__()
103
- self.syn = syn
104
- self.out = out
105
-
106
- def update(self, *args, **kwargs):
107
- self.out.bind_cond(self.syn(*args, **kwargs))
108
-
109
-
110
- class AlignPostProj(Projection):
111
- """
112
- Align-post projection of the neural network.
113
-
114
-
115
- Examples
116
- --------
117
-
118
- Here is an example of using the `AlignPostProj` to create a synaptic projection.
119
- Note that this projection needs the manual input of pre-synaptic spikes.
120
-
121
- >>> import brainstate
122
- >>> import brainunit as u
123
- >>> n_exc = 3200
124
- >>> n_inh = 800
125
- >>> num = n_exc + n_inh
126
- >>> pop = brainstate.nn.LIFRef(
127
- ... num,
128
- ... V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
129
- ... tau=20. * u.ms, tau_ref=5. * u.ms,
130
- ... V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
131
- ... )
132
- >>> pop.init_state()
133
- >>> E = brainstate.nn.AlignPostProj(
134
- ... comm=brainstate.nn.FixedNumConn(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
135
- ... syn=brainstate.nn.Expon.desc(num, tau=5. * u.ms),
136
- ... out=brainstate.nn.CUBA.desc(scale=u.volt),
137
- ... post=pop
138
- ... )
139
- >>> exe_current = E(pop.get_spike())
140
-
141
- """
142
- __module__ = 'brainstate.nn'
143
-
144
- def __init__(
145
- self,
146
- *modules,
147
- comm: Callable,
148
- syn: Union[ParamDescriber[AlignPost], AlignPost],
149
- out: Union[ParamDescriber[SynOut], SynOut],
150
- post: Dynamics,
151
- label: Optional[str] = None,
152
- ):
153
- super().__init__(name=get_unique_name(self.__class__.__name__))
154
-
155
- # checking modules
156
- self.modules = _check_modules(*modules)
157
-
158
- # checking communication model
159
- if not callable(comm):
160
- raise TypeError(
161
- f'The communication should be an instance of callable function, but got {comm}.'
162
- )
163
-
164
- # checking synapse and output models
165
- if is_instance(syn, ParamDescriber[AlignPost]):
166
- if not is_instance(out, ParamDescriber[SynOut]):
167
- if is_instance(out, ParamDescriber):
168
- raise TypeError(
169
- f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
170
- f'the synapse is an instance of {AlignPost}, but got {out}.'
171
- )
172
- raise TypeError(
173
- f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
174
- f'the synapse is a describer, but we got {out}.'
175
- )
176
- merging = True
177
- else:
178
- if is_instance(syn, ParamDescriber):
179
- raise TypeError(
180
- f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
181
- )
182
- if not is_instance(out, SynOut):
183
- raise TypeError(
184
- f'The output should be an instance of {SynOut} when the synapse is '
185
- f'not a describer, but we got {out}.'
186
- )
187
- merging = False
188
- self.merging = merging
189
-
190
- # checking post model
191
- if not is_instance(post, Dynamics):
192
- raise TypeError(
193
- f'The post should be an instance of {Dynamics}, but got {post}.'
194
- )
195
-
196
- if merging:
197
- # synapse and output initialization
198
- syn, out = align_post_add_bef_update(syn_desc=syn,
199
- out_desc=out,
200
- post=post,
201
- proj_name=self.name,
202
- label=label)
203
- else:
204
- post.add_current_input(self.name, out)
205
-
206
- # references
207
- self.comm = comm
208
- self.syn: JointTypes[Dynamics, AlignPost] = syn
209
- self.out: BindCondData = out
210
- self.post: Dynamics = post
211
-
212
- @call_order(2)
213
- def init_state(self, *args, **kwargs):
214
- for module in self.modules:
215
- maybe_init_prefetch(module, *args, **kwargs)
216
-
217
- def update(self, *args):
218
- # call all modules
219
- for module in self.modules:
220
- x = call_module(module, *args)
221
- args = (x,)
222
- # communication module
223
- x = self.comm(*args)
224
- # add synapse input
225
- self.syn.add_delta_input(self.name, x)
226
- if not self.merging:
227
- # synapse and output interaction
228
- conductance = self.syn()
229
- self.out.bind_cond(conductance)
230
-
231
-
232
- class DeltaProj(Projection):
233
- """
234
- Delta-based projection of the neural network.
235
-
236
- This projection directly applies delta inputs to post-synaptic neurons without intervening
237
- synaptic dynamics. It processes inputs through optional prefetch modules, applies a communication model,
238
- and adds the result directly as a delta input to the post-synaptic population.
239
-
240
- Parameters
241
- ----------
242
- *prefetch : State or callable
243
- Optional prefetch modules to process input before communication.
244
- comm : callable
245
- Communication model that determines how signals are transmitted.
246
- post : Dynamics
247
- Post-synaptic neural population to receive the delta inputs.
248
- label : Optional[str], default=None
249
- Optional label for the projection to identify it in the post-synaptic population.
250
-
251
- Examples
252
- --------
253
- >>> import brainstate
254
- >>> import brainunit as u
255
- >>> n_neurons = 100
256
- >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
- >>> pop.init_state()
258
- >>> delta_input = brainstate.nn.DeltaProj(
259
- ... comm=lambda x: x * 10.0*u.mV,
260
- ... post=pop
261
- ... )
262
- >>> delta_input(1.0) # Apply voltage increment directly
263
- """
264
- __module__ = 'brainstate.nn'
265
-
266
- def __init__(self, *prefetch, comm: Callable, post: Dynamics, label=None):
267
- super().__init__(name=get_unique_name(self.__class__.__name__))
268
-
269
- self.label = label
270
-
271
- # checking modules
272
- self.prefetches = _check_modules(*prefetch)
273
-
274
- # checking communication model
275
- if not callable(comm):
276
- raise TypeError(
277
- f'The communication should be an instance of callable function, but got {comm}.'
278
- )
279
- self.comm = comm
280
-
281
- # post model
282
- if not isinstance(post, Dynamics):
283
- raise TypeError(
284
- f'The post should be an instance of {Dynamics}, but got {post}.'
285
- )
286
- self.post = post
287
-
288
- @call_order(2)
289
- def init_state(self, *args, **kwargs):
290
- for prefetch in self.prefetches:
291
- maybe_init_prefetch(prefetch, *args, **kwargs)
292
-
293
- def update(self, *x):
294
- for module in self.prefetches:
295
- x = (call_module(module, *x),)
296
- assert len(x) == 1, f'The output of the modules should be a single value, but got {x}.'
297
- x = self.comm(x[0])
298
- self.post.add_delta_input(self.name, x, label=self.label)
299
-
300
-
301
- class CurrentProj(Projection):
302
- """
303
- Current-based projection of the neural network.
304
-
305
- This projection directly modulates post-synaptic currents without separate synaptic dynamics.
306
- It processes inputs through optional prefetch modules, applies a communication model,
307
- and binds the result to the output model which is then added as a current input to the post-synaptic population.
308
-
309
- Parameters
310
- ----------
311
- *prefetch : State or callable
312
- Optional prefetch modules to process input before communication.
313
- The last element must be an instance of Prefetch or PrefetchDelayAt if any are provided.
314
- comm : callable
315
- Communication model that determines how signals are transmitted.
316
- out : SynOut
317
- Output model that converts communication results to post-synaptic currents.
318
- post : Dynamics
319
- Post-synaptic neural population to receive the currents.
320
-
321
- Examples
322
- --------
323
- >>> import brainstate
324
- >>> import brainunit as u
325
- >>> n_neurons = 100
326
- >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
327
- >>> pop.init_state()
328
- >>> current_input = brainstate.nn.CurrentProj(
329
- ... comm=lambda x: x * 0.5,
330
- ... out=brainstate.nn.CUBA(scale=1.0*u.nA),
331
- ... post=pop
332
- ... )
333
- >>> current_input(0.2) # Apply external current
334
- """
335
- __module__ = 'brainstate.nn'
336
-
337
- def __init__(
338
- self,
339
- *prefetch,
340
- comm: Callable,
341
- out: SynOut,
342
- post: Dynamics,
343
- ):
344
- super().__init__(name=get_unique_name(self.__class__.__name__))
345
-
346
- # check prefetch
347
- self.prefetch = prefetch
348
- if len(self.prefetch) > 0 and not isinstance(prefetch[-1], (Prefetch, PrefetchDelayAt)):
349
- raise TypeError(
350
- f'The last element of prefetch should be an instance of {Prefetch} or {PrefetchDelayAt}, '
351
- f'but got {prefetch[-1]}.'
352
- )
353
-
354
- # check out
355
- if not isinstance(out, SynOut):
356
- raise TypeError(f'The out should be a SynOut, but got {out}.')
357
- self.out = out
358
-
359
- # check post
360
- if not isinstance(post, Dynamics):
361
- raise TypeError(f'The post should be a Dynamics, but got {post}.')
362
- self.post = post
363
- post.add_current_input(self.name, out)
364
-
365
- # output initialization
366
- self.comm = comm
367
-
368
- @call_order(2)
369
- def init_state(self, *args, **kwargs):
370
- for prefetch in self.prefetch:
371
- maybe_init_prefetch(prefetch, *args, **kwargs)
372
-
373
- def update(self, *x):
374
- for prefetch in self.prefetch:
375
- x = (call_module(prefetch, *x),)
376
- x = self.comm(*x)
377
- self.out.bind_cond(x)
378
-
379
-
380
- class align_pre_projection(Projection):
381
- """
382
- Represents a pre-synaptic alignment projection mechanism.
383
-
384
- This class inherits from the `Projection` base class and is designed to
385
- manage the pre-synaptic alignment process in neural network simulations.
386
- It takes into account pre-synaptic dynamics, synaptic properties, delays,
387
- communication functions, synaptic outputs, post-synaptic dynamics, and
388
- short-term plasticity.
389
-
390
- Attributes:
391
- pre (Dynamics): The pre-synaptic dynamics object.
392
- syn (Synapse): The synaptic object after pre-synaptic alignment.
393
- delay (u.Quantity[u.second]): The output delay from the synapse.
394
- projection (CurrentProj): The current projection object handling communication,
395
- output, and post-synaptic dynamics.
396
- stp (ShortTermPlasticity, optional): The short-term plasticity object,
397
- defaults to None.
398
- """
399
-
400
- def __init__(
401
- self,
402
- *spike_generator,
403
- syn: Dynamics,
404
- comm: Callable,
405
- out: SynOut,
406
- post: Dynamics,
407
- stp: ShortTermPlasticity = None,
408
- ):
409
- super().__init__()
410
-
411
- self.spike_generator = _check_modules(*spike_generator)
412
- self.projection = CurrentProj(comm=comm, out=out, post=post)
413
- self.syn = syn
414
- self.stp = stp
415
-
416
- @call_order(2)
417
- def init_state(self, *args, **kwargs):
418
- for module in self.spike_generator:
419
- maybe_init_prefetch(module, *args, **kwargs)
420
-
421
- def update(self, *x):
422
- for fun in self.spike_generator:
423
- x = fun(*x)
424
- if isinstance(x, (tuple, list)):
425
- x = tuple(x)
426
- else:
427
- x = (x,)
428
- assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
429
- x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
430
- if self.stp is not None:
431
- x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
432
- x = self.syn(x) # Apply pre-synaptic alignment
433
- return self.projection(x)
434
-
435
-
436
- class align_post_projection(Projection):
437
- """
438
- Represents a post-synaptic alignment projection mechanism.
439
-
440
- This class inherits from the `Projection` base class and is designed to
441
- manage the post-synaptic alignment process in neural network simulations.
442
- It takes into account spike generators, communication functions, synaptic
443
- properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity.
444
-
445
- Args:
446
- *spike_generator: Callable(s) that generate spike events or transform input spikes.
447
- comm (Callable): Communication function for the projection.
448
- syn (Union[AlignPost, ParamDescriber[AlignPost]]): The post-synaptic alignment object or its parameter describer.
449
- out (Union[SynOut, ParamDescriber[SynOut]]): The synaptic output object or its parameter describer.
450
- post (Dynamics): The post-synaptic dynamics object.
451
- stp (ShortTermPlasticity, optional): The short-term plasticity object, defaults to None.
452
-
453
- """
454
-
455
- def __init__(
456
- self,
457
- *spike_generator,
458
- comm: Callable,
459
- syn: Union[AlignPost, ParamDescriber[AlignPost]],
460
- out: Union[SynOut, ParamDescriber[SynOut]],
461
- post: Dynamics,
462
- stp: ShortTermPlasticity = None,
463
- ):
464
- super().__init__()
465
-
466
- self.spike_generator = _check_modules(*spike_generator)
467
- self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
468
- self.stp = stp
469
-
470
- @call_order(2)
471
- def init_state(self, *args, **kwargs):
472
- for module in self.spike_generator:
473
- maybe_init_prefetch(module, *args, **kwargs)
474
-
475
- def update(self, *x):
476
- for fun in self.spike_generator:
477
- x = fun(*x)
478
- if isinstance(x, (tuple, list)):
479
- x = tuple(x)
480
- else:
481
- x = (x,)
482
- assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
483
- x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
484
- if self.stp is not None:
485
- x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
486
- return self.projection(x)