brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,686 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from __future__ import annotations
19
-
20
- import collections.abc
21
- import numbers
22
- from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
23
-
24
- import jax
25
- import jax.numpy as jnp
26
-
27
- from ._base import DnnLayer
28
- from brainstate import init, functional
29
- from brainstate._state import ParamState
30
- from brainstate.mixin import Mode
31
- from brainstate.typing import ArrayLike
32
-
33
- T = TypeVar('T')
34
-
35
- __all__ = [
36
- 'Linear', 'ScaledWSLinear', 'SignedWLinear', 'CSRLinear',
37
- 'Conv1d', 'Conv2d', 'Conv3d',
38
- 'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
39
- ]
40
-
41
-
42
- def to_dimension_numbers(
43
- num_spatial_dims: int,
44
- channels_last: bool,
45
- transpose: bool
46
- ) -> jax.lax.ConvDimensionNumbers:
47
- """Create a `lax.ConvDimensionNumbers` for the given inputs."""
48
- num_dims = num_spatial_dims + 2
49
- if channels_last:
50
- spatial_dims = tuple(range(1, num_dims - 1))
51
- image_dn = (0, num_dims - 1) + spatial_dims
52
- else:
53
- spatial_dims = tuple(range(2, num_dims))
54
- image_dn = (0, 1) + spatial_dims
55
- if transpose:
56
- kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
57
- else:
58
- kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
59
- return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
60
- rhs_spec=kernel_dn,
61
- out_spec=image_dn)
62
-
63
-
64
- def replicate(
65
- element: Union[T, Sequence[T]],
66
- num_replicate: int,
67
- name: str,
68
- ) -> Tuple[T, ...]:
69
- """Replicates entry in `element` `num_replicate` if needed."""
70
- if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
71
- return (element,) * num_replicate
72
- elif len(element) == 1:
73
- return tuple(list(element) * num_replicate)
74
- elif len(element) == num_replicate:
75
- return tuple(element)
76
- else:
77
- raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
78
- f"sequence of length {num_replicate}.")
79
-
80
-
81
- class Linear(DnnLayer):
82
- """
83
- Linear layer.
84
- """
85
- __module__ = 'brainstate.nn'
86
-
87
- def __init__(
88
- self,
89
- in_size: Union[int, Sequence[int]],
90
- out_size: Union[int, Sequence[int]],
91
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
92
- b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
93
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
94
- name: Optional[str] = None,
95
- mode: Optional[Mode] = None,
96
- ):
97
- super().__init__(name=name, mode=mode)
98
-
99
- # input and output shape
100
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
101
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
102
-
103
- # w_mask
104
- self.w_mask = init.param(w_mask, self.in_size + self.out_size)
105
-
106
- # weights
107
- params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
108
- if b_init is not None:
109
- params['bias'] = init.param(b_init, self.out_size, allow_none=False)
110
-
111
- # weight + op
112
- self.W = ParamState(params)
113
-
114
- def update(self, x):
115
- params = self.W.value
116
- weight = params['weight']
117
- if self.w_mask is not None:
118
- weight = weight * self.w_mask
119
- y = jnp.dot(x, weight)
120
- if 'bias' in params:
121
- y = y + params['bias']
122
- return y
123
-
124
-
125
- class SignedWLinear(DnnLayer):
126
- """
127
- Linear layer with signed weights.
128
- """
129
- __module__ = 'brainstate.nn'
130
-
131
- def __init__(
132
- self,
133
- in_size: Union[int, Sequence[int]],
134
- out_size: Union[int, Sequence[int]],
135
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
136
- w_sign: Optional[ArrayLike] = None,
137
- name: Optional[str] = None,
138
- mode: Optional[Mode] = None
139
- ):
140
- super().__init__(name=name, mode=mode)
141
-
142
- # input and output shape
143
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
144
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
145
-
146
- # w_mask
147
- self.w_sign = w_sign
148
-
149
- # weights
150
- weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
151
- self.W = ParamState(weight)
152
-
153
- def _operation(self, x, w):
154
- if self.w_sign is None:
155
- return jnp.matmul(x, jnp.abs(w))
156
- else:
157
- return jnp.matmul(x, jnp.abs(w) * self.w_sign)
158
-
159
- def update(self, x):
160
- return self._operation(x, self.W.value)
161
-
162
-
163
- class ScaledWSLinear(DnnLayer):
164
- """
165
- Linear Layer with Weight Standardization.
166
-
167
- Applies weight standardization to the weights of the linear layer.
168
-
169
- Parameters
170
- ----------
171
- in_size: int, sequence of int
172
- The input size.
173
- out_size: int, sequence of int
174
- The output size.
175
- w_init: Callable, ArrayLike
176
- The initializer for the weights.
177
- b_init: Callable, ArrayLike
178
- The initializer for the bias.
179
- w_mask: ArrayLike, Callable
180
- The optional mask of the weights.
181
- as_etrace_weight: bool
182
- Whether to use ETraceParamOp for the weights.
183
- ws_gain: bool
184
- Whether to use gain for the weights. The default is True.
185
- eps: float
186
- The epsilon value for the weight standardization.
187
- name: str
188
- The name of the object.
189
- mode: Mode
190
- The computation mode of the current object.
191
-
192
- """
193
- __module__ = 'brainstate.nn'
194
-
195
- def __init__(
196
- self,
197
- in_size: Union[int, Sequence[int]],
198
- out_size: Union[int, Sequence[int]],
199
- w_init: Callable = init.KaimingNormal(),
200
- b_init: Callable = init.ZeroInit(),
201
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
202
- as_etrace_weight: bool = True,
203
- full_etrace: bool = False,
204
- ws_gain: bool = True,
205
- eps: float = 1e-4,
206
- name: str = None,
207
- mode: Mode = None
208
- ):
209
- super().__init__(name=name, mode=mode)
210
-
211
- # input and output shape
212
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
213
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
214
-
215
- # w_mask
216
- self.w_mask = init.param(w_mask, (self.in_size[0], 1))
217
-
218
- # parameters
219
- self.eps = eps
220
-
221
- # weights
222
- params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
223
- if b_init is not None:
224
- params['bias'] = init.param(b_init, self.out_size, allow_none=False)
225
- # gain
226
- if ws_gain:
227
- s = params['weight'].shape
228
- params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
229
-
230
- # weight operation
231
- self.W = ParamState(params)
232
-
233
- def update(self, x):
234
- return self._operation(x, self.W.value)
235
-
236
- def _operation(self, x, params):
237
- w = params['weight']
238
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
239
- if self.w_mask is not None:
240
- w = w * self.w_mask
241
- y = jnp.dot(x, w)
242
- if 'bias' in params:
243
- y = y + params['bias']
244
- return y
245
-
246
-
247
- class CSRLinear(DnnLayer):
248
- __module__ = 'brainstate.nn'
249
-
250
-
251
- class _BaseConv(DnnLayer):
252
- # the number of spatial dimensions
253
- num_spatial_dims: int
254
-
255
- # the weight and its operations
256
- W: ParamState
257
-
258
- def __init__(
259
- self,
260
- in_size: Sequence[int],
261
- out_channels: int,
262
- kernel_size: Union[int, Tuple[int, ...]],
263
- stride: Union[int, Tuple[int, ...]] = 1,
264
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
265
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
266
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
267
- groups: int = 1,
268
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
269
- mode: Mode = None,
270
- name: str = None,
271
- ):
272
- super().__init__(name=name, mode=mode)
273
-
274
- # general parameters
275
- assert self.num_spatial_dims + 1 == len(in_size)
276
- self.in_size = tuple(in_size)
277
- self.in_channels = in_size[-1]
278
- self.out_channels = out_channels
279
- self.stride = replicate(stride, self.num_spatial_dims, 'stride')
280
- self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
281
- self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
282
- self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
283
- self.groups = groups
284
- self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
285
-
286
- # the padding parameter
287
- if isinstance(padding, str):
288
- assert padding in ['SAME', 'VALID']
289
- elif isinstance(padding, int):
290
- padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
291
- elif isinstance(padding, (tuple, list)):
292
- if isinstance(padding[0], int):
293
- padding = (padding,) * self.num_spatial_dims
294
- elif isinstance(padding[0], (tuple, list)):
295
- if len(padding) == 1:
296
- padding = tuple(padding) * self.num_spatial_dims
297
- else:
298
- if len(padding) != self.num_spatial_dims:
299
- raise ValueError(
300
- f"Padding {padding} must be a Tuple[int, int], "
301
- f"or sequence of Tuple[int, int] with length 1, "
302
- f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
303
- )
304
- padding = tuple(padding)
305
- else:
306
- raise ValueError
307
- self.padding = padding
308
-
309
- # the number of in-/out-channels
310
- assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
311
- assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
312
-
313
- # kernel shape and w_mask
314
- kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
315
- self.kernel_shape = kernel_shape
316
- self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
317
-
318
- def _check_input_dim(self, x):
319
- if x.ndim == self.num_spatial_dims + 2:
320
- x_shape = x.shape[1:]
321
- elif x.ndim == self.num_spatial_dims + 1:
322
- x_shape = x.shape
323
- else:
324
- raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
325
- f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
326
- if self.in_size != x_shape:
327
- raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
328
-
329
- def update(self, x):
330
- self._check_input_dim(x)
331
- non_batching = False
332
- if x.ndim == self.num_spatial_dims + 1:
333
- x = jnp.expand_dims(x, 0)
334
- non_batching = True
335
- y = self._conv_op(x, self.W.value)
336
- return y[0] if non_batching else y
337
-
338
- def _conv_op(self, x, params):
339
- raise NotImplementedError
340
-
341
- def __repr__(self):
342
- return (f'{self.__class__.__name__}('
343
- f'in_channels={self.in_channels}, '
344
- f'out_channels={self.out_channels}, '
345
- f'kernel_size={self.kernel_size}, '
346
- f'stride={self.stride}, '
347
- f'padding={self.padding}, '
348
- f'groups={self.groups})')
349
-
350
-
351
- class _Conv(_BaseConv):
352
- num_spatial_dims: int = None
353
-
354
- def __init__(
355
- self,
356
- in_size: Sequence[int],
357
- out_channels: int,
358
- kernel_size: Union[int, Tuple[int, ...]],
359
- stride: Union[int, Tuple[int, ...]] = 1,
360
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
361
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
362
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
363
- groups: int = 1,
364
- w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
365
- b_init: Optional[Union[Callable, ArrayLike]] = None,
366
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
367
- mode: Mode = None,
368
- name: str = None,
369
- ):
370
- super().__init__(in_size=in_size,
371
- out_channels=out_channels,
372
- kernel_size=kernel_size,
373
- stride=stride,
374
- padding=padding,
375
- lhs_dilation=lhs_dilation,
376
- rhs_dilation=rhs_dilation,
377
- groups=groups,
378
- w_mask=w_mask,
379
- name=name,
380
- mode=mode)
381
-
382
- self.w_initializer = w_init
383
- self.b_initializer = b_init
384
-
385
- # --- weights --- #
386
- weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
387
- params = dict(weight=weight)
388
- if self.b_initializer is not None:
389
- bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
390
- bias = init.param(self.b_initializer, bias_shape, allow_none=True)
391
- params['bias'] = bias
392
-
393
- # The weight operation
394
- self.W = ParamState(params)
395
-
396
- # Evaluate the output shape
397
- abstract_y = jax.eval_shape(
398
- self._conv_op,
399
- jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
400
- params
401
- )
402
- y_shape = abstract_y.shape[1:]
403
- self.out_size = y_shape
404
-
405
- def _conv_op(self, x, params):
406
- w = params['weight']
407
- if self.w_mask is not None:
408
- w = w * self.w_mask
409
- y = jax.lax.conv_general_dilated(
410
- lhs=x,
411
- rhs=w,
412
- window_strides=self.stride,
413
- padding=self.padding,
414
- lhs_dilation=self.lhs_dilation,
415
- rhs_dilation=self.rhs_dilation,
416
- feature_group_count=self.groups,
417
- dimension_numbers=self.dimension_numbers
418
- )
419
- if 'bias' in params:
420
- y = y + params['bias']
421
- return y
422
-
423
- def __repr__(self):
424
- return (f'{self.__class__.__name__}('
425
- f'in_channels={self.in_channels}, '
426
- f'out_channels={self.out_channels}, '
427
- f'kernel_size={self.kernel_size}, '
428
- f'stride={self.stride}, '
429
- f'padding={self.padding}, '
430
- f'groups={self.groups})')
431
-
432
-
433
- class Conv1d(_Conv):
434
- """One-dimensional convolution.
435
-
436
- The input should be a 3d array with the shape of ``[B, H, C]``.
437
-
438
- Parameters
439
- ----------
440
- %s
441
- """
442
- __module__ = 'brainstate.nn'
443
- num_spatial_dims: int = 1
444
-
445
-
446
- class Conv2d(_Conv):
447
- """Two-dimensional convolution.
448
-
449
- The input should be a 4d array with the shape of ``[B, H, W, C]``.
450
-
451
- Parameters
452
- ----------
453
- %s
454
- """
455
- __module__ = 'brainstate.nn'
456
- num_spatial_dims: int = 2
457
-
458
-
459
- class Conv3d(_Conv):
460
- """Three-dimensional convolution.
461
-
462
- The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
463
-
464
- Parameters
465
- ----------
466
- %s
467
- """
468
- __module__ = 'brainstate.nn'
469
- num_spatial_dims: int = 3
470
-
471
-
472
- _conv_doc = '''
473
- in_size: tuple of int
474
- The input shape, without the batch size. This argument is important, since it is
475
- used to evaluate the shape of the output.
476
- out_channels: int
477
- The number of output channels.
478
- kernel_size: int, sequence of int
479
- The shape of the convolutional kernel.
480
- For 1D convolution, the kernel size can be passed as an integer.
481
- For all other cases, it must be a sequence of integers.
482
- stride: int, sequence of int
483
- An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
484
- padding: str, int, sequence of int, sequence of tuple
485
- Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
486
- high)` integer pairs that give the padding to apply before and after each
487
- spatial dimension.
488
- lhs_dilation: int, sequence of int
489
- An integer or a sequence of `n` integers, giving the
490
- dilation factor to apply in each spatial dimension of `inputs`
491
- (default: 1). Convolution with input dilation `d` is equivalent to
492
- transposed convolution with stride `d`.
493
- rhs_dilation: int, sequence of int
494
- An integer or a sequence of `n` integers, giving the
495
- dilation factor to apply in each spatial dimension of the convolution
496
- kernel (default: 1). Convolution with kernel dilation
497
- is also known as 'atrous convolution'.
498
- groups: int
499
- If specified, divides the input features into groups. default 1.
500
- w_init: Callable, ArrayLike, Initializer
501
- The initializer for the convolutional kernel.
502
- b_init: Optional, Callable, ArrayLike, Initializer
503
- The initializer for the bias.
504
- w_mask: ArrayLike, Callable, Optional
505
- The optional mask of the weights.
506
- mode: Mode
507
- The computation mode of the current object. Default it is `training`.
508
- name: str, Optional
509
- The name of the object.
510
- '''
511
-
512
- Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
513
- Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
514
- Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
515
-
516
-
517
- class _ScaledWSConv(_BaseConv):
518
- def __init__(
519
- self,
520
- in_size: Sequence[int],
521
- out_channels: int,
522
- kernel_size: Union[int, Tuple[int, ...]],
523
- stride: Union[int, Tuple[int, ...]] = 1,
524
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
525
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
526
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
527
- groups: int = 1,
528
- ws_gain: bool = True,
529
- eps: float = 1e-4,
530
- w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
531
- b_init: Optional[Union[Callable, ArrayLike]] = None,
532
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
533
- mode: Mode = None,
534
- name: str = None,
535
- ):
536
- super().__init__(in_size=in_size,
537
- out_channels=out_channels,
538
- kernel_size=kernel_size,
539
- stride=stride,
540
- padding=padding,
541
- lhs_dilation=lhs_dilation,
542
- rhs_dilation=rhs_dilation,
543
- groups=groups,
544
- w_mask=w_mask,
545
- name=name,
546
- mode=mode)
547
-
548
- self.w_initializer = w_init
549
- self.b_initializer = b_init
550
-
551
- # --- weights --- #
552
- weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
553
- params = dict(weight=weight)
554
- if self.b_initializer is not None:
555
- bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
556
- bias = init.param(self.b_initializer, bias_shape, allow_none=True)
557
- params['bias'] = bias
558
-
559
- # gain
560
- if ws_gain:
561
- gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
562
- ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
563
- params['gain'] = ws_gain
564
-
565
- # Epsilon, a small constant to avoid dividing by zero.
566
- self.eps = eps
567
-
568
- # The weight operation
569
- self.W = ParamState(params)
570
-
571
- # Evaluate the output shape
572
- abstract_y = jax.eval_shape(
573
- self._conv_op,
574
- jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
575
- params
576
- )
577
- y_shape = abstract_y.shape[1:]
578
- self.out_size = y_shape
579
-
580
- def _conv_op(self, x, params):
581
- w = params['weight']
582
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
583
- if self.w_mask is not None:
584
- w = w * self.w_mask
585
- y = jax.lax.conv_general_dilated(
586
- lhs=x,
587
- rhs=w,
588
- window_strides=self.stride,
589
- padding=self.padding,
590
- lhs_dilation=self.lhs_dilation,
591
- rhs_dilation=self.rhs_dilation,
592
- feature_group_count=self.groups,
593
- dimension_numbers=self.dimension_numbers
594
- )
595
- if 'bias' in params:
596
- y = y + params['bias']
597
- return y
598
-
599
-
600
- class ScaledWSConv1d(_ScaledWSConv):
601
- """One-dimensional convolution with weight standardization.
602
-
603
- The input should be a 3d array with the shape of ``[B, H, C]``.
604
-
605
- Parameters
606
- ----------
607
- %s
608
- """
609
- __module__ = 'brainstate.nn'
610
- num_spatial_dims: int = 1
611
-
612
-
613
- class ScaledWSConv2d(_ScaledWSConv):
614
- """Two-dimensional convolution with weight standardization.
615
-
616
- The input should be a 4d array with the shape of ``[B, H, W, C]``.
617
-
618
- Parameters
619
- ----------
620
- %s
621
- """
622
- __module__ = 'brainstate.nn'
623
- num_spatial_dims: int = 2
624
-
625
-
626
- class ScaledWSConv3d(_ScaledWSConv):
627
- """Three-dimensional convolution with weight standardization.
628
-
629
- The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
630
-
631
- Parameters
632
- ----------
633
- %s
634
- """
635
- __module__ = 'brainstate.nn'
636
- num_spatial_dims: int = 3
637
-
638
-
639
- _ws_conv_doc = '''
640
- in_size: tuple of int
641
- The input shape, without the batch size. This argument is important, since it is
642
- used to evaluate the shape of the output.
643
- out_channels: int
644
- The number of output channels.
645
- kernel_size: int, sequence of int
646
- The shape of the convolutional kernel.
647
- For 1D convolution, the kernel size can be passed as an integer.
648
- For all other cases, it must be a sequence of integers.
649
- stride: int, sequence of int
650
- An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
651
- padding: str, int, sequence of int, sequence of tuple
652
- Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
653
- high)` integer pairs that give the padding to apply before and after each
654
- spatial dimension.
655
- lhs_dilation: int, sequence of int
656
- An integer or a sequence of `n` integers, giving the
657
- dilation factor to apply in each spatial dimension of `inputs`
658
- (default: 1). Convolution with input dilation `d` is equivalent to
659
- transposed convolution with stride `d`.
660
- rhs_dilation: int, sequence of int
661
- An integer or a sequence of `n` integers, giving the
662
- dilation factor to apply in each spatial dimension of the convolution
663
- kernel (default: 1). Convolution with kernel dilation
664
- is also known as 'atrous convolution'.
665
- groups: int
666
- If specified, divides the input features into groups. default 1.
667
- w_init: Callable, ArrayLike, Initializer
668
- The initializer for the convolutional kernel.
669
- b_init: Optional, Callable, ArrayLike, Initializer
670
- The initializer for the bias.
671
- ws_gain: bool
672
- Whether to add a gain term for the weight standarization. The default is `True`.
673
- eps: float
674
- The epsilon value for numerical stability.
675
- w_mask: ArrayLike, Callable, Optional
676
- The optional mask of the weights.
677
- mode: Mode
678
- The computation mode of the current object. Default it is `training`.
679
- name: str, Optional
680
- The name of the object.
681
-
682
- '''
683
-
684
- ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
685
- ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
686
- ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc