brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,686 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from __future__ import annotations
19
-
20
- import collections.abc
21
- import numbers
22
- from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
23
-
24
- import jax
25
- import jax.numpy as jnp
26
-
27
- from ._base import DnnLayer
28
- from brainstate import init, functional
29
- from brainstate._state import ParamState
30
- from brainstate.mixin import Mode
31
- from brainstate.typing import ArrayLike
32
-
33
- T = TypeVar('T')
34
-
35
- __all__ = [
36
- 'Linear', 'ScaledWSLinear', 'SignedWLinear', 'CSRLinear',
37
- 'Conv1d', 'Conv2d', 'Conv3d',
38
- 'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
39
- ]
40
-
41
-
42
- def to_dimension_numbers(
43
- num_spatial_dims: int,
44
- channels_last: bool,
45
- transpose: bool
46
- ) -> jax.lax.ConvDimensionNumbers:
47
- """Create a `lax.ConvDimensionNumbers` for the given inputs."""
48
- num_dims = num_spatial_dims + 2
49
- if channels_last:
50
- spatial_dims = tuple(range(1, num_dims - 1))
51
- image_dn = (0, num_dims - 1) + spatial_dims
52
- else:
53
- spatial_dims = tuple(range(2, num_dims))
54
- image_dn = (0, 1) + spatial_dims
55
- if transpose:
56
- kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
57
- else:
58
- kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
59
- return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
60
- rhs_spec=kernel_dn,
61
- out_spec=image_dn)
62
-
63
-
64
- def replicate(
65
- element: Union[T, Sequence[T]],
66
- num_replicate: int,
67
- name: str,
68
- ) -> Tuple[T, ...]:
69
- """Replicates entry in `element` `num_replicate` if needed."""
70
- if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
71
- return (element,) * num_replicate
72
- elif len(element) == 1:
73
- return tuple(list(element) * num_replicate)
74
- elif len(element) == num_replicate:
75
- return tuple(element)
76
- else:
77
- raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
78
- f"sequence of length {num_replicate}.")
79
-
80
-
81
- class Linear(DnnLayer):
82
- """
83
- Linear layer.
84
- """
85
- __module__ = 'brainstate.nn'
86
-
87
- def __init__(
88
- self,
89
- in_size: Union[int, Sequence[int]],
90
- out_size: Union[int, Sequence[int]],
91
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
92
- b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
93
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
94
- name: Optional[str] = None,
95
- mode: Optional[Mode] = None,
96
- ):
97
- super().__init__(name=name, mode=mode)
98
-
99
- # input and output shape
100
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
101
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
102
-
103
- # w_mask
104
- self.w_mask = init.param(w_mask, self.in_size + self.out_size)
105
-
106
- # weights
107
- params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
108
- if b_init is not None:
109
- params['bias'] = init.param(b_init, self.out_size, allow_none=False)
110
-
111
- # weight + op
112
- self.W = ParamState(params)
113
-
114
- def update(self, x):
115
- params = self.W.value
116
- weight = params['weight']
117
- if self.w_mask is not None:
118
- weight = weight * self.w_mask
119
- y = jnp.dot(x, weight)
120
- if 'bias' in params:
121
- y = y + params['bias']
122
- return y
123
-
124
-
125
- class SignedWLinear(DnnLayer):
126
- """
127
- Linear layer with signed weights.
128
- """
129
- __module__ = 'brainstate.nn'
130
-
131
- def __init__(
132
- self,
133
- in_size: Union[int, Sequence[int]],
134
- out_size: Union[int, Sequence[int]],
135
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
136
- w_sign: Optional[ArrayLike] = None,
137
- name: Optional[str] = None,
138
- mode: Optional[Mode] = None
139
- ):
140
- super().__init__(name=name, mode=mode)
141
-
142
- # input and output shape
143
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
144
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
145
-
146
- # w_mask
147
- self.w_sign = w_sign
148
-
149
- # weights
150
- weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
151
- self.W = ParamState(weight)
152
-
153
- def _operation(self, x, w):
154
- if self.w_sign is None:
155
- return jnp.matmul(x, jnp.abs(w))
156
- else:
157
- return jnp.matmul(x, jnp.abs(w) * self.w_sign)
158
-
159
- def update(self, x):
160
- return self._operation(x, self.W.value)
161
-
162
-
163
- class ScaledWSLinear(DnnLayer):
164
- """
165
- Linear Layer with Weight Standardization.
166
-
167
- Applies weight standardization to the weights of the linear layer.
168
-
169
- Parameters
170
- ----------
171
- in_size: int, sequence of int
172
- The input size.
173
- out_size: int, sequence of int
174
- The output size.
175
- w_init: Callable, ArrayLike
176
- The initializer for the weights.
177
- b_init: Callable, ArrayLike
178
- The initializer for the bias.
179
- w_mask: ArrayLike, Callable
180
- The optional mask of the weights.
181
- as_etrace_weight: bool
182
- Whether to use ETraceParamOp for the weights.
183
- ws_gain: bool
184
- Whether to use gain for the weights. The default is True.
185
- eps: float
186
- The epsilon value for the weight standardization.
187
- name: str
188
- The name of the object.
189
- mode: Mode
190
- The computation mode of the current object.
191
-
192
- """
193
- __module__ = 'brainstate.nn'
194
-
195
- def __init__(
196
- self,
197
- in_size: Union[int, Sequence[int]],
198
- out_size: Union[int, Sequence[int]],
199
- w_init: Callable = init.KaimingNormal(),
200
- b_init: Callable = init.ZeroInit(),
201
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
202
- as_etrace_weight: bool = True,
203
- full_etrace: bool = False,
204
- ws_gain: bool = True,
205
- eps: float = 1e-4,
206
- name: str = None,
207
- mode: Mode = None
208
- ):
209
- super().__init__(name=name, mode=mode)
210
-
211
- # input and output shape
212
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
213
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
214
-
215
- # w_mask
216
- self.w_mask = init.param(w_mask, (self.in_size[0], 1))
217
-
218
- # parameters
219
- self.eps = eps
220
-
221
- # weights
222
- params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
223
- if b_init is not None:
224
- params['bias'] = init.param(b_init, self.out_size, allow_none=False)
225
- # gain
226
- if ws_gain:
227
- s = params['weight'].shape
228
- params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
229
-
230
- # weight operation
231
- self.W = ParamState(params)
232
-
233
- def update(self, x):
234
- return self._operation(x, self.W.value)
235
-
236
- def _operation(self, x, params):
237
- w = params['weight']
238
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
239
- if self.w_mask is not None:
240
- w = w * self.w_mask
241
- y = jnp.dot(x, w)
242
- if 'bias' in params:
243
- y = y + params['bias']
244
- return y
245
-
246
-
247
- class CSRLinear(DnnLayer):
248
- __module__ = 'brainstate.nn'
249
-
250
-
251
- class _BaseConv(DnnLayer):
252
- # the number of spatial dimensions
253
- num_spatial_dims: int
254
-
255
- # the weight and its operations
256
- W: ParamState
257
-
258
- def __init__(
259
- self,
260
- in_size: Sequence[int],
261
- out_channels: int,
262
- kernel_size: Union[int, Tuple[int, ...]],
263
- stride: Union[int, Tuple[int, ...]] = 1,
264
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
265
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
266
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
267
- groups: int = 1,
268
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
269
- mode: Mode = None,
270
- name: str = None,
271
- ):
272
- super().__init__(name=name, mode=mode)
273
-
274
- # general parameters
275
- assert self.num_spatial_dims + 1 == len(in_size)
276
- self.in_size = tuple(in_size)
277
- self.in_channels = in_size[-1]
278
- self.out_channels = out_channels
279
- self.stride = replicate(stride, self.num_spatial_dims, 'stride')
280
- self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
281
- self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
282
- self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
283
- self.groups = groups
284
- self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
285
-
286
- # the padding parameter
287
- if isinstance(padding, str):
288
- assert padding in ['SAME', 'VALID']
289
- elif isinstance(padding, int):
290
- padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
291
- elif isinstance(padding, (tuple, list)):
292
- if isinstance(padding[0], int):
293
- padding = (padding,) * self.num_spatial_dims
294
- elif isinstance(padding[0], (tuple, list)):
295
- if len(padding) == 1:
296
- padding = tuple(padding) * self.num_spatial_dims
297
- else:
298
- if len(padding) != self.num_spatial_dims:
299
- raise ValueError(
300
- f"Padding {padding} must be a Tuple[int, int], "
301
- f"or sequence of Tuple[int, int] with length 1, "
302
- f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
303
- )
304
- padding = tuple(padding)
305
- else:
306
- raise ValueError
307
- self.padding = padding
308
-
309
- # the number of in-/out-channels
310
- assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
311
- assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
312
-
313
- # kernel shape and w_mask
314
- kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
315
- self.kernel_shape = kernel_shape
316
- self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
317
-
318
- def _check_input_dim(self, x):
319
- if x.ndim == self.num_spatial_dims + 2:
320
- x_shape = x.shape[1:]
321
- elif x.ndim == self.num_spatial_dims + 1:
322
- x_shape = x.shape
323
- else:
324
- raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
325
- f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
326
- if self.in_size != x_shape:
327
- raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
328
-
329
- def update(self, x):
330
- self._check_input_dim(x)
331
- non_batching = False
332
- if x.ndim == self.num_spatial_dims + 1:
333
- x = jnp.expand_dims(x, 0)
334
- non_batching = True
335
- y = self._conv_op(x, self.W.value)
336
- return y[0] if non_batching else y
337
-
338
- def _conv_op(self, x, params):
339
- raise NotImplementedError
340
-
341
- def __repr__(self):
342
- return (f'{self.__class__.__name__}('
343
- f'in_channels={self.in_channels}, '
344
- f'out_channels={self.out_channels}, '
345
- f'kernel_size={self.kernel_size}, '
346
- f'stride={self.stride}, '
347
- f'padding={self.padding}, '
348
- f'groups={self.groups})')
349
-
350
-
351
- class _Conv(_BaseConv):
352
- num_spatial_dims: int = None
353
-
354
- def __init__(
355
- self,
356
- in_size: Sequence[int],
357
- out_channels: int,
358
- kernel_size: Union[int, Tuple[int, ...]],
359
- stride: Union[int, Tuple[int, ...]] = 1,
360
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
361
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
362
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
363
- groups: int = 1,
364
- w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
365
- b_init: Optional[Union[Callable, ArrayLike]] = None,
366
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
367
- mode: Mode = None,
368
- name: str = None,
369
- ):
370
- super().__init__(in_size=in_size,
371
- out_channels=out_channels,
372
- kernel_size=kernel_size,
373
- stride=stride,
374
- padding=padding,
375
- lhs_dilation=lhs_dilation,
376
- rhs_dilation=rhs_dilation,
377
- groups=groups,
378
- w_mask=w_mask,
379
- name=name,
380
- mode=mode)
381
-
382
- self.w_initializer = w_init
383
- self.b_initializer = b_init
384
-
385
- # --- weights --- #
386
- weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
387
- params = dict(weight=weight)
388
- if self.b_initializer is not None:
389
- bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
390
- bias = init.param(self.b_initializer, bias_shape, allow_none=True)
391
- params['bias'] = bias
392
-
393
- # The weight operation
394
- self.W = ParamState(params)
395
-
396
- # Evaluate the output shape
397
- abstract_y = jax.eval_shape(
398
- self._conv_op,
399
- jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
400
- params
401
- )
402
- y_shape = abstract_y.shape[1:]
403
- self.out_size = y_shape
404
-
405
- def _conv_op(self, x, params):
406
- w = params['weight']
407
- if self.w_mask is not None:
408
- w = w * self.w_mask
409
- y = jax.lax.conv_general_dilated(
410
- lhs=x,
411
- rhs=w,
412
- window_strides=self.stride,
413
- padding=self.padding,
414
- lhs_dilation=self.lhs_dilation,
415
- rhs_dilation=self.rhs_dilation,
416
- feature_group_count=self.groups,
417
- dimension_numbers=self.dimension_numbers
418
- )
419
- if 'bias' in params:
420
- y = y + params['bias']
421
- return y
422
-
423
- def __repr__(self):
424
- return (f'{self.__class__.__name__}('
425
- f'in_channels={self.in_channels}, '
426
- f'out_channels={self.out_channels}, '
427
- f'kernel_size={self.kernel_size}, '
428
- f'stride={self.stride}, '
429
- f'padding={self.padding}, '
430
- f'groups={self.groups})')
431
-
432
-
433
- class Conv1d(_Conv):
434
- """One-dimensional convolution.
435
-
436
- The input should be a 3d array with the shape of ``[B, H, C]``.
437
-
438
- Parameters
439
- ----------
440
- %s
441
- """
442
- __module__ = 'brainstate.nn'
443
- num_spatial_dims: int = 1
444
-
445
-
446
- class Conv2d(_Conv):
447
- """Two-dimensional convolution.
448
-
449
- The input should be a 4d array with the shape of ``[B, H, W, C]``.
450
-
451
- Parameters
452
- ----------
453
- %s
454
- """
455
- __module__ = 'brainstate.nn'
456
- num_spatial_dims: int = 2
457
-
458
-
459
- class Conv3d(_Conv):
460
- """Three-dimensional convolution.
461
-
462
- The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
463
-
464
- Parameters
465
- ----------
466
- %s
467
- """
468
- __module__ = 'brainstate.nn'
469
- num_spatial_dims: int = 3
470
-
471
-
472
- _conv_doc = '''
473
- in_size: tuple of int
474
- The input shape, without the batch size. This argument is important, since it is
475
- used to evaluate the shape of the output.
476
- out_channels: int
477
- The number of output channels.
478
- kernel_size: int, sequence of int
479
- The shape of the convolutional kernel.
480
- For 1D convolution, the kernel size can be passed as an integer.
481
- For all other cases, it must be a sequence of integers.
482
- stride: int, sequence of int
483
- An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
484
- padding: str, int, sequence of int, sequence of tuple
485
- Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
486
- high)` integer pairs that give the padding to apply before and after each
487
- spatial dimension.
488
- lhs_dilation: int, sequence of int
489
- An integer or a sequence of `n` integers, giving the
490
- dilation factor to apply in each spatial dimension of `inputs`
491
- (default: 1). Convolution with input dilation `d` is equivalent to
492
- transposed convolution with stride `d`.
493
- rhs_dilation: int, sequence of int
494
- An integer or a sequence of `n` integers, giving the
495
- dilation factor to apply in each spatial dimension of the convolution
496
- kernel (default: 1). Convolution with kernel dilation
497
- is also known as 'atrous convolution'.
498
- groups: int
499
- If specified, divides the input features into groups. default 1.
500
- w_init: Callable, ArrayLike, Initializer
501
- The initializer for the convolutional kernel.
502
- b_init: Optional, Callable, ArrayLike, Initializer
503
- The initializer for the bias.
504
- w_mask: ArrayLike, Callable, Optional
505
- The optional mask of the weights.
506
- mode: Mode
507
- The computation mode of the current object. Default it is `training`.
508
- name: str, Optional
509
- The name of the object.
510
- '''
511
-
512
- Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
513
- Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
514
- Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
515
-
516
-
517
- class _ScaledWSConv(_BaseConv):
518
- def __init__(
519
- self,
520
- in_size: Sequence[int],
521
- out_channels: int,
522
- kernel_size: Union[int, Tuple[int, ...]],
523
- stride: Union[int, Tuple[int, ...]] = 1,
524
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
525
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
526
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
527
- groups: int = 1,
528
- ws_gain: bool = True,
529
- eps: float = 1e-4,
530
- w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
531
- b_init: Optional[Union[Callable, ArrayLike]] = None,
532
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
533
- mode: Mode = None,
534
- name: str = None,
535
- ):
536
- super().__init__(in_size=in_size,
537
- out_channels=out_channels,
538
- kernel_size=kernel_size,
539
- stride=stride,
540
- padding=padding,
541
- lhs_dilation=lhs_dilation,
542
- rhs_dilation=rhs_dilation,
543
- groups=groups,
544
- w_mask=w_mask,
545
- name=name,
546
- mode=mode)
547
-
548
- self.w_initializer = w_init
549
- self.b_initializer = b_init
550
-
551
- # --- weights --- #
552
- weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
553
- params = dict(weight=weight)
554
- if self.b_initializer is not None:
555
- bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
556
- bias = init.param(self.b_initializer, bias_shape, allow_none=True)
557
- params['bias'] = bias
558
-
559
- # gain
560
- if ws_gain:
561
- gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
562
- ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
563
- params['gain'] = ws_gain
564
-
565
- # Epsilon, a small constant to avoid dividing by zero.
566
- self.eps = eps
567
-
568
- # The weight operation
569
- self.W = ParamState(params)
570
-
571
- # Evaluate the output shape
572
- abstract_y = jax.eval_shape(
573
- self._conv_op,
574
- jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
575
- params
576
- )
577
- y_shape = abstract_y.shape[1:]
578
- self.out_size = y_shape
579
-
580
- def _conv_op(self, x, params):
581
- w = params['weight']
582
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
583
- if self.w_mask is not None:
584
- w = w * self.w_mask
585
- y = jax.lax.conv_general_dilated(
586
- lhs=x,
587
- rhs=w,
588
- window_strides=self.stride,
589
- padding=self.padding,
590
- lhs_dilation=self.lhs_dilation,
591
- rhs_dilation=self.rhs_dilation,
592
- feature_group_count=self.groups,
593
- dimension_numbers=self.dimension_numbers
594
- )
595
- if 'bias' in params:
596
- y = y + params['bias']
597
- return y
598
-
599
-
600
- class ScaledWSConv1d(_ScaledWSConv):
601
- """One-dimensional convolution with weight standardization.
602
-
603
- The input should be a 3d array with the shape of ``[B, H, C]``.
604
-
605
- Parameters
606
- ----------
607
- %s
608
- """
609
- __module__ = 'brainstate.nn'
610
- num_spatial_dims: int = 1
611
-
612
-
613
- class ScaledWSConv2d(_ScaledWSConv):
614
- """Two-dimensional convolution with weight standardization.
615
-
616
- The input should be a 4d array with the shape of ``[B, H, W, C]``.
617
-
618
- Parameters
619
- ----------
620
- %s
621
- """
622
- __module__ = 'brainstate.nn'
623
- num_spatial_dims: int = 2
624
-
625
-
626
- class ScaledWSConv3d(_ScaledWSConv):
627
- """Three-dimensional convolution with weight standardization.
628
-
629
- The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
630
-
631
- Parameters
632
- ----------
633
- %s
634
- """
635
- __module__ = 'brainstate.nn'
636
- num_spatial_dims: int = 3
637
-
638
-
639
- _ws_conv_doc = '''
640
- in_size: tuple of int
641
- The input shape, without the batch size. This argument is important, since it is
642
- used to evaluate the shape of the output.
643
- out_channels: int
644
- The number of output channels.
645
- kernel_size: int, sequence of int
646
- The shape of the convolutional kernel.
647
- For 1D convolution, the kernel size can be passed as an integer.
648
- For all other cases, it must be a sequence of integers.
649
- stride: int, sequence of int
650
- An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
651
- padding: str, int, sequence of int, sequence of tuple
652
- Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
653
- high)` integer pairs that give the padding to apply before and after each
654
- spatial dimension.
655
- lhs_dilation: int, sequence of int
656
- An integer or a sequence of `n` integers, giving the
657
- dilation factor to apply in each spatial dimension of `inputs`
658
- (default: 1). Convolution with input dilation `d` is equivalent to
659
- transposed convolution with stride `d`.
660
- rhs_dilation: int, sequence of int
661
- An integer or a sequence of `n` integers, giving the
662
- dilation factor to apply in each spatial dimension of the convolution
663
- kernel (default: 1). Convolution with kernel dilation
664
- is also known as 'atrous convolution'.
665
- groups: int
666
- If specified, divides the input features into groups. default 1.
667
- w_init: Callable, ArrayLike, Initializer
668
- The initializer for the convolutional kernel.
669
- b_init: Optional, Callable, ArrayLike, Initializer
670
- The initializer for the bias.
671
- ws_gain: bool
672
- Whether to add a gain term for the weight standarization. The default is `True`.
673
- eps: float
674
- The epsilon value for numerical stability.
675
- w_mask: ArrayLike, Callable, Optional
676
- The optional mask of the weights.
677
- mode: Mode
678
- The computation mode of the current object. Default it is `training`.
679
- name: str, Optional
680
- The name of the object.
681
-
682
- '''
683
-
684
- ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
685
- ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
686
- ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc