brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1228 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import functools
21
+ from typing import Sequence, Optional
22
+ from typing import Union, Tuple, Callable, List
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+
28
+ from ._base import DnnLayer, ExplicitInOutSize
29
+ from .. import environ, math
30
+ from ..mixin import Mode
31
+ from ..typing import Size
32
+
33
+ __all__ = [
34
+ 'Flatten', 'Unflatten',
35
+
36
+ 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
37
+ 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
38
+
39
+ 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
40
+ 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
41
+ ]
42
+
43
+
44
+ class Flatten(DnnLayer, ExplicitInOutSize):
45
+ r"""
46
+ Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
47
+
48
+ Shape:
49
+ - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
50
+ where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
51
+ number of dimensions including none.
52
+ - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
53
+
54
+ Args:
55
+ in_size: Sequence of int. The shape of the input tensor.
56
+ start_dim: first dim to flatten (default = 1).
57
+ end_dim: last dim to flatten (default = -1).
58
+
59
+ Examples::
60
+ >>> import brainstate as bst
61
+ >>> inp = bst.random.randn(32, 1, 5, 5)
62
+ >>> # With default parameters
63
+ >>> m = Flatten()
64
+ >>> output = m(inp)
65
+ >>> output.shape
66
+ (32, 25)
67
+ >>> # With non-default parameters
68
+ >>> m = Flatten(0, 2)
69
+ >>> output = m(inp)
70
+ >>> output.shape
71
+ (160, 5)
72
+ """
73
+ __module__ = 'brainstate.nn'
74
+
75
+ def __init__(
76
+ self,
77
+ start_dim: int = 0,
78
+ end_dim: int = -1,
79
+ in_size: Optional[Size] = None
80
+ ) -> None:
81
+ super().__init__()
82
+ self.start_dim = start_dim
83
+ self.end_dim = end_dim
84
+
85
+ if in_size is not None:
86
+ self.in_size = tuple(in_size)
87
+ y = jax.eval_shape(functools.partial(math.flatten, start_dim=start_dim, end_dim=end_dim),
88
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
89
+ self.out_size = y.shape
90
+
91
+ def update(self, x):
92
+ if self._in_size is None:
93
+ start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim
94
+ else:
95
+ assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
96
+ dim_diff = x.ndim - len(self.in_size)
97
+ if self.in_size != x.shape[dim_diff:]:
98
+ raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
99
+ if self.start_dim >= 0:
100
+ start_dim = self.start_dim + dim_diff
101
+ else:
102
+ start_dim = x.ndim + self.start_dim
103
+ return math.flatten(x, start_dim, self.end_dim)
104
+
105
+ def __repr__(self) -> str:
106
+ return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})'
107
+
108
+
109
+ class Unflatten(DnnLayer, ExplicitInOutSize):
110
+ r"""
111
+ Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
112
+
113
+ * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
114
+ be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
115
+
116
+ * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
117
+ a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
118
+ (tuple of `(name, size)` tuples) for `NamedTensor` input.
119
+
120
+ Shape:
121
+ - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
122
+ dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
123
+ - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
124
+ :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
125
+
126
+ Args:
127
+ dim: int, Dimension to be unflattened.
128
+ sizes: Sequence of int. New shape of the unflattened dimension.
129
+ in_size: Sequence of int. The shape of the input tensor.
130
+ """
131
+ __module__ = 'brainstate.nn'
132
+
133
+ def __init__(
134
+ self,
135
+ dim: int,
136
+ sizes: Size,
137
+ mode: Mode = None,
138
+ name: str = None,
139
+ in_size: Optional[Size] = None
140
+ ) -> None:
141
+ super().__init__(mode=mode, name=name)
142
+
143
+ self.dim = dim
144
+ self.sizes = sizes
145
+ if isinstance(sizes, (tuple, list)):
146
+ for idx, elem in enumerate(sizes):
147
+ if not isinstance(elem, int):
148
+ raise TypeError("unflattened sizes must be tuple of ints, " +
149
+ "but found element of type {} at pos {}".format(type(elem).__name__, idx))
150
+ else:
151
+ raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
152
+
153
+ if in_size is not None:
154
+ self.in_size = tuple(in_size)
155
+ y = jax.eval_shape(functools.partial(math.unflatten, dim=dim, sizes=sizes),
156
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
157
+ self.out_size = y.shape
158
+
159
+ def update(self, x):
160
+ return math.unflatten(x, self.dim, self.sizes)
161
+
162
+ def __repr__(self):
163
+ return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})'
164
+
165
+
166
+ class _MaxPool(DnnLayer, ExplicitInOutSize):
167
+ def __init__(
168
+ self,
169
+ init_value: float,
170
+ computation: Callable,
171
+ pool_dim: int,
172
+ kernel_size: Size,
173
+ stride: Union[int, Sequence[int]] = None,
174
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
175
+ channel_axis: Optional[int] = -1,
176
+ mode: Mode = None,
177
+ name: Optional[str] = None,
178
+ in_size: Optional[Size] = None,
179
+ ):
180
+ super().__init__(name=name, mode=mode)
181
+
182
+ self.init_value = init_value
183
+ self.computation = computation
184
+ self.pool_dim = pool_dim
185
+
186
+ # kernel_size
187
+ if isinstance(kernel_size, int):
188
+ kernel_size = (kernel_size,) * pool_dim
189
+ elif isinstance(kernel_size, Sequence):
190
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
191
+ assert all([isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
192
+ if len(kernel_size) != pool_dim:
193
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
194
+ else:
195
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
196
+ self.kernel_size = kernel_size
197
+
198
+ # stride
199
+ if stride is None:
200
+ stride = kernel_size
201
+ if isinstance(stride, int):
202
+ stride = (stride,) * pool_dim
203
+ elif isinstance(stride, Sequence):
204
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
205
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
206
+ if len(stride) != pool_dim:
207
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
208
+ else:
209
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
210
+ self.stride = stride
211
+
212
+ # padding
213
+ if isinstance(padding, str):
214
+ if padding not in ("SAME", "VALID"):
215
+ raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
216
+ elif isinstance(padding, int):
217
+ padding = [(padding, padding) for _ in range(pool_dim)]
218
+ elif isinstance(padding, (list, tuple)):
219
+ if isinstance(padding[0], int):
220
+ if len(padding) == pool_dim:
221
+ padding = [(x, x) for x in padding]
222
+ else:
223
+ raise ValueError(f'If padding is a sequence of ints, it '
224
+ f'should has the length of {pool_dim}.')
225
+ else:
226
+ if not all([isinstance(x, (tuple, list)) for x in padding]):
227
+ raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
228
+ if not all([len(x) == 2 for x in padding]):
229
+ raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
230
+ if len(padding) == 1:
231
+ padding = tuple(padding) * pool_dim
232
+ assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
233
+ else:
234
+ raise ValueError
235
+ self.padding = padding
236
+
237
+ # channel_axis
238
+ assert channel_axis is None or isinstance(channel_axis, int), \
239
+ f'channel_axis should be an int, but got {channel_axis}'
240
+ self.channel_axis = channel_axis
241
+
242
+ # in & out shapes
243
+ if in_size is not None:
244
+ in_size = tuple(in_size)
245
+ self.in_size = in_size
246
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
247
+ self.out_size = y.shape[1:]
248
+
249
+ def update(self, x):
250
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
251
+ if x.ndim < x_dim:
252
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
253
+ window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
254
+ stride = self._infer_shape(x.ndim, self.stride, 1)
255
+ padding = (self.padding if isinstance(self.padding, str) else
256
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
257
+ r = jax.lax.reduce_window(
258
+ x,
259
+ init_value=self.init_value,
260
+ computation=self.computation,
261
+ window_dimensions=window_shape,
262
+ window_strides=stride,
263
+ padding=padding
264
+ )
265
+ return r
266
+
267
+ def _infer_shape(self, x_dim, inputs, element):
268
+ channel_axis = self.channel_axis
269
+ if channel_axis and not 0 <= abs(channel_axis) < x_dim:
270
+ raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
271
+ if channel_axis and channel_axis < 0:
272
+ channel_axis = x_dim + channel_axis
273
+ all_dims = list(range(x_dim))
274
+ if channel_axis is not None:
275
+ all_dims.pop(channel_axis)
276
+ pool_dims = all_dims[-self.pool_dim:]
277
+ results = [element] * x_dim
278
+ for i, dim in enumerate(pool_dims):
279
+ results[dim] = inputs[i]
280
+ return results
281
+
282
+
283
+ class _AvgPool(_MaxPool):
284
+ def update(self, x):
285
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
286
+ if x.ndim < x_dim:
287
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
288
+ dims = self._infer_shape(x.ndim, self.kernel_size, 1)
289
+ stride = self._infer_shape(x.ndim, self.stride, 1)
290
+ padding = (self.padding if isinstance(self.padding, str) else
291
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
292
+ pooled = jax.lax.reduce_window(x,
293
+ init_value=self.init_value,
294
+ computation=self.computation,
295
+ window_dimensions=dims,
296
+ window_strides=stride,
297
+ padding=padding)
298
+ if padding == "VALID":
299
+ # Avoid the extra reduce_window.
300
+ return pooled / np.prod(dims)
301
+ else:
302
+ # Count the number of valid entries at each input point, then use that for
303
+ # computing average. Assumes that any two arrays of same shape will be
304
+ # padded the same.
305
+ window_counts = jax.lax.reduce_window(jnp.ones_like(x),
306
+ init_value=self.init_value,
307
+ computation=self.computation,
308
+ window_dimensions=dims,
309
+ window_strides=stride,
310
+ padding=padding)
311
+ assert pooled.shape == window_counts.shape
312
+ return pooled / window_counts
313
+
314
+
315
+ class MaxPool1d(_MaxPool):
316
+ r"""Applies a 1D max pooling over an input signal composed of several input planes.
317
+
318
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
319
+ and output :math:`(N, L_{out}, C)` can be precisely described as:
320
+
321
+ .. math::
322
+ out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
323
+ input(N_i, stride \times k + m, C_j)
324
+
325
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
326
+ for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
327
+ sliding window. This `link`_ has a nice visualization of the pooling parameters.
328
+
329
+ Shape:
330
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
331
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
332
+
333
+ .. math::
334
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
335
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
336
+
337
+
338
+ Examples::
339
+
340
+ >>> import brainstate as bst
341
+ >>> # pool of size=3, stride=2
342
+ >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
343
+ >>> input = bst.random.randn(20, 50, 16)
344
+ >>> output = m(input)
345
+ >>> output.shape
346
+ (20, 24, 16)
347
+
348
+ Parameters
349
+ ----------
350
+ in_size: Sequence of int
351
+ The shape of the input tensor.
352
+ kernel_size: int, sequence of int
353
+ An integer, or a sequence of integers defining the window to reduce over.
354
+ stride: int, sequence of int
355
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
356
+ padding: str, int, sequence of tuple
357
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
358
+ of n `(low, high)` integer pairs that give the padding to apply before
359
+ and after each spatial dimension.
360
+ channel_axis: int, optional
361
+ Axis of the spatial channels for which pooling is skipped.
362
+ If ``None``, there is no channel axis.
363
+ mode: Mode
364
+ The computation mode.
365
+ name: optional, str
366
+ The object name.
367
+
368
+ .. _link:
369
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
370
+ """
371
+ __module__ = 'brainstate.nn'
372
+
373
+ def __init__(
374
+ self,
375
+ kernel_size: Size,
376
+ stride: Union[int, Sequence[int]] = None,
377
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
378
+ channel_axis: Optional[int] = -1,
379
+ mode: Mode = None,
380
+ name: Optional[str] = None,
381
+ in_size: Optional[Size] = None,
382
+ ):
383
+ super().__init__(in_size=in_size,
384
+ init_value=-jax.numpy.inf,
385
+ computation=jax.lax.max,
386
+ pool_dim=1,
387
+ kernel_size=kernel_size,
388
+ stride=stride,
389
+ padding=padding,
390
+ channel_axis=channel_axis,
391
+ name=name,
392
+ mode=mode)
393
+
394
+
395
+ class MaxPool2d(_MaxPool):
396
+ r"""Applies a 2D max pooling over an input signal composed of several input planes.
397
+
398
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
399
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
400
+ can be precisely described as:
401
+
402
+ .. math::
403
+ \begin{aligned}
404
+ out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
405
+ & \text{input}(N_i, \text{stride[0]} \times h + m,
406
+ \text{stride[1]} \times w + n, C_j)
407
+ \end{aligned}
408
+
409
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
410
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
411
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
412
+
413
+
414
+ Shape:
415
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
416
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
417
+
418
+ .. math::
419
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
420
+ \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
421
+
422
+ .. math::
423
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
424
+ \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
425
+
426
+ Examples::
427
+
428
+ >>> import brainstate as bst
429
+ >>> # pool of square window of size=3, stride=2
430
+ >>> m = MaxPool2d(3, stride=2)
431
+ >>> # pool of non-square window
432
+ >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
433
+ >>> input = bst.random.randn(20, 50, 32, 16)
434
+ >>> output = m(input)
435
+ >>> output.shape
436
+ (20, 24, 31, 16)
437
+
438
+ .. _link:
439
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
440
+
441
+ Parameters
442
+ ----------
443
+ in_size: Sequence of int
444
+ The shape of the input tensor.
445
+ kernel_size: int, sequence of int
446
+ An integer, or a sequence of integers defining the window to reduce over.
447
+ stride: int, sequence of int
448
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
449
+ padding: str, int, sequence of tuple
450
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
451
+ of n `(low, high)` integer pairs that give the padding to apply before
452
+ and after each spatial dimension.
453
+ channel_axis: int, optional
454
+ Axis of the spatial channels for which pooling is skipped.
455
+ If ``None``, there is no channel axis.
456
+ mode: Mode
457
+ The computation mode.
458
+ name: optional, str
459
+ The object name.
460
+
461
+ """
462
+ __module__ = 'brainstate.nn'
463
+
464
+ def __init__(
465
+ self,
466
+ kernel_size: Size,
467
+ stride: Union[int, Sequence[int]] = None,
468
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
469
+ channel_axis: Optional[int] = -1,
470
+ mode: Mode = None,
471
+ name: Optional[str] = None,
472
+ in_size: Optional[Size] = None,
473
+ ):
474
+ super().__init__(in_size=in_size,
475
+ init_value=-jax.numpy.inf,
476
+ computation=jax.lax.max,
477
+ pool_dim=2,
478
+ kernel_size=kernel_size,
479
+ stride=stride,
480
+ padding=padding,
481
+ channel_axis=channel_axis,
482
+ name=name, mode=mode)
483
+
484
+
485
+ class MaxPool3d(_MaxPool):
486
+ r"""Applies a 3D max pooling over an input signal composed of several input planes.
487
+
488
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
489
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
490
+ can be precisely described as:
491
+
492
+ .. math::
493
+ \begin{aligned}
494
+ \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
495
+ & \text{input}(N_i, \text{stride[0]} \times d + k,
496
+ \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
497
+ \end{aligned}
498
+
499
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
500
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
501
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
502
+
503
+
504
+ Shape:
505
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
506
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
507
+
508
+ .. math::
509
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
510
+ (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
511
+
512
+ .. math::
513
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
514
+ (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
515
+
516
+ .. math::
517
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
518
+ (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
519
+
520
+ Examples::
521
+
522
+ >>> import brainstate as bst
523
+ >>> # pool of square window of size=3, stride=2
524
+ >>> m = MaxPool3d(3, stride=2)
525
+ >>> # pool of non-square window
526
+ >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
527
+ >>> input = bst.random.randn(20, 50, 44, 31, 16)
528
+ >>> output = m(input)
529
+ >>> output.shape
530
+ (20, 24, 43, 15, 16)
531
+
532
+ .. _link:
533
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
534
+
535
+ Parameters
536
+ ----------
537
+ in_size: Sequence of int
538
+ The shape of the input tensor.
539
+ kernel_size: int, sequence of int
540
+ An integer, or a sequence of integers defining the window to reduce over.
541
+ stride: int, sequence of int
542
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
543
+ padding: str, int, sequence of tuple
544
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
545
+ of n `(low, high)` integer pairs that give the padding to apply before
546
+ and after each spatial dimension.
547
+ channel_axis: int, optional
548
+ Axis of the spatial channels for which pooling is skipped.
549
+ If ``None``, there is no channel axis.
550
+ mode: Mode
551
+ The computation mode.
552
+ name: optional, str
553
+ The object name.
554
+
555
+ """
556
+ __module__ = 'brainstate.nn'
557
+
558
+ def __init__(
559
+ self,
560
+ kernel_size: Size,
561
+ stride: Union[int, Sequence[int]] = None,
562
+ padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
563
+ channel_axis: Optional[int] = -1,
564
+ mode: Mode = None,
565
+ name: Optional[str] = None,
566
+ in_size: Optional[Size] = None,
567
+ ):
568
+ super().__init__(in_size=in_size,
569
+ init_value=-jax.numpy.inf,
570
+ computation=jax.lax.max,
571
+ pool_dim=3,
572
+ kernel_size=kernel_size,
573
+ stride=stride,
574
+ padding=padding,
575
+ channel_axis=channel_axis,
576
+ name=name, mode=mode)
577
+
578
+
579
+ class AvgPool1d(_AvgPool):
580
+ r"""Applies a 1D average pooling over an input signal composed of several input planes.
581
+
582
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
583
+ output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
584
+ can be precisely described as:
585
+
586
+ .. math::
587
+
588
+ \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
589
+ \text{input}(N_i, \text{stride} \times l + m, C_j)
590
+
591
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
592
+ for :attr:`padding` number of points.
593
+
594
+ Shape:
595
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
596
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
597
+
598
+ .. math::
599
+ L_{out} = \left\lfloor \frac{L_{in} +
600
+ 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
601
+
602
+ Examples::
603
+
604
+ >>> import brainstate as bst
605
+ >>> # pool with window of size=3, stride=2
606
+ >>> m = AvgPool1d(3, stride=2)
607
+ >>> input = bst.random.randn(20, 50, 16)
608
+ >>> m(input).shape
609
+ (20, 24, 16)
610
+
611
+ Parameters
612
+ ----------
613
+ in_size: Sequence of int
614
+ The shape of the input tensor.
615
+ kernel_size: int, sequence of int
616
+ An integer, or a sequence of integers defining the window to reduce over.
617
+ stride: int, sequence of int
618
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
619
+ padding: str, int, sequence of tuple
620
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
621
+ of n `(low, high)` integer pairs that give the padding to apply before
622
+ and after each spatial dimension.
623
+ channel_axis: int, optional
624
+ Axis of the spatial channels for which pooling is skipped.
625
+ If ``None``, there is no channel axis.
626
+ mode: Mode
627
+ The computation mode.
628
+ name: optional, str
629
+ The object name.
630
+
631
+ """
632
+ __module__ = 'brainstate.nn'
633
+
634
+ def __init__(
635
+ self,
636
+ kernel_size: Size,
637
+ stride: Union[int, Sequence[int]] = 1,
638
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
639
+ channel_axis: Optional[int] = -1,
640
+ mode: Mode = None,
641
+ name: Optional[str] = None,
642
+ in_size: Optional[Size] = None,
643
+ ):
644
+ super().__init__(in_size=in_size,
645
+ init_value=0.,
646
+ computation=jax.lax.add,
647
+ pool_dim=1,
648
+ kernel_size=kernel_size,
649
+ stride=stride,
650
+ padding=padding,
651
+ channel_axis=channel_axis,
652
+ name=name,
653
+ mode=mode)
654
+
655
+
656
+ class AvgPool2d(_AvgPool):
657
+ r"""Applies a 2D average pooling over an input signal composed of several input planes.
658
+
659
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
660
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
661
+ can be precisely described as:
662
+
663
+ .. math::
664
+
665
+ out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
666
+ input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
667
+
668
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
669
+ for :attr:`padding` number of points.
670
+
671
+ Shape:
672
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
673
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
674
+
675
+ .. math::
676
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
677
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
678
+
679
+ .. math::
680
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
681
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
682
+
683
+ Examples::
684
+
685
+ >>> import brainstate as bst
686
+ >>> # pool of square window of size=3, stride=2
687
+ >>> m = AvgPool2d(3, stride=2)
688
+ >>> # pool of non-square window
689
+ >>> m = AvgPool2d((3, 2), stride=(2, 1))
690
+ >>> input = bst.random.randn(20, 50, 32, , 16)
691
+ >>> output = m(input)
692
+ >>> output.shape
693
+ (20, 24, 31, 16)
694
+
695
+ Parameters
696
+ ----------
697
+ in_size: Sequence of int
698
+ The shape of the input tensor.
699
+ kernel_size: int, sequence of int
700
+ An integer, or a sequence of integers defining the window to reduce over.
701
+ stride: int, sequence of int
702
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
703
+ padding: str, int, sequence of tuple
704
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
705
+ of n `(low, high)` integer pairs that give the padding to apply before
706
+ and after each spatial dimension.
707
+ channel_axis: int, optional
708
+ Axis of the spatial channels for which pooling is skipped.
709
+ If ``None``, there is no channel axis.
710
+ mode: Mode
711
+ The computation mode.
712
+ name: optional, str
713
+ The object name.
714
+ """
715
+ __module__ = 'brainstate.nn'
716
+
717
+ def __init__(
718
+ self,
719
+ kernel_size: Size,
720
+ stride: Union[int, Sequence[int]] = 1,
721
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
722
+ channel_axis: Optional[int] = -1,
723
+ mode: Mode = None,
724
+ name: Optional[str] = None,
725
+ in_size: Optional[Size] = None,
726
+ ):
727
+ super().__init__(in_size=in_size,
728
+ init_value=0.,
729
+ computation=jax.lax.add,
730
+ pool_dim=2,
731
+ kernel_size=kernel_size,
732
+ stride=stride,
733
+ padding=padding,
734
+ channel_axis=channel_axis,
735
+ name=name,
736
+ mode=mode)
737
+
738
+
739
+ class AvgPool3d(_AvgPool):
740
+ r"""Applies a 3D average pooling over an input signal composed of several input planes.
741
+
742
+
743
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
744
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
745
+ can be precisely described as:
746
+
747
+ .. math::
748
+ \begin{aligned}
749
+ \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
750
+ & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
751
+ \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
752
+ {kD \times kH \times kW}
753
+ \end{aligned}
754
+
755
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
756
+ for :attr:`padding` number of points.
757
+
758
+ Shape:
759
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
760
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
761
+ :math:`(D_{out}, H_{out}, W_{out}, C)`, where
762
+
763
+ .. math::
764
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
765
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
766
+
767
+ .. math::
768
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
769
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
770
+
771
+ .. math::
772
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
773
+ \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
774
+
775
+ Examples::
776
+
777
+ >>> import brainstate as bst
778
+ >>> # pool of square window of size=3, stride=2
779
+ >>> m = AvgPool3d(3, stride=2)
780
+ >>> # pool of non-square window
781
+ >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
782
+ >>> input = bst.random.randn(20, 50, 44, 31, 16)
783
+ >>> output = m(input)
784
+ >>> output.shape
785
+ (20, 24, 43, 15, 16)
786
+
787
+ Parameters
788
+ ----------
789
+ in_size: Sequence of int
790
+ The shape of the input tensor.
791
+ kernel_size: int, sequence of int
792
+ An integer, or a sequence of integers defining the window to reduce over.
793
+ stride: int, sequence of int
794
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
795
+ padding: str, int, sequence of tuple
796
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
797
+ of n `(low, high)` integer pairs that give the padding to apply before
798
+ and after each spatial dimension.
799
+ channel_axis: int, optional
800
+ Axis of the spatial channels for which pooling is skipped.
801
+ If ``None``, there is no channel axis.
802
+ mode: Mode
803
+ The computation mode.
804
+ name: optional, str
805
+ The object name.
806
+
807
+ """
808
+ __module__ = 'brainstate.nn'
809
+
810
+ def __init__(
811
+ self,
812
+ kernel_size: Size,
813
+ stride: Union[int, Sequence[int]] = 1,
814
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
815
+ channel_axis: Optional[int] = -1,
816
+ mode: Mode = None,
817
+ name: Optional[str] = None,
818
+ in_size: Optional[Size] = None,
819
+ ):
820
+ super().__init__(in_size=in_size,
821
+ init_value=0.,
822
+ computation=jax.lax.add,
823
+ pool_dim=3,
824
+ kernel_size=kernel_size,
825
+ stride=stride,
826
+ padding=padding,
827
+ channel_axis=channel_axis,
828
+ name=name,
829
+ mode=mode)
830
+
831
+
832
+ def _adaptive_pool1d(x, target_size: int, operation: Callable):
833
+ """Adaptive pool 1D.
834
+
835
+ Args:
836
+ x: The input. Should be a JAX array of shape `(dim,)`.
837
+ target_size: The shape of the output after the pooling operation `(target_size,)`.
838
+ operation: The pooling operation to be performed on the input array.
839
+
840
+ Returns:
841
+ A JAX array of shape `(target_size, )`.
842
+ """
843
+ size = jnp.size(x)
844
+ num_head_arrays = size % target_size
845
+ num_block = size // target_size
846
+ if num_head_arrays != 0:
847
+ head_end_index = num_head_arrays * (num_block + 1)
848
+ heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
849
+ tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
850
+ outs = jnp.concatenate([heads, tails])
851
+ else:
852
+ outs = jax.vmap(operation)(x.reshape(-1, num_block))
853
+ return outs
854
+
855
+
856
+ def _generate_vmap(fun: Callable, map_axes: List[int]):
857
+ map_axes = sorted(map_axes)
858
+ for axis in map_axes:
859
+ fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
860
+ return fun
861
+
862
+
863
+ class _AdaptivePool(DnnLayer, ExplicitInOutSize):
864
+ """General N dimensional adaptive down-sampling to a target shape.
865
+
866
+ Parameters
867
+ ----------
868
+ in_size: Sequence of int
869
+ The shape of the input tensor.
870
+ target_size: int, sequence of int
871
+ The target output shape.
872
+ num_spatial_dims: int
873
+ The number of spatial dimensions.
874
+ channel_axis: int, optional
875
+ Axis of the spatial channels for which pooling is skipped.
876
+ If ``None``, there is no channel axis.
877
+ operation: Callable
878
+ The down-sampling operation.
879
+ name: str
880
+ The class name.
881
+ mode: Mode
882
+ The computing mode.
883
+ """
884
+
885
+ def __init__(
886
+ self,
887
+ in_size: Size,
888
+ target_size: Size,
889
+ num_spatial_dims: int,
890
+ operation: Callable,
891
+ channel_axis: Optional[int] = -1,
892
+ name: Optional[str] = None,
893
+ mode: Optional[Mode] = None,
894
+ ):
895
+ super().__init__(name=name, mode=mode)
896
+
897
+ self.channel_axis = channel_axis
898
+ self.operation = operation
899
+ if isinstance(target_size, int):
900
+ self.target_shape = (target_size,) * num_spatial_dims
901
+ elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
902
+ self.target_shape = target_size
903
+ else:
904
+ raise ValueError("`target_size` must either be an int or tuple of length "
905
+ f"{num_spatial_dims} containing ints.")
906
+
907
+ # in & out shapes
908
+ if in_size is not None:
909
+ in_size = tuple(in_size)
910
+ self.in_size = in_size
911
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
912
+ self.out_size = y.shape[1:]
913
+
914
+ def update(self, x):
915
+ """Input-output mapping.
916
+
917
+ Parameters
918
+ ----------
919
+ x: Array
920
+ Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
921
+ or `(..., dim_1, dim_2)`.
922
+ """
923
+ # channel axis
924
+ channel_axis = self.channel_axis
925
+
926
+ if channel_axis:
927
+ if not 0 <= abs(channel_axis) < x.ndim:
928
+ raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
929
+ if channel_axis < 0:
930
+ channel_axis = x.ndim + channel_axis
931
+ # input dimension
932
+ if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
933
+ raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
934
+ f"dimensions (channel_axis={self.channel_axis}). "
935
+ f"But got {x.ndim} dimensions.")
936
+ # pooling dimensions
937
+ pool_dims = list(range(x.ndim))
938
+ if channel_axis:
939
+ pool_dims.pop(channel_axis)
940
+
941
+ # pooling
942
+ for i, di in enumerate(pool_dims[-len(self.target_shape):]):
943
+ poo_axes = [j for j in range(x.ndim) if j != di]
944
+ op = _generate_vmap(_adaptive_pool1d, poo_axes)
945
+ x = op(x, self.target_shape[i], self.operation)
946
+ return x
947
+
948
+
949
+ class AdaptiveAvgPool1d(_AdaptivePool):
950
+ r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
951
+
952
+ The output size is :math:`L_{out}`, for any input size.
953
+ The number of output features is equal to the number of input planes.
954
+
955
+ Shape:
956
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
957
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
958
+ :math:`L_{out}=\text{output\_size}`.
959
+
960
+ Examples:
961
+
962
+ >>> import brainstate as bst
963
+ >>> # target output size of 5
964
+ >>> m = AdaptiveMaxPool1d(5)
965
+ >>> input = bst.random.randn(1, 64, 8)
966
+ >>> output = m(input)
967
+ >>> output.shape
968
+ (1, 5, 8)
969
+
970
+ Parameters
971
+ ----------
972
+ in_size: Sequence of int
973
+ The shape of the input tensor.
974
+ target_size: int, sequence of int
975
+ The target output shape.
976
+ channel_axis: int, optional
977
+ Axis of the spatial channels for which pooling is skipped.
978
+ If ``None``, there is no channel axis.
979
+ name: str
980
+ The class name.
981
+ mode: Mode
982
+ The computing mode.
983
+ """
984
+ __module__ = 'brainstate.nn'
985
+
986
+ def __init__(self,
987
+ target_size: Union[int, Sequence[int]],
988
+ channel_axis: Optional[int] = -1,
989
+ name: Optional[str] = None,
990
+ mode: Optional[Mode] = None,
991
+ in_size: Optional[Sequence[int]] = None, ):
992
+ super().__init__(in_size=in_size,
993
+ target_size=target_size,
994
+ channel_axis=channel_axis,
995
+ num_spatial_dims=1,
996
+ operation=jnp.mean,
997
+ name=name,
998
+ mode=mode)
999
+
1000
+
1001
+ class AdaptiveAvgPool2d(_AdaptivePool):
1002
+ r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
1003
+
1004
+ The output is of size :math:`H_{out} \times W_{out}`, for any input size.
1005
+ The number of output features is equal to the number of input planes.
1006
+
1007
+ Shape:
1008
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
1009
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1010
+ :math:`(H_{out}, W_{out})=\text{output\_size}`.
1011
+
1012
+ Examples:
1013
+
1014
+ >>> import brainstate as bst
1015
+ >>> # target output size of 5x7
1016
+ >>> m = AdaptiveMaxPool2d((5, 7))
1017
+ >>> input = bst.random.randn(1, 8, 9, 64)
1018
+ >>> output = m(input)
1019
+ >>> output.shape
1020
+ (1, 5, 7, 64)
1021
+ >>> # target output size of 7x7 (square)
1022
+ >>> m = AdaptiveMaxPool2d(7)
1023
+ >>> input = bst.random.randn(1, 10, 9, 64)
1024
+ >>> output = m(input)
1025
+ >>> output.shape
1026
+ (1, 7, 7, 64)
1027
+ >>> # target output size of 10x7
1028
+ >>> m = AdaptiveMaxPool2d((None, 7))
1029
+ >>> input = bst.random.randn(1, 10, 9, 64)
1030
+ >>> output = m(input)
1031
+ >>> output.shape
1032
+ (1, 10, 7, 64)
1033
+
1034
+ Parameters
1035
+ ----------
1036
+ in_size: Sequence of int
1037
+ The shape of the input tensor.
1038
+ target_size: int, sequence of int
1039
+ The target output shape.
1040
+ channel_axis: int, optional
1041
+ Axis of the spatial channels for which pooling is skipped.
1042
+ If ``None``, there is no channel axis.
1043
+ name: str
1044
+ The class name.
1045
+ mode: Mode
1046
+ The computing mode.
1047
+ """
1048
+ __module__ = 'brainstate.nn'
1049
+
1050
+ def __init__(self,
1051
+ target_size: Union[int, Sequence[int]],
1052
+ channel_axis: Optional[int] = -1,
1053
+ name: Optional[str] = None,
1054
+ mode: Optional[Mode] = None,
1055
+ in_size: Optional[Sequence[int]] = None, ):
1056
+ super().__init__(in_size=in_size,
1057
+ target_size=target_size,
1058
+ channel_axis=channel_axis,
1059
+ num_spatial_dims=2,
1060
+ operation=jnp.mean,
1061
+ name=name,
1062
+ mode=mode)
1063
+
1064
+
1065
+ class AdaptiveAvgPool3d(_AdaptivePool):
1066
+ r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1067
+
1068
+ The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1069
+ The number of output features is equal to the number of input planes.
1070
+
1071
+ Shape:
1072
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1073
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1074
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1075
+
1076
+ Examples:
1077
+
1078
+ >>> import brainstate as bst
1079
+ >>> # target output size of 5x7x9
1080
+ >>> m = AdaptiveMaxPool3d((5, 7, 9))
1081
+ >>> input = bst.random.randn(1, 8, 9, 10, 64)
1082
+ >>> output = m(input)
1083
+ >>> output.shape
1084
+ (1, 5, 7, 9, 64)
1085
+ >>> # target output size of 7x7x7 (cube)
1086
+ >>> m = AdaptiveMaxPool3d(7)
1087
+ >>> input = bst.random.randn(1, 10, 9, 8, 64)
1088
+ >>> output = m(input)
1089
+ >>> output.shape
1090
+ (1, 7, 7, 7, 64)
1091
+ >>> # target output size of 7x9x8
1092
+ >>> m = AdaptiveMaxPool3d((7, None, None))
1093
+ >>> input = bst.random.randn(1, 10, 9, 8, 64)
1094
+ >>> output = m(input)
1095
+ >>> output.shape
1096
+ (1, 7, 9, 8, 64)
1097
+
1098
+ Parameters
1099
+ ----------
1100
+ in_size: Sequence of int
1101
+ The shape of the input tensor.
1102
+ target_size: int, sequence of int
1103
+ The target output shape.
1104
+ channel_axis: int, optional
1105
+ Axis of the spatial channels for which pooling is skipped.
1106
+ If ``None``, there is no channel axis.
1107
+ name: str
1108
+ The class name.
1109
+ mode: Mode
1110
+ The computing mode.
1111
+ """
1112
+ __module__ = 'brainstate.nn'
1113
+
1114
+ def __init__(self,
1115
+ target_size: Union[int, Sequence[int]],
1116
+ channel_axis: Optional[int] = -1,
1117
+ name: Optional[str] = None,
1118
+ mode: Optional[Mode] = None,
1119
+ in_size: Optional[Sequence[int]] = None, ):
1120
+ super().__init__(in_size=in_size,
1121
+ target_size=target_size,
1122
+ channel_axis=channel_axis,
1123
+ num_spatial_dims=3,
1124
+ operation=jnp.mean,
1125
+ name=name,
1126
+ mode=mode)
1127
+
1128
+
1129
+ class AdaptiveMaxPool1d(_AdaptivePool):
1130
+ """Adaptive one-dimensional maximum down-sampling.
1131
+
1132
+ Parameters
1133
+ ----------
1134
+ in_size: Sequence of int
1135
+ The shape of the input tensor.
1136
+ target_size: int, sequence of int
1137
+ The target output shape.
1138
+ channel_axis: int, optional
1139
+ Axis of the spatial channels for which pooling is skipped.
1140
+ If ``None``, there is no channel axis.
1141
+ name: str
1142
+ The class name.
1143
+ mode: Mode
1144
+ The computing mode.
1145
+ """
1146
+ __module__ = 'brainstate.nn'
1147
+
1148
+ def __init__(self,
1149
+ target_size: Union[int, Sequence[int]],
1150
+ channel_axis: Optional[int] = -1,
1151
+ name: Optional[str] = None,
1152
+ mode: Optional[Mode] = None,
1153
+ in_size: Optional[Sequence[int]] = None, ):
1154
+ super().__init__(in_size=in_size,
1155
+ target_size=target_size,
1156
+ channel_axis=channel_axis,
1157
+ num_spatial_dims=1,
1158
+ operation=jnp.max,
1159
+ name=name,
1160
+ mode=mode)
1161
+
1162
+
1163
+ class AdaptiveMaxPool2d(_AdaptivePool):
1164
+ """Adaptive two-dimensional maximum down-sampling.
1165
+
1166
+ Parameters
1167
+ ----------
1168
+ in_size: Sequence of int
1169
+ The shape of the input tensor.
1170
+ target_size: int, sequence of int
1171
+ The target output shape.
1172
+ channel_axis: int, optional
1173
+ Axis of the spatial channels for which pooling is skipped.
1174
+ If ``None``, there is no channel axis.
1175
+ name: str
1176
+ The class name.
1177
+ mode: Mode
1178
+ The computing mode.
1179
+ """
1180
+ __module__ = 'brainstate.nn'
1181
+
1182
+ def __init__(self,
1183
+ target_size: Union[int, Sequence[int]],
1184
+ channel_axis: Optional[int] = -1,
1185
+ name: Optional[str] = None,
1186
+ mode: Optional[Mode] = None,
1187
+ in_size: Optional[Sequence[int]] = None, ):
1188
+ super().__init__(in_size=in_size,
1189
+ target_size=target_size,
1190
+ channel_axis=channel_axis,
1191
+ num_spatial_dims=2,
1192
+ operation=jnp.max,
1193
+ name=name,
1194
+ mode=mode)
1195
+
1196
+
1197
+ class AdaptiveMaxPool3d(_AdaptivePool):
1198
+ """Adaptive three-dimensional maximum down-sampling.
1199
+
1200
+ Parameters
1201
+ ----------
1202
+ in_size: Sequence of int
1203
+ The shape of the input tensor.
1204
+ target_size: int, sequence of int
1205
+ The target output shape.
1206
+ channel_axis: int, optional
1207
+ Axis of the spatial channels for which pooling is skipped.
1208
+ If ``None``, there is no channel axis.
1209
+ name: str
1210
+ The class name.
1211
+ mode: Mode
1212
+ The computing mode.
1213
+ """
1214
+ __module__ = 'brainstate.nn'
1215
+
1216
+ def __init__(self,
1217
+ target_size: Union[int, Sequence[int]],
1218
+ channel_axis: Optional[int] = -1,
1219
+ name: Optional[str] = None,
1220
+ mode: Optional[Mode] = None,
1221
+ in_size: Optional[Sequence[int]] = None, ):
1222
+ super().__init__(in_size=in_size,
1223
+ target_size=target_size,
1224
+ channel_axis=channel_axis,
1225
+ num_spatial_dims=3,
1226
+ operation=jnp.max,
1227
+ name=name,
1228
+ mode=mode)