brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,2239 +1,2239 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import functools
19
- from typing import Sequence, Optional
20
- from typing import Union, Tuple, Callable, List
21
-
22
- import brainunit as u
23
- import jax
24
- import jax.numpy as jnp
25
- import numpy as np
26
-
27
- from brainstate import environ
28
- from brainstate.typing import Size
29
- from ._module import Module
30
-
31
- __all__ = [
32
- 'Flatten', 'Unflatten',
33
-
34
- 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
35
- 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
36
- 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
37
- 'LPPool1d', 'LPPool2d', 'LPPool3d',
38
-
39
- 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
40
- 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
41
- ]
42
-
43
-
44
- class Flatten(Module):
45
- r"""
46
- Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
47
-
48
- Shape:
49
- - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
50
- where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
51
- number of dimensions including none.
52
- - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
53
-
54
- Parameters
55
- ----------
56
- start_axis : int, optional
57
- First dim to flatten (default = 0).
58
- end_axis : int, optional
59
- Last dim to flatten (default = -1).
60
- in_size : Sequence of int, optional
61
- The shape of the input tensor.
62
-
63
- Examples
64
- --------
65
- .. code-block:: python
66
-
67
- >>> import brainstate
68
- >>> inp = brainstate.random.randn(32, 1, 5, 5)
69
- >>> # With default parameters
70
- >>> m = Flatten()
71
- >>> output = m(inp)
72
- >>> output.shape
73
- (32, 25)
74
- >>> # With non-default parameters
75
- >>> m = Flatten(0, 2)
76
- >>> output = m(inp)
77
- >>> output.shape
78
- (160, 5)
79
- """
80
- __module__ = 'brainstate.nn'
81
-
82
- def __init__(
83
- self,
84
- start_axis: int = 0,
85
- end_axis: int = -1,
86
- in_size: Optional[Size] = None
87
- ) -> None:
88
- super().__init__()
89
- self.start_axis = start_axis
90
- self.end_axis = end_axis
91
-
92
- if in_size is not None:
93
- self.in_size = tuple(in_size)
94
- y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
95
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
96
- self.out_size = y.shape
97
-
98
- def update(self, x):
99
- if self._in_size is None:
100
- start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
101
- else:
102
- assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
103
- dim_diff = x.ndim - len(self.in_size)
104
- if self.in_size != x.shape[dim_diff:]:
105
- raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
106
- if self.start_axis >= 0:
107
- start_axis = self.start_axis + dim_diff
108
- else:
109
- start_axis = x.ndim + self.start_axis
110
- return u.math.flatten(x, start_axis, self.end_axis)
111
-
112
-
113
- class Unflatten(Module):
114
- r"""
115
- Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
116
-
117
- * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
118
- be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
119
-
120
- * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
121
- a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
122
- (tuple of `(name, size)` tuples) for `NamedTensor` input.
123
-
124
- Shape:
125
- - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
126
- dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
127
- - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
128
- :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
129
-
130
- Parameters
131
- ----------
132
- axis : int
133
- Dimension to be unflattened.
134
- sizes : Sequence of int
135
- New shape of the unflattened dimension.
136
- name : str, optional
137
- The name of the module.
138
- in_size : Sequence of int, optional
139
- The shape of the input tensor.
140
- """
141
- __module__ = 'brainstate.nn'
142
-
143
- def __init__(
144
- self,
145
- axis: int,
146
- sizes: Size,
147
- name: str = None,
148
- in_size: Optional[Size] = None
149
- ) -> None:
150
- super().__init__(name=name)
151
-
152
- self.axis = axis
153
- self.sizes = sizes
154
- if isinstance(sizes, (tuple, list)):
155
- for idx, elem in enumerate(sizes):
156
- if not isinstance(elem, int):
157
- raise TypeError("unflattened sizes must be tuple of ints, " +
158
- "but found element of type {} at pos {}".format(type(elem).__name__, idx))
159
- else:
160
- raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
161
-
162
- if in_size is not None:
163
- self.in_size = tuple(in_size)
164
- y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
165
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
166
- self.out_size = y.shape
167
-
168
- def update(self, x):
169
- return u.math.unflatten(x, self.axis, self.sizes)
170
-
171
-
172
- class _MaxPool(Module):
173
- def __init__(
174
- self,
175
- init_value: float,
176
- computation: Callable,
177
- pool_dim: int,
178
- kernel_size: Size,
179
- stride: Union[int, Sequence[int]] = None,
180
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
181
- channel_axis: Optional[int] = -1,
182
- return_indices: bool = False,
183
- name: Optional[str] = None,
184
- in_size: Optional[Size] = None,
185
- ):
186
- super().__init__(name=name)
187
-
188
- self.init_value = init_value
189
- self.computation = computation
190
- self.pool_dim = pool_dim
191
- self.return_indices = return_indices
192
-
193
- # kernel_size
194
- if isinstance(kernel_size, int):
195
- kernel_size = (kernel_size,) * pool_dim
196
- elif isinstance(kernel_size, Sequence):
197
- assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
198
- assert all(
199
- [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
200
- if len(kernel_size) != pool_dim:
201
- raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
202
- else:
203
- raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
204
- self.kernel_size = kernel_size
205
-
206
- # stride
207
- if stride is None:
208
- stride = kernel_size
209
- if isinstance(stride, int):
210
- stride = (stride,) * pool_dim
211
- elif isinstance(stride, Sequence):
212
- assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
213
- assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
214
- if len(stride) != pool_dim:
215
- raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
216
- else:
217
- raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
218
- self.stride = stride
219
-
220
- # padding
221
- if isinstance(padding, str):
222
- if padding not in ("SAME", "VALID"):
223
- raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
224
- elif isinstance(padding, int):
225
- padding = [(padding, padding) for _ in range(pool_dim)]
226
- elif isinstance(padding, (list, tuple)):
227
- if isinstance(padding[0], int):
228
- if len(padding) == pool_dim:
229
- padding = [(x, x) for x in padding]
230
- else:
231
- raise ValueError(f'If padding is a sequence of ints, it '
232
- f'should has the length of {pool_dim}.')
233
- else:
234
- if not all([isinstance(x, (tuple, list)) for x in padding]):
235
- raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
236
- if not all([len(x) == 2 for x in padding]):
237
- raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
238
- if len(padding) == 1:
239
- padding = tuple(padding) * pool_dim
240
- assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
241
- else:
242
- raise ValueError
243
- self.padding = padding
244
-
245
- # channel_axis
246
- assert channel_axis is None or isinstance(channel_axis, int), \
247
- f'channel_axis should be an int, but got {channel_axis}'
248
- self.channel_axis = channel_axis
249
-
250
- # in & out shapes
251
- if in_size is not None:
252
- in_size = tuple(in_size)
253
- self.in_size = in_size
254
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
255
- self.out_size = y.shape[1:]
256
-
257
- def update(self, x):
258
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
259
- if x.ndim < x_dim:
260
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
261
- window_shape = tuple(self._infer_shape(x.ndim, self.kernel_size, 1))
262
- stride = tuple(self._infer_shape(x.ndim, self.stride, 1))
263
- if isinstance(self.padding, str):
264
- padding = tuple(jax.lax.padtype_to_pads(x.shape, window_shape, stride, self.padding))
265
- else:
266
- padding = tuple(self._infer_shape(x.ndim, self.padding, element=(0, 0)))
267
-
268
- if self.return_indices:
269
- # For returning indices, we need to use a custom implementation
270
- return self._pooling_with_indices(x, window_shape, stride, padding)
271
-
272
- return jax.lax.reduce_window(
273
- x,
274
- init_value=self.init_value,
275
- computation=self.computation,
276
- window_dimensions=window_shape,
277
- window_strides=stride,
278
- padding=padding
279
- )
280
-
281
- def _pooling_with_indices(self, x, window_shape, stride, padding):
282
- """Perform max pooling and return both pooled values and indices."""
283
- total_size = x.size
284
- flat_indices = jnp.arange(total_size, dtype=jnp.int32).reshape(x.shape)
285
-
286
- init_val = jnp.asarray(self.init_value, dtype=x.dtype)
287
- init_idx = jnp.array(total_size, dtype=flat_indices.dtype)
288
-
289
- def reducer(acc, operand):
290
- acc_val, acc_idx = acc
291
- cur_val, cur_idx = operand
292
-
293
- better = cur_val > acc_val
294
- best_val = jnp.where(better, cur_val, acc_val)
295
- best_idx = jnp.where(better, cur_idx, acc_idx)
296
- tie = jnp.logical_and(cur_val == acc_val, cur_idx < acc_idx)
297
- best_idx = jnp.where(tie, cur_idx, best_idx)
298
-
299
- return best_val, best_idx
300
-
301
- pooled, indices_result = jax.lax.reduce_window(
302
- (x, flat_indices),
303
- (init_val, init_idx),
304
- reducer,
305
- window_dimensions=window_shape,
306
- window_strides=stride,
307
- padding=padding
308
- )
309
-
310
- indices_result = jnp.where(indices_result == total_size, 0, indices_result)
311
-
312
- return pooled, indices_result.astype(jnp.int32)
313
-
314
- def _infer_shape(self, x_dim, inputs, element):
315
- channel_axis = self.channel_axis
316
- if channel_axis and not 0 <= abs(channel_axis) < x_dim:
317
- raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
318
- if channel_axis and channel_axis < 0:
319
- channel_axis = x_dim + channel_axis
320
- all_dims = list(range(x_dim))
321
- if channel_axis is not None:
322
- all_dims.pop(channel_axis)
323
- pool_dims = all_dims[-self.pool_dim:]
324
- results = [element] * x_dim
325
- for i, dim in enumerate(pool_dims):
326
- results[dim] = inputs[i]
327
- return results
328
-
329
-
330
- class _AvgPool(_MaxPool):
331
- def update(self, x):
332
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
333
- if x.ndim < x_dim:
334
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
335
- dims = self._infer_shape(x.ndim, self.kernel_size, 1)
336
- stride = self._infer_shape(x.ndim, self.stride, 1)
337
- padding = (self.padding if isinstance(self.padding, str) else
338
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
339
- pooled = jax.lax.reduce_window(x,
340
- init_value=self.init_value,
341
- computation=self.computation,
342
- window_dimensions=dims,
343
- window_strides=stride,
344
- padding=padding)
345
- if padding == "VALID":
346
- # Avoid the extra reduce_window.
347
- return pooled / np.prod(dims)
348
- else:
349
- # Count the number of valid entries at each input point, then use that for
350
- # computing average. Assumes that any two arrays of same shape will be
351
- # padded the same.
352
- window_counts = jax.lax.reduce_window(jnp.ones_like(x),
353
- init_value=self.init_value,
354
- computation=self.computation,
355
- window_dimensions=dims,
356
- window_strides=stride,
357
- padding=padding)
358
- assert pooled.shape == window_counts.shape
359
- return pooled / window_counts
360
-
361
-
362
- class MaxPool1d(_MaxPool):
363
- r"""Applies a 1D max pooling over an input signal composed of several input planes.
364
-
365
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
366
- and output :math:`(N, L_{out}, C)` can be precisely described as:
367
-
368
- .. math::
369
- out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
370
- input(N_i, stride \times k + m, C_j)
371
-
372
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
373
- for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
374
- sliding window. This `link`_ has a nice visualization of the pooling parameters.
375
-
376
- Shape:
377
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
378
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
379
-
380
- .. math::
381
- L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
382
- \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
383
-
384
- Parameters
385
- ----------
386
- kernel_size : int or sequence of int
387
- An integer, or a sequence of integers defining the window to reduce over.
388
- stride : int or sequence of int, optional
389
- An integer, or a sequence of integers, representing the inter-window stride.
390
- Default: kernel_size
391
- padding : str, int or sequence of tuple, optional
392
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
393
- of n `(low, high)` integer pairs that give the padding to apply before
394
- and after each spatial dimension. Default: 'VALID'
395
- channel_axis : int, optional
396
- Axis of the spatial channels for which pooling is skipped.
397
- If ``None``, there is no channel axis. Default: -1
398
- return_indices : bool, optional
399
- If True, will return the max indices along with the outputs.
400
- Useful for MaxUnpool1d. Default: False
401
- name : str, optional
402
- The object name.
403
- in_size : Sequence of int, optional
404
- The shape of the input tensor.
405
-
406
- Examples
407
- --------
408
- .. code-block:: python
409
-
410
- >>> import brainstate
411
- >>> # pool of size=3, stride=2
412
- >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
413
- >>> input = brainstate.random.randn(20, 50, 16)
414
- >>> output = m(input)
415
- >>> output.shape
416
- (20, 24, 16)
417
-
418
- .. _link:
419
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
420
- """
421
- __module__ = 'brainstate.nn'
422
-
423
- def __init__(
424
- self,
425
- kernel_size: Size,
426
- stride: Union[int, Sequence[int]] = None,
427
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
428
- channel_axis: Optional[int] = -1,
429
- return_indices: bool = False,
430
- name: Optional[str] = None,
431
- in_size: Optional[Size] = None,
432
- ):
433
- super().__init__(in_size=in_size,
434
- init_value=-jax.numpy.inf,
435
- computation=jax.lax.max,
436
- pool_dim=1,
437
- kernel_size=kernel_size,
438
- stride=stride,
439
- padding=padding,
440
- channel_axis=channel_axis,
441
- return_indices=return_indices,
442
- name=name)
443
-
444
-
445
- class MaxPool2d(_MaxPool):
446
- r"""Applies a 2D max pooling over an input signal composed of several input planes.
447
-
448
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
449
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
450
- can be precisely described as:
451
-
452
- .. math::
453
- \begin{aligned}
454
- out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
455
- & \text{input}(N_i, \text{stride[0]} \times h + m,
456
- \text{stride[1]} \times w + n, C_j)
457
- \end{aligned}
458
-
459
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
460
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
461
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
462
-
463
- Shape:
464
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
465
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
466
-
467
- .. math::
468
- H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
469
- \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
470
-
471
- .. math::
472
- W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
473
- \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
474
-
475
- Parameters
476
- ----------
477
- kernel_size : int or sequence of int
478
- An integer, or a sequence of integers defining the window to reduce over.
479
- stride : int or sequence of int, optional
480
- An integer, or a sequence of integers, representing the inter-window stride.
481
- Default: kernel_size
482
- padding : str, int or sequence of tuple, optional
483
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
484
- of n `(low, high)` integer pairs that give the padding to apply before
485
- and after each spatial dimension. Default: 'VALID'
486
- channel_axis : int, optional
487
- Axis of the spatial channels for which pooling is skipped.
488
- If ``None``, there is no channel axis. Default: -1
489
- return_indices : bool, optional
490
- If True, will return the max indices along with the outputs.
491
- Useful for MaxUnpool2d. Default: False
492
- name : str, optional
493
- The object name.
494
- in_size : Sequence of int, optional
495
- The shape of the input tensor.
496
-
497
- Examples
498
- --------
499
- .. code-block:: python
500
-
501
- >>> import brainstate
502
- >>> # pool of square window of size=3, stride=2
503
- >>> m = MaxPool2d(3, stride=2)
504
- >>> # pool of non-square window
505
- >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
506
- >>> input = brainstate.random.randn(20, 50, 32, 16)
507
- >>> output = m(input)
508
- >>> output.shape
509
- (20, 24, 31, 16)
510
-
511
- .. _link:
512
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
513
- """
514
- __module__ = 'brainstate.nn'
515
-
516
- def __init__(
517
- self,
518
- kernel_size: Size,
519
- stride: Union[int, Sequence[int]] = None,
520
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
521
- channel_axis: Optional[int] = -1,
522
- return_indices: bool = False,
523
- name: Optional[str] = None,
524
- in_size: Optional[Size] = None,
525
- ):
526
- super().__init__(in_size=in_size,
527
- init_value=-jax.numpy.inf,
528
- computation=jax.lax.max,
529
- pool_dim=2,
530
- kernel_size=kernel_size,
531
- stride=stride,
532
- padding=padding,
533
- channel_axis=channel_axis,
534
- return_indices=return_indices,
535
- name=name)
536
-
537
-
538
- class MaxPool3d(_MaxPool):
539
- r"""Applies a 3D max pooling over an input signal composed of several input planes.
540
-
541
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
542
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
543
- can be precisely described as:
544
-
545
- .. math::
546
- \begin{aligned}
547
- \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
548
- & \text{input}(N_i, \text{stride[0]} \times d + k,
549
- \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
550
- \end{aligned}
551
-
552
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
553
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
554
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
555
-
556
- Shape:
557
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
558
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
559
-
560
- .. math::
561
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
562
- (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
563
-
564
- .. math::
565
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
566
- (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
567
-
568
- .. math::
569
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
570
- (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
571
-
572
- Parameters
573
- ----------
574
- kernel_size : int or sequence of int
575
- An integer, or a sequence of integers defining the window to reduce over.
576
- stride : int or sequence of int, optional
577
- An integer, or a sequence of integers, representing the inter-window stride.
578
- Default: kernel_size
579
- padding : str, int or sequence of tuple, optional
580
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
581
- of n `(low, high)` integer pairs that give the padding to apply before
582
- and after each spatial dimension. Default: 'VALID'
583
- channel_axis : int, optional
584
- Axis of the spatial channels for which pooling is skipped.
585
- If ``None``, there is no channel axis. Default: -1
586
- return_indices : bool, optional
587
- If True, will return the max indices along with the outputs.
588
- Useful for MaxUnpool3d. Default: False
589
- name : str, optional
590
- The object name.
591
- in_size : Sequence of int, optional
592
- The shape of the input tensor.
593
-
594
- Examples
595
- --------
596
- .. code-block:: python
597
-
598
- >>> import brainstate
599
- >>> # pool of square window of size=3, stride=2
600
- >>> m = MaxPool3d(3, stride=2)
601
- >>> # pool of non-square window
602
- >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
603
- >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
604
- >>> output = m(input)
605
- >>> output.shape
606
- (20, 24, 43, 15, 16)
607
-
608
- .. _link:
609
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
610
- """
611
- __module__ = 'brainstate.nn'
612
-
613
- def __init__(
614
- self,
615
- kernel_size: Size,
616
- stride: Union[int, Sequence[int]] = None,
617
- padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
618
- channel_axis: Optional[int] = -1,
619
- return_indices: bool = False,
620
- name: Optional[str] = None,
621
- in_size: Optional[Size] = None,
622
- ):
623
- super().__init__(in_size=in_size,
624
- init_value=-jax.numpy.inf,
625
- computation=jax.lax.max,
626
- pool_dim=3,
627
- kernel_size=kernel_size,
628
- stride=stride,
629
- padding=padding,
630
- channel_axis=channel_axis,
631
- return_indices=return_indices,
632
- name=name)
633
-
634
-
635
- class _MaxUnpool(Module):
636
- """Base class for max unpooling operations."""
637
-
638
- def __init__(
639
- self,
640
- pool_dim: int,
641
- kernel_size: Size,
642
- stride: Union[int, Sequence[int]] = None,
643
- padding: Union[int, Tuple[int, ...]] = 0,
644
- channel_axis: Optional[int] = -1,
645
- name: Optional[str] = None,
646
- in_size: Optional[Size] = None,
647
- ):
648
- super().__init__(name=name)
649
-
650
- self.pool_dim = pool_dim
651
-
652
- # kernel_size
653
- if isinstance(kernel_size, int):
654
- kernel_size = (kernel_size,) * pool_dim
655
- elif isinstance(kernel_size, Sequence):
656
- assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
657
- assert all(
658
- [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
659
- if len(kernel_size) != pool_dim:
660
- raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
661
- else:
662
- raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
663
- self.kernel_size = kernel_size
664
-
665
- # stride
666
- if stride is None:
667
- stride = kernel_size
668
- if isinstance(stride, int):
669
- stride = (stride,) * pool_dim
670
- elif isinstance(stride, Sequence):
671
- assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
672
- assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
673
- if len(stride) != pool_dim:
674
- raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
675
- else:
676
- raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
677
- self.stride = stride
678
-
679
- # padding
680
- if isinstance(padding, int):
681
- padding = (padding,) * pool_dim
682
- elif isinstance(padding, (tuple, list)):
683
- if len(padding) != pool_dim:
684
- raise ValueError(f'padding should have {pool_dim} values, but got {len(padding)}')
685
- else:
686
- raise TypeError(f'padding should be int or tuple of {pool_dim} ints.')
687
- self.padding = padding
688
-
689
- # channel_axis
690
- assert channel_axis is None or isinstance(channel_axis, int), \
691
- f'channel_axis should be an int, but got {channel_axis}'
692
- self.channel_axis = channel_axis
693
-
694
- # in & out shapes
695
- if in_size is not None:
696
- in_size = tuple(in_size)
697
- self.in_size = in_size
698
-
699
- def _compute_output_shape(self, input_shape, output_size=None):
700
- """Compute the output shape after unpooling."""
701
- if output_size is not None:
702
- return output_size
703
-
704
- # Calculate output shape based on kernel, stride, and padding
705
- output_shape = []
706
- for i in range(self.pool_dim):
707
- dim_size = (input_shape[i] - 1) * self.stride[i] - 2 * self.padding[i] + self.kernel_size[i]
708
- output_shape.append(dim_size)
709
-
710
- return tuple(output_shape)
711
-
712
- def _unpool_nd(self, x, indices, output_size=None):
713
- """Perform N-dimensional max unpooling."""
714
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
715
- if x.ndim < x_dim:
716
- raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
717
-
718
- # Determine output shape
719
- if output_size is None:
720
- # Infer output shape from input shape
721
- spatial_dims = self._get_spatial_dims(x.shape)
722
- output_spatial_shape = self._compute_output_shape(spatial_dims, output_size)
723
- output_shape = list(x.shape)
724
-
725
- # Update spatial dimensions in output shape
726
- spatial_start = self._get_spatial_start_idx(x.ndim)
727
- for i, size in enumerate(output_spatial_shape):
728
- output_shape[spatial_start + i] = size
729
- output_shape = tuple(output_shape)
730
- else:
731
- # Use provided output size
732
- if isinstance(output_size, (list, tuple)):
733
- if len(output_size) == x.ndim:
734
- # Full output shape provided
735
- output_shape = tuple(output_size)
736
- else:
737
- # Only spatial dimensions provided
738
- if len(output_size) != self.pool_dim:
739
- raise ValueError(f"output_size must have {self.pool_dim} spatial dimensions, got {len(output_size)}")
740
- output_shape = list(x.shape)
741
- spatial_start = self._get_spatial_start_idx(x.ndim)
742
- for i, size in enumerate(output_size):
743
- output_shape[spatial_start + i] = size
744
- output_shape = tuple(output_shape)
745
- else:
746
- # Single integer provided, use for all spatial dims
747
- output_shape = list(x.shape)
748
- spatial_start = self._get_spatial_start_idx(x.ndim)
749
- for i in range(self.pool_dim):
750
- output_shape[spatial_start + i] = output_size
751
- output_shape = tuple(output_shape)
752
-
753
- # Create output array filled with zeros
754
- output = jnp.zeros(output_shape, dtype=x.dtype)
755
-
756
- # # Scatter input values to output using indices
757
- # # Flatten spatial dimensions for easier indexing
758
- # batch_dims = x.ndim - self.pool_dim - (0 if self.channel_axis is None else 1)
759
- #
760
- # # Reshape for processing
761
- # if batch_dims > 0:
762
- # batch_shape = x.shape[:batch_dims]
763
- # if self.channel_axis is not None and self.channel_axis < batch_dims:
764
- # # Channel axis is before spatial dims
765
- # channel_idx = self.channel_axis
766
- # n_channels = x.shape[channel_idx]
767
- # elif self.channel_axis is not None:
768
- # # Channel axis is after spatial dims
769
- # if self.channel_axis < 0:
770
- # channel_idx = x.ndim + self.channel_axis
771
- # else:
772
- # channel_idx = self.channel_axis
773
- # n_channels = x.shape[channel_idx]
774
- # else:
775
- # n_channels = None
776
- # else:
777
- # batch_shape = ()
778
- # if self.channel_axis is not None:
779
- # if self.channel_axis < 0:
780
- # channel_idx = x.ndim + self.channel_axis
781
- # else:
782
- # channel_idx = self.channel_axis
783
- # n_channels = x.shape[channel_idx]
784
- # else:
785
- # n_channels = None
786
-
787
- # Use JAX's scatter operation
788
- # Flatten the indices to 1D for scatter
789
- flat_indices = indices.ravel()
790
- flat_values = x.ravel()
791
- flat_output = output.ravel()
792
-
793
- # Scatter the values
794
- flat_output = flat_output.at[flat_indices].set(flat_values)
795
-
796
- # Reshape back to original shape
797
- output = flat_output.reshape(output_shape)
798
-
799
- return output
800
-
801
- def _get_spatial_dims(self, shape):
802
- """Extract spatial dimensions from input shape."""
803
- if self.channel_axis is None:
804
- return shape[-self.pool_dim:]
805
- else:
806
- channel_axis = self.channel_axis if self.channel_axis >= 0 else len(shape) + self.channel_axis
807
- all_dims = list(range(len(shape)))
808
- all_dims.pop(channel_axis)
809
- return tuple(shape[i] for i in all_dims[-self.pool_dim:])
810
-
811
- def _get_spatial_start_idx(self, ndim):
812
- """Get the starting index of spatial dimensions."""
813
- if self.channel_axis is None:
814
- return ndim - self.pool_dim
815
- else:
816
- channel_axis = self.channel_axis if self.channel_axis >= 0 else ndim + self.channel_axis
817
- if channel_axis < ndim - self.pool_dim:
818
- return ndim - self.pool_dim
819
- else:
820
- return ndim - self.pool_dim - 1
821
-
822
- def update(self, x, indices, output_size=None):
823
- """Forward pass of MaxUnpool1d.
824
-
825
- Parameters
826
- ----------
827
- x : Array
828
- Input tensor from MaxPool1d
829
- indices : Array
830
- Indices of maximum values from MaxPool1d
831
- output_size : int or tuple, optional
832
- The targeted output size
833
-
834
- Returns
835
- -------
836
- Array
837
- Unpooled output
838
- """
839
- return self._unpool_nd(x, indices, output_size)
840
-
841
-
842
- class MaxUnpool1d(_MaxUnpool):
843
- r"""Computes a partial inverse of MaxPool1d.
844
-
845
- MaxPool1d is not fully invertible, since the non-maximal values are lost.
846
- MaxUnpool1d takes in as input the output of MaxPool1d including the indices
847
- of the maximal values and computes a partial inverse in which all
848
- non-maximal values are set to zero.
849
-
850
- Note:
851
- This function may produce nondeterministic gradients when given tensors
852
- on a CUDA device. See notes on reproducibility for more information.
853
-
854
- Shape:
855
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`
856
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
857
-
858
- .. math::
859
- L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size}
860
-
861
- or as given by :attr:`output_size` in the call operator
862
-
863
- Parameters
864
- ----------
865
- kernel_size : int or tuple
866
- Size of the max pooling window.
867
- stride : int or tuple, optional
868
- Stride of the max pooling window. Default: kernel_size
869
- padding : int or tuple, optional
870
- Padding that was added to the input. Default: 0
871
- channel_axis : int, optional
872
- Axis of the channels. Default: -1
873
- name : str, optional
874
- Name of the module.
875
- in_size : Size, optional
876
- Input size for shape inference.
877
-
878
- Examples
879
- --------
880
- .. code-block:: python
881
-
882
- >>> import brainstate
883
- >>> import jax.numpy as jnp
884
- >>> # Create pooling and unpooling layers
885
- >>> pool = MaxPool1d(2, stride=2, return_indices=True, channel_axis=-1)
886
- >>> unpool = MaxUnpool1d(2, stride=2, channel_axis=-1)
887
- >>> input = brainstate.random.randn(20, 50, 16)
888
- >>> output, indices = pool(input)
889
- >>> unpooled = unpool(output, indices)
890
- >>> # unpooled will have shape (20, 100, 16) with zeros at non-maximal positions
891
- """
892
- __module__ = 'brainstate.nn'
893
-
894
- def __init__(
895
- self,
896
- kernel_size: Size,
897
- stride: Union[int, Sequence[int]] = None,
898
- padding: Union[int, Tuple[int, ...]] = 0,
899
- channel_axis: Optional[int] = -1,
900
- name: Optional[str] = None,
901
- in_size: Optional[Size] = None,
902
- ):
903
- super().__init__(
904
- pool_dim=1,
905
- kernel_size=kernel_size,
906
- stride=stride,
907
- padding=padding,
908
- channel_axis=channel_axis,
909
- name=name,
910
- in_size=in_size
911
- )
912
-
913
-
914
- class MaxUnpool2d(_MaxUnpool):
915
- r"""Computes a partial inverse of MaxPool2d.
916
-
917
- MaxPool2d is not fully invertible, since the non-maximal values are lost.
918
- MaxUnpool2d takes in as input the output of MaxPool2d including the indices
919
- of the maximal values and computes a partial inverse in which all
920
- non-maximal values are set to zero.
921
-
922
- Shape:
923
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
924
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
925
-
926
- .. math::
927
- H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
928
-
929
- .. math::
930
- W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
931
-
932
- or as given by :attr:`output_size` in the call operator
933
-
934
- Parameters
935
- ----------
936
- kernel_size : int or tuple
937
- Size of the max pooling window.
938
- stride : int or tuple, optional
939
- Stride of the max pooling window. Default: kernel_size
940
- padding : int or tuple, optional
941
- Padding that was added to the input. Default: 0
942
- channel_axis : int, optional
943
- Axis of the channels. Default: -1
944
- name : str, optional
945
- Name of the module.
946
- in_size : Size, optional
947
- Input size for shape inference.
948
-
949
- Examples
950
- --------
951
- .. code-block:: python
952
-
953
- >>> import brainstate
954
- >>> # Create pooling and unpooling layers
955
- >>> pool = MaxPool2d(2, stride=2, return_indices=True, channel_axis=-1)
956
- >>> unpool = MaxUnpool2d(2, stride=2, channel_axis=-1)
957
- >>> input = brainstate.random.randn(1, 4, 4, 16)
958
- >>> output, indices = pool(input)
959
- >>> unpooled = unpool(output, indices)
960
- >>> # unpooled will have shape (1, 8, 8, 16) with zeros at non-maximal positions
961
- """
962
- __module__ = 'brainstate.nn'
963
-
964
- def __init__(
965
- self,
966
- kernel_size: Size,
967
- stride: Union[int, Sequence[int]] = None,
968
- padding: Union[int, Tuple[int, ...]] = 0,
969
- channel_axis: Optional[int] = -1,
970
- name: Optional[str] = None,
971
- in_size: Optional[Size] = None,
972
- ):
973
- super().__init__(
974
- pool_dim=2,
975
- kernel_size=kernel_size,
976
- stride=stride,
977
- padding=padding,
978
- channel_axis=channel_axis,
979
- name=name,
980
- in_size=in_size
981
- )
982
-
983
-
984
- class MaxUnpool3d(_MaxUnpool):
985
- r"""Computes a partial inverse of MaxPool3d.
986
-
987
- MaxPool3d is not fully invertible, since the non-maximal values are lost.
988
- MaxUnpool3d takes in as input the output of MaxPool3d including the indices
989
- of the maximal values and computes a partial inverse in which all
990
- non-maximal values are set to zero.
991
-
992
- Shape:
993
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`
994
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
995
-
996
- .. math::
997
- D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
998
-
999
- .. math::
1000
- H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
1001
-
1002
- .. math::
1003
- W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{kernel\_size}[2]
1004
-
1005
- or as given by :attr:`output_size` in the call operator
1006
-
1007
- Parameters
1008
- ----------
1009
- kernel_size : int or tuple
1010
- Size of the max pooling window.
1011
- stride : int or tuple, optional
1012
- Stride of the max pooling window. Default: kernel_size
1013
- padding : int or tuple, optional
1014
- Padding that was added to the input. Default: 0
1015
- channel_axis : int, optional
1016
- Axis of the channels. Default: -1
1017
- name : str, optional
1018
- Name of the module.
1019
- in_size : Size, optional
1020
- Input size for shape inference.
1021
-
1022
- Examples
1023
- --------
1024
- .. code-block:: python
1025
-
1026
- >>> import brainstate
1027
- >>> # Create pooling and unpooling layers
1028
- >>> pool = MaxPool3d(2, stride=2, return_indices=True, channel_axis=-1)
1029
- >>> unpool = MaxUnpool3d(2, stride=2, channel_axis=-1)
1030
- >>> input = brainstate.random.randn(1, 4, 4, 4, 16)
1031
- >>> output, indices = pool(input)
1032
- >>> unpooled = unpool(output, indices)
1033
- >>> # unpooled will have shape (1, 8, 8, 8, 16) with zeros at non-maximal positions
1034
- """
1035
- __module__ = 'brainstate.nn'
1036
-
1037
- def __init__(
1038
- self,
1039
- kernel_size: Size,
1040
- stride: Union[int, Sequence[int]] = None,
1041
- padding: Union[int, Tuple[int, ...]] = 0,
1042
- channel_axis: Optional[int] = -1,
1043
- name: Optional[str] = None,
1044
- in_size: Optional[Size] = None,
1045
- ):
1046
- super().__init__(
1047
- pool_dim=3,
1048
- kernel_size=kernel_size,
1049
- stride=stride,
1050
- padding=padding,
1051
- channel_axis=channel_axis,
1052
- name=name,
1053
- in_size=in_size
1054
- )
1055
-
1056
-
1057
- class AvgPool1d(_AvgPool):
1058
- r"""Applies a 1D average pooling over an input signal composed of several input planes.
1059
-
1060
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
1061
- output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
1062
- can be precisely described as:
1063
-
1064
- .. math::
1065
-
1066
- \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
1067
- \text{input}(N_i, \text{stride} \times l + m, C_j)
1068
-
1069
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
1070
- for :attr:`padding` number of points.
1071
-
1072
- Shape:
1073
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
1074
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
1075
-
1076
- .. math::
1077
- L_{out} = \left\lfloor \frac{L_{in} +
1078
- 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
1079
-
1080
- Parameters
1081
- ----------
1082
- kernel_size : int or sequence of int
1083
- An integer, or a sequence of integers defining the window to reduce over.
1084
- stride : int or sequence of int, optional
1085
- An integer, or a sequence of integers, representing the inter-window stride.
1086
- Default: 1
1087
- padding : str, int or sequence of tuple, optional
1088
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
1089
- of n `(low, high)` integer pairs that give the padding to apply before
1090
- and after each spatial dimension. Default: 'VALID'
1091
- channel_axis : int, optional
1092
- Axis of the spatial channels for which pooling is skipped.
1093
- If ``None``, there is no channel axis. Default: -1
1094
- name : str, optional
1095
- The object name.
1096
- in_size : Sequence of int, optional
1097
- The shape of the input tensor.
1098
-
1099
- Examples
1100
- --------
1101
- .. code-block:: python
1102
-
1103
- >>> import brainstate
1104
- >>> # pool with window of size=3, stride=2
1105
- >>> m = AvgPool1d(3, stride=2)
1106
- >>> input = brainstate.random.randn(20, 50, 16)
1107
- >>> m(input).shape
1108
- (20, 24, 16)
1109
- """
1110
- __module__ = 'brainstate.nn'
1111
-
1112
- def __init__(
1113
- self,
1114
- kernel_size: Size,
1115
- stride: Union[int, Sequence[int]] = 1,
1116
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1117
- channel_axis: Optional[int] = -1,
1118
- name: Optional[str] = None,
1119
- in_size: Optional[Size] = None,
1120
- ):
1121
- super().__init__(
1122
- in_size=in_size,
1123
- init_value=0.,
1124
- computation=jax.lax.add,
1125
- pool_dim=1,
1126
- kernel_size=kernel_size,
1127
- stride=stride,
1128
- padding=padding,
1129
- channel_axis=channel_axis,
1130
- name=name
1131
- )
1132
-
1133
-
1134
- class AvgPool2d(_AvgPool):
1135
- r"""Applies a 2D average pooling over an input signal composed of several input planes.
1136
-
1137
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
1138
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
1139
- can be precisely described as:
1140
-
1141
- .. math::
1142
-
1143
- out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
1144
- input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
1145
-
1146
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
1147
- for :attr:`padding` number of points.
1148
-
1149
- Shape:
1150
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
1151
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1152
-
1153
- .. math::
1154
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
1155
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1156
-
1157
- .. math::
1158
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
1159
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1160
-
1161
- Parameters
1162
- ----------
1163
- kernel_size : int or sequence of int
1164
- An integer, or a sequence of integers defining the window to reduce over.
1165
- stride : int or sequence of int, optional
1166
- An integer, or a sequence of integers, representing the inter-window stride.
1167
- Default: 1
1168
- padding : str, int or sequence of tuple, optional
1169
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
1170
- of n `(low, high)` integer pairs that give the padding to apply before
1171
- and after each spatial dimension. Default: 'VALID'
1172
- channel_axis : int, optional
1173
- Axis of the spatial channels for which pooling is skipped.
1174
- If ``None``, there is no channel axis. Default: -1
1175
- name : str, optional
1176
- The object name.
1177
- in_size : Sequence of int, optional
1178
- The shape of the input tensor.
1179
-
1180
- Examples
1181
- --------
1182
- .. code-block:: python
1183
-
1184
- >>> import brainstate
1185
- >>> # pool of square window of size=3, stride=2
1186
- >>> m = AvgPool2d(3, stride=2)
1187
- >>> # pool of non-square window
1188
- >>> m = AvgPool2d((3, 2), stride=(2, 1))
1189
- >>> input = brainstate.random.randn(20, 50, 32, 16)
1190
- >>> output = m(input)
1191
- >>> output.shape
1192
- (20, 24, 31, 16)
1193
- """
1194
- __module__ = 'brainstate.nn'
1195
-
1196
- def __init__(
1197
- self,
1198
- kernel_size: Size,
1199
- stride: Union[int, Sequence[int]] = 1,
1200
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1201
- channel_axis: Optional[int] = -1,
1202
- name: Optional[str] = None,
1203
- in_size: Optional[Size] = None,
1204
- ):
1205
- super().__init__(
1206
- in_size=in_size,
1207
- init_value=0.,
1208
- computation=jax.lax.add,
1209
- pool_dim=2,
1210
- kernel_size=kernel_size,
1211
- stride=stride,
1212
- padding=padding,
1213
- channel_axis=channel_axis,
1214
- name=name
1215
- )
1216
-
1217
-
1218
- class AvgPool3d(_AvgPool):
1219
- r"""Applies a 3D average pooling over an input signal composed of several input planes.
1220
-
1221
-
1222
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
1223
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
1224
- can be precisely described as:
1225
-
1226
- .. math::
1227
- \begin{aligned}
1228
- \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
1229
- & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
1230
- \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
1231
- {kD \times kH \times kW}
1232
- \end{aligned}
1233
-
1234
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
1235
- for :attr:`padding` number of points.
1236
-
1237
- Shape:
1238
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1239
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
1240
- :math:`(D_{out}, H_{out}, W_{out}, C)`, where
1241
-
1242
- .. math::
1243
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
1244
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1245
-
1246
- .. math::
1247
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
1248
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1249
-
1250
- .. math::
1251
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
1252
- \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
1253
-
1254
- Parameters
1255
- ----------
1256
- kernel_size : int or sequence of int
1257
- An integer, or a sequence of integers defining the window to reduce over.
1258
- stride : int or sequence of int, optional
1259
- An integer, or a sequence of integers, representing the inter-window stride.
1260
- Default: 1
1261
- padding : str, int or sequence of tuple, optional
1262
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
1263
- of n `(low, high)` integer pairs that give the padding to apply before
1264
- and after each spatial dimension. Default: 'VALID'
1265
- channel_axis : int, optional
1266
- Axis of the spatial channels for which pooling is skipped.
1267
- If ``None``, there is no channel axis. Default: -1
1268
- name : str, optional
1269
- The object name.
1270
- in_size : Sequence of int, optional
1271
- The shape of the input tensor.
1272
-
1273
- Examples
1274
- --------
1275
- .. code-block:: python
1276
-
1277
- >>> import brainstate
1278
- >>> # pool of square window of size=3, stride=2
1279
- >>> m = AvgPool3d(3, stride=2)
1280
- >>> # pool of non-square window
1281
- >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
1282
- >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
1283
- >>> output = m(input)
1284
- >>> output.shape
1285
- (20, 24, 43, 15, 16)
1286
-
1287
- """
1288
- __module__ = 'brainstate.nn'
1289
-
1290
- def __init__(
1291
- self,
1292
- kernel_size: Size,
1293
- stride: Union[int, Sequence[int]] = 1,
1294
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1295
- channel_axis: Optional[int] = -1,
1296
- name: Optional[str] = None,
1297
- in_size: Optional[Size] = None,
1298
- ):
1299
- super().__init__(
1300
- in_size=in_size,
1301
- init_value=0.,
1302
- computation=jax.lax.add,
1303
- pool_dim=3,
1304
- kernel_size=kernel_size,
1305
- stride=stride,
1306
- padding=padding,
1307
- channel_axis=channel_axis,
1308
- name=name
1309
- )
1310
-
1311
-
1312
- class _LPPool(Module):
1313
- """Base class for Lp pooling operations."""
1314
-
1315
- def __init__(
1316
- self,
1317
- norm_type: float,
1318
- pool_dim: int,
1319
- kernel_size: Size,
1320
- stride: Union[int, Sequence[int]] = None,
1321
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1322
- channel_axis: Optional[int] = -1,
1323
- name: Optional[str] = None,
1324
- in_size: Optional[Size] = None,
1325
- ):
1326
- super().__init__(name=name)
1327
-
1328
- if norm_type <= 0:
1329
- raise ValueError(f"norm_type must be positive, got {norm_type}")
1330
- self.norm_type = norm_type
1331
- self.pool_dim = pool_dim
1332
-
1333
- # kernel_size
1334
- if isinstance(kernel_size, int):
1335
- kernel_size = (kernel_size,) * pool_dim
1336
- elif isinstance(kernel_size, Sequence):
1337
- assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
1338
- assert all(
1339
- [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
1340
- if len(kernel_size) != pool_dim:
1341
- raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
1342
- else:
1343
- raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
1344
- self.kernel_size = kernel_size
1345
-
1346
- # stride
1347
- if stride is None:
1348
- stride = kernel_size
1349
- if isinstance(stride, int):
1350
- stride = (stride,) * pool_dim
1351
- elif isinstance(stride, Sequence):
1352
- assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
1353
- assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
1354
- if len(stride) != pool_dim:
1355
- raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
1356
- else:
1357
- raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
1358
- self.stride = stride
1359
-
1360
- # padding
1361
- if isinstance(padding, str):
1362
- if padding not in ("SAME", "VALID"):
1363
- raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
1364
- elif isinstance(padding, int):
1365
- padding = [(padding, padding) for _ in range(pool_dim)]
1366
- elif isinstance(padding, (list, tuple)):
1367
- if isinstance(padding[0], int):
1368
- if len(padding) == pool_dim:
1369
- padding = [(x, x) for x in padding]
1370
- else:
1371
- raise ValueError(f'If padding is a sequence of ints, it '
1372
- f'should has the length of {pool_dim}.')
1373
- else:
1374
- if not all([isinstance(x, (tuple, list)) for x in padding]):
1375
- raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
1376
- if not all([len(x) == 2 for x in padding]):
1377
- raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
1378
- if len(padding) == 1:
1379
- padding = tuple(padding) * pool_dim
1380
- assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
1381
- else:
1382
- raise ValueError
1383
- self.padding = padding
1384
-
1385
- # channel_axis
1386
- assert channel_axis is None or isinstance(channel_axis, int), \
1387
- f'channel_axis should be an int, but got {channel_axis}'
1388
- self.channel_axis = channel_axis
1389
-
1390
- # in & out shapes
1391
- if in_size is not None:
1392
- in_size = tuple(in_size)
1393
- self.in_size = in_size
1394
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
1395
- self.out_size = y.shape[1:]
1396
-
1397
- def update(self, x):
1398
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
1399
- if x.ndim < x_dim:
1400
- raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
1401
-
1402
- window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
1403
- stride = self._infer_shape(x.ndim, self.stride, 1)
1404
- padding = (self.padding if isinstance(self.padding, str) else
1405
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
1406
-
1407
- # For Lp pooling, we need to:
1408
- # 1. Take absolute value and raise to power p
1409
- # 2. Sum over the window
1410
- # 3. Take the p-th root
1411
-
1412
- # Step 1: |x|^p
1413
- x_pow = jnp.abs(x) ** self.norm_type
1414
-
1415
- # Step 2: Sum over window
1416
- pooled_sum = jax.lax.reduce_window(
1417
- x_pow,
1418
- init_value=0.,
1419
- computation=jax.lax.add,
1420
- window_dimensions=window_shape,
1421
- window_strides=stride,
1422
- padding=padding
1423
- )
1424
-
1425
- # Step 3: Take p-th root and multiply by normalization factor
1426
- # The normalization factor is (1/N)^(1/p) where N is the window size
1427
- window_size = np.prod([w for i, w in enumerate(self.kernel_size)])
1428
- norm_factor = window_size ** (-1.0 / self.norm_type)
1429
- result = norm_factor * (pooled_sum ** (1.0 / self.norm_type))
1430
-
1431
- return result
1432
-
1433
- def _infer_shape(self, x_dim, inputs, element):
1434
- channel_axis = self.channel_axis
1435
- if channel_axis and not 0 <= abs(channel_axis) < x_dim:
1436
- raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
1437
- if channel_axis and channel_axis < 0:
1438
- channel_axis = x_dim + channel_axis
1439
- all_dims = list(range(x_dim))
1440
- if channel_axis is not None:
1441
- all_dims.pop(channel_axis)
1442
- pool_dims = all_dims[-self.pool_dim:]
1443
- results = [element] * x_dim
1444
- for i, dim in enumerate(pool_dims):
1445
- results[dim] = inputs[i]
1446
- return results
1447
-
1448
-
1449
- class LPPool1d(_LPPool):
1450
- r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
1451
-
1452
- On each window, the function computed is:
1453
-
1454
- .. math::
1455
- f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1456
-
1457
- - At :math:`p = \infty`, one gets max pooling
1458
- - At :math:`p = 1`, one gets average pooling (with absolute values)
1459
- - At :math:`p = 2`, one gets root mean square (RMS) pooling
1460
-
1461
- Shape:
1462
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
1463
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
1464
-
1465
- .. math::
1466
- L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
1467
-
1468
- Parameters
1469
- ----------
1470
- norm_type : float
1471
- Exponent for the pooling operation. Default: 2.0
1472
- kernel_size : int or sequence of int
1473
- An integer, or a sequence of integers defining the window to reduce over.
1474
- stride : int or sequence of int, optional
1475
- An integer, or a sequence of integers, representing the inter-window stride.
1476
- Default: kernel_size
1477
- padding : str, int or sequence of tuple, optional
1478
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
1479
- of n `(low, high)` integer pairs that give the padding to apply before
1480
- and after each spatial dimension. Default: 'VALID'
1481
- channel_axis : int, optional
1482
- Axis of the spatial channels for which pooling is skipped.
1483
- If ``None``, there is no channel axis. Default: -1
1484
- name : str, optional
1485
- The object name.
1486
- in_size : Sequence of int, optional
1487
- The shape of the input tensor.
1488
-
1489
- Examples
1490
- --------
1491
- .. code-block:: python
1492
-
1493
- >>> import brainstate
1494
- >>> # power-average pooling of window of size=3, stride=2 with norm_type=2.0
1495
- >>> m = LPPool1d(2, 3, stride=2)
1496
- >>> input = brainstate.random.randn(20, 50, 16)
1497
- >>> output = m(input)
1498
- >>> output.shape
1499
- (20, 24, 16)
1500
- """
1501
- __module__ = 'brainstate.nn'
1502
-
1503
- def __init__(
1504
- self,
1505
- norm_type: float,
1506
- kernel_size: Size,
1507
- stride: Union[int, Sequence[int]] = None,
1508
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1509
- channel_axis: Optional[int] = -1,
1510
- name: Optional[str] = None,
1511
- in_size: Optional[Size] = None,
1512
- ):
1513
- super().__init__(
1514
- norm_type=norm_type,
1515
- pool_dim=1,
1516
- kernel_size=kernel_size,
1517
- stride=stride,
1518
- padding=padding,
1519
- channel_axis=channel_axis,
1520
- name=name,
1521
- in_size=in_size
1522
- )
1523
-
1524
-
1525
- class LPPool2d(_LPPool):
1526
- r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
1527
-
1528
- On each window, the function computed is:
1529
-
1530
- .. math::
1531
- f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1532
-
1533
- - At :math:`p = \infty`, one gets max pooling
1534
- - At :math:`p = 1`, one gets average pooling (with absolute values)
1535
- - At :math:`p = 2`, one gets root mean square (RMS) pooling
1536
-
1537
- Shape:
1538
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
1539
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1540
-
1541
- .. math::
1542
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1543
-
1544
- .. math::
1545
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1546
-
1547
- Parameters
1548
- ----------
1549
- norm_type : float
1550
- Exponent for the pooling operation. Default: 2.0
1551
- kernel_size : int or sequence of int
1552
- An integer, or a sequence of integers defining the window to reduce over.
1553
- stride : int or sequence of int, optional
1554
- An integer, or a sequence of integers, representing the inter-window stride.
1555
- Default: kernel_size
1556
- padding : str, int or sequence of tuple, optional
1557
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
1558
- of n `(low, high)` integer pairs that give the padding to apply before
1559
- and after each spatial dimension. Default: 'VALID'
1560
- channel_axis : int, optional
1561
- Axis of the spatial channels for which pooling is skipped.
1562
- If ``None``, there is no channel axis. Default: -1
1563
- name : str, optional
1564
- The object name.
1565
- in_size : Sequence of int, optional
1566
- The shape of the input tensor.
1567
-
1568
- Examples
1569
- --------
1570
- .. code-block:: python
1571
-
1572
- >>> import brainstate
1573
- >>> # power-average pooling of square window of size=3, stride=2
1574
- >>> m = LPPool2d(2, 3, stride=2)
1575
- >>> # pool of non-square window with norm_type=1.5
1576
- >>> m = LPPool2d(1.5, (3, 2), stride=(2, 1), channel_axis=-1)
1577
- >>> input = brainstate.random.randn(20, 50, 32, 16)
1578
- >>> output = m(input)
1579
- >>> output.shape
1580
- (20, 24, 31, 16)
1581
- """
1582
- __module__ = 'brainstate.nn'
1583
-
1584
- def __init__(
1585
- self,
1586
- norm_type: float,
1587
- kernel_size: Size,
1588
- stride: Union[int, Sequence[int]] = None,
1589
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1590
- channel_axis: Optional[int] = -1,
1591
- name: Optional[str] = None,
1592
- in_size: Optional[Size] = None,
1593
- ):
1594
- super().__init__(
1595
- norm_type=norm_type,
1596
- pool_dim=2,
1597
- kernel_size=kernel_size,
1598
- stride=stride,
1599
- padding=padding,
1600
- channel_axis=channel_axis,
1601
- name=name,
1602
- in_size=in_size
1603
- )
1604
-
1605
-
1606
- class LPPool3d(_LPPool):
1607
- r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
1608
-
1609
- On each window, the function computed is:
1610
-
1611
- .. math::
1612
- f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1613
-
1614
- - At :math:`p = \infty`, one gets max pooling
1615
- - At :math:`p = 1`, one gets average pooling (with absolute values)
1616
- - At :math:`p = 2`, one gets root mean square (RMS) pooling
1617
-
1618
- Shape:
1619
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1620
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
1621
-
1622
- .. math::
1623
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1624
-
1625
- .. math::
1626
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1627
-
1628
- .. math::
1629
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
1630
-
1631
- Parameters
1632
- ----------
1633
- norm_type : float
1634
- Exponent for the pooling operation. Default: 2.0
1635
- kernel_size : int or sequence of int
1636
- An integer, or a sequence of integers defining the window to reduce over.
1637
- stride : int or sequence of int, optional
1638
- An integer, or a sequence of integers, representing the inter-window stride.
1639
- Default: kernel_size
1640
- padding : str, int or sequence of tuple, optional
1641
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
1642
- of n `(low, high)` integer pairs that give the padding to apply before
1643
- and after each spatial dimension. Default: 'VALID'
1644
- channel_axis : int, optional
1645
- Axis of the spatial channels for which pooling is skipped.
1646
- If ``None``, there is no channel axis. Default: -1
1647
- name : str, optional
1648
- The object name.
1649
- in_size : Sequence of int, optional
1650
- The shape of the input tensor.
1651
-
1652
- Examples
1653
- --------
1654
- .. code-block:: python
1655
-
1656
- >>> import brainstate
1657
- >>> # power-average pooling of cube window of size=3, stride=2
1658
- >>> m = LPPool3d(2, 3, stride=2)
1659
- >>> # pool of non-cubic window with norm_type=1.5
1660
- >>> m = LPPool3d(1.5, (3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
1661
- >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
1662
- >>> output = m(input)
1663
- >>> output.shape
1664
- (20, 24, 43, 15, 16)
1665
- """
1666
- __module__ = 'brainstate.nn'
1667
-
1668
- def __init__(
1669
- self,
1670
- norm_type: float,
1671
- kernel_size: Size,
1672
- stride: Union[int, Sequence[int]] = None,
1673
- padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
1674
- channel_axis: Optional[int] = -1,
1675
- name: Optional[str] = None,
1676
- in_size: Optional[Size] = None,
1677
- ):
1678
- super().__init__(
1679
- norm_type=norm_type,
1680
- pool_dim=3,
1681
- kernel_size=kernel_size,
1682
- stride=stride,
1683
- padding=padding,
1684
- channel_axis=channel_axis,
1685
- name=name,
1686
- in_size=in_size
1687
- )
1688
-
1689
-
1690
- def _adaptive_pool1d(x, target_size: int, operation: Callable):
1691
- """Adaptive pool 1D.
1692
-
1693
- Args:
1694
- x: The input. Should be a JAX array of shape `(dim,)`.
1695
- target_size: The shape of the output after the pooling operation `(target_size,)`.
1696
- operation: The pooling operation to be performed on the input array.
1697
-
1698
- Returns:
1699
- A JAX array of shape `(target_size, )`.
1700
- """
1701
- size = jnp.size(x)
1702
- num_head_arrays = size % target_size
1703
- num_block = size // target_size
1704
- if num_head_arrays != 0:
1705
- head_end_index = num_head_arrays * (num_block + 1)
1706
- heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
1707
- tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
1708
- outs = jnp.concatenate([heads, tails])
1709
- else:
1710
- outs = jax.vmap(operation)(x.reshape(-1, num_block))
1711
- return outs
1712
-
1713
-
1714
- def _generate_vmap(fun: Callable, map_axes: List[int]):
1715
- map_axes = sorted(map_axes)
1716
- for axis in map_axes:
1717
- fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
1718
- return fun
1719
-
1720
-
1721
- class _AdaptivePool(Module):
1722
- """General N dimensional adaptive down-sampling to a target shape.
1723
-
1724
- Parameters
1725
- ----------
1726
- in_size: Sequence of int
1727
- The shape of the input tensor.
1728
- target_size: int, sequence of int
1729
- The target output shape.
1730
- num_spatial_dims: int
1731
- The number of spatial dimensions.
1732
- channel_axis: int, optional
1733
- Axis of the spatial channels for which pooling is skipped.
1734
- If ``None``, there is no channel axis.
1735
- operation: Callable
1736
- The down-sampling operation.
1737
- name: str
1738
- The class name.
1739
- """
1740
-
1741
- def __init__(
1742
- self,
1743
- in_size: Size,
1744
- target_size: Size,
1745
- num_spatial_dims: int,
1746
- operation: Callable,
1747
- channel_axis: Optional[int] = -1,
1748
- name: Optional[str] = None,
1749
- ):
1750
- super().__init__(name=name)
1751
-
1752
- self.channel_axis = channel_axis
1753
- self.operation = operation
1754
- if isinstance(target_size, int):
1755
- self.target_shape = (target_size,) * num_spatial_dims
1756
- elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
1757
- self.target_shape = target_size
1758
- else:
1759
- raise ValueError("`target_size` must either be an int or tuple of length "
1760
- f"{num_spatial_dims} containing ints.")
1761
-
1762
- # in & out shapes
1763
- if in_size is not None:
1764
- in_size = tuple(in_size)
1765
- self.in_size = in_size
1766
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
1767
- self.out_size = y.shape[1:]
1768
-
1769
- def update(self, x):
1770
- """Input-output mapping.
1771
-
1772
- Parameters
1773
- ----------
1774
- x: Array
1775
- Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
1776
- or `(..., dim_1, dim_2)`.
1777
- """
1778
- # channel axis
1779
- channel_axis = self.channel_axis
1780
-
1781
- if channel_axis:
1782
- if not 0 <= abs(channel_axis) < x.ndim:
1783
- raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
1784
- if channel_axis < 0:
1785
- channel_axis = x.ndim + channel_axis
1786
- # input dimension
1787
- if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
1788
- raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
1789
- f"dimensions (channel_axis={self.channel_axis}). "
1790
- f"But got {x.ndim} dimensions.")
1791
- # pooling dimensions
1792
- pool_dims = list(range(x.ndim))
1793
- if channel_axis:
1794
- pool_dims.pop(channel_axis)
1795
-
1796
- # pooling
1797
- for i, di in enumerate(pool_dims[-len(self.target_shape):]):
1798
- poo_axes = [j for j in range(x.ndim) if j != di]
1799
- op = _generate_vmap(_adaptive_pool1d, poo_axes)
1800
- x = op(x, self.target_shape[i], self.operation)
1801
- return x
1802
-
1803
-
1804
- class AdaptiveAvgPool1d(_AdaptivePool):
1805
- r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
1806
-
1807
- The output size is :math:`L_{out}`, for any input size.
1808
- The number of output features is equal to the number of input planes.
1809
-
1810
- Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1811
- output size, making it useful for creating fixed-size representations from variable-sized inputs.
1812
-
1813
- Shape:
1814
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
1815
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
1816
- :math:`L_{out}=\text{target\_size}`.
1817
-
1818
- Parameters
1819
- ----------
1820
- target_size : int or sequence of int
1821
- The target output size. The number of output features for each channel.
1822
- channel_axis : int, optional
1823
- Axis of the spatial channels for which pooling is skipped.
1824
- If ``None``, there is no channel axis. Default: -1
1825
- name : str, optional
1826
- The name of the module.
1827
- in_size : Sequence of int, optional
1828
- The shape of the input tensor for shape inference.
1829
-
1830
- Examples
1831
- --------
1832
- .. code-block:: python
1833
-
1834
- >>> import brainstate
1835
- >>> # Target output size of 5
1836
- >>> m = AdaptiveAvgPool1d(5)
1837
- >>> input = brainstate.random.randn(1, 64, 8)
1838
- >>> output = m(input)
1839
- >>> output.shape
1840
- (1, 5, 8)
1841
- >>> # Can handle variable input sizes
1842
- >>> input2 = brainstate.random.randn(1, 32, 8)
1843
- >>> output2 = m(input2)
1844
- >>> output2.shape
1845
- (1, 5, 8) # Same output size regardless of input size
1846
-
1847
- See Also
1848
- --------
1849
- AvgPool1d : Non-adaptive 1D average pooling.
1850
- AdaptiveMaxPool1d : Adaptive 1D max pooling.
1851
- """
1852
- __module__ = 'brainstate.nn'
1853
-
1854
- def __init__(
1855
- self,
1856
- target_size: Union[int, Sequence[int]],
1857
- channel_axis: Optional[int] = -1,
1858
- name: Optional[str] = None,
1859
- in_size: Optional[Sequence[int]] = None,
1860
- ):
1861
- super().__init__(
1862
- in_size=in_size,
1863
- target_size=target_size,
1864
- channel_axis=channel_axis,
1865
- num_spatial_dims=1,
1866
- operation=jnp.mean,
1867
- name=name
1868
- )
1869
-
1870
-
1871
- class AdaptiveAvgPool2d(_AdaptivePool):
1872
- r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
1873
-
1874
- The output is of size :math:`H_{out} \times W_{out}`, for any input size.
1875
- The number of output features is equal to the number of input planes.
1876
-
1877
- Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1878
- output size, making it useful for creating fixed-size representations from variable-sized inputs.
1879
-
1880
- Shape:
1881
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
1882
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1883
- :math:`(H_{out}, W_{out})=\text{target\_size}`.
1884
-
1885
- Parameters
1886
- ----------
1887
- target_size : int or tuple of int
1888
- The target output size. If a single integer is provided, the output will be a square
1889
- of that size. If a tuple is provided, it specifies (H_out, W_out).
1890
- Use None for dimensions that should not be pooled.
1891
- channel_axis : int, optional
1892
- Axis of the spatial channels for which pooling is skipped.
1893
- If ``None``, there is no channel axis. Default: -1
1894
- name : str, optional
1895
- The name of the module.
1896
- in_size : Sequence of int, optional
1897
- The shape of the input tensor for shape inference.
1898
-
1899
- Examples
1900
- --------
1901
- .. code-block:: python
1902
-
1903
- >>> import brainstate
1904
- >>> # Target output size of 5x7
1905
- >>> m = AdaptiveAvgPool2d((5, 7))
1906
- >>> input = brainstate.random.randn(1, 8, 9, 64)
1907
- >>> output = m(input)
1908
- >>> output.shape
1909
- (1, 5, 7, 64)
1910
- >>> # Target output size of 7x7 (square)
1911
- >>> m = AdaptiveAvgPool2d(7)
1912
- >>> input = brainstate.random.randn(1, 10, 9, 64)
1913
- >>> output = m(input)
1914
- >>> output.shape
1915
- (1, 7, 7, 64)
1916
- >>> # Target output size of 10x7
1917
- >>> m = AdaptiveAvgPool2d((None, 7))
1918
- >>> input = brainstate.random.randn(1, 10, 9, 64)
1919
- >>> output = m(input)
1920
- >>> output.shape
1921
- (1, 10, 7, 64)
1922
-
1923
- See Also
1924
- --------
1925
- AvgPool2d : Non-adaptive 2D average pooling.
1926
- AdaptiveMaxPool2d : Adaptive 2D max pooling.
1927
- """
1928
- __module__ = 'brainstate.nn'
1929
-
1930
- def __init__(
1931
- self,
1932
- target_size: Union[int, Sequence[int]],
1933
- channel_axis: Optional[int] = -1,
1934
- name: Optional[str] = None,
1935
- in_size: Optional[Sequence[int]] = None,
1936
- ):
1937
- super().__init__(
1938
- in_size=in_size,
1939
- target_size=target_size,
1940
- channel_axis=channel_axis,
1941
- num_spatial_dims=2,
1942
- operation=jnp.mean,
1943
- name=name
1944
- )
1945
-
1946
-
1947
- class AdaptiveAvgPool3d(_AdaptivePool):
1948
- r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
1949
-
1950
- The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1951
- The number of output features is equal to the number of input planes.
1952
-
1953
- Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1954
- output size, making it useful for creating fixed-size representations from variable-sized inputs.
1955
-
1956
- Shape:
1957
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1958
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1959
- where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
1960
-
1961
- Parameters
1962
- ----------
1963
- target_size : int or tuple of int
1964
- The target output size. If a single integer is provided, the output will be a cube
1965
- of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
1966
- Use None for dimensions that should not be pooled.
1967
- channel_axis : int, optional
1968
- Axis of the spatial channels for which pooling is skipped.
1969
- If ``None``, there is no channel axis. Default: -1
1970
- name : str, optional
1971
- The name of the module.
1972
- in_size : Sequence of int, optional
1973
- The shape of the input tensor for shape inference.
1974
-
1975
- Examples
1976
- --------
1977
- .. code-block:: python
1978
-
1979
- >>> import brainstate
1980
- >>> # Target output size of 5x7x9
1981
- >>> m = AdaptiveAvgPool3d((5, 7, 9))
1982
- >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
1983
- >>> output = m(input)
1984
- >>> output.shape
1985
- (1, 5, 7, 9, 64)
1986
- >>> # Target output size of 7x7x7 (cube)
1987
- >>> m = AdaptiveAvgPool3d(7)
1988
- >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1989
- >>> output = m(input)
1990
- >>> output.shape
1991
- (1, 7, 7, 7, 64)
1992
- >>> # Target output size of 7x9x8
1993
- >>> m = AdaptiveAvgPool3d((7, None, None))
1994
- >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1995
- >>> output = m(input)
1996
- >>> output.shape
1997
- (1, 7, 9, 8, 64)
1998
-
1999
- See Also
2000
- --------
2001
- AvgPool3d : Non-adaptive 3D average pooling.
2002
- AdaptiveMaxPool3d : Adaptive 3D max pooling.
2003
- """
2004
- __module__ = 'brainstate.nn'
2005
-
2006
- def __init__(
2007
- self,
2008
- target_size: Union[int, Sequence[int]],
2009
- channel_axis: Optional[int] = -1,
2010
- name: Optional[str] = None,
2011
- in_size: Optional[Sequence[int]] = None,
2012
- ):
2013
- super().__init__(
2014
- in_size=in_size,
2015
- target_size=target_size,
2016
- channel_axis=channel_axis,
2017
- num_spatial_dims=3,
2018
- operation=jnp.mean,
2019
- name=name
2020
- )
2021
-
2022
-
2023
- class AdaptiveMaxPool1d(_AdaptivePool):
2024
- r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
2025
-
2026
- The output size is :math:`L_{out}`, for any input size.
2027
- The number of output features is equal to the number of input planes.
2028
-
2029
- Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2030
- output size, making it useful for creating fixed-size representations from variable-sized inputs.
2031
-
2032
- Shape:
2033
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
2034
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
2035
- :math:`L_{out}=\text{target\_size}`.
2036
-
2037
- Parameters
2038
- ----------
2039
- target_size : int or sequence of int
2040
- The target output size. The number of output features for each channel.
2041
- channel_axis : int, optional
2042
- Axis of the spatial channels for which pooling is skipped.
2043
- If ``None``, there is no channel axis. Default: -1
2044
- name : str, optional
2045
- The name of the module.
2046
- in_size : Sequence of int, optional
2047
- The shape of the input tensor for shape inference.
2048
-
2049
- Examples
2050
- --------
2051
- .. code-block:: python
2052
-
2053
- >>> import brainstate
2054
- >>> # Target output size of 5
2055
- >>> m = AdaptiveMaxPool1d(5)
2056
- >>> input = brainstate.random.randn(1, 64, 8)
2057
- >>> output = m(input)
2058
- >>> output.shape
2059
- (1, 5, 8)
2060
- >>> # Can handle variable input sizes
2061
- >>> input2 = brainstate.random.randn(1, 32, 8)
2062
- >>> output2 = m(input2)
2063
- >>> output2.shape
2064
- (1, 5, 8) # Same output size regardless of input size
2065
-
2066
- See Also
2067
- --------
2068
- MaxPool1d : Non-adaptive 1D max pooling.
2069
- AdaptiveAvgPool1d : Adaptive 1D average pooling.
2070
- """
2071
- __module__ = 'brainstate.nn'
2072
-
2073
- def __init__(
2074
- self,
2075
- target_size: Union[int, Sequence[int]],
2076
- channel_axis: Optional[int] = -1,
2077
- name: Optional[str] = None,
2078
- in_size: Optional[Sequence[int]] = None,
2079
- ):
2080
- super().__init__(
2081
- in_size=in_size,
2082
- target_size=target_size,
2083
- channel_axis=channel_axis,
2084
- num_spatial_dims=1,
2085
- operation=jnp.max,
2086
- name=name
2087
- )
2088
-
2089
-
2090
- class AdaptiveMaxPool2d(_AdaptivePool):
2091
- r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
2092
-
2093
- The output is of size :math:`H_{out} \times W_{out}`, for any input size.
2094
- The number of output features is equal to the number of input planes.
2095
-
2096
- Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2097
- output size, making it useful for creating fixed-size representations from variable-sized inputs.
2098
-
2099
- Shape:
2100
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
2101
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
2102
- :math:`(H_{out}, W_{out})=\text{target\_size}`.
2103
-
2104
- Parameters
2105
- ----------
2106
- target_size : int or tuple of int
2107
- The target output size. If a single integer is provided, the output will be a square
2108
- of that size. If a tuple is provided, it specifies (H_out, W_out).
2109
- Use None for dimensions that should not be pooled.
2110
- channel_axis : int, optional
2111
- Axis of the spatial channels for which pooling is skipped.
2112
- If ``None``, there is no channel axis. Default: -1
2113
- name : str, optional
2114
- The name of the module.
2115
- in_size : Sequence of int, optional
2116
- The shape of the input tensor for shape inference.
2117
-
2118
- Examples
2119
- --------
2120
- .. code-block:: python
2121
-
2122
- >>> import brainstate
2123
- >>> # Target output size of 5x7
2124
- >>> m = AdaptiveMaxPool2d((5, 7))
2125
- >>> input = brainstate.random.randn(1, 8, 9, 64)
2126
- >>> output = m(input)
2127
- >>> output.shape
2128
- (1, 5, 7, 64)
2129
- >>> # Target output size of 7x7 (square)
2130
- >>> m = AdaptiveMaxPool2d(7)
2131
- >>> input = brainstate.random.randn(1, 10, 9, 64)
2132
- >>> output = m(input)
2133
- >>> output.shape
2134
- (1, 7, 7, 64)
2135
- >>> # Target output size of 10x7
2136
- >>> m = AdaptiveMaxPool2d((None, 7))
2137
- >>> input = brainstate.random.randn(1, 10, 9, 64)
2138
- >>> output = m(input)
2139
- >>> output.shape
2140
- (1, 10, 7, 64)
2141
-
2142
- See Also
2143
- --------
2144
- MaxPool2d : Non-adaptive 2D max pooling.
2145
- AdaptiveAvgPool2d : Adaptive 2D average pooling.
2146
- """
2147
- __module__ = 'brainstate.nn'
2148
-
2149
- def __init__(
2150
- self,
2151
- target_size: Union[int, Sequence[int]],
2152
- channel_axis: Optional[int] = -1,
2153
- name: Optional[str] = None,
2154
- in_size: Optional[Sequence[int]] = None,
2155
- ):
2156
- super().__init__(
2157
- in_size=in_size,
2158
- target_size=target_size,
2159
- channel_axis=channel_axis,
2160
- num_spatial_dims=2,
2161
- operation=jnp.max,
2162
- name=name
2163
- )
2164
-
2165
-
2166
- class AdaptiveMaxPool3d(_AdaptivePool):
2167
- r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
2168
-
2169
- The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
2170
- The number of output features is equal to the number of input planes.
2171
-
2172
- Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2173
- output size, making it useful for creating fixed-size representations from variable-sized inputs.
2174
-
2175
- Shape:
2176
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
2177
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
2178
- where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
2179
-
2180
- Parameters
2181
- ----------
2182
- target_size : int or tuple of int
2183
- The target output size. If a single integer is provided, the output will be a cube
2184
- of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
2185
- Use None for dimensions that should not be pooled.
2186
- channel_axis : int, optional
2187
- Axis of the spatial channels for which pooling is skipped.
2188
- If ``None``, there is no channel axis. Default: -1
2189
- name : str, optional
2190
- The name of the module.
2191
- in_size : Sequence of int, optional
2192
- The shape of the input tensor for shape inference.
2193
-
2194
- Examples
2195
- --------
2196
- .. code-block:: python
2197
-
2198
- >>> import brainstate
2199
- >>> # Target output size of 5x7x9
2200
- >>> m = AdaptiveMaxPool3d((5, 7, 9))
2201
- >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
2202
- >>> output = m(input)
2203
- >>> output.shape
2204
- (1, 5, 7, 9, 64)
2205
- >>> # Target output size of 7x7x7 (cube)
2206
- >>> m = AdaptiveMaxPool3d(7)
2207
- >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
2208
- >>> output = m(input)
2209
- >>> output.shape
2210
- (1, 7, 7, 7, 64)
2211
- >>> # Target output size of 7x9x8
2212
- >>> m = AdaptiveMaxPool3d((7, None, None))
2213
- >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
2214
- >>> output = m(input)
2215
- >>> output.shape
2216
- (1, 7, 9, 8, 64)
2217
-
2218
- See Also
2219
- --------
2220
- MaxPool3d : Non-adaptive 3D max pooling.
2221
- AdaptiveAvgPool3d : Adaptive 3D average pooling.
2222
- """
2223
- __module__ = 'brainstate.nn'
2224
-
2225
- def __init__(
2226
- self,
2227
- target_size: Union[int, Sequence[int]],
2228
- channel_axis: Optional[int] = -1,
2229
- name: Optional[str] = None,
2230
- in_size: Optional[Sequence[int]] = None,
2231
- ):
2232
- super().__init__(
2233
- in_size=in_size,
2234
- target_size=target_size,
2235
- channel_axis=channel_axis,
2236
- num_spatial_dims=3,
2237
- operation=jnp.max,
2238
- name=name
2239
- )
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import functools
19
+ from typing import Sequence, Optional
20
+ from typing import Union, Tuple, Callable, List
21
+
22
+ import brainunit as u
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ from brainstate import environ
28
+ from brainstate.typing import Size
29
+ from ._module import Module
30
+
31
+ __all__ = [
32
+ 'Flatten', 'Unflatten',
33
+
34
+ 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
35
+ 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
36
+ 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
37
+ 'LPPool1d', 'LPPool2d', 'LPPool3d',
38
+
39
+ 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
40
+ 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
41
+ ]
42
+
43
+
44
+ class Flatten(Module):
45
+ r"""
46
+ Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
47
+
48
+ Shape:
49
+ - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
50
+ where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
51
+ number of dimensions including none.
52
+ - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
53
+
54
+ Parameters
55
+ ----------
56
+ start_axis : int, optional
57
+ First dim to flatten (default = 0).
58
+ end_axis : int, optional
59
+ Last dim to flatten (default = -1).
60
+ in_size : Sequence of int, optional
61
+ The shape of the input tensor.
62
+
63
+ Examples
64
+ --------
65
+ .. code-block:: python
66
+
67
+ >>> import brainstate
68
+ >>> inp = brainstate.random.randn(32, 1, 5, 5)
69
+ >>> # With default parameters
70
+ >>> m = Flatten()
71
+ >>> output = m(inp)
72
+ >>> output.shape
73
+ (32, 25)
74
+ >>> # With non-default parameters
75
+ >>> m = Flatten(0, 2)
76
+ >>> output = m(inp)
77
+ >>> output.shape
78
+ (160, 5)
79
+ """
80
+ __module__ = 'brainstate.nn'
81
+
82
+ def __init__(
83
+ self,
84
+ start_axis: int = 0,
85
+ end_axis: int = -1,
86
+ in_size: Optional[Size] = None
87
+ ) -> None:
88
+ super().__init__()
89
+ self.start_axis = start_axis
90
+ self.end_axis = end_axis
91
+
92
+ if in_size is not None:
93
+ self.in_size = tuple(in_size)
94
+ y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
95
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
96
+ self.out_size = y.shape
97
+
98
+ def update(self, x):
99
+ if self._in_size is None:
100
+ start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
101
+ else:
102
+ assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
103
+ dim_diff = x.ndim - len(self.in_size)
104
+ if self.in_size != x.shape[dim_diff:]:
105
+ raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
106
+ if self.start_axis >= 0:
107
+ start_axis = self.start_axis + dim_diff
108
+ else:
109
+ start_axis = x.ndim + self.start_axis
110
+ return u.math.flatten(x, start_axis, self.end_axis)
111
+
112
+
113
+ class Unflatten(Module):
114
+ r"""
115
+ Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
116
+
117
+ * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
118
+ be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
119
+
120
+ * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
121
+ a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
122
+ (tuple of `(name, size)` tuples) for `NamedTensor` input.
123
+
124
+ Shape:
125
+ - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
126
+ dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
127
+ - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
128
+ :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
129
+
130
+ Parameters
131
+ ----------
132
+ axis : int
133
+ Dimension to be unflattened.
134
+ sizes : Sequence of int
135
+ New shape of the unflattened dimension.
136
+ name : str, optional
137
+ The name of the module.
138
+ in_size : Sequence of int, optional
139
+ The shape of the input tensor.
140
+ """
141
+ __module__ = 'brainstate.nn'
142
+
143
+ def __init__(
144
+ self,
145
+ axis: int,
146
+ sizes: Size,
147
+ name: str = None,
148
+ in_size: Optional[Size] = None
149
+ ) -> None:
150
+ super().__init__(name=name)
151
+
152
+ self.axis = axis
153
+ self.sizes = sizes
154
+ if isinstance(sizes, (tuple, list)):
155
+ for idx, elem in enumerate(sizes):
156
+ if not isinstance(elem, int):
157
+ raise TypeError("unflattened sizes must be tuple of ints, " +
158
+ "but found element of type {} at pos {}".format(type(elem).__name__, idx))
159
+ else:
160
+ raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
161
+
162
+ if in_size is not None:
163
+ self.in_size = tuple(in_size)
164
+ y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
165
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
166
+ self.out_size = y.shape
167
+
168
+ def update(self, x):
169
+ return u.math.unflatten(x, self.axis, self.sizes)
170
+
171
+
172
+ class _MaxPool(Module):
173
+ def __init__(
174
+ self,
175
+ init_value: float,
176
+ computation: Callable,
177
+ pool_dim: int,
178
+ kernel_size: Size,
179
+ stride: Union[int, Sequence[int]] = None,
180
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
181
+ channel_axis: Optional[int] = -1,
182
+ return_indices: bool = False,
183
+ name: Optional[str] = None,
184
+ in_size: Optional[Size] = None,
185
+ ):
186
+ super().__init__(name=name)
187
+
188
+ self.init_value = init_value
189
+ self.computation = computation
190
+ self.pool_dim = pool_dim
191
+ self.return_indices = return_indices
192
+
193
+ # kernel_size
194
+ if isinstance(kernel_size, int):
195
+ kernel_size = (kernel_size,) * pool_dim
196
+ elif isinstance(kernel_size, Sequence):
197
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
198
+ assert all(
199
+ [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
200
+ if len(kernel_size) != pool_dim:
201
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
202
+ else:
203
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
204
+ self.kernel_size = kernel_size
205
+
206
+ # stride
207
+ if stride is None:
208
+ stride = kernel_size
209
+ if isinstance(stride, int):
210
+ stride = (stride,) * pool_dim
211
+ elif isinstance(stride, Sequence):
212
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
213
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
214
+ if len(stride) != pool_dim:
215
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
216
+ else:
217
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
218
+ self.stride = stride
219
+
220
+ # padding
221
+ if isinstance(padding, str):
222
+ if padding not in ("SAME", "VALID"):
223
+ raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
224
+ elif isinstance(padding, int):
225
+ padding = [(padding, padding) for _ in range(pool_dim)]
226
+ elif isinstance(padding, (list, tuple)):
227
+ if isinstance(padding[0], int):
228
+ if len(padding) == pool_dim:
229
+ padding = [(x, x) for x in padding]
230
+ else:
231
+ raise ValueError(f'If padding is a sequence of ints, it '
232
+ f'should has the length of {pool_dim}.')
233
+ else:
234
+ if not all([isinstance(x, (tuple, list)) for x in padding]):
235
+ raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
236
+ if not all([len(x) == 2 for x in padding]):
237
+ raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
238
+ if len(padding) == 1:
239
+ padding = tuple(padding) * pool_dim
240
+ assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
241
+ else:
242
+ raise ValueError
243
+ self.padding = padding
244
+
245
+ # channel_axis
246
+ assert channel_axis is None or isinstance(channel_axis, int), \
247
+ f'channel_axis should be an int, but got {channel_axis}'
248
+ self.channel_axis = channel_axis
249
+
250
+ # in & out shapes
251
+ if in_size is not None:
252
+ in_size = tuple(in_size)
253
+ self.in_size = in_size
254
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
255
+ self.out_size = y.shape[1:]
256
+
257
+ def update(self, x):
258
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
259
+ if x.ndim < x_dim:
260
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
261
+ window_shape = tuple(self._infer_shape(x.ndim, self.kernel_size, 1))
262
+ stride = tuple(self._infer_shape(x.ndim, self.stride, 1))
263
+ if isinstance(self.padding, str):
264
+ padding = tuple(jax.lax.padtype_to_pads(x.shape, window_shape, stride, self.padding))
265
+ else:
266
+ padding = tuple(self._infer_shape(x.ndim, self.padding, element=(0, 0)))
267
+
268
+ if self.return_indices:
269
+ # For returning indices, we need to use a custom implementation
270
+ return self._pooling_with_indices(x, window_shape, stride, padding)
271
+
272
+ return jax.lax.reduce_window(
273
+ x,
274
+ init_value=self.init_value,
275
+ computation=self.computation,
276
+ window_dimensions=window_shape,
277
+ window_strides=stride,
278
+ padding=padding
279
+ )
280
+
281
+ def _pooling_with_indices(self, x, window_shape, stride, padding):
282
+ """Perform max pooling and return both pooled values and indices."""
283
+ total_size = x.size
284
+ flat_indices = jnp.arange(total_size, dtype=jnp.int32).reshape(x.shape)
285
+
286
+ init_val = jnp.asarray(self.init_value, dtype=x.dtype)
287
+ init_idx = jnp.array(total_size, dtype=flat_indices.dtype)
288
+
289
+ def reducer(acc, operand):
290
+ acc_val, acc_idx = acc
291
+ cur_val, cur_idx = operand
292
+
293
+ better = cur_val > acc_val
294
+ best_val = jnp.where(better, cur_val, acc_val)
295
+ best_idx = jnp.where(better, cur_idx, acc_idx)
296
+ tie = jnp.logical_and(cur_val == acc_val, cur_idx < acc_idx)
297
+ best_idx = jnp.where(tie, cur_idx, best_idx)
298
+
299
+ return best_val, best_idx
300
+
301
+ pooled, indices_result = jax.lax.reduce_window(
302
+ (x, flat_indices),
303
+ (init_val, init_idx),
304
+ reducer,
305
+ window_dimensions=window_shape,
306
+ window_strides=stride,
307
+ padding=padding
308
+ )
309
+
310
+ indices_result = jnp.where(indices_result == total_size, 0, indices_result)
311
+
312
+ return pooled, indices_result.astype(jnp.int32)
313
+
314
+ def _infer_shape(self, x_dim, inputs, element):
315
+ channel_axis = self.channel_axis
316
+ if channel_axis and not 0 <= abs(channel_axis) < x_dim:
317
+ raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
318
+ if channel_axis and channel_axis < 0:
319
+ channel_axis = x_dim + channel_axis
320
+ all_dims = list(range(x_dim))
321
+ if channel_axis is not None:
322
+ all_dims.pop(channel_axis)
323
+ pool_dims = all_dims[-self.pool_dim:]
324
+ results = [element] * x_dim
325
+ for i, dim in enumerate(pool_dims):
326
+ results[dim] = inputs[i]
327
+ return results
328
+
329
+
330
+ class _AvgPool(_MaxPool):
331
+ def update(self, x):
332
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
333
+ if x.ndim < x_dim:
334
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
335
+ dims = self._infer_shape(x.ndim, self.kernel_size, 1)
336
+ stride = self._infer_shape(x.ndim, self.stride, 1)
337
+ padding = (self.padding if isinstance(self.padding, str) else
338
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
339
+ pooled = jax.lax.reduce_window(x,
340
+ init_value=self.init_value,
341
+ computation=self.computation,
342
+ window_dimensions=dims,
343
+ window_strides=stride,
344
+ padding=padding)
345
+ if padding == "VALID":
346
+ # Avoid the extra reduce_window.
347
+ return pooled / np.prod(dims)
348
+ else:
349
+ # Count the number of valid entries at each input point, then use that for
350
+ # computing average. Assumes that any two arrays of same shape will be
351
+ # padded the same.
352
+ window_counts = jax.lax.reduce_window(jnp.ones_like(x),
353
+ init_value=self.init_value,
354
+ computation=self.computation,
355
+ window_dimensions=dims,
356
+ window_strides=stride,
357
+ padding=padding)
358
+ assert pooled.shape == window_counts.shape
359
+ return pooled / window_counts
360
+
361
+
362
+ class MaxPool1d(_MaxPool):
363
+ r"""Applies a 1D max pooling over an input signal composed of several input planes.
364
+
365
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
366
+ and output :math:`(N, L_{out}, C)` can be precisely described as:
367
+
368
+ .. math::
369
+ out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
370
+ input(N_i, stride \times k + m, C_j)
371
+
372
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
373
+ for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
374
+ sliding window. This `link`_ has a nice visualization of the pooling parameters.
375
+
376
+ Shape:
377
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
378
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
379
+
380
+ .. math::
381
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
382
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
383
+
384
+ Parameters
385
+ ----------
386
+ kernel_size : int or sequence of int
387
+ An integer, or a sequence of integers defining the window to reduce over.
388
+ stride : int or sequence of int, optional
389
+ An integer, or a sequence of integers, representing the inter-window stride.
390
+ Default: kernel_size
391
+ padding : str, int or sequence of tuple, optional
392
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
393
+ of n `(low, high)` integer pairs that give the padding to apply before
394
+ and after each spatial dimension. Default: 'VALID'
395
+ channel_axis : int, optional
396
+ Axis of the spatial channels for which pooling is skipped.
397
+ If ``None``, there is no channel axis. Default: -1
398
+ return_indices : bool, optional
399
+ If True, will return the max indices along with the outputs.
400
+ Useful for MaxUnpool1d. Default: False
401
+ name : str, optional
402
+ The object name.
403
+ in_size : Sequence of int, optional
404
+ The shape of the input tensor.
405
+
406
+ Examples
407
+ --------
408
+ .. code-block:: python
409
+
410
+ >>> import brainstate
411
+ >>> # pool of size=3, stride=2
412
+ >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
413
+ >>> input = brainstate.random.randn(20, 50, 16)
414
+ >>> output = m(input)
415
+ >>> output.shape
416
+ (20, 24, 16)
417
+
418
+ .. _link:
419
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
420
+ """
421
+ __module__ = 'brainstate.nn'
422
+
423
+ def __init__(
424
+ self,
425
+ kernel_size: Size,
426
+ stride: Union[int, Sequence[int]] = None,
427
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
428
+ channel_axis: Optional[int] = -1,
429
+ return_indices: bool = False,
430
+ name: Optional[str] = None,
431
+ in_size: Optional[Size] = None,
432
+ ):
433
+ super().__init__(in_size=in_size,
434
+ init_value=-jax.numpy.inf,
435
+ computation=jax.lax.max,
436
+ pool_dim=1,
437
+ kernel_size=kernel_size,
438
+ stride=stride,
439
+ padding=padding,
440
+ channel_axis=channel_axis,
441
+ return_indices=return_indices,
442
+ name=name)
443
+
444
+
445
+ class MaxPool2d(_MaxPool):
446
+ r"""Applies a 2D max pooling over an input signal composed of several input planes.
447
+
448
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
449
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
450
+ can be precisely described as:
451
+
452
+ .. math::
453
+ \begin{aligned}
454
+ out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
455
+ & \text{input}(N_i, \text{stride[0]} \times h + m,
456
+ \text{stride[1]} \times w + n, C_j)
457
+ \end{aligned}
458
+
459
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
460
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
461
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
462
+
463
+ Shape:
464
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
465
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
466
+
467
+ .. math::
468
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
469
+ \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
470
+
471
+ .. math::
472
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
473
+ \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
474
+
475
+ Parameters
476
+ ----------
477
+ kernel_size : int or sequence of int
478
+ An integer, or a sequence of integers defining the window to reduce over.
479
+ stride : int or sequence of int, optional
480
+ An integer, or a sequence of integers, representing the inter-window stride.
481
+ Default: kernel_size
482
+ padding : str, int or sequence of tuple, optional
483
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
484
+ of n `(low, high)` integer pairs that give the padding to apply before
485
+ and after each spatial dimension. Default: 'VALID'
486
+ channel_axis : int, optional
487
+ Axis of the spatial channels for which pooling is skipped.
488
+ If ``None``, there is no channel axis. Default: -1
489
+ return_indices : bool, optional
490
+ If True, will return the max indices along with the outputs.
491
+ Useful for MaxUnpool2d. Default: False
492
+ name : str, optional
493
+ The object name.
494
+ in_size : Sequence of int, optional
495
+ The shape of the input tensor.
496
+
497
+ Examples
498
+ --------
499
+ .. code-block:: python
500
+
501
+ >>> import brainstate
502
+ >>> # pool of square window of size=3, stride=2
503
+ >>> m = MaxPool2d(3, stride=2)
504
+ >>> # pool of non-square window
505
+ >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
506
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
507
+ >>> output = m(input)
508
+ >>> output.shape
509
+ (20, 24, 31, 16)
510
+
511
+ .. _link:
512
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
513
+ """
514
+ __module__ = 'brainstate.nn'
515
+
516
+ def __init__(
517
+ self,
518
+ kernel_size: Size,
519
+ stride: Union[int, Sequence[int]] = None,
520
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
521
+ channel_axis: Optional[int] = -1,
522
+ return_indices: bool = False,
523
+ name: Optional[str] = None,
524
+ in_size: Optional[Size] = None,
525
+ ):
526
+ super().__init__(in_size=in_size,
527
+ init_value=-jax.numpy.inf,
528
+ computation=jax.lax.max,
529
+ pool_dim=2,
530
+ kernel_size=kernel_size,
531
+ stride=stride,
532
+ padding=padding,
533
+ channel_axis=channel_axis,
534
+ return_indices=return_indices,
535
+ name=name)
536
+
537
+
538
+ class MaxPool3d(_MaxPool):
539
+ r"""Applies a 3D max pooling over an input signal composed of several input planes.
540
+
541
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
542
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
543
+ can be precisely described as:
544
+
545
+ .. math::
546
+ \begin{aligned}
547
+ \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
548
+ & \text{input}(N_i, \text{stride[0]} \times d + k,
549
+ \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
550
+ \end{aligned}
551
+
552
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
553
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
554
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
555
+
556
+ Shape:
557
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
558
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
559
+
560
+ .. math::
561
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
562
+ (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
563
+
564
+ .. math::
565
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
566
+ (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
567
+
568
+ .. math::
569
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
570
+ (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
571
+
572
+ Parameters
573
+ ----------
574
+ kernel_size : int or sequence of int
575
+ An integer, or a sequence of integers defining the window to reduce over.
576
+ stride : int or sequence of int, optional
577
+ An integer, or a sequence of integers, representing the inter-window stride.
578
+ Default: kernel_size
579
+ padding : str, int or sequence of tuple, optional
580
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
581
+ of n `(low, high)` integer pairs that give the padding to apply before
582
+ and after each spatial dimension. Default: 'VALID'
583
+ channel_axis : int, optional
584
+ Axis of the spatial channels for which pooling is skipped.
585
+ If ``None``, there is no channel axis. Default: -1
586
+ return_indices : bool, optional
587
+ If True, will return the max indices along with the outputs.
588
+ Useful for MaxUnpool3d. Default: False
589
+ name : str, optional
590
+ The object name.
591
+ in_size : Sequence of int, optional
592
+ The shape of the input tensor.
593
+
594
+ Examples
595
+ --------
596
+ .. code-block:: python
597
+
598
+ >>> import brainstate
599
+ >>> # pool of square window of size=3, stride=2
600
+ >>> m = MaxPool3d(3, stride=2)
601
+ >>> # pool of non-square window
602
+ >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
603
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
604
+ >>> output = m(input)
605
+ >>> output.shape
606
+ (20, 24, 43, 15, 16)
607
+
608
+ .. _link:
609
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
610
+ """
611
+ __module__ = 'brainstate.nn'
612
+
613
+ def __init__(
614
+ self,
615
+ kernel_size: Size,
616
+ stride: Union[int, Sequence[int]] = None,
617
+ padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
618
+ channel_axis: Optional[int] = -1,
619
+ return_indices: bool = False,
620
+ name: Optional[str] = None,
621
+ in_size: Optional[Size] = None,
622
+ ):
623
+ super().__init__(in_size=in_size,
624
+ init_value=-jax.numpy.inf,
625
+ computation=jax.lax.max,
626
+ pool_dim=3,
627
+ kernel_size=kernel_size,
628
+ stride=stride,
629
+ padding=padding,
630
+ channel_axis=channel_axis,
631
+ return_indices=return_indices,
632
+ name=name)
633
+
634
+
635
+ class _MaxUnpool(Module):
636
+ """Base class for max unpooling operations."""
637
+
638
+ def __init__(
639
+ self,
640
+ pool_dim: int,
641
+ kernel_size: Size,
642
+ stride: Union[int, Sequence[int]] = None,
643
+ padding: Union[int, Tuple[int, ...]] = 0,
644
+ channel_axis: Optional[int] = -1,
645
+ name: Optional[str] = None,
646
+ in_size: Optional[Size] = None,
647
+ ):
648
+ super().__init__(name=name)
649
+
650
+ self.pool_dim = pool_dim
651
+
652
+ # kernel_size
653
+ if isinstance(kernel_size, int):
654
+ kernel_size = (kernel_size,) * pool_dim
655
+ elif isinstance(kernel_size, Sequence):
656
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
657
+ assert all(
658
+ [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
659
+ if len(kernel_size) != pool_dim:
660
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
661
+ else:
662
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
663
+ self.kernel_size = kernel_size
664
+
665
+ # stride
666
+ if stride is None:
667
+ stride = kernel_size
668
+ if isinstance(stride, int):
669
+ stride = (stride,) * pool_dim
670
+ elif isinstance(stride, Sequence):
671
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
672
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
673
+ if len(stride) != pool_dim:
674
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
675
+ else:
676
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
677
+ self.stride = stride
678
+
679
+ # padding
680
+ if isinstance(padding, int):
681
+ padding = (padding,) * pool_dim
682
+ elif isinstance(padding, (tuple, list)):
683
+ if len(padding) != pool_dim:
684
+ raise ValueError(f'padding should have {pool_dim} values, but got {len(padding)}')
685
+ else:
686
+ raise TypeError(f'padding should be int or tuple of {pool_dim} ints.')
687
+ self.padding = padding
688
+
689
+ # channel_axis
690
+ assert channel_axis is None or isinstance(channel_axis, int), \
691
+ f'channel_axis should be an int, but got {channel_axis}'
692
+ self.channel_axis = channel_axis
693
+
694
+ # in & out shapes
695
+ if in_size is not None:
696
+ in_size = tuple(in_size)
697
+ self.in_size = in_size
698
+
699
+ def _compute_output_shape(self, input_shape, output_size=None):
700
+ """Compute the output shape after unpooling."""
701
+ if output_size is not None:
702
+ return output_size
703
+
704
+ # Calculate output shape based on kernel, stride, and padding
705
+ output_shape = []
706
+ for i in range(self.pool_dim):
707
+ dim_size = (input_shape[i] - 1) * self.stride[i] - 2 * self.padding[i] + self.kernel_size[i]
708
+ output_shape.append(dim_size)
709
+
710
+ return tuple(output_shape)
711
+
712
+ def _unpool_nd(self, x, indices, output_size=None):
713
+ """Perform N-dimensional max unpooling."""
714
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
715
+ if x.ndim < x_dim:
716
+ raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
717
+
718
+ # Determine output shape
719
+ if output_size is None:
720
+ # Infer output shape from input shape
721
+ spatial_dims = self._get_spatial_dims(x.shape)
722
+ output_spatial_shape = self._compute_output_shape(spatial_dims, output_size)
723
+ output_shape = list(x.shape)
724
+
725
+ # Update spatial dimensions in output shape
726
+ spatial_start = self._get_spatial_start_idx(x.ndim)
727
+ for i, size in enumerate(output_spatial_shape):
728
+ output_shape[spatial_start + i] = size
729
+ output_shape = tuple(output_shape)
730
+ else:
731
+ # Use provided output size
732
+ if isinstance(output_size, (list, tuple)):
733
+ if len(output_size) == x.ndim:
734
+ # Full output shape provided
735
+ output_shape = tuple(output_size)
736
+ else:
737
+ # Only spatial dimensions provided
738
+ if len(output_size) != self.pool_dim:
739
+ raise ValueError(f"output_size must have {self.pool_dim} spatial dimensions, got {len(output_size)}")
740
+ output_shape = list(x.shape)
741
+ spatial_start = self._get_spatial_start_idx(x.ndim)
742
+ for i, size in enumerate(output_size):
743
+ output_shape[spatial_start + i] = size
744
+ output_shape = tuple(output_shape)
745
+ else:
746
+ # Single integer provided, use for all spatial dims
747
+ output_shape = list(x.shape)
748
+ spatial_start = self._get_spatial_start_idx(x.ndim)
749
+ for i in range(self.pool_dim):
750
+ output_shape[spatial_start + i] = output_size
751
+ output_shape = tuple(output_shape)
752
+
753
+ # Create output array filled with zeros
754
+ output = jnp.zeros(output_shape, dtype=x.dtype)
755
+
756
+ # # Scatter input values to output using indices
757
+ # # Flatten spatial dimensions for easier indexing
758
+ # batch_dims = x.ndim - self.pool_dim - (0 if self.channel_axis is None else 1)
759
+ #
760
+ # # Reshape for processing
761
+ # if batch_dims > 0:
762
+ # batch_shape = x.shape[:batch_dims]
763
+ # if self.channel_axis is not None and self.channel_axis < batch_dims:
764
+ # # Channel axis is before spatial dims
765
+ # channel_idx = self.channel_axis
766
+ # n_channels = x.shape[channel_idx]
767
+ # elif self.channel_axis is not None:
768
+ # # Channel axis is after spatial dims
769
+ # if self.channel_axis < 0:
770
+ # channel_idx = x.ndim + self.channel_axis
771
+ # else:
772
+ # channel_idx = self.channel_axis
773
+ # n_channels = x.shape[channel_idx]
774
+ # else:
775
+ # n_channels = None
776
+ # else:
777
+ # batch_shape = ()
778
+ # if self.channel_axis is not None:
779
+ # if self.channel_axis < 0:
780
+ # channel_idx = x.ndim + self.channel_axis
781
+ # else:
782
+ # channel_idx = self.channel_axis
783
+ # n_channels = x.shape[channel_idx]
784
+ # else:
785
+ # n_channels = None
786
+
787
+ # Use JAX's scatter operation
788
+ # Flatten the indices to 1D for scatter
789
+ flat_indices = indices.ravel()
790
+ flat_values = x.ravel()
791
+ flat_output = output.ravel()
792
+
793
+ # Scatter the values
794
+ flat_output = flat_output.at[flat_indices].set(flat_values)
795
+
796
+ # Reshape back to original shape
797
+ output = flat_output.reshape(output_shape)
798
+
799
+ return output
800
+
801
+ def _get_spatial_dims(self, shape):
802
+ """Extract spatial dimensions from input shape."""
803
+ if self.channel_axis is None:
804
+ return shape[-self.pool_dim:]
805
+ else:
806
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else len(shape) + self.channel_axis
807
+ all_dims = list(range(len(shape)))
808
+ all_dims.pop(channel_axis)
809
+ return tuple(shape[i] for i in all_dims[-self.pool_dim:])
810
+
811
+ def _get_spatial_start_idx(self, ndim):
812
+ """Get the starting index of spatial dimensions."""
813
+ if self.channel_axis is None:
814
+ return ndim - self.pool_dim
815
+ else:
816
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else ndim + self.channel_axis
817
+ if channel_axis < ndim - self.pool_dim:
818
+ return ndim - self.pool_dim
819
+ else:
820
+ return ndim - self.pool_dim - 1
821
+
822
+ def update(self, x, indices, output_size=None):
823
+ """Forward pass of MaxUnpool1d.
824
+
825
+ Parameters
826
+ ----------
827
+ x : Array
828
+ Input tensor from MaxPool1d
829
+ indices : Array
830
+ Indices of maximum values from MaxPool1d
831
+ output_size : int or tuple, optional
832
+ The targeted output size
833
+
834
+ Returns
835
+ -------
836
+ Array
837
+ Unpooled output
838
+ """
839
+ return self._unpool_nd(x, indices, output_size)
840
+
841
+
842
+ class MaxUnpool1d(_MaxUnpool):
843
+ r"""Computes a partial inverse of MaxPool1d.
844
+
845
+ MaxPool1d is not fully invertible, since the non-maximal values are lost.
846
+ MaxUnpool1d takes in as input the output of MaxPool1d including the indices
847
+ of the maximal values and computes a partial inverse in which all
848
+ non-maximal values are set to zero.
849
+
850
+ Note:
851
+ This function may produce nondeterministic gradients when given tensors
852
+ on a CUDA device. See notes on reproducibility for more information.
853
+
854
+ Shape:
855
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`
856
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
857
+
858
+ .. math::
859
+ L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size}
860
+
861
+ or as given by :attr:`output_size` in the call operator
862
+
863
+ Parameters
864
+ ----------
865
+ kernel_size : int or tuple
866
+ Size of the max pooling window.
867
+ stride : int or tuple, optional
868
+ Stride of the max pooling window. Default: kernel_size
869
+ padding : int or tuple, optional
870
+ Padding that was added to the input. Default: 0
871
+ channel_axis : int, optional
872
+ Axis of the channels. Default: -1
873
+ name : str, optional
874
+ Name of the module.
875
+ in_size : Size, optional
876
+ Input size for shape inference.
877
+
878
+ Examples
879
+ --------
880
+ .. code-block:: python
881
+
882
+ >>> import brainstate
883
+ >>> import jax.numpy as jnp
884
+ >>> # Create pooling and unpooling layers
885
+ >>> pool = MaxPool1d(2, stride=2, return_indices=True, channel_axis=-1)
886
+ >>> unpool = MaxUnpool1d(2, stride=2, channel_axis=-1)
887
+ >>> input = brainstate.random.randn(20, 50, 16)
888
+ >>> output, indices = pool(input)
889
+ >>> unpooled = unpool(output, indices)
890
+ >>> # unpooled will have shape (20, 100, 16) with zeros at non-maximal positions
891
+ """
892
+ __module__ = 'brainstate.nn'
893
+
894
+ def __init__(
895
+ self,
896
+ kernel_size: Size,
897
+ stride: Union[int, Sequence[int]] = None,
898
+ padding: Union[int, Tuple[int, ...]] = 0,
899
+ channel_axis: Optional[int] = -1,
900
+ name: Optional[str] = None,
901
+ in_size: Optional[Size] = None,
902
+ ):
903
+ super().__init__(
904
+ pool_dim=1,
905
+ kernel_size=kernel_size,
906
+ stride=stride,
907
+ padding=padding,
908
+ channel_axis=channel_axis,
909
+ name=name,
910
+ in_size=in_size
911
+ )
912
+
913
+
914
+ class MaxUnpool2d(_MaxUnpool):
915
+ r"""Computes a partial inverse of MaxPool2d.
916
+
917
+ MaxPool2d is not fully invertible, since the non-maximal values are lost.
918
+ MaxUnpool2d takes in as input the output of MaxPool2d including the indices
919
+ of the maximal values and computes a partial inverse in which all
920
+ non-maximal values are set to zero.
921
+
922
+ Shape:
923
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
924
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
925
+
926
+ .. math::
927
+ H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
928
+
929
+ .. math::
930
+ W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
931
+
932
+ or as given by :attr:`output_size` in the call operator
933
+
934
+ Parameters
935
+ ----------
936
+ kernel_size : int or tuple
937
+ Size of the max pooling window.
938
+ stride : int or tuple, optional
939
+ Stride of the max pooling window. Default: kernel_size
940
+ padding : int or tuple, optional
941
+ Padding that was added to the input. Default: 0
942
+ channel_axis : int, optional
943
+ Axis of the channels. Default: -1
944
+ name : str, optional
945
+ Name of the module.
946
+ in_size : Size, optional
947
+ Input size for shape inference.
948
+
949
+ Examples
950
+ --------
951
+ .. code-block:: python
952
+
953
+ >>> import brainstate
954
+ >>> # Create pooling and unpooling layers
955
+ >>> pool = MaxPool2d(2, stride=2, return_indices=True, channel_axis=-1)
956
+ >>> unpool = MaxUnpool2d(2, stride=2, channel_axis=-1)
957
+ >>> input = brainstate.random.randn(1, 4, 4, 16)
958
+ >>> output, indices = pool(input)
959
+ >>> unpooled = unpool(output, indices)
960
+ >>> # unpooled will have shape (1, 8, 8, 16) with zeros at non-maximal positions
961
+ """
962
+ __module__ = 'brainstate.nn'
963
+
964
+ def __init__(
965
+ self,
966
+ kernel_size: Size,
967
+ stride: Union[int, Sequence[int]] = None,
968
+ padding: Union[int, Tuple[int, ...]] = 0,
969
+ channel_axis: Optional[int] = -1,
970
+ name: Optional[str] = None,
971
+ in_size: Optional[Size] = None,
972
+ ):
973
+ super().__init__(
974
+ pool_dim=2,
975
+ kernel_size=kernel_size,
976
+ stride=stride,
977
+ padding=padding,
978
+ channel_axis=channel_axis,
979
+ name=name,
980
+ in_size=in_size
981
+ )
982
+
983
+
984
+ class MaxUnpool3d(_MaxUnpool):
985
+ r"""Computes a partial inverse of MaxPool3d.
986
+
987
+ MaxPool3d is not fully invertible, since the non-maximal values are lost.
988
+ MaxUnpool3d takes in as input the output of MaxPool3d including the indices
989
+ of the maximal values and computes a partial inverse in which all
990
+ non-maximal values are set to zero.
991
+
992
+ Shape:
993
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`
994
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
995
+
996
+ .. math::
997
+ D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
998
+
999
+ .. math::
1000
+ H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
1001
+
1002
+ .. math::
1003
+ W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{kernel\_size}[2]
1004
+
1005
+ or as given by :attr:`output_size` in the call operator
1006
+
1007
+ Parameters
1008
+ ----------
1009
+ kernel_size : int or tuple
1010
+ Size of the max pooling window.
1011
+ stride : int or tuple, optional
1012
+ Stride of the max pooling window. Default: kernel_size
1013
+ padding : int or tuple, optional
1014
+ Padding that was added to the input. Default: 0
1015
+ channel_axis : int, optional
1016
+ Axis of the channels. Default: -1
1017
+ name : str, optional
1018
+ Name of the module.
1019
+ in_size : Size, optional
1020
+ Input size for shape inference.
1021
+
1022
+ Examples
1023
+ --------
1024
+ .. code-block:: python
1025
+
1026
+ >>> import brainstate
1027
+ >>> # Create pooling and unpooling layers
1028
+ >>> pool = MaxPool3d(2, stride=2, return_indices=True, channel_axis=-1)
1029
+ >>> unpool = MaxUnpool3d(2, stride=2, channel_axis=-1)
1030
+ >>> input = brainstate.random.randn(1, 4, 4, 4, 16)
1031
+ >>> output, indices = pool(input)
1032
+ >>> unpooled = unpool(output, indices)
1033
+ >>> # unpooled will have shape (1, 8, 8, 8, 16) with zeros at non-maximal positions
1034
+ """
1035
+ __module__ = 'brainstate.nn'
1036
+
1037
+ def __init__(
1038
+ self,
1039
+ kernel_size: Size,
1040
+ stride: Union[int, Sequence[int]] = None,
1041
+ padding: Union[int, Tuple[int, ...]] = 0,
1042
+ channel_axis: Optional[int] = -1,
1043
+ name: Optional[str] = None,
1044
+ in_size: Optional[Size] = None,
1045
+ ):
1046
+ super().__init__(
1047
+ pool_dim=3,
1048
+ kernel_size=kernel_size,
1049
+ stride=stride,
1050
+ padding=padding,
1051
+ channel_axis=channel_axis,
1052
+ name=name,
1053
+ in_size=in_size
1054
+ )
1055
+
1056
+
1057
+ class AvgPool1d(_AvgPool):
1058
+ r"""Applies a 1D average pooling over an input signal composed of several input planes.
1059
+
1060
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
1061
+ output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
1062
+ can be precisely described as:
1063
+
1064
+ .. math::
1065
+
1066
+ \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
1067
+ \text{input}(N_i, \text{stride} \times l + m, C_j)
1068
+
1069
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
1070
+ for :attr:`padding` number of points.
1071
+
1072
+ Shape:
1073
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
1074
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
1075
+
1076
+ .. math::
1077
+ L_{out} = \left\lfloor \frac{L_{in} +
1078
+ 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
1079
+
1080
+ Parameters
1081
+ ----------
1082
+ kernel_size : int or sequence of int
1083
+ An integer, or a sequence of integers defining the window to reduce over.
1084
+ stride : int or sequence of int, optional
1085
+ An integer, or a sequence of integers, representing the inter-window stride.
1086
+ Default: 1
1087
+ padding : str, int or sequence of tuple, optional
1088
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1089
+ of n `(low, high)` integer pairs that give the padding to apply before
1090
+ and after each spatial dimension. Default: 'VALID'
1091
+ channel_axis : int, optional
1092
+ Axis of the spatial channels for which pooling is skipped.
1093
+ If ``None``, there is no channel axis. Default: -1
1094
+ name : str, optional
1095
+ The object name.
1096
+ in_size : Sequence of int, optional
1097
+ The shape of the input tensor.
1098
+
1099
+ Examples
1100
+ --------
1101
+ .. code-block:: python
1102
+
1103
+ >>> import brainstate
1104
+ >>> # pool with window of size=3, stride=2
1105
+ >>> m = AvgPool1d(3, stride=2)
1106
+ >>> input = brainstate.random.randn(20, 50, 16)
1107
+ >>> m(input).shape
1108
+ (20, 24, 16)
1109
+ """
1110
+ __module__ = 'brainstate.nn'
1111
+
1112
+ def __init__(
1113
+ self,
1114
+ kernel_size: Size,
1115
+ stride: Union[int, Sequence[int]] = 1,
1116
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1117
+ channel_axis: Optional[int] = -1,
1118
+ name: Optional[str] = None,
1119
+ in_size: Optional[Size] = None,
1120
+ ):
1121
+ super().__init__(
1122
+ in_size=in_size,
1123
+ init_value=0.,
1124
+ computation=jax.lax.add,
1125
+ pool_dim=1,
1126
+ kernel_size=kernel_size,
1127
+ stride=stride,
1128
+ padding=padding,
1129
+ channel_axis=channel_axis,
1130
+ name=name
1131
+ )
1132
+
1133
+
1134
+ class AvgPool2d(_AvgPool):
1135
+ r"""Applies a 2D average pooling over an input signal composed of several input planes.
1136
+
1137
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
1138
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
1139
+ can be precisely described as:
1140
+
1141
+ .. math::
1142
+
1143
+ out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
1144
+ input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
1145
+
1146
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
1147
+ for :attr:`padding` number of points.
1148
+
1149
+ Shape:
1150
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
1151
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1152
+
1153
+ .. math::
1154
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
1155
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1156
+
1157
+ .. math::
1158
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
1159
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1160
+
1161
+ Parameters
1162
+ ----------
1163
+ kernel_size : int or sequence of int
1164
+ An integer, or a sequence of integers defining the window to reduce over.
1165
+ stride : int or sequence of int, optional
1166
+ An integer, or a sequence of integers, representing the inter-window stride.
1167
+ Default: 1
1168
+ padding : str, int or sequence of tuple, optional
1169
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1170
+ of n `(low, high)` integer pairs that give the padding to apply before
1171
+ and after each spatial dimension. Default: 'VALID'
1172
+ channel_axis : int, optional
1173
+ Axis of the spatial channels for which pooling is skipped.
1174
+ If ``None``, there is no channel axis. Default: -1
1175
+ name : str, optional
1176
+ The object name.
1177
+ in_size : Sequence of int, optional
1178
+ The shape of the input tensor.
1179
+
1180
+ Examples
1181
+ --------
1182
+ .. code-block:: python
1183
+
1184
+ >>> import brainstate
1185
+ >>> # pool of square window of size=3, stride=2
1186
+ >>> m = AvgPool2d(3, stride=2)
1187
+ >>> # pool of non-square window
1188
+ >>> m = AvgPool2d((3, 2), stride=(2, 1))
1189
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
1190
+ >>> output = m(input)
1191
+ >>> output.shape
1192
+ (20, 24, 31, 16)
1193
+ """
1194
+ __module__ = 'brainstate.nn'
1195
+
1196
+ def __init__(
1197
+ self,
1198
+ kernel_size: Size,
1199
+ stride: Union[int, Sequence[int]] = 1,
1200
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1201
+ channel_axis: Optional[int] = -1,
1202
+ name: Optional[str] = None,
1203
+ in_size: Optional[Size] = None,
1204
+ ):
1205
+ super().__init__(
1206
+ in_size=in_size,
1207
+ init_value=0.,
1208
+ computation=jax.lax.add,
1209
+ pool_dim=2,
1210
+ kernel_size=kernel_size,
1211
+ stride=stride,
1212
+ padding=padding,
1213
+ channel_axis=channel_axis,
1214
+ name=name
1215
+ )
1216
+
1217
+
1218
+ class AvgPool3d(_AvgPool):
1219
+ r"""Applies a 3D average pooling over an input signal composed of several input planes.
1220
+
1221
+
1222
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
1223
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
1224
+ can be precisely described as:
1225
+
1226
+ .. math::
1227
+ \begin{aligned}
1228
+ \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
1229
+ & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
1230
+ \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
1231
+ {kD \times kH \times kW}
1232
+ \end{aligned}
1233
+
1234
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
1235
+ for :attr:`padding` number of points.
1236
+
1237
+ Shape:
1238
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1239
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
1240
+ :math:`(D_{out}, H_{out}, W_{out}, C)`, where
1241
+
1242
+ .. math::
1243
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
1244
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1245
+
1246
+ .. math::
1247
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
1248
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1249
+
1250
+ .. math::
1251
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
1252
+ \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
1253
+
1254
+ Parameters
1255
+ ----------
1256
+ kernel_size : int or sequence of int
1257
+ An integer, or a sequence of integers defining the window to reduce over.
1258
+ stride : int or sequence of int, optional
1259
+ An integer, or a sequence of integers, representing the inter-window stride.
1260
+ Default: 1
1261
+ padding : str, int or sequence of tuple, optional
1262
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1263
+ of n `(low, high)` integer pairs that give the padding to apply before
1264
+ and after each spatial dimension. Default: 'VALID'
1265
+ channel_axis : int, optional
1266
+ Axis of the spatial channels for which pooling is skipped.
1267
+ If ``None``, there is no channel axis. Default: -1
1268
+ name : str, optional
1269
+ The object name.
1270
+ in_size : Sequence of int, optional
1271
+ The shape of the input tensor.
1272
+
1273
+ Examples
1274
+ --------
1275
+ .. code-block:: python
1276
+
1277
+ >>> import brainstate
1278
+ >>> # pool of square window of size=3, stride=2
1279
+ >>> m = AvgPool3d(3, stride=2)
1280
+ >>> # pool of non-square window
1281
+ >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
1282
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
1283
+ >>> output = m(input)
1284
+ >>> output.shape
1285
+ (20, 24, 43, 15, 16)
1286
+
1287
+ """
1288
+ __module__ = 'brainstate.nn'
1289
+
1290
+ def __init__(
1291
+ self,
1292
+ kernel_size: Size,
1293
+ stride: Union[int, Sequence[int]] = 1,
1294
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1295
+ channel_axis: Optional[int] = -1,
1296
+ name: Optional[str] = None,
1297
+ in_size: Optional[Size] = None,
1298
+ ):
1299
+ super().__init__(
1300
+ in_size=in_size,
1301
+ init_value=0.,
1302
+ computation=jax.lax.add,
1303
+ pool_dim=3,
1304
+ kernel_size=kernel_size,
1305
+ stride=stride,
1306
+ padding=padding,
1307
+ channel_axis=channel_axis,
1308
+ name=name
1309
+ )
1310
+
1311
+
1312
+ class _LPPool(Module):
1313
+ """Base class for Lp pooling operations."""
1314
+
1315
+ def __init__(
1316
+ self,
1317
+ norm_type: float,
1318
+ pool_dim: int,
1319
+ kernel_size: Size,
1320
+ stride: Union[int, Sequence[int]] = None,
1321
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1322
+ channel_axis: Optional[int] = -1,
1323
+ name: Optional[str] = None,
1324
+ in_size: Optional[Size] = None,
1325
+ ):
1326
+ super().__init__(name=name)
1327
+
1328
+ if norm_type <= 0:
1329
+ raise ValueError(f"norm_type must be positive, got {norm_type}")
1330
+ self.norm_type = norm_type
1331
+ self.pool_dim = pool_dim
1332
+
1333
+ # kernel_size
1334
+ if isinstance(kernel_size, int):
1335
+ kernel_size = (kernel_size,) * pool_dim
1336
+ elif isinstance(kernel_size, Sequence):
1337
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
1338
+ assert all(
1339
+ [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
1340
+ if len(kernel_size) != pool_dim:
1341
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
1342
+ else:
1343
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
1344
+ self.kernel_size = kernel_size
1345
+
1346
+ # stride
1347
+ if stride is None:
1348
+ stride = kernel_size
1349
+ if isinstance(stride, int):
1350
+ stride = (stride,) * pool_dim
1351
+ elif isinstance(stride, Sequence):
1352
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
1353
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
1354
+ if len(stride) != pool_dim:
1355
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
1356
+ else:
1357
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
1358
+ self.stride = stride
1359
+
1360
+ # padding
1361
+ if isinstance(padding, str):
1362
+ if padding not in ("SAME", "VALID"):
1363
+ raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
1364
+ elif isinstance(padding, int):
1365
+ padding = [(padding, padding) for _ in range(pool_dim)]
1366
+ elif isinstance(padding, (list, tuple)):
1367
+ if isinstance(padding[0], int):
1368
+ if len(padding) == pool_dim:
1369
+ padding = [(x, x) for x in padding]
1370
+ else:
1371
+ raise ValueError(f'If padding is a sequence of ints, it '
1372
+ f'should has the length of {pool_dim}.')
1373
+ else:
1374
+ if not all([isinstance(x, (tuple, list)) for x in padding]):
1375
+ raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
1376
+ if not all([len(x) == 2 for x in padding]):
1377
+ raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
1378
+ if len(padding) == 1:
1379
+ padding = tuple(padding) * pool_dim
1380
+ assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
1381
+ else:
1382
+ raise ValueError
1383
+ self.padding = padding
1384
+
1385
+ # channel_axis
1386
+ assert channel_axis is None or isinstance(channel_axis, int), \
1387
+ f'channel_axis should be an int, but got {channel_axis}'
1388
+ self.channel_axis = channel_axis
1389
+
1390
+ # in & out shapes
1391
+ if in_size is not None:
1392
+ in_size = tuple(in_size)
1393
+ self.in_size = in_size
1394
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
1395
+ self.out_size = y.shape[1:]
1396
+
1397
+ def update(self, x):
1398
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
1399
+ if x.ndim < x_dim:
1400
+ raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
1401
+
1402
+ window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
1403
+ stride = self._infer_shape(x.ndim, self.stride, 1)
1404
+ padding = (self.padding if isinstance(self.padding, str) else
1405
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
1406
+
1407
+ # For Lp pooling, we need to:
1408
+ # 1. Take absolute value and raise to power p
1409
+ # 2. Sum over the window
1410
+ # 3. Take the p-th root
1411
+
1412
+ # Step 1: |x|^p
1413
+ x_pow = jnp.abs(x) ** self.norm_type
1414
+
1415
+ # Step 2: Sum over window
1416
+ pooled_sum = jax.lax.reduce_window(
1417
+ x_pow,
1418
+ init_value=0.,
1419
+ computation=jax.lax.add,
1420
+ window_dimensions=window_shape,
1421
+ window_strides=stride,
1422
+ padding=padding
1423
+ )
1424
+
1425
+ # Step 3: Take p-th root and multiply by normalization factor
1426
+ # The normalization factor is (1/N)^(1/p) where N is the window size
1427
+ window_size = np.prod([w for i, w in enumerate(self.kernel_size)])
1428
+ norm_factor = window_size ** (-1.0 / self.norm_type)
1429
+ result = norm_factor * (pooled_sum ** (1.0 / self.norm_type))
1430
+
1431
+ return result
1432
+
1433
+ def _infer_shape(self, x_dim, inputs, element):
1434
+ channel_axis = self.channel_axis
1435
+ if channel_axis and not 0 <= abs(channel_axis) < x_dim:
1436
+ raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
1437
+ if channel_axis and channel_axis < 0:
1438
+ channel_axis = x_dim + channel_axis
1439
+ all_dims = list(range(x_dim))
1440
+ if channel_axis is not None:
1441
+ all_dims.pop(channel_axis)
1442
+ pool_dims = all_dims[-self.pool_dim:]
1443
+ results = [element] * x_dim
1444
+ for i, dim in enumerate(pool_dims):
1445
+ results[dim] = inputs[i]
1446
+ return results
1447
+
1448
+
1449
+ class LPPool1d(_LPPool):
1450
+ r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
1451
+
1452
+ On each window, the function computed is:
1453
+
1454
+ .. math::
1455
+ f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1456
+
1457
+ - At :math:`p = \infty`, one gets max pooling
1458
+ - At :math:`p = 1`, one gets average pooling (with absolute values)
1459
+ - At :math:`p = 2`, one gets root mean square (RMS) pooling
1460
+
1461
+ Shape:
1462
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
1463
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
1464
+
1465
+ .. math::
1466
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
1467
+
1468
+ Parameters
1469
+ ----------
1470
+ norm_type : float
1471
+ Exponent for the pooling operation. Default: 2.0
1472
+ kernel_size : int or sequence of int
1473
+ An integer, or a sequence of integers defining the window to reduce over.
1474
+ stride : int or sequence of int, optional
1475
+ An integer, or a sequence of integers, representing the inter-window stride.
1476
+ Default: kernel_size
1477
+ padding : str, int or sequence of tuple, optional
1478
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1479
+ of n `(low, high)` integer pairs that give the padding to apply before
1480
+ and after each spatial dimension. Default: 'VALID'
1481
+ channel_axis : int, optional
1482
+ Axis of the spatial channels for which pooling is skipped.
1483
+ If ``None``, there is no channel axis. Default: -1
1484
+ name : str, optional
1485
+ The object name.
1486
+ in_size : Sequence of int, optional
1487
+ The shape of the input tensor.
1488
+
1489
+ Examples
1490
+ --------
1491
+ .. code-block:: python
1492
+
1493
+ >>> import brainstate
1494
+ >>> # power-average pooling of window of size=3, stride=2 with norm_type=2.0
1495
+ >>> m = LPPool1d(2, 3, stride=2)
1496
+ >>> input = brainstate.random.randn(20, 50, 16)
1497
+ >>> output = m(input)
1498
+ >>> output.shape
1499
+ (20, 24, 16)
1500
+ """
1501
+ __module__ = 'brainstate.nn'
1502
+
1503
+ def __init__(
1504
+ self,
1505
+ norm_type: float,
1506
+ kernel_size: Size,
1507
+ stride: Union[int, Sequence[int]] = None,
1508
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1509
+ channel_axis: Optional[int] = -1,
1510
+ name: Optional[str] = None,
1511
+ in_size: Optional[Size] = None,
1512
+ ):
1513
+ super().__init__(
1514
+ norm_type=norm_type,
1515
+ pool_dim=1,
1516
+ kernel_size=kernel_size,
1517
+ stride=stride,
1518
+ padding=padding,
1519
+ channel_axis=channel_axis,
1520
+ name=name,
1521
+ in_size=in_size
1522
+ )
1523
+
1524
+
1525
+ class LPPool2d(_LPPool):
1526
+ r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
1527
+
1528
+ On each window, the function computed is:
1529
+
1530
+ .. math::
1531
+ f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1532
+
1533
+ - At :math:`p = \infty`, one gets max pooling
1534
+ - At :math:`p = 1`, one gets average pooling (with absolute values)
1535
+ - At :math:`p = 2`, one gets root mean square (RMS) pooling
1536
+
1537
+ Shape:
1538
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
1539
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1540
+
1541
+ .. math::
1542
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1543
+
1544
+ .. math::
1545
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1546
+
1547
+ Parameters
1548
+ ----------
1549
+ norm_type : float
1550
+ Exponent for the pooling operation. Default: 2.0
1551
+ kernel_size : int or sequence of int
1552
+ An integer, or a sequence of integers defining the window to reduce over.
1553
+ stride : int or sequence of int, optional
1554
+ An integer, or a sequence of integers, representing the inter-window stride.
1555
+ Default: kernel_size
1556
+ padding : str, int or sequence of tuple, optional
1557
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1558
+ of n `(low, high)` integer pairs that give the padding to apply before
1559
+ and after each spatial dimension. Default: 'VALID'
1560
+ channel_axis : int, optional
1561
+ Axis of the spatial channels for which pooling is skipped.
1562
+ If ``None``, there is no channel axis. Default: -1
1563
+ name : str, optional
1564
+ The object name.
1565
+ in_size : Sequence of int, optional
1566
+ The shape of the input tensor.
1567
+
1568
+ Examples
1569
+ --------
1570
+ .. code-block:: python
1571
+
1572
+ >>> import brainstate
1573
+ >>> # power-average pooling of square window of size=3, stride=2
1574
+ >>> m = LPPool2d(2, 3, stride=2)
1575
+ >>> # pool of non-square window with norm_type=1.5
1576
+ >>> m = LPPool2d(1.5, (3, 2), stride=(2, 1), channel_axis=-1)
1577
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
1578
+ >>> output = m(input)
1579
+ >>> output.shape
1580
+ (20, 24, 31, 16)
1581
+ """
1582
+ __module__ = 'brainstate.nn'
1583
+
1584
+ def __init__(
1585
+ self,
1586
+ norm_type: float,
1587
+ kernel_size: Size,
1588
+ stride: Union[int, Sequence[int]] = None,
1589
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1590
+ channel_axis: Optional[int] = -1,
1591
+ name: Optional[str] = None,
1592
+ in_size: Optional[Size] = None,
1593
+ ):
1594
+ super().__init__(
1595
+ norm_type=norm_type,
1596
+ pool_dim=2,
1597
+ kernel_size=kernel_size,
1598
+ stride=stride,
1599
+ padding=padding,
1600
+ channel_axis=channel_axis,
1601
+ name=name,
1602
+ in_size=in_size
1603
+ )
1604
+
1605
+
1606
+ class LPPool3d(_LPPool):
1607
+ r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
1608
+
1609
+ On each window, the function computed is:
1610
+
1611
+ .. math::
1612
+ f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1613
+
1614
+ - At :math:`p = \infty`, one gets max pooling
1615
+ - At :math:`p = 1`, one gets average pooling (with absolute values)
1616
+ - At :math:`p = 2`, one gets root mean square (RMS) pooling
1617
+
1618
+ Shape:
1619
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1620
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
1621
+
1622
+ .. math::
1623
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1624
+
1625
+ .. math::
1626
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1627
+
1628
+ .. math::
1629
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
1630
+
1631
+ Parameters
1632
+ ----------
1633
+ norm_type : float
1634
+ Exponent for the pooling operation. Default: 2.0
1635
+ kernel_size : int or sequence of int
1636
+ An integer, or a sequence of integers defining the window to reduce over.
1637
+ stride : int or sequence of int, optional
1638
+ An integer, or a sequence of integers, representing the inter-window stride.
1639
+ Default: kernel_size
1640
+ padding : str, int or sequence of tuple, optional
1641
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1642
+ of n `(low, high)` integer pairs that give the padding to apply before
1643
+ and after each spatial dimension. Default: 'VALID'
1644
+ channel_axis : int, optional
1645
+ Axis of the spatial channels for which pooling is skipped.
1646
+ If ``None``, there is no channel axis. Default: -1
1647
+ name : str, optional
1648
+ The object name.
1649
+ in_size : Sequence of int, optional
1650
+ The shape of the input tensor.
1651
+
1652
+ Examples
1653
+ --------
1654
+ .. code-block:: python
1655
+
1656
+ >>> import brainstate
1657
+ >>> # power-average pooling of cube window of size=3, stride=2
1658
+ >>> m = LPPool3d(2, 3, stride=2)
1659
+ >>> # pool of non-cubic window with norm_type=1.5
1660
+ >>> m = LPPool3d(1.5, (3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
1661
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
1662
+ >>> output = m(input)
1663
+ >>> output.shape
1664
+ (20, 24, 43, 15, 16)
1665
+ """
1666
+ __module__ = 'brainstate.nn'
1667
+
1668
+ def __init__(
1669
+ self,
1670
+ norm_type: float,
1671
+ kernel_size: Size,
1672
+ stride: Union[int, Sequence[int]] = None,
1673
+ padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
1674
+ channel_axis: Optional[int] = -1,
1675
+ name: Optional[str] = None,
1676
+ in_size: Optional[Size] = None,
1677
+ ):
1678
+ super().__init__(
1679
+ norm_type=norm_type,
1680
+ pool_dim=3,
1681
+ kernel_size=kernel_size,
1682
+ stride=stride,
1683
+ padding=padding,
1684
+ channel_axis=channel_axis,
1685
+ name=name,
1686
+ in_size=in_size
1687
+ )
1688
+
1689
+
1690
+ def _adaptive_pool1d(x, target_size: int, operation: Callable):
1691
+ """Adaptive pool 1D.
1692
+
1693
+ Args:
1694
+ x: The input. Should be a JAX array of shape `(dim,)`.
1695
+ target_size: The shape of the output after the pooling operation `(target_size,)`.
1696
+ operation: The pooling operation to be performed on the input array.
1697
+
1698
+ Returns:
1699
+ A JAX array of shape `(target_size, )`.
1700
+ """
1701
+ size = jnp.size(x)
1702
+ num_head_arrays = size % target_size
1703
+ num_block = size // target_size
1704
+ if num_head_arrays != 0:
1705
+ head_end_index = num_head_arrays * (num_block + 1)
1706
+ heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
1707
+ tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
1708
+ outs = jnp.concatenate([heads, tails])
1709
+ else:
1710
+ outs = jax.vmap(operation)(x.reshape(-1, num_block))
1711
+ return outs
1712
+
1713
+
1714
+ def _generate_vmap(fun: Callable, map_axes: List[int]):
1715
+ map_axes = sorted(map_axes)
1716
+ for axis in map_axes:
1717
+ fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
1718
+ return fun
1719
+
1720
+
1721
+ class _AdaptivePool(Module):
1722
+ """General N dimensional adaptive down-sampling to a target shape.
1723
+
1724
+ Parameters
1725
+ ----------
1726
+ in_size: Sequence of int
1727
+ The shape of the input tensor.
1728
+ target_size: int, sequence of int
1729
+ The target output shape.
1730
+ num_spatial_dims: int
1731
+ The number of spatial dimensions.
1732
+ channel_axis: int, optional
1733
+ Axis of the spatial channels for which pooling is skipped.
1734
+ If ``None``, there is no channel axis.
1735
+ operation: Callable
1736
+ The down-sampling operation.
1737
+ name: str
1738
+ The class name.
1739
+ """
1740
+
1741
+ def __init__(
1742
+ self,
1743
+ in_size: Size,
1744
+ target_size: Size,
1745
+ num_spatial_dims: int,
1746
+ operation: Callable,
1747
+ channel_axis: Optional[int] = -1,
1748
+ name: Optional[str] = None,
1749
+ ):
1750
+ super().__init__(name=name)
1751
+
1752
+ self.channel_axis = channel_axis
1753
+ self.operation = operation
1754
+ if isinstance(target_size, int):
1755
+ self.target_shape = (target_size,) * num_spatial_dims
1756
+ elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
1757
+ self.target_shape = target_size
1758
+ else:
1759
+ raise ValueError("`target_size` must either be an int or tuple of length "
1760
+ f"{num_spatial_dims} containing ints.")
1761
+
1762
+ # in & out shapes
1763
+ if in_size is not None:
1764
+ in_size = tuple(in_size)
1765
+ self.in_size = in_size
1766
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
1767
+ self.out_size = y.shape[1:]
1768
+
1769
+ def update(self, x):
1770
+ """Input-output mapping.
1771
+
1772
+ Parameters
1773
+ ----------
1774
+ x: Array
1775
+ Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
1776
+ or `(..., dim_1, dim_2)`.
1777
+ """
1778
+ # channel axis
1779
+ channel_axis = self.channel_axis
1780
+
1781
+ if channel_axis:
1782
+ if not 0 <= abs(channel_axis) < x.ndim:
1783
+ raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
1784
+ if channel_axis < 0:
1785
+ channel_axis = x.ndim + channel_axis
1786
+ # input dimension
1787
+ if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
1788
+ raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
1789
+ f"dimensions (channel_axis={self.channel_axis}). "
1790
+ f"But got {x.ndim} dimensions.")
1791
+ # pooling dimensions
1792
+ pool_dims = list(range(x.ndim))
1793
+ if channel_axis:
1794
+ pool_dims.pop(channel_axis)
1795
+
1796
+ # pooling
1797
+ for i, di in enumerate(pool_dims[-len(self.target_shape):]):
1798
+ poo_axes = [j for j in range(x.ndim) if j != di]
1799
+ op = _generate_vmap(_adaptive_pool1d, poo_axes)
1800
+ x = op(x, self.target_shape[i], self.operation)
1801
+ return x
1802
+
1803
+
1804
+ class AdaptiveAvgPool1d(_AdaptivePool):
1805
+ r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
1806
+
1807
+ The output size is :math:`L_{out}`, for any input size.
1808
+ The number of output features is equal to the number of input planes.
1809
+
1810
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1811
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
1812
+
1813
+ Shape:
1814
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
1815
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
1816
+ :math:`L_{out}=\text{target\_size}`.
1817
+
1818
+ Parameters
1819
+ ----------
1820
+ target_size : int or sequence of int
1821
+ The target output size. The number of output features for each channel.
1822
+ channel_axis : int, optional
1823
+ Axis of the spatial channels for which pooling is skipped.
1824
+ If ``None``, there is no channel axis. Default: -1
1825
+ name : str, optional
1826
+ The name of the module.
1827
+ in_size : Sequence of int, optional
1828
+ The shape of the input tensor for shape inference.
1829
+
1830
+ Examples
1831
+ --------
1832
+ .. code-block:: python
1833
+
1834
+ >>> import brainstate
1835
+ >>> # Target output size of 5
1836
+ >>> m = AdaptiveAvgPool1d(5)
1837
+ >>> input = brainstate.random.randn(1, 64, 8)
1838
+ >>> output = m(input)
1839
+ >>> output.shape
1840
+ (1, 5, 8)
1841
+ >>> # Can handle variable input sizes
1842
+ >>> input2 = brainstate.random.randn(1, 32, 8)
1843
+ >>> output2 = m(input2)
1844
+ >>> output2.shape
1845
+ (1, 5, 8) # Same output size regardless of input size
1846
+
1847
+ See Also
1848
+ --------
1849
+ AvgPool1d : Non-adaptive 1D average pooling.
1850
+ AdaptiveMaxPool1d : Adaptive 1D max pooling.
1851
+ """
1852
+ __module__ = 'brainstate.nn'
1853
+
1854
+ def __init__(
1855
+ self,
1856
+ target_size: Union[int, Sequence[int]],
1857
+ channel_axis: Optional[int] = -1,
1858
+ name: Optional[str] = None,
1859
+ in_size: Optional[Sequence[int]] = None,
1860
+ ):
1861
+ super().__init__(
1862
+ in_size=in_size,
1863
+ target_size=target_size,
1864
+ channel_axis=channel_axis,
1865
+ num_spatial_dims=1,
1866
+ operation=jnp.mean,
1867
+ name=name
1868
+ )
1869
+
1870
+
1871
+ class AdaptiveAvgPool2d(_AdaptivePool):
1872
+ r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
1873
+
1874
+ The output is of size :math:`H_{out} \times W_{out}`, for any input size.
1875
+ The number of output features is equal to the number of input planes.
1876
+
1877
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1878
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
1879
+
1880
+ Shape:
1881
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
1882
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1883
+ :math:`(H_{out}, W_{out})=\text{target\_size}`.
1884
+
1885
+ Parameters
1886
+ ----------
1887
+ target_size : int or tuple of int
1888
+ The target output size. If a single integer is provided, the output will be a square
1889
+ of that size. If a tuple is provided, it specifies (H_out, W_out).
1890
+ Use None for dimensions that should not be pooled.
1891
+ channel_axis : int, optional
1892
+ Axis of the spatial channels for which pooling is skipped.
1893
+ If ``None``, there is no channel axis. Default: -1
1894
+ name : str, optional
1895
+ The name of the module.
1896
+ in_size : Sequence of int, optional
1897
+ The shape of the input tensor for shape inference.
1898
+
1899
+ Examples
1900
+ --------
1901
+ .. code-block:: python
1902
+
1903
+ >>> import brainstate
1904
+ >>> # Target output size of 5x7
1905
+ >>> m = AdaptiveAvgPool2d((5, 7))
1906
+ >>> input = brainstate.random.randn(1, 8, 9, 64)
1907
+ >>> output = m(input)
1908
+ >>> output.shape
1909
+ (1, 5, 7, 64)
1910
+ >>> # Target output size of 7x7 (square)
1911
+ >>> m = AdaptiveAvgPool2d(7)
1912
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
1913
+ >>> output = m(input)
1914
+ >>> output.shape
1915
+ (1, 7, 7, 64)
1916
+ >>> # Target output size of 10x7
1917
+ >>> m = AdaptiveAvgPool2d((None, 7))
1918
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
1919
+ >>> output = m(input)
1920
+ >>> output.shape
1921
+ (1, 10, 7, 64)
1922
+
1923
+ See Also
1924
+ --------
1925
+ AvgPool2d : Non-adaptive 2D average pooling.
1926
+ AdaptiveMaxPool2d : Adaptive 2D max pooling.
1927
+ """
1928
+ __module__ = 'brainstate.nn'
1929
+
1930
+ def __init__(
1931
+ self,
1932
+ target_size: Union[int, Sequence[int]],
1933
+ channel_axis: Optional[int] = -1,
1934
+ name: Optional[str] = None,
1935
+ in_size: Optional[Sequence[int]] = None,
1936
+ ):
1937
+ super().__init__(
1938
+ in_size=in_size,
1939
+ target_size=target_size,
1940
+ channel_axis=channel_axis,
1941
+ num_spatial_dims=2,
1942
+ operation=jnp.mean,
1943
+ name=name
1944
+ )
1945
+
1946
+
1947
+ class AdaptiveAvgPool3d(_AdaptivePool):
1948
+ r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
1949
+
1950
+ The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1951
+ The number of output features is equal to the number of input planes.
1952
+
1953
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1954
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
1955
+
1956
+ Shape:
1957
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1958
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1959
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
1960
+
1961
+ Parameters
1962
+ ----------
1963
+ target_size : int or tuple of int
1964
+ The target output size. If a single integer is provided, the output will be a cube
1965
+ of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
1966
+ Use None for dimensions that should not be pooled.
1967
+ channel_axis : int, optional
1968
+ Axis of the spatial channels for which pooling is skipped.
1969
+ If ``None``, there is no channel axis. Default: -1
1970
+ name : str, optional
1971
+ The name of the module.
1972
+ in_size : Sequence of int, optional
1973
+ The shape of the input tensor for shape inference.
1974
+
1975
+ Examples
1976
+ --------
1977
+ .. code-block:: python
1978
+
1979
+ >>> import brainstate
1980
+ >>> # Target output size of 5x7x9
1981
+ >>> m = AdaptiveAvgPool3d((5, 7, 9))
1982
+ >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
1983
+ >>> output = m(input)
1984
+ >>> output.shape
1985
+ (1, 5, 7, 9, 64)
1986
+ >>> # Target output size of 7x7x7 (cube)
1987
+ >>> m = AdaptiveAvgPool3d(7)
1988
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1989
+ >>> output = m(input)
1990
+ >>> output.shape
1991
+ (1, 7, 7, 7, 64)
1992
+ >>> # Target output size of 7x9x8
1993
+ >>> m = AdaptiveAvgPool3d((7, None, None))
1994
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1995
+ >>> output = m(input)
1996
+ >>> output.shape
1997
+ (1, 7, 9, 8, 64)
1998
+
1999
+ See Also
2000
+ --------
2001
+ AvgPool3d : Non-adaptive 3D average pooling.
2002
+ AdaptiveMaxPool3d : Adaptive 3D max pooling.
2003
+ """
2004
+ __module__ = 'brainstate.nn'
2005
+
2006
+ def __init__(
2007
+ self,
2008
+ target_size: Union[int, Sequence[int]],
2009
+ channel_axis: Optional[int] = -1,
2010
+ name: Optional[str] = None,
2011
+ in_size: Optional[Sequence[int]] = None,
2012
+ ):
2013
+ super().__init__(
2014
+ in_size=in_size,
2015
+ target_size=target_size,
2016
+ channel_axis=channel_axis,
2017
+ num_spatial_dims=3,
2018
+ operation=jnp.mean,
2019
+ name=name
2020
+ )
2021
+
2022
+
2023
+ class AdaptiveMaxPool1d(_AdaptivePool):
2024
+ r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
2025
+
2026
+ The output size is :math:`L_{out}`, for any input size.
2027
+ The number of output features is equal to the number of input planes.
2028
+
2029
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2030
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
2031
+
2032
+ Shape:
2033
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
2034
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
2035
+ :math:`L_{out}=\text{target\_size}`.
2036
+
2037
+ Parameters
2038
+ ----------
2039
+ target_size : int or sequence of int
2040
+ The target output size. The number of output features for each channel.
2041
+ channel_axis : int, optional
2042
+ Axis of the spatial channels for which pooling is skipped.
2043
+ If ``None``, there is no channel axis. Default: -1
2044
+ name : str, optional
2045
+ The name of the module.
2046
+ in_size : Sequence of int, optional
2047
+ The shape of the input tensor for shape inference.
2048
+
2049
+ Examples
2050
+ --------
2051
+ .. code-block:: python
2052
+
2053
+ >>> import brainstate
2054
+ >>> # Target output size of 5
2055
+ >>> m = AdaptiveMaxPool1d(5)
2056
+ >>> input = brainstate.random.randn(1, 64, 8)
2057
+ >>> output = m(input)
2058
+ >>> output.shape
2059
+ (1, 5, 8)
2060
+ >>> # Can handle variable input sizes
2061
+ >>> input2 = brainstate.random.randn(1, 32, 8)
2062
+ >>> output2 = m(input2)
2063
+ >>> output2.shape
2064
+ (1, 5, 8) # Same output size regardless of input size
2065
+
2066
+ See Also
2067
+ --------
2068
+ MaxPool1d : Non-adaptive 1D max pooling.
2069
+ AdaptiveAvgPool1d : Adaptive 1D average pooling.
2070
+ """
2071
+ __module__ = 'brainstate.nn'
2072
+
2073
+ def __init__(
2074
+ self,
2075
+ target_size: Union[int, Sequence[int]],
2076
+ channel_axis: Optional[int] = -1,
2077
+ name: Optional[str] = None,
2078
+ in_size: Optional[Sequence[int]] = None,
2079
+ ):
2080
+ super().__init__(
2081
+ in_size=in_size,
2082
+ target_size=target_size,
2083
+ channel_axis=channel_axis,
2084
+ num_spatial_dims=1,
2085
+ operation=jnp.max,
2086
+ name=name
2087
+ )
2088
+
2089
+
2090
+ class AdaptiveMaxPool2d(_AdaptivePool):
2091
+ r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
2092
+
2093
+ The output is of size :math:`H_{out} \times W_{out}`, for any input size.
2094
+ The number of output features is equal to the number of input planes.
2095
+
2096
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2097
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
2098
+
2099
+ Shape:
2100
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
2101
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
2102
+ :math:`(H_{out}, W_{out})=\text{target\_size}`.
2103
+
2104
+ Parameters
2105
+ ----------
2106
+ target_size : int or tuple of int
2107
+ The target output size. If a single integer is provided, the output will be a square
2108
+ of that size. If a tuple is provided, it specifies (H_out, W_out).
2109
+ Use None for dimensions that should not be pooled.
2110
+ channel_axis : int, optional
2111
+ Axis of the spatial channels for which pooling is skipped.
2112
+ If ``None``, there is no channel axis. Default: -1
2113
+ name : str, optional
2114
+ The name of the module.
2115
+ in_size : Sequence of int, optional
2116
+ The shape of the input tensor for shape inference.
2117
+
2118
+ Examples
2119
+ --------
2120
+ .. code-block:: python
2121
+
2122
+ >>> import brainstate
2123
+ >>> # Target output size of 5x7
2124
+ >>> m = AdaptiveMaxPool2d((5, 7))
2125
+ >>> input = brainstate.random.randn(1, 8, 9, 64)
2126
+ >>> output = m(input)
2127
+ >>> output.shape
2128
+ (1, 5, 7, 64)
2129
+ >>> # Target output size of 7x7 (square)
2130
+ >>> m = AdaptiveMaxPool2d(7)
2131
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
2132
+ >>> output = m(input)
2133
+ >>> output.shape
2134
+ (1, 7, 7, 64)
2135
+ >>> # Target output size of 10x7
2136
+ >>> m = AdaptiveMaxPool2d((None, 7))
2137
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
2138
+ >>> output = m(input)
2139
+ >>> output.shape
2140
+ (1, 10, 7, 64)
2141
+
2142
+ See Also
2143
+ --------
2144
+ MaxPool2d : Non-adaptive 2D max pooling.
2145
+ AdaptiveAvgPool2d : Adaptive 2D average pooling.
2146
+ """
2147
+ __module__ = 'brainstate.nn'
2148
+
2149
+ def __init__(
2150
+ self,
2151
+ target_size: Union[int, Sequence[int]],
2152
+ channel_axis: Optional[int] = -1,
2153
+ name: Optional[str] = None,
2154
+ in_size: Optional[Sequence[int]] = None,
2155
+ ):
2156
+ super().__init__(
2157
+ in_size=in_size,
2158
+ target_size=target_size,
2159
+ channel_axis=channel_axis,
2160
+ num_spatial_dims=2,
2161
+ operation=jnp.max,
2162
+ name=name
2163
+ )
2164
+
2165
+
2166
+ class AdaptiveMaxPool3d(_AdaptivePool):
2167
+ r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
2168
+
2169
+ The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
2170
+ The number of output features is equal to the number of input planes.
2171
+
2172
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2173
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
2174
+
2175
+ Shape:
2176
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
2177
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
2178
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
2179
+
2180
+ Parameters
2181
+ ----------
2182
+ target_size : int or tuple of int
2183
+ The target output size. If a single integer is provided, the output will be a cube
2184
+ of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
2185
+ Use None for dimensions that should not be pooled.
2186
+ channel_axis : int, optional
2187
+ Axis of the spatial channels for which pooling is skipped.
2188
+ If ``None``, there is no channel axis. Default: -1
2189
+ name : str, optional
2190
+ The name of the module.
2191
+ in_size : Sequence of int, optional
2192
+ The shape of the input tensor for shape inference.
2193
+
2194
+ Examples
2195
+ --------
2196
+ .. code-block:: python
2197
+
2198
+ >>> import brainstate
2199
+ >>> # Target output size of 5x7x9
2200
+ >>> m = AdaptiveMaxPool3d((5, 7, 9))
2201
+ >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
2202
+ >>> output = m(input)
2203
+ >>> output.shape
2204
+ (1, 5, 7, 9, 64)
2205
+ >>> # Target output size of 7x7x7 (cube)
2206
+ >>> m = AdaptiveMaxPool3d(7)
2207
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
2208
+ >>> output = m(input)
2209
+ >>> output.shape
2210
+ (1, 7, 7, 7, 64)
2211
+ >>> # Target output size of 7x9x8
2212
+ >>> m = AdaptiveMaxPool3d((7, None, None))
2213
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
2214
+ >>> output = m(input)
2215
+ >>> output.shape
2216
+ (1, 7, 9, 8, 64)
2217
+
2218
+ See Also
2219
+ --------
2220
+ MaxPool3d : Non-adaptive 3D max pooling.
2221
+ AdaptiveAvgPool3d : Adaptive 3D average pooling.
2222
+ """
2223
+ __module__ = 'brainstate.nn'
2224
+
2225
+ def __init__(
2226
+ self,
2227
+ target_size: Union[int, Sequence[int]],
2228
+ channel_axis: Optional[int] = -1,
2229
+ name: Optional[str] = None,
2230
+ in_size: Optional[Sequence[int]] = None,
2231
+ ):
2232
+ super().__init__(
2233
+ in_size=in_size,
2234
+ target_size=target_size,
2235
+ channel_axis=channel_axis,
2236
+ num_spatial_dims=3,
2237
+ operation=jnp.max,
2238
+ name=name
2239
+ )