brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,499 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import collections.abc
21
+ from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+
26
+ from brainstate import init, functional
27
+ from brainstate._state import ParamState
28
+ from brainstate.nn._module import Module
29
+ from brainstate.typing import ArrayLike
30
+
31
+ T = TypeVar('T')
32
+
33
+ __all__ = [
34
+ 'Conv1d', 'Conv2d', 'Conv3d',
35
+ 'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
36
+ ]
37
+
38
+
39
+ def to_dimension_numbers(
40
+ num_spatial_dims: int,
41
+ channels_last: bool,
42
+ transpose: bool
43
+ ) -> jax.lax.ConvDimensionNumbers:
44
+ """Create a `lax.ConvDimensionNumbers` for the given inputs."""
45
+ num_dims = num_spatial_dims + 2
46
+ if channels_last:
47
+ spatial_dims = tuple(range(1, num_dims - 1))
48
+ image_dn = (0, num_dims - 1) + spatial_dims
49
+ else:
50
+ spatial_dims = tuple(range(2, num_dims))
51
+ image_dn = (0, 1) + spatial_dims
52
+ if transpose:
53
+ kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
54
+ else:
55
+ kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
56
+ return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
57
+ rhs_spec=kernel_dn,
58
+ out_spec=image_dn)
59
+
60
+
61
+ def replicate(
62
+ element: Union[T, Sequence[T]],
63
+ num_replicate: int,
64
+ name: str,
65
+ ) -> Tuple[T, ...]:
66
+ """Replicates entry in `element` `num_replicate` if needed."""
67
+ if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
68
+ return (element,) * num_replicate
69
+ elif len(element) == 1:
70
+ return tuple(list(element) * num_replicate)
71
+ elif len(element) == num_replicate:
72
+ return tuple(element)
73
+ else:
74
+ raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
75
+ f"sequence of length {num_replicate}.")
76
+
77
+
78
+ class _BaseConv(Module):
79
+ # the number of spatial dimensions
80
+ num_spatial_dims: int
81
+
82
+ # the weight and its operations
83
+ weight: ParamState
84
+
85
+ def __init__(
86
+ self,
87
+ in_size: Sequence[int],
88
+ out_channels: int,
89
+ kernel_size: Union[int, Tuple[int, ...]],
90
+ stride: Union[int, Tuple[int, ...]] = 1,
91
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
92
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
93
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
94
+ groups: int = 1,
95
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
96
+ name: str = None,
97
+ ):
98
+ super().__init__(name=name)
99
+
100
+ # general parameters
101
+ assert self.num_spatial_dims + 1 == len(in_size)
102
+ self.in_size = tuple(in_size)
103
+ self.in_channels = in_size[-1]
104
+ self.out_channels = out_channels
105
+ self.stride = replicate(stride, self.num_spatial_dims, 'stride')
106
+ self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
107
+ self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
108
+ self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
109
+ self.groups = groups
110
+ self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
111
+
112
+ # the padding parameter
113
+ if isinstance(padding, str):
114
+ assert padding in ['SAME', 'VALID']
115
+ elif isinstance(padding, int):
116
+ padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
117
+ elif isinstance(padding, (tuple, list)):
118
+ if isinstance(padding[0], int):
119
+ padding = (padding,) * self.num_spatial_dims
120
+ elif isinstance(padding[0], (tuple, list)):
121
+ if len(padding) == 1:
122
+ padding = tuple(padding) * self.num_spatial_dims
123
+ else:
124
+ if len(padding) != self.num_spatial_dims:
125
+ raise ValueError(
126
+ f"Padding {padding} must be a Tuple[int, int], "
127
+ f"or sequence of Tuple[int, int] with length 1, "
128
+ f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
129
+ )
130
+ padding = tuple(padding)
131
+ else:
132
+ raise ValueError
133
+ self.padding = padding
134
+
135
+ # the number of in-/out-channels
136
+ assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
137
+ assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
138
+
139
+ # kernel shape and w_mask
140
+ kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
141
+ self.kernel_shape = kernel_shape
142
+ self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
143
+
144
+ def _check_input_dim(self, x):
145
+ if x.ndim == self.num_spatial_dims + 2:
146
+ x_shape = x.shape[1:]
147
+ elif x.ndim == self.num_spatial_dims + 1:
148
+ x_shape = x.shape
149
+ else:
150
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
151
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
152
+ if self.in_size != x_shape:
153
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
154
+
155
+ def update(self, x):
156
+ self._check_input_dim(x)
157
+ non_batching = False
158
+ if x.ndim == self.num_spatial_dims + 1:
159
+ x = jnp.expand_dims(x, 0)
160
+ non_batching = True
161
+ y = self._conv_op(x, self.weight.value)
162
+ return y[0] if non_batching else y
163
+
164
+ def _conv_op(self, x, params):
165
+ raise NotImplementedError
166
+
167
+ def __repr__(self):
168
+ return (f'{self.__class__.__name__}('
169
+ f'in_channels={self.in_channels}, '
170
+ f'out_channels={self.out_channels}, '
171
+ f'kernel_size={self.kernel_size}, '
172
+ f'stride={self.stride}, '
173
+ f'padding={self.padding}, '
174
+ f'groups={self.groups})')
175
+
176
+
177
+ class _Conv(_BaseConv):
178
+ num_spatial_dims: int = None
179
+
180
+ def __init__(
181
+ self,
182
+ in_size: Sequence[int],
183
+ out_channels: int,
184
+ kernel_size: Union[int, Tuple[int, ...]],
185
+ stride: Union[int, Tuple[int, ...]] = 1,
186
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
187
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
188
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
189
+ groups: int = 1,
190
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
191
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
192
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
193
+ name: str = None,
194
+ ):
195
+ super().__init__(in_size=in_size,
196
+ out_channels=out_channels,
197
+ kernel_size=kernel_size,
198
+ stride=stride,
199
+ padding=padding,
200
+ lhs_dilation=lhs_dilation,
201
+ rhs_dilation=rhs_dilation,
202
+ groups=groups,
203
+ w_mask=w_mask,
204
+ name=name)
205
+
206
+ self.w_initializer = w_init
207
+ self.b_initializer = b_init
208
+
209
+ # --- weights --- #
210
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
211
+ params = dict(weight=weight)
212
+ if self.b_initializer is not None:
213
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
214
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
215
+ params['bias'] = bias
216
+
217
+ # The weight operation
218
+ self.weight = ParamState(params)
219
+
220
+ # Evaluate the output shape
221
+ abstract_y = jax.eval_shape(
222
+ self._conv_op,
223
+ jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
224
+ params
225
+ )
226
+ y_shape = abstract_y.shape[1:]
227
+ self.out_size = y_shape
228
+
229
+ def _conv_op(self, x, params):
230
+ w = params['weight']
231
+ if self.w_mask is not None:
232
+ w = w * self.w_mask
233
+ y = jax.lax.conv_general_dilated(
234
+ lhs=x,
235
+ rhs=w,
236
+ window_strides=self.stride,
237
+ padding=self.padding,
238
+ lhs_dilation=self.lhs_dilation,
239
+ rhs_dilation=self.rhs_dilation,
240
+ feature_group_count=self.groups,
241
+ dimension_numbers=self.dimension_numbers
242
+ )
243
+ if 'bias' in params:
244
+ y = y + params['bias']
245
+ return y
246
+
247
+
248
+ class Conv1d(_Conv):
249
+ """One-dimensional convolution.
250
+
251
+ The input should be a 3d array with the shape of ``[B, H, C]``.
252
+
253
+ Parameters
254
+ ----------
255
+ %s
256
+ """
257
+ __module__ = 'brainstate.nn'
258
+ num_spatial_dims: int = 1
259
+
260
+
261
+ class Conv2d(_Conv):
262
+ """Two-dimensional convolution.
263
+
264
+ The input should be a 4d array with the shape of ``[B, H, W, C]``.
265
+
266
+ Parameters
267
+ ----------
268
+ %s
269
+ """
270
+ __module__ = 'brainstate.nn'
271
+ num_spatial_dims: int = 2
272
+
273
+
274
+ class Conv3d(_Conv):
275
+ """Three-dimensional convolution.
276
+
277
+ The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
278
+
279
+ Parameters
280
+ ----------
281
+ %s
282
+ """
283
+ __module__ = 'brainstate.nn'
284
+ num_spatial_dims: int = 3
285
+
286
+
287
+ _conv_doc = '''
288
+ in_size: tuple of int
289
+ The input shape, without the batch size. This argument is important, since it is
290
+ used to evaluate the shape of the output.
291
+ out_channels: int
292
+ The number of output channels.
293
+ kernel_size: int, sequence of int
294
+ The shape of the convolutional kernel.
295
+ For 1D convolution, the kernel size can be passed as an integer.
296
+ For all other cases, it must be a sequence of integers.
297
+ stride: int, sequence of int
298
+ An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
299
+ padding: str, int, sequence of int, sequence of tuple
300
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
301
+ high)` integer pairs that give the padding to apply before and after each
302
+ spatial dimension.
303
+ lhs_dilation: int, sequence of int
304
+ An integer or a sequence of `n` integers, giving the
305
+ dilation factor to apply in each spatial dimension of `inputs`
306
+ (default: 1). Convolution with input dilation `d` is equivalent to
307
+ transposed convolution with stride `d`.
308
+ rhs_dilation: int, sequence of int
309
+ An integer or a sequence of `n` integers, giving the
310
+ dilation factor to apply in each spatial dimension of the convolution
311
+ kernel (default: 1). Convolution with kernel dilation
312
+ is also known as 'atrous convolution'.
313
+ groups: int
314
+ If specified, divides the input features into groups. default 1.
315
+ w_init: Callable, ArrayLike, Initializer
316
+ The initializer for the convolutional kernel.
317
+ b_init: Optional, Callable, ArrayLike, Initializer
318
+ The initializer for the bias.
319
+ w_mask: ArrayLike, Callable, Optional
320
+ The optional mask of the weights.
321
+ mode: Mode
322
+ The computation mode of the current object. Default it is `training`.
323
+ name: str, Optional
324
+ The name of the object.
325
+ '''
326
+
327
+ Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
328
+ Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
329
+ Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
330
+
331
+
332
+ class _ScaledWSConv(_BaseConv):
333
+ def __init__(
334
+ self,
335
+ in_size: Sequence[int],
336
+ out_channels: int,
337
+ kernel_size: Union[int, Tuple[int, ...]],
338
+ stride: Union[int, Tuple[int, ...]] = 1,
339
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
340
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
341
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
342
+ groups: int = 1,
343
+ ws_gain: bool = True,
344
+ eps: float = 1e-4,
345
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
346
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
347
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
348
+ name: str = None,
349
+ ):
350
+ super().__init__(in_size=in_size,
351
+ out_channels=out_channels,
352
+ kernel_size=kernel_size,
353
+ stride=stride,
354
+ padding=padding,
355
+ lhs_dilation=lhs_dilation,
356
+ rhs_dilation=rhs_dilation,
357
+ groups=groups,
358
+ w_mask=w_mask,
359
+ name=name, )
360
+
361
+ self.w_initializer = w_init
362
+ self.b_initializer = b_init
363
+
364
+ # --- weights --- #
365
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
366
+ params = dict(weight=weight)
367
+ if self.b_initializer is not None:
368
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
369
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
370
+ params['bias'] = bias
371
+
372
+ # gain
373
+ if ws_gain:
374
+ gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
375
+ ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
376
+ params['gain'] = ws_gain
377
+
378
+ # Epsilon, a small constant to avoid dividing by zero.
379
+ self.eps = eps
380
+
381
+ # The weight operation
382
+ self.weight = ParamState(params)
383
+
384
+ # Evaluate the output shape
385
+ abstract_y = jax.eval_shape(
386
+ self._conv_op,
387
+ jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
388
+ params
389
+ )
390
+ y_shape = abstract_y.shape[1:]
391
+ self.out_size = y_shape
392
+
393
+ def _conv_op(self, x, params):
394
+ w = params['weight']
395
+ w = functional.weight_standardization(w, self.eps, params.get('gain', None))
396
+ if self.w_mask is not None:
397
+ w = w * self.w_mask
398
+ y = jax.lax.conv_general_dilated(
399
+ lhs=x,
400
+ rhs=w,
401
+ window_strides=self.stride,
402
+ padding=self.padding,
403
+ lhs_dilation=self.lhs_dilation,
404
+ rhs_dilation=self.rhs_dilation,
405
+ feature_group_count=self.groups,
406
+ dimension_numbers=self.dimension_numbers
407
+ )
408
+ if 'bias' in params:
409
+ y = y + params['bias']
410
+ return y
411
+
412
+
413
+ class ScaledWSConv1d(_ScaledWSConv):
414
+ """One-dimensional convolution with weight standardization.
415
+
416
+ The input should be a 3d array with the shape of ``[B, H, C]``.
417
+
418
+ Parameters
419
+ ----------
420
+ %s
421
+ """
422
+ __module__ = 'brainstate.nn'
423
+ num_spatial_dims: int = 1
424
+
425
+
426
+ class ScaledWSConv2d(_ScaledWSConv):
427
+ """Two-dimensional convolution with weight standardization.
428
+
429
+ The input should be a 4d array with the shape of ``[B, H, W, C]``.
430
+
431
+ Parameters
432
+ ----------
433
+ %s
434
+ """
435
+ __module__ = 'brainstate.nn'
436
+ num_spatial_dims: int = 2
437
+
438
+
439
+ class ScaledWSConv3d(_ScaledWSConv):
440
+ """Three-dimensional convolution with weight standardization.
441
+
442
+ The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
443
+
444
+ Parameters
445
+ ----------
446
+ %s
447
+ """
448
+ __module__ = 'brainstate.nn'
449
+ num_spatial_dims: int = 3
450
+
451
+
452
+ _ws_conv_doc = '''
453
+ in_size: tuple of int
454
+ The input shape, without the batch size. This argument is important, since it is
455
+ used to evaluate the shape of the output.
456
+ out_channels: int
457
+ The number of output channels.
458
+ kernel_size: int, sequence of int
459
+ The shape of the convolutional kernel.
460
+ For 1D convolution, the kernel size can be passed as an integer.
461
+ For all other cases, it must be a sequence of integers.
462
+ stride: int, sequence of int
463
+ An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
464
+ padding: str, int, sequence of int, sequence of tuple
465
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
466
+ high)` integer pairs that give the padding to apply before and after each
467
+ spatial dimension.
468
+ lhs_dilation: int, sequence of int
469
+ An integer or a sequence of `n` integers, giving the
470
+ dilation factor to apply in each spatial dimension of `inputs`
471
+ (default: 1). Convolution with input dilation `d` is equivalent to
472
+ transposed convolution with stride `d`.
473
+ rhs_dilation: int, sequence of int
474
+ An integer or a sequence of `n` integers, giving the
475
+ dilation factor to apply in each spatial dimension of the convolution
476
+ kernel (default: 1). Convolution with kernel dilation
477
+ is also known as 'atrous convolution'.
478
+ groups: int
479
+ If specified, divides the input features into groups. default 1.
480
+ w_init: Callable, ArrayLike, Initializer
481
+ The initializer for the convolutional kernel.
482
+ b_init: Optional, Callable, ArrayLike, Initializer
483
+ The initializer for the bias.
484
+ ws_gain: bool
485
+ Whether to add a gain term for the weight standarization. The default is `True`.
486
+ eps: float
487
+ The epsilon value for numerical stability.
488
+ w_mask: ArrayLike, Callable, Optional
489
+ The optional mask of the weights.
490
+ mode: Mode
491
+ The computation mode of the current object. Default it is `training`.
492
+ name: str, Optional
493
+ The name of the object.
494
+
495
+ '''
496
+
497
+ ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
498
+ ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
499
+ ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc