brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_poolings.py
CHANGED
@@ -1,2239 +1,2239 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import functools
|
19
|
-
from typing import Sequence, Optional
|
20
|
-
from typing import Union, Tuple, Callable, List
|
21
|
-
|
22
|
-
import brainunit as u
|
23
|
-
import jax
|
24
|
-
import jax.numpy as jnp
|
25
|
-
import numpy as np
|
26
|
-
|
27
|
-
from brainstate import environ
|
28
|
-
from brainstate.typing import Size
|
29
|
-
from ._module import Module
|
30
|
-
|
31
|
-
__all__ = [
|
32
|
-
'Flatten', 'Unflatten',
|
33
|
-
|
34
|
-
'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
|
35
|
-
'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
|
36
|
-
'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
|
37
|
-
'LPPool1d', 'LPPool2d', 'LPPool3d',
|
38
|
-
|
39
|
-
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
|
40
|
-
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
|
41
|
-
]
|
42
|
-
|
43
|
-
|
44
|
-
class Flatten(Module):
|
45
|
-
r"""
|
46
|
-
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
47
|
-
|
48
|
-
Shape:
|
49
|
-
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
|
50
|
-
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
|
51
|
-
number of dimensions including none.
|
52
|
-
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
|
53
|
-
|
54
|
-
Parameters
|
55
|
-
----------
|
56
|
-
start_axis : int, optional
|
57
|
-
First dim to flatten (default = 0).
|
58
|
-
end_axis : int, optional
|
59
|
-
Last dim to flatten (default = -1).
|
60
|
-
in_size : Sequence of int, optional
|
61
|
-
The shape of the input tensor.
|
62
|
-
|
63
|
-
Examples
|
64
|
-
--------
|
65
|
-
.. code-block:: python
|
66
|
-
|
67
|
-
>>> import brainstate
|
68
|
-
>>> inp = brainstate.random.randn(32, 1, 5, 5)
|
69
|
-
>>> # With default parameters
|
70
|
-
>>> m = Flatten()
|
71
|
-
>>> output = m(inp)
|
72
|
-
>>> output.shape
|
73
|
-
(32, 25)
|
74
|
-
>>> # With non-default parameters
|
75
|
-
>>> m = Flatten(0, 2)
|
76
|
-
>>> output = m(inp)
|
77
|
-
>>> output.shape
|
78
|
-
(160, 5)
|
79
|
-
"""
|
80
|
-
__module__ = 'brainstate.nn'
|
81
|
-
|
82
|
-
def __init__(
|
83
|
-
self,
|
84
|
-
start_axis: int = 0,
|
85
|
-
end_axis: int = -1,
|
86
|
-
in_size: Optional[Size] = None
|
87
|
-
) -> None:
|
88
|
-
super().__init__()
|
89
|
-
self.start_axis = start_axis
|
90
|
-
self.end_axis = end_axis
|
91
|
-
|
92
|
-
if in_size is not None:
|
93
|
-
self.in_size = tuple(in_size)
|
94
|
-
y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
|
95
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
96
|
-
self.out_size = y.shape
|
97
|
-
|
98
|
-
def update(self, x):
|
99
|
-
if self._in_size is None:
|
100
|
-
start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
|
101
|
-
else:
|
102
|
-
assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
|
103
|
-
dim_diff = x.ndim - len(self.in_size)
|
104
|
-
if self.in_size != x.shape[dim_diff:]:
|
105
|
-
raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
|
106
|
-
if self.start_axis >= 0:
|
107
|
-
start_axis = self.start_axis + dim_diff
|
108
|
-
else:
|
109
|
-
start_axis = x.ndim + self.start_axis
|
110
|
-
return u.math.flatten(x, start_axis, self.end_axis)
|
111
|
-
|
112
|
-
|
113
|
-
class Unflatten(Module):
|
114
|
-
r"""
|
115
|
-
Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
|
116
|
-
|
117
|
-
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
|
118
|
-
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
|
119
|
-
|
120
|
-
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
|
121
|
-
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
|
122
|
-
(tuple of `(name, size)` tuples) for `NamedTensor` input.
|
123
|
-
|
124
|
-
Shape:
|
125
|
-
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
|
126
|
-
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
|
127
|
-
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
|
128
|
-
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
129
|
-
|
130
|
-
Parameters
|
131
|
-
----------
|
132
|
-
axis : int
|
133
|
-
Dimension to be unflattened.
|
134
|
-
sizes : Sequence of int
|
135
|
-
New shape of the unflattened dimension.
|
136
|
-
name : str, optional
|
137
|
-
The name of the module.
|
138
|
-
in_size : Sequence of int, optional
|
139
|
-
The shape of the input tensor.
|
140
|
-
"""
|
141
|
-
__module__ = 'brainstate.nn'
|
142
|
-
|
143
|
-
def __init__(
|
144
|
-
self,
|
145
|
-
axis: int,
|
146
|
-
sizes: Size,
|
147
|
-
name: str = None,
|
148
|
-
in_size: Optional[Size] = None
|
149
|
-
) -> None:
|
150
|
-
super().__init__(name=name)
|
151
|
-
|
152
|
-
self.axis = axis
|
153
|
-
self.sizes = sizes
|
154
|
-
if isinstance(sizes, (tuple, list)):
|
155
|
-
for idx, elem in enumerate(sizes):
|
156
|
-
if not isinstance(elem, int):
|
157
|
-
raise TypeError("unflattened sizes must be tuple of ints, " +
|
158
|
-
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
|
159
|
-
else:
|
160
|
-
raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
|
161
|
-
|
162
|
-
if in_size is not None:
|
163
|
-
self.in_size = tuple(in_size)
|
164
|
-
y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
|
165
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
166
|
-
self.out_size = y.shape
|
167
|
-
|
168
|
-
def update(self, x):
|
169
|
-
return u.math.unflatten(x, self.axis, self.sizes)
|
170
|
-
|
171
|
-
|
172
|
-
class _MaxPool(Module):
|
173
|
-
def __init__(
|
174
|
-
self,
|
175
|
-
init_value: float,
|
176
|
-
computation: Callable,
|
177
|
-
pool_dim: int,
|
178
|
-
kernel_size: Size,
|
179
|
-
stride: Union[int, Sequence[int]] = None,
|
180
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
181
|
-
channel_axis: Optional[int] = -1,
|
182
|
-
return_indices: bool = False,
|
183
|
-
name: Optional[str] = None,
|
184
|
-
in_size: Optional[Size] = None,
|
185
|
-
):
|
186
|
-
super().__init__(name=name)
|
187
|
-
|
188
|
-
self.init_value = init_value
|
189
|
-
self.computation = computation
|
190
|
-
self.pool_dim = pool_dim
|
191
|
-
self.return_indices = return_indices
|
192
|
-
|
193
|
-
# kernel_size
|
194
|
-
if isinstance(kernel_size, int):
|
195
|
-
kernel_size = (kernel_size,) * pool_dim
|
196
|
-
elif isinstance(kernel_size, Sequence):
|
197
|
-
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
198
|
-
assert all(
|
199
|
-
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
200
|
-
if len(kernel_size) != pool_dim:
|
201
|
-
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
202
|
-
else:
|
203
|
-
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
204
|
-
self.kernel_size = kernel_size
|
205
|
-
|
206
|
-
# stride
|
207
|
-
if stride is None:
|
208
|
-
stride = kernel_size
|
209
|
-
if isinstance(stride, int):
|
210
|
-
stride = (stride,) * pool_dim
|
211
|
-
elif isinstance(stride, Sequence):
|
212
|
-
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
213
|
-
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
214
|
-
if len(stride) != pool_dim:
|
215
|
-
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
216
|
-
else:
|
217
|
-
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
218
|
-
self.stride = stride
|
219
|
-
|
220
|
-
# padding
|
221
|
-
if isinstance(padding, str):
|
222
|
-
if padding not in ("SAME", "VALID"):
|
223
|
-
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
|
224
|
-
elif isinstance(padding, int):
|
225
|
-
padding = [(padding, padding) for _ in range(pool_dim)]
|
226
|
-
elif isinstance(padding, (list, tuple)):
|
227
|
-
if isinstance(padding[0], int):
|
228
|
-
if len(padding) == pool_dim:
|
229
|
-
padding = [(x, x) for x in padding]
|
230
|
-
else:
|
231
|
-
raise ValueError(f'If padding is a sequence of ints, it '
|
232
|
-
f'should has the length of {pool_dim}.')
|
233
|
-
else:
|
234
|
-
if not all([isinstance(x, (tuple, list)) for x in padding]):
|
235
|
-
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
|
236
|
-
if not all([len(x) == 2 for x in padding]):
|
237
|
-
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
|
238
|
-
if len(padding) == 1:
|
239
|
-
padding = tuple(padding) * pool_dim
|
240
|
-
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
|
241
|
-
else:
|
242
|
-
raise ValueError
|
243
|
-
self.padding = padding
|
244
|
-
|
245
|
-
# channel_axis
|
246
|
-
assert channel_axis is None or isinstance(channel_axis, int), \
|
247
|
-
f'channel_axis should be an int, but got {channel_axis}'
|
248
|
-
self.channel_axis = channel_axis
|
249
|
-
|
250
|
-
# in & out shapes
|
251
|
-
if in_size is not None:
|
252
|
-
in_size = tuple(in_size)
|
253
|
-
self.in_size = in_size
|
254
|
-
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
255
|
-
self.out_size = y.shape[1:]
|
256
|
-
|
257
|
-
def update(self, x):
|
258
|
-
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
259
|
-
if x.ndim < x_dim:
|
260
|
-
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
261
|
-
window_shape = tuple(self._infer_shape(x.ndim, self.kernel_size, 1))
|
262
|
-
stride = tuple(self._infer_shape(x.ndim, self.stride, 1))
|
263
|
-
if isinstance(self.padding, str):
|
264
|
-
padding = tuple(jax.lax.padtype_to_pads(x.shape, window_shape, stride, self.padding))
|
265
|
-
else:
|
266
|
-
padding = tuple(self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
267
|
-
|
268
|
-
if self.return_indices:
|
269
|
-
# For returning indices, we need to use a custom implementation
|
270
|
-
return self._pooling_with_indices(x, window_shape, stride, padding)
|
271
|
-
|
272
|
-
return jax.lax.reduce_window(
|
273
|
-
x,
|
274
|
-
init_value=self.init_value,
|
275
|
-
computation=self.computation,
|
276
|
-
window_dimensions=window_shape,
|
277
|
-
window_strides=stride,
|
278
|
-
padding=padding
|
279
|
-
)
|
280
|
-
|
281
|
-
def _pooling_with_indices(self, x, window_shape, stride, padding):
|
282
|
-
"""Perform max pooling and return both pooled values and indices."""
|
283
|
-
total_size = x.size
|
284
|
-
flat_indices = jnp.arange(total_size, dtype=jnp.int32).reshape(x.shape)
|
285
|
-
|
286
|
-
init_val = jnp.asarray(self.init_value, dtype=x.dtype)
|
287
|
-
init_idx = jnp.array(total_size, dtype=flat_indices.dtype)
|
288
|
-
|
289
|
-
def reducer(acc, operand):
|
290
|
-
acc_val, acc_idx = acc
|
291
|
-
cur_val, cur_idx = operand
|
292
|
-
|
293
|
-
better = cur_val > acc_val
|
294
|
-
best_val = jnp.where(better, cur_val, acc_val)
|
295
|
-
best_idx = jnp.where(better, cur_idx, acc_idx)
|
296
|
-
tie = jnp.logical_and(cur_val == acc_val, cur_idx < acc_idx)
|
297
|
-
best_idx = jnp.where(tie, cur_idx, best_idx)
|
298
|
-
|
299
|
-
return best_val, best_idx
|
300
|
-
|
301
|
-
pooled, indices_result = jax.lax.reduce_window(
|
302
|
-
(x, flat_indices),
|
303
|
-
(init_val, init_idx),
|
304
|
-
reducer,
|
305
|
-
window_dimensions=window_shape,
|
306
|
-
window_strides=stride,
|
307
|
-
padding=padding
|
308
|
-
)
|
309
|
-
|
310
|
-
indices_result = jnp.where(indices_result == total_size, 0, indices_result)
|
311
|
-
|
312
|
-
return pooled, indices_result.astype(jnp.int32)
|
313
|
-
|
314
|
-
def _infer_shape(self, x_dim, inputs, element):
|
315
|
-
channel_axis = self.channel_axis
|
316
|
-
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
|
317
|
-
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
|
318
|
-
if channel_axis and channel_axis < 0:
|
319
|
-
channel_axis = x_dim + channel_axis
|
320
|
-
all_dims = list(range(x_dim))
|
321
|
-
if channel_axis is not None:
|
322
|
-
all_dims.pop(channel_axis)
|
323
|
-
pool_dims = all_dims[-self.pool_dim:]
|
324
|
-
results = [element] * x_dim
|
325
|
-
for i, dim in enumerate(pool_dims):
|
326
|
-
results[dim] = inputs[i]
|
327
|
-
return results
|
328
|
-
|
329
|
-
|
330
|
-
class _AvgPool(_MaxPool):
|
331
|
-
def update(self, x):
|
332
|
-
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
333
|
-
if x.ndim < x_dim:
|
334
|
-
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
335
|
-
dims = self._infer_shape(x.ndim, self.kernel_size, 1)
|
336
|
-
stride = self._infer_shape(x.ndim, self.stride, 1)
|
337
|
-
padding = (self.padding if isinstance(self.padding, str) else
|
338
|
-
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
339
|
-
pooled = jax.lax.reduce_window(x,
|
340
|
-
init_value=self.init_value,
|
341
|
-
computation=self.computation,
|
342
|
-
window_dimensions=dims,
|
343
|
-
window_strides=stride,
|
344
|
-
padding=padding)
|
345
|
-
if padding == "VALID":
|
346
|
-
# Avoid the extra reduce_window.
|
347
|
-
return pooled / np.prod(dims)
|
348
|
-
else:
|
349
|
-
# Count the number of valid entries at each input point, then use that for
|
350
|
-
# computing average. Assumes that any two arrays of same shape will be
|
351
|
-
# padded the same.
|
352
|
-
window_counts = jax.lax.reduce_window(jnp.ones_like(x),
|
353
|
-
init_value=self.init_value,
|
354
|
-
computation=self.computation,
|
355
|
-
window_dimensions=dims,
|
356
|
-
window_strides=stride,
|
357
|
-
padding=padding)
|
358
|
-
assert pooled.shape == window_counts.shape
|
359
|
-
return pooled / window_counts
|
360
|
-
|
361
|
-
|
362
|
-
class MaxPool1d(_MaxPool):
|
363
|
-
r"""Applies a 1D max pooling over an input signal composed of several input planes.
|
364
|
-
|
365
|
-
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
|
366
|
-
and output :math:`(N, L_{out}, C)` can be precisely described as:
|
367
|
-
|
368
|
-
.. math::
|
369
|
-
out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
|
370
|
-
input(N_i, stride \times k + m, C_j)
|
371
|
-
|
372
|
-
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
373
|
-
for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
|
374
|
-
sliding window. This `link`_ has a nice visualization of the pooling parameters.
|
375
|
-
|
376
|
-
Shape:
|
377
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
378
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
379
|
-
|
380
|
-
.. math::
|
381
|
-
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
382
|
-
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
|
383
|
-
|
384
|
-
Parameters
|
385
|
-
----------
|
386
|
-
kernel_size : int or sequence of int
|
387
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
388
|
-
stride : int or sequence of int, optional
|
389
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
390
|
-
Default: kernel_size
|
391
|
-
padding : str, int or sequence of tuple, optional
|
392
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
393
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
394
|
-
and after each spatial dimension. Default: 'VALID'
|
395
|
-
channel_axis : int, optional
|
396
|
-
Axis of the spatial channels for which pooling is skipped.
|
397
|
-
If ``None``, there is no channel axis. Default: -1
|
398
|
-
return_indices : bool, optional
|
399
|
-
If True, will return the max indices along with the outputs.
|
400
|
-
Useful for MaxUnpool1d. Default: False
|
401
|
-
name : str, optional
|
402
|
-
The object name.
|
403
|
-
in_size : Sequence of int, optional
|
404
|
-
The shape of the input tensor.
|
405
|
-
|
406
|
-
Examples
|
407
|
-
--------
|
408
|
-
.. code-block:: python
|
409
|
-
|
410
|
-
>>> import brainstate
|
411
|
-
>>> # pool of size=3, stride=2
|
412
|
-
>>> m = MaxPool1d(3, stride=2, channel_axis=-1)
|
413
|
-
>>> input = brainstate.random.randn(20, 50, 16)
|
414
|
-
>>> output = m(input)
|
415
|
-
>>> output.shape
|
416
|
-
(20, 24, 16)
|
417
|
-
|
418
|
-
.. _link:
|
419
|
-
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
420
|
-
"""
|
421
|
-
__module__ = 'brainstate.nn'
|
422
|
-
|
423
|
-
def __init__(
|
424
|
-
self,
|
425
|
-
kernel_size: Size,
|
426
|
-
stride: Union[int, Sequence[int]] = None,
|
427
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
428
|
-
channel_axis: Optional[int] = -1,
|
429
|
-
return_indices: bool = False,
|
430
|
-
name: Optional[str] = None,
|
431
|
-
in_size: Optional[Size] = None,
|
432
|
-
):
|
433
|
-
super().__init__(in_size=in_size,
|
434
|
-
init_value=-jax.numpy.inf,
|
435
|
-
computation=jax.lax.max,
|
436
|
-
pool_dim=1,
|
437
|
-
kernel_size=kernel_size,
|
438
|
-
stride=stride,
|
439
|
-
padding=padding,
|
440
|
-
channel_axis=channel_axis,
|
441
|
-
return_indices=return_indices,
|
442
|
-
name=name)
|
443
|
-
|
444
|
-
|
445
|
-
class MaxPool2d(_MaxPool):
|
446
|
-
r"""Applies a 2D max pooling over an input signal composed of several input planes.
|
447
|
-
|
448
|
-
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
449
|
-
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
450
|
-
can be precisely described as:
|
451
|
-
|
452
|
-
.. math::
|
453
|
-
\begin{aligned}
|
454
|
-
out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
455
|
-
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
456
|
-
\text{stride[1]} \times w + n, C_j)
|
457
|
-
\end{aligned}
|
458
|
-
|
459
|
-
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
460
|
-
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
461
|
-
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
462
|
-
|
463
|
-
Shape:
|
464
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
465
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
466
|
-
|
467
|
-
.. math::
|
468
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
|
469
|
-
\times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
|
470
|
-
|
471
|
-
.. math::
|
472
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
|
473
|
-
\times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
|
474
|
-
|
475
|
-
Parameters
|
476
|
-
----------
|
477
|
-
kernel_size : int or sequence of int
|
478
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
479
|
-
stride : int or sequence of int, optional
|
480
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
481
|
-
Default: kernel_size
|
482
|
-
padding : str, int or sequence of tuple, optional
|
483
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
484
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
485
|
-
and after each spatial dimension. Default: 'VALID'
|
486
|
-
channel_axis : int, optional
|
487
|
-
Axis of the spatial channels for which pooling is skipped.
|
488
|
-
If ``None``, there is no channel axis. Default: -1
|
489
|
-
return_indices : bool, optional
|
490
|
-
If True, will return the max indices along with the outputs.
|
491
|
-
Useful for MaxUnpool2d. Default: False
|
492
|
-
name : str, optional
|
493
|
-
The object name.
|
494
|
-
in_size : Sequence of int, optional
|
495
|
-
The shape of the input tensor.
|
496
|
-
|
497
|
-
Examples
|
498
|
-
--------
|
499
|
-
.. code-block:: python
|
500
|
-
|
501
|
-
>>> import brainstate
|
502
|
-
>>> # pool of square window of size=3, stride=2
|
503
|
-
>>> m = MaxPool2d(3, stride=2)
|
504
|
-
>>> # pool of non-square window
|
505
|
-
>>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
|
506
|
-
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
507
|
-
>>> output = m(input)
|
508
|
-
>>> output.shape
|
509
|
-
(20, 24, 31, 16)
|
510
|
-
|
511
|
-
.. _link:
|
512
|
-
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
513
|
-
"""
|
514
|
-
__module__ = 'brainstate.nn'
|
515
|
-
|
516
|
-
def __init__(
|
517
|
-
self,
|
518
|
-
kernel_size: Size,
|
519
|
-
stride: Union[int, Sequence[int]] = None,
|
520
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
521
|
-
channel_axis: Optional[int] = -1,
|
522
|
-
return_indices: bool = False,
|
523
|
-
name: Optional[str] = None,
|
524
|
-
in_size: Optional[Size] = None,
|
525
|
-
):
|
526
|
-
super().__init__(in_size=in_size,
|
527
|
-
init_value=-jax.numpy.inf,
|
528
|
-
computation=jax.lax.max,
|
529
|
-
pool_dim=2,
|
530
|
-
kernel_size=kernel_size,
|
531
|
-
stride=stride,
|
532
|
-
padding=padding,
|
533
|
-
channel_axis=channel_axis,
|
534
|
-
return_indices=return_indices,
|
535
|
-
name=name)
|
536
|
-
|
537
|
-
|
538
|
-
class MaxPool3d(_MaxPool):
|
539
|
-
r"""Applies a 3D max pooling over an input signal composed of several input planes.
|
540
|
-
|
541
|
-
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
542
|
-
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
543
|
-
can be precisely described as:
|
544
|
-
|
545
|
-
.. math::
|
546
|
-
\begin{aligned}
|
547
|
-
\text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
548
|
-
& \text{input}(N_i, \text{stride[0]} \times d + k,
|
549
|
-
\text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
|
550
|
-
\end{aligned}
|
551
|
-
|
552
|
-
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
553
|
-
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
554
|
-
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
555
|
-
|
556
|
-
Shape:
|
557
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
558
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
559
|
-
|
560
|
-
.. math::
|
561
|
-
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
|
562
|
-
(\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
|
563
|
-
|
564
|
-
.. math::
|
565
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
|
566
|
-
(\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
|
567
|
-
|
568
|
-
.. math::
|
569
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
|
570
|
-
(\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
|
571
|
-
|
572
|
-
Parameters
|
573
|
-
----------
|
574
|
-
kernel_size : int or sequence of int
|
575
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
576
|
-
stride : int or sequence of int, optional
|
577
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
578
|
-
Default: kernel_size
|
579
|
-
padding : str, int or sequence of tuple, optional
|
580
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
581
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
582
|
-
and after each spatial dimension. Default: 'VALID'
|
583
|
-
channel_axis : int, optional
|
584
|
-
Axis of the spatial channels for which pooling is skipped.
|
585
|
-
If ``None``, there is no channel axis. Default: -1
|
586
|
-
return_indices : bool, optional
|
587
|
-
If True, will return the max indices along with the outputs.
|
588
|
-
Useful for MaxUnpool3d. Default: False
|
589
|
-
name : str, optional
|
590
|
-
The object name.
|
591
|
-
in_size : Sequence of int, optional
|
592
|
-
The shape of the input tensor.
|
593
|
-
|
594
|
-
Examples
|
595
|
-
--------
|
596
|
-
.. code-block:: python
|
597
|
-
|
598
|
-
>>> import brainstate
|
599
|
-
>>> # pool of square window of size=3, stride=2
|
600
|
-
>>> m = MaxPool3d(3, stride=2)
|
601
|
-
>>> # pool of non-square window
|
602
|
-
>>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
603
|
-
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
604
|
-
>>> output = m(input)
|
605
|
-
>>> output.shape
|
606
|
-
(20, 24, 43, 15, 16)
|
607
|
-
|
608
|
-
.. _link:
|
609
|
-
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
610
|
-
"""
|
611
|
-
__module__ = 'brainstate.nn'
|
612
|
-
|
613
|
-
def __init__(
|
614
|
-
self,
|
615
|
-
kernel_size: Size,
|
616
|
-
stride: Union[int, Sequence[int]] = None,
|
617
|
-
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
618
|
-
channel_axis: Optional[int] = -1,
|
619
|
-
return_indices: bool = False,
|
620
|
-
name: Optional[str] = None,
|
621
|
-
in_size: Optional[Size] = None,
|
622
|
-
):
|
623
|
-
super().__init__(in_size=in_size,
|
624
|
-
init_value=-jax.numpy.inf,
|
625
|
-
computation=jax.lax.max,
|
626
|
-
pool_dim=3,
|
627
|
-
kernel_size=kernel_size,
|
628
|
-
stride=stride,
|
629
|
-
padding=padding,
|
630
|
-
channel_axis=channel_axis,
|
631
|
-
return_indices=return_indices,
|
632
|
-
name=name)
|
633
|
-
|
634
|
-
|
635
|
-
class _MaxUnpool(Module):
|
636
|
-
"""Base class for max unpooling operations."""
|
637
|
-
|
638
|
-
def __init__(
|
639
|
-
self,
|
640
|
-
pool_dim: int,
|
641
|
-
kernel_size: Size,
|
642
|
-
stride: Union[int, Sequence[int]] = None,
|
643
|
-
padding: Union[int, Tuple[int, ...]] = 0,
|
644
|
-
channel_axis: Optional[int] = -1,
|
645
|
-
name: Optional[str] = None,
|
646
|
-
in_size: Optional[Size] = None,
|
647
|
-
):
|
648
|
-
super().__init__(name=name)
|
649
|
-
|
650
|
-
self.pool_dim = pool_dim
|
651
|
-
|
652
|
-
# kernel_size
|
653
|
-
if isinstance(kernel_size, int):
|
654
|
-
kernel_size = (kernel_size,) * pool_dim
|
655
|
-
elif isinstance(kernel_size, Sequence):
|
656
|
-
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
657
|
-
assert all(
|
658
|
-
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
659
|
-
if len(kernel_size) != pool_dim:
|
660
|
-
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
661
|
-
else:
|
662
|
-
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
663
|
-
self.kernel_size = kernel_size
|
664
|
-
|
665
|
-
# stride
|
666
|
-
if stride is None:
|
667
|
-
stride = kernel_size
|
668
|
-
if isinstance(stride, int):
|
669
|
-
stride = (stride,) * pool_dim
|
670
|
-
elif isinstance(stride, Sequence):
|
671
|
-
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
672
|
-
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
673
|
-
if len(stride) != pool_dim:
|
674
|
-
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
|
675
|
-
else:
|
676
|
-
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
677
|
-
self.stride = stride
|
678
|
-
|
679
|
-
# padding
|
680
|
-
if isinstance(padding, int):
|
681
|
-
padding = (padding,) * pool_dim
|
682
|
-
elif isinstance(padding, (tuple, list)):
|
683
|
-
if len(padding) != pool_dim:
|
684
|
-
raise ValueError(f'padding should have {pool_dim} values, but got {len(padding)}')
|
685
|
-
else:
|
686
|
-
raise TypeError(f'padding should be int or tuple of {pool_dim} ints.')
|
687
|
-
self.padding = padding
|
688
|
-
|
689
|
-
# channel_axis
|
690
|
-
assert channel_axis is None or isinstance(channel_axis, int), \
|
691
|
-
f'channel_axis should be an int, but got {channel_axis}'
|
692
|
-
self.channel_axis = channel_axis
|
693
|
-
|
694
|
-
# in & out shapes
|
695
|
-
if in_size is not None:
|
696
|
-
in_size = tuple(in_size)
|
697
|
-
self.in_size = in_size
|
698
|
-
|
699
|
-
def _compute_output_shape(self, input_shape, output_size=None):
|
700
|
-
"""Compute the output shape after unpooling."""
|
701
|
-
if output_size is not None:
|
702
|
-
return output_size
|
703
|
-
|
704
|
-
# Calculate output shape based on kernel, stride, and padding
|
705
|
-
output_shape = []
|
706
|
-
for i in range(self.pool_dim):
|
707
|
-
dim_size = (input_shape[i] - 1) * self.stride[i] - 2 * self.padding[i] + self.kernel_size[i]
|
708
|
-
output_shape.append(dim_size)
|
709
|
-
|
710
|
-
return tuple(output_shape)
|
711
|
-
|
712
|
-
def _unpool_nd(self, x, indices, output_size=None):
|
713
|
-
"""Perform N-dimensional max unpooling."""
|
714
|
-
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
715
|
-
if x.ndim < x_dim:
|
716
|
-
raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
|
717
|
-
|
718
|
-
# Determine output shape
|
719
|
-
if output_size is None:
|
720
|
-
# Infer output shape from input shape
|
721
|
-
spatial_dims = self._get_spatial_dims(x.shape)
|
722
|
-
output_spatial_shape = self._compute_output_shape(spatial_dims, output_size)
|
723
|
-
output_shape = list(x.shape)
|
724
|
-
|
725
|
-
# Update spatial dimensions in output shape
|
726
|
-
spatial_start = self._get_spatial_start_idx(x.ndim)
|
727
|
-
for i, size in enumerate(output_spatial_shape):
|
728
|
-
output_shape[spatial_start + i] = size
|
729
|
-
output_shape = tuple(output_shape)
|
730
|
-
else:
|
731
|
-
# Use provided output size
|
732
|
-
if isinstance(output_size, (list, tuple)):
|
733
|
-
if len(output_size) == x.ndim:
|
734
|
-
# Full output shape provided
|
735
|
-
output_shape = tuple(output_size)
|
736
|
-
else:
|
737
|
-
# Only spatial dimensions provided
|
738
|
-
if len(output_size) != self.pool_dim:
|
739
|
-
raise ValueError(f"output_size must have {self.pool_dim} spatial dimensions, got {len(output_size)}")
|
740
|
-
output_shape = list(x.shape)
|
741
|
-
spatial_start = self._get_spatial_start_idx(x.ndim)
|
742
|
-
for i, size in enumerate(output_size):
|
743
|
-
output_shape[spatial_start + i] = size
|
744
|
-
output_shape = tuple(output_shape)
|
745
|
-
else:
|
746
|
-
# Single integer provided, use for all spatial dims
|
747
|
-
output_shape = list(x.shape)
|
748
|
-
spatial_start = self._get_spatial_start_idx(x.ndim)
|
749
|
-
for i in range(self.pool_dim):
|
750
|
-
output_shape[spatial_start + i] = output_size
|
751
|
-
output_shape = tuple(output_shape)
|
752
|
-
|
753
|
-
# Create output array filled with zeros
|
754
|
-
output = jnp.zeros(output_shape, dtype=x.dtype)
|
755
|
-
|
756
|
-
# # Scatter input values to output using indices
|
757
|
-
# # Flatten spatial dimensions for easier indexing
|
758
|
-
# batch_dims = x.ndim - self.pool_dim - (0 if self.channel_axis is None else 1)
|
759
|
-
#
|
760
|
-
# # Reshape for processing
|
761
|
-
# if batch_dims > 0:
|
762
|
-
# batch_shape = x.shape[:batch_dims]
|
763
|
-
# if self.channel_axis is not None and self.channel_axis < batch_dims:
|
764
|
-
# # Channel axis is before spatial dims
|
765
|
-
# channel_idx = self.channel_axis
|
766
|
-
# n_channels = x.shape[channel_idx]
|
767
|
-
# elif self.channel_axis is not None:
|
768
|
-
# # Channel axis is after spatial dims
|
769
|
-
# if self.channel_axis < 0:
|
770
|
-
# channel_idx = x.ndim + self.channel_axis
|
771
|
-
# else:
|
772
|
-
# channel_idx = self.channel_axis
|
773
|
-
# n_channels = x.shape[channel_idx]
|
774
|
-
# else:
|
775
|
-
# n_channels = None
|
776
|
-
# else:
|
777
|
-
# batch_shape = ()
|
778
|
-
# if self.channel_axis is not None:
|
779
|
-
# if self.channel_axis < 0:
|
780
|
-
# channel_idx = x.ndim + self.channel_axis
|
781
|
-
# else:
|
782
|
-
# channel_idx = self.channel_axis
|
783
|
-
# n_channels = x.shape[channel_idx]
|
784
|
-
# else:
|
785
|
-
# n_channels = None
|
786
|
-
|
787
|
-
# Use JAX's scatter operation
|
788
|
-
# Flatten the indices to 1D for scatter
|
789
|
-
flat_indices = indices.ravel()
|
790
|
-
flat_values = x.ravel()
|
791
|
-
flat_output = output.ravel()
|
792
|
-
|
793
|
-
# Scatter the values
|
794
|
-
flat_output = flat_output.at[flat_indices].set(flat_values)
|
795
|
-
|
796
|
-
# Reshape back to original shape
|
797
|
-
output = flat_output.reshape(output_shape)
|
798
|
-
|
799
|
-
return output
|
800
|
-
|
801
|
-
def _get_spatial_dims(self, shape):
|
802
|
-
"""Extract spatial dimensions from input shape."""
|
803
|
-
if self.channel_axis is None:
|
804
|
-
return shape[-self.pool_dim:]
|
805
|
-
else:
|
806
|
-
channel_axis = self.channel_axis if self.channel_axis >= 0 else len(shape) + self.channel_axis
|
807
|
-
all_dims = list(range(len(shape)))
|
808
|
-
all_dims.pop(channel_axis)
|
809
|
-
return tuple(shape[i] for i in all_dims[-self.pool_dim:])
|
810
|
-
|
811
|
-
def _get_spatial_start_idx(self, ndim):
|
812
|
-
"""Get the starting index of spatial dimensions."""
|
813
|
-
if self.channel_axis is None:
|
814
|
-
return ndim - self.pool_dim
|
815
|
-
else:
|
816
|
-
channel_axis = self.channel_axis if self.channel_axis >= 0 else ndim + self.channel_axis
|
817
|
-
if channel_axis < ndim - self.pool_dim:
|
818
|
-
return ndim - self.pool_dim
|
819
|
-
else:
|
820
|
-
return ndim - self.pool_dim - 1
|
821
|
-
|
822
|
-
def update(self, x, indices, output_size=None):
|
823
|
-
"""Forward pass of MaxUnpool1d.
|
824
|
-
|
825
|
-
Parameters
|
826
|
-
----------
|
827
|
-
x : Array
|
828
|
-
Input tensor from MaxPool1d
|
829
|
-
indices : Array
|
830
|
-
Indices of maximum values from MaxPool1d
|
831
|
-
output_size : int or tuple, optional
|
832
|
-
The targeted output size
|
833
|
-
|
834
|
-
Returns
|
835
|
-
-------
|
836
|
-
Array
|
837
|
-
Unpooled output
|
838
|
-
"""
|
839
|
-
return self._unpool_nd(x, indices, output_size)
|
840
|
-
|
841
|
-
|
842
|
-
class MaxUnpool1d(_MaxUnpool):
|
843
|
-
r"""Computes a partial inverse of MaxPool1d.
|
844
|
-
|
845
|
-
MaxPool1d is not fully invertible, since the non-maximal values are lost.
|
846
|
-
MaxUnpool1d takes in as input the output of MaxPool1d including the indices
|
847
|
-
of the maximal values and computes a partial inverse in which all
|
848
|
-
non-maximal values are set to zero.
|
849
|
-
|
850
|
-
Note:
|
851
|
-
This function may produce nondeterministic gradients when given tensors
|
852
|
-
on a CUDA device. See notes on reproducibility for more information.
|
853
|
-
|
854
|
-
Shape:
|
855
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`
|
856
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
857
|
-
|
858
|
-
.. math::
|
859
|
-
L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size}
|
860
|
-
|
861
|
-
or as given by :attr:`output_size` in the call operator
|
862
|
-
|
863
|
-
Parameters
|
864
|
-
----------
|
865
|
-
kernel_size : int or tuple
|
866
|
-
Size of the max pooling window.
|
867
|
-
stride : int or tuple, optional
|
868
|
-
Stride of the max pooling window. Default: kernel_size
|
869
|
-
padding : int or tuple, optional
|
870
|
-
Padding that was added to the input. Default: 0
|
871
|
-
channel_axis : int, optional
|
872
|
-
Axis of the channels. Default: -1
|
873
|
-
name : str, optional
|
874
|
-
Name of the module.
|
875
|
-
in_size : Size, optional
|
876
|
-
Input size for shape inference.
|
877
|
-
|
878
|
-
Examples
|
879
|
-
--------
|
880
|
-
.. code-block:: python
|
881
|
-
|
882
|
-
>>> import brainstate
|
883
|
-
>>> import jax.numpy as jnp
|
884
|
-
>>> # Create pooling and unpooling layers
|
885
|
-
>>> pool = MaxPool1d(2, stride=2, return_indices=True, channel_axis=-1)
|
886
|
-
>>> unpool = MaxUnpool1d(2, stride=2, channel_axis=-1)
|
887
|
-
>>> input = brainstate.random.randn(20, 50, 16)
|
888
|
-
>>> output, indices = pool(input)
|
889
|
-
>>> unpooled = unpool(output, indices)
|
890
|
-
>>> # unpooled will have shape (20, 100, 16) with zeros at non-maximal positions
|
891
|
-
"""
|
892
|
-
__module__ = 'brainstate.nn'
|
893
|
-
|
894
|
-
def __init__(
|
895
|
-
self,
|
896
|
-
kernel_size: Size,
|
897
|
-
stride: Union[int, Sequence[int]] = None,
|
898
|
-
padding: Union[int, Tuple[int, ...]] = 0,
|
899
|
-
channel_axis: Optional[int] = -1,
|
900
|
-
name: Optional[str] = None,
|
901
|
-
in_size: Optional[Size] = None,
|
902
|
-
):
|
903
|
-
super().__init__(
|
904
|
-
pool_dim=1,
|
905
|
-
kernel_size=kernel_size,
|
906
|
-
stride=stride,
|
907
|
-
padding=padding,
|
908
|
-
channel_axis=channel_axis,
|
909
|
-
name=name,
|
910
|
-
in_size=in_size
|
911
|
-
)
|
912
|
-
|
913
|
-
|
914
|
-
class MaxUnpool2d(_MaxUnpool):
|
915
|
-
r"""Computes a partial inverse of MaxPool2d.
|
916
|
-
|
917
|
-
MaxPool2d is not fully invertible, since the non-maximal values are lost.
|
918
|
-
MaxUnpool2d takes in as input the output of MaxPool2d including the indices
|
919
|
-
of the maximal values and computes a partial inverse in which all
|
920
|
-
non-maximal values are set to zero.
|
921
|
-
|
922
|
-
Shape:
|
923
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
924
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
925
|
-
|
926
|
-
.. math::
|
927
|
-
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
|
928
|
-
|
929
|
-
.. math::
|
930
|
-
W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
|
931
|
-
|
932
|
-
or as given by :attr:`output_size` in the call operator
|
933
|
-
|
934
|
-
Parameters
|
935
|
-
----------
|
936
|
-
kernel_size : int or tuple
|
937
|
-
Size of the max pooling window.
|
938
|
-
stride : int or tuple, optional
|
939
|
-
Stride of the max pooling window. Default: kernel_size
|
940
|
-
padding : int or tuple, optional
|
941
|
-
Padding that was added to the input. Default: 0
|
942
|
-
channel_axis : int, optional
|
943
|
-
Axis of the channels. Default: -1
|
944
|
-
name : str, optional
|
945
|
-
Name of the module.
|
946
|
-
in_size : Size, optional
|
947
|
-
Input size for shape inference.
|
948
|
-
|
949
|
-
Examples
|
950
|
-
--------
|
951
|
-
.. code-block:: python
|
952
|
-
|
953
|
-
>>> import brainstate
|
954
|
-
>>> # Create pooling and unpooling layers
|
955
|
-
>>> pool = MaxPool2d(2, stride=2, return_indices=True, channel_axis=-1)
|
956
|
-
>>> unpool = MaxUnpool2d(2, stride=2, channel_axis=-1)
|
957
|
-
>>> input = brainstate.random.randn(1, 4, 4, 16)
|
958
|
-
>>> output, indices = pool(input)
|
959
|
-
>>> unpooled = unpool(output, indices)
|
960
|
-
>>> # unpooled will have shape (1, 8, 8, 16) with zeros at non-maximal positions
|
961
|
-
"""
|
962
|
-
__module__ = 'brainstate.nn'
|
963
|
-
|
964
|
-
def __init__(
|
965
|
-
self,
|
966
|
-
kernel_size: Size,
|
967
|
-
stride: Union[int, Sequence[int]] = None,
|
968
|
-
padding: Union[int, Tuple[int, ...]] = 0,
|
969
|
-
channel_axis: Optional[int] = -1,
|
970
|
-
name: Optional[str] = None,
|
971
|
-
in_size: Optional[Size] = None,
|
972
|
-
):
|
973
|
-
super().__init__(
|
974
|
-
pool_dim=2,
|
975
|
-
kernel_size=kernel_size,
|
976
|
-
stride=stride,
|
977
|
-
padding=padding,
|
978
|
-
channel_axis=channel_axis,
|
979
|
-
name=name,
|
980
|
-
in_size=in_size
|
981
|
-
)
|
982
|
-
|
983
|
-
|
984
|
-
class MaxUnpool3d(_MaxUnpool):
|
985
|
-
r"""Computes a partial inverse of MaxPool3d.
|
986
|
-
|
987
|
-
MaxPool3d is not fully invertible, since the non-maximal values are lost.
|
988
|
-
MaxUnpool3d takes in as input the output of MaxPool3d including the indices
|
989
|
-
of the maximal values and computes a partial inverse in which all
|
990
|
-
non-maximal values are set to zero.
|
991
|
-
|
992
|
-
Shape:
|
993
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`
|
994
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
995
|
-
|
996
|
-
.. math::
|
997
|
-
D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
|
998
|
-
|
999
|
-
.. math::
|
1000
|
-
H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
|
1001
|
-
|
1002
|
-
.. math::
|
1003
|
-
W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{kernel\_size}[2]
|
1004
|
-
|
1005
|
-
or as given by :attr:`output_size` in the call operator
|
1006
|
-
|
1007
|
-
Parameters
|
1008
|
-
----------
|
1009
|
-
kernel_size : int or tuple
|
1010
|
-
Size of the max pooling window.
|
1011
|
-
stride : int or tuple, optional
|
1012
|
-
Stride of the max pooling window. Default: kernel_size
|
1013
|
-
padding : int or tuple, optional
|
1014
|
-
Padding that was added to the input. Default: 0
|
1015
|
-
channel_axis : int, optional
|
1016
|
-
Axis of the channels. Default: -1
|
1017
|
-
name : str, optional
|
1018
|
-
Name of the module.
|
1019
|
-
in_size : Size, optional
|
1020
|
-
Input size for shape inference.
|
1021
|
-
|
1022
|
-
Examples
|
1023
|
-
--------
|
1024
|
-
.. code-block:: python
|
1025
|
-
|
1026
|
-
>>> import brainstate
|
1027
|
-
>>> # Create pooling and unpooling layers
|
1028
|
-
>>> pool = MaxPool3d(2, stride=2, return_indices=True, channel_axis=-1)
|
1029
|
-
>>> unpool = MaxUnpool3d(2, stride=2, channel_axis=-1)
|
1030
|
-
>>> input = brainstate.random.randn(1, 4, 4, 4, 16)
|
1031
|
-
>>> output, indices = pool(input)
|
1032
|
-
>>> unpooled = unpool(output, indices)
|
1033
|
-
>>> # unpooled will have shape (1, 8, 8, 8, 16) with zeros at non-maximal positions
|
1034
|
-
"""
|
1035
|
-
__module__ = 'brainstate.nn'
|
1036
|
-
|
1037
|
-
def __init__(
|
1038
|
-
self,
|
1039
|
-
kernel_size: Size,
|
1040
|
-
stride: Union[int, Sequence[int]] = None,
|
1041
|
-
padding: Union[int, Tuple[int, ...]] = 0,
|
1042
|
-
channel_axis: Optional[int] = -1,
|
1043
|
-
name: Optional[str] = None,
|
1044
|
-
in_size: Optional[Size] = None,
|
1045
|
-
):
|
1046
|
-
super().__init__(
|
1047
|
-
pool_dim=3,
|
1048
|
-
kernel_size=kernel_size,
|
1049
|
-
stride=stride,
|
1050
|
-
padding=padding,
|
1051
|
-
channel_axis=channel_axis,
|
1052
|
-
name=name,
|
1053
|
-
in_size=in_size
|
1054
|
-
)
|
1055
|
-
|
1056
|
-
|
1057
|
-
class AvgPool1d(_AvgPool):
|
1058
|
-
r"""Applies a 1D average pooling over an input signal composed of several input planes.
|
1059
|
-
|
1060
|
-
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
|
1061
|
-
output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
|
1062
|
-
can be precisely described as:
|
1063
|
-
|
1064
|
-
.. math::
|
1065
|
-
|
1066
|
-
\text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
|
1067
|
-
\text{input}(N_i, \text{stride} \times l + m, C_j)
|
1068
|
-
|
1069
|
-
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
1070
|
-
for :attr:`padding` number of points.
|
1071
|
-
|
1072
|
-
Shape:
|
1073
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
1074
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
1075
|
-
|
1076
|
-
.. math::
|
1077
|
-
L_{out} = \left\lfloor \frac{L_{in} +
|
1078
|
-
2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
1079
|
-
|
1080
|
-
Parameters
|
1081
|
-
----------
|
1082
|
-
kernel_size : int or sequence of int
|
1083
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
1084
|
-
stride : int or sequence of int, optional
|
1085
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
1086
|
-
Default: 1
|
1087
|
-
padding : str, int or sequence of tuple, optional
|
1088
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1089
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
1090
|
-
and after each spatial dimension. Default: 'VALID'
|
1091
|
-
channel_axis : int, optional
|
1092
|
-
Axis of the spatial channels for which pooling is skipped.
|
1093
|
-
If ``None``, there is no channel axis. Default: -1
|
1094
|
-
name : str, optional
|
1095
|
-
The object name.
|
1096
|
-
in_size : Sequence of int, optional
|
1097
|
-
The shape of the input tensor.
|
1098
|
-
|
1099
|
-
Examples
|
1100
|
-
--------
|
1101
|
-
.. code-block:: python
|
1102
|
-
|
1103
|
-
>>> import brainstate
|
1104
|
-
>>> # pool with window of size=3, stride=2
|
1105
|
-
>>> m = AvgPool1d(3, stride=2)
|
1106
|
-
>>> input = brainstate.random.randn(20, 50, 16)
|
1107
|
-
>>> m(input).shape
|
1108
|
-
(20, 24, 16)
|
1109
|
-
"""
|
1110
|
-
__module__ = 'brainstate.nn'
|
1111
|
-
|
1112
|
-
def __init__(
|
1113
|
-
self,
|
1114
|
-
kernel_size: Size,
|
1115
|
-
stride: Union[int, Sequence[int]] = 1,
|
1116
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1117
|
-
channel_axis: Optional[int] = -1,
|
1118
|
-
name: Optional[str] = None,
|
1119
|
-
in_size: Optional[Size] = None,
|
1120
|
-
):
|
1121
|
-
super().__init__(
|
1122
|
-
in_size=in_size,
|
1123
|
-
init_value=0.,
|
1124
|
-
computation=jax.lax.add,
|
1125
|
-
pool_dim=1,
|
1126
|
-
kernel_size=kernel_size,
|
1127
|
-
stride=stride,
|
1128
|
-
padding=padding,
|
1129
|
-
channel_axis=channel_axis,
|
1130
|
-
name=name
|
1131
|
-
)
|
1132
|
-
|
1133
|
-
|
1134
|
-
class AvgPool2d(_AvgPool):
|
1135
|
-
r"""Applies a 2D average pooling over an input signal composed of several input planes.
|
1136
|
-
|
1137
|
-
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
1138
|
-
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
1139
|
-
can be precisely described as:
|
1140
|
-
|
1141
|
-
.. math::
|
1142
|
-
|
1143
|
-
out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
|
1144
|
-
input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
|
1145
|
-
|
1146
|
-
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
1147
|
-
for :attr:`padding` number of points.
|
1148
|
-
|
1149
|
-
Shape:
|
1150
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
1151
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
1152
|
-
|
1153
|
-
.. math::
|
1154
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
|
1155
|
-
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1156
|
-
|
1157
|
-
.. math::
|
1158
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
|
1159
|
-
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1160
|
-
|
1161
|
-
Parameters
|
1162
|
-
----------
|
1163
|
-
kernel_size : int or sequence of int
|
1164
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
1165
|
-
stride : int or sequence of int, optional
|
1166
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
1167
|
-
Default: 1
|
1168
|
-
padding : str, int or sequence of tuple, optional
|
1169
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1170
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
1171
|
-
and after each spatial dimension. Default: 'VALID'
|
1172
|
-
channel_axis : int, optional
|
1173
|
-
Axis of the spatial channels for which pooling is skipped.
|
1174
|
-
If ``None``, there is no channel axis. Default: -1
|
1175
|
-
name : str, optional
|
1176
|
-
The object name.
|
1177
|
-
in_size : Sequence of int, optional
|
1178
|
-
The shape of the input tensor.
|
1179
|
-
|
1180
|
-
Examples
|
1181
|
-
--------
|
1182
|
-
.. code-block:: python
|
1183
|
-
|
1184
|
-
>>> import brainstate
|
1185
|
-
>>> # pool of square window of size=3, stride=2
|
1186
|
-
>>> m = AvgPool2d(3, stride=2)
|
1187
|
-
>>> # pool of non-square window
|
1188
|
-
>>> m = AvgPool2d((3, 2), stride=(2, 1))
|
1189
|
-
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
1190
|
-
>>> output = m(input)
|
1191
|
-
>>> output.shape
|
1192
|
-
(20, 24, 31, 16)
|
1193
|
-
"""
|
1194
|
-
__module__ = 'brainstate.nn'
|
1195
|
-
|
1196
|
-
def __init__(
|
1197
|
-
self,
|
1198
|
-
kernel_size: Size,
|
1199
|
-
stride: Union[int, Sequence[int]] = 1,
|
1200
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1201
|
-
channel_axis: Optional[int] = -1,
|
1202
|
-
name: Optional[str] = None,
|
1203
|
-
in_size: Optional[Size] = None,
|
1204
|
-
):
|
1205
|
-
super().__init__(
|
1206
|
-
in_size=in_size,
|
1207
|
-
init_value=0.,
|
1208
|
-
computation=jax.lax.add,
|
1209
|
-
pool_dim=2,
|
1210
|
-
kernel_size=kernel_size,
|
1211
|
-
stride=stride,
|
1212
|
-
padding=padding,
|
1213
|
-
channel_axis=channel_axis,
|
1214
|
-
name=name
|
1215
|
-
)
|
1216
|
-
|
1217
|
-
|
1218
|
-
class AvgPool3d(_AvgPool):
|
1219
|
-
r"""Applies a 3D average pooling over an input signal composed of several input planes.
|
1220
|
-
|
1221
|
-
|
1222
|
-
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
1223
|
-
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
1224
|
-
can be precisely described as:
|
1225
|
-
|
1226
|
-
.. math::
|
1227
|
-
\begin{aligned}
|
1228
|
-
\text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
|
1229
|
-
& \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
|
1230
|
-
\text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
|
1231
|
-
{kD \times kH \times kW}
|
1232
|
-
\end{aligned}
|
1233
|
-
|
1234
|
-
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
|
1235
|
-
for :attr:`padding` number of points.
|
1236
|
-
|
1237
|
-
Shape:
|
1238
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1239
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
|
1240
|
-
:math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
1241
|
-
|
1242
|
-
.. math::
|
1243
|
-
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
|
1244
|
-
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1245
|
-
|
1246
|
-
.. math::
|
1247
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
|
1248
|
-
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1249
|
-
|
1250
|
-
.. math::
|
1251
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
|
1252
|
-
\text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
1253
|
-
|
1254
|
-
Parameters
|
1255
|
-
----------
|
1256
|
-
kernel_size : int or sequence of int
|
1257
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
1258
|
-
stride : int or sequence of int, optional
|
1259
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
1260
|
-
Default: 1
|
1261
|
-
padding : str, int or sequence of tuple, optional
|
1262
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1263
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
1264
|
-
and after each spatial dimension. Default: 'VALID'
|
1265
|
-
channel_axis : int, optional
|
1266
|
-
Axis of the spatial channels for which pooling is skipped.
|
1267
|
-
If ``None``, there is no channel axis. Default: -1
|
1268
|
-
name : str, optional
|
1269
|
-
The object name.
|
1270
|
-
in_size : Sequence of int, optional
|
1271
|
-
The shape of the input tensor.
|
1272
|
-
|
1273
|
-
Examples
|
1274
|
-
--------
|
1275
|
-
.. code-block:: python
|
1276
|
-
|
1277
|
-
>>> import brainstate
|
1278
|
-
>>> # pool of square window of size=3, stride=2
|
1279
|
-
>>> m = AvgPool3d(3, stride=2)
|
1280
|
-
>>> # pool of non-square window
|
1281
|
-
>>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
|
1282
|
-
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
1283
|
-
>>> output = m(input)
|
1284
|
-
>>> output.shape
|
1285
|
-
(20, 24, 43, 15, 16)
|
1286
|
-
|
1287
|
-
"""
|
1288
|
-
__module__ = 'brainstate.nn'
|
1289
|
-
|
1290
|
-
def __init__(
|
1291
|
-
self,
|
1292
|
-
kernel_size: Size,
|
1293
|
-
stride: Union[int, Sequence[int]] = 1,
|
1294
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1295
|
-
channel_axis: Optional[int] = -1,
|
1296
|
-
name: Optional[str] = None,
|
1297
|
-
in_size: Optional[Size] = None,
|
1298
|
-
):
|
1299
|
-
super().__init__(
|
1300
|
-
in_size=in_size,
|
1301
|
-
init_value=0.,
|
1302
|
-
computation=jax.lax.add,
|
1303
|
-
pool_dim=3,
|
1304
|
-
kernel_size=kernel_size,
|
1305
|
-
stride=stride,
|
1306
|
-
padding=padding,
|
1307
|
-
channel_axis=channel_axis,
|
1308
|
-
name=name
|
1309
|
-
)
|
1310
|
-
|
1311
|
-
|
1312
|
-
class _LPPool(Module):
|
1313
|
-
"""Base class for Lp pooling operations."""
|
1314
|
-
|
1315
|
-
def __init__(
|
1316
|
-
self,
|
1317
|
-
norm_type: float,
|
1318
|
-
pool_dim: int,
|
1319
|
-
kernel_size: Size,
|
1320
|
-
stride: Union[int, Sequence[int]] = None,
|
1321
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1322
|
-
channel_axis: Optional[int] = -1,
|
1323
|
-
name: Optional[str] = None,
|
1324
|
-
in_size: Optional[Size] = None,
|
1325
|
-
):
|
1326
|
-
super().__init__(name=name)
|
1327
|
-
|
1328
|
-
if norm_type <= 0:
|
1329
|
-
raise ValueError(f"norm_type must be positive, got {norm_type}")
|
1330
|
-
self.norm_type = norm_type
|
1331
|
-
self.pool_dim = pool_dim
|
1332
|
-
|
1333
|
-
# kernel_size
|
1334
|
-
if isinstance(kernel_size, int):
|
1335
|
-
kernel_size = (kernel_size,) * pool_dim
|
1336
|
-
elif isinstance(kernel_size, Sequence):
|
1337
|
-
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
1338
|
-
assert all(
|
1339
|
-
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
1340
|
-
if len(kernel_size) != pool_dim:
|
1341
|
-
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
1342
|
-
else:
|
1343
|
-
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
1344
|
-
self.kernel_size = kernel_size
|
1345
|
-
|
1346
|
-
# stride
|
1347
|
-
if stride is None:
|
1348
|
-
stride = kernel_size
|
1349
|
-
if isinstance(stride, int):
|
1350
|
-
stride = (stride,) * pool_dim
|
1351
|
-
elif isinstance(stride, Sequence):
|
1352
|
-
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
1353
|
-
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
1354
|
-
if len(stride) != pool_dim:
|
1355
|
-
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
|
1356
|
-
else:
|
1357
|
-
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
1358
|
-
self.stride = stride
|
1359
|
-
|
1360
|
-
# padding
|
1361
|
-
if isinstance(padding, str):
|
1362
|
-
if padding not in ("SAME", "VALID"):
|
1363
|
-
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
|
1364
|
-
elif isinstance(padding, int):
|
1365
|
-
padding = [(padding, padding) for _ in range(pool_dim)]
|
1366
|
-
elif isinstance(padding, (list, tuple)):
|
1367
|
-
if isinstance(padding[0], int):
|
1368
|
-
if len(padding) == pool_dim:
|
1369
|
-
padding = [(x, x) for x in padding]
|
1370
|
-
else:
|
1371
|
-
raise ValueError(f'If padding is a sequence of ints, it '
|
1372
|
-
f'should has the length of {pool_dim}.')
|
1373
|
-
else:
|
1374
|
-
if not all([isinstance(x, (tuple, list)) for x in padding]):
|
1375
|
-
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
|
1376
|
-
if not all([len(x) == 2 for x in padding]):
|
1377
|
-
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
|
1378
|
-
if len(padding) == 1:
|
1379
|
-
padding = tuple(padding) * pool_dim
|
1380
|
-
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
|
1381
|
-
else:
|
1382
|
-
raise ValueError
|
1383
|
-
self.padding = padding
|
1384
|
-
|
1385
|
-
# channel_axis
|
1386
|
-
assert channel_axis is None or isinstance(channel_axis, int), \
|
1387
|
-
f'channel_axis should be an int, but got {channel_axis}'
|
1388
|
-
self.channel_axis = channel_axis
|
1389
|
-
|
1390
|
-
# in & out shapes
|
1391
|
-
if in_size is not None:
|
1392
|
-
in_size = tuple(in_size)
|
1393
|
-
self.in_size = in_size
|
1394
|
-
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
1395
|
-
self.out_size = y.shape[1:]
|
1396
|
-
|
1397
|
-
def update(self, x):
|
1398
|
-
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
1399
|
-
if x.ndim < x_dim:
|
1400
|
-
raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
|
1401
|
-
|
1402
|
-
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
|
1403
|
-
stride = self._infer_shape(x.ndim, self.stride, 1)
|
1404
|
-
padding = (self.padding if isinstance(self.padding, str) else
|
1405
|
-
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
1406
|
-
|
1407
|
-
# For Lp pooling, we need to:
|
1408
|
-
# 1. Take absolute value and raise to power p
|
1409
|
-
# 2. Sum over the window
|
1410
|
-
# 3. Take the p-th root
|
1411
|
-
|
1412
|
-
# Step 1: |x|^p
|
1413
|
-
x_pow = jnp.abs(x) ** self.norm_type
|
1414
|
-
|
1415
|
-
# Step 2: Sum over window
|
1416
|
-
pooled_sum = jax.lax.reduce_window(
|
1417
|
-
x_pow,
|
1418
|
-
init_value=0.,
|
1419
|
-
computation=jax.lax.add,
|
1420
|
-
window_dimensions=window_shape,
|
1421
|
-
window_strides=stride,
|
1422
|
-
padding=padding
|
1423
|
-
)
|
1424
|
-
|
1425
|
-
# Step 3: Take p-th root and multiply by normalization factor
|
1426
|
-
# The normalization factor is (1/N)^(1/p) where N is the window size
|
1427
|
-
window_size = np.prod([w for i, w in enumerate(self.kernel_size)])
|
1428
|
-
norm_factor = window_size ** (-1.0 / self.norm_type)
|
1429
|
-
result = norm_factor * (pooled_sum ** (1.0 / self.norm_type))
|
1430
|
-
|
1431
|
-
return result
|
1432
|
-
|
1433
|
-
def _infer_shape(self, x_dim, inputs, element):
|
1434
|
-
channel_axis = self.channel_axis
|
1435
|
-
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
|
1436
|
-
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
|
1437
|
-
if channel_axis and channel_axis < 0:
|
1438
|
-
channel_axis = x_dim + channel_axis
|
1439
|
-
all_dims = list(range(x_dim))
|
1440
|
-
if channel_axis is not None:
|
1441
|
-
all_dims.pop(channel_axis)
|
1442
|
-
pool_dims = all_dims[-self.pool_dim:]
|
1443
|
-
results = [element] * x_dim
|
1444
|
-
for i, dim in enumerate(pool_dims):
|
1445
|
-
results[dim] = inputs[i]
|
1446
|
-
return results
|
1447
|
-
|
1448
|
-
|
1449
|
-
class LPPool1d(_LPPool):
|
1450
|
-
r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
|
1451
|
-
|
1452
|
-
On each window, the function computed is:
|
1453
|
-
|
1454
|
-
.. math::
|
1455
|
-
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1456
|
-
|
1457
|
-
- At :math:`p = \infty`, one gets max pooling
|
1458
|
-
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1459
|
-
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1460
|
-
|
1461
|
-
Shape:
|
1462
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
1463
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
1464
|
-
|
1465
|
-
.. math::
|
1466
|
-
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
1467
|
-
|
1468
|
-
Parameters
|
1469
|
-
----------
|
1470
|
-
norm_type : float
|
1471
|
-
Exponent for the pooling operation. Default: 2.0
|
1472
|
-
kernel_size : int or sequence of int
|
1473
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
1474
|
-
stride : int or sequence of int, optional
|
1475
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
1476
|
-
Default: kernel_size
|
1477
|
-
padding : str, int or sequence of tuple, optional
|
1478
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1479
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
1480
|
-
and after each spatial dimension. Default: 'VALID'
|
1481
|
-
channel_axis : int, optional
|
1482
|
-
Axis of the spatial channels for which pooling is skipped.
|
1483
|
-
If ``None``, there is no channel axis. Default: -1
|
1484
|
-
name : str, optional
|
1485
|
-
The object name.
|
1486
|
-
in_size : Sequence of int, optional
|
1487
|
-
The shape of the input tensor.
|
1488
|
-
|
1489
|
-
Examples
|
1490
|
-
--------
|
1491
|
-
.. code-block:: python
|
1492
|
-
|
1493
|
-
>>> import brainstate
|
1494
|
-
>>> # power-average pooling of window of size=3, stride=2 with norm_type=2.0
|
1495
|
-
>>> m = LPPool1d(2, 3, stride=2)
|
1496
|
-
>>> input = brainstate.random.randn(20, 50, 16)
|
1497
|
-
>>> output = m(input)
|
1498
|
-
>>> output.shape
|
1499
|
-
(20, 24, 16)
|
1500
|
-
"""
|
1501
|
-
__module__ = 'brainstate.nn'
|
1502
|
-
|
1503
|
-
def __init__(
|
1504
|
-
self,
|
1505
|
-
norm_type: float,
|
1506
|
-
kernel_size: Size,
|
1507
|
-
stride: Union[int, Sequence[int]] = None,
|
1508
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1509
|
-
channel_axis: Optional[int] = -1,
|
1510
|
-
name: Optional[str] = None,
|
1511
|
-
in_size: Optional[Size] = None,
|
1512
|
-
):
|
1513
|
-
super().__init__(
|
1514
|
-
norm_type=norm_type,
|
1515
|
-
pool_dim=1,
|
1516
|
-
kernel_size=kernel_size,
|
1517
|
-
stride=stride,
|
1518
|
-
padding=padding,
|
1519
|
-
channel_axis=channel_axis,
|
1520
|
-
name=name,
|
1521
|
-
in_size=in_size
|
1522
|
-
)
|
1523
|
-
|
1524
|
-
|
1525
|
-
class LPPool2d(_LPPool):
|
1526
|
-
r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
|
1527
|
-
|
1528
|
-
On each window, the function computed is:
|
1529
|
-
|
1530
|
-
.. math::
|
1531
|
-
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1532
|
-
|
1533
|
-
- At :math:`p = \infty`, one gets max pooling
|
1534
|
-
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1535
|
-
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1536
|
-
|
1537
|
-
Shape:
|
1538
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
1539
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
1540
|
-
|
1541
|
-
.. math::
|
1542
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1543
|
-
|
1544
|
-
.. math::
|
1545
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1546
|
-
|
1547
|
-
Parameters
|
1548
|
-
----------
|
1549
|
-
norm_type : float
|
1550
|
-
Exponent for the pooling operation. Default: 2.0
|
1551
|
-
kernel_size : int or sequence of int
|
1552
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
1553
|
-
stride : int or sequence of int, optional
|
1554
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
1555
|
-
Default: kernel_size
|
1556
|
-
padding : str, int or sequence of tuple, optional
|
1557
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1558
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
1559
|
-
and after each spatial dimension. Default: 'VALID'
|
1560
|
-
channel_axis : int, optional
|
1561
|
-
Axis of the spatial channels for which pooling is skipped.
|
1562
|
-
If ``None``, there is no channel axis. Default: -1
|
1563
|
-
name : str, optional
|
1564
|
-
The object name.
|
1565
|
-
in_size : Sequence of int, optional
|
1566
|
-
The shape of the input tensor.
|
1567
|
-
|
1568
|
-
Examples
|
1569
|
-
--------
|
1570
|
-
.. code-block:: python
|
1571
|
-
|
1572
|
-
>>> import brainstate
|
1573
|
-
>>> # power-average pooling of square window of size=3, stride=2
|
1574
|
-
>>> m = LPPool2d(2, 3, stride=2)
|
1575
|
-
>>> # pool of non-square window with norm_type=1.5
|
1576
|
-
>>> m = LPPool2d(1.5, (3, 2), stride=(2, 1), channel_axis=-1)
|
1577
|
-
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
1578
|
-
>>> output = m(input)
|
1579
|
-
>>> output.shape
|
1580
|
-
(20, 24, 31, 16)
|
1581
|
-
"""
|
1582
|
-
__module__ = 'brainstate.nn'
|
1583
|
-
|
1584
|
-
def __init__(
|
1585
|
-
self,
|
1586
|
-
norm_type: float,
|
1587
|
-
kernel_size: Size,
|
1588
|
-
stride: Union[int, Sequence[int]] = None,
|
1589
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1590
|
-
channel_axis: Optional[int] = -1,
|
1591
|
-
name: Optional[str] = None,
|
1592
|
-
in_size: Optional[Size] = None,
|
1593
|
-
):
|
1594
|
-
super().__init__(
|
1595
|
-
norm_type=norm_type,
|
1596
|
-
pool_dim=2,
|
1597
|
-
kernel_size=kernel_size,
|
1598
|
-
stride=stride,
|
1599
|
-
padding=padding,
|
1600
|
-
channel_axis=channel_axis,
|
1601
|
-
name=name,
|
1602
|
-
in_size=in_size
|
1603
|
-
)
|
1604
|
-
|
1605
|
-
|
1606
|
-
class LPPool3d(_LPPool):
|
1607
|
-
r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
|
1608
|
-
|
1609
|
-
On each window, the function computed is:
|
1610
|
-
|
1611
|
-
.. math::
|
1612
|
-
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1613
|
-
|
1614
|
-
- At :math:`p = \infty`, one gets max pooling
|
1615
|
-
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1616
|
-
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1617
|
-
|
1618
|
-
Shape:
|
1619
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1620
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
1621
|
-
|
1622
|
-
.. math::
|
1623
|
-
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1624
|
-
|
1625
|
-
.. math::
|
1626
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1627
|
-
|
1628
|
-
.. math::
|
1629
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
1630
|
-
|
1631
|
-
Parameters
|
1632
|
-
----------
|
1633
|
-
norm_type : float
|
1634
|
-
Exponent for the pooling operation. Default: 2.0
|
1635
|
-
kernel_size : int or sequence of int
|
1636
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
1637
|
-
stride : int or sequence of int, optional
|
1638
|
-
An integer, or a sequence of integers, representing the inter-window stride.
|
1639
|
-
Default: kernel_size
|
1640
|
-
padding : str, int or sequence of tuple, optional
|
1641
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1642
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
1643
|
-
and after each spatial dimension. Default: 'VALID'
|
1644
|
-
channel_axis : int, optional
|
1645
|
-
Axis of the spatial channels for which pooling is skipped.
|
1646
|
-
If ``None``, there is no channel axis. Default: -1
|
1647
|
-
name : str, optional
|
1648
|
-
The object name.
|
1649
|
-
in_size : Sequence of int, optional
|
1650
|
-
The shape of the input tensor.
|
1651
|
-
|
1652
|
-
Examples
|
1653
|
-
--------
|
1654
|
-
.. code-block:: python
|
1655
|
-
|
1656
|
-
>>> import brainstate
|
1657
|
-
>>> # power-average pooling of cube window of size=3, stride=2
|
1658
|
-
>>> m = LPPool3d(2, 3, stride=2)
|
1659
|
-
>>> # pool of non-cubic window with norm_type=1.5
|
1660
|
-
>>> m = LPPool3d(1.5, (3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
1661
|
-
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
1662
|
-
>>> output = m(input)
|
1663
|
-
>>> output.shape
|
1664
|
-
(20, 24, 43, 15, 16)
|
1665
|
-
"""
|
1666
|
-
__module__ = 'brainstate.nn'
|
1667
|
-
|
1668
|
-
def __init__(
|
1669
|
-
self,
|
1670
|
-
norm_type: float,
|
1671
|
-
kernel_size: Size,
|
1672
|
-
stride: Union[int, Sequence[int]] = None,
|
1673
|
-
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
1674
|
-
channel_axis: Optional[int] = -1,
|
1675
|
-
name: Optional[str] = None,
|
1676
|
-
in_size: Optional[Size] = None,
|
1677
|
-
):
|
1678
|
-
super().__init__(
|
1679
|
-
norm_type=norm_type,
|
1680
|
-
pool_dim=3,
|
1681
|
-
kernel_size=kernel_size,
|
1682
|
-
stride=stride,
|
1683
|
-
padding=padding,
|
1684
|
-
channel_axis=channel_axis,
|
1685
|
-
name=name,
|
1686
|
-
in_size=in_size
|
1687
|
-
)
|
1688
|
-
|
1689
|
-
|
1690
|
-
def _adaptive_pool1d(x, target_size: int, operation: Callable):
|
1691
|
-
"""Adaptive pool 1D.
|
1692
|
-
|
1693
|
-
Args:
|
1694
|
-
x: The input. Should be a JAX array of shape `(dim,)`.
|
1695
|
-
target_size: The shape of the output after the pooling operation `(target_size,)`.
|
1696
|
-
operation: The pooling operation to be performed on the input array.
|
1697
|
-
|
1698
|
-
Returns:
|
1699
|
-
A JAX array of shape `(target_size, )`.
|
1700
|
-
"""
|
1701
|
-
size = jnp.size(x)
|
1702
|
-
num_head_arrays = size % target_size
|
1703
|
-
num_block = size // target_size
|
1704
|
-
if num_head_arrays != 0:
|
1705
|
-
head_end_index = num_head_arrays * (num_block + 1)
|
1706
|
-
heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
|
1707
|
-
tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
|
1708
|
-
outs = jnp.concatenate([heads, tails])
|
1709
|
-
else:
|
1710
|
-
outs = jax.vmap(operation)(x.reshape(-1, num_block))
|
1711
|
-
return outs
|
1712
|
-
|
1713
|
-
|
1714
|
-
def _generate_vmap(fun: Callable, map_axes: List[int]):
|
1715
|
-
map_axes = sorted(map_axes)
|
1716
|
-
for axis in map_axes:
|
1717
|
-
fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
|
1718
|
-
return fun
|
1719
|
-
|
1720
|
-
|
1721
|
-
class _AdaptivePool(Module):
|
1722
|
-
"""General N dimensional adaptive down-sampling to a target shape.
|
1723
|
-
|
1724
|
-
Parameters
|
1725
|
-
----------
|
1726
|
-
in_size: Sequence of int
|
1727
|
-
The shape of the input tensor.
|
1728
|
-
target_size: int, sequence of int
|
1729
|
-
The target output shape.
|
1730
|
-
num_spatial_dims: int
|
1731
|
-
The number of spatial dimensions.
|
1732
|
-
channel_axis: int, optional
|
1733
|
-
Axis of the spatial channels for which pooling is skipped.
|
1734
|
-
If ``None``, there is no channel axis.
|
1735
|
-
operation: Callable
|
1736
|
-
The down-sampling operation.
|
1737
|
-
name: str
|
1738
|
-
The class name.
|
1739
|
-
"""
|
1740
|
-
|
1741
|
-
def __init__(
|
1742
|
-
self,
|
1743
|
-
in_size: Size,
|
1744
|
-
target_size: Size,
|
1745
|
-
num_spatial_dims: int,
|
1746
|
-
operation: Callable,
|
1747
|
-
channel_axis: Optional[int] = -1,
|
1748
|
-
name: Optional[str] = None,
|
1749
|
-
):
|
1750
|
-
super().__init__(name=name)
|
1751
|
-
|
1752
|
-
self.channel_axis = channel_axis
|
1753
|
-
self.operation = operation
|
1754
|
-
if isinstance(target_size, int):
|
1755
|
-
self.target_shape = (target_size,) * num_spatial_dims
|
1756
|
-
elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
|
1757
|
-
self.target_shape = target_size
|
1758
|
-
else:
|
1759
|
-
raise ValueError("`target_size` must either be an int or tuple of length "
|
1760
|
-
f"{num_spatial_dims} containing ints.")
|
1761
|
-
|
1762
|
-
# in & out shapes
|
1763
|
-
if in_size is not None:
|
1764
|
-
in_size = tuple(in_size)
|
1765
|
-
self.in_size = in_size
|
1766
|
-
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
1767
|
-
self.out_size = y.shape[1:]
|
1768
|
-
|
1769
|
-
def update(self, x):
|
1770
|
-
"""Input-output mapping.
|
1771
|
-
|
1772
|
-
Parameters
|
1773
|
-
----------
|
1774
|
-
x: Array
|
1775
|
-
Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
|
1776
|
-
or `(..., dim_1, dim_2)`.
|
1777
|
-
"""
|
1778
|
-
# channel axis
|
1779
|
-
channel_axis = self.channel_axis
|
1780
|
-
|
1781
|
-
if channel_axis:
|
1782
|
-
if not 0 <= abs(channel_axis) < x.ndim:
|
1783
|
-
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
|
1784
|
-
if channel_axis < 0:
|
1785
|
-
channel_axis = x.ndim + channel_axis
|
1786
|
-
# input dimension
|
1787
|
-
if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
|
1788
|
-
raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
|
1789
|
-
f"dimensions (channel_axis={self.channel_axis}). "
|
1790
|
-
f"But got {x.ndim} dimensions.")
|
1791
|
-
# pooling dimensions
|
1792
|
-
pool_dims = list(range(x.ndim))
|
1793
|
-
if channel_axis:
|
1794
|
-
pool_dims.pop(channel_axis)
|
1795
|
-
|
1796
|
-
# pooling
|
1797
|
-
for i, di in enumerate(pool_dims[-len(self.target_shape):]):
|
1798
|
-
poo_axes = [j for j in range(x.ndim) if j != di]
|
1799
|
-
op = _generate_vmap(_adaptive_pool1d, poo_axes)
|
1800
|
-
x = op(x, self.target_shape[i], self.operation)
|
1801
|
-
return x
|
1802
|
-
|
1803
|
-
|
1804
|
-
class AdaptiveAvgPool1d(_AdaptivePool):
|
1805
|
-
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
1806
|
-
|
1807
|
-
The output size is :math:`L_{out}`, for any input size.
|
1808
|
-
The number of output features is equal to the number of input planes.
|
1809
|
-
|
1810
|
-
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1811
|
-
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1812
|
-
|
1813
|
-
Shape:
|
1814
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
1815
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
1816
|
-
:math:`L_{out}=\text{target\_size}`.
|
1817
|
-
|
1818
|
-
Parameters
|
1819
|
-
----------
|
1820
|
-
target_size : int or sequence of int
|
1821
|
-
The target output size. The number of output features for each channel.
|
1822
|
-
channel_axis : int, optional
|
1823
|
-
Axis of the spatial channels for which pooling is skipped.
|
1824
|
-
If ``None``, there is no channel axis. Default: -1
|
1825
|
-
name : str, optional
|
1826
|
-
The name of the module.
|
1827
|
-
in_size : Sequence of int, optional
|
1828
|
-
The shape of the input tensor for shape inference.
|
1829
|
-
|
1830
|
-
Examples
|
1831
|
-
--------
|
1832
|
-
.. code-block:: python
|
1833
|
-
|
1834
|
-
>>> import brainstate
|
1835
|
-
>>> # Target output size of 5
|
1836
|
-
>>> m = AdaptiveAvgPool1d(5)
|
1837
|
-
>>> input = brainstate.random.randn(1, 64, 8)
|
1838
|
-
>>> output = m(input)
|
1839
|
-
>>> output.shape
|
1840
|
-
(1, 5, 8)
|
1841
|
-
>>> # Can handle variable input sizes
|
1842
|
-
>>> input2 = brainstate.random.randn(1, 32, 8)
|
1843
|
-
>>> output2 = m(input2)
|
1844
|
-
>>> output2.shape
|
1845
|
-
(1, 5, 8) # Same output size regardless of input size
|
1846
|
-
|
1847
|
-
See Also
|
1848
|
-
--------
|
1849
|
-
AvgPool1d : Non-adaptive 1D average pooling.
|
1850
|
-
AdaptiveMaxPool1d : Adaptive 1D max pooling.
|
1851
|
-
"""
|
1852
|
-
__module__ = 'brainstate.nn'
|
1853
|
-
|
1854
|
-
def __init__(
|
1855
|
-
self,
|
1856
|
-
target_size: Union[int, Sequence[int]],
|
1857
|
-
channel_axis: Optional[int] = -1,
|
1858
|
-
name: Optional[str] = None,
|
1859
|
-
in_size: Optional[Sequence[int]] = None,
|
1860
|
-
):
|
1861
|
-
super().__init__(
|
1862
|
-
in_size=in_size,
|
1863
|
-
target_size=target_size,
|
1864
|
-
channel_axis=channel_axis,
|
1865
|
-
num_spatial_dims=1,
|
1866
|
-
operation=jnp.mean,
|
1867
|
-
name=name
|
1868
|
-
)
|
1869
|
-
|
1870
|
-
|
1871
|
-
class AdaptiveAvgPool2d(_AdaptivePool):
|
1872
|
-
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
1873
|
-
|
1874
|
-
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
1875
|
-
The number of output features is equal to the number of input planes.
|
1876
|
-
|
1877
|
-
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1878
|
-
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1879
|
-
|
1880
|
-
Shape:
|
1881
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
1882
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
1883
|
-
:math:`(H_{out}, W_{out})=\text{target\_size}`.
|
1884
|
-
|
1885
|
-
Parameters
|
1886
|
-
----------
|
1887
|
-
target_size : int or tuple of int
|
1888
|
-
The target output size. If a single integer is provided, the output will be a square
|
1889
|
-
of that size. If a tuple is provided, it specifies (H_out, W_out).
|
1890
|
-
Use None for dimensions that should not be pooled.
|
1891
|
-
channel_axis : int, optional
|
1892
|
-
Axis of the spatial channels for which pooling is skipped.
|
1893
|
-
If ``None``, there is no channel axis. Default: -1
|
1894
|
-
name : str, optional
|
1895
|
-
The name of the module.
|
1896
|
-
in_size : Sequence of int, optional
|
1897
|
-
The shape of the input tensor for shape inference.
|
1898
|
-
|
1899
|
-
Examples
|
1900
|
-
--------
|
1901
|
-
.. code-block:: python
|
1902
|
-
|
1903
|
-
>>> import brainstate
|
1904
|
-
>>> # Target output size of 5x7
|
1905
|
-
>>> m = AdaptiveAvgPool2d((5, 7))
|
1906
|
-
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
1907
|
-
>>> output = m(input)
|
1908
|
-
>>> output.shape
|
1909
|
-
(1, 5, 7, 64)
|
1910
|
-
>>> # Target output size of 7x7 (square)
|
1911
|
-
>>> m = AdaptiveAvgPool2d(7)
|
1912
|
-
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
1913
|
-
>>> output = m(input)
|
1914
|
-
>>> output.shape
|
1915
|
-
(1, 7, 7, 64)
|
1916
|
-
>>> # Target output size of 10x7
|
1917
|
-
>>> m = AdaptiveAvgPool2d((None, 7))
|
1918
|
-
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
1919
|
-
>>> output = m(input)
|
1920
|
-
>>> output.shape
|
1921
|
-
(1, 10, 7, 64)
|
1922
|
-
|
1923
|
-
See Also
|
1924
|
-
--------
|
1925
|
-
AvgPool2d : Non-adaptive 2D average pooling.
|
1926
|
-
AdaptiveMaxPool2d : Adaptive 2D max pooling.
|
1927
|
-
"""
|
1928
|
-
__module__ = 'brainstate.nn'
|
1929
|
-
|
1930
|
-
def __init__(
|
1931
|
-
self,
|
1932
|
-
target_size: Union[int, Sequence[int]],
|
1933
|
-
channel_axis: Optional[int] = -1,
|
1934
|
-
name: Optional[str] = None,
|
1935
|
-
in_size: Optional[Sequence[int]] = None,
|
1936
|
-
):
|
1937
|
-
super().__init__(
|
1938
|
-
in_size=in_size,
|
1939
|
-
target_size=target_size,
|
1940
|
-
channel_axis=channel_axis,
|
1941
|
-
num_spatial_dims=2,
|
1942
|
-
operation=jnp.mean,
|
1943
|
-
name=name
|
1944
|
-
)
|
1945
|
-
|
1946
|
-
|
1947
|
-
class AdaptiveAvgPool3d(_AdaptivePool):
|
1948
|
-
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
|
1949
|
-
|
1950
|
-
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
1951
|
-
The number of output features is equal to the number of input planes.
|
1952
|
-
|
1953
|
-
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1954
|
-
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1955
|
-
|
1956
|
-
Shape:
|
1957
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1958
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
1959
|
-
where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
|
1960
|
-
|
1961
|
-
Parameters
|
1962
|
-
----------
|
1963
|
-
target_size : int or tuple of int
|
1964
|
-
The target output size. If a single integer is provided, the output will be a cube
|
1965
|
-
of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
|
1966
|
-
Use None for dimensions that should not be pooled.
|
1967
|
-
channel_axis : int, optional
|
1968
|
-
Axis of the spatial channels for which pooling is skipped.
|
1969
|
-
If ``None``, there is no channel axis. Default: -1
|
1970
|
-
name : str, optional
|
1971
|
-
The name of the module.
|
1972
|
-
in_size : Sequence of int, optional
|
1973
|
-
The shape of the input tensor for shape inference.
|
1974
|
-
|
1975
|
-
Examples
|
1976
|
-
--------
|
1977
|
-
.. code-block:: python
|
1978
|
-
|
1979
|
-
>>> import brainstate
|
1980
|
-
>>> # Target output size of 5x7x9
|
1981
|
-
>>> m = AdaptiveAvgPool3d((5, 7, 9))
|
1982
|
-
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
1983
|
-
>>> output = m(input)
|
1984
|
-
>>> output.shape
|
1985
|
-
(1, 5, 7, 9, 64)
|
1986
|
-
>>> # Target output size of 7x7x7 (cube)
|
1987
|
-
>>> m = AdaptiveAvgPool3d(7)
|
1988
|
-
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1989
|
-
>>> output = m(input)
|
1990
|
-
>>> output.shape
|
1991
|
-
(1, 7, 7, 7, 64)
|
1992
|
-
>>> # Target output size of 7x9x8
|
1993
|
-
>>> m = AdaptiveAvgPool3d((7, None, None))
|
1994
|
-
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1995
|
-
>>> output = m(input)
|
1996
|
-
>>> output.shape
|
1997
|
-
(1, 7, 9, 8, 64)
|
1998
|
-
|
1999
|
-
See Also
|
2000
|
-
--------
|
2001
|
-
AvgPool3d : Non-adaptive 3D average pooling.
|
2002
|
-
AdaptiveMaxPool3d : Adaptive 3D max pooling.
|
2003
|
-
"""
|
2004
|
-
__module__ = 'brainstate.nn'
|
2005
|
-
|
2006
|
-
def __init__(
|
2007
|
-
self,
|
2008
|
-
target_size: Union[int, Sequence[int]],
|
2009
|
-
channel_axis: Optional[int] = -1,
|
2010
|
-
name: Optional[str] = None,
|
2011
|
-
in_size: Optional[Sequence[int]] = None,
|
2012
|
-
):
|
2013
|
-
super().__init__(
|
2014
|
-
in_size=in_size,
|
2015
|
-
target_size=target_size,
|
2016
|
-
channel_axis=channel_axis,
|
2017
|
-
num_spatial_dims=3,
|
2018
|
-
operation=jnp.mean,
|
2019
|
-
name=name
|
2020
|
-
)
|
2021
|
-
|
2022
|
-
|
2023
|
-
class AdaptiveMaxPool1d(_AdaptivePool):
|
2024
|
-
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
2025
|
-
|
2026
|
-
The output size is :math:`L_{out}`, for any input size.
|
2027
|
-
The number of output features is equal to the number of input planes.
|
2028
|
-
|
2029
|
-
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2030
|
-
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2031
|
-
|
2032
|
-
Shape:
|
2033
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
2034
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
2035
|
-
:math:`L_{out}=\text{target\_size}`.
|
2036
|
-
|
2037
|
-
Parameters
|
2038
|
-
----------
|
2039
|
-
target_size : int or sequence of int
|
2040
|
-
The target output size. The number of output features for each channel.
|
2041
|
-
channel_axis : int, optional
|
2042
|
-
Axis of the spatial channels for which pooling is skipped.
|
2043
|
-
If ``None``, there is no channel axis. Default: -1
|
2044
|
-
name : str, optional
|
2045
|
-
The name of the module.
|
2046
|
-
in_size : Sequence of int, optional
|
2047
|
-
The shape of the input tensor for shape inference.
|
2048
|
-
|
2049
|
-
Examples
|
2050
|
-
--------
|
2051
|
-
.. code-block:: python
|
2052
|
-
|
2053
|
-
>>> import brainstate
|
2054
|
-
>>> # Target output size of 5
|
2055
|
-
>>> m = AdaptiveMaxPool1d(5)
|
2056
|
-
>>> input = brainstate.random.randn(1, 64, 8)
|
2057
|
-
>>> output = m(input)
|
2058
|
-
>>> output.shape
|
2059
|
-
(1, 5, 8)
|
2060
|
-
>>> # Can handle variable input sizes
|
2061
|
-
>>> input2 = brainstate.random.randn(1, 32, 8)
|
2062
|
-
>>> output2 = m(input2)
|
2063
|
-
>>> output2.shape
|
2064
|
-
(1, 5, 8) # Same output size regardless of input size
|
2065
|
-
|
2066
|
-
See Also
|
2067
|
-
--------
|
2068
|
-
MaxPool1d : Non-adaptive 1D max pooling.
|
2069
|
-
AdaptiveAvgPool1d : Adaptive 1D average pooling.
|
2070
|
-
"""
|
2071
|
-
__module__ = 'brainstate.nn'
|
2072
|
-
|
2073
|
-
def __init__(
|
2074
|
-
self,
|
2075
|
-
target_size: Union[int, Sequence[int]],
|
2076
|
-
channel_axis: Optional[int] = -1,
|
2077
|
-
name: Optional[str] = None,
|
2078
|
-
in_size: Optional[Sequence[int]] = None,
|
2079
|
-
):
|
2080
|
-
super().__init__(
|
2081
|
-
in_size=in_size,
|
2082
|
-
target_size=target_size,
|
2083
|
-
channel_axis=channel_axis,
|
2084
|
-
num_spatial_dims=1,
|
2085
|
-
operation=jnp.max,
|
2086
|
-
name=name
|
2087
|
-
)
|
2088
|
-
|
2089
|
-
|
2090
|
-
class AdaptiveMaxPool2d(_AdaptivePool):
|
2091
|
-
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
2092
|
-
|
2093
|
-
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
2094
|
-
The number of output features is equal to the number of input planes.
|
2095
|
-
|
2096
|
-
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2097
|
-
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2098
|
-
|
2099
|
-
Shape:
|
2100
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
2101
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
2102
|
-
:math:`(H_{out}, W_{out})=\text{target\_size}`.
|
2103
|
-
|
2104
|
-
Parameters
|
2105
|
-
----------
|
2106
|
-
target_size : int or tuple of int
|
2107
|
-
The target output size. If a single integer is provided, the output will be a square
|
2108
|
-
of that size. If a tuple is provided, it specifies (H_out, W_out).
|
2109
|
-
Use None for dimensions that should not be pooled.
|
2110
|
-
channel_axis : int, optional
|
2111
|
-
Axis of the spatial channels for which pooling is skipped.
|
2112
|
-
If ``None``, there is no channel axis. Default: -1
|
2113
|
-
name : str, optional
|
2114
|
-
The name of the module.
|
2115
|
-
in_size : Sequence of int, optional
|
2116
|
-
The shape of the input tensor for shape inference.
|
2117
|
-
|
2118
|
-
Examples
|
2119
|
-
--------
|
2120
|
-
.. code-block:: python
|
2121
|
-
|
2122
|
-
>>> import brainstate
|
2123
|
-
>>> # Target output size of 5x7
|
2124
|
-
>>> m = AdaptiveMaxPool2d((5, 7))
|
2125
|
-
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
2126
|
-
>>> output = m(input)
|
2127
|
-
>>> output.shape
|
2128
|
-
(1, 5, 7, 64)
|
2129
|
-
>>> # Target output size of 7x7 (square)
|
2130
|
-
>>> m = AdaptiveMaxPool2d(7)
|
2131
|
-
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
2132
|
-
>>> output = m(input)
|
2133
|
-
>>> output.shape
|
2134
|
-
(1, 7, 7, 64)
|
2135
|
-
>>> # Target output size of 10x7
|
2136
|
-
>>> m = AdaptiveMaxPool2d((None, 7))
|
2137
|
-
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
2138
|
-
>>> output = m(input)
|
2139
|
-
>>> output.shape
|
2140
|
-
(1, 10, 7, 64)
|
2141
|
-
|
2142
|
-
See Also
|
2143
|
-
--------
|
2144
|
-
MaxPool2d : Non-adaptive 2D max pooling.
|
2145
|
-
AdaptiveAvgPool2d : Adaptive 2D average pooling.
|
2146
|
-
"""
|
2147
|
-
__module__ = 'brainstate.nn'
|
2148
|
-
|
2149
|
-
def __init__(
|
2150
|
-
self,
|
2151
|
-
target_size: Union[int, Sequence[int]],
|
2152
|
-
channel_axis: Optional[int] = -1,
|
2153
|
-
name: Optional[str] = None,
|
2154
|
-
in_size: Optional[Sequence[int]] = None,
|
2155
|
-
):
|
2156
|
-
super().__init__(
|
2157
|
-
in_size=in_size,
|
2158
|
-
target_size=target_size,
|
2159
|
-
channel_axis=channel_axis,
|
2160
|
-
num_spatial_dims=2,
|
2161
|
-
operation=jnp.max,
|
2162
|
-
name=name
|
2163
|
-
)
|
2164
|
-
|
2165
|
-
|
2166
|
-
class AdaptiveMaxPool3d(_AdaptivePool):
|
2167
|
-
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
2168
|
-
|
2169
|
-
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
2170
|
-
The number of output features is equal to the number of input planes.
|
2171
|
-
|
2172
|
-
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2173
|
-
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2174
|
-
|
2175
|
-
Shape:
|
2176
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
2177
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
2178
|
-
where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
|
2179
|
-
|
2180
|
-
Parameters
|
2181
|
-
----------
|
2182
|
-
target_size : int or tuple of int
|
2183
|
-
The target output size. If a single integer is provided, the output will be a cube
|
2184
|
-
of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
|
2185
|
-
Use None for dimensions that should not be pooled.
|
2186
|
-
channel_axis : int, optional
|
2187
|
-
Axis of the spatial channels for which pooling is skipped.
|
2188
|
-
If ``None``, there is no channel axis. Default: -1
|
2189
|
-
name : str, optional
|
2190
|
-
The name of the module.
|
2191
|
-
in_size : Sequence of int, optional
|
2192
|
-
The shape of the input tensor for shape inference.
|
2193
|
-
|
2194
|
-
Examples
|
2195
|
-
--------
|
2196
|
-
.. code-block:: python
|
2197
|
-
|
2198
|
-
>>> import brainstate
|
2199
|
-
>>> # Target output size of 5x7x9
|
2200
|
-
>>> m = AdaptiveMaxPool3d((5, 7, 9))
|
2201
|
-
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
2202
|
-
>>> output = m(input)
|
2203
|
-
>>> output.shape
|
2204
|
-
(1, 5, 7, 9, 64)
|
2205
|
-
>>> # Target output size of 7x7x7 (cube)
|
2206
|
-
>>> m = AdaptiveMaxPool3d(7)
|
2207
|
-
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
2208
|
-
>>> output = m(input)
|
2209
|
-
>>> output.shape
|
2210
|
-
(1, 7, 7, 7, 64)
|
2211
|
-
>>> # Target output size of 7x9x8
|
2212
|
-
>>> m = AdaptiveMaxPool3d((7, None, None))
|
2213
|
-
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
2214
|
-
>>> output = m(input)
|
2215
|
-
>>> output.shape
|
2216
|
-
(1, 7, 9, 8, 64)
|
2217
|
-
|
2218
|
-
See Also
|
2219
|
-
--------
|
2220
|
-
MaxPool3d : Non-adaptive 3D max pooling.
|
2221
|
-
AdaptiveAvgPool3d : Adaptive 3D average pooling.
|
2222
|
-
"""
|
2223
|
-
__module__ = 'brainstate.nn'
|
2224
|
-
|
2225
|
-
def __init__(
|
2226
|
-
self,
|
2227
|
-
target_size: Union[int, Sequence[int]],
|
2228
|
-
channel_axis: Optional[int] = -1,
|
2229
|
-
name: Optional[str] = None,
|
2230
|
-
in_size: Optional[Sequence[int]] = None,
|
2231
|
-
):
|
2232
|
-
super().__init__(
|
2233
|
-
in_size=in_size,
|
2234
|
-
target_size=target_size,
|
2235
|
-
channel_axis=channel_axis,
|
2236
|
-
num_spatial_dims=3,
|
2237
|
-
operation=jnp.max,
|
2238
|
-
name=name
|
2239
|
-
)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from typing import Sequence, Optional
|
20
|
+
from typing import Union, Tuple, Callable, List
|
21
|
+
|
22
|
+
import brainunit as u
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
import numpy as np
|
26
|
+
|
27
|
+
from brainstate import environ
|
28
|
+
from brainstate.typing import Size
|
29
|
+
from ._module import Module
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'Flatten', 'Unflatten',
|
33
|
+
|
34
|
+
'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
|
35
|
+
'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
|
36
|
+
'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
|
37
|
+
'LPPool1d', 'LPPool2d', 'LPPool3d',
|
38
|
+
|
39
|
+
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
|
40
|
+
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
|
41
|
+
]
|
42
|
+
|
43
|
+
|
44
|
+
class Flatten(Module):
|
45
|
+
r"""
|
46
|
+
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
47
|
+
|
48
|
+
Shape:
|
49
|
+
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
|
50
|
+
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
|
51
|
+
number of dimensions including none.
|
52
|
+
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
start_axis : int, optional
|
57
|
+
First dim to flatten (default = 0).
|
58
|
+
end_axis : int, optional
|
59
|
+
Last dim to flatten (default = -1).
|
60
|
+
in_size : Sequence of int, optional
|
61
|
+
The shape of the input tensor.
|
62
|
+
|
63
|
+
Examples
|
64
|
+
--------
|
65
|
+
.. code-block:: python
|
66
|
+
|
67
|
+
>>> import brainstate
|
68
|
+
>>> inp = brainstate.random.randn(32, 1, 5, 5)
|
69
|
+
>>> # With default parameters
|
70
|
+
>>> m = Flatten()
|
71
|
+
>>> output = m(inp)
|
72
|
+
>>> output.shape
|
73
|
+
(32, 25)
|
74
|
+
>>> # With non-default parameters
|
75
|
+
>>> m = Flatten(0, 2)
|
76
|
+
>>> output = m(inp)
|
77
|
+
>>> output.shape
|
78
|
+
(160, 5)
|
79
|
+
"""
|
80
|
+
__module__ = 'brainstate.nn'
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
start_axis: int = 0,
|
85
|
+
end_axis: int = -1,
|
86
|
+
in_size: Optional[Size] = None
|
87
|
+
) -> None:
|
88
|
+
super().__init__()
|
89
|
+
self.start_axis = start_axis
|
90
|
+
self.end_axis = end_axis
|
91
|
+
|
92
|
+
if in_size is not None:
|
93
|
+
self.in_size = tuple(in_size)
|
94
|
+
y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
|
95
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
96
|
+
self.out_size = y.shape
|
97
|
+
|
98
|
+
def update(self, x):
|
99
|
+
if self._in_size is None:
|
100
|
+
start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
|
101
|
+
else:
|
102
|
+
assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
|
103
|
+
dim_diff = x.ndim - len(self.in_size)
|
104
|
+
if self.in_size != x.shape[dim_diff:]:
|
105
|
+
raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
|
106
|
+
if self.start_axis >= 0:
|
107
|
+
start_axis = self.start_axis + dim_diff
|
108
|
+
else:
|
109
|
+
start_axis = x.ndim + self.start_axis
|
110
|
+
return u.math.flatten(x, start_axis, self.end_axis)
|
111
|
+
|
112
|
+
|
113
|
+
class Unflatten(Module):
|
114
|
+
r"""
|
115
|
+
Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
|
116
|
+
|
117
|
+
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
|
118
|
+
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
|
119
|
+
|
120
|
+
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
|
121
|
+
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
|
122
|
+
(tuple of `(name, size)` tuples) for `NamedTensor` input.
|
123
|
+
|
124
|
+
Shape:
|
125
|
+
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
|
126
|
+
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
|
127
|
+
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
|
128
|
+
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
axis : int
|
133
|
+
Dimension to be unflattened.
|
134
|
+
sizes : Sequence of int
|
135
|
+
New shape of the unflattened dimension.
|
136
|
+
name : str, optional
|
137
|
+
The name of the module.
|
138
|
+
in_size : Sequence of int, optional
|
139
|
+
The shape of the input tensor.
|
140
|
+
"""
|
141
|
+
__module__ = 'brainstate.nn'
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
axis: int,
|
146
|
+
sizes: Size,
|
147
|
+
name: str = None,
|
148
|
+
in_size: Optional[Size] = None
|
149
|
+
) -> None:
|
150
|
+
super().__init__(name=name)
|
151
|
+
|
152
|
+
self.axis = axis
|
153
|
+
self.sizes = sizes
|
154
|
+
if isinstance(sizes, (tuple, list)):
|
155
|
+
for idx, elem in enumerate(sizes):
|
156
|
+
if not isinstance(elem, int):
|
157
|
+
raise TypeError("unflattened sizes must be tuple of ints, " +
|
158
|
+
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
|
159
|
+
else:
|
160
|
+
raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
|
161
|
+
|
162
|
+
if in_size is not None:
|
163
|
+
self.in_size = tuple(in_size)
|
164
|
+
y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
|
165
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
166
|
+
self.out_size = y.shape
|
167
|
+
|
168
|
+
def update(self, x):
|
169
|
+
return u.math.unflatten(x, self.axis, self.sizes)
|
170
|
+
|
171
|
+
|
172
|
+
class _MaxPool(Module):
|
173
|
+
def __init__(
|
174
|
+
self,
|
175
|
+
init_value: float,
|
176
|
+
computation: Callable,
|
177
|
+
pool_dim: int,
|
178
|
+
kernel_size: Size,
|
179
|
+
stride: Union[int, Sequence[int]] = None,
|
180
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
181
|
+
channel_axis: Optional[int] = -1,
|
182
|
+
return_indices: bool = False,
|
183
|
+
name: Optional[str] = None,
|
184
|
+
in_size: Optional[Size] = None,
|
185
|
+
):
|
186
|
+
super().__init__(name=name)
|
187
|
+
|
188
|
+
self.init_value = init_value
|
189
|
+
self.computation = computation
|
190
|
+
self.pool_dim = pool_dim
|
191
|
+
self.return_indices = return_indices
|
192
|
+
|
193
|
+
# kernel_size
|
194
|
+
if isinstance(kernel_size, int):
|
195
|
+
kernel_size = (kernel_size,) * pool_dim
|
196
|
+
elif isinstance(kernel_size, Sequence):
|
197
|
+
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
198
|
+
assert all(
|
199
|
+
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
200
|
+
if len(kernel_size) != pool_dim:
|
201
|
+
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
202
|
+
else:
|
203
|
+
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
204
|
+
self.kernel_size = kernel_size
|
205
|
+
|
206
|
+
# stride
|
207
|
+
if stride is None:
|
208
|
+
stride = kernel_size
|
209
|
+
if isinstance(stride, int):
|
210
|
+
stride = (stride,) * pool_dim
|
211
|
+
elif isinstance(stride, Sequence):
|
212
|
+
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
213
|
+
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
214
|
+
if len(stride) != pool_dim:
|
215
|
+
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
216
|
+
else:
|
217
|
+
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
218
|
+
self.stride = stride
|
219
|
+
|
220
|
+
# padding
|
221
|
+
if isinstance(padding, str):
|
222
|
+
if padding not in ("SAME", "VALID"):
|
223
|
+
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
|
224
|
+
elif isinstance(padding, int):
|
225
|
+
padding = [(padding, padding) for _ in range(pool_dim)]
|
226
|
+
elif isinstance(padding, (list, tuple)):
|
227
|
+
if isinstance(padding[0], int):
|
228
|
+
if len(padding) == pool_dim:
|
229
|
+
padding = [(x, x) for x in padding]
|
230
|
+
else:
|
231
|
+
raise ValueError(f'If padding is a sequence of ints, it '
|
232
|
+
f'should has the length of {pool_dim}.')
|
233
|
+
else:
|
234
|
+
if not all([isinstance(x, (tuple, list)) for x in padding]):
|
235
|
+
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
|
236
|
+
if not all([len(x) == 2 for x in padding]):
|
237
|
+
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
|
238
|
+
if len(padding) == 1:
|
239
|
+
padding = tuple(padding) * pool_dim
|
240
|
+
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
|
241
|
+
else:
|
242
|
+
raise ValueError
|
243
|
+
self.padding = padding
|
244
|
+
|
245
|
+
# channel_axis
|
246
|
+
assert channel_axis is None or isinstance(channel_axis, int), \
|
247
|
+
f'channel_axis should be an int, but got {channel_axis}'
|
248
|
+
self.channel_axis = channel_axis
|
249
|
+
|
250
|
+
# in & out shapes
|
251
|
+
if in_size is not None:
|
252
|
+
in_size = tuple(in_size)
|
253
|
+
self.in_size = in_size
|
254
|
+
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
255
|
+
self.out_size = y.shape[1:]
|
256
|
+
|
257
|
+
def update(self, x):
|
258
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
259
|
+
if x.ndim < x_dim:
|
260
|
+
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
261
|
+
window_shape = tuple(self._infer_shape(x.ndim, self.kernel_size, 1))
|
262
|
+
stride = tuple(self._infer_shape(x.ndim, self.stride, 1))
|
263
|
+
if isinstance(self.padding, str):
|
264
|
+
padding = tuple(jax.lax.padtype_to_pads(x.shape, window_shape, stride, self.padding))
|
265
|
+
else:
|
266
|
+
padding = tuple(self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
267
|
+
|
268
|
+
if self.return_indices:
|
269
|
+
# For returning indices, we need to use a custom implementation
|
270
|
+
return self._pooling_with_indices(x, window_shape, stride, padding)
|
271
|
+
|
272
|
+
return jax.lax.reduce_window(
|
273
|
+
x,
|
274
|
+
init_value=self.init_value,
|
275
|
+
computation=self.computation,
|
276
|
+
window_dimensions=window_shape,
|
277
|
+
window_strides=stride,
|
278
|
+
padding=padding
|
279
|
+
)
|
280
|
+
|
281
|
+
def _pooling_with_indices(self, x, window_shape, stride, padding):
|
282
|
+
"""Perform max pooling and return both pooled values and indices."""
|
283
|
+
total_size = x.size
|
284
|
+
flat_indices = jnp.arange(total_size, dtype=jnp.int32).reshape(x.shape)
|
285
|
+
|
286
|
+
init_val = jnp.asarray(self.init_value, dtype=x.dtype)
|
287
|
+
init_idx = jnp.array(total_size, dtype=flat_indices.dtype)
|
288
|
+
|
289
|
+
def reducer(acc, operand):
|
290
|
+
acc_val, acc_idx = acc
|
291
|
+
cur_val, cur_idx = operand
|
292
|
+
|
293
|
+
better = cur_val > acc_val
|
294
|
+
best_val = jnp.where(better, cur_val, acc_val)
|
295
|
+
best_idx = jnp.where(better, cur_idx, acc_idx)
|
296
|
+
tie = jnp.logical_and(cur_val == acc_val, cur_idx < acc_idx)
|
297
|
+
best_idx = jnp.where(tie, cur_idx, best_idx)
|
298
|
+
|
299
|
+
return best_val, best_idx
|
300
|
+
|
301
|
+
pooled, indices_result = jax.lax.reduce_window(
|
302
|
+
(x, flat_indices),
|
303
|
+
(init_val, init_idx),
|
304
|
+
reducer,
|
305
|
+
window_dimensions=window_shape,
|
306
|
+
window_strides=stride,
|
307
|
+
padding=padding
|
308
|
+
)
|
309
|
+
|
310
|
+
indices_result = jnp.where(indices_result == total_size, 0, indices_result)
|
311
|
+
|
312
|
+
return pooled, indices_result.astype(jnp.int32)
|
313
|
+
|
314
|
+
def _infer_shape(self, x_dim, inputs, element):
|
315
|
+
channel_axis = self.channel_axis
|
316
|
+
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
|
317
|
+
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
|
318
|
+
if channel_axis and channel_axis < 0:
|
319
|
+
channel_axis = x_dim + channel_axis
|
320
|
+
all_dims = list(range(x_dim))
|
321
|
+
if channel_axis is not None:
|
322
|
+
all_dims.pop(channel_axis)
|
323
|
+
pool_dims = all_dims[-self.pool_dim:]
|
324
|
+
results = [element] * x_dim
|
325
|
+
for i, dim in enumerate(pool_dims):
|
326
|
+
results[dim] = inputs[i]
|
327
|
+
return results
|
328
|
+
|
329
|
+
|
330
|
+
class _AvgPool(_MaxPool):
|
331
|
+
def update(self, x):
|
332
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
333
|
+
if x.ndim < x_dim:
|
334
|
+
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
335
|
+
dims = self._infer_shape(x.ndim, self.kernel_size, 1)
|
336
|
+
stride = self._infer_shape(x.ndim, self.stride, 1)
|
337
|
+
padding = (self.padding if isinstance(self.padding, str) else
|
338
|
+
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
339
|
+
pooled = jax.lax.reduce_window(x,
|
340
|
+
init_value=self.init_value,
|
341
|
+
computation=self.computation,
|
342
|
+
window_dimensions=dims,
|
343
|
+
window_strides=stride,
|
344
|
+
padding=padding)
|
345
|
+
if padding == "VALID":
|
346
|
+
# Avoid the extra reduce_window.
|
347
|
+
return pooled / np.prod(dims)
|
348
|
+
else:
|
349
|
+
# Count the number of valid entries at each input point, then use that for
|
350
|
+
# computing average. Assumes that any two arrays of same shape will be
|
351
|
+
# padded the same.
|
352
|
+
window_counts = jax.lax.reduce_window(jnp.ones_like(x),
|
353
|
+
init_value=self.init_value,
|
354
|
+
computation=self.computation,
|
355
|
+
window_dimensions=dims,
|
356
|
+
window_strides=stride,
|
357
|
+
padding=padding)
|
358
|
+
assert pooled.shape == window_counts.shape
|
359
|
+
return pooled / window_counts
|
360
|
+
|
361
|
+
|
362
|
+
class MaxPool1d(_MaxPool):
|
363
|
+
r"""Applies a 1D max pooling over an input signal composed of several input planes.
|
364
|
+
|
365
|
+
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
|
366
|
+
and output :math:`(N, L_{out}, C)` can be precisely described as:
|
367
|
+
|
368
|
+
.. math::
|
369
|
+
out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
|
370
|
+
input(N_i, stride \times k + m, C_j)
|
371
|
+
|
372
|
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
373
|
+
for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
|
374
|
+
sliding window. This `link`_ has a nice visualization of the pooling parameters.
|
375
|
+
|
376
|
+
Shape:
|
377
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
378
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
379
|
+
|
380
|
+
.. math::
|
381
|
+
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
382
|
+
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
|
383
|
+
|
384
|
+
Parameters
|
385
|
+
----------
|
386
|
+
kernel_size : int or sequence of int
|
387
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
388
|
+
stride : int or sequence of int, optional
|
389
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
390
|
+
Default: kernel_size
|
391
|
+
padding : str, int or sequence of tuple, optional
|
392
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
393
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
394
|
+
and after each spatial dimension. Default: 'VALID'
|
395
|
+
channel_axis : int, optional
|
396
|
+
Axis of the spatial channels for which pooling is skipped.
|
397
|
+
If ``None``, there is no channel axis. Default: -1
|
398
|
+
return_indices : bool, optional
|
399
|
+
If True, will return the max indices along with the outputs.
|
400
|
+
Useful for MaxUnpool1d. Default: False
|
401
|
+
name : str, optional
|
402
|
+
The object name.
|
403
|
+
in_size : Sequence of int, optional
|
404
|
+
The shape of the input tensor.
|
405
|
+
|
406
|
+
Examples
|
407
|
+
--------
|
408
|
+
.. code-block:: python
|
409
|
+
|
410
|
+
>>> import brainstate
|
411
|
+
>>> # pool of size=3, stride=2
|
412
|
+
>>> m = MaxPool1d(3, stride=2, channel_axis=-1)
|
413
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
414
|
+
>>> output = m(input)
|
415
|
+
>>> output.shape
|
416
|
+
(20, 24, 16)
|
417
|
+
|
418
|
+
.. _link:
|
419
|
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
420
|
+
"""
|
421
|
+
__module__ = 'brainstate.nn'
|
422
|
+
|
423
|
+
def __init__(
|
424
|
+
self,
|
425
|
+
kernel_size: Size,
|
426
|
+
stride: Union[int, Sequence[int]] = None,
|
427
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
428
|
+
channel_axis: Optional[int] = -1,
|
429
|
+
return_indices: bool = False,
|
430
|
+
name: Optional[str] = None,
|
431
|
+
in_size: Optional[Size] = None,
|
432
|
+
):
|
433
|
+
super().__init__(in_size=in_size,
|
434
|
+
init_value=-jax.numpy.inf,
|
435
|
+
computation=jax.lax.max,
|
436
|
+
pool_dim=1,
|
437
|
+
kernel_size=kernel_size,
|
438
|
+
stride=stride,
|
439
|
+
padding=padding,
|
440
|
+
channel_axis=channel_axis,
|
441
|
+
return_indices=return_indices,
|
442
|
+
name=name)
|
443
|
+
|
444
|
+
|
445
|
+
class MaxPool2d(_MaxPool):
|
446
|
+
r"""Applies a 2D max pooling over an input signal composed of several input planes.
|
447
|
+
|
448
|
+
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
449
|
+
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
450
|
+
can be precisely described as:
|
451
|
+
|
452
|
+
.. math::
|
453
|
+
\begin{aligned}
|
454
|
+
out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
455
|
+
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
456
|
+
\text{stride[1]} \times w + n, C_j)
|
457
|
+
\end{aligned}
|
458
|
+
|
459
|
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
460
|
+
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
461
|
+
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
462
|
+
|
463
|
+
Shape:
|
464
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
465
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
466
|
+
|
467
|
+
.. math::
|
468
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
|
469
|
+
\times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
|
470
|
+
|
471
|
+
.. math::
|
472
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
|
473
|
+
\times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
|
474
|
+
|
475
|
+
Parameters
|
476
|
+
----------
|
477
|
+
kernel_size : int or sequence of int
|
478
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
479
|
+
stride : int or sequence of int, optional
|
480
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
481
|
+
Default: kernel_size
|
482
|
+
padding : str, int or sequence of tuple, optional
|
483
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
484
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
485
|
+
and after each spatial dimension. Default: 'VALID'
|
486
|
+
channel_axis : int, optional
|
487
|
+
Axis of the spatial channels for which pooling is skipped.
|
488
|
+
If ``None``, there is no channel axis. Default: -1
|
489
|
+
return_indices : bool, optional
|
490
|
+
If True, will return the max indices along with the outputs.
|
491
|
+
Useful for MaxUnpool2d. Default: False
|
492
|
+
name : str, optional
|
493
|
+
The object name.
|
494
|
+
in_size : Sequence of int, optional
|
495
|
+
The shape of the input tensor.
|
496
|
+
|
497
|
+
Examples
|
498
|
+
--------
|
499
|
+
.. code-block:: python
|
500
|
+
|
501
|
+
>>> import brainstate
|
502
|
+
>>> # pool of square window of size=3, stride=2
|
503
|
+
>>> m = MaxPool2d(3, stride=2)
|
504
|
+
>>> # pool of non-square window
|
505
|
+
>>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
|
506
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
507
|
+
>>> output = m(input)
|
508
|
+
>>> output.shape
|
509
|
+
(20, 24, 31, 16)
|
510
|
+
|
511
|
+
.. _link:
|
512
|
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
513
|
+
"""
|
514
|
+
__module__ = 'brainstate.nn'
|
515
|
+
|
516
|
+
def __init__(
|
517
|
+
self,
|
518
|
+
kernel_size: Size,
|
519
|
+
stride: Union[int, Sequence[int]] = None,
|
520
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
521
|
+
channel_axis: Optional[int] = -1,
|
522
|
+
return_indices: bool = False,
|
523
|
+
name: Optional[str] = None,
|
524
|
+
in_size: Optional[Size] = None,
|
525
|
+
):
|
526
|
+
super().__init__(in_size=in_size,
|
527
|
+
init_value=-jax.numpy.inf,
|
528
|
+
computation=jax.lax.max,
|
529
|
+
pool_dim=2,
|
530
|
+
kernel_size=kernel_size,
|
531
|
+
stride=stride,
|
532
|
+
padding=padding,
|
533
|
+
channel_axis=channel_axis,
|
534
|
+
return_indices=return_indices,
|
535
|
+
name=name)
|
536
|
+
|
537
|
+
|
538
|
+
class MaxPool3d(_MaxPool):
|
539
|
+
r"""Applies a 3D max pooling over an input signal composed of several input planes.
|
540
|
+
|
541
|
+
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
542
|
+
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
543
|
+
can be precisely described as:
|
544
|
+
|
545
|
+
.. math::
|
546
|
+
\begin{aligned}
|
547
|
+
\text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
548
|
+
& \text{input}(N_i, \text{stride[0]} \times d + k,
|
549
|
+
\text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
|
550
|
+
\end{aligned}
|
551
|
+
|
552
|
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
553
|
+
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
554
|
+
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
555
|
+
|
556
|
+
Shape:
|
557
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
558
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
559
|
+
|
560
|
+
.. math::
|
561
|
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
|
562
|
+
(\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
|
563
|
+
|
564
|
+
.. math::
|
565
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
|
566
|
+
(\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
|
567
|
+
|
568
|
+
.. math::
|
569
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
|
570
|
+
(\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
|
571
|
+
|
572
|
+
Parameters
|
573
|
+
----------
|
574
|
+
kernel_size : int or sequence of int
|
575
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
576
|
+
stride : int or sequence of int, optional
|
577
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
578
|
+
Default: kernel_size
|
579
|
+
padding : str, int or sequence of tuple, optional
|
580
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
581
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
582
|
+
and after each spatial dimension. Default: 'VALID'
|
583
|
+
channel_axis : int, optional
|
584
|
+
Axis of the spatial channels for which pooling is skipped.
|
585
|
+
If ``None``, there is no channel axis. Default: -1
|
586
|
+
return_indices : bool, optional
|
587
|
+
If True, will return the max indices along with the outputs.
|
588
|
+
Useful for MaxUnpool3d. Default: False
|
589
|
+
name : str, optional
|
590
|
+
The object name.
|
591
|
+
in_size : Sequence of int, optional
|
592
|
+
The shape of the input tensor.
|
593
|
+
|
594
|
+
Examples
|
595
|
+
--------
|
596
|
+
.. code-block:: python
|
597
|
+
|
598
|
+
>>> import brainstate
|
599
|
+
>>> # pool of square window of size=3, stride=2
|
600
|
+
>>> m = MaxPool3d(3, stride=2)
|
601
|
+
>>> # pool of non-square window
|
602
|
+
>>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
603
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
604
|
+
>>> output = m(input)
|
605
|
+
>>> output.shape
|
606
|
+
(20, 24, 43, 15, 16)
|
607
|
+
|
608
|
+
.. _link:
|
609
|
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
610
|
+
"""
|
611
|
+
__module__ = 'brainstate.nn'
|
612
|
+
|
613
|
+
def __init__(
|
614
|
+
self,
|
615
|
+
kernel_size: Size,
|
616
|
+
stride: Union[int, Sequence[int]] = None,
|
617
|
+
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
618
|
+
channel_axis: Optional[int] = -1,
|
619
|
+
return_indices: bool = False,
|
620
|
+
name: Optional[str] = None,
|
621
|
+
in_size: Optional[Size] = None,
|
622
|
+
):
|
623
|
+
super().__init__(in_size=in_size,
|
624
|
+
init_value=-jax.numpy.inf,
|
625
|
+
computation=jax.lax.max,
|
626
|
+
pool_dim=3,
|
627
|
+
kernel_size=kernel_size,
|
628
|
+
stride=stride,
|
629
|
+
padding=padding,
|
630
|
+
channel_axis=channel_axis,
|
631
|
+
return_indices=return_indices,
|
632
|
+
name=name)
|
633
|
+
|
634
|
+
|
635
|
+
class _MaxUnpool(Module):
|
636
|
+
"""Base class for max unpooling operations."""
|
637
|
+
|
638
|
+
def __init__(
|
639
|
+
self,
|
640
|
+
pool_dim: int,
|
641
|
+
kernel_size: Size,
|
642
|
+
stride: Union[int, Sequence[int]] = None,
|
643
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
644
|
+
channel_axis: Optional[int] = -1,
|
645
|
+
name: Optional[str] = None,
|
646
|
+
in_size: Optional[Size] = None,
|
647
|
+
):
|
648
|
+
super().__init__(name=name)
|
649
|
+
|
650
|
+
self.pool_dim = pool_dim
|
651
|
+
|
652
|
+
# kernel_size
|
653
|
+
if isinstance(kernel_size, int):
|
654
|
+
kernel_size = (kernel_size,) * pool_dim
|
655
|
+
elif isinstance(kernel_size, Sequence):
|
656
|
+
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
657
|
+
assert all(
|
658
|
+
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
659
|
+
if len(kernel_size) != pool_dim:
|
660
|
+
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
661
|
+
else:
|
662
|
+
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
663
|
+
self.kernel_size = kernel_size
|
664
|
+
|
665
|
+
# stride
|
666
|
+
if stride is None:
|
667
|
+
stride = kernel_size
|
668
|
+
if isinstance(stride, int):
|
669
|
+
stride = (stride,) * pool_dim
|
670
|
+
elif isinstance(stride, Sequence):
|
671
|
+
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
672
|
+
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
673
|
+
if len(stride) != pool_dim:
|
674
|
+
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
|
675
|
+
else:
|
676
|
+
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
677
|
+
self.stride = stride
|
678
|
+
|
679
|
+
# padding
|
680
|
+
if isinstance(padding, int):
|
681
|
+
padding = (padding,) * pool_dim
|
682
|
+
elif isinstance(padding, (tuple, list)):
|
683
|
+
if len(padding) != pool_dim:
|
684
|
+
raise ValueError(f'padding should have {pool_dim} values, but got {len(padding)}')
|
685
|
+
else:
|
686
|
+
raise TypeError(f'padding should be int or tuple of {pool_dim} ints.')
|
687
|
+
self.padding = padding
|
688
|
+
|
689
|
+
# channel_axis
|
690
|
+
assert channel_axis is None or isinstance(channel_axis, int), \
|
691
|
+
f'channel_axis should be an int, but got {channel_axis}'
|
692
|
+
self.channel_axis = channel_axis
|
693
|
+
|
694
|
+
# in & out shapes
|
695
|
+
if in_size is not None:
|
696
|
+
in_size = tuple(in_size)
|
697
|
+
self.in_size = in_size
|
698
|
+
|
699
|
+
def _compute_output_shape(self, input_shape, output_size=None):
|
700
|
+
"""Compute the output shape after unpooling."""
|
701
|
+
if output_size is not None:
|
702
|
+
return output_size
|
703
|
+
|
704
|
+
# Calculate output shape based on kernel, stride, and padding
|
705
|
+
output_shape = []
|
706
|
+
for i in range(self.pool_dim):
|
707
|
+
dim_size = (input_shape[i] - 1) * self.stride[i] - 2 * self.padding[i] + self.kernel_size[i]
|
708
|
+
output_shape.append(dim_size)
|
709
|
+
|
710
|
+
return tuple(output_shape)
|
711
|
+
|
712
|
+
def _unpool_nd(self, x, indices, output_size=None):
|
713
|
+
"""Perform N-dimensional max unpooling."""
|
714
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
715
|
+
if x.ndim < x_dim:
|
716
|
+
raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
|
717
|
+
|
718
|
+
# Determine output shape
|
719
|
+
if output_size is None:
|
720
|
+
# Infer output shape from input shape
|
721
|
+
spatial_dims = self._get_spatial_dims(x.shape)
|
722
|
+
output_spatial_shape = self._compute_output_shape(spatial_dims, output_size)
|
723
|
+
output_shape = list(x.shape)
|
724
|
+
|
725
|
+
# Update spatial dimensions in output shape
|
726
|
+
spatial_start = self._get_spatial_start_idx(x.ndim)
|
727
|
+
for i, size in enumerate(output_spatial_shape):
|
728
|
+
output_shape[spatial_start + i] = size
|
729
|
+
output_shape = tuple(output_shape)
|
730
|
+
else:
|
731
|
+
# Use provided output size
|
732
|
+
if isinstance(output_size, (list, tuple)):
|
733
|
+
if len(output_size) == x.ndim:
|
734
|
+
# Full output shape provided
|
735
|
+
output_shape = tuple(output_size)
|
736
|
+
else:
|
737
|
+
# Only spatial dimensions provided
|
738
|
+
if len(output_size) != self.pool_dim:
|
739
|
+
raise ValueError(f"output_size must have {self.pool_dim} spatial dimensions, got {len(output_size)}")
|
740
|
+
output_shape = list(x.shape)
|
741
|
+
spatial_start = self._get_spatial_start_idx(x.ndim)
|
742
|
+
for i, size in enumerate(output_size):
|
743
|
+
output_shape[spatial_start + i] = size
|
744
|
+
output_shape = tuple(output_shape)
|
745
|
+
else:
|
746
|
+
# Single integer provided, use for all spatial dims
|
747
|
+
output_shape = list(x.shape)
|
748
|
+
spatial_start = self._get_spatial_start_idx(x.ndim)
|
749
|
+
for i in range(self.pool_dim):
|
750
|
+
output_shape[spatial_start + i] = output_size
|
751
|
+
output_shape = tuple(output_shape)
|
752
|
+
|
753
|
+
# Create output array filled with zeros
|
754
|
+
output = jnp.zeros(output_shape, dtype=x.dtype)
|
755
|
+
|
756
|
+
# # Scatter input values to output using indices
|
757
|
+
# # Flatten spatial dimensions for easier indexing
|
758
|
+
# batch_dims = x.ndim - self.pool_dim - (0 if self.channel_axis is None else 1)
|
759
|
+
#
|
760
|
+
# # Reshape for processing
|
761
|
+
# if batch_dims > 0:
|
762
|
+
# batch_shape = x.shape[:batch_dims]
|
763
|
+
# if self.channel_axis is not None and self.channel_axis < batch_dims:
|
764
|
+
# # Channel axis is before spatial dims
|
765
|
+
# channel_idx = self.channel_axis
|
766
|
+
# n_channels = x.shape[channel_idx]
|
767
|
+
# elif self.channel_axis is not None:
|
768
|
+
# # Channel axis is after spatial dims
|
769
|
+
# if self.channel_axis < 0:
|
770
|
+
# channel_idx = x.ndim + self.channel_axis
|
771
|
+
# else:
|
772
|
+
# channel_idx = self.channel_axis
|
773
|
+
# n_channels = x.shape[channel_idx]
|
774
|
+
# else:
|
775
|
+
# n_channels = None
|
776
|
+
# else:
|
777
|
+
# batch_shape = ()
|
778
|
+
# if self.channel_axis is not None:
|
779
|
+
# if self.channel_axis < 0:
|
780
|
+
# channel_idx = x.ndim + self.channel_axis
|
781
|
+
# else:
|
782
|
+
# channel_idx = self.channel_axis
|
783
|
+
# n_channels = x.shape[channel_idx]
|
784
|
+
# else:
|
785
|
+
# n_channels = None
|
786
|
+
|
787
|
+
# Use JAX's scatter operation
|
788
|
+
# Flatten the indices to 1D for scatter
|
789
|
+
flat_indices = indices.ravel()
|
790
|
+
flat_values = x.ravel()
|
791
|
+
flat_output = output.ravel()
|
792
|
+
|
793
|
+
# Scatter the values
|
794
|
+
flat_output = flat_output.at[flat_indices].set(flat_values)
|
795
|
+
|
796
|
+
# Reshape back to original shape
|
797
|
+
output = flat_output.reshape(output_shape)
|
798
|
+
|
799
|
+
return output
|
800
|
+
|
801
|
+
def _get_spatial_dims(self, shape):
|
802
|
+
"""Extract spatial dimensions from input shape."""
|
803
|
+
if self.channel_axis is None:
|
804
|
+
return shape[-self.pool_dim:]
|
805
|
+
else:
|
806
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else len(shape) + self.channel_axis
|
807
|
+
all_dims = list(range(len(shape)))
|
808
|
+
all_dims.pop(channel_axis)
|
809
|
+
return tuple(shape[i] for i in all_dims[-self.pool_dim:])
|
810
|
+
|
811
|
+
def _get_spatial_start_idx(self, ndim):
|
812
|
+
"""Get the starting index of spatial dimensions."""
|
813
|
+
if self.channel_axis is None:
|
814
|
+
return ndim - self.pool_dim
|
815
|
+
else:
|
816
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else ndim + self.channel_axis
|
817
|
+
if channel_axis < ndim - self.pool_dim:
|
818
|
+
return ndim - self.pool_dim
|
819
|
+
else:
|
820
|
+
return ndim - self.pool_dim - 1
|
821
|
+
|
822
|
+
def update(self, x, indices, output_size=None):
|
823
|
+
"""Forward pass of MaxUnpool1d.
|
824
|
+
|
825
|
+
Parameters
|
826
|
+
----------
|
827
|
+
x : Array
|
828
|
+
Input tensor from MaxPool1d
|
829
|
+
indices : Array
|
830
|
+
Indices of maximum values from MaxPool1d
|
831
|
+
output_size : int or tuple, optional
|
832
|
+
The targeted output size
|
833
|
+
|
834
|
+
Returns
|
835
|
+
-------
|
836
|
+
Array
|
837
|
+
Unpooled output
|
838
|
+
"""
|
839
|
+
return self._unpool_nd(x, indices, output_size)
|
840
|
+
|
841
|
+
|
842
|
+
class MaxUnpool1d(_MaxUnpool):
|
843
|
+
r"""Computes a partial inverse of MaxPool1d.
|
844
|
+
|
845
|
+
MaxPool1d is not fully invertible, since the non-maximal values are lost.
|
846
|
+
MaxUnpool1d takes in as input the output of MaxPool1d including the indices
|
847
|
+
of the maximal values and computes a partial inverse in which all
|
848
|
+
non-maximal values are set to zero.
|
849
|
+
|
850
|
+
Note:
|
851
|
+
This function may produce nondeterministic gradients when given tensors
|
852
|
+
on a CUDA device. See notes on reproducibility for more information.
|
853
|
+
|
854
|
+
Shape:
|
855
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`
|
856
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
857
|
+
|
858
|
+
.. math::
|
859
|
+
L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size}
|
860
|
+
|
861
|
+
or as given by :attr:`output_size` in the call operator
|
862
|
+
|
863
|
+
Parameters
|
864
|
+
----------
|
865
|
+
kernel_size : int or tuple
|
866
|
+
Size of the max pooling window.
|
867
|
+
stride : int or tuple, optional
|
868
|
+
Stride of the max pooling window. Default: kernel_size
|
869
|
+
padding : int or tuple, optional
|
870
|
+
Padding that was added to the input. Default: 0
|
871
|
+
channel_axis : int, optional
|
872
|
+
Axis of the channels. Default: -1
|
873
|
+
name : str, optional
|
874
|
+
Name of the module.
|
875
|
+
in_size : Size, optional
|
876
|
+
Input size for shape inference.
|
877
|
+
|
878
|
+
Examples
|
879
|
+
--------
|
880
|
+
.. code-block:: python
|
881
|
+
|
882
|
+
>>> import brainstate
|
883
|
+
>>> import jax.numpy as jnp
|
884
|
+
>>> # Create pooling and unpooling layers
|
885
|
+
>>> pool = MaxPool1d(2, stride=2, return_indices=True, channel_axis=-1)
|
886
|
+
>>> unpool = MaxUnpool1d(2, stride=2, channel_axis=-1)
|
887
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
888
|
+
>>> output, indices = pool(input)
|
889
|
+
>>> unpooled = unpool(output, indices)
|
890
|
+
>>> # unpooled will have shape (20, 100, 16) with zeros at non-maximal positions
|
891
|
+
"""
|
892
|
+
__module__ = 'brainstate.nn'
|
893
|
+
|
894
|
+
def __init__(
|
895
|
+
self,
|
896
|
+
kernel_size: Size,
|
897
|
+
stride: Union[int, Sequence[int]] = None,
|
898
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
899
|
+
channel_axis: Optional[int] = -1,
|
900
|
+
name: Optional[str] = None,
|
901
|
+
in_size: Optional[Size] = None,
|
902
|
+
):
|
903
|
+
super().__init__(
|
904
|
+
pool_dim=1,
|
905
|
+
kernel_size=kernel_size,
|
906
|
+
stride=stride,
|
907
|
+
padding=padding,
|
908
|
+
channel_axis=channel_axis,
|
909
|
+
name=name,
|
910
|
+
in_size=in_size
|
911
|
+
)
|
912
|
+
|
913
|
+
|
914
|
+
class MaxUnpool2d(_MaxUnpool):
|
915
|
+
r"""Computes a partial inverse of MaxPool2d.
|
916
|
+
|
917
|
+
MaxPool2d is not fully invertible, since the non-maximal values are lost.
|
918
|
+
MaxUnpool2d takes in as input the output of MaxPool2d including the indices
|
919
|
+
of the maximal values and computes a partial inverse in which all
|
920
|
+
non-maximal values are set to zero.
|
921
|
+
|
922
|
+
Shape:
|
923
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
924
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
925
|
+
|
926
|
+
.. math::
|
927
|
+
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
|
928
|
+
|
929
|
+
.. math::
|
930
|
+
W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
|
931
|
+
|
932
|
+
or as given by :attr:`output_size` in the call operator
|
933
|
+
|
934
|
+
Parameters
|
935
|
+
----------
|
936
|
+
kernel_size : int or tuple
|
937
|
+
Size of the max pooling window.
|
938
|
+
stride : int or tuple, optional
|
939
|
+
Stride of the max pooling window. Default: kernel_size
|
940
|
+
padding : int or tuple, optional
|
941
|
+
Padding that was added to the input. Default: 0
|
942
|
+
channel_axis : int, optional
|
943
|
+
Axis of the channels. Default: -1
|
944
|
+
name : str, optional
|
945
|
+
Name of the module.
|
946
|
+
in_size : Size, optional
|
947
|
+
Input size for shape inference.
|
948
|
+
|
949
|
+
Examples
|
950
|
+
--------
|
951
|
+
.. code-block:: python
|
952
|
+
|
953
|
+
>>> import brainstate
|
954
|
+
>>> # Create pooling and unpooling layers
|
955
|
+
>>> pool = MaxPool2d(2, stride=2, return_indices=True, channel_axis=-1)
|
956
|
+
>>> unpool = MaxUnpool2d(2, stride=2, channel_axis=-1)
|
957
|
+
>>> input = brainstate.random.randn(1, 4, 4, 16)
|
958
|
+
>>> output, indices = pool(input)
|
959
|
+
>>> unpooled = unpool(output, indices)
|
960
|
+
>>> # unpooled will have shape (1, 8, 8, 16) with zeros at non-maximal positions
|
961
|
+
"""
|
962
|
+
__module__ = 'brainstate.nn'
|
963
|
+
|
964
|
+
def __init__(
|
965
|
+
self,
|
966
|
+
kernel_size: Size,
|
967
|
+
stride: Union[int, Sequence[int]] = None,
|
968
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
969
|
+
channel_axis: Optional[int] = -1,
|
970
|
+
name: Optional[str] = None,
|
971
|
+
in_size: Optional[Size] = None,
|
972
|
+
):
|
973
|
+
super().__init__(
|
974
|
+
pool_dim=2,
|
975
|
+
kernel_size=kernel_size,
|
976
|
+
stride=stride,
|
977
|
+
padding=padding,
|
978
|
+
channel_axis=channel_axis,
|
979
|
+
name=name,
|
980
|
+
in_size=in_size
|
981
|
+
)
|
982
|
+
|
983
|
+
|
984
|
+
class MaxUnpool3d(_MaxUnpool):
|
985
|
+
r"""Computes a partial inverse of MaxPool3d.
|
986
|
+
|
987
|
+
MaxPool3d is not fully invertible, since the non-maximal values are lost.
|
988
|
+
MaxUnpool3d takes in as input the output of MaxPool3d including the indices
|
989
|
+
of the maximal values and computes a partial inverse in which all
|
990
|
+
non-maximal values are set to zero.
|
991
|
+
|
992
|
+
Shape:
|
993
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`
|
994
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
995
|
+
|
996
|
+
.. math::
|
997
|
+
D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
|
998
|
+
|
999
|
+
.. math::
|
1000
|
+
H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
|
1001
|
+
|
1002
|
+
.. math::
|
1003
|
+
W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{kernel\_size}[2]
|
1004
|
+
|
1005
|
+
or as given by :attr:`output_size` in the call operator
|
1006
|
+
|
1007
|
+
Parameters
|
1008
|
+
----------
|
1009
|
+
kernel_size : int or tuple
|
1010
|
+
Size of the max pooling window.
|
1011
|
+
stride : int or tuple, optional
|
1012
|
+
Stride of the max pooling window. Default: kernel_size
|
1013
|
+
padding : int or tuple, optional
|
1014
|
+
Padding that was added to the input. Default: 0
|
1015
|
+
channel_axis : int, optional
|
1016
|
+
Axis of the channels. Default: -1
|
1017
|
+
name : str, optional
|
1018
|
+
Name of the module.
|
1019
|
+
in_size : Size, optional
|
1020
|
+
Input size for shape inference.
|
1021
|
+
|
1022
|
+
Examples
|
1023
|
+
--------
|
1024
|
+
.. code-block:: python
|
1025
|
+
|
1026
|
+
>>> import brainstate
|
1027
|
+
>>> # Create pooling and unpooling layers
|
1028
|
+
>>> pool = MaxPool3d(2, stride=2, return_indices=True, channel_axis=-1)
|
1029
|
+
>>> unpool = MaxUnpool3d(2, stride=2, channel_axis=-1)
|
1030
|
+
>>> input = brainstate.random.randn(1, 4, 4, 4, 16)
|
1031
|
+
>>> output, indices = pool(input)
|
1032
|
+
>>> unpooled = unpool(output, indices)
|
1033
|
+
>>> # unpooled will have shape (1, 8, 8, 8, 16) with zeros at non-maximal positions
|
1034
|
+
"""
|
1035
|
+
__module__ = 'brainstate.nn'
|
1036
|
+
|
1037
|
+
def __init__(
|
1038
|
+
self,
|
1039
|
+
kernel_size: Size,
|
1040
|
+
stride: Union[int, Sequence[int]] = None,
|
1041
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
1042
|
+
channel_axis: Optional[int] = -1,
|
1043
|
+
name: Optional[str] = None,
|
1044
|
+
in_size: Optional[Size] = None,
|
1045
|
+
):
|
1046
|
+
super().__init__(
|
1047
|
+
pool_dim=3,
|
1048
|
+
kernel_size=kernel_size,
|
1049
|
+
stride=stride,
|
1050
|
+
padding=padding,
|
1051
|
+
channel_axis=channel_axis,
|
1052
|
+
name=name,
|
1053
|
+
in_size=in_size
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
|
1057
|
+
class AvgPool1d(_AvgPool):
|
1058
|
+
r"""Applies a 1D average pooling over an input signal composed of several input planes.
|
1059
|
+
|
1060
|
+
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
|
1061
|
+
output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
|
1062
|
+
can be precisely described as:
|
1063
|
+
|
1064
|
+
.. math::
|
1065
|
+
|
1066
|
+
\text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
|
1067
|
+
\text{input}(N_i, \text{stride} \times l + m, C_j)
|
1068
|
+
|
1069
|
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
1070
|
+
for :attr:`padding` number of points.
|
1071
|
+
|
1072
|
+
Shape:
|
1073
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
1074
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
1075
|
+
|
1076
|
+
.. math::
|
1077
|
+
L_{out} = \left\lfloor \frac{L_{in} +
|
1078
|
+
2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
1079
|
+
|
1080
|
+
Parameters
|
1081
|
+
----------
|
1082
|
+
kernel_size : int or sequence of int
|
1083
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1084
|
+
stride : int or sequence of int, optional
|
1085
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1086
|
+
Default: 1
|
1087
|
+
padding : str, int or sequence of tuple, optional
|
1088
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1089
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1090
|
+
and after each spatial dimension. Default: 'VALID'
|
1091
|
+
channel_axis : int, optional
|
1092
|
+
Axis of the spatial channels for which pooling is skipped.
|
1093
|
+
If ``None``, there is no channel axis. Default: -1
|
1094
|
+
name : str, optional
|
1095
|
+
The object name.
|
1096
|
+
in_size : Sequence of int, optional
|
1097
|
+
The shape of the input tensor.
|
1098
|
+
|
1099
|
+
Examples
|
1100
|
+
--------
|
1101
|
+
.. code-block:: python
|
1102
|
+
|
1103
|
+
>>> import brainstate
|
1104
|
+
>>> # pool with window of size=3, stride=2
|
1105
|
+
>>> m = AvgPool1d(3, stride=2)
|
1106
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
1107
|
+
>>> m(input).shape
|
1108
|
+
(20, 24, 16)
|
1109
|
+
"""
|
1110
|
+
__module__ = 'brainstate.nn'
|
1111
|
+
|
1112
|
+
def __init__(
|
1113
|
+
self,
|
1114
|
+
kernel_size: Size,
|
1115
|
+
stride: Union[int, Sequence[int]] = 1,
|
1116
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1117
|
+
channel_axis: Optional[int] = -1,
|
1118
|
+
name: Optional[str] = None,
|
1119
|
+
in_size: Optional[Size] = None,
|
1120
|
+
):
|
1121
|
+
super().__init__(
|
1122
|
+
in_size=in_size,
|
1123
|
+
init_value=0.,
|
1124
|
+
computation=jax.lax.add,
|
1125
|
+
pool_dim=1,
|
1126
|
+
kernel_size=kernel_size,
|
1127
|
+
stride=stride,
|
1128
|
+
padding=padding,
|
1129
|
+
channel_axis=channel_axis,
|
1130
|
+
name=name
|
1131
|
+
)
|
1132
|
+
|
1133
|
+
|
1134
|
+
class AvgPool2d(_AvgPool):
|
1135
|
+
r"""Applies a 2D average pooling over an input signal composed of several input planes.
|
1136
|
+
|
1137
|
+
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
1138
|
+
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
1139
|
+
can be precisely described as:
|
1140
|
+
|
1141
|
+
.. math::
|
1142
|
+
|
1143
|
+
out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
|
1144
|
+
input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
|
1145
|
+
|
1146
|
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
1147
|
+
for :attr:`padding` number of points.
|
1148
|
+
|
1149
|
+
Shape:
|
1150
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
1151
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
1152
|
+
|
1153
|
+
.. math::
|
1154
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
|
1155
|
+
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1156
|
+
|
1157
|
+
.. math::
|
1158
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
|
1159
|
+
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1160
|
+
|
1161
|
+
Parameters
|
1162
|
+
----------
|
1163
|
+
kernel_size : int or sequence of int
|
1164
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1165
|
+
stride : int or sequence of int, optional
|
1166
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1167
|
+
Default: 1
|
1168
|
+
padding : str, int or sequence of tuple, optional
|
1169
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1170
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1171
|
+
and after each spatial dimension. Default: 'VALID'
|
1172
|
+
channel_axis : int, optional
|
1173
|
+
Axis of the spatial channels for which pooling is skipped.
|
1174
|
+
If ``None``, there is no channel axis. Default: -1
|
1175
|
+
name : str, optional
|
1176
|
+
The object name.
|
1177
|
+
in_size : Sequence of int, optional
|
1178
|
+
The shape of the input tensor.
|
1179
|
+
|
1180
|
+
Examples
|
1181
|
+
--------
|
1182
|
+
.. code-block:: python
|
1183
|
+
|
1184
|
+
>>> import brainstate
|
1185
|
+
>>> # pool of square window of size=3, stride=2
|
1186
|
+
>>> m = AvgPool2d(3, stride=2)
|
1187
|
+
>>> # pool of non-square window
|
1188
|
+
>>> m = AvgPool2d((3, 2), stride=(2, 1))
|
1189
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
1190
|
+
>>> output = m(input)
|
1191
|
+
>>> output.shape
|
1192
|
+
(20, 24, 31, 16)
|
1193
|
+
"""
|
1194
|
+
__module__ = 'brainstate.nn'
|
1195
|
+
|
1196
|
+
def __init__(
|
1197
|
+
self,
|
1198
|
+
kernel_size: Size,
|
1199
|
+
stride: Union[int, Sequence[int]] = 1,
|
1200
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1201
|
+
channel_axis: Optional[int] = -1,
|
1202
|
+
name: Optional[str] = None,
|
1203
|
+
in_size: Optional[Size] = None,
|
1204
|
+
):
|
1205
|
+
super().__init__(
|
1206
|
+
in_size=in_size,
|
1207
|
+
init_value=0.,
|
1208
|
+
computation=jax.lax.add,
|
1209
|
+
pool_dim=2,
|
1210
|
+
kernel_size=kernel_size,
|
1211
|
+
stride=stride,
|
1212
|
+
padding=padding,
|
1213
|
+
channel_axis=channel_axis,
|
1214
|
+
name=name
|
1215
|
+
)
|
1216
|
+
|
1217
|
+
|
1218
|
+
class AvgPool3d(_AvgPool):
|
1219
|
+
r"""Applies a 3D average pooling over an input signal composed of several input planes.
|
1220
|
+
|
1221
|
+
|
1222
|
+
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
1223
|
+
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
1224
|
+
can be precisely described as:
|
1225
|
+
|
1226
|
+
.. math::
|
1227
|
+
\begin{aligned}
|
1228
|
+
\text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
|
1229
|
+
& \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
|
1230
|
+
\text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
|
1231
|
+
{kD \times kH \times kW}
|
1232
|
+
\end{aligned}
|
1233
|
+
|
1234
|
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
|
1235
|
+
for :attr:`padding` number of points.
|
1236
|
+
|
1237
|
+
Shape:
|
1238
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1239
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
|
1240
|
+
:math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
1241
|
+
|
1242
|
+
.. math::
|
1243
|
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
|
1244
|
+
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1245
|
+
|
1246
|
+
.. math::
|
1247
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
|
1248
|
+
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1249
|
+
|
1250
|
+
.. math::
|
1251
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
|
1252
|
+
\text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
1253
|
+
|
1254
|
+
Parameters
|
1255
|
+
----------
|
1256
|
+
kernel_size : int or sequence of int
|
1257
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1258
|
+
stride : int or sequence of int, optional
|
1259
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1260
|
+
Default: 1
|
1261
|
+
padding : str, int or sequence of tuple, optional
|
1262
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1263
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1264
|
+
and after each spatial dimension. Default: 'VALID'
|
1265
|
+
channel_axis : int, optional
|
1266
|
+
Axis of the spatial channels for which pooling is skipped.
|
1267
|
+
If ``None``, there is no channel axis. Default: -1
|
1268
|
+
name : str, optional
|
1269
|
+
The object name.
|
1270
|
+
in_size : Sequence of int, optional
|
1271
|
+
The shape of the input tensor.
|
1272
|
+
|
1273
|
+
Examples
|
1274
|
+
--------
|
1275
|
+
.. code-block:: python
|
1276
|
+
|
1277
|
+
>>> import brainstate
|
1278
|
+
>>> # pool of square window of size=3, stride=2
|
1279
|
+
>>> m = AvgPool3d(3, stride=2)
|
1280
|
+
>>> # pool of non-square window
|
1281
|
+
>>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
|
1282
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
1283
|
+
>>> output = m(input)
|
1284
|
+
>>> output.shape
|
1285
|
+
(20, 24, 43, 15, 16)
|
1286
|
+
|
1287
|
+
"""
|
1288
|
+
__module__ = 'brainstate.nn'
|
1289
|
+
|
1290
|
+
def __init__(
|
1291
|
+
self,
|
1292
|
+
kernel_size: Size,
|
1293
|
+
stride: Union[int, Sequence[int]] = 1,
|
1294
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1295
|
+
channel_axis: Optional[int] = -1,
|
1296
|
+
name: Optional[str] = None,
|
1297
|
+
in_size: Optional[Size] = None,
|
1298
|
+
):
|
1299
|
+
super().__init__(
|
1300
|
+
in_size=in_size,
|
1301
|
+
init_value=0.,
|
1302
|
+
computation=jax.lax.add,
|
1303
|
+
pool_dim=3,
|
1304
|
+
kernel_size=kernel_size,
|
1305
|
+
stride=stride,
|
1306
|
+
padding=padding,
|
1307
|
+
channel_axis=channel_axis,
|
1308
|
+
name=name
|
1309
|
+
)
|
1310
|
+
|
1311
|
+
|
1312
|
+
class _LPPool(Module):
|
1313
|
+
"""Base class for Lp pooling operations."""
|
1314
|
+
|
1315
|
+
def __init__(
|
1316
|
+
self,
|
1317
|
+
norm_type: float,
|
1318
|
+
pool_dim: int,
|
1319
|
+
kernel_size: Size,
|
1320
|
+
stride: Union[int, Sequence[int]] = None,
|
1321
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1322
|
+
channel_axis: Optional[int] = -1,
|
1323
|
+
name: Optional[str] = None,
|
1324
|
+
in_size: Optional[Size] = None,
|
1325
|
+
):
|
1326
|
+
super().__init__(name=name)
|
1327
|
+
|
1328
|
+
if norm_type <= 0:
|
1329
|
+
raise ValueError(f"norm_type must be positive, got {norm_type}")
|
1330
|
+
self.norm_type = norm_type
|
1331
|
+
self.pool_dim = pool_dim
|
1332
|
+
|
1333
|
+
# kernel_size
|
1334
|
+
if isinstance(kernel_size, int):
|
1335
|
+
kernel_size = (kernel_size,) * pool_dim
|
1336
|
+
elif isinstance(kernel_size, Sequence):
|
1337
|
+
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
1338
|
+
assert all(
|
1339
|
+
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
1340
|
+
if len(kernel_size) != pool_dim:
|
1341
|
+
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
1342
|
+
else:
|
1343
|
+
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
1344
|
+
self.kernel_size = kernel_size
|
1345
|
+
|
1346
|
+
# stride
|
1347
|
+
if stride is None:
|
1348
|
+
stride = kernel_size
|
1349
|
+
if isinstance(stride, int):
|
1350
|
+
stride = (stride,) * pool_dim
|
1351
|
+
elif isinstance(stride, Sequence):
|
1352
|
+
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
1353
|
+
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
1354
|
+
if len(stride) != pool_dim:
|
1355
|
+
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
|
1356
|
+
else:
|
1357
|
+
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
1358
|
+
self.stride = stride
|
1359
|
+
|
1360
|
+
# padding
|
1361
|
+
if isinstance(padding, str):
|
1362
|
+
if padding not in ("SAME", "VALID"):
|
1363
|
+
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
|
1364
|
+
elif isinstance(padding, int):
|
1365
|
+
padding = [(padding, padding) for _ in range(pool_dim)]
|
1366
|
+
elif isinstance(padding, (list, tuple)):
|
1367
|
+
if isinstance(padding[0], int):
|
1368
|
+
if len(padding) == pool_dim:
|
1369
|
+
padding = [(x, x) for x in padding]
|
1370
|
+
else:
|
1371
|
+
raise ValueError(f'If padding is a sequence of ints, it '
|
1372
|
+
f'should has the length of {pool_dim}.')
|
1373
|
+
else:
|
1374
|
+
if not all([isinstance(x, (tuple, list)) for x in padding]):
|
1375
|
+
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
|
1376
|
+
if not all([len(x) == 2 for x in padding]):
|
1377
|
+
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
|
1378
|
+
if len(padding) == 1:
|
1379
|
+
padding = tuple(padding) * pool_dim
|
1380
|
+
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
|
1381
|
+
else:
|
1382
|
+
raise ValueError
|
1383
|
+
self.padding = padding
|
1384
|
+
|
1385
|
+
# channel_axis
|
1386
|
+
assert channel_axis is None or isinstance(channel_axis, int), \
|
1387
|
+
f'channel_axis should be an int, but got {channel_axis}'
|
1388
|
+
self.channel_axis = channel_axis
|
1389
|
+
|
1390
|
+
# in & out shapes
|
1391
|
+
if in_size is not None:
|
1392
|
+
in_size = tuple(in_size)
|
1393
|
+
self.in_size = in_size
|
1394
|
+
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
1395
|
+
self.out_size = y.shape[1:]
|
1396
|
+
|
1397
|
+
def update(self, x):
|
1398
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
1399
|
+
if x.ndim < x_dim:
|
1400
|
+
raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
|
1401
|
+
|
1402
|
+
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
|
1403
|
+
stride = self._infer_shape(x.ndim, self.stride, 1)
|
1404
|
+
padding = (self.padding if isinstance(self.padding, str) else
|
1405
|
+
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
1406
|
+
|
1407
|
+
# For Lp pooling, we need to:
|
1408
|
+
# 1. Take absolute value and raise to power p
|
1409
|
+
# 2. Sum over the window
|
1410
|
+
# 3. Take the p-th root
|
1411
|
+
|
1412
|
+
# Step 1: |x|^p
|
1413
|
+
x_pow = jnp.abs(x) ** self.norm_type
|
1414
|
+
|
1415
|
+
# Step 2: Sum over window
|
1416
|
+
pooled_sum = jax.lax.reduce_window(
|
1417
|
+
x_pow,
|
1418
|
+
init_value=0.,
|
1419
|
+
computation=jax.lax.add,
|
1420
|
+
window_dimensions=window_shape,
|
1421
|
+
window_strides=stride,
|
1422
|
+
padding=padding
|
1423
|
+
)
|
1424
|
+
|
1425
|
+
# Step 3: Take p-th root and multiply by normalization factor
|
1426
|
+
# The normalization factor is (1/N)^(1/p) where N is the window size
|
1427
|
+
window_size = np.prod([w for i, w in enumerate(self.kernel_size)])
|
1428
|
+
norm_factor = window_size ** (-1.0 / self.norm_type)
|
1429
|
+
result = norm_factor * (pooled_sum ** (1.0 / self.norm_type))
|
1430
|
+
|
1431
|
+
return result
|
1432
|
+
|
1433
|
+
def _infer_shape(self, x_dim, inputs, element):
|
1434
|
+
channel_axis = self.channel_axis
|
1435
|
+
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
|
1436
|
+
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
|
1437
|
+
if channel_axis and channel_axis < 0:
|
1438
|
+
channel_axis = x_dim + channel_axis
|
1439
|
+
all_dims = list(range(x_dim))
|
1440
|
+
if channel_axis is not None:
|
1441
|
+
all_dims.pop(channel_axis)
|
1442
|
+
pool_dims = all_dims[-self.pool_dim:]
|
1443
|
+
results = [element] * x_dim
|
1444
|
+
for i, dim in enumerate(pool_dims):
|
1445
|
+
results[dim] = inputs[i]
|
1446
|
+
return results
|
1447
|
+
|
1448
|
+
|
1449
|
+
class LPPool1d(_LPPool):
|
1450
|
+
r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
|
1451
|
+
|
1452
|
+
On each window, the function computed is:
|
1453
|
+
|
1454
|
+
.. math::
|
1455
|
+
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1456
|
+
|
1457
|
+
- At :math:`p = \infty`, one gets max pooling
|
1458
|
+
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1459
|
+
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1460
|
+
|
1461
|
+
Shape:
|
1462
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
1463
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
1464
|
+
|
1465
|
+
.. math::
|
1466
|
+
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
1467
|
+
|
1468
|
+
Parameters
|
1469
|
+
----------
|
1470
|
+
norm_type : float
|
1471
|
+
Exponent for the pooling operation. Default: 2.0
|
1472
|
+
kernel_size : int or sequence of int
|
1473
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1474
|
+
stride : int or sequence of int, optional
|
1475
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1476
|
+
Default: kernel_size
|
1477
|
+
padding : str, int or sequence of tuple, optional
|
1478
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1479
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1480
|
+
and after each spatial dimension. Default: 'VALID'
|
1481
|
+
channel_axis : int, optional
|
1482
|
+
Axis of the spatial channels for which pooling is skipped.
|
1483
|
+
If ``None``, there is no channel axis. Default: -1
|
1484
|
+
name : str, optional
|
1485
|
+
The object name.
|
1486
|
+
in_size : Sequence of int, optional
|
1487
|
+
The shape of the input tensor.
|
1488
|
+
|
1489
|
+
Examples
|
1490
|
+
--------
|
1491
|
+
.. code-block:: python
|
1492
|
+
|
1493
|
+
>>> import brainstate
|
1494
|
+
>>> # power-average pooling of window of size=3, stride=2 with norm_type=2.0
|
1495
|
+
>>> m = LPPool1d(2, 3, stride=2)
|
1496
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
1497
|
+
>>> output = m(input)
|
1498
|
+
>>> output.shape
|
1499
|
+
(20, 24, 16)
|
1500
|
+
"""
|
1501
|
+
__module__ = 'brainstate.nn'
|
1502
|
+
|
1503
|
+
def __init__(
|
1504
|
+
self,
|
1505
|
+
norm_type: float,
|
1506
|
+
kernel_size: Size,
|
1507
|
+
stride: Union[int, Sequence[int]] = None,
|
1508
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1509
|
+
channel_axis: Optional[int] = -1,
|
1510
|
+
name: Optional[str] = None,
|
1511
|
+
in_size: Optional[Size] = None,
|
1512
|
+
):
|
1513
|
+
super().__init__(
|
1514
|
+
norm_type=norm_type,
|
1515
|
+
pool_dim=1,
|
1516
|
+
kernel_size=kernel_size,
|
1517
|
+
stride=stride,
|
1518
|
+
padding=padding,
|
1519
|
+
channel_axis=channel_axis,
|
1520
|
+
name=name,
|
1521
|
+
in_size=in_size
|
1522
|
+
)
|
1523
|
+
|
1524
|
+
|
1525
|
+
class LPPool2d(_LPPool):
|
1526
|
+
r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
|
1527
|
+
|
1528
|
+
On each window, the function computed is:
|
1529
|
+
|
1530
|
+
.. math::
|
1531
|
+
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1532
|
+
|
1533
|
+
- At :math:`p = \infty`, one gets max pooling
|
1534
|
+
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1535
|
+
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1536
|
+
|
1537
|
+
Shape:
|
1538
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
1539
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
1540
|
+
|
1541
|
+
.. math::
|
1542
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1543
|
+
|
1544
|
+
.. math::
|
1545
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1546
|
+
|
1547
|
+
Parameters
|
1548
|
+
----------
|
1549
|
+
norm_type : float
|
1550
|
+
Exponent for the pooling operation. Default: 2.0
|
1551
|
+
kernel_size : int or sequence of int
|
1552
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1553
|
+
stride : int or sequence of int, optional
|
1554
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1555
|
+
Default: kernel_size
|
1556
|
+
padding : str, int or sequence of tuple, optional
|
1557
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1558
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1559
|
+
and after each spatial dimension. Default: 'VALID'
|
1560
|
+
channel_axis : int, optional
|
1561
|
+
Axis of the spatial channels for which pooling is skipped.
|
1562
|
+
If ``None``, there is no channel axis. Default: -1
|
1563
|
+
name : str, optional
|
1564
|
+
The object name.
|
1565
|
+
in_size : Sequence of int, optional
|
1566
|
+
The shape of the input tensor.
|
1567
|
+
|
1568
|
+
Examples
|
1569
|
+
--------
|
1570
|
+
.. code-block:: python
|
1571
|
+
|
1572
|
+
>>> import brainstate
|
1573
|
+
>>> # power-average pooling of square window of size=3, stride=2
|
1574
|
+
>>> m = LPPool2d(2, 3, stride=2)
|
1575
|
+
>>> # pool of non-square window with norm_type=1.5
|
1576
|
+
>>> m = LPPool2d(1.5, (3, 2), stride=(2, 1), channel_axis=-1)
|
1577
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
1578
|
+
>>> output = m(input)
|
1579
|
+
>>> output.shape
|
1580
|
+
(20, 24, 31, 16)
|
1581
|
+
"""
|
1582
|
+
__module__ = 'brainstate.nn'
|
1583
|
+
|
1584
|
+
def __init__(
|
1585
|
+
self,
|
1586
|
+
norm_type: float,
|
1587
|
+
kernel_size: Size,
|
1588
|
+
stride: Union[int, Sequence[int]] = None,
|
1589
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1590
|
+
channel_axis: Optional[int] = -1,
|
1591
|
+
name: Optional[str] = None,
|
1592
|
+
in_size: Optional[Size] = None,
|
1593
|
+
):
|
1594
|
+
super().__init__(
|
1595
|
+
norm_type=norm_type,
|
1596
|
+
pool_dim=2,
|
1597
|
+
kernel_size=kernel_size,
|
1598
|
+
stride=stride,
|
1599
|
+
padding=padding,
|
1600
|
+
channel_axis=channel_axis,
|
1601
|
+
name=name,
|
1602
|
+
in_size=in_size
|
1603
|
+
)
|
1604
|
+
|
1605
|
+
|
1606
|
+
class LPPool3d(_LPPool):
|
1607
|
+
r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
|
1608
|
+
|
1609
|
+
On each window, the function computed is:
|
1610
|
+
|
1611
|
+
.. math::
|
1612
|
+
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1613
|
+
|
1614
|
+
- At :math:`p = \infty`, one gets max pooling
|
1615
|
+
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1616
|
+
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1617
|
+
|
1618
|
+
Shape:
|
1619
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1620
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
1621
|
+
|
1622
|
+
.. math::
|
1623
|
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1624
|
+
|
1625
|
+
.. math::
|
1626
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1627
|
+
|
1628
|
+
.. math::
|
1629
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
1630
|
+
|
1631
|
+
Parameters
|
1632
|
+
----------
|
1633
|
+
norm_type : float
|
1634
|
+
Exponent for the pooling operation. Default: 2.0
|
1635
|
+
kernel_size : int or sequence of int
|
1636
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1637
|
+
stride : int or sequence of int, optional
|
1638
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1639
|
+
Default: kernel_size
|
1640
|
+
padding : str, int or sequence of tuple, optional
|
1641
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1642
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1643
|
+
and after each spatial dimension. Default: 'VALID'
|
1644
|
+
channel_axis : int, optional
|
1645
|
+
Axis of the spatial channels for which pooling is skipped.
|
1646
|
+
If ``None``, there is no channel axis. Default: -1
|
1647
|
+
name : str, optional
|
1648
|
+
The object name.
|
1649
|
+
in_size : Sequence of int, optional
|
1650
|
+
The shape of the input tensor.
|
1651
|
+
|
1652
|
+
Examples
|
1653
|
+
--------
|
1654
|
+
.. code-block:: python
|
1655
|
+
|
1656
|
+
>>> import brainstate
|
1657
|
+
>>> # power-average pooling of cube window of size=3, stride=2
|
1658
|
+
>>> m = LPPool3d(2, 3, stride=2)
|
1659
|
+
>>> # pool of non-cubic window with norm_type=1.5
|
1660
|
+
>>> m = LPPool3d(1.5, (3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
1661
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
1662
|
+
>>> output = m(input)
|
1663
|
+
>>> output.shape
|
1664
|
+
(20, 24, 43, 15, 16)
|
1665
|
+
"""
|
1666
|
+
__module__ = 'brainstate.nn'
|
1667
|
+
|
1668
|
+
def __init__(
|
1669
|
+
self,
|
1670
|
+
norm_type: float,
|
1671
|
+
kernel_size: Size,
|
1672
|
+
stride: Union[int, Sequence[int]] = None,
|
1673
|
+
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
1674
|
+
channel_axis: Optional[int] = -1,
|
1675
|
+
name: Optional[str] = None,
|
1676
|
+
in_size: Optional[Size] = None,
|
1677
|
+
):
|
1678
|
+
super().__init__(
|
1679
|
+
norm_type=norm_type,
|
1680
|
+
pool_dim=3,
|
1681
|
+
kernel_size=kernel_size,
|
1682
|
+
stride=stride,
|
1683
|
+
padding=padding,
|
1684
|
+
channel_axis=channel_axis,
|
1685
|
+
name=name,
|
1686
|
+
in_size=in_size
|
1687
|
+
)
|
1688
|
+
|
1689
|
+
|
1690
|
+
def _adaptive_pool1d(x, target_size: int, operation: Callable):
|
1691
|
+
"""Adaptive pool 1D.
|
1692
|
+
|
1693
|
+
Args:
|
1694
|
+
x: The input. Should be a JAX array of shape `(dim,)`.
|
1695
|
+
target_size: The shape of the output after the pooling operation `(target_size,)`.
|
1696
|
+
operation: The pooling operation to be performed on the input array.
|
1697
|
+
|
1698
|
+
Returns:
|
1699
|
+
A JAX array of shape `(target_size, )`.
|
1700
|
+
"""
|
1701
|
+
size = jnp.size(x)
|
1702
|
+
num_head_arrays = size % target_size
|
1703
|
+
num_block = size // target_size
|
1704
|
+
if num_head_arrays != 0:
|
1705
|
+
head_end_index = num_head_arrays * (num_block + 1)
|
1706
|
+
heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
|
1707
|
+
tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
|
1708
|
+
outs = jnp.concatenate([heads, tails])
|
1709
|
+
else:
|
1710
|
+
outs = jax.vmap(operation)(x.reshape(-1, num_block))
|
1711
|
+
return outs
|
1712
|
+
|
1713
|
+
|
1714
|
+
def _generate_vmap(fun: Callable, map_axes: List[int]):
|
1715
|
+
map_axes = sorted(map_axes)
|
1716
|
+
for axis in map_axes:
|
1717
|
+
fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
|
1718
|
+
return fun
|
1719
|
+
|
1720
|
+
|
1721
|
+
class _AdaptivePool(Module):
|
1722
|
+
"""General N dimensional adaptive down-sampling to a target shape.
|
1723
|
+
|
1724
|
+
Parameters
|
1725
|
+
----------
|
1726
|
+
in_size: Sequence of int
|
1727
|
+
The shape of the input tensor.
|
1728
|
+
target_size: int, sequence of int
|
1729
|
+
The target output shape.
|
1730
|
+
num_spatial_dims: int
|
1731
|
+
The number of spatial dimensions.
|
1732
|
+
channel_axis: int, optional
|
1733
|
+
Axis of the spatial channels for which pooling is skipped.
|
1734
|
+
If ``None``, there is no channel axis.
|
1735
|
+
operation: Callable
|
1736
|
+
The down-sampling operation.
|
1737
|
+
name: str
|
1738
|
+
The class name.
|
1739
|
+
"""
|
1740
|
+
|
1741
|
+
def __init__(
|
1742
|
+
self,
|
1743
|
+
in_size: Size,
|
1744
|
+
target_size: Size,
|
1745
|
+
num_spatial_dims: int,
|
1746
|
+
operation: Callable,
|
1747
|
+
channel_axis: Optional[int] = -1,
|
1748
|
+
name: Optional[str] = None,
|
1749
|
+
):
|
1750
|
+
super().__init__(name=name)
|
1751
|
+
|
1752
|
+
self.channel_axis = channel_axis
|
1753
|
+
self.operation = operation
|
1754
|
+
if isinstance(target_size, int):
|
1755
|
+
self.target_shape = (target_size,) * num_spatial_dims
|
1756
|
+
elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
|
1757
|
+
self.target_shape = target_size
|
1758
|
+
else:
|
1759
|
+
raise ValueError("`target_size` must either be an int or tuple of length "
|
1760
|
+
f"{num_spatial_dims} containing ints.")
|
1761
|
+
|
1762
|
+
# in & out shapes
|
1763
|
+
if in_size is not None:
|
1764
|
+
in_size = tuple(in_size)
|
1765
|
+
self.in_size = in_size
|
1766
|
+
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
1767
|
+
self.out_size = y.shape[1:]
|
1768
|
+
|
1769
|
+
def update(self, x):
|
1770
|
+
"""Input-output mapping.
|
1771
|
+
|
1772
|
+
Parameters
|
1773
|
+
----------
|
1774
|
+
x: Array
|
1775
|
+
Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
|
1776
|
+
or `(..., dim_1, dim_2)`.
|
1777
|
+
"""
|
1778
|
+
# channel axis
|
1779
|
+
channel_axis = self.channel_axis
|
1780
|
+
|
1781
|
+
if channel_axis:
|
1782
|
+
if not 0 <= abs(channel_axis) < x.ndim:
|
1783
|
+
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
|
1784
|
+
if channel_axis < 0:
|
1785
|
+
channel_axis = x.ndim + channel_axis
|
1786
|
+
# input dimension
|
1787
|
+
if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
|
1788
|
+
raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
|
1789
|
+
f"dimensions (channel_axis={self.channel_axis}). "
|
1790
|
+
f"But got {x.ndim} dimensions.")
|
1791
|
+
# pooling dimensions
|
1792
|
+
pool_dims = list(range(x.ndim))
|
1793
|
+
if channel_axis:
|
1794
|
+
pool_dims.pop(channel_axis)
|
1795
|
+
|
1796
|
+
# pooling
|
1797
|
+
for i, di in enumerate(pool_dims[-len(self.target_shape):]):
|
1798
|
+
poo_axes = [j for j in range(x.ndim) if j != di]
|
1799
|
+
op = _generate_vmap(_adaptive_pool1d, poo_axes)
|
1800
|
+
x = op(x, self.target_shape[i], self.operation)
|
1801
|
+
return x
|
1802
|
+
|
1803
|
+
|
1804
|
+
class AdaptiveAvgPool1d(_AdaptivePool):
|
1805
|
+
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
1806
|
+
|
1807
|
+
The output size is :math:`L_{out}`, for any input size.
|
1808
|
+
The number of output features is equal to the number of input planes.
|
1809
|
+
|
1810
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1811
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1812
|
+
|
1813
|
+
Shape:
|
1814
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
1815
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
1816
|
+
:math:`L_{out}=\text{target\_size}`.
|
1817
|
+
|
1818
|
+
Parameters
|
1819
|
+
----------
|
1820
|
+
target_size : int or sequence of int
|
1821
|
+
The target output size. The number of output features for each channel.
|
1822
|
+
channel_axis : int, optional
|
1823
|
+
Axis of the spatial channels for which pooling is skipped.
|
1824
|
+
If ``None``, there is no channel axis. Default: -1
|
1825
|
+
name : str, optional
|
1826
|
+
The name of the module.
|
1827
|
+
in_size : Sequence of int, optional
|
1828
|
+
The shape of the input tensor for shape inference.
|
1829
|
+
|
1830
|
+
Examples
|
1831
|
+
--------
|
1832
|
+
.. code-block:: python
|
1833
|
+
|
1834
|
+
>>> import brainstate
|
1835
|
+
>>> # Target output size of 5
|
1836
|
+
>>> m = AdaptiveAvgPool1d(5)
|
1837
|
+
>>> input = brainstate.random.randn(1, 64, 8)
|
1838
|
+
>>> output = m(input)
|
1839
|
+
>>> output.shape
|
1840
|
+
(1, 5, 8)
|
1841
|
+
>>> # Can handle variable input sizes
|
1842
|
+
>>> input2 = brainstate.random.randn(1, 32, 8)
|
1843
|
+
>>> output2 = m(input2)
|
1844
|
+
>>> output2.shape
|
1845
|
+
(1, 5, 8) # Same output size regardless of input size
|
1846
|
+
|
1847
|
+
See Also
|
1848
|
+
--------
|
1849
|
+
AvgPool1d : Non-adaptive 1D average pooling.
|
1850
|
+
AdaptiveMaxPool1d : Adaptive 1D max pooling.
|
1851
|
+
"""
|
1852
|
+
__module__ = 'brainstate.nn'
|
1853
|
+
|
1854
|
+
def __init__(
|
1855
|
+
self,
|
1856
|
+
target_size: Union[int, Sequence[int]],
|
1857
|
+
channel_axis: Optional[int] = -1,
|
1858
|
+
name: Optional[str] = None,
|
1859
|
+
in_size: Optional[Sequence[int]] = None,
|
1860
|
+
):
|
1861
|
+
super().__init__(
|
1862
|
+
in_size=in_size,
|
1863
|
+
target_size=target_size,
|
1864
|
+
channel_axis=channel_axis,
|
1865
|
+
num_spatial_dims=1,
|
1866
|
+
operation=jnp.mean,
|
1867
|
+
name=name
|
1868
|
+
)
|
1869
|
+
|
1870
|
+
|
1871
|
+
class AdaptiveAvgPool2d(_AdaptivePool):
|
1872
|
+
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
1873
|
+
|
1874
|
+
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
1875
|
+
The number of output features is equal to the number of input planes.
|
1876
|
+
|
1877
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1878
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1879
|
+
|
1880
|
+
Shape:
|
1881
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
1882
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
1883
|
+
:math:`(H_{out}, W_{out})=\text{target\_size}`.
|
1884
|
+
|
1885
|
+
Parameters
|
1886
|
+
----------
|
1887
|
+
target_size : int or tuple of int
|
1888
|
+
The target output size. If a single integer is provided, the output will be a square
|
1889
|
+
of that size. If a tuple is provided, it specifies (H_out, W_out).
|
1890
|
+
Use None for dimensions that should not be pooled.
|
1891
|
+
channel_axis : int, optional
|
1892
|
+
Axis of the spatial channels for which pooling is skipped.
|
1893
|
+
If ``None``, there is no channel axis. Default: -1
|
1894
|
+
name : str, optional
|
1895
|
+
The name of the module.
|
1896
|
+
in_size : Sequence of int, optional
|
1897
|
+
The shape of the input tensor for shape inference.
|
1898
|
+
|
1899
|
+
Examples
|
1900
|
+
--------
|
1901
|
+
.. code-block:: python
|
1902
|
+
|
1903
|
+
>>> import brainstate
|
1904
|
+
>>> # Target output size of 5x7
|
1905
|
+
>>> m = AdaptiveAvgPool2d((5, 7))
|
1906
|
+
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
1907
|
+
>>> output = m(input)
|
1908
|
+
>>> output.shape
|
1909
|
+
(1, 5, 7, 64)
|
1910
|
+
>>> # Target output size of 7x7 (square)
|
1911
|
+
>>> m = AdaptiveAvgPool2d(7)
|
1912
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
1913
|
+
>>> output = m(input)
|
1914
|
+
>>> output.shape
|
1915
|
+
(1, 7, 7, 64)
|
1916
|
+
>>> # Target output size of 10x7
|
1917
|
+
>>> m = AdaptiveAvgPool2d((None, 7))
|
1918
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
1919
|
+
>>> output = m(input)
|
1920
|
+
>>> output.shape
|
1921
|
+
(1, 10, 7, 64)
|
1922
|
+
|
1923
|
+
See Also
|
1924
|
+
--------
|
1925
|
+
AvgPool2d : Non-adaptive 2D average pooling.
|
1926
|
+
AdaptiveMaxPool2d : Adaptive 2D max pooling.
|
1927
|
+
"""
|
1928
|
+
__module__ = 'brainstate.nn'
|
1929
|
+
|
1930
|
+
def __init__(
|
1931
|
+
self,
|
1932
|
+
target_size: Union[int, Sequence[int]],
|
1933
|
+
channel_axis: Optional[int] = -1,
|
1934
|
+
name: Optional[str] = None,
|
1935
|
+
in_size: Optional[Sequence[int]] = None,
|
1936
|
+
):
|
1937
|
+
super().__init__(
|
1938
|
+
in_size=in_size,
|
1939
|
+
target_size=target_size,
|
1940
|
+
channel_axis=channel_axis,
|
1941
|
+
num_spatial_dims=2,
|
1942
|
+
operation=jnp.mean,
|
1943
|
+
name=name
|
1944
|
+
)
|
1945
|
+
|
1946
|
+
|
1947
|
+
class AdaptiveAvgPool3d(_AdaptivePool):
|
1948
|
+
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
|
1949
|
+
|
1950
|
+
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
1951
|
+
The number of output features is equal to the number of input planes.
|
1952
|
+
|
1953
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1954
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1955
|
+
|
1956
|
+
Shape:
|
1957
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1958
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
1959
|
+
where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
|
1960
|
+
|
1961
|
+
Parameters
|
1962
|
+
----------
|
1963
|
+
target_size : int or tuple of int
|
1964
|
+
The target output size. If a single integer is provided, the output will be a cube
|
1965
|
+
of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
|
1966
|
+
Use None for dimensions that should not be pooled.
|
1967
|
+
channel_axis : int, optional
|
1968
|
+
Axis of the spatial channels for which pooling is skipped.
|
1969
|
+
If ``None``, there is no channel axis. Default: -1
|
1970
|
+
name : str, optional
|
1971
|
+
The name of the module.
|
1972
|
+
in_size : Sequence of int, optional
|
1973
|
+
The shape of the input tensor for shape inference.
|
1974
|
+
|
1975
|
+
Examples
|
1976
|
+
--------
|
1977
|
+
.. code-block:: python
|
1978
|
+
|
1979
|
+
>>> import brainstate
|
1980
|
+
>>> # Target output size of 5x7x9
|
1981
|
+
>>> m = AdaptiveAvgPool3d((5, 7, 9))
|
1982
|
+
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
1983
|
+
>>> output = m(input)
|
1984
|
+
>>> output.shape
|
1985
|
+
(1, 5, 7, 9, 64)
|
1986
|
+
>>> # Target output size of 7x7x7 (cube)
|
1987
|
+
>>> m = AdaptiveAvgPool3d(7)
|
1988
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1989
|
+
>>> output = m(input)
|
1990
|
+
>>> output.shape
|
1991
|
+
(1, 7, 7, 7, 64)
|
1992
|
+
>>> # Target output size of 7x9x8
|
1993
|
+
>>> m = AdaptiveAvgPool3d((7, None, None))
|
1994
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1995
|
+
>>> output = m(input)
|
1996
|
+
>>> output.shape
|
1997
|
+
(1, 7, 9, 8, 64)
|
1998
|
+
|
1999
|
+
See Also
|
2000
|
+
--------
|
2001
|
+
AvgPool3d : Non-adaptive 3D average pooling.
|
2002
|
+
AdaptiveMaxPool3d : Adaptive 3D max pooling.
|
2003
|
+
"""
|
2004
|
+
__module__ = 'brainstate.nn'
|
2005
|
+
|
2006
|
+
def __init__(
|
2007
|
+
self,
|
2008
|
+
target_size: Union[int, Sequence[int]],
|
2009
|
+
channel_axis: Optional[int] = -1,
|
2010
|
+
name: Optional[str] = None,
|
2011
|
+
in_size: Optional[Sequence[int]] = None,
|
2012
|
+
):
|
2013
|
+
super().__init__(
|
2014
|
+
in_size=in_size,
|
2015
|
+
target_size=target_size,
|
2016
|
+
channel_axis=channel_axis,
|
2017
|
+
num_spatial_dims=3,
|
2018
|
+
operation=jnp.mean,
|
2019
|
+
name=name
|
2020
|
+
)
|
2021
|
+
|
2022
|
+
|
2023
|
+
class AdaptiveMaxPool1d(_AdaptivePool):
|
2024
|
+
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
2025
|
+
|
2026
|
+
The output size is :math:`L_{out}`, for any input size.
|
2027
|
+
The number of output features is equal to the number of input planes.
|
2028
|
+
|
2029
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2030
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2031
|
+
|
2032
|
+
Shape:
|
2033
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
2034
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
2035
|
+
:math:`L_{out}=\text{target\_size}`.
|
2036
|
+
|
2037
|
+
Parameters
|
2038
|
+
----------
|
2039
|
+
target_size : int or sequence of int
|
2040
|
+
The target output size. The number of output features for each channel.
|
2041
|
+
channel_axis : int, optional
|
2042
|
+
Axis of the spatial channels for which pooling is skipped.
|
2043
|
+
If ``None``, there is no channel axis. Default: -1
|
2044
|
+
name : str, optional
|
2045
|
+
The name of the module.
|
2046
|
+
in_size : Sequence of int, optional
|
2047
|
+
The shape of the input tensor for shape inference.
|
2048
|
+
|
2049
|
+
Examples
|
2050
|
+
--------
|
2051
|
+
.. code-block:: python
|
2052
|
+
|
2053
|
+
>>> import brainstate
|
2054
|
+
>>> # Target output size of 5
|
2055
|
+
>>> m = AdaptiveMaxPool1d(5)
|
2056
|
+
>>> input = brainstate.random.randn(1, 64, 8)
|
2057
|
+
>>> output = m(input)
|
2058
|
+
>>> output.shape
|
2059
|
+
(1, 5, 8)
|
2060
|
+
>>> # Can handle variable input sizes
|
2061
|
+
>>> input2 = brainstate.random.randn(1, 32, 8)
|
2062
|
+
>>> output2 = m(input2)
|
2063
|
+
>>> output2.shape
|
2064
|
+
(1, 5, 8) # Same output size regardless of input size
|
2065
|
+
|
2066
|
+
See Also
|
2067
|
+
--------
|
2068
|
+
MaxPool1d : Non-adaptive 1D max pooling.
|
2069
|
+
AdaptiveAvgPool1d : Adaptive 1D average pooling.
|
2070
|
+
"""
|
2071
|
+
__module__ = 'brainstate.nn'
|
2072
|
+
|
2073
|
+
def __init__(
|
2074
|
+
self,
|
2075
|
+
target_size: Union[int, Sequence[int]],
|
2076
|
+
channel_axis: Optional[int] = -1,
|
2077
|
+
name: Optional[str] = None,
|
2078
|
+
in_size: Optional[Sequence[int]] = None,
|
2079
|
+
):
|
2080
|
+
super().__init__(
|
2081
|
+
in_size=in_size,
|
2082
|
+
target_size=target_size,
|
2083
|
+
channel_axis=channel_axis,
|
2084
|
+
num_spatial_dims=1,
|
2085
|
+
operation=jnp.max,
|
2086
|
+
name=name
|
2087
|
+
)
|
2088
|
+
|
2089
|
+
|
2090
|
+
class AdaptiveMaxPool2d(_AdaptivePool):
|
2091
|
+
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
2092
|
+
|
2093
|
+
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
2094
|
+
The number of output features is equal to the number of input planes.
|
2095
|
+
|
2096
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2097
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2098
|
+
|
2099
|
+
Shape:
|
2100
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
2101
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
2102
|
+
:math:`(H_{out}, W_{out})=\text{target\_size}`.
|
2103
|
+
|
2104
|
+
Parameters
|
2105
|
+
----------
|
2106
|
+
target_size : int or tuple of int
|
2107
|
+
The target output size. If a single integer is provided, the output will be a square
|
2108
|
+
of that size. If a tuple is provided, it specifies (H_out, W_out).
|
2109
|
+
Use None for dimensions that should not be pooled.
|
2110
|
+
channel_axis : int, optional
|
2111
|
+
Axis of the spatial channels for which pooling is skipped.
|
2112
|
+
If ``None``, there is no channel axis. Default: -1
|
2113
|
+
name : str, optional
|
2114
|
+
The name of the module.
|
2115
|
+
in_size : Sequence of int, optional
|
2116
|
+
The shape of the input tensor for shape inference.
|
2117
|
+
|
2118
|
+
Examples
|
2119
|
+
--------
|
2120
|
+
.. code-block:: python
|
2121
|
+
|
2122
|
+
>>> import brainstate
|
2123
|
+
>>> # Target output size of 5x7
|
2124
|
+
>>> m = AdaptiveMaxPool2d((5, 7))
|
2125
|
+
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
2126
|
+
>>> output = m(input)
|
2127
|
+
>>> output.shape
|
2128
|
+
(1, 5, 7, 64)
|
2129
|
+
>>> # Target output size of 7x7 (square)
|
2130
|
+
>>> m = AdaptiveMaxPool2d(7)
|
2131
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
2132
|
+
>>> output = m(input)
|
2133
|
+
>>> output.shape
|
2134
|
+
(1, 7, 7, 64)
|
2135
|
+
>>> # Target output size of 10x7
|
2136
|
+
>>> m = AdaptiveMaxPool2d((None, 7))
|
2137
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
2138
|
+
>>> output = m(input)
|
2139
|
+
>>> output.shape
|
2140
|
+
(1, 10, 7, 64)
|
2141
|
+
|
2142
|
+
See Also
|
2143
|
+
--------
|
2144
|
+
MaxPool2d : Non-adaptive 2D max pooling.
|
2145
|
+
AdaptiveAvgPool2d : Adaptive 2D average pooling.
|
2146
|
+
"""
|
2147
|
+
__module__ = 'brainstate.nn'
|
2148
|
+
|
2149
|
+
def __init__(
|
2150
|
+
self,
|
2151
|
+
target_size: Union[int, Sequence[int]],
|
2152
|
+
channel_axis: Optional[int] = -1,
|
2153
|
+
name: Optional[str] = None,
|
2154
|
+
in_size: Optional[Sequence[int]] = None,
|
2155
|
+
):
|
2156
|
+
super().__init__(
|
2157
|
+
in_size=in_size,
|
2158
|
+
target_size=target_size,
|
2159
|
+
channel_axis=channel_axis,
|
2160
|
+
num_spatial_dims=2,
|
2161
|
+
operation=jnp.max,
|
2162
|
+
name=name
|
2163
|
+
)
|
2164
|
+
|
2165
|
+
|
2166
|
+
class AdaptiveMaxPool3d(_AdaptivePool):
|
2167
|
+
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
2168
|
+
|
2169
|
+
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
2170
|
+
The number of output features is equal to the number of input planes.
|
2171
|
+
|
2172
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2173
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2174
|
+
|
2175
|
+
Shape:
|
2176
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
2177
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
2178
|
+
where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
|
2179
|
+
|
2180
|
+
Parameters
|
2181
|
+
----------
|
2182
|
+
target_size : int or tuple of int
|
2183
|
+
The target output size. If a single integer is provided, the output will be a cube
|
2184
|
+
of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
|
2185
|
+
Use None for dimensions that should not be pooled.
|
2186
|
+
channel_axis : int, optional
|
2187
|
+
Axis of the spatial channels for which pooling is skipped.
|
2188
|
+
If ``None``, there is no channel axis. Default: -1
|
2189
|
+
name : str, optional
|
2190
|
+
The name of the module.
|
2191
|
+
in_size : Sequence of int, optional
|
2192
|
+
The shape of the input tensor for shape inference.
|
2193
|
+
|
2194
|
+
Examples
|
2195
|
+
--------
|
2196
|
+
.. code-block:: python
|
2197
|
+
|
2198
|
+
>>> import brainstate
|
2199
|
+
>>> # Target output size of 5x7x9
|
2200
|
+
>>> m = AdaptiveMaxPool3d((5, 7, 9))
|
2201
|
+
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
2202
|
+
>>> output = m(input)
|
2203
|
+
>>> output.shape
|
2204
|
+
(1, 5, 7, 9, 64)
|
2205
|
+
>>> # Target output size of 7x7x7 (cube)
|
2206
|
+
>>> m = AdaptiveMaxPool3d(7)
|
2207
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
2208
|
+
>>> output = m(input)
|
2209
|
+
>>> output.shape
|
2210
|
+
(1, 7, 7, 7, 64)
|
2211
|
+
>>> # Target output size of 7x9x8
|
2212
|
+
>>> m = AdaptiveMaxPool3d((7, None, None))
|
2213
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
2214
|
+
>>> output = m(input)
|
2215
|
+
>>> output.shape
|
2216
|
+
(1, 7, 9, 8, 64)
|
2217
|
+
|
2218
|
+
See Also
|
2219
|
+
--------
|
2220
|
+
MaxPool3d : Non-adaptive 3D max pooling.
|
2221
|
+
AdaptiveAvgPool3d : Adaptive 3D average pooling.
|
2222
|
+
"""
|
2223
|
+
__module__ = 'brainstate.nn'
|
2224
|
+
|
2225
|
+
def __init__(
|
2226
|
+
self,
|
2227
|
+
target_size: Union[int, Sequence[int]],
|
2228
|
+
channel_axis: Optional[int] = -1,
|
2229
|
+
name: Optional[str] = None,
|
2230
|
+
in_size: Optional[Sequence[int]] = None,
|
2231
|
+
):
|
2232
|
+
super().__init__(
|
2233
|
+
in_size=in_size,
|
2234
|
+
target_size=target_size,
|
2235
|
+
channel_axis=channel_axis,
|
2236
|
+
num_spatial_dims=3,
|
2237
|
+
operation=jnp.max,
|
2238
|
+
name=name
|
2239
|
+
)
|