brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/_poolings.py
CHANGED
@@ -1,1177 +1,1177 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import functools
|
19
|
-
from typing import Sequence, Optional
|
20
|
-
from typing import Union, Tuple, Callable, List
|
21
|
-
|
22
|
-
import brainunit as u
|
23
|
-
import jax
|
24
|
-
import jax.numpy as jnp
|
25
|
-
import numpy as np
|
26
|
-
|
27
|
-
from brainstate import environ
|
28
|
-
from brainstate.typing import Size
|
29
|
-
from ._module import Module
|
30
|
-
|
31
|
-
__all__ = [
|
32
|
-
'Flatten', 'Unflatten',
|
33
|
-
|
34
|
-
'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
|
35
|
-
'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
|
36
|
-
|
37
|
-
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
|
38
|
-
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
|
39
|
-
]
|
40
|
-
|
41
|
-
|
42
|
-
class Flatten(Module):
|
43
|
-
r"""
|
44
|
-
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
45
|
-
|
46
|
-
Shape:
|
47
|
-
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
|
48
|
-
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
|
49
|
-
number of dimensions including none.
|
50
|
-
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
|
51
|
-
|
52
|
-
Args:
|
53
|
-
in_size: Sequence of int. The shape of the input tensor.
|
54
|
-
start_axis: first dim to flatten (default = 1).
|
55
|
-
end_axis: last dim to flatten (default = -1).
|
56
|
-
|
57
|
-
Examples::
|
58
|
-
>>> import brainstate as brainstate
|
59
|
-
>>> inp = brainstate.random.randn(32, 1, 5, 5)
|
60
|
-
>>> # With default parameters
|
61
|
-
>>> m = Flatten()
|
62
|
-
>>> output = m(inp)
|
63
|
-
>>> output.shape
|
64
|
-
(32, 25)
|
65
|
-
>>> # With non-default parameters
|
66
|
-
>>> m = Flatten(0, 2)
|
67
|
-
>>> output = m(inp)
|
68
|
-
>>> output.shape
|
69
|
-
(160, 5)
|
70
|
-
"""
|
71
|
-
__module__ = 'brainstate.nn'
|
72
|
-
|
73
|
-
def __init__(
|
74
|
-
self,
|
75
|
-
start_axis: int = 0,
|
76
|
-
end_axis: int = -1,
|
77
|
-
in_size: Optional[Size] = None
|
78
|
-
) -> None:
|
79
|
-
super().__init__()
|
80
|
-
self.start_axis = start_axis
|
81
|
-
self.end_axis = end_axis
|
82
|
-
|
83
|
-
if in_size is not None:
|
84
|
-
self.in_size = tuple(in_size)
|
85
|
-
y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
|
86
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
87
|
-
self.out_size = y.shape
|
88
|
-
|
89
|
-
def update(self, x):
|
90
|
-
if self._in_size is None:
|
91
|
-
start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
|
92
|
-
else:
|
93
|
-
assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
|
94
|
-
dim_diff = x.ndim - len(self.in_size)
|
95
|
-
if self.in_size != x.shape[dim_diff:]:
|
96
|
-
raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
|
97
|
-
if self.start_axis >= 0:
|
98
|
-
start_axis = self.start_axis + dim_diff
|
99
|
-
else:
|
100
|
-
start_axis = x.ndim + self.start_axis
|
101
|
-
return u.math.flatten(x, start_axis, self.end_axis)
|
102
|
-
|
103
|
-
def __repr__(self) -> str:
|
104
|
-
return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
|
105
|
-
|
106
|
-
|
107
|
-
class Unflatten(Module):
|
108
|
-
r"""
|
109
|
-
Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
|
110
|
-
|
111
|
-
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
|
112
|
-
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
|
113
|
-
|
114
|
-
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
|
115
|
-
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
|
116
|
-
(tuple of `(name, size)` tuples) for `NamedTensor` input.
|
117
|
-
|
118
|
-
Shape:
|
119
|
-
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
|
120
|
-
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
|
121
|
-
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
|
122
|
-
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
123
|
-
|
124
|
-
Args:
|
125
|
-
axis: int, Dimension to be unflattened.
|
126
|
-
sizes: Sequence of int. New shape of the unflattened dimension.
|
127
|
-
in_size: Sequence of int. The shape of the input tensor.
|
128
|
-
"""
|
129
|
-
__module__ = 'brainstate.nn'
|
130
|
-
|
131
|
-
def __init__(
|
132
|
-
self,
|
133
|
-
axis: int,
|
134
|
-
sizes: Size,
|
135
|
-
name: str = None,
|
136
|
-
in_size: Optional[Size] = None
|
137
|
-
) -> None:
|
138
|
-
super().__init__(name=name)
|
139
|
-
|
140
|
-
self.axis = axis
|
141
|
-
self.sizes = sizes
|
142
|
-
if isinstance(sizes, (tuple, list)):
|
143
|
-
for idx, elem in enumerate(sizes):
|
144
|
-
if not isinstance(elem, int):
|
145
|
-
raise TypeError("unflattened sizes must be tuple of ints, " +
|
146
|
-
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
|
147
|
-
else:
|
148
|
-
raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
|
149
|
-
|
150
|
-
if in_size is not None:
|
151
|
-
self.in_size = tuple(in_size)
|
152
|
-
y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
|
153
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
154
|
-
self.out_size = y.shape
|
155
|
-
|
156
|
-
def update(self, x):
|
157
|
-
return u.math.unflatten(x, self.axis, self.sizes)
|
158
|
-
|
159
|
-
def __repr__(self):
|
160
|
-
return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
|
161
|
-
|
162
|
-
|
163
|
-
class _MaxPool(Module):
|
164
|
-
def __init__(
|
165
|
-
self,
|
166
|
-
init_value: float,
|
167
|
-
computation: Callable,
|
168
|
-
pool_dim: int,
|
169
|
-
kernel_size: Size,
|
170
|
-
stride: Union[int, Sequence[int]] = None,
|
171
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
172
|
-
channel_axis: Optional[int] = -1,
|
173
|
-
name: Optional[str] = None,
|
174
|
-
in_size: Optional[Size] = None,
|
175
|
-
):
|
176
|
-
super().__init__(name=name)
|
177
|
-
|
178
|
-
self.init_value = init_value
|
179
|
-
self.computation = computation
|
180
|
-
self.pool_dim = pool_dim
|
181
|
-
|
182
|
-
# kernel_size
|
183
|
-
if isinstance(kernel_size, int):
|
184
|
-
kernel_size = (kernel_size,) * pool_dim
|
185
|
-
elif isinstance(kernel_size, Sequence):
|
186
|
-
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
187
|
-
assert all(
|
188
|
-
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
189
|
-
if len(kernel_size) != pool_dim:
|
190
|
-
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
191
|
-
else:
|
192
|
-
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
193
|
-
self.kernel_size = kernel_size
|
194
|
-
|
195
|
-
# stride
|
196
|
-
if stride is None:
|
197
|
-
stride = kernel_size
|
198
|
-
if isinstance(stride, int):
|
199
|
-
stride = (stride,) * pool_dim
|
200
|
-
elif isinstance(stride, Sequence):
|
201
|
-
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
202
|
-
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
203
|
-
if len(stride) != pool_dim:
|
204
|
-
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
205
|
-
else:
|
206
|
-
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
207
|
-
self.stride = stride
|
208
|
-
|
209
|
-
# padding
|
210
|
-
if isinstance(padding, str):
|
211
|
-
if padding not in ("SAME", "VALID"):
|
212
|
-
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
|
213
|
-
elif isinstance(padding, int):
|
214
|
-
padding = [(padding, padding) for _ in range(pool_dim)]
|
215
|
-
elif isinstance(padding, (list, tuple)):
|
216
|
-
if isinstance(padding[0], int):
|
217
|
-
if len(padding) == pool_dim:
|
218
|
-
padding = [(x, x) for x in padding]
|
219
|
-
else:
|
220
|
-
raise ValueError(f'If padding is a sequence of ints, it '
|
221
|
-
f'should has the length of {pool_dim}.')
|
222
|
-
else:
|
223
|
-
if not all([isinstance(x, (tuple, list)) for x in padding]):
|
224
|
-
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
|
225
|
-
if not all([len(x) == 2 for x in padding]):
|
226
|
-
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
|
227
|
-
if len(padding) == 1:
|
228
|
-
padding = tuple(padding) * pool_dim
|
229
|
-
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
|
230
|
-
else:
|
231
|
-
raise ValueError
|
232
|
-
self.padding = padding
|
233
|
-
|
234
|
-
# channel_axis
|
235
|
-
assert channel_axis is None or isinstance(channel_axis, int), \
|
236
|
-
f'channel_axis should be an int, but got {channel_axis}'
|
237
|
-
self.channel_axis = channel_axis
|
238
|
-
|
239
|
-
# in & out shapes
|
240
|
-
if in_size is not None:
|
241
|
-
in_size = tuple(in_size)
|
242
|
-
self.in_size = in_size
|
243
|
-
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
244
|
-
self.out_size = y.shape[1:]
|
245
|
-
|
246
|
-
def update(self, x):
|
247
|
-
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
248
|
-
if x.ndim < x_dim:
|
249
|
-
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
250
|
-
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
|
251
|
-
stride = self._infer_shape(x.ndim, self.stride, 1)
|
252
|
-
padding = (self.padding if isinstance(self.padding, str) else
|
253
|
-
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
254
|
-
r = jax.lax.reduce_window(
|
255
|
-
x,
|
256
|
-
init_value=self.init_value,
|
257
|
-
computation=self.computation,
|
258
|
-
window_dimensions=window_shape,
|
259
|
-
window_strides=stride,
|
260
|
-
padding=padding
|
261
|
-
)
|
262
|
-
return r
|
263
|
-
|
264
|
-
def _infer_shape(self, x_dim, inputs, element):
|
265
|
-
channel_axis = self.channel_axis
|
266
|
-
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
|
267
|
-
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
|
268
|
-
if channel_axis and channel_axis < 0:
|
269
|
-
channel_axis = x_dim + channel_axis
|
270
|
-
all_dims = list(range(x_dim))
|
271
|
-
if channel_axis is not None:
|
272
|
-
all_dims.pop(channel_axis)
|
273
|
-
pool_dims = all_dims[-self.pool_dim:]
|
274
|
-
results = [element] * x_dim
|
275
|
-
for i, dim in enumerate(pool_dims):
|
276
|
-
results[dim] = inputs[i]
|
277
|
-
return results
|
278
|
-
|
279
|
-
|
280
|
-
class _AvgPool(_MaxPool):
|
281
|
-
def update(self, x):
|
282
|
-
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
283
|
-
if x.ndim < x_dim:
|
284
|
-
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
285
|
-
dims = self._infer_shape(x.ndim, self.kernel_size, 1)
|
286
|
-
stride = self._infer_shape(x.ndim, self.stride, 1)
|
287
|
-
padding = (self.padding if isinstance(self.padding, str) else
|
288
|
-
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
289
|
-
pooled = jax.lax.reduce_window(x,
|
290
|
-
init_value=self.init_value,
|
291
|
-
computation=self.computation,
|
292
|
-
window_dimensions=dims,
|
293
|
-
window_strides=stride,
|
294
|
-
padding=padding)
|
295
|
-
if padding == "VALID":
|
296
|
-
# Avoid the extra reduce_window.
|
297
|
-
return pooled / np.prod(dims)
|
298
|
-
else:
|
299
|
-
# Count the number of valid entries at each input point, then use that for
|
300
|
-
# computing average. Assumes that any two arrays of same shape will be
|
301
|
-
# padded the same.
|
302
|
-
window_counts = jax.lax.reduce_window(jnp.ones_like(x),
|
303
|
-
init_value=self.init_value,
|
304
|
-
computation=self.computation,
|
305
|
-
window_dimensions=dims,
|
306
|
-
window_strides=stride,
|
307
|
-
padding=padding)
|
308
|
-
assert pooled.shape == window_counts.shape
|
309
|
-
return pooled / window_counts
|
310
|
-
|
311
|
-
|
312
|
-
class MaxPool1d(_MaxPool):
|
313
|
-
r"""Applies a 1D max pooling over an input signal composed of several input planes.
|
314
|
-
|
315
|
-
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
|
316
|
-
and output :math:`(N, L_{out}, C)` can be precisely described as:
|
317
|
-
|
318
|
-
.. math::
|
319
|
-
out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
|
320
|
-
input(N_i, stride \times k + m, C_j)
|
321
|
-
|
322
|
-
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
323
|
-
for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
|
324
|
-
sliding window. This `link`_ has a nice visualization of the pooling parameters.
|
325
|
-
|
326
|
-
Shape:
|
327
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
328
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
329
|
-
|
330
|
-
.. math::
|
331
|
-
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
332
|
-
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
|
333
|
-
|
334
|
-
|
335
|
-
Examples::
|
336
|
-
|
337
|
-
>>> import brainstate as brainstate
|
338
|
-
>>> # pool of size=3, stride=2
|
339
|
-
>>> m = MaxPool1d(3, stride=2, channel_axis=-1)
|
340
|
-
>>> input = brainstate.random.randn(20, 50, 16)
|
341
|
-
>>> output = m(input)
|
342
|
-
>>> output.shape
|
343
|
-
(20, 24, 16)
|
344
|
-
|
345
|
-
Parameters
|
346
|
-
----------
|
347
|
-
in_size: Sequence of int
|
348
|
-
The shape of the input tensor.
|
349
|
-
kernel_size: int, sequence of int
|
350
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
351
|
-
stride: int, sequence of int
|
352
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
353
|
-
padding: str, int, sequence of tuple
|
354
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
355
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
356
|
-
and after each spatial dimension.
|
357
|
-
channel_axis: int, optional
|
358
|
-
Axis of the spatial channels for which pooling is skipped.
|
359
|
-
If ``None``, there is no channel axis.
|
360
|
-
name: optional, str
|
361
|
-
The object name.
|
362
|
-
|
363
|
-
.. _link:
|
364
|
-
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
365
|
-
"""
|
366
|
-
__module__ = 'brainstate.nn'
|
367
|
-
|
368
|
-
def __init__(
|
369
|
-
self,
|
370
|
-
kernel_size: Size,
|
371
|
-
stride: Union[int, Sequence[int]] = None,
|
372
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
373
|
-
channel_axis: Optional[int] = -1,
|
374
|
-
name: Optional[str] = None,
|
375
|
-
in_size: Optional[Size] = None,
|
376
|
-
):
|
377
|
-
super().__init__(in_size=in_size,
|
378
|
-
init_value=-jax.numpy.inf,
|
379
|
-
computation=jax.lax.max,
|
380
|
-
pool_dim=1,
|
381
|
-
kernel_size=kernel_size,
|
382
|
-
stride=stride,
|
383
|
-
padding=padding,
|
384
|
-
channel_axis=channel_axis,
|
385
|
-
name=name)
|
386
|
-
|
387
|
-
|
388
|
-
class MaxPool2d(_MaxPool):
|
389
|
-
r"""Applies a 2D max pooling over an input signal composed of several input planes.
|
390
|
-
|
391
|
-
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
392
|
-
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
393
|
-
can be precisely described as:
|
394
|
-
|
395
|
-
.. math::
|
396
|
-
\begin{aligned}
|
397
|
-
out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
398
|
-
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
399
|
-
\text{stride[1]} \times w + n, C_j)
|
400
|
-
\end{aligned}
|
401
|
-
|
402
|
-
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
403
|
-
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
404
|
-
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
405
|
-
|
406
|
-
|
407
|
-
Shape:
|
408
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
409
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
410
|
-
|
411
|
-
.. math::
|
412
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
|
413
|
-
\times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
|
414
|
-
|
415
|
-
.. math::
|
416
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
|
417
|
-
\times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
|
418
|
-
|
419
|
-
Examples::
|
420
|
-
|
421
|
-
>>> import brainstate as brainstate
|
422
|
-
>>> # pool of square window of size=3, stride=2
|
423
|
-
>>> m = MaxPool2d(3, stride=2)
|
424
|
-
>>> # pool of non-square window
|
425
|
-
>>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
|
426
|
-
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
427
|
-
>>> output = m(input)
|
428
|
-
>>> output.shape
|
429
|
-
(20, 24, 31, 16)
|
430
|
-
|
431
|
-
.. _link:
|
432
|
-
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
433
|
-
|
434
|
-
Parameters
|
435
|
-
----------
|
436
|
-
in_size: Sequence of int
|
437
|
-
The shape of the input tensor.
|
438
|
-
kernel_size: int, sequence of int
|
439
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
440
|
-
stride: int, sequence of int
|
441
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
442
|
-
padding: str, int, sequence of tuple
|
443
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
444
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
445
|
-
and after each spatial dimension.
|
446
|
-
channel_axis: int, optional
|
447
|
-
Axis of the spatial channels for which pooling is skipped.
|
448
|
-
If ``None``, there is no channel axis.
|
449
|
-
name: optional, str
|
450
|
-
The object name.
|
451
|
-
|
452
|
-
"""
|
453
|
-
__module__ = 'brainstate.nn'
|
454
|
-
|
455
|
-
def __init__(
|
456
|
-
self,
|
457
|
-
kernel_size: Size,
|
458
|
-
stride: Union[int, Sequence[int]] = None,
|
459
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
460
|
-
channel_axis: Optional[int] = -1,
|
461
|
-
name: Optional[str] = None,
|
462
|
-
in_size: Optional[Size] = None,
|
463
|
-
):
|
464
|
-
super().__init__(in_size=in_size,
|
465
|
-
init_value=-jax.numpy.inf,
|
466
|
-
computation=jax.lax.max,
|
467
|
-
pool_dim=2,
|
468
|
-
kernel_size=kernel_size,
|
469
|
-
stride=stride,
|
470
|
-
padding=padding,
|
471
|
-
channel_axis=channel_axis,
|
472
|
-
name=name)
|
473
|
-
|
474
|
-
|
475
|
-
class MaxPool3d(_MaxPool):
|
476
|
-
r"""Applies a 3D max pooling over an input signal composed of several input planes.
|
477
|
-
|
478
|
-
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
479
|
-
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
480
|
-
can be precisely described as:
|
481
|
-
|
482
|
-
.. math::
|
483
|
-
\begin{aligned}
|
484
|
-
\text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
485
|
-
& \text{input}(N_i, \text{stride[0]} \times d + k,
|
486
|
-
\text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
|
487
|
-
\end{aligned}
|
488
|
-
|
489
|
-
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
490
|
-
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
491
|
-
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
492
|
-
|
493
|
-
|
494
|
-
Shape:
|
495
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
496
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
497
|
-
|
498
|
-
.. math::
|
499
|
-
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
|
500
|
-
(\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
|
501
|
-
|
502
|
-
.. math::
|
503
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
|
504
|
-
(\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
|
505
|
-
|
506
|
-
.. math::
|
507
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
|
508
|
-
(\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
|
509
|
-
|
510
|
-
Examples::
|
511
|
-
|
512
|
-
>>> import brainstate as brainstate
|
513
|
-
>>> # pool of square window of size=3, stride=2
|
514
|
-
>>> m = MaxPool3d(3, stride=2)
|
515
|
-
>>> # pool of non-square window
|
516
|
-
>>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
517
|
-
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
518
|
-
>>> output = m(input)
|
519
|
-
>>> output.shape
|
520
|
-
(20, 24, 43, 15, 16)
|
521
|
-
|
522
|
-
.. _link:
|
523
|
-
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
524
|
-
|
525
|
-
Parameters
|
526
|
-
----------
|
527
|
-
in_size: Sequence of int
|
528
|
-
The shape of the input tensor.
|
529
|
-
kernel_size: int, sequence of int
|
530
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
531
|
-
stride: int, sequence of int
|
532
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
533
|
-
padding: str, int, sequence of tuple
|
534
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
535
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
536
|
-
and after each spatial dimension.
|
537
|
-
channel_axis: int, optional
|
538
|
-
Axis of the spatial channels for which pooling is skipped.
|
539
|
-
If ``None``, there is no channel axis.
|
540
|
-
name: optional, str
|
541
|
-
The object name.
|
542
|
-
|
543
|
-
"""
|
544
|
-
__module__ = 'brainstate.nn'
|
545
|
-
|
546
|
-
def __init__(
|
547
|
-
self,
|
548
|
-
kernel_size: Size,
|
549
|
-
stride: Union[int, Sequence[int]] = None,
|
550
|
-
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
551
|
-
channel_axis: Optional[int] = -1,
|
552
|
-
name: Optional[str] = None,
|
553
|
-
in_size: Optional[Size] = None,
|
554
|
-
):
|
555
|
-
super().__init__(in_size=in_size,
|
556
|
-
init_value=-jax.numpy.inf,
|
557
|
-
computation=jax.lax.max,
|
558
|
-
pool_dim=3,
|
559
|
-
kernel_size=kernel_size,
|
560
|
-
stride=stride,
|
561
|
-
padding=padding,
|
562
|
-
channel_axis=channel_axis,
|
563
|
-
name=name)
|
564
|
-
|
565
|
-
|
566
|
-
class AvgPool1d(_AvgPool):
|
567
|
-
r"""Applies a 1D average pooling over an input signal composed of several input planes.
|
568
|
-
|
569
|
-
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
|
570
|
-
output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
|
571
|
-
can be precisely described as:
|
572
|
-
|
573
|
-
.. math::
|
574
|
-
|
575
|
-
\text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
|
576
|
-
\text{input}(N_i, \text{stride} \times l + m, C_j)
|
577
|
-
|
578
|
-
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
579
|
-
for :attr:`padding` number of points.
|
580
|
-
|
581
|
-
Shape:
|
582
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
583
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
584
|
-
|
585
|
-
.. math::
|
586
|
-
L_{out} = \left\lfloor \frac{L_{in} +
|
587
|
-
2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
588
|
-
|
589
|
-
Examples::
|
590
|
-
|
591
|
-
>>> import brainstate as brainstate
|
592
|
-
>>> # pool with window of size=3, stride=2
|
593
|
-
>>> m = AvgPool1d(3, stride=2)
|
594
|
-
>>> input = brainstate.random.randn(20, 50, 16)
|
595
|
-
>>> m(input).shape
|
596
|
-
(20, 24, 16)
|
597
|
-
|
598
|
-
Parameters
|
599
|
-
----------
|
600
|
-
in_size: Sequence of int
|
601
|
-
The shape of the input tensor.
|
602
|
-
kernel_size: int, sequence of int
|
603
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
604
|
-
stride: int, sequence of int
|
605
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
606
|
-
padding: str, int, sequence of tuple
|
607
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
608
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
609
|
-
and after each spatial dimension.
|
610
|
-
channel_axis: int, optional
|
611
|
-
Axis of the spatial channels for which pooling is skipped.
|
612
|
-
If ``None``, there is no channel axis.
|
613
|
-
name: optional, str
|
614
|
-
The object name.
|
615
|
-
|
616
|
-
"""
|
617
|
-
__module__ = 'brainstate.nn'
|
618
|
-
|
619
|
-
def __init__(
|
620
|
-
self,
|
621
|
-
kernel_size: Size,
|
622
|
-
stride: Union[int, Sequence[int]] = 1,
|
623
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
624
|
-
channel_axis: Optional[int] = -1,
|
625
|
-
name: Optional[str] = None,
|
626
|
-
in_size: Optional[Size] = None,
|
627
|
-
):
|
628
|
-
super().__init__(in_size=in_size,
|
629
|
-
init_value=0.,
|
630
|
-
computation=jax.lax.add,
|
631
|
-
pool_dim=1,
|
632
|
-
kernel_size=kernel_size,
|
633
|
-
stride=stride,
|
634
|
-
padding=padding,
|
635
|
-
channel_axis=channel_axis,
|
636
|
-
name=name)
|
637
|
-
|
638
|
-
|
639
|
-
class AvgPool2d(_AvgPool):
|
640
|
-
r"""Applies a 2D average pooling over an input signal composed of several input planes.
|
641
|
-
|
642
|
-
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
643
|
-
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
644
|
-
can be precisely described as:
|
645
|
-
|
646
|
-
.. math::
|
647
|
-
|
648
|
-
out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
|
649
|
-
input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
|
650
|
-
|
651
|
-
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
652
|
-
for :attr:`padding` number of points.
|
653
|
-
|
654
|
-
Shape:
|
655
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
656
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
657
|
-
|
658
|
-
.. math::
|
659
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
|
660
|
-
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
661
|
-
|
662
|
-
.. math::
|
663
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
|
664
|
-
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
665
|
-
|
666
|
-
Examples::
|
667
|
-
|
668
|
-
>>> import brainstate as brainstate
|
669
|
-
>>> # pool of square window of size=3, stride=2
|
670
|
-
>>> m = AvgPool2d(3, stride=2)
|
671
|
-
>>> # pool of non-square window
|
672
|
-
>>> m = AvgPool2d((3, 2), stride=(2, 1))
|
673
|
-
>>> input = brainstate.random.randn(20, 50, 32, , 16)
|
674
|
-
>>> output = m(input)
|
675
|
-
>>> output.shape
|
676
|
-
(20, 24, 31, 16)
|
677
|
-
|
678
|
-
Parameters
|
679
|
-
----------
|
680
|
-
in_size: Sequence of int
|
681
|
-
The shape of the input tensor.
|
682
|
-
kernel_size: int, sequence of int
|
683
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
684
|
-
stride: int, sequence of int
|
685
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
686
|
-
padding: str, int, sequence of tuple
|
687
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
688
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
689
|
-
and after each spatial dimension.
|
690
|
-
channel_axis: int, optional
|
691
|
-
Axis of the spatial channels for which pooling is skipped.
|
692
|
-
If ``None``, there is no channel axis.
|
693
|
-
name: optional, str
|
694
|
-
The object name.
|
695
|
-
"""
|
696
|
-
__module__ = 'brainstate.nn'
|
697
|
-
|
698
|
-
def __init__(
|
699
|
-
self,
|
700
|
-
kernel_size: Size,
|
701
|
-
stride: Union[int, Sequence[int]] = 1,
|
702
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
703
|
-
channel_axis: Optional[int] = -1,
|
704
|
-
name: Optional[str] = None,
|
705
|
-
in_size: Optional[Size] = None,
|
706
|
-
):
|
707
|
-
super().__init__(in_size=in_size,
|
708
|
-
init_value=0.,
|
709
|
-
computation=jax.lax.add,
|
710
|
-
pool_dim=2,
|
711
|
-
kernel_size=kernel_size,
|
712
|
-
stride=stride,
|
713
|
-
padding=padding,
|
714
|
-
channel_axis=channel_axis,
|
715
|
-
name=name)
|
716
|
-
|
717
|
-
|
718
|
-
class AvgPool3d(_AvgPool):
|
719
|
-
r"""Applies a 3D average pooling over an input signal composed of several input planes.
|
720
|
-
|
721
|
-
|
722
|
-
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
723
|
-
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
724
|
-
can be precisely described as:
|
725
|
-
|
726
|
-
.. math::
|
727
|
-
\begin{aligned}
|
728
|
-
\text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
|
729
|
-
& \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
|
730
|
-
\text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
|
731
|
-
{kD \times kH \times kW}
|
732
|
-
\end{aligned}
|
733
|
-
|
734
|
-
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
|
735
|
-
for :attr:`padding` number of points.
|
736
|
-
|
737
|
-
Shape:
|
738
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
739
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
|
740
|
-
:math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
741
|
-
|
742
|
-
.. math::
|
743
|
-
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
|
744
|
-
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
745
|
-
|
746
|
-
.. math::
|
747
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
|
748
|
-
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
749
|
-
|
750
|
-
.. math::
|
751
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
|
752
|
-
\text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
753
|
-
|
754
|
-
Examples::
|
755
|
-
|
756
|
-
>>> import brainstate as brainstate
|
757
|
-
>>> # pool of square window of size=3, stride=2
|
758
|
-
>>> m = AvgPool3d(3, stride=2)
|
759
|
-
>>> # pool of non-square window
|
760
|
-
>>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
|
761
|
-
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
762
|
-
>>> output = m(input)
|
763
|
-
>>> output.shape
|
764
|
-
(20, 24, 43, 15, 16)
|
765
|
-
|
766
|
-
Parameters
|
767
|
-
----------
|
768
|
-
in_size: Sequence of int
|
769
|
-
The shape of the input tensor.
|
770
|
-
kernel_size: int, sequence of int
|
771
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
772
|
-
stride: int, sequence of int
|
773
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
774
|
-
padding: str, int, sequence of tuple
|
775
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
776
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
777
|
-
and after each spatial dimension.
|
778
|
-
channel_axis: int, optional
|
779
|
-
Axis of the spatial channels for which pooling is skipped.
|
780
|
-
If ``None``, there is no channel axis.
|
781
|
-
name: optional, str
|
782
|
-
The object name.
|
783
|
-
|
784
|
-
"""
|
785
|
-
__module__ = 'brainstate.nn'
|
786
|
-
|
787
|
-
def __init__(
|
788
|
-
self,
|
789
|
-
kernel_size: Size,
|
790
|
-
stride: Union[int, Sequence[int]] = 1,
|
791
|
-
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
792
|
-
channel_axis: Optional[int] = -1,
|
793
|
-
name: Optional[str] = None,
|
794
|
-
in_size: Optional[Size] = None,
|
795
|
-
):
|
796
|
-
super().__init__(in_size=in_size,
|
797
|
-
init_value=0.,
|
798
|
-
computation=jax.lax.add,
|
799
|
-
pool_dim=3,
|
800
|
-
kernel_size=kernel_size,
|
801
|
-
stride=stride,
|
802
|
-
padding=padding,
|
803
|
-
channel_axis=channel_axis,
|
804
|
-
name=name)
|
805
|
-
|
806
|
-
|
807
|
-
def _adaptive_pool1d(x, target_size: int, operation: Callable):
|
808
|
-
"""Adaptive pool 1D.
|
809
|
-
|
810
|
-
Args:
|
811
|
-
x: The input. Should be a JAX array of shape `(dim,)`.
|
812
|
-
target_size: The shape of the output after the pooling operation `(target_size,)`.
|
813
|
-
operation: The pooling operation to be performed on the input array.
|
814
|
-
|
815
|
-
Returns:
|
816
|
-
A JAX array of shape `(target_size, )`.
|
817
|
-
"""
|
818
|
-
size = jnp.size(x)
|
819
|
-
num_head_arrays = size % target_size
|
820
|
-
num_block = size // target_size
|
821
|
-
if num_head_arrays != 0:
|
822
|
-
head_end_index = num_head_arrays * (num_block + 1)
|
823
|
-
heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
|
824
|
-
tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
|
825
|
-
outs = jnp.concatenate([heads, tails])
|
826
|
-
else:
|
827
|
-
outs = jax.vmap(operation)(x.reshape(-1, num_block))
|
828
|
-
return outs
|
829
|
-
|
830
|
-
|
831
|
-
def _generate_vmap(fun: Callable, map_axes: List[int]):
|
832
|
-
map_axes = sorted(map_axes)
|
833
|
-
for axis in map_axes:
|
834
|
-
fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
|
835
|
-
return fun
|
836
|
-
|
837
|
-
|
838
|
-
class _AdaptivePool(Module):
|
839
|
-
"""General N dimensional adaptive down-sampling to a target shape.
|
840
|
-
|
841
|
-
Parameters
|
842
|
-
----------
|
843
|
-
in_size: Sequence of int
|
844
|
-
The shape of the input tensor.
|
845
|
-
target_size: int, sequence of int
|
846
|
-
The target output shape.
|
847
|
-
num_spatial_dims: int
|
848
|
-
The number of spatial dimensions.
|
849
|
-
channel_axis: int, optional
|
850
|
-
Axis of the spatial channels for which pooling is skipped.
|
851
|
-
If ``None``, there is no channel axis.
|
852
|
-
operation: Callable
|
853
|
-
The down-sampling operation.
|
854
|
-
name: str
|
855
|
-
The class name.
|
856
|
-
"""
|
857
|
-
|
858
|
-
def __init__(
|
859
|
-
self,
|
860
|
-
in_size: Size,
|
861
|
-
target_size: Size,
|
862
|
-
num_spatial_dims: int,
|
863
|
-
operation: Callable,
|
864
|
-
channel_axis: Optional[int] = -1,
|
865
|
-
name: Optional[str] = None,
|
866
|
-
):
|
867
|
-
super().__init__(name=name)
|
868
|
-
|
869
|
-
self.channel_axis = channel_axis
|
870
|
-
self.operation = operation
|
871
|
-
if isinstance(target_size, int):
|
872
|
-
self.target_shape = (target_size,) * num_spatial_dims
|
873
|
-
elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
|
874
|
-
self.target_shape = target_size
|
875
|
-
else:
|
876
|
-
raise ValueError("`target_size` must either be an int or tuple of length "
|
877
|
-
f"{num_spatial_dims} containing ints.")
|
878
|
-
|
879
|
-
# in & out shapes
|
880
|
-
if in_size is not None:
|
881
|
-
in_size = tuple(in_size)
|
882
|
-
self.in_size = in_size
|
883
|
-
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
884
|
-
self.out_size = y.shape[1:]
|
885
|
-
|
886
|
-
def update(self, x):
|
887
|
-
"""Input-output mapping.
|
888
|
-
|
889
|
-
Parameters
|
890
|
-
----------
|
891
|
-
x: Array
|
892
|
-
Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
|
893
|
-
or `(..., dim_1, dim_2)`.
|
894
|
-
"""
|
895
|
-
# channel axis
|
896
|
-
channel_axis = self.channel_axis
|
897
|
-
|
898
|
-
if channel_axis:
|
899
|
-
if not 0 <= abs(channel_axis) < x.ndim:
|
900
|
-
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
|
901
|
-
if channel_axis < 0:
|
902
|
-
channel_axis = x.ndim + channel_axis
|
903
|
-
# input dimension
|
904
|
-
if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
|
905
|
-
raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
|
906
|
-
f"dimensions (channel_axis={self.channel_axis}). "
|
907
|
-
f"But got {x.ndim} dimensions.")
|
908
|
-
# pooling dimensions
|
909
|
-
pool_dims = list(range(x.ndim))
|
910
|
-
if channel_axis:
|
911
|
-
pool_dims.pop(channel_axis)
|
912
|
-
|
913
|
-
# pooling
|
914
|
-
for i, di in enumerate(pool_dims[-len(self.target_shape):]):
|
915
|
-
poo_axes = [j for j in range(x.ndim) if j != di]
|
916
|
-
op = _generate_vmap(_adaptive_pool1d, poo_axes)
|
917
|
-
x = op(x, self.target_shape[i], self.operation)
|
918
|
-
return x
|
919
|
-
|
920
|
-
|
921
|
-
class AdaptiveAvgPool1d(_AdaptivePool):
|
922
|
-
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
923
|
-
|
924
|
-
The output size is :math:`L_{out}`, for any input size.
|
925
|
-
The number of output features is equal to the number of input planes.
|
926
|
-
|
927
|
-
Shape:
|
928
|
-
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
929
|
-
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
930
|
-
:math:`L_{out}=\text{output\_size}`.
|
931
|
-
|
932
|
-
Examples:
|
933
|
-
|
934
|
-
>>> import brainstate as brainstate
|
935
|
-
>>> # target output size of 5
|
936
|
-
>>> m = AdaptiveMaxPool1d(5)
|
937
|
-
>>> input = brainstate.random.randn(1, 64, 8)
|
938
|
-
>>> output = m(input)
|
939
|
-
>>> output.shape
|
940
|
-
(1, 5, 8)
|
941
|
-
|
942
|
-
Parameters
|
943
|
-
----------
|
944
|
-
in_size: Sequence of int
|
945
|
-
The shape of the input tensor.
|
946
|
-
target_size: int, sequence of int
|
947
|
-
The target output shape.
|
948
|
-
channel_axis: int, optional
|
949
|
-
Axis of the spatial channels for which pooling is skipped.
|
950
|
-
If ``None``, there is no channel axis.
|
951
|
-
name: str
|
952
|
-
The class name.
|
953
|
-
"""
|
954
|
-
__module__ = 'brainstate.nn'
|
955
|
-
|
956
|
-
def __init__(self,
|
957
|
-
target_size: Union[int, Sequence[int]],
|
958
|
-
channel_axis: Optional[int] = -1,
|
959
|
-
name: Optional[str] = None,
|
960
|
-
in_size: Optional[Sequence[int]] = None, ):
|
961
|
-
super().__init__(in_size=in_size,
|
962
|
-
target_size=target_size,
|
963
|
-
channel_axis=channel_axis,
|
964
|
-
num_spatial_dims=1,
|
965
|
-
operation=jnp.mean,
|
966
|
-
name=name)
|
967
|
-
|
968
|
-
|
969
|
-
class AdaptiveAvgPool2d(_AdaptivePool):
|
970
|
-
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
971
|
-
|
972
|
-
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
973
|
-
The number of output features is equal to the number of input planes.
|
974
|
-
|
975
|
-
Shape:
|
976
|
-
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
977
|
-
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
978
|
-
:math:`(H_{out}, W_{out})=\text{output\_size}`.
|
979
|
-
|
980
|
-
Examples:
|
981
|
-
|
982
|
-
>>> import brainstate as brainstate
|
983
|
-
>>> # target output size of 5x7
|
984
|
-
>>> m = AdaptiveMaxPool2d((5, 7))
|
985
|
-
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
986
|
-
>>> output = m(input)
|
987
|
-
>>> output.shape
|
988
|
-
(1, 5, 7, 64)
|
989
|
-
>>> # target output size of 7x7 (square)
|
990
|
-
>>> m = AdaptiveMaxPool2d(7)
|
991
|
-
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
992
|
-
>>> output = m(input)
|
993
|
-
>>> output.shape
|
994
|
-
(1, 7, 7, 64)
|
995
|
-
>>> # target output size of 10x7
|
996
|
-
>>> m = AdaptiveMaxPool2d((None, 7))
|
997
|
-
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
998
|
-
>>> output = m(input)
|
999
|
-
>>> output.shape
|
1000
|
-
(1, 10, 7, 64)
|
1001
|
-
|
1002
|
-
Parameters
|
1003
|
-
----------
|
1004
|
-
in_size: Sequence of int
|
1005
|
-
The shape of the input tensor.
|
1006
|
-
target_size: int, sequence of int
|
1007
|
-
The target output shape.
|
1008
|
-
channel_axis: int, optional
|
1009
|
-
Axis of the spatial channels for which pooling is skipped.
|
1010
|
-
If ``None``, there is no channel axis.
|
1011
|
-
name: str
|
1012
|
-
The class name.
|
1013
|
-
"""
|
1014
|
-
__module__ = 'brainstate.nn'
|
1015
|
-
|
1016
|
-
def __init__(self,
|
1017
|
-
target_size: Union[int, Sequence[int]],
|
1018
|
-
channel_axis: Optional[int] = -1,
|
1019
|
-
name: Optional[str] = None,
|
1020
|
-
|
1021
|
-
in_size: Optional[Sequence[int]] = None, ):
|
1022
|
-
super().__init__(in_size=in_size,
|
1023
|
-
target_size=target_size,
|
1024
|
-
channel_axis=channel_axis,
|
1025
|
-
num_spatial_dims=2,
|
1026
|
-
operation=jnp.mean,
|
1027
|
-
name=name)
|
1028
|
-
|
1029
|
-
|
1030
|
-
class AdaptiveAvgPool3d(_AdaptivePool):
|
1031
|
-
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
1032
|
-
|
1033
|
-
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
1034
|
-
The number of output features is equal to the number of input planes.
|
1035
|
-
|
1036
|
-
Shape:
|
1037
|
-
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1038
|
-
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
1039
|
-
where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
|
1040
|
-
|
1041
|
-
Examples:
|
1042
|
-
|
1043
|
-
>>> import brainstate as brainstate
|
1044
|
-
>>> # target output size of 5x7x9
|
1045
|
-
>>> m = AdaptiveMaxPool3d((5, 7, 9))
|
1046
|
-
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
1047
|
-
>>> output = m(input)
|
1048
|
-
>>> output.shape
|
1049
|
-
(1, 5, 7, 9, 64)
|
1050
|
-
>>> # target output size of 7x7x7 (cube)
|
1051
|
-
>>> m = AdaptiveMaxPool3d(7)
|
1052
|
-
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1053
|
-
>>> output = m(input)
|
1054
|
-
>>> output.shape
|
1055
|
-
(1, 7, 7, 7, 64)
|
1056
|
-
>>> # target output size of 7x9x8
|
1057
|
-
>>> m = AdaptiveMaxPool3d((7, None, None))
|
1058
|
-
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1059
|
-
>>> output = m(input)
|
1060
|
-
>>> output.shape
|
1061
|
-
(1, 7, 9, 8, 64)
|
1062
|
-
|
1063
|
-
Parameters
|
1064
|
-
----------
|
1065
|
-
in_size: Sequence of int
|
1066
|
-
The shape of the input tensor.
|
1067
|
-
target_size: int, sequence of int
|
1068
|
-
The target output shape.
|
1069
|
-
channel_axis: int, optional
|
1070
|
-
Axis of the spatial channels for which pooling is skipped.
|
1071
|
-
If ``None``, there is no channel axis.
|
1072
|
-
name: str
|
1073
|
-
The class name.
|
1074
|
-
"""
|
1075
|
-
__module__ = 'brainstate.nn'
|
1076
|
-
|
1077
|
-
def __init__(self,
|
1078
|
-
target_size: Union[int, Sequence[int]],
|
1079
|
-
channel_axis: Optional[int] = -1,
|
1080
|
-
name: Optional[str] = None,
|
1081
|
-
in_size: Optional[Sequence[int]] = None, ):
|
1082
|
-
super().__init__(in_size=in_size,
|
1083
|
-
target_size=target_size,
|
1084
|
-
channel_axis=channel_axis,
|
1085
|
-
num_spatial_dims=3,
|
1086
|
-
operation=jnp.mean,
|
1087
|
-
name=name)
|
1088
|
-
|
1089
|
-
|
1090
|
-
class AdaptiveMaxPool1d(_AdaptivePool):
|
1091
|
-
"""Adaptive one-dimensional maximum down-sampling.
|
1092
|
-
|
1093
|
-
Parameters
|
1094
|
-
----------
|
1095
|
-
in_size: Sequence of int
|
1096
|
-
The shape of the input tensor.
|
1097
|
-
target_size: int, sequence of int
|
1098
|
-
The target output shape.
|
1099
|
-
channel_axis: int, optional
|
1100
|
-
Axis of the spatial channels for which pooling is skipped.
|
1101
|
-
If ``None``, there is no channel axis.
|
1102
|
-
name: str
|
1103
|
-
The class name.
|
1104
|
-
"""
|
1105
|
-
__module__ = 'brainstate.nn'
|
1106
|
-
|
1107
|
-
def __init__(self,
|
1108
|
-
target_size: Union[int, Sequence[int]],
|
1109
|
-
channel_axis: Optional[int] = -1,
|
1110
|
-
name: Optional[str] = None,
|
1111
|
-
in_size: Optional[Sequence[int]] = None, ):
|
1112
|
-
super().__init__(in_size=in_size,
|
1113
|
-
target_size=target_size,
|
1114
|
-
channel_axis=channel_axis,
|
1115
|
-
num_spatial_dims=1,
|
1116
|
-
operation=jnp.max,
|
1117
|
-
name=name)
|
1118
|
-
|
1119
|
-
|
1120
|
-
class AdaptiveMaxPool2d(_AdaptivePool):
|
1121
|
-
"""Adaptive two-dimensional maximum down-sampling.
|
1122
|
-
|
1123
|
-
Parameters
|
1124
|
-
----------
|
1125
|
-
in_size: Sequence of int
|
1126
|
-
The shape of the input tensor.
|
1127
|
-
target_size: int, sequence of int
|
1128
|
-
The target output shape.
|
1129
|
-
channel_axis: int, optional
|
1130
|
-
Axis of the spatial channels for which pooling is skipped.
|
1131
|
-
If ``None``, there is no channel axis.
|
1132
|
-
name: str
|
1133
|
-
The class name.
|
1134
|
-
"""
|
1135
|
-
__module__ = 'brainstate.nn'
|
1136
|
-
|
1137
|
-
def __init__(self,
|
1138
|
-
target_size: Union[int, Sequence[int]],
|
1139
|
-
channel_axis: Optional[int] = -1,
|
1140
|
-
name: Optional[str] = None,
|
1141
|
-
in_size: Optional[Sequence[int]] = None, ):
|
1142
|
-
super().__init__(in_size=in_size,
|
1143
|
-
target_size=target_size,
|
1144
|
-
channel_axis=channel_axis,
|
1145
|
-
num_spatial_dims=2,
|
1146
|
-
operation=jnp.max,
|
1147
|
-
name=name)
|
1148
|
-
|
1149
|
-
|
1150
|
-
class AdaptiveMaxPool3d(_AdaptivePool):
|
1151
|
-
"""Adaptive three-dimensional maximum down-sampling.
|
1152
|
-
|
1153
|
-
Parameters
|
1154
|
-
----------
|
1155
|
-
in_size: Sequence of int
|
1156
|
-
The shape of the input tensor.
|
1157
|
-
target_size: int, sequence of int
|
1158
|
-
The target output shape.
|
1159
|
-
channel_axis: int, optional
|
1160
|
-
Axis of the spatial channels for which pooling is skipped.
|
1161
|
-
If ``None``, there is no channel axis.
|
1162
|
-
name: str
|
1163
|
-
The class name.
|
1164
|
-
"""
|
1165
|
-
__module__ = 'brainstate.nn'
|
1166
|
-
|
1167
|
-
def __init__(self,
|
1168
|
-
target_size: Union[int, Sequence[int]],
|
1169
|
-
channel_axis: Optional[int] = -1,
|
1170
|
-
name: Optional[str] = None,
|
1171
|
-
in_size: Optional[Sequence[int]] = None, ):
|
1172
|
-
super().__init__(in_size=in_size,
|
1173
|
-
target_size=target_size,
|
1174
|
-
channel_axis=channel_axis,
|
1175
|
-
num_spatial_dims=3,
|
1176
|
-
operation=jnp.max,
|
1177
|
-
name=name)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from typing import Sequence, Optional
|
20
|
+
from typing import Union, Tuple, Callable, List
|
21
|
+
|
22
|
+
import brainunit as u
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
import numpy as np
|
26
|
+
|
27
|
+
from brainstate import environ
|
28
|
+
from brainstate.typing import Size
|
29
|
+
from ._module import Module
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'Flatten', 'Unflatten',
|
33
|
+
|
34
|
+
'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
|
35
|
+
'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
|
36
|
+
|
37
|
+
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
|
38
|
+
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
class Flatten(Module):
|
43
|
+
r"""
|
44
|
+
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
45
|
+
|
46
|
+
Shape:
|
47
|
+
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
|
48
|
+
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
|
49
|
+
number of dimensions including none.
|
50
|
+
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
in_size: Sequence of int. The shape of the input tensor.
|
54
|
+
start_axis: first dim to flatten (default = 1).
|
55
|
+
end_axis: last dim to flatten (default = -1).
|
56
|
+
|
57
|
+
Examples::
|
58
|
+
>>> import brainstate as brainstate
|
59
|
+
>>> inp = brainstate.random.randn(32, 1, 5, 5)
|
60
|
+
>>> # With default parameters
|
61
|
+
>>> m = Flatten()
|
62
|
+
>>> output = m(inp)
|
63
|
+
>>> output.shape
|
64
|
+
(32, 25)
|
65
|
+
>>> # With non-default parameters
|
66
|
+
>>> m = Flatten(0, 2)
|
67
|
+
>>> output = m(inp)
|
68
|
+
>>> output.shape
|
69
|
+
(160, 5)
|
70
|
+
"""
|
71
|
+
__module__ = 'brainstate.nn'
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
start_axis: int = 0,
|
76
|
+
end_axis: int = -1,
|
77
|
+
in_size: Optional[Size] = None
|
78
|
+
) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self.start_axis = start_axis
|
81
|
+
self.end_axis = end_axis
|
82
|
+
|
83
|
+
if in_size is not None:
|
84
|
+
self.in_size = tuple(in_size)
|
85
|
+
y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
|
86
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
87
|
+
self.out_size = y.shape
|
88
|
+
|
89
|
+
def update(self, x):
|
90
|
+
if self._in_size is None:
|
91
|
+
start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
|
92
|
+
else:
|
93
|
+
assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
|
94
|
+
dim_diff = x.ndim - len(self.in_size)
|
95
|
+
if self.in_size != x.shape[dim_diff:]:
|
96
|
+
raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
|
97
|
+
if self.start_axis >= 0:
|
98
|
+
start_axis = self.start_axis + dim_diff
|
99
|
+
else:
|
100
|
+
start_axis = x.ndim + self.start_axis
|
101
|
+
return u.math.flatten(x, start_axis, self.end_axis)
|
102
|
+
|
103
|
+
def __repr__(self) -> str:
|
104
|
+
return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
|
105
|
+
|
106
|
+
|
107
|
+
class Unflatten(Module):
|
108
|
+
r"""
|
109
|
+
Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
|
110
|
+
|
111
|
+
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
|
112
|
+
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
|
113
|
+
|
114
|
+
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
|
115
|
+
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
|
116
|
+
(tuple of `(name, size)` tuples) for `NamedTensor` input.
|
117
|
+
|
118
|
+
Shape:
|
119
|
+
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
|
120
|
+
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
|
121
|
+
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
|
122
|
+
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
axis: int, Dimension to be unflattened.
|
126
|
+
sizes: Sequence of int. New shape of the unflattened dimension.
|
127
|
+
in_size: Sequence of int. The shape of the input tensor.
|
128
|
+
"""
|
129
|
+
__module__ = 'brainstate.nn'
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
axis: int,
|
134
|
+
sizes: Size,
|
135
|
+
name: str = None,
|
136
|
+
in_size: Optional[Size] = None
|
137
|
+
) -> None:
|
138
|
+
super().__init__(name=name)
|
139
|
+
|
140
|
+
self.axis = axis
|
141
|
+
self.sizes = sizes
|
142
|
+
if isinstance(sizes, (tuple, list)):
|
143
|
+
for idx, elem in enumerate(sizes):
|
144
|
+
if not isinstance(elem, int):
|
145
|
+
raise TypeError("unflattened sizes must be tuple of ints, " +
|
146
|
+
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
|
147
|
+
else:
|
148
|
+
raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
|
149
|
+
|
150
|
+
if in_size is not None:
|
151
|
+
self.in_size = tuple(in_size)
|
152
|
+
y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
|
153
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
154
|
+
self.out_size = y.shape
|
155
|
+
|
156
|
+
def update(self, x):
|
157
|
+
return u.math.unflatten(x, self.axis, self.sizes)
|
158
|
+
|
159
|
+
def __repr__(self):
|
160
|
+
return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
|
161
|
+
|
162
|
+
|
163
|
+
class _MaxPool(Module):
|
164
|
+
def __init__(
|
165
|
+
self,
|
166
|
+
init_value: float,
|
167
|
+
computation: Callable,
|
168
|
+
pool_dim: int,
|
169
|
+
kernel_size: Size,
|
170
|
+
stride: Union[int, Sequence[int]] = None,
|
171
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
172
|
+
channel_axis: Optional[int] = -1,
|
173
|
+
name: Optional[str] = None,
|
174
|
+
in_size: Optional[Size] = None,
|
175
|
+
):
|
176
|
+
super().__init__(name=name)
|
177
|
+
|
178
|
+
self.init_value = init_value
|
179
|
+
self.computation = computation
|
180
|
+
self.pool_dim = pool_dim
|
181
|
+
|
182
|
+
# kernel_size
|
183
|
+
if isinstance(kernel_size, int):
|
184
|
+
kernel_size = (kernel_size,) * pool_dim
|
185
|
+
elif isinstance(kernel_size, Sequence):
|
186
|
+
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
187
|
+
assert all(
|
188
|
+
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
189
|
+
if len(kernel_size) != pool_dim:
|
190
|
+
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
191
|
+
else:
|
192
|
+
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
193
|
+
self.kernel_size = kernel_size
|
194
|
+
|
195
|
+
# stride
|
196
|
+
if stride is None:
|
197
|
+
stride = kernel_size
|
198
|
+
if isinstance(stride, int):
|
199
|
+
stride = (stride,) * pool_dim
|
200
|
+
elif isinstance(stride, Sequence):
|
201
|
+
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
202
|
+
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
203
|
+
if len(stride) != pool_dim:
|
204
|
+
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
205
|
+
else:
|
206
|
+
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
207
|
+
self.stride = stride
|
208
|
+
|
209
|
+
# padding
|
210
|
+
if isinstance(padding, str):
|
211
|
+
if padding not in ("SAME", "VALID"):
|
212
|
+
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
|
213
|
+
elif isinstance(padding, int):
|
214
|
+
padding = [(padding, padding) for _ in range(pool_dim)]
|
215
|
+
elif isinstance(padding, (list, tuple)):
|
216
|
+
if isinstance(padding[0], int):
|
217
|
+
if len(padding) == pool_dim:
|
218
|
+
padding = [(x, x) for x in padding]
|
219
|
+
else:
|
220
|
+
raise ValueError(f'If padding is a sequence of ints, it '
|
221
|
+
f'should has the length of {pool_dim}.')
|
222
|
+
else:
|
223
|
+
if not all([isinstance(x, (tuple, list)) for x in padding]):
|
224
|
+
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
|
225
|
+
if not all([len(x) == 2 for x in padding]):
|
226
|
+
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
|
227
|
+
if len(padding) == 1:
|
228
|
+
padding = tuple(padding) * pool_dim
|
229
|
+
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
|
230
|
+
else:
|
231
|
+
raise ValueError
|
232
|
+
self.padding = padding
|
233
|
+
|
234
|
+
# channel_axis
|
235
|
+
assert channel_axis is None or isinstance(channel_axis, int), \
|
236
|
+
f'channel_axis should be an int, but got {channel_axis}'
|
237
|
+
self.channel_axis = channel_axis
|
238
|
+
|
239
|
+
# in & out shapes
|
240
|
+
if in_size is not None:
|
241
|
+
in_size = tuple(in_size)
|
242
|
+
self.in_size = in_size
|
243
|
+
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
244
|
+
self.out_size = y.shape[1:]
|
245
|
+
|
246
|
+
def update(self, x):
|
247
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
248
|
+
if x.ndim < x_dim:
|
249
|
+
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
250
|
+
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
|
251
|
+
stride = self._infer_shape(x.ndim, self.stride, 1)
|
252
|
+
padding = (self.padding if isinstance(self.padding, str) else
|
253
|
+
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
254
|
+
r = jax.lax.reduce_window(
|
255
|
+
x,
|
256
|
+
init_value=self.init_value,
|
257
|
+
computation=self.computation,
|
258
|
+
window_dimensions=window_shape,
|
259
|
+
window_strides=stride,
|
260
|
+
padding=padding
|
261
|
+
)
|
262
|
+
return r
|
263
|
+
|
264
|
+
def _infer_shape(self, x_dim, inputs, element):
|
265
|
+
channel_axis = self.channel_axis
|
266
|
+
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
|
267
|
+
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
|
268
|
+
if channel_axis and channel_axis < 0:
|
269
|
+
channel_axis = x_dim + channel_axis
|
270
|
+
all_dims = list(range(x_dim))
|
271
|
+
if channel_axis is not None:
|
272
|
+
all_dims.pop(channel_axis)
|
273
|
+
pool_dims = all_dims[-self.pool_dim:]
|
274
|
+
results = [element] * x_dim
|
275
|
+
for i, dim in enumerate(pool_dims):
|
276
|
+
results[dim] = inputs[i]
|
277
|
+
return results
|
278
|
+
|
279
|
+
|
280
|
+
class _AvgPool(_MaxPool):
|
281
|
+
def update(self, x):
|
282
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
283
|
+
if x.ndim < x_dim:
|
284
|
+
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
285
|
+
dims = self._infer_shape(x.ndim, self.kernel_size, 1)
|
286
|
+
stride = self._infer_shape(x.ndim, self.stride, 1)
|
287
|
+
padding = (self.padding if isinstance(self.padding, str) else
|
288
|
+
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
289
|
+
pooled = jax.lax.reduce_window(x,
|
290
|
+
init_value=self.init_value,
|
291
|
+
computation=self.computation,
|
292
|
+
window_dimensions=dims,
|
293
|
+
window_strides=stride,
|
294
|
+
padding=padding)
|
295
|
+
if padding == "VALID":
|
296
|
+
# Avoid the extra reduce_window.
|
297
|
+
return pooled / np.prod(dims)
|
298
|
+
else:
|
299
|
+
# Count the number of valid entries at each input point, then use that for
|
300
|
+
# computing average. Assumes that any two arrays of same shape will be
|
301
|
+
# padded the same.
|
302
|
+
window_counts = jax.lax.reduce_window(jnp.ones_like(x),
|
303
|
+
init_value=self.init_value,
|
304
|
+
computation=self.computation,
|
305
|
+
window_dimensions=dims,
|
306
|
+
window_strides=stride,
|
307
|
+
padding=padding)
|
308
|
+
assert pooled.shape == window_counts.shape
|
309
|
+
return pooled / window_counts
|
310
|
+
|
311
|
+
|
312
|
+
class MaxPool1d(_MaxPool):
|
313
|
+
r"""Applies a 1D max pooling over an input signal composed of several input planes.
|
314
|
+
|
315
|
+
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
|
316
|
+
and output :math:`(N, L_{out}, C)` can be precisely described as:
|
317
|
+
|
318
|
+
.. math::
|
319
|
+
out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
|
320
|
+
input(N_i, stride \times k + m, C_j)
|
321
|
+
|
322
|
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
323
|
+
for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
|
324
|
+
sliding window. This `link`_ has a nice visualization of the pooling parameters.
|
325
|
+
|
326
|
+
Shape:
|
327
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
328
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
329
|
+
|
330
|
+
.. math::
|
331
|
+
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
332
|
+
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
|
333
|
+
|
334
|
+
|
335
|
+
Examples::
|
336
|
+
|
337
|
+
>>> import brainstate as brainstate
|
338
|
+
>>> # pool of size=3, stride=2
|
339
|
+
>>> m = MaxPool1d(3, stride=2, channel_axis=-1)
|
340
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
341
|
+
>>> output = m(input)
|
342
|
+
>>> output.shape
|
343
|
+
(20, 24, 16)
|
344
|
+
|
345
|
+
Parameters
|
346
|
+
----------
|
347
|
+
in_size: Sequence of int
|
348
|
+
The shape of the input tensor.
|
349
|
+
kernel_size: int, sequence of int
|
350
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
351
|
+
stride: int, sequence of int
|
352
|
+
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
353
|
+
padding: str, int, sequence of tuple
|
354
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
355
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
356
|
+
and after each spatial dimension.
|
357
|
+
channel_axis: int, optional
|
358
|
+
Axis of the spatial channels for which pooling is skipped.
|
359
|
+
If ``None``, there is no channel axis.
|
360
|
+
name: optional, str
|
361
|
+
The object name.
|
362
|
+
|
363
|
+
.. _link:
|
364
|
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
365
|
+
"""
|
366
|
+
__module__ = 'brainstate.nn'
|
367
|
+
|
368
|
+
def __init__(
|
369
|
+
self,
|
370
|
+
kernel_size: Size,
|
371
|
+
stride: Union[int, Sequence[int]] = None,
|
372
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
373
|
+
channel_axis: Optional[int] = -1,
|
374
|
+
name: Optional[str] = None,
|
375
|
+
in_size: Optional[Size] = None,
|
376
|
+
):
|
377
|
+
super().__init__(in_size=in_size,
|
378
|
+
init_value=-jax.numpy.inf,
|
379
|
+
computation=jax.lax.max,
|
380
|
+
pool_dim=1,
|
381
|
+
kernel_size=kernel_size,
|
382
|
+
stride=stride,
|
383
|
+
padding=padding,
|
384
|
+
channel_axis=channel_axis,
|
385
|
+
name=name)
|
386
|
+
|
387
|
+
|
388
|
+
class MaxPool2d(_MaxPool):
|
389
|
+
r"""Applies a 2D max pooling over an input signal composed of several input planes.
|
390
|
+
|
391
|
+
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
392
|
+
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
393
|
+
can be precisely described as:
|
394
|
+
|
395
|
+
.. math::
|
396
|
+
\begin{aligned}
|
397
|
+
out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
398
|
+
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
399
|
+
\text{stride[1]} \times w + n, C_j)
|
400
|
+
\end{aligned}
|
401
|
+
|
402
|
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
403
|
+
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
404
|
+
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
405
|
+
|
406
|
+
|
407
|
+
Shape:
|
408
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
409
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
410
|
+
|
411
|
+
.. math::
|
412
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
|
413
|
+
\times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
|
414
|
+
|
415
|
+
.. math::
|
416
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
|
417
|
+
\times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
|
418
|
+
|
419
|
+
Examples::
|
420
|
+
|
421
|
+
>>> import brainstate as brainstate
|
422
|
+
>>> # pool of square window of size=3, stride=2
|
423
|
+
>>> m = MaxPool2d(3, stride=2)
|
424
|
+
>>> # pool of non-square window
|
425
|
+
>>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
|
426
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
427
|
+
>>> output = m(input)
|
428
|
+
>>> output.shape
|
429
|
+
(20, 24, 31, 16)
|
430
|
+
|
431
|
+
.. _link:
|
432
|
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
433
|
+
|
434
|
+
Parameters
|
435
|
+
----------
|
436
|
+
in_size: Sequence of int
|
437
|
+
The shape of the input tensor.
|
438
|
+
kernel_size: int, sequence of int
|
439
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
440
|
+
stride: int, sequence of int
|
441
|
+
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
442
|
+
padding: str, int, sequence of tuple
|
443
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
444
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
445
|
+
and after each spatial dimension.
|
446
|
+
channel_axis: int, optional
|
447
|
+
Axis of the spatial channels for which pooling is skipped.
|
448
|
+
If ``None``, there is no channel axis.
|
449
|
+
name: optional, str
|
450
|
+
The object name.
|
451
|
+
|
452
|
+
"""
|
453
|
+
__module__ = 'brainstate.nn'
|
454
|
+
|
455
|
+
def __init__(
|
456
|
+
self,
|
457
|
+
kernel_size: Size,
|
458
|
+
stride: Union[int, Sequence[int]] = None,
|
459
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
460
|
+
channel_axis: Optional[int] = -1,
|
461
|
+
name: Optional[str] = None,
|
462
|
+
in_size: Optional[Size] = None,
|
463
|
+
):
|
464
|
+
super().__init__(in_size=in_size,
|
465
|
+
init_value=-jax.numpy.inf,
|
466
|
+
computation=jax.lax.max,
|
467
|
+
pool_dim=2,
|
468
|
+
kernel_size=kernel_size,
|
469
|
+
stride=stride,
|
470
|
+
padding=padding,
|
471
|
+
channel_axis=channel_axis,
|
472
|
+
name=name)
|
473
|
+
|
474
|
+
|
475
|
+
class MaxPool3d(_MaxPool):
|
476
|
+
r"""Applies a 3D max pooling over an input signal composed of several input planes.
|
477
|
+
|
478
|
+
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
479
|
+
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
480
|
+
can be precisely described as:
|
481
|
+
|
482
|
+
.. math::
|
483
|
+
\begin{aligned}
|
484
|
+
\text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
485
|
+
& \text{input}(N_i, \text{stride[0]} \times d + k,
|
486
|
+
\text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
|
487
|
+
\end{aligned}
|
488
|
+
|
489
|
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
490
|
+
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
491
|
+
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
492
|
+
|
493
|
+
|
494
|
+
Shape:
|
495
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
496
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
497
|
+
|
498
|
+
.. math::
|
499
|
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
|
500
|
+
(\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
|
501
|
+
|
502
|
+
.. math::
|
503
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
|
504
|
+
(\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
|
505
|
+
|
506
|
+
.. math::
|
507
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
|
508
|
+
(\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
|
509
|
+
|
510
|
+
Examples::
|
511
|
+
|
512
|
+
>>> import brainstate as brainstate
|
513
|
+
>>> # pool of square window of size=3, stride=2
|
514
|
+
>>> m = MaxPool3d(3, stride=2)
|
515
|
+
>>> # pool of non-square window
|
516
|
+
>>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
517
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
518
|
+
>>> output = m(input)
|
519
|
+
>>> output.shape
|
520
|
+
(20, 24, 43, 15, 16)
|
521
|
+
|
522
|
+
.. _link:
|
523
|
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
524
|
+
|
525
|
+
Parameters
|
526
|
+
----------
|
527
|
+
in_size: Sequence of int
|
528
|
+
The shape of the input tensor.
|
529
|
+
kernel_size: int, sequence of int
|
530
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
531
|
+
stride: int, sequence of int
|
532
|
+
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
533
|
+
padding: str, int, sequence of tuple
|
534
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
535
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
536
|
+
and after each spatial dimension.
|
537
|
+
channel_axis: int, optional
|
538
|
+
Axis of the spatial channels for which pooling is skipped.
|
539
|
+
If ``None``, there is no channel axis.
|
540
|
+
name: optional, str
|
541
|
+
The object name.
|
542
|
+
|
543
|
+
"""
|
544
|
+
__module__ = 'brainstate.nn'
|
545
|
+
|
546
|
+
def __init__(
|
547
|
+
self,
|
548
|
+
kernel_size: Size,
|
549
|
+
stride: Union[int, Sequence[int]] = None,
|
550
|
+
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
551
|
+
channel_axis: Optional[int] = -1,
|
552
|
+
name: Optional[str] = None,
|
553
|
+
in_size: Optional[Size] = None,
|
554
|
+
):
|
555
|
+
super().__init__(in_size=in_size,
|
556
|
+
init_value=-jax.numpy.inf,
|
557
|
+
computation=jax.lax.max,
|
558
|
+
pool_dim=3,
|
559
|
+
kernel_size=kernel_size,
|
560
|
+
stride=stride,
|
561
|
+
padding=padding,
|
562
|
+
channel_axis=channel_axis,
|
563
|
+
name=name)
|
564
|
+
|
565
|
+
|
566
|
+
class AvgPool1d(_AvgPool):
|
567
|
+
r"""Applies a 1D average pooling over an input signal composed of several input planes.
|
568
|
+
|
569
|
+
In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
|
570
|
+
output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
|
571
|
+
can be precisely described as:
|
572
|
+
|
573
|
+
.. math::
|
574
|
+
|
575
|
+
\text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
|
576
|
+
\text{input}(N_i, \text{stride} \times l + m, C_j)
|
577
|
+
|
578
|
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
579
|
+
for :attr:`padding` number of points.
|
580
|
+
|
581
|
+
Shape:
|
582
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
583
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
584
|
+
|
585
|
+
.. math::
|
586
|
+
L_{out} = \left\lfloor \frac{L_{in} +
|
587
|
+
2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
588
|
+
|
589
|
+
Examples::
|
590
|
+
|
591
|
+
>>> import brainstate as brainstate
|
592
|
+
>>> # pool with window of size=3, stride=2
|
593
|
+
>>> m = AvgPool1d(3, stride=2)
|
594
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
595
|
+
>>> m(input).shape
|
596
|
+
(20, 24, 16)
|
597
|
+
|
598
|
+
Parameters
|
599
|
+
----------
|
600
|
+
in_size: Sequence of int
|
601
|
+
The shape of the input tensor.
|
602
|
+
kernel_size: int, sequence of int
|
603
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
604
|
+
stride: int, sequence of int
|
605
|
+
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
606
|
+
padding: str, int, sequence of tuple
|
607
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
608
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
609
|
+
and after each spatial dimension.
|
610
|
+
channel_axis: int, optional
|
611
|
+
Axis of the spatial channels for which pooling is skipped.
|
612
|
+
If ``None``, there is no channel axis.
|
613
|
+
name: optional, str
|
614
|
+
The object name.
|
615
|
+
|
616
|
+
"""
|
617
|
+
__module__ = 'brainstate.nn'
|
618
|
+
|
619
|
+
def __init__(
|
620
|
+
self,
|
621
|
+
kernel_size: Size,
|
622
|
+
stride: Union[int, Sequence[int]] = 1,
|
623
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
624
|
+
channel_axis: Optional[int] = -1,
|
625
|
+
name: Optional[str] = None,
|
626
|
+
in_size: Optional[Size] = None,
|
627
|
+
):
|
628
|
+
super().__init__(in_size=in_size,
|
629
|
+
init_value=0.,
|
630
|
+
computation=jax.lax.add,
|
631
|
+
pool_dim=1,
|
632
|
+
kernel_size=kernel_size,
|
633
|
+
stride=stride,
|
634
|
+
padding=padding,
|
635
|
+
channel_axis=channel_axis,
|
636
|
+
name=name)
|
637
|
+
|
638
|
+
|
639
|
+
class AvgPool2d(_AvgPool):
|
640
|
+
r"""Applies a 2D average pooling over an input signal composed of several input planes.
|
641
|
+
|
642
|
+
In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
|
643
|
+
output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
|
644
|
+
can be precisely described as:
|
645
|
+
|
646
|
+
.. math::
|
647
|
+
|
648
|
+
out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
|
649
|
+
input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
|
650
|
+
|
651
|
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
652
|
+
for :attr:`padding` number of points.
|
653
|
+
|
654
|
+
Shape:
|
655
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
656
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
657
|
+
|
658
|
+
.. math::
|
659
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
|
660
|
+
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
661
|
+
|
662
|
+
.. math::
|
663
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
|
664
|
+
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
665
|
+
|
666
|
+
Examples::
|
667
|
+
|
668
|
+
>>> import brainstate as brainstate
|
669
|
+
>>> # pool of square window of size=3, stride=2
|
670
|
+
>>> m = AvgPool2d(3, stride=2)
|
671
|
+
>>> # pool of non-square window
|
672
|
+
>>> m = AvgPool2d((3, 2), stride=(2, 1))
|
673
|
+
>>> input = brainstate.random.randn(20, 50, 32, , 16)
|
674
|
+
>>> output = m(input)
|
675
|
+
>>> output.shape
|
676
|
+
(20, 24, 31, 16)
|
677
|
+
|
678
|
+
Parameters
|
679
|
+
----------
|
680
|
+
in_size: Sequence of int
|
681
|
+
The shape of the input tensor.
|
682
|
+
kernel_size: int, sequence of int
|
683
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
684
|
+
stride: int, sequence of int
|
685
|
+
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
686
|
+
padding: str, int, sequence of tuple
|
687
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
688
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
689
|
+
and after each spatial dimension.
|
690
|
+
channel_axis: int, optional
|
691
|
+
Axis of the spatial channels for which pooling is skipped.
|
692
|
+
If ``None``, there is no channel axis.
|
693
|
+
name: optional, str
|
694
|
+
The object name.
|
695
|
+
"""
|
696
|
+
__module__ = 'brainstate.nn'
|
697
|
+
|
698
|
+
def __init__(
|
699
|
+
self,
|
700
|
+
kernel_size: Size,
|
701
|
+
stride: Union[int, Sequence[int]] = 1,
|
702
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
703
|
+
channel_axis: Optional[int] = -1,
|
704
|
+
name: Optional[str] = None,
|
705
|
+
in_size: Optional[Size] = None,
|
706
|
+
):
|
707
|
+
super().__init__(in_size=in_size,
|
708
|
+
init_value=0.,
|
709
|
+
computation=jax.lax.add,
|
710
|
+
pool_dim=2,
|
711
|
+
kernel_size=kernel_size,
|
712
|
+
stride=stride,
|
713
|
+
padding=padding,
|
714
|
+
channel_axis=channel_axis,
|
715
|
+
name=name)
|
716
|
+
|
717
|
+
|
718
|
+
class AvgPool3d(_AvgPool):
|
719
|
+
r"""Applies a 3D average pooling over an input signal composed of several input planes.
|
720
|
+
|
721
|
+
|
722
|
+
In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
|
723
|
+
output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
724
|
+
can be precisely described as:
|
725
|
+
|
726
|
+
.. math::
|
727
|
+
\begin{aligned}
|
728
|
+
\text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
|
729
|
+
& \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
|
730
|
+
\text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
|
731
|
+
{kD \times kH \times kW}
|
732
|
+
\end{aligned}
|
733
|
+
|
734
|
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
|
735
|
+
for :attr:`padding` number of points.
|
736
|
+
|
737
|
+
Shape:
|
738
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
739
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
|
740
|
+
:math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
741
|
+
|
742
|
+
.. math::
|
743
|
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
|
744
|
+
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
745
|
+
|
746
|
+
.. math::
|
747
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
|
748
|
+
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
749
|
+
|
750
|
+
.. math::
|
751
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
|
752
|
+
\text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
753
|
+
|
754
|
+
Examples::
|
755
|
+
|
756
|
+
>>> import brainstate as brainstate
|
757
|
+
>>> # pool of square window of size=3, stride=2
|
758
|
+
>>> m = AvgPool3d(3, stride=2)
|
759
|
+
>>> # pool of non-square window
|
760
|
+
>>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
|
761
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
762
|
+
>>> output = m(input)
|
763
|
+
>>> output.shape
|
764
|
+
(20, 24, 43, 15, 16)
|
765
|
+
|
766
|
+
Parameters
|
767
|
+
----------
|
768
|
+
in_size: Sequence of int
|
769
|
+
The shape of the input tensor.
|
770
|
+
kernel_size: int, sequence of int
|
771
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
772
|
+
stride: int, sequence of int
|
773
|
+
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
774
|
+
padding: str, int, sequence of tuple
|
775
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
776
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
777
|
+
and after each spatial dimension.
|
778
|
+
channel_axis: int, optional
|
779
|
+
Axis of the spatial channels for which pooling is skipped.
|
780
|
+
If ``None``, there is no channel axis.
|
781
|
+
name: optional, str
|
782
|
+
The object name.
|
783
|
+
|
784
|
+
"""
|
785
|
+
__module__ = 'brainstate.nn'
|
786
|
+
|
787
|
+
def __init__(
|
788
|
+
self,
|
789
|
+
kernel_size: Size,
|
790
|
+
stride: Union[int, Sequence[int]] = 1,
|
791
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
792
|
+
channel_axis: Optional[int] = -1,
|
793
|
+
name: Optional[str] = None,
|
794
|
+
in_size: Optional[Size] = None,
|
795
|
+
):
|
796
|
+
super().__init__(in_size=in_size,
|
797
|
+
init_value=0.,
|
798
|
+
computation=jax.lax.add,
|
799
|
+
pool_dim=3,
|
800
|
+
kernel_size=kernel_size,
|
801
|
+
stride=stride,
|
802
|
+
padding=padding,
|
803
|
+
channel_axis=channel_axis,
|
804
|
+
name=name)
|
805
|
+
|
806
|
+
|
807
|
+
def _adaptive_pool1d(x, target_size: int, operation: Callable):
|
808
|
+
"""Adaptive pool 1D.
|
809
|
+
|
810
|
+
Args:
|
811
|
+
x: The input. Should be a JAX array of shape `(dim,)`.
|
812
|
+
target_size: The shape of the output after the pooling operation `(target_size,)`.
|
813
|
+
operation: The pooling operation to be performed on the input array.
|
814
|
+
|
815
|
+
Returns:
|
816
|
+
A JAX array of shape `(target_size, )`.
|
817
|
+
"""
|
818
|
+
size = jnp.size(x)
|
819
|
+
num_head_arrays = size % target_size
|
820
|
+
num_block = size // target_size
|
821
|
+
if num_head_arrays != 0:
|
822
|
+
head_end_index = num_head_arrays * (num_block + 1)
|
823
|
+
heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
|
824
|
+
tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
|
825
|
+
outs = jnp.concatenate([heads, tails])
|
826
|
+
else:
|
827
|
+
outs = jax.vmap(operation)(x.reshape(-1, num_block))
|
828
|
+
return outs
|
829
|
+
|
830
|
+
|
831
|
+
def _generate_vmap(fun: Callable, map_axes: List[int]):
|
832
|
+
map_axes = sorted(map_axes)
|
833
|
+
for axis in map_axes:
|
834
|
+
fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
|
835
|
+
return fun
|
836
|
+
|
837
|
+
|
838
|
+
class _AdaptivePool(Module):
|
839
|
+
"""General N dimensional adaptive down-sampling to a target shape.
|
840
|
+
|
841
|
+
Parameters
|
842
|
+
----------
|
843
|
+
in_size: Sequence of int
|
844
|
+
The shape of the input tensor.
|
845
|
+
target_size: int, sequence of int
|
846
|
+
The target output shape.
|
847
|
+
num_spatial_dims: int
|
848
|
+
The number of spatial dimensions.
|
849
|
+
channel_axis: int, optional
|
850
|
+
Axis of the spatial channels for which pooling is skipped.
|
851
|
+
If ``None``, there is no channel axis.
|
852
|
+
operation: Callable
|
853
|
+
The down-sampling operation.
|
854
|
+
name: str
|
855
|
+
The class name.
|
856
|
+
"""
|
857
|
+
|
858
|
+
def __init__(
|
859
|
+
self,
|
860
|
+
in_size: Size,
|
861
|
+
target_size: Size,
|
862
|
+
num_spatial_dims: int,
|
863
|
+
operation: Callable,
|
864
|
+
channel_axis: Optional[int] = -1,
|
865
|
+
name: Optional[str] = None,
|
866
|
+
):
|
867
|
+
super().__init__(name=name)
|
868
|
+
|
869
|
+
self.channel_axis = channel_axis
|
870
|
+
self.operation = operation
|
871
|
+
if isinstance(target_size, int):
|
872
|
+
self.target_shape = (target_size,) * num_spatial_dims
|
873
|
+
elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
|
874
|
+
self.target_shape = target_size
|
875
|
+
else:
|
876
|
+
raise ValueError("`target_size` must either be an int or tuple of length "
|
877
|
+
f"{num_spatial_dims} containing ints.")
|
878
|
+
|
879
|
+
# in & out shapes
|
880
|
+
if in_size is not None:
|
881
|
+
in_size = tuple(in_size)
|
882
|
+
self.in_size = in_size
|
883
|
+
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
884
|
+
self.out_size = y.shape[1:]
|
885
|
+
|
886
|
+
def update(self, x):
|
887
|
+
"""Input-output mapping.
|
888
|
+
|
889
|
+
Parameters
|
890
|
+
----------
|
891
|
+
x: Array
|
892
|
+
Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
|
893
|
+
or `(..., dim_1, dim_2)`.
|
894
|
+
"""
|
895
|
+
# channel axis
|
896
|
+
channel_axis = self.channel_axis
|
897
|
+
|
898
|
+
if channel_axis:
|
899
|
+
if not 0 <= abs(channel_axis) < x.ndim:
|
900
|
+
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
|
901
|
+
if channel_axis < 0:
|
902
|
+
channel_axis = x.ndim + channel_axis
|
903
|
+
# input dimension
|
904
|
+
if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
|
905
|
+
raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
|
906
|
+
f"dimensions (channel_axis={self.channel_axis}). "
|
907
|
+
f"But got {x.ndim} dimensions.")
|
908
|
+
# pooling dimensions
|
909
|
+
pool_dims = list(range(x.ndim))
|
910
|
+
if channel_axis:
|
911
|
+
pool_dims.pop(channel_axis)
|
912
|
+
|
913
|
+
# pooling
|
914
|
+
for i, di in enumerate(pool_dims[-len(self.target_shape):]):
|
915
|
+
poo_axes = [j for j in range(x.ndim) if j != di]
|
916
|
+
op = _generate_vmap(_adaptive_pool1d, poo_axes)
|
917
|
+
x = op(x, self.target_shape[i], self.operation)
|
918
|
+
return x
|
919
|
+
|
920
|
+
|
921
|
+
class AdaptiveAvgPool1d(_AdaptivePool):
|
922
|
+
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
923
|
+
|
924
|
+
The output size is :math:`L_{out}`, for any input size.
|
925
|
+
The number of output features is equal to the number of input planes.
|
926
|
+
|
927
|
+
Shape:
|
928
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
929
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
930
|
+
:math:`L_{out}=\text{output\_size}`.
|
931
|
+
|
932
|
+
Examples:
|
933
|
+
|
934
|
+
>>> import brainstate as brainstate
|
935
|
+
>>> # target output size of 5
|
936
|
+
>>> m = AdaptiveMaxPool1d(5)
|
937
|
+
>>> input = brainstate.random.randn(1, 64, 8)
|
938
|
+
>>> output = m(input)
|
939
|
+
>>> output.shape
|
940
|
+
(1, 5, 8)
|
941
|
+
|
942
|
+
Parameters
|
943
|
+
----------
|
944
|
+
in_size: Sequence of int
|
945
|
+
The shape of the input tensor.
|
946
|
+
target_size: int, sequence of int
|
947
|
+
The target output shape.
|
948
|
+
channel_axis: int, optional
|
949
|
+
Axis of the spatial channels for which pooling is skipped.
|
950
|
+
If ``None``, there is no channel axis.
|
951
|
+
name: str
|
952
|
+
The class name.
|
953
|
+
"""
|
954
|
+
__module__ = 'brainstate.nn'
|
955
|
+
|
956
|
+
def __init__(self,
|
957
|
+
target_size: Union[int, Sequence[int]],
|
958
|
+
channel_axis: Optional[int] = -1,
|
959
|
+
name: Optional[str] = None,
|
960
|
+
in_size: Optional[Sequence[int]] = None, ):
|
961
|
+
super().__init__(in_size=in_size,
|
962
|
+
target_size=target_size,
|
963
|
+
channel_axis=channel_axis,
|
964
|
+
num_spatial_dims=1,
|
965
|
+
operation=jnp.mean,
|
966
|
+
name=name)
|
967
|
+
|
968
|
+
|
969
|
+
class AdaptiveAvgPool2d(_AdaptivePool):
|
970
|
+
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
971
|
+
|
972
|
+
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
973
|
+
The number of output features is equal to the number of input planes.
|
974
|
+
|
975
|
+
Shape:
|
976
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
977
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
978
|
+
:math:`(H_{out}, W_{out})=\text{output\_size}`.
|
979
|
+
|
980
|
+
Examples:
|
981
|
+
|
982
|
+
>>> import brainstate as brainstate
|
983
|
+
>>> # target output size of 5x7
|
984
|
+
>>> m = AdaptiveMaxPool2d((5, 7))
|
985
|
+
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
986
|
+
>>> output = m(input)
|
987
|
+
>>> output.shape
|
988
|
+
(1, 5, 7, 64)
|
989
|
+
>>> # target output size of 7x7 (square)
|
990
|
+
>>> m = AdaptiveMaxPool2d(7)
|
991
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
992
|
+
>>> output = m(input)
|
993
|
+
>>> output.shape
|
994
|
+
(1, 7, 7, 64)
|
995
|
+
>>> # target output size of 10x7
|
996
|
+
>>> m = AdaptiveMaxPool2d((None, 7))
|
997
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
998
|
+
>>> output = m(input)
|
999
|
+
>>> output.shape
|
1000
|
+
(1, 10, 7, 64)
|
1001
|
+
|
1002
|
+
Parameters
|
1003
|
+
----------
|
1004
|
+
in_size: Sequence of int
|
1005
|
+
The shape of the input tensor.
|
1006
|
+
target_size: int, sequence of int
|
1007
|
+
The target output shape.
|
1008
|
+
channel_axis: int, optional
|
1009
|
+
Axis of the spatial channels for which pooling is skipped.
|
1010
|
+
If ``None``, there is no channel axis.
|
1011
|
+
name: str
|
1012
|
+
The class name.
|
1013
|
+
"""
|
1014
|
+
__module__ = 'brainstate.nn'
|
1015
|
+
|
1016
|
+
def __init__(self,
|
1017
|
+
target_size: Union[int, Sequence[int]],
|
1018
|
+
channel_axis: Optional[int] = -1,
|
1019
|
+
name: Optional[str] = None,
|
1020
|
+
|
1021
|
+
in_size: Optional[Sequence[int]] = None, ):
|
1022
|
+
super().__init__(in_size=in_size,
|
1023
|
+
target_size=target_size,
|
1024
|
+
channel_axis=channel_axis,
|
1025
|
+
num_spatial_dims=2,
|
1026
|
+
operation=jnp.mean,
|
1027
|
+
name=name)
|
1028
|
+
|
1029
|
+
|
1030
|
+
class AdaptiveAvgPool3d(_AdaptivePool):
|
1031
|
+
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
1032
|
+
|
1033
|
+
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
1034
|
+
The number of output features is equal to the number of input planes.
|
1035
|
+
|
1036
|
+
Shape:
|
1037
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1038
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
1039
|
+
where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
|
1040
|
+
|
1041
|
+
Examples:
|
1042
|
+
|
1043
|
+
>>> import brainstate as brainstate
|
1044
|
+
>>> # target output size of 5x7x9
|
1045
|
+
>>> m = AdaptiveMaxPool3d((5, 7, 9))
|
1046
|
+
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
1047
|
+
>>> output = m(input)
|
1048
|
+
>>> output.shape
|
1049
|
+
(1, 5, 7, 9, 64)
|
1050
|
+
>>> # target output size of 7x7x7 (cube)
|
1051
|
+
>>> m = AdaptiveMaxPool3d(7)
|
1052
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1053
|
+
>>> output = m(input)
|
1054
|
+
>>> output.shape
|
1055
|
+
(1, 7, 7, 7, 64)
|
1056
|
+
>>> # target output size of 7x9x8
|
1057
|
+
>>> m = AdaptiveMaxPool3d((7, None, None))
|
1058
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1059
|
+
>>> output = m(input)
|
1060
|
+
>>> output.shape
|
1061
|
+
(1, 7, 9, 8, 64)
|
1062
|
+
|
1063
|
+
Parameters
|
1064
|
+
----------
|
1065
|
+
in_size: Sequence of int
|
1066
|
+
The shape of the input tensor.
|
1067
|
+
target_size: int, sequence of int
|
1068
|
+
The target output shape.
|
1069
|
+
channel_axis: int, optional
|
1070
|
+
Axis of the spatial channels for which pooling is skipped.
|
1071
|
+
If ``None``, there is no channel axis.
|
1072
|
+
name: str
|
1073
|
+
The class name.
|
1074
|
+
"""
|
1075
|
+
__module__ = 'brainstate.nn'
|
1076
|
+
|
1077
|
+
def __init__(self,
|
1078
|
+
target_size: Union[int, Sequence[int]],
|
1079
|
+
channel_axis: Optional[int] = -1,
|
1080
|
+
name: Optional[str] = None,
|
1081
|
+
in_size: Optional[Sequence[int]] = None, ):
|
1082
|
+
super().__init__(in_size=in_size,
|
1083
|
+
target_size=target_size,
|
1084
|
+
channel_axis=channel_axis,
|
1085
|
+
num_spatial_dims=3,
|
1086
|
+
operation=jnp.mean,
|
1087
|
+
name=name)
|
1088
|
+
|
1089
|
+
|
1090
|
+
class AdaptiveMaxPool1d(_AdaptivePool):
|
1091
|
+
"""Adaptive one-dimensional maximum down-sampling.
|
1092
|
+
|
1093
|
+
Parameters
|
1094
|
+
----------
|
1095
|
+
in_size: Sequence of int
|
1096
|
+
The shape of the input tensor.
|
1097
|
+
target_size: int, sequence of int
|
1098
|
+
The target output shape.
|
1099
|
+
channel_axis: int, optional
|
1100
|
+
Axis of the spatial channels for which pooling is skipped.
|
1101
|
+
If ``None``, there is no channel axis.
|
1102
|
+
name: str
|
1103
|
+
The class name.
|
1104
|
+
"""
|
1105
|
+
__module__ = 'brainstate.nn'
|
1106
|
+
|
1107
|
+
def __init__(self,
|
1108
|
+
target_size: Union[int, Sequence[int]],
|
1109
|
+
channel_axis: Optional[int] = -1,
|
1110
|
+
name: Optional[str] = None,
|
1111
|
+
in_size: Optional[Sequence[int]] = None, ):
|
1112
|
+
super().__init__(in_size=in_size,
|
1113
|
+
target_size=target_size,
|
1114
|
+
channel_axis=channel_axis,
|
1115
|
+
num_spatial_dims=1,
|
1116
|
+
operation=jnp.max,
|
1117
|
+
name=name)
|
1118
|
+
|
1119
|
+
|
1120
|
+
class AdaptiveMaxPool2d(_AdaptivePool):
|
1121
|
+
"""Adaptive two-dimensional maximum down-sampling.
|
1122
|
+
|
1123
|
+
Parameters
|
1124
|
+
----------
|
1125
|
+
in_size: Sequence of int
|
1126
|
+
The shape of the input tensor.
|
1127
|
+
target_size: int, sequence of int
|
1128
|
+
The target output shape.
|
1129
|
+
channel_axis: int, optional
|
1130
|
+
Axis of the spatial channels for which pooling is skipped.
|
1131
|
+
If ``None``, there is no channel axis.
|
1132
|
+
name: str
|
1133
|
+
The class name.
|
1134
|
+
"""
|
1135
|
+
__module__ = 'brainstate.nn'
|
1136
|
+
|
1137
|
+
def __init__(self,
|
1138
|
+
target_size: Union[int, Sequence[int]],
|
1139
|
+
channel_axis: Optional[int] = -1,
|
1140
|
+
name: Optional[str] = None,
|
1141
|
+
in_size: Optional[Sequence[int]] = None, ):
|
1142
|
+
super().__init__(in_size=in_size,
|
1143
|
+
target_size=target_size,
|
1144
|
+
channel_axis=channel_axis,
|
1145
|
+
num_spatial_dims=2,
|
1146
|
+
operation=jnp.max,
|
1147
|
+
name=name)
|
1148
|
+
|
1149
|
+
|
1150
|
+
class AdaptiveMaxPool3d(_AdaptivePool):
|
1151
|
+
"""Adaptive three-dimensional maximum down-sampling.
|
1152
|
+
|
1153
|
+
Parameters
|
1154
|
+
----------
|
1155
|
+
in_size: Sequence of int
|
1156
|
+
The shape of the input tensor.
|
1157
|
+
target_size: int, sequence of int
|
1158
|
+
The target output shape.
|
1159
|
+
channel_axis: int, optional
|
1160
|
+
Axis of the spatial channels for which pooling is skipped.
|
1161
|
+
If ``None``, there is no channel axis.
|
1162
|
+
name: str
|
1163
|
+
The class name.
|
1164
|
+
"""
|
1165
|
+
__module__ = 'brainstate.nn'
|
1166
|
+
|
1167
|
+
def __init__(self,
|
1168
|
+
target_size: Union[int, Sequence[int]],
|
1169
|
+
channel_axis: Optional[int] = -1,
|
1170
|
+
name: Optional[str] = None,
|
1171
|
+
in_size: Optional[Sequence[int]] = None, ):
|
1172
|
+
super().__init__(in_size=in_size,
|
1173
|
+
target_size=target_size,
|
1174
|
+
channel_axis=channel_axis,
|
1175
|
+
num_spatial_dims=3,
|
1176
|
+
operation=jnp.max,
|
1177
|
+
name=name)
|