brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1179 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import functools
21
+ from typing import Sequence, Optional
22
+ from typing import Union, Tuple, Callable, List
23
+
24
+ import brainunit as u
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import numpy as np
28
+
29
+ from brainstate import environ
30
+ from brainstate.nn._module import Module
31
+ from brainstate.typing import Size
32
+
33
+ __all__ = [
34
+ 'Flatten', 'Unflatten',
35
+
36
+ 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
37
+ 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
38
+
39
+ 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
40
+ 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
41
+ ]
42
+
43
+
44
+ class Flatten(Module):
45
+ r"""
46
+ Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
47
+
48
+ Shape:
49
+ - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
50
+ where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
51
+ number of dimensions including none.
52
+ - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
53
+
54
+ Args:
55
+ in_size: Sequence of int. The shape of the input tensor.
56
+ start_axis: first dim to flatten (default = 1).
57
+ end_axis: last dim to flatten (default = -1).
58
+
59
+ Examples::
60
+ >>> import brainstate as bst
61
+ >>> inp = bst.random.randn(32, 1, 5, 5)
62
+ >>> # With default parameters
63
+ >>> m = Flatten()
64
+ >>> output = m(inp)
65
+ >>> output.shape
66
+ (32, 25)
67
+ >>> # With non-default parameters
68
+ >>> m = Flatten(0, 2)
69
+ >>> output = m(inp)
70
+ >>> output.shape
71
+ (160, 5)
72
+ """
73
+ __module__ = 'brainstate.nn'
74
+
75
+ def __init__(
76
+ self,
77
+ start_axis: int = 0,
78
+ end_axis: int = -1,
79
+ in_size: Optional[Size] = None
80
+ ) -> None:
81
+ super().__init__()
82
+ self.start_axis = start_axis
83
+ self.end_axis = end_axis
84
+
85
+ if in_size is not None:
86
+ self.in_size = tuple(in_size)
87
+ y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
88
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
89
+ self.out_size = y.shape
90
+
91
+ def update(self, x):
92
+ if self._in_size is None:
93
+ start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
94
+ else:
95
+ assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
96
+ dim_diff = x.ndim - len(self.in_size)
97
+ if self.in_size != x.shape[dim_diff:]:
98
+ raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
99
+ if self.start_axis >= 0:
100
+ start_axis = self.start_axis + dim_diff
101
+ else:
102
+ start_axis = x.ndim + self.start_axis
103
+ return u.math.flatten(x, start_axis, self.end_axis)
104
+
105
+ def __repr__(self) -> str:
106
+ return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
107
+
108
+
109
+ class Unflatten(Module):
110
+ r"""
111
+ Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
112
+
113
+ * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
114
+ be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
115
+
116
+ * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
117
+ a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
118
+ (tuple of `(name, size)` tuples) for `NamedTensor` input.
119
+
120
+ Shape:
121
+ - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
122
+ dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
123
+ - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
124
+ :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
125
+
126
+ Args:
127
+ axis: int, Dimension to be unflattened.
128
+ sizes: Sequence of int. New shape of the unflattened dimension.
129
+ in_size: Sequence of int. The shape of the input tensor.
130
+ """
131
+ __module__ = 'brainstate.nn'
132
+
133
+ def __init__(
134
+ self,
135
+ axis: int,
136
+ sizes: Size,
137
+ name: str = None,
138
+ in_size: Optional[Size] = None
139
+ ) -> None:
140
+ super().__init__(name=name)
141
+
142
+ self.axis = axis
143
+ self.sizes = sizes
144
+ if isinstance(sizes, (tuple, list)):
145
+ for idx, elem in enumerate(sizes):
146
+ if not isinstance(elem, int):
147
+ raise TypeError("unflattened sizes must be tuple of ints, " +
148
+ "but found element of type {} at pos {}".format(type(elem).__name__, idx))
149
+ else:
150
+ raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
151
+
152
+ if in_size is not None:
153
+ self.in_size = tuple(in_size)
154
+ y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
155
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
156
+ self.out_size = y.shape
157
+
158
+ def update(self, x):
159
+ return u.math.unflatten(x, self.axis, self.sizes)
160
+
161
+ def __repr__(self):
162
+ return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
163
+
164
+
165
+ class _MaxPool(Module):
166
+ def __init__(
167
+ self,
168
+ init_value: float,
169
+ computation: Callable,
170
+ pool_dim: int,
171
+ kernel_size: Size,
172
+ stride: Union[int, Sequence[int]] = None,
173
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
174
+ channel_axis: Optional[int] = -1,
175
+ name: Optional[str] = None,
176
+ in_size: Optional[Size] = None,
177
+ ):
178
+ super().__init__(name=name)
179
+
180
+ self.init_value = init_value
181
+ self.computation = computation
182
+ self.pool_dim = pool_dim
183
+
184
+ # kernel_size
185
+ if isinstance(kernel_size, int):
186
+ kernel_size = (kernel_size,) * pool_dim
187
+ elif isinstance(kernel_size, Sequence):
188
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
189
+ assert all(
190
+ [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
191
+ if len(kernel_size) != pool_dim:
192
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
193
+ else:
194
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
195
+ self.kernel_size = kernel_size
196
+
197
+ # stride
198
+ if stride is None:
199
+ stride = kernel_size
200
+ if isinstance(stride, int):
201
+ stride = (stride,) * pool_dim
202
+ elif isinstance(stride, Sequence):
203
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
204
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
205
+ if len(stride) != pool_dim:
206
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
207
+ else:
208
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
209
+ self.stride = stride
210
+
211
+ # padding
212
+ if isinstance(padding, str):
213
+ if padding not in ("SAME", "VALID"):
214
+ raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
215
+ elif isinstance(padding, int):
216
+ padding = [(padding, padding) for _ in range(pool_dim)]
217
+ elif isinstance(padding, (list, tuple)):
218
+ if isinstance(padding[0], int):
219
+ if len(padding) == pool_dim:
220
+ padding = [(x, x) for x in padding]
221
+ else:
222
+ raise ValueError(f'If padding is a sequence of ints, it '
223
+ f'should has the length of {pool_dim}.')
224
+ else:
225
+ if not all([isinstance(x, (tuple, list)) for x in padding]):
226
+ raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
227
+ if not all([len(x) == 2 for x in padding]):
228
+ raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
229
+ if len(padding) == 1:
230
+ padding = tuple(padding) * pool_dim
231
+ assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
232
+ else:
233
+ raise ValueError
234
+ self.padding = padding
235
+
236
+ # channel_axis
237
+ assert channel_axis is None or isinstance(channel_axis, int), \
238
+ f'channel_axis should be an int, but got {channel_axis}'
239
+ self.channel_axis = channel_axis
240
+
241
+ # in & out shapes
242
+ if in_size is not None:
243
+ in_size = tuple(in_size)
244
+ self.in_size = in_size
245
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
246
+ self.out_size = y.shape[1:]
247
+
248
+ def update(self, x):
249
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
250
+ if x.ndim < x_dim:
251
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
252
+ window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
253
+ stride = self._infer_shape(x.ndim, self.stride, 1)
254
+ padding = (self.padding if isinstance(self.padding, str) else
255
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
256
+ r = jax.lax.reduce_window(
257
+ x,
258
+ init_value=self.init_value,
259
+ computation=self.computation,
260
+ window_dimensions=window_shape,
261
+ window_strides=stride,
262
+ padding=padding
263
+ )
264
+ return r
265
+
266
+ def _infer_shape(self, x_dim, inputs, element):
267
+ channel_axis = self.channel_axis
268
+ if channel_axis and not 0 <= abs(channel_axis) < x_dim:
269
+ raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
270
+ if channel_axis and channel_axis < 0:
271
+ channel_axis = x_dim + channel_axis
272
+ all_dims = list(range(x_dim))
273
+ if channel_axis is not None:
274
+ all_dims.pop(channel_axis)
275
+ pool_dims = all_dims[-self.pool_dim:]
276
+ results = [element] * x_dim
277
+ for i, dim in enumerate(pool_dims):
278
+ results[dim] = inputs[i]
279
+ return results
280
+
281
+
282
+ class _AvgPool(_MaxPool):
283
+ def update(self, x):
284
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
285
+ if x.ndim < x_dim:
286
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
287
+ dims = self._infer_shape(x.ndim, self.kernel_size, 1)
288
+ stride = self._infer_shape(x.ndim, self.stride, 1)
289
+ padding = (self.padding if isinstance(self.padding, str) else
290
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
291
+ pooled = jax.lax.reduce_window(x,
292
+ init_value=self.init_value,
293
+ computation=self.computation,
294
+ window_dimensions=dims,
295
+ window_strides=stride,
296
+ padding=padding)
297
+ if padding == "VALID":
298
+ # Avoid the extra reduce_window.
299
+ return pooled / np.prod(dims)
300
+ else:
301
+ # Count the number of valid entries at each input point, then use that for
302
+ # computing average. Assumes that any two arrays of same shape will be
303
+ # padded the same.
304
+ window_counts = jax.lax.reduce_window(jnp.ones_like(x),
305
+ init_value=self.init_value,
306
+ computation=self.computation,
307
+ window_dimensions=dims,
308
+ window_strides=stride,
309
+ padding=padding)
310
+ assert pooled.shape == window_counts.shape
311
+ return pooled / window_counts
312
+
313
+
314
+ class MaxPool1d(_MaxPool):
315
+ r"""Applies a 1D max pooling over an input signal composed of several input planes.
316
+
317
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
318
+ and output :math:`(N, L_{out}, C)` can be precisely described as:
319
+
320
+ .. math::
321
+ out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
322
+ input(N_i, stride \times k + m, C_j)
323
+
324
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
325
+ for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
326
+ sliding window. This `link`_ has a nice visualization of the pooling parameters.
327
+
328
+ Shape:
329
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
330
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
331
+
332
+ .. math::
333
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
334
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
335
+
336
+
337
+ Examples::
338
+
339
+ >>> import brainstate as bst
340
+ >>> # pool of size=3, stride=2
341
+ >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
342
+ >>> input = bst.random.randn(20, 50, 16)
343
+ >>> output = m(input)
344
+ >>> output.shape
345
+ (20, 24, 16)
346
+
347
+ Parameters
348
+ ----------
349
+ in_size: Sequence of int
350
+ The shape of the input tensor.
351
+ kernel_size: int, sequence of int
352
+ An integer, or a sequence of integers defining the window to reduce over.
353
+ stride: int, sequence of int
354
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
355
+ padding: str, int, sequence of tuple
356
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
357
+ of n `(low, high)` integer pairs that give the padding to apply before
358
+ and after each spatial dimension.
359
+ channel_axis: int, optional
360
+ Axis of the spatial channels for which pooling is skipped.
361
+ If ``None``, there is no channel axis.
362
+ name: optional, str
363
+ The object name.
364
+
365
+ .. _link:
366
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
367
+ """
368
+ __module__ = 'brainstate.nn'
369
+
370
+ def __init__(
371
+ self,
372
+ kernel_size: Size,
373
+ stride: Union[int, Sequence[int]] = None,
374
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
375
+ channel_axis: Optional[int] = -1,
376
+ name: Optional[str] = None,
377
+ in_size: Optional[Size] = None,
378
+ ):
379
+ super().__init__(in_size=in_size,
380
+ init_value=-jax.numpy.inf,
381
+ computation=jax.lax.max,
382
+ pool_dim=1,
383
+ kernel_size=kernel_size,
384
+ stride=stride,
385
+ padding=padding,
386
+ channel_axis=channel_axis,
387
+ name=name)
388
+
389
+
390
+ class MaxPool2d(_MaxPool):
391
+ r"""Applies a 2D max pooling over an input signal composed of several input planes.
392
+
393
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
394
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
395
+ can be precisely described as:
396
+
397
+ .. math::
398
+ \begin{aligned}
399
+ out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
400
+ & \text{input}(N_i, \text{stride[0]} \times h + m,
401
+ \text{stride[1]} \times w + n, C_j)
402
+ \end{aligned}
403
+
404
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
405
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
406
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
407
+
408
+
409
+ Shape:
410
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
411
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
412
+
413
+ .. math::
414
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
415
+ \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
416
+
417
+ .. math::
418
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
419
+ \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
420
+
421
+ Examples::
422
+
423
+ >>> import brainstate as bst
424
+ >>> # pool of square window of size=3, stride=2
425
+ >>> m = MaxPool2d(3, stride=2)
426
+ >>> # pool of non-square window
427
+ >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
428
+ >>> input = bst.random.randn(20, 50, 32, 16)
429
+ >>> output = m(input)
430
+ >>> output.shape
431
+ (20, 24, 31, 16)
432
+
433
+ .. _link:
434
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
435
+
436
+ Parameters
437
+ ----------
438
+ in_size: Sequence of int
439
+ The shape of the input tensor.
440
+ kernel_size: int, sequence of int
441
+ An integer, or a sequence of integers defining the window to reduce over.
442
+ stride: int, sequence of int
443
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
444
+ padding: str, int, sequence of tuple
445
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
446
+ of n `(low, high)` integer pairs that give the padding to apply before
447
+ and after each spatial dimension.
448
+ channel_axis: int, optional
449
+ Axis of the spatial channels for which pooling is skipped.
450
+ If ``None``, there is no channel axis.
451
+ name: optional, str
452
+ The object name.
453
+
454
+ """
455
+ __module__ = 'brainstate.nn'
456
+
457
+ def __init__(
458
+ self,
459
+ kernel_size: Size,
460
+ stride: Union[int, Sequence[int]] = None,
461
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
462
+ channel_axis: Optional[int] = -1,
463
+ name: Optional[str] = None,
464
+ in_size: Optional[Size] = None,
465
+ ):
466
+ super().__init__(in_size=in_size,
467
+ init_value=-jax.numpy.inf,
468
+ computation=jax.lax.max,
469
+ pool_dim=2,
470
+ kernel_size=kernel_size,
471
+ stride=stride,
472
+ padding=padding,
473
+ channel_axis=channel_axis,
474
+ name=name)
475
+
476
+
477
+ class MaxPool3d(_MaxPool):
478
+ r"""Applies a 3D max pooling over an input signal composed of several input planes.
479
+
480
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
481
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
482
+ can be precisely described as:
483
+
484
+ .. math::
485
+ \begin{aligned}
486
+ \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
487
+ & \text{input}(N_i, \text{stride[0]} \times d + k,
488
+ \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
489
+ \end{aligned}
490
+
491
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
492
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
493
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
494
+
495
+
496
+ Shape:
497
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
498
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
499
+
500
+ .. math::
501
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
502
+ (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
503
+
504
+ .. math::
505
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
506
+ (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
507
+
508
+ .. math::
509
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
510
+ (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
511
+
512
+ Examples::
513
+
514
+ >>> import brainstate as bst
515
+ >>> # pool of square window of size=3, stride=2
516
+ >>> m = MaxPool3d(3, stride=2)
517
+ >>> # pool of non-square window
518
+ >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
519
+ >>> input = bst.random.randn(20, 50, 44, 31, 16)
520
+ >>> output = m(input)
521
+ >>> output.shape
522
+ (20, 24, 43, 15, 16)
523
+
524
+ .. _link:
525
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
526
+
527
+ Parameters
528
+ ----------
529
+ in_size: Sequence of int
530
+ The shape of the input tensor.
531
+ kernel_size: int, sequence of int
532
+ An integer, or a sequence of integers defining the window to reduce over.
533
+ stride: int, sequence of int
534
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
535
+ padding: str, int, sequence of tuple
536
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
537
+ of n `(low, high)` integer pairs that give the padding to apply before
538
+ and after each spatial dimension.
539
+ channel_axis: int, optional
540
+ Axis of the spatial channels for which pooling is skipped.
541
+ If ``None``, there is no channel axis.
542
+ name: optional, str
543
+ The object name.
544
+
545
+ """
546
+ __module__ = 'brainstate.nn'
547
+
548
+ def __init__(
549
+ self,
550
+ kernel_size: Size,
551
+ stride: Union[int, Sequence[int]] = None,
552
+ padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
553
+ channel_axis: Optional[int] = -1,
554
+ name: Optional[str] = None,
555
+ in_size: Optional[Size] = None,
556
+ ):
557
+ super().__init__(in_size=in_size,
558
+ init_value=-jax.numpy.inf,
559
+ computation=jax.lax.max,
560
+ pool_dim=3,
561
+ kernel_size=kernel_size,
562
+ stride=stride,
563
+ padding=padding,
564
+ channel_axis=channel_axis,
565
+ name=name)
566
+
567
+
568
+ class AvgPool1d(_AvgPool):
569
+ r"""Applies a 1D average pooling over an input signal composed of several input planes.
570
+
571
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
572
+ output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
573
+ can be precisely described as:
574
+
575
+ .. math::
576
+
577
+ \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
578
+ \text{input}(N_i, \text{stride} \times l + m, C_j)
579
+
580
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
581
+ for :attr:`padding` number of points.
582
+
583
+ Shape:
584
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
585
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
586
+
587
+ .. math::
588
+ L_{out} = \left\lfloor \frac{L_{in} +
589
+ 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
590
+
591
+ Examples::
592
+
593
+ >>> import brainstate as bst
594
+ >>> # pool with window of size=3, stride=2
595
+ >>> m = AvgPool1d(3, stride=2)
596
+ >>> input = bst.random.randn(20, 50, 16)
597
+ >>> m(input).shape
598
+ (20, 24, 16)
599
+
600
+ Parameters
601
+ ----------
602
+ in_size: Sequence of int
603
+ The shape of the input tensor.
604
+ kernel_size: int, sequence of int
605
+ An integer, or a sequence of integers defining the window to reduce over.
606
+ stride: int, sequence of int
607
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
608
+ padding: str, int, sequence of tuple
609
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
610
+ of n `(low, high)` integer pairs that give the padding to apply before
611
+ and after each spatial dimension.
612
+ channel_axis: int, optional
613
+ Axis of the spatial channels for which pooling is skipped.
614
+ If ``None``, there is no channel axis.
615
+ name: optional, str
616
+ The object name.
617
+
618
+ """
619
+ __module__ = 'brainstate.nn'
620
+
621
+ def __init__(
622
+ self,
623
+ kernel_size: Size,
624
+ stride: Union[int, Sequence[int]] = 1,
625
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
626
+ channel_axis: Optional[int] = -1,
627
+ name: Optional[str] = None,
628
+ in_size: Optional[Size] = None,
629
+ ):
630
+ super().__init__(in_size=in_size,
631
+ init_value=0.,
632
+ computation=jax.lax.add,
633
+ pool_dim=1,
634
+ kernel_size=kernel_size,
635
+ stride=stride,
636
+ padding=padding,
637
+ channel_axis=channel_axis,
638
+ name=name)
639
+
640
+
641
+ class AvgPool2d(_AvgPool):
642
+ r"""Applies a 2D average pooling over an input signal composed of several input planes.
643
+
644
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
645
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
646
+ can be precisely described as:
647
+
648
+ .. math::
649
+
650
+ out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
651
+ input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
652
+
653
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
654
+ for :attr:`padding` number of points.
655
+
656
+ Shape:
657
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
658
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
659
+
660
+ .. math::
661
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
662
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
663
+
664
+ .. math::
665
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
666
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
667
+
668
+ Examples::
669
+
670
+ >>> import brainstate as bst
671
+ >>> # pool of square window of size=3, stride=2
672
+ >>> m = AvgPool2d(3, stride=2)
673
+ >>> # pool of non-square window
674
+ >>> m = AvgPool2d((3, 2), stride=(2, 1))
675
+ >>> input = bst.random.randn(20, 50, 32, , 16)
676
+ >>> output = m(input)
677
+ >>> output.shape
678
+ (20, 24, 31, 16)
679
+
680
+ Parameters
681
+ ----------
682
+ in_size: Sequence of int
683
+ The shape of the input tensor.
684
+ kernel_size: int, sequence of int
685
+ An integer, or a sequence of integers defining the window to reduce over.
686
+ stride: int, sequence of int
687
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
688
+ padding: str, int, sequence of tuple
689
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
690
+ of n `(low, high)` integer pairs that give the padding to apply before
691
+ and after each spatial dimension.
692
+ channel_axis: int, optional
693
+ Axis of the spatial channels for which pooling is skipped.
694
+ If ``None``, there is no channel axis.
695
+ name: optional, str
696
+ The object name.
697
+ """
698
+ __module__ = 'brainstate.nn'
699
+
700
+ def __init__(
701
+ self,
702
+ kernel_size: Size,
703
+ stride: Union[int, Sequence[int]] = 1,
704
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
705
+ channel_axis: Optional[int] = -1,
706
+ name: Optional[str] = None,
707
+ in_size: Optional[Size] = None,
708
+ ):
709
+ super().__init__(in_size=in_size,
710
+ init_value=0.,
711
+ computation=jax.lax.add,
712
+ pool_dim=2,
713
+ kernel_size=kernel_size,
714
+ stride=stride,
715
+ padding=padding,
716
+ channel_axis=channel_axis,
717
+ name=name)
718
+
719
+
720
+ class AvgPool3d(_AvgPool):
721
+ r"""Applies a 3D average pooling over an input signal composed of several input planes.
722
+
723
+
724
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
725
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
726
+ can be precisely described as:
727
+
728
+ .. math::
729
+ \begin{aligned}
730
+ \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
731
+ & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
732
+ \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
733
+ {kD \times kH \times kW}
734
+ \end{aligned}
735
+
736
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
737
+ for :attr:`padding` number of points.
738
+
739
+ Shape:
740
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
741
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
742
+ :math:`(D_{out}, H_{out}, W_{out}, C)`, where
743
+
744
+ .. math::
745
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
746
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
747
+
748
+ .. math::
749
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
750
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
751
+
752
+ .. math::
753
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
754
+ \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
755
+
756
+ Examples::
757
+
758
+ >>> import brainstate as bst
759
+ >>> # pool of square window of size=3, stride=2
760
+ >>> m = AvgPool3d(3, stride=2)
761
+ >>> # pool of non-square window
762
+ >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
763
+ >>> input = bst.random.randn(20, 50, 44, 31, 16)
764
+ >>> output = m(input)
765
+ >>> output.shape
766
+ (20, 24, 43, 15, 16)
767
+
768
+ Parameters
769
+ ----------
770
+ in_size: Sequence of int
771
+ The shape of the input tensor.
772
+ kernel_size: int, sequence of int
773
+ An integer, or a sequence of integers defining the window to reduce over.
774
+ stride: int, sequence of int
775
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
776
+ padding: str, int, sequence of tuple
777
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
778
+ of n `(low, high)` integer pairs that give the padding to apply before
779
+ and after each spatial dimension.
780
+ channel_axis: int, optional
781
+ Axis of the spatial channels for which pooling is skipped.
782
+ If ``None``, there is no channel axis.
783
+ name: optional, str
784
+ The object name.
785
+
786
+ """
787
+ __module__ = 'brainstate.nn'
788
+
789
+ def __init__(
790
+ self,
791
+ kernel_size: Size,
792
+ stride: Union[int, Sequence[int]] = 1,
793
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
794
+ channel_axis: Optional[int] = -1,
795
+ name: Optional[str] = None,
796
+ in_size: Optional[Size] = None,
797
+ ):
798
+ super().__init__(in_size=in_size,
799
+ init_value=0.,
800
+ computation=jax.lax.add,
801
+ pool_dim=3,
802
+ kernel_size=kernel_size,
803
+ stride=stride,
804
+ padding=padding,
805
+ channel_axis=channel_axis,
806
+ name=name)
807
+
808
+
809
+ def _adaptive_pool1d(x, target_size: int, operation: Callable):
810
+ """Adaptive pool 1D.
811
+
812
+ Args:
813
+ x: The input. Should be a JAX array of shape `(dim,)`.
814
+ target_size: The shape of the output after the pooling operation `(target_size,)`.
815
+ operation: The pooling operation to be performed on the input array.
816
+
817
+ Returns:
818
+ A JAX array of shape `(target_size, )`.
819
+ """
820
+ size = jnp.size(x)
821
+ num_head_arrays = size % target_size
822
+ num_block = size // target_size
823
+ if num_head_arrays != 0:
824
+ head_end_index = num_head_arrays * (num_block + 1)
825
+ heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
826
+ tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
827
+ outs = jnp.concatenate([heads, tails])
828
+ else:
829
+ outs = jax.vmap(operation)(x.reshape(-1, num_block))
830
+ return outs
831
+
832
+
833
+ def _generate_vmap(fun: Callable, map_axes: List[int]):
834
+ map_axes = sorted(map_axes)
835
+ for axis in map_axes:
836
+ fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
837
+ return fun
838
+
839
+
840
+ class _AdaptivePool(Module):
841
+ """General N dimensional adaptive down-sampling to a target shape.
842
+
843
+ Parameters
844
+ ----------
845
+ in_size: Sequence of int
846
+ The shape of the input tensor.
847
+ target_size: int, sequence of int
848
+ The target output shape.
849
+ num_spatial_dims: int
850
+ The number of spatial dimensions.
851
+ channel_axis: int, optional
852
+ Axis of the spatial channels for which pooling is skipped.
853
+ If ``None``, there is no channel axis.
854
+ operation: Callable
855
+ The down-sampling operation.
856
+ name: str
857
+ The class name.
858
+ """
859
+
860
+ def __init__(
861
+ self,
862
+ in_size: Size,
863
+ target_size: Size,
864
+ num_spatial_dims: int,
865
+ operation: Callable,
866
+ channel_axis: Optional[int] = -1,
867
+ name: Optional[str] = None,
868
+ ):
869
+ super().__init__(name=name)
870
+
871
+ self.channel_axis = channel_axis
872
+ self.operation = operation
873
+ if isinstance(target_size, int):
874
+ self.target_shape = (target_size,) * num_spatial_dims
875
+ elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
876
+ self.target_shape = target_size
877
+ else:
878
+ raise ValueError("`target_size` must either be an int or tuple of length "
879
+ f"{num_spatial_dims} containing ints.")
880
+
881
+ # in & out shapes
882
+ if in_size is not None:
883
+ in_size = tuple(in_size)
884
+ self.in_size = in_size
885
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
886
+ self.out_size = y.shape[1:]
887
+
888
+ def update(self, x):
889
+ """Input-output mapping.
890
+
891
+ Parameters
892
+ ----------
893
+ x: Array
894
+ Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
895
+ or `(..., dim_1, dim_2)`.
896
+ """
897
+ # channel axis
898
+ channel_axis = self.channel_axis
899
+
900
+ if channel_axis:
901
+ if not 0 <= abs(channel_axis) < x.ndim:
902
+ raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
903
+ if channel_axis < 0:
904
+ channel_axis = x.ndim + channel_axis
905
+ # input dimension
906
+ if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
907
+ raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
908
+ f"dimensions (channel_axis={self.channel_axis}). "
909
+ f"But got {x.ndim} dimensions.")
910
+ # pooling dimensions
911
+ pool_dims = list(range(x.ndim))
912
+ if channel_axis:
913
+ pool_dims.pop(channel_axis)
914
+
915
+ # pooling
916
+ for i, di in enumerate(pool_dims[-len(self.target_shape):]):
917
+ poo_axes = [j for j in range(x.ndim) if j != di]
918
+ op = _generate_vmap(_adaptive_pool1d, poo_axes)
919
+ x = op(x, self.target_shape[i], self.operation)
920
+ return x
921
+
922
+
923
+ class AdaptiveAvgPool1d(_AdaptivePool):
924
+ r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
925
+
926
+ The output size is :math:`L_{out}`, for any input size.
927
+ The number of output features is equal to the number of input planes.
928
+
929
+ Shape:
930
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
931
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
932
+ :math:`L_{out}=\text{output\_size}`.
933
+
934
+ Examples:
935
+
936
+ >>> import brainstate as bst
937
+ >>> # target output size of 5
938
+ >>> m = AdaptiveMaxPool1d(5)
939
+ >>> input = bst.random.randn(1, 64, 8)
940
+ >>> output = m(input)
941
+ >>> output.shape
942
+ (1, 5, 8)
943
+
944
+ Parameters
945
+ ----------
946
+ in_size: Sequence of int
947
+ The shape of the input tensor.
948
+ target_size: int, sequence of int
949
+ The target output shape.
950
+ channel_axis: int, optional
951
+ Axis of the spatial channels for which pooling is skipped.
952
+ If ``None``, there is no channel axis.
953
+ name: str
954
+ The class name.
955
+ """
956
+ __module__ = 'brainstate.nn'
957
+
958
+ def __init__(self,
959
+ target_size: Union[int, Sequence[int]],
960
+ channel_axis: Optional[int] = -1,
961
+ name: Optional[str] = None,
962
+ in_size: Optional[Sequence[int]] = None, ):
963
+ super().__init__(in_size=in_size,
964
+ target_size=target_size,
965
+ channel_axis=channel_axis,
966
+ num_spatial_dims=1,
967
+ operation=jnp.mean,
968
+ name=name)
969
+
970
+
971
+ class AdaptiveAvgPool2d(_AdaptivePool):
972
+ r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
973
+
974
+ The output is of size :math:`H_{out} \times W_{out}`, for any input size.
975
+ The number of output features is equal to the number of input planes.
976
+
977
+ Shape:
978
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
979
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
980
+ :math:`(H_{out}, W_{out})=\text{output\_size}`.
981
+
982
+ Examples:
983
+
984
+ >>> import brainstate as bst
985
+ >>> # target output size of 5x7
986
+ >>> m = AdaptiveMaxPool2d((5, 7))
987
+ >>> input = bst.random.randn(1, 8, 9, 64)
988
+ >>> output = m(input)
989
+ >>> output.shape
990
+ (1, 5, 7, 64)
991
+ >>> # target output size of 7x7 (square)
992
+ >>> m = AdaptiveMaxPool2d(7)
993
+ >>> input = bst.random.randn(1, 10, 9, 64)
994
+ >>> output = m(input)
995
+ >>> output.shape
996
+ (1, 7, 7, 64)
997
+ >>> # target output size of 10x7
998
+ >>> m = AdaptiveMaxPool2d((None, 7))
999
+ >>> input = bst.random.randn(1, 10, 9, 64)
1000
+ >>> output = m(input)
1001
+ >>> output.shape
1002
+ (1, 10, 7, 64)
1003
+
1004
+ Parameters
1005
+ ----------
1006
+ in_size: Sequence of int
1007
+ The shape of the input tensor.
1008
+ target_size: int, sequence of int
1009
+ The target output shape.
1010
+ channel_axis: int, optional
1011
+ Axis of the spatial channels for which pooling is skipped.
1012
+ If ``None``, there is no channel axis.
1013
+ name: str
1014
+ The class name.
1015
+ """
1016
+ __module__ = 'brainstate.nn'
1017
+
1018
+ def __init__(self,
1019
+ target_size: Union[int, Sequence[int]],
1020
+ channel_axis: Optional[int] = -1,
1021
+ name: Optional[str] = None,
1022
+
1023
+ in_size: Optional[Sequence[int]] = None, ):
1024
+ super().__init__(in_size=in_size,
1025
+ target_size=target_size,
1026
+ channel_axis=channel_axis,
1027
+ num_spatial_dims=2,
1028
+ operation=jnp.mean,
1029
+ name=name)
1030
+
1031
+
1032
+ class AdaptiveAvgPool3d(_AdaptivePool):
1033
+ r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1034
+
1035
+ The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1036
+ The number of output features is equal to the number of input planes.
1037
+
1038
+ Shape:
1039
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1040
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1041
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1042
+
1043
+ Examples:
1044
+
1045
+ >>> import brainstate as bst
1046
+ >>> # target output size of 5x7x9
1047
+ >>> m = AdaptiveMaxPool3d((5, 7, 9))
1048
+ >>> input = bst.random.randn(1, 8, 9, 10, 64)
1049
+ >>> output = m(input)
1050
+ >>> output.shape
1051
+ (1, 5, 7, 9, 64)
1052
+ >>> # target output size of 7x7x7 (cube)
1053
+ >>> m = AdaptiveMaxPool3d(7)
1054
+ >>> input = bst.random.randn(1, 10, 9, 8, 64)
1055
+ >>> output = m(input)
1056
+ >>> output.shape
1057
+ (1, 7, 7, 7, 64)
1058
+ >>> # target output size of 7x9x8
1059
+ >>> m = AdaptiveMaxPool3d((7, None, None))
1060
+ >>> input = bst.random.randn(1, 10, 9, 8, 64)
1061
+ >>> output = m(input)
1062
+ >>> output.shape
1063
+ (1, 7, 9, 8, 64)
1064
+
1065
+ Parameters
1066
+ ----------
1067
+ in_size: Sequence of int
1068
+ The shape of the input tensor.
1069
+ target_size: int, sequence of int
1070
+ The target output shape.
1071
+ channel_axis: int, optional
1072
+ Axis of the spatial channels for which pooling is skipped.
1073
+ If ``None``, there is no channel axis.
1074
+ name: str
1075
+ The class name.
1076
+ """
1077
+ __module__ = 'brainstate.nn'
1078
+
1079
+ def __init__(self,
1080
+ target_size: Union[int, Sequence[int]],
1081
+ channel_axis: Optional[int] = -1,
1082
+ name: Optional[str] = None,
1083
+ in_size: Optional[Sequence[int]] = None, ):
1084
+ super().__init__(in_size=in_size,
1085
+ target_size=target_size,
1086
+ channel_axis=channel_axis,
1087
+ num_spatial_dims=3,
1088
+ operation=jnp.mean,
1089
+ name=name)
1090
+
1091
+
1092
+ class AdaptiveMaxPool1d(_AdaptivePool):
1093
+ """Adaptive one-dimensional maximum down-sampling.
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ in_size: Sequence of int
1098
+ The shape of the input tensor.
1099
+ target_size: int, sequence of int
1100
+ The target output shape.
1101
+ channel_axis: int, optional
1102
+ Axis of the spatial channels for which pooling is skipped.
1103
+ If ``None``, there is no channel axis.
1104
+ name: str
1105
+ The class name.
1106
+ """
1107
+ __module__ = 'brainstate.nn'
1108
+
1109
+ def __init__(self,
1110
+ target_size: Union[int, Sequence[int]],
1111
+ channel_axis: Optional[int] = -1,
1112
+ name: Optional[str] = None,
1113
+ in_size: Optional[Sequence[int]] = None, ):
1114
+ super().__init__(in_size=in_size,
1115
+ target_size=target_size,
1116
+ channel_axis=channel_axis,
1117
+ num_spatial_dims=1,
1118
+ operation=jnp.max,
1119
+ name=name)
1120
+
1121
+
1122
+ class AdaptiveMaxPool2d(_AdaptivePool):
1123
+ """Adaptive two-dimensional maximum down-sampling.
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ in_size: Sequence of int
1128
+ The shape of the input tensor.
1129
+ target_size: int, sequence of int
1130
+ The target output shape.
1131
+ channel_axis: int, optional
1132
+ Axis of the spatial channels for which pooling is skipped.
1133
+ If ``None``, there is no channel axis.
1134
+ name: str
1135
+ The class name.
1136
+ """
1137
+ __module__ = 'brainstate.nn'
1138
+
1139
+ def __init__(self,
1140
+ target_size: Union[int, Sequence[int]],
1141
+ channel_axis: Optional[int] = -1,
1142
+ name: Optional[str] = None,
1143
+ in_size: Optional[Sequence[int]] = None, ):
1144
+ super().__init__(in_size=in_size,
1145
+ target_size=target_size,
1146
+ channel_axis=channel_axis,
1147
+ num_spatial_dims=2,
1148
+ operation=jnp.max,
1149
+ name=name)
1150
+
1151
+
1152
+ class AdaptiveMaxPool3d(_AdaptivePool):
1153
+ """Adaptive three-dimensional maximum down-sampling.
1154
+
1155
+ Parameters
1156
+ ----------
1157
+ in_size: Sequence of int
1158
+ The shape of the input tensor.
1159
+ target_size: int, sequence of int
1160
+ The target output shape.
1161
+ channel_axis: int, optional
1162
+ Axis of the spatial channels for which pooling is skipped.
1163
+ If ``None``, there is no channel axis.
1164
+ name: str
1165
+ The class name.
1166
+ """
1167
+ __module__ = 'brainstate.nn'
1168
+
1169
+ def __init__(self,
1170
+ target_size: Union[int, Sequence[int]],
1171
+ channel_axis: Optional[int] = -1,
1172
+ name: Optional[str] = None,
1173
+ in_size: Optional[Sequence[int]] = None, ):
1174
+ super().__init__(in_size=in_size,
1175
+ target_size=target_size,
1176
+ channel_axis=channel_axis,
1177
+ num_spatial_dims=3,
1178
+ operation=jnp.max,
1179
+ name=name)