brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_poolings.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -33,6 +33,8 @@ __all__ = [
|
|
33
33
|
|
34
34
|
'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
|
35
35
|
'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
|
36
|
+
'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
|
37
|
+
'LPPool1d', 'LPPool2d', 'LPPool3d',
|
36
38
|
|
37
39
|
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
|
38
40
|
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
|
@@ -49,13 +51,20 @@ class Flatten(Module):
|
|
49
51
|
number of dimensions including none.
|
50
52
|
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
|
51
53
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
start_axis : int, optional
|
57
|
+
First dim to flatten (default = 0).
|
58
|
+
end_axis : int, optional
|
59
|
+
Last dim to flatten (default = -1).
|
60
|
+
in_size : Sequence of int, optional
|
61
|
+
The shape of the input tensor.
|
62
|
+
|
63
|
+
Examples
|
64
|
+
--------
|
65
|
+
.. code-block:: python
|
66
|
+
|
67
|
+
>>> import brainstate
|
59
68
|
>>> inp = brainstate.random.randn(32, 1, 5, 5)
|
60
69
|
>>> # With default parameters
|
61
70
|
>>> m = Flatten()
|
@@ -100,9 +109,6 @@ class Flatten(Module):
|
|
100
109
|
start_axis = x.ndim + self.start_axis
|
101
110
|
return u.math.flatten(x, start_axis, self.end_axis)
|
102
111
|
|
103
|
-
def __repr__(self) -> str:
|
104
|
-
return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
|
105
|
-
|
106
112
|
|
107
113
|
class Unflatten(Module):
|
108
114
|
r"""
|
@@ -121,10 +127,16 @@ class Unflatten(Module):
|
|
121
127
|
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
|
122
128
|
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
123
129
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
axis : int
|
133
|
+
Dimension to be unflattened.
|
134
|
+
sizes : Sequence of int
|
135
|
+
New shape of the unflattened dimension.
|
136
|
+
name : str, optional
|
137
|
+
The name of the module.
|
138
|
+
in_size : Sequence of int, optional
|
139
|
+
The shape of the input tensor.
|
128
140
|
"""
|
129
141
|
__module__ = 'brainstate.nn'
|
130
142
|
|
@@ -156,9 +168,6 @@ class Unflatten(Module):
|
|
156
168
|
def update(self, x):
|
157
169
|
return u.math.unflatten(x, self.axis, self.sizes)
|
158
170
|
|
159
|
-
def __repr__(self):
|
160
|
-
return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
|
161
|
-
|
162
171
|
|
163
172
|
class _MaxPool(Module):
|
164
173
|
def __init__(
|
@@ -170,6 +179,7 @@ class _MaxPool(Module):
|
|
170
179
|
stride: Union[int, Sequence[int]] = None,
|
171
180
|
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
172
181
|
channel_axis: Optional[int] = -1,
|
182
|
+
return_indices: bool = False,
|
173
183
|
name: Optional[str] = None,
|
174
184
|
in_size: Optional[Size] = None,
|
175
185
|
):
|
@@ -178,6 +188,7 @@ class _MaxPool(Module):
|
|
178
188
|
self.init_value = init_value
|
179
189
|
self.computation = computation
|
180
190
|
self.pool_dim = pool_dim
|
191
|
+
self.return_indices = return_indices
|
181
192
|
|
182
193
|
# kernel_size
|
183
194
|
if isinstance(kernel_size, int):
|
@@ -247,11 +258,18 @@ class _MaxPool(Module):
|
|
247
258
|
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
248
259
|
if x.ndim < x_dim:
|
249
260
|
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
|
250
|
-
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
|
251
|
-
stride = self._infer_shape(x.ndim, self.stride, 1)
|
252
|
-
|
253
|
-
|
254
|
-
|
261
|
+
window_shape = tuple(self._infer_shape(x.ndim, self.kernel_size, 1))
|
262
|
+
stride = tuple(self._infer_shape(x.ndim, self.stride, 1))
|
263
|
+
if isinstance(self.padding, str):
|
264
|
+
padding = tuple(jax.lax.padtype_to_pads(x.shape, window_shape, stride, self.padding))
|
265
|
+
else:
|
266
|
+
padding = tuple(self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
267
|
+
|
268
|
+
if self.return_indices:
|
269
|
+
# For returning indices, we need to use a custom implementation
|
270
|
+
return self._pooling_with_indices(x, window_shape, stride, padding)
|
271
|
+
|
272
|
+
return jax.lax.reduce_window(
|
255
273
|
x,
|
256
274
|
init_value=self.init_value,
|
257
275
|
computation=self.computation,
|
@@ -259,7 +277,39 @@ class _MaxPool(Module):
|
|
259
277
|
window_strides=stride,
|
260
278
|
padding=padding
|
261
279
|
)
|
262
|
-
|
280
|
+
|
281
|
+
def _pooling_with_indices(self, x, window_shape, stride, padding):
|
282
|
+
"""Perform max pooling and return both pooled values and indices."""
|
283
|
+
total_size = x.size
|
284
|
+
flat_indices = jnp.arange(total_size, dtype=jnp.int32).reshape(x.shape)
|
285
|
+
|
286
|
+
init_val = jnp.asarray(self.init_value, dtype=x.dtype)
|
287
|
+
init_idx = jnp.array(total_size, dtype=flat_indices.dtype)
|
288
|
+
|
289
|
+
def reducer(acc, operand):
|
290
|
+
acc_val, acc_idx = acc
|
291
|
+
cur_val, cur_idx = operand
|
292
|
+
|
293
|
+
better = cur_val > acc_val
|
294
|
+
best_val = jnp.where(better, cur_val, acc_val)
|
295
|
+
best_idx = jnp.where(better, cur_idx, acc_idx)
|
296
|
+
tie = jnp.logical_and(cur_val == acc_val, cur_idx < acc_idx)
|
297
|
+
best_idx = jnp.where(tie, cur_idx, best_idx)
|
298
|
+
|
299
|
+
return best_val, best_idx
|
300
|
+
|
301
|
+
pooled, indices_result = jax.lax.reduce_window(
|
302
|
+
(x, flat_indices),
|
303
|
+
(init_val, init_idx),
|
304
|
+
reducer,
|
305
|
+
window_dimensions=window_shape,
|
306
|
+
window_strides=stride,
|
307
|
+
padding=padding
|
308
|
+
)
|
309
|
+
|
310
|
+
indices_result = jnp.where(indices_result == total_size, 0, indices_result)
|
311
|
+
|
312
|
+
return pooled, indices_result.astype(jnp.int32)
|
263
313
|
|
264
314
|
def _infer_shape(self, x_dim, inputs, element):
|
265
315
|
channel_axis = self.channel_axis
|
@@ -331,10 +381,33 @@ class MaxPool1d(_MaxPool):
|
|
331
381
|
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
332
382
|
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
|
333
383
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
384
|
+
Parameters
|
385
|
+
----------
|
386
|
+
kernel_size : int or sequence of int
|
387
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
388
|
+
stride : int or sequence of int, optional
|
389
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
390
|
+
Default: kernel_size
|
391
|
+
padding : str, int or sequence of tuple, optional
|
392
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
393
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
394
|
+
and after each spatial dimension. Default: 'VALID'
|
395
|
+
channel_axis : int, optional
|
396
|
+
Axis of the spatial channels for which pooling is skipped.
|
397
|
+
If ``None``, there is no channel axis. Default: -1
|
398
|
+
return_indices : bool, optional
|
399
|
+
If True, will return the max indices along with the outputs.
|
400
|
+
Useful for MaxUnpool1d. Default: False
|
401
|
+
name : str, optional
|
402
|
+
The object name.
|
403
|
+
in_size : Sequence of int, optional
|
404
|
+
The shape of the input tensor.
|
405
|
+
|
406
|
+
Examples
|
407
|
+
--------
|
408
|
+
.. code-block:: python
|
409
|
+
|
410
|
+
>>> import brainstate
|
338
411
|
>>> # pool of size=3, stride=2
|
339
412
|
>>> m = MaxPool1d(3, stride=2, channel_axis=-1)
|
340
413
|
>>> input = brainstate.random.randn(20, 50, 16)
|
@@ -342,24 +415,6 @@ class MaxPool1d(_MaxPool):
|
|
342
415
|
>>> output.shape
|
343
416
|
(20, 24, 16)
|
344
417
|
|
345
|
-
Parameters
|
346
|
-
----------
|
347
|
-
in_size: Sequence of int
|
348
|
-
The shape of the input tensor.
|
349
|
-
kernel_size: int, sequence of int
|
350
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
351
|
-
stride: int, sequence of int
|
352
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
353
|
-
padding: str, int, sequence of tuple
|
354
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
355
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
356
|
-
and after each spatial dimension.
|
357
|
-
channel_axis: int, optional
|
358
|
-
Axis of the spatial channels for which pooling is skipped.
|
359
|
-
If ``None``, there is no channel axis.
|
360
|
-
name: optional, str
|
361
|
-
The object name.
|
362
|
-
|
363
418
|
.. _link:
|
364
419
|
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
365
420
|
"""
|
@@ -371,6 +426,7 @@ class MaxPool1d(_MaxPool):
|
|
371
426
|
stride: Union[int, Sequence[int]] = None,
|
372
427
|
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
373
428
|
channel_axis: Optional[int] = -1,
|
429
|
+
return_indices: bool = False,
|
374
430
|
name: Optional[str] = None,
|
375
431
|
in_size: Optional[Size] = None,
|
376
432
|
):
|
@@ -382,6 +438,7 @@ class MaxPool1d(_MaxPool):
|
|
382
438
|
stride=stride,
|
383
439
|
padding=padding,
|
384
440
|
channel_axis=channel_axis,
|
441
|
+
return_indices=return_indices,
|
385
442
|
name=name)
|
386
443
|
|
387
444
|
|
@@ -403,7 +460,6 @@ class MaxPool2d(_MaxPool):
|
|
403
460
|
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
404
461
|
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
405
462
|
|
406
|
-
|
407
463
|
Shape:
|
408
464
|
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
409
465
|
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
@@ -416,9 +472,33 @@ class MaxPool2d(_MaxPool):
|
|
416
472
|
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
|
417
473
|
\times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
|
418
474
|
|
419
|
-
|
420
|
-
|
421
|
-
|
475
|
+
Parameters
|
476
|
+
----------
|
477
|
+
kernel_size : int or sequence of int
|
478
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
479
|
+
stride : int or sequence of int, optional
|
480
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
481
|
+
Default: kernel_size
|
482
|
+
padding : str, int or sequence of tuple, optional
|
483
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
484
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
485
|
+
and after each spatial dimension. Default: 'VALID'
|
486
|
+
channel_axis : int, optional
|
487
|
+
Axis of the spatial channels for which pooling is skipped.
|
488
|
+
If ``None``, there is no channel axis. Default: -1
|
489
|
+
return_indices : bool, optional
|
490
|
+
If True, will return the max indices along with the outputs.
|
491
|
+
Useful for MaxUnpool2d. Default: False
|
492
|
+
name : str, optional
|
493
|
+
The object name.
|
494
|
+
in_size : Sequence of int, optional
|
495
|
+
The shape of the input tensor.
|
496
|
+
|
497
|
+
Examples
|
498
|
+
--------
|
499
|
+
.. code-block:: python
|
500
|
+
|
501
|
+
>>> import brainstate
|
422
502
|
>>> # pool of square window of size=3, stride=2
|
423
503
|
>>> m = MaxPool2d(3, stride=2)
|
424
504
|
>>> # pool of non-square window
|
@@ -430,25 +510,6 @@ class MaxPool2d(_MaxPool):
|
|
430
510
|
|
431
511
|
.. _link:
|
432
512
|
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
433
|
-
|
434
|
-
Parameters
|
435
|
-
----------
|
436
|
-
in_size: Sequence of int
|
437
|
-
The shape of the input tensor.
|
438
|
-
kernel_size: int, sequence of int
|
439
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
440
|
-
stride: int, sequence of int
|
441
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
442
|
-
padding: str, int, sequence of tuple
|
443
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
444
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
445
|
-
and after each spatial dimension.
|
446
|
-
channel_axis: int, optional
|
447
|
-
Axis of the spatial channels for which pooling is skipped.
|
448
|
-
If ``None``, there is no channel axis.
|
449
|
-
name: optional, str
|
450
|
-
The object name.
|
451
|
-
|
452
513
|
"""
|
453
514
|
__module__ = 'brainstate.nn'
|
454
515
|
|
@@ -458,6 +519,7 @@ class MaxPool2d(_MaxPool):
|
|
458
519
|
stride: Union[int, Sequence[int]] = None,
|
459
520
|
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
460
521
|
channel_axis: Optional[int] = -1,
|
522
|
+
return_indices: bool = False,
|
461
523
|
name: Optional[str] = None,
|
462
524
|
in_size: Optional[Size] = None,
|
463
525
|
):
|
@@ -469,6 +531,7 @@ class MaxPool2d(_MaxPool):
|
|
469
531
|
stride=stride,
|
470
532
|
padding=padding,
|
471
533
|
channel_axis=channel_axis,
|
534
|
+
return_indices=return_indices,
|
472
535
|
name=name)
|
473
536
|
|
474
537
|
|
@@ -490,7 +553,6 @@ class MaxPool3d(_MaxPool):
|
|
490
553
|
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
491
554
|
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
492
555
|
|
493
|
-
|
494
556
|
Shape:
|
495
557
|
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
496
558
|
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
@@ -507,9 +569,33 @@ class MaxPool3d(_MaxPool):
|
|
507
569
|
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
|
508
570
|
(\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
|
509
571
|
|
510
|
-
|
511
|
-
|
512
|
-
|
572
|
+
Parameters
|
573
|
+
----------
|
574
|
+
kernel_size : int or sequence of int
|
575
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
576
|
+
stride : int or sequence of int, optional
|
577
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
578
|
+
Default: kernel_size
|
579
|
+
padding : str, int or sequence of tuple, optional
|
580
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
581
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
582
|
+
and after each spatial dimension. Default: 'VALID'
|
583
|
+
channel_axis : int, optional
|
584
|
+
Axis of the spatial channels for which pooling is skipped.
|
585
|
+
If ``None``, there is no channel axis. Default: -1
|
586
|
+
return_indices : bool, optional
|
587
|
+
If True, will return the max indices along with the outputs.
|
588
|
+
Useful for MaxUnpool3d. Default: False
|
589
|
+
name : str, optional
|
590
|
+
The object name.
|
591
|
+
in_size : Sequence of int, optional
|
592
|
+
The shape of the input tensor.
|
593
|
+
|
594
|
+
Examples
|
595
|
+
--------
|
596
|
+
.. code-block:: python
|
597
|
+
|
598
|
+
>>> import brainstate
|
513
599
|
>>> # pool of square window of size=3, stride=2
|
514
600
|
>>> m = MaxPool3d(3, stride=2)
|
515
601
|
>>> # pool of non-square window
|
@@ -521,25 +607,6 @@ class MaxPool3d(_MaxPool):
|
|
521
607
|
|
522
608
|
.. _link:
|
523
609
|
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
524
|
-
|
525
|
-
Parameters
|
526
|
-
----------
|
527
|
-
in_size: Sequence of int
|
528
|
-
The shape of the input tensor.
|
529
|
-
kernel_size: int, sequence of int
|
530
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
531
|
-
stride: int, sequence of int
|
532
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
533
|
-
padding: str, int, sequence of tuple
|
534
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
535
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
536
|
-
and after each spatial dimension.
|
537
|
-
channel_axis: int, optional
|
538
|
-
Axis of the spatial channels for which pooling is skipped.
|
539
|
-
If ``None``, there is no channel axis.
|
540
|
-
name: optional, str
|
541
|
-
The object name.
|
542
|
-
|
543
610
|
"""
|
544
611
|
__module__ = 'brainstate.nn'
|
545
612
|
|
@@ -549,6 +616,7 @@ class MaxPool3d(_MaxPool):
|
|
549
616
|
stride: Union[int, Sequence[int]] = None,
|
550
617
|
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
551
618
|
channel_axis: Optional[int] = -1,
|
619
|
+
return_indices: bool = False,
|
552
620
|
name: Optional[str] = None,
|
553
621
|
in_size: Optional[Size] = None,
|
554
622
|
):
|
@@ -560,9 +628,432 @@ class MaxPool3d(_MaxPool):
|
|
560
628
|
stride=stride,
|
561
629
|
padding=padding,
|
562
630
|
channel_axis=channel_axis,
|
631
|
+
return_indices=return_indices,
|
563
632
|
name=name)
|
564
633
|
|
565
634
|
|
635
|
+
class _MaxUnpool(Module):
|
636
|
+
"""Base class for max unpooling operations."""
|
637
|
+
|
638
|
+
def __init__(
|
639
|
+
self,
|
640
|
+
pool_dim: int,
|
641
|
+
kernel_size: Size,
|
642
|
+
stride: Union[int, Sequence[int]] = None,
|
643
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
644
|
+
channel_axis: Optional[int] = -1,
|
645
|
+
name: Optional[str] = None,
|
646
|
+
in_size: Optional[Size] = None,
|
647
|
+
):
|
648
|
+
super().__init__(name=name)
|
649
|
+
|
650
|
+
self.pool_dim = pool_dim
|
651
|
+
|
652
|
+
# kernel_size
|
653
|
+
if isinstance(kernel_size, int):
|
654
|
+
kernel_size = (kernel_size,) * pool_dim
|
655
|
+
elif isinstance(kernel_size, Sequence):
|
656
|
+
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
657
|
+
assert all(
|
658
|
+
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
659
|
+
if len(kernel_size) != pool_dim:
|
660
|
+
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
661
|
+
else:
|
662
|
+
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
663
|
+
self.kernel_size = kernel_size
|
664
|
+
|
665
|
+
# stride
|
666
|
+
if stride is None:
|
667
|
+
stride = kernel_size
|
668
|
+
if isinstance(stride, int):
|
669
|
+
stride = (stride,) * pool_dim
|
670
|
+
elif isinstance(stride, Sequence):
|
671
|
+
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
672
|
+
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
673
|
+
if len(stride) != pool_dim:
|
674
|
+
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
|
675
|
+
else:
|
676
|
+
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
677
|
+
self.stride = stride
|
678
|
+
|
679
|
+
# padding
|
680
|
+
if isinstance(padding, int):
|
681
|
+
padding = (padding,) * pool_dim
|
682
|
+
elif isinstance(padding, (tuple, list)):
|
683
|
+
if len(padding) != pool_dim:
|
684
|
+
raise ValueError(f'padding should have {pool_dim} values, but got {len(padding)}')
|
685
|
+
else:
|
686
|
+
raise TypeError(f'padding should be int or tuple of {pool_dim} ints.')
|
687
|
+
self.padding = padding
|
688
|
+
|
689
|
+
# channel_axis
|
690
|
+
assert channel_axis is None or isinstance(channel_axis, int), \
|
691
|
+
f'channel_axis should be an int, but got {channel_axis}'
|
692
|
+
self.channel_axis = channel_axis
|
693
|
+
|
694
|
+
# in & out shapes
|
695
|
+
if in_size is not None:
|
696
|
+
in_size = tuple(in_size)
|
697
|
+
self.in_size = in_size
|
698
|
+
|
699
|
+
def _compute_output_shape(self, input_shape, output_size=None):
|
700
|
+
"""Compute the output shape after unpooling."""
|
701
|
+
if output_size is not None:
|
702
|
+
return output_size
|
703
|
+
|
704
|
+
# Calculate output shape based on kernel, stride, and padding
|
705
|
+
output_shape = []
|
706
|
+
for i in range(self.pool_dim):
|
707
|
+
dim_size = (input_shape[i] - 1) * self.stride[i] - 2 * self.padding[i] + self.kernel_size[i]
|
708
|
+
output_shape.append(dim_size)
|
709
|
+
|
710
|
+
return tuple(output_shape)
|
711
|
+
|
712
|
+
def _unpool_nd(self, x, indices, output_size=None):
|
713
|
+
"""Perform N-dimensional max unpooling."""
|
714
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
715
|
+
if x.ndim < x_dim:
|
716
|
+
raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
|
717
|
+
|
718
|
+
# Determine output shape
|
719
|
+
if output_size is None:
|
720
|
+
# Infer output shape from input shape
|
721
|
+
spatial_dims = self._get_spatial_dims(x.shape)
|
722
|
+
output_spatial_shape = self._compute_output_shape(spatial_dims, output_size)
|
723
|
+
output_shape = list(x.shape)
|
724
|
+
|
725
|
+
# Update spatial dimensions in output shape
|
726
|
+
spatial_start = self._get_spatial_start_idx(x.ndim)
|
727
|
+
for i, size in enumerate(output_spatial_shape):
|
728
|
+
output_shape[spatial_start + i] = size
|
729
|
+
output_shape = tuple(output_shape)
|
730
|
+
else:
|
731
|
+
# Use provided output size
|
732
|
+
if isinstance(output_size, (list, tuple)):
|
733
|
+
if len(output_size) == x.ndim:
|
734
|
+
# Full output shape provided
|
735
|
+
output_shape = tuple(output_size)
|
736
|
+
else:
|
737
|
+
# Only spatial dimensions provided
|
738
|
+
if len(output_size) != self.pool_dim:
|
739
|
+
raise ValueError(f"output_size must have {self.pool_dim} spatial dimensions, got {len(output_size)}")
|
740
|
+
output_shape = list(x.shape)
|
741
|
+
spatial_start = self._get_spatial_start_idx(x.ndim)
|
742
|
+
for i, size in enumerate(output_size):
|
743
|
+
output_shape[spatial_start + i] = size
|
744
|
+
output_shape = tuple(output_shape)
|
745
|
+
else:
|
746
|
+
# Single integer provided, use for all spatial dims
|
747
|
+
output_shape = list(x.shape)
|
748
|
+
spatial_start = self._get_spatial_start_idx(x.ndim)
|
749
|
+
for i in range(self.pool_dim):
|
750
|
+
output_shape[spatial_start + i] = output_size
|
751
|
+
output_shape = tuple(output_shape)
|
752
|
+
|
753
|
+
# Create output array filled with zeros
|
754
|
+
output = jnp.zeros(output_shape, dtype=x.dtype)
|
755
|
+
|
756
|
+
# # Scatter input values to output using indices
|
757
|
+
# # Flatten spatial dimensions for easier indexing
|
758
|
+
# batch_dims = x.ndim - self.pool_dim - (0 if self.channel_axis is None else 1)
|
759
|
+
#
|
760
|
+
# # Reshape for processing
|
761
|
+
# if batch_dims > 0:
|
762
|
+
# batch_shape = x.shape[:batch_dims]
|
763
|
+
# if self.channel_axis is not None and self.channel_axis < batch_dims:
|
764
|
+
# # Channel axis is before spatial dims
|
765
|
+
# channel_idx = self.channel_axis
|
766
|
+
# n_channels = x.shape[channel_idx]
|
767
|
+
# elif self.channel_axis is not None:
|
768
|
+
# # Channel axis is after spatial dims
|
769
|
+
# if self.channel_axis < 0:
|
770
|
+
# channel_idx = x.ndim + self.channel_axis
|
771
|
+
# else:
|
772
|
+
# channel_idx = self.channel_axis
|
773
|
+
# n_channels = x.shape[channel_idx]
|
774
|
+
# else:
|
775
|
+
# n_channels = None
|
776
|
+
# else:
|
777
|
+
# batch_shape = ()
|
778
|
+
# if self.channel_axis is not None:
|
779
|
+
# if self.channel_axis < 0:
|
780
|
+
# channel_idx = x.ndim + self.channel_axis
|
781
|
+
# else:
|
782
|
+
# channel_idx = self.channel_axis
|
783
|
+
# n_channels = x.shape[channel_idx]
|
784
|
+
# else:
|
785
|
+
# n_channels = None
|
786
|
+
|
787
|
+
# Use JAX's scatter operation
|
788
|
+
# Flatten the indices to 1D for scatter
|
789
|
+
flat_indices = indices.ravel()
|
790
|
+
flat_values = x.ravel()
|
791
|
+
flat_output = output.ravel()
|
792
|
+
|
793
|
+
# Scatter the values
|
794
|
+
flat_output = flat_output.at[flat_indices].set(flat_values)
|
795
|
+
|
796
|
+
# Reshape back to original shape
|
797
|
+
output = flat_output.reshape(output_shape)
|
798
|
+
|
799
|
+
return output
|
800
|
+
|
801
|
+
def _get_spatial_dims(self, shape):
|
802
|
+
"""Extract spatial dimensions from input shape."""
|
803
|
+
if self.channel_axis is None:
|
804
|
+
return shape[-self.pool_dim:]
|
805
|
+
else:
|
806
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else len(shape) + self.channel_axis
|
807
|
+
all_dims = list(range(len(shape)))
|
808
|
+
all_dims.pop(channel_axis)
|
809
|
+
return tuple(shape[i] for i in all_dims[-self.pool_dim:])
|
810
|
+
|
811
|
+
def _get_spatial_start_idx(self, ndim):
|
812
|
+
"""Get the starting index of spatial dimensions."""
|
813
|
+
if self.channel_axis is None:
|
814
|
+
return ndim - self.pool_dim
|
815
|
+
else:
|
816
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else ndim + self.channel_axis
|
817
|
+
if channel_axis < ndim - self.pool_dim:
|
818
|
+
return ndim - self.pool_dim
|
819
|
+
else:
|
820
|
+
return ndim - self.pool_dim - 1
|
821
|
+
|
822
|
+
def update(self, x, indices, output_size=None):
|
823
|
+
"""Forward pass of MaxUnpool1d.
|
824
|
+
|
825
|
+
Parameters
|
826
|
+
----------
|
827
|
+
x : Array
|
828
|
+
Input tensor from MaxPool1d
|
829
|
+
indices : Array
|
830
|
+
Indices of maximum values from MaxPool1d
|
831
|
+
output_size : int or tuple, optional
|
832
|
+
The targeted output size
|
833
|
+
|
834
|
+
Returns
|
835
|
+
-------
|
836
|
+
Array
|
837
|
+
Unpooled output
|
838
|
+
"""
|
839
|
+
return self._unpool_nd(x, indices, output_size)
|
840
|
+
|
841
|
+
|
842
|
+
class MaxUnpool1d(_MaxUnpool):
|
843
|
+
r"""Computes a partial inverse of MaxPool1d.
|
844
|
+
|
845
|
+
MaxPool1d is not fully invertible, since the non-maximal values are lost.
|
846
|
+
MaxUnpool1d takes in as input the output of MaxPool1d including the indices
|
847
|
+
of the maximal values and computes a partial inverse in which all
|
848
|
+
non-maximal values are set to zero.
|
849
|
+
|
850
|
+
Note:
|
851
|
+
This function may produce nondeterministic gradients when given tensors
|
852
|
+
on a CUDA device. See notes on reproducibility for more information.
|
853
|
+
|
854
|
+
Shape:
|
855
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`
|
856
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
857
|
+
|
858
|
+
.. math::
|
859
|
+
L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size}
|
860
|
+
|
861
|
+
or as given by :attr:`output_size` in the call operator
|
862
|
+
|
863
|
+
Parameters
|
864
|
+
----------
|
865
|
+
kernel_size : int or tuple
|
866
|
+
Size of the max pooling window.
|
867
|
+
stride : int or tuple, optional
|
868
|
+
Stride of the max pooling window. Default: kernel_size
|
869
|
+
padding : int or tuple, optional
|
870
|
+
Padding that was added to the input. Default: 0
|
871
|
+
channel_axis : int, optional
|
872
|
+
Axis of the channels. Default: -1
|
873
|
+
name : str, optional
|
874
|
+
Name of the module.
|
875
|
+
in_size : Size, optional
|
876
|
+
Input size for shape inference.
|
877
|
+
|
878
|
+
Examples
|
879
|
+
--------
|
880
|
+
.. code-block:: python
|
881
|
+
|
882
|
+
>>> import brainstate
|
883
|
+
>>> import jax.numpy as jnp
|
884
|
+
>>> # Create pooling and unpooling layers
|
885
|
+
>>> pool = MaxPool1d(2, stride=2, return_indices=True, channel_axis=-1)
|
886
|
+
>>> unpool = MaxUnpool1d(2, stride=2, channel_axis=-1)
|
887
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
888
|
+
>>> output, indices = pool(input)
|
889
|
+
>>> unpooled = unpool(output, indices)
|
890
|
+
>>> # unpooled will have shape (20, 100, 16) with zeros at non-maximal positions
|
891
|
+
"""
|
892
|
+
__module__ = 'brainstate.nn'
|
893
|
+
|
894
|
+
def __init__(
|
895
|
+
self,
|
896
|
+
kernel_size: Size,
|
897
|
+
stride: Union[int, Sequence[int]] = None,
|
898
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
899
|
+
channel_axis: Optional[int] = -1,
|
900
|
+
name: Optional[str] = None,
|
901
|
+
in_size: Optional[Size] = None,
|
902
|
+
):
|
903
|
+
super().__init__(
|
904
|
+
pool_dim=1,
|
905
|
+
kernel_size=kernel_size,
|
906
|
+
stride=stride,
|
907
|
+
padding=padding,
|
908
|
+
channel_axis=channel_axis,
|
909
|
+
name=name,
|
910
|
+
in_size=in_size
|
911
|
+
)
|
912
|
+
|
913
|
+
|
914
|
+
class MaxUnpool2d(_MaxUnpool):
|
915
|
+
r"""Computes a partial inverse of MaxPool2d.
|
916
|
+
|
917
|
+
MaxPool2d is not fully invertible, since the non-maximal values are lost.
|
918
|
+
MaxUnpool2d takes in as input the output of MaxPool2d including the indices
|
919
|
+
of the maximal values and computes a partial inverse in which all
|
920
|
+
non-maximal values are set to zero.
|
921
|
+
|
922
|
+
Shape:
|
923
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
924
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
925
|
+
|
926
|
+
.. math::
|
927
|
+
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
|
928
|
+
|
929
|
+
.. math::
|
930
|
+
W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
|
931
|
+
|
932
|
+
or as given by :attr:`output_size` in the call operator
|
933
|
+
|
934
|
+
Parameters
|
935
|
+
----------
|
936
|
+
kernel_size : int or tuple
|
937
|
+
Size of the max pooling window.
|
938
|
+
stride : int or tuple, optional
|
939
|
+
Stride of the max pooling window. Default: kernel_size
|
940
|
+
padding : int or tuple, optional
|
941
|
+
Padding that was added to the input. Default: 0
|
942
|
+
channel_axis : int, optional
|
943
|
+
Axis of the channels. Default: -1
|
944
|
+
name : str, optional
|
945
|
+
Name of the module.
|
946
|
+
in_size : Size, optional
|
947
|
+
Input size for shape inference.
|
948
|
+
|
949
|
+
Examples
|
950
|
+
--------
|
951
|
+
.. code-block:: python
|
952
|
+
|
953
|
+
>>> import brainstate
|
954
|
+
>>> # Create pooling and unpooling layers
|
955
|
+
>>> pool = MaxPool2d(2, stride=2, return_indices=True, channel_axis=-1)
|
956
|
+
>>> unpool = MaxUnpool2d(2, stride=2, channel_axis=-1)
|
957
|
+
>>> input = brainstate.random.randn(1, 4, 4, 16)
|
958
|
+
>>> output, indices = pool(input)
|
959
|
+
>>> unpooled = unpool(output, indices)
|
960
|
+
>>> # unpooled will have shape (1, 8, 8, 16) with zeros at non-maximal positions
|
961
|
+
"""
|
962
|
+
__module__ = 'brainstate.nn'
|
963
|
+
|
964
|
+
def __init__(
|
965
|
+
self,
|
966
|
+
kernel_size: Size,
|
967
|
+
stride: Union[int, Sequence[int]] = None,
|
968
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
969
|
+
channel_axis: Optional[int] = -1,
|
970
|
+
name: Optional[str] = None,
|
971
|
+
in_size: Optional[Size] = None,
|
972
|
+
):
|
973
|
+
super().__init__(
|
974
|
+
pool_dim=2,
|
975
|
+
kernel_size=kernel_size,
|
976
|
+
stride=stride,
|
977
|
+
padding=padding,
|
978
|
+
channel_axis=channel_axis,
|
979
|
+
name=name,
|
980
|
+
in_size=in_size
|
981
|
+
)
|
982
|
+
|
983
|
+
|
984
|
+
class MaxUnpool3d(_MaxUnpool):
|
985
|
+
r"""Computes a partial inverse of MaxPool3d.
|
986
|
+
|
987
|
+
MaxPool3d is not fully invertible, since the non-maximal values are lost.
|
988
|
+
MaxUnpool3d takes in as input the output of MaxPool3d including the indices
|
989
|
+
of the maximal values and computes a partial inverse in which all
|
990
|
+
non-maximal values are set to zero.
|
991
|
+
|
992
|
+
Shape:
|
993
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`
|
994
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
995
|
+
|
996
|
+
.. math::
|
997
|
+
D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
|
998
|
+
|
999
|
+
.. math::
|
1000
|
+
H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
|
1001
|
+
|
1002
|
+
.. math::
|
1003
|
+
W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{kernel\_size}[2]
|
1004
|
+
|
1005
|
+
or as given by :attr:`output_size` in the call operator
|
1006
|
+
|
1007
|
+
Parameters
|
1008
|
+
----------
|
1009
|
+
kernel_size : int or tuple
|
1010
|
+
Size of the max pooling window.
|
1011
|
+
stride : int or tuple, optional
|
1012
|
+
Stride of the max pooling window. Default: kernel_size
|
1013
|
+
padding : int or tuple, optional
|
1014
|
+
Padding that was added to the input. Default: 0
|
1015
|
+
channel_axis : int, optional
|
1016
|
+
Axis of the channels. Default: -1
|
1017
|
+
name : str, optional
|
1018
|
+
Name of the module.
|
1019
|
+
in_size : Size, optional
|
1020
|
+
Input size for shape inference.
|
1021
|
+
|
1022
|
+
Examples
|
1023
|
+
--------
|
1024
|
+
.. code-block:: python
|
1025
|
+
|
1026
|
+
>>> import brainstate
|
1027
|
+
>>> # Create pooling and unpooling layers
|
1028
|
+
>>> pool = MaxPool3d(2, stride=2, return_indices=True, channel_axis=-1)
|
1029
|
+
>>> unpool = MaxUnpool3d(2, stride=2, channel_axis=-1)
|
1030
|
+
>>> input = brainstate.random.randn(1, 4, 4, 4, 16)
|
1031
|
+
>>> output, indices = pool(input)
|
1032
|
+
>>> unpooled = unpool(output, indices)
|
1033
|
+
>>> # unpooled will have shape (1, 8, 8, 8, 16) with zeros at non-maximal positions
|
1034
|
+
"""
|
1035
|
+
__module__ = 'brainstate.nn'
|
1036
|
+
|
1037
|
+
def __init__(
|
1038
|
+
self,
|
1039
|
+
kernel_size: Size,
|
1040
|
+
stride: Union[int, Sequence[int]] = None,
|
1041
|
+
padding: Union[int, Tuple[int, ...]] = 0,
|
1042
|
+
channel_axis: Optional[int] = -1,
|
1043
|
+
name: Optional[str] = None,
|
1044
|
+
in_size: Optional[Size] = None,
|
1045
|
+
):
|
1046
|
+
super().__init__(
|
1047
|
+
pool_dim=3,
|
1048
|
+
kernel_size=kernel_size,
|
1049
|
+
stride=stride,
|
1050
|
+
padding=padding,
|
1051
|
+
channel_axis=channel_axis,
|
1052
|
+
name=name,
|
1053
|
+
in_size=in_size
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
|
566
1057
|
class AvgPool1d(_AvgPool):
|
567
1058
|
r"""Applies a 1D average pooling over an input signal composed of several input planes.
|
568
1059
|
|
@@ -586,33 +1077,35 @@ class AvgPool1d(_AvgPool):
|
|
586
1077
|
L_{out} = \left\lfloor \frac{L_{in} +
|
587
1078
|
2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
588
1079
|
|
589
|
-
|
590
|
-
|
591
|
-
|
1080
|
+
Parameters
|
1081
|
+
----------
|
1082
|
+
kernel_size : int or sequence of int
|
1083
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1084
|
+
stride : int or sequence of int, optional
|
1085
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1086
|
+
Default: 1
|
1087
|
+
padding : str, int or sequence of tuple, optional
|
1088
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1089
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1090
|
+
and after each spatial dimension. Default: 'VALID'
|
1091
|
+
channel_axis : int, optional
|
1092
|
+
Axis of the spatial channels for which pooling is skipped.
|
1093
|
+
If ``None``, there is no channel axis. Default: -1
|
1094
|
+
name : str, optional
|
1095
|
+
The object name.
|
1096
|
+
in_size : Sequence of int, optional
|
1097
|
+
The shape of the input tensor.
|
1098
|
+
|
1099
|
+
Examples
|
1100
|
+
--------
|
1101
|
+
.. code-block:: python
|
1102
|
+
|
1103
|
+
>>> import brainstate
|
592
1104
|
>>> # pool with window of size=3, stride=2
|
593
1105
|
>>> m = AvgPool1d(3, stride=2)
|
594
1106
|
>>> input = brainstate.random.randn(20, 50, 16)
|
595
1107
|
>>> m(input).shape
|
596
1108
|
(20, 24, 16)
|
597
|
-
|
598
|
-
Parameters
|
599
|
-
----------
|
600
|
-
in_size: Sequence of int
|
601
|
-
The shape of the input tensor.
|
602
|
-
kernel_size: int, sequence of int
|
603
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
604
|
-
stride: int, sequence of int
|
605
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
606
|
-
padding: str, int, sequence of tuple
|
607
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
608
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
609
|
-
and after each spatial dimension.
|
610
|
-
channel_axis: int, optional
|
611
|
-
Axis of the spatial channels for which pooling is skipped.
|
612
|
-
If ``None``, there is no channel axis.
|
613
|
-
name: optional, str
|
614
|
-
The object name.
|
615
|
-
|
616
1109
|
"""
|
617
1110
|
__module__ = 'brainstate.nn'
|
618
1111
|
|
@@ -625,15 +1118,17 @@ class AvgPool1d(_AvgPool):
|
|
625
1118
|
name: Optional[str] = None,
|
626
1119
|
in_size: Optional[Size] = None,
|
627
1120
|
):
|
628
|
-
super().__init__(
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
1121
|
+
super().__init__(
|
1122
|
+
in_size=in_size,
|
1123
|
+
init_value=0.,
|
1124
|
+
computation=jax.lax.add,
|
1125
|
+
pool_dim=1,
|
1126
|
+
kernel_size=kernel_size,
|
1127
|
+
stride=stride,
|
1128
|
+
padding=padding,
|
1129
|
+
channel_axis=channel_axis,
|
1130
|
+
name=name
|
1131
|
+
)
|
637
1132
|
|
638
1133
|
|
639
1134
|
class AvgPool2d(_AvgPool):
|
@@ -663,35 +1158,38 @@ class AvgPool2d(_AvgPool):
|
|
663
1158
|
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
|
664
1159
|
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
665
1160
|
|
666
|
-
|
667
|
-
|
668
|
-
|
1161
|
+
Parameters
|
1162
|
+
----------
|
1163
|
+
kernel_size : int or sequence of int
|
1164
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1165
|
+
stride : int or sequence of int, optional
|
1166
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1167
|
+
Default: 1
|
1168
|
+
padding : str, int or sequence of tuple, optional
|
1169
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1170
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1171
|
+
and after each spatial dimension. Default: 'VALID'
|
1172
|
+
channel_axis : int, optional
|
1173
|
+
Axis of the spatial channels for which pooling is skipped.
|
1174
|
+
If ``None``, there is no channel axis. Default: -1
|
1175
|
+
name : str, optional
|
1176
|
+
The object name.
|
1177
|
+
in_size : Sequence of int, optional
|
1178
|
+
The shape of the input tensor.
|
1179
|
+
|
1180
|
+
Examples
|
1181
|
+
--------
|
1182
|
+
.. code-block:: python
|
1183
|
+
|
1184
|
+
>>> import brainstate
|
669
1185
|
>>> # pool of square window of size=3, stride=2
|
670
1186
|
>>> m = AvgPool2d(3, stride=2)
|
671
1187
|
>>> # pool of non-square window
|
672
1188
|
>>> m = AvgPool2d((3, 2), stride=(2, 1))
|
673
|
-
>>> input = brainstate.random.randn(20, 50, 32,
|
1189
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
674
1190
|
>>> output = m(input)
|
675
1191
|
>>> output.shape
|
676
1192
|
(20, 24, 31, 16)
|
677
|
-
|
678
|
-
Parameters
|
679
|
-
----------
|
680
|
-
in_size: Sequence of int
|
681
|
-
The shape of the input tensor.
|
682
|
-
kernel_size: int, sequence of int
|
683
|
-
An integer, or a sequence of integers defining the window to reduce over.
|
684
|
-
stride: int, sequence of int
|
685
|
-
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
|
686
|
-
padding: str, int, sequence of tuple
|
687
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
688
|
-
of n `(low, high)` integer pairs that give the padding to apply before
|
689
|
-
and after each spatial dimension.
|
690
|
-
channel_axis: int, optional
|
691
|
-
Axis of the spatial channels for which pooling is skipped.
|
692
|
-
If ``None``, there is no channel axis.
|
693
|
-
name: optional, str
|
694
|
-
The object name.
|
695
1193
|
"""
|
696
1194
|
__module__ = 'brainstate.nn'
|
697
1195
|
|
@@ -704,15 +1202,17 @@ class AvgPool2d(_AvgPool):
|
|
704
1202
|
name: Optional[str] = None,
|
705
1203
|
in_size: Optional[Size] = None,
|
706
1204
|
):
|
707
|
-
super().__init__(
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
1205
|
+
super().__init__(
|
1206
|
+
in_size=in_size,
|
1207
|
+
init_value=0.,
|
1208
|
+
computation=jax.lax.add,
|
1209
|
+
pool_dim=2,
|
1210
|
+
kernel_size=kernel_size,
|
1211
|
+
stride=stride,
|
1212
|
+
padding=padding,
|
1213
|
+
channel_axis=channel_axis,
|
1214
|
+
name=name
|
1215
|
+
)
|
716
1216
|
|
717
1217
|
|
718
1218
|
class AvgPool3d(_AvgPool):
|
@@ -751,9 +1251,30 @@ class AvgPool3d(_AvgPool):
|
|
751
1251
|
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
|
752
1252
|
\text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
753
1253
|
|
754
|
-
|
755
|
-
|
756
|
-
|
1254
|
+
Parameters
|
1255
|
+
----------
|
1256
|
+
kernel_size : int or sequence of int
|
1257
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1258
|
+
stride : int or sequence of int, optional
|
1259
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1260
|
+
Default: 1
|
1261
|
+
padding : str, int or sequence of tuple, optional
|
1262
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1263
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1264
|
+
and after each spatial dimension. Default: 'VALID'
|
1265
|
+
channel_axis : int, optional
|
1266
|
+
Axis of the spatial channels for which pooling is skipped.
|
1267
|
+
If ``None``, there is no channel axis. Default: -1
|
1268
|
+
name : str, optional
|
1269
|
+
The object name.
|
1270
|
+
in_size : Sequence of int, optional
|
1271
|
+
The shape of the input tensor.
|
1272
|
+
|
1273
|
+
Examples
|
1274
|
+
--------
|
1275
|
+
.. code-block:: python
|
1276
|
+
|
1277
|
+
>>> import brainstate
|
757
1278
|
>>> # pool of square window of size=3, stride=2
|
758
1279
|
>>> m = AvgPool3d(3, stride=2)
|
759
1280
|
>>> # pool of non-square window
|
@@ -763,45 +1284,407 @@ class AvgPool3d(_AvgPool):
|
|
763
1284
|
>>> output.shape
|
764
1285
|
(20, 24, 43, 15, 16)
|
765
1286
|
|
1287
|
+
"""
|
1288
|
+
__module__ = 'brainstate.nn'
|
1289
|
+
|
1290
|
+
def __init__(
|
1291
|
+
self,
|
1292
|
+
kernel_size: Size,
|
1293
|
+
stride: Union[int, Sequence[int]] = 1,
|
1294
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1295
|
+
channel_axis: Optional[int] = -1,
|
1296
|
+
name: Optional[str] = None,
|
1297
|
+
in_size: Optional[Size] = None,
|
1298
|
+
):
|
1299
|
+
super().__init__(
|
1300
|
+
in_size=in_size,
|
1301
|
+
init_value=0.,
|
1302
|
+
computation=jax.lax.add,
|
1303
|
+
pool_dim=3,
|
1304
|
+
kernel_size=kernel_size,
|
1305
|
+
stride=stride,
|
1306
|
+
padding=padding,
|
1307
|
+
channel_axis=channel_axis,
|
1308
|
+
name=name
|
1309
|
+
)
|
1310
|
+
|
1311
|
+
|
1312
|
+
class _LPPool(Module):
|
1313
|
+
"""Base class for Lp pooling operations."""
|
1314
|
+
|
1315
|
+
def __init__(
|
1316
|
+
self,
|
1317
|
+
norm_type: float,
|
1318
|
+
pool_dim: int,
|
1319
|
+
kernel_size: Size,
|
1320
|
+
stride: Union[int, Sequence[int]] = None,
|
1321
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1322
|
+
channel_axis: Optional[int] = -1,
|
1323
|
+
name: Optional[str] = None,
|
1324
|
+
in_size: Optional[Size] = None,
|
1325
|
+
):
|
1326
|
+
super().__init__(name=name)
|
1327
|
+
|
1328
|
+
if norm_type <= 0:
|
1329
|
+
raise ValueError(f"norm_type must be positive, got {norm_type}")
|
1330
|
+
self.norm_type = norm_type
|
1331
|
+
self.pool_dim = pool_dim
|
1332
|
+
|
1333
|
+
# kernel_size
|
1334
|
+
if isinstance(kernel_size, int):
|
1335
|
+
kernel_size = (kernel_size,) * pool_dim
|
1336
|
+
elif isinstance(kernel_size, Sequence):
|
1337
|
+
assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
|
1338
|
+
assert all(
|
1339
|
+
[isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
|
1340
|
+
if len(kernel_size) != pool_dim:
|
1341
|
+
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
|
1342
|
+
else:
|
1343
|
+
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
|
1344
|
+
self.kernel_size = kernel_size
|
1345
|
+
|
1346
|
+
# stride
|
1347
|
+
if stride is None:
|
1348
|
+
stride = kernel_size
|
1349
|
+
if isinstance(stride, int):
|
1350
|
+
stride = (stride,) * pool_dim
|
1351
|
+
elif isinstance(stride, Sequence):
|
1352
|
+
assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
|
1353
|
+
assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
|
1354
|
+
if len(stride) != pool_dim:
|
1355
|
+
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
|
1356
|
+
else:
|
1357
|
+
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
|
1358
|
+
self.stride = stride
|
1359
|
+
|
1360
|
+
# padding
|
1361
|
+
if isinstance(padding, str):
|
1362
|
+
if padding not in ("SAME", "VALID"):
|
1363
|
+
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
|
1364
|
+
elif isinstance(padding, int):
|
1365
|
+
padding = [(padding, padding) for _ in range(pool_dim)]
|
1366
|
+
elif isinstance(padding, (list, tuple)):
|
1367
|
+
if isinstance(padding[0], int):
|
1368
|
+
if len(padding) == pool_dim:
|
1369
|
+
padding = [(x, x) for x in padding]
|
1370
|
+
else:
|
1371
|
+
raise ValueError(f'If padding is a sequence of ints, it '
|
1372
|
+
f'should has the length of {pool_dim}.')
|
1373
|
+
else:
|
1374
|
+
if not all([isinstance(x, (tuple, list)) for x in padding]):
|
1375
|
+
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
|
1376
|
+
if not all([len(x) == 2 for x in padding]):
|
1377
|
+
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
|
1378
|
+
if len(padding) == 1:
|
1379
|
+
padding = tuple(padding) * pool_dim
|
1380
|
+
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
|
1381
|
+
else:
|
1382
|
+
raise ValueError
|
1383
|
+
self.padding = padding
|
1384
|
+
|
1385
|
+
# channel_axis
|
1386
|
+
assert channel_axis is None or isinstance(channel_axis, int), \
|
1387
|
+
f'channel_axis should be an int, but got {channel_axis}'
|
1388
|
+
self.channel_axis = channel_axis
|
1389
|
+
|
1390
|
+
# in & out shapes
|
1391
|
+
if in_size is not None:
|
1392
|
+
in_size = tuple(in_size)
|
1393
|
+
self.in_size = in_size
|
1394
|
+
y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
|
1395
|
+
self.out_size = y.shape[1:]
|
1396
|
+
|
1397
|
+
def update(self, x):
|
1398
|
+
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
|
1399
|
+
if x.ndim < x_dim:
|
1400
|
+
raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
|
1401
|
+
|
1402
|
+
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
|
1403
|
+
stride = self._infer_shape(x.ndim, self.stride, 1)
|
1404
|
+
padding = (self.padding if isinstance(self.padding, str) else
|
1405
|
+
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
|
1406
|
+
|
1407
|
+
# For Lp pooling, we need to:
|
1408
|
+
# 1. Take absolute value and raise to power p
|
1409
|
+
# 2. Sum over the window
|
1410
|
+
# 3. Take the p-th root
|
1411
|
+
|
1412
|
+
# Step 1: |x|^p
|
1413
|
+
x_pow = jnp.abs(x) ** self.norm_type
|
1414
|
+
|
1415
|
+
# Step 2: Sum over window
|
1416
|
+
pooled_sum = jax.lax.reduce_window(
|
1417
|
+
x_pow,
|
1418
|
+
init_value=0.,
|
1419
|
+
computation=jax.lax.add,
|
1420
|
+
window_dimensions=window_shape,
|
1421
|
+
window_strides=stride,
|
1422
|
+
padding=padding
|
1423
|
+
)
|
1424
|
+
|
1425
|
+
# Step 3: Take p-th root and multiply by normalization factor
|
1426
|
+
# The normalization factor is (1/N)^(1/p) where N is the window size
|
1427
|
+
window_size = np.prod([w for i, w in enumerate(self.kernel_size)])
|
1428
|
+
norm_factor = window_size ** (-1.0 / self.norm_type)
|
1429
|
+
result = norm_factor * (pooled_sum ** (1.0 / self.norm_type))
|
1430
|
+
|
1431
|
+
return result
|
1432
|
+
|
1433
|
+
def _infer_shape(self, x_dim, inputs, element):
|
1434
|
+
channel_axis = self.channel_axis
|
1435
|
+
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
|
1436
|
+
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
|
1437
|
+
if channel_axis and channel_axis < 0:
|
1438
|
+
channel_axis = x_dim + channel_axis
|
1439
|
+
all_dims = list(range(x_dim))
|
1440
|
+
if channel_axis is not None:
|
1441
|
+
all_dims.pop(channel_axis)
|
1442
|
+
pool_dims = all_dims[-self.pool_dim:]
|
1443
|
+
results = [element] * x_dim
|
1444
|
+
for i, dim in enumerate(pool_dims):
|
1445
|
+
results[dim] = inputs[i]
|
1446
|
+
return results
|
1447
|
+
|
1448
|
+
|
1449
|
+
class LPPool1d(_LPPool):
|
1450
|
+
r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
|
1451
|
+
|
1452
|
+
On each window, the function computed is:
|
1453
|
+
|
1454
|
+
.. math::
|
1455
|
+
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1456
|
+
|
1457
|
+
- At :math:`p = \infty`, one gets max pooling
|
1458
|
+
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1459
|
+
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1460
|
+
|
1461
|
+
Shape:
|
1462
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
1463
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
1464
|
+
|
1465
|
+
.. math::
|
1466
|
+
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
1467
|
+
|
766
1468
|
Parameters
|
767
1469
|
----------
|
768
|
-
|
769
|
-
|
770
|
-
kernel_size: int
|
771
|
-
|
772
|
-
stride: int
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
1470
|
+
norm_type : float
|
1471
|
+
Exponent for the pooling operation. Default: 2.0
|
1472
|
+
kernel_size : int or sequence of int
|
1473
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1474
|
+
stride : int or sequence of int, optional
|
1475
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1476
|
+
Default: kernel_size
|
1477
|
+
padding : str, int or sequence of tuple, optional
|
1478
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1479
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1480
|
+
and after each spatial dimension. Default: 'VALID'
|
1481
|
+
channel_axis : int, optional
|
1482
|
+
Axis of the spatial channels for which pooling is skipped.
|
1483
|
+
If ``None``, there is no channel axis. Default: -1
|
1484
|
+
name : str, optional
|
1485
|
+
The object name.
|
1486
|
+
in_size : Sequence of int, optional
|
1487
|
+
The shape of the input tensor.
|
1488
|
+
|
1489
|
+
Examples
|
1490
|
+
--------
|
1491
|
+
.. code-block:: python
|
1492
|
+
|
1493
|
+
>>> import brainstate
|
1494
|
+
>>> # power-average pooling of window of size=3, stride=2 with norm_type=2.0
|
1495
|
+
>>> m = LPPool1d(2, 3, stride=2)
|
1496
|
+
>>> input = brainstate.random.randn(20, 50, 16)
|
1497
|
+
>>> output = m(input)
|
1498
|
+
>>> output.shape
|
1499
|
+
(20, 24, 16)
|
1500
|
+
"""
|
1501
|
+
__module__ = 'brainstate.nn'
|
783
1502
|
|
1503
|
+
def __init__(
|
1504
|
+
self,
|
1505
|
+
norm_type: float,
|
1506
|
+
kernel_size: Size,
|
1507
|
+
stride: Union[int, Sequence[int]] = None,
|
1508
|
+
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
1509
|
+
channel_axis: Optional[int] = -1,
|
1510
|
+
name: Optional[str] = None,
|
1511
|
+
in_size: Optional[Size] = None,
|
1512
|
+
):
|
1513
|
+
super().__init__(
|
1514
|
+
norm_type=norm_type,
|
1515
|
+
pool_dim=1,
|
1516
|
+
kernel_size=kernel_size,
|
1517
|
+
stride=stride,
|
1518
|
+
padding=padding,
|
1519
|
+
channel_axis=channel_axis,
|
1520
|
+
name=name,
|
1521
|
+
in_size=in_size
|
1522
|
+
)
|
1523
|
+
|
1524
|
+
|
1525
|
+
class LPPool2d(_LPPool):
|
1526
|
+
r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
|
1527
|
+
|
1528
|
+
On each window, the function computed is:
|
1529
|
+
|
1530
|
+
.. math::
|
1531
|
+
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1532
|
+
|
1533
|
+
- At :math:`p = \infty`, one gets max pooling
|
1534
|
+
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1535
|
+
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1536
|
+
|
1537
|
+
Shape:
|
1538
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
|
1539
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
1540
|
+
|
1541
|
+
.. math::
|
1542
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1543
|
+
|
1544
|
+
.. math::
|
1545
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1546
|
+
|
1547
|
+
Parameters
|
1548
|
+
----------
|
1549
|
+
norm_type : float
|
1550
|
+
Exponent for the pooling operation. Default: 2.0
|
1551
|
+
kernel_size : int or sequence of int
|
1552
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1553
|
+
stride : int or sequence of int, optional
|
1554
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1555
|
+
Default: kernel_size
|
1556
|
+
padding : str, int or sequence of tuple, optional
|
1557
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1558
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1559
|
+
and after each spatial dimension. Default: 'VALID'
|
1560
|
+
channel_axis : int, optional
|
1561
|
+
Axis of the spatial channels for which pooling is skipped.
|
1562
|
+
If ``None``, there is no channel axis. Default: -1
|
1563
|
+
name : str, optional
|
1564
|
+
The object name.
|
1565
|
+
in_size : Sequence of int, optional
|
1566
|
+
The shape of the input tensor.
|
1567
|
+
|
1568
|
+
Examples
|
1569
|
+
--------
|
1570
|
+
.. code-block:: python
|
1571
|
+
|
1572
|
+
>>> import brainstate
|
1573
|
+
>>> # power-average pooling of square window of size=3, stride=2
|
1574
|
+
>>> m = LPPool2d(2, 3, stride=2)
|
1575
|
+
>>> # pool of non-square window with norm_type=1.5
|
1576
|
+
>>> m = LPPool2d(1.5, (3, 2), stride=(2, 1), channel_axis=-1)
|
1577
|
+
>>> input = brainstate.random.randn(20, 50, 32, 16)
|
1578
|
+
>>> output = m(input)
|
1579
|
+
>>> output.shape
|
1580
|
+
(20, 24, 31, 16)
|
784
1581
|
"""
|
785
1582
|
__module__ = 'brainstate.nn'
|
786
1583
|
|
787
1584
|
def __init__(
|
788
1585
|
self,
|
1586
|
+
norm_type: float,
|
789
1587
|
kernel_size: Size,
|
790
|
-
stride: Union[int, Sequence[int]] =
|
1588
|
+
stride: Union[int, Sequence[int]] = None,
|
791
1589
|
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
|
792
1590
|
channel_axis: Optional[int] = -1,
|
793
1591
|
name: Optional[str] = None,
|
794
1592
|
in_size: Optional[Size] = None,
|
795
1593
|
):
|
796
|
-
super().__init__(
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
1594
|
+
super().__init__(
|
1595
|
+
norm_type=norm_type,
|
1596
|
+
pool_dim=2,
|
1597
|
+
kernel_size=kernel_size,
|
1598
|
+
stride=stride,
|
1599
|
+
padding=padding,
|
1600
|
+
channel_axis=channel_axis,
|
1601
|
+
name=name,
|
1602
|
+
in_size=in_size
|
1603
|
+
)
|
1604
|
+
|
1605
|
+
|
1606
|
+
class LPPool3d(_LPPool):
|
1607
|
+
r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
|
1608
|
+
|
1609
|
+
On each window, the function computed is:
|
1610
|
+
|
1611
|
+
.. math::
|
1612
|
+
f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
|
1613
|
+
|
1614
|
+
- At :math:`p = \infty`, one gets max pooling
|
1615
|
+
- At :math:`p = 1`, one gets average pooling (with absolute values)
|
1616
|
+
- At :math:`p = 2`, one gets root mean square (RMS) pooling
|
1617
|
+
|
1618
|
+
Shape:
|
1619
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1620
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
|
1621
|
+
|
1622
|
+
.. math::
|
1623
|
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
1624
|
+
|
1625
|
+
.. math::
|
1626
|
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
1627
|
+
|
1628
|
+
.. math::
|
1629
|
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
1630
|
+
|
1631
|
+
Parameters
|
1632
|
+
----------
|
1633
|
+
norm_type : float
|
1634
|
+
Exponent for the pooling operation. Default: 2.0
|
1635
|
+
kernel_size : int or sequence of int
|
1636
|
+
An integer, or a sequence of integers defining the window to reduce over.
|
1637
|
+
stride : int or sequence of int, optional
|
1638
|
+
An integer, or a sequence of integers, representing the inter-window stride.
|
1639
|
+
Default: kernel_size
|
1640
|
+
padding : str, int or sequence of tuple, optional
|
1641
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence
|
1642
|
+
of n `(low, high)` integer pairs that give the padding to apply before
|
1643
|
+
and after each spatial dimension. Default: 'VALID'
|
1644
|
+
channel_axis : int, optional
|
1645
|
+
Axis of the spatial channels for which pooling is skipped.
|
1646
|
+
If ``None``, there is no channel axis. Default: -1
|
1647
|
+
name : str, optional
|
1648
|
+
The object name.
|
1649
|
+
in_size : Sequence of int, optional
|
1650
|
+
The shape of the input tensor.
|
1651
|
+
|
1652
|
+
Examples
|
1653
|
+
--------
|
1654
|
+
.. code-block:: python
|
1655
|
+
|
1656
|
+
>>> import brainstate
|
1657
|
+
>>> # power-average pooling of cube window of size=3, stride=2
|
1658
|
+
>>> m = LPPool3d(2, 3, stride=2)
|
1659
|
+
>>> # pool of non-cubic window with norm_type=1.5
|
1660
|
+
>>> m = LPPool3d(1.5, (3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
|
1661
|
+
>>> input = brainstate.random.randn(20, 50, 44, 31, 16)
|
1662
|
+
>>> output = m(input)
|
1663
|
+
>>> output.shape
|
1664
|
+
(20, 24, 43, 15, 16)
|
1665
|
+
"""
|
1666
|
+
__module__ = 'brainstate.nn'
|
1667
|
+
|
1668
|
+
def __init__(
|
1669
|
+
self,
|
1670
|
+
norm_type: float,
|
1671
|
+
kernel_size: Size,
|
1672
|
+
stride: Union[int, Sequence[int]] = None,
|
1673
|
+
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
|
1674
|
+
channel_axis: Optional[int] = -1,
|
1675
|
+
name: Optional[str] = None,
|
1676
|
+
in_size: Optional[Size] = None,
|
1677
|
+
):
|
1678
|
+
super().__init__(
|
1679
|
+
norm_type=norm_type,
|
1680
|
+
pool_dim=3,
|
1681
|
+
kernel_size=kernel_size,
|
1682
|
+
stride=stride,
|
1683
|
+
padding=padding,
|
1684
|
+
channel_axis=channel_axis,
|
1685
|
+
name=name,
|
1686
|
+
in_size=in_size
|
1687
|
+
)
|
805
1688
|
|
806
1689
|
|
807
1690
|
def _adaptive_pool1d(x, target_size: int, operation: Callable):
|
@@ -919,259 +1802,438 @@ class _AdaptivePool(Module):
|
|
919
1802
|
|
920
1803
|
|
921
1804
|
class AdaptiveAvgPool1d(_AdaptivePool):
|
922
|
-
r"""Applies a 1D adaptive
|
1805
|
+
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
923
1806
|
|
924
1807
|
The output size is :math:`L_{out}`, for any input size.
|
925
1808
|
The number of output features is equal to the number of input planes.
|
926
1809
|
|
1810
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1811
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1812
|
+
|
927
1813
|
Shape:
|
928
1814
|
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
929
1815
|
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
930
|
-
:math:`L_{out}=\text{
|
1816
|
+
:math:`L_{out}=\text{target\_size}`.
|
931
1817
|
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
1818
|
+
Parameters
|
1819
|
+
----------
|
1820
|
+
target_size : int or sequence of int
|
1821
|
+
The target output size. The number of output features for each channel.
|
1822
|
+
channel_axis : int, optional
|
1823
|
+
Axis of the spatial channels for which pooling is skipped.
|
1824
|
+
If ``None``, there is no channel axis. Default: -1
|
1825
|
+
name : str, optional
|
1826
|
+
The name of the module.
|
1827
|
+
in_size : Sequence of int, optional
|
1828
|
+
The shape of the input tensor for shape inference.
|
1829
|
+
|
1830
|
+
Examples
|
1831
|
+
--------
|
1832
|
+
.. code-block:: python
|
1833
|
+
|
1834
|
+
>>> import brainstate
|
1835
|
+
>>> # Target output size of 5
|
1836
|
+
>>> m = AdaptiveAvgPool1d(5)
|
937
1837
|
>>> input = brainstate.random.randn(1, 64, 8)
|
938
1838
|
>>> output = m(input)
|
939
1839
|
>>> output.shape
|
940
1840
|
(1, 5, 8)
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
name: str
|
952
|
-
The class name.
|
1841
|
+
>>> # Can handle variable input sizes
|
1842
|
+
>>> input2 = brainstate.random.randn(1, 32, 8)
|
1843
|
+
>>> output2 = m(input2)
|
1844
|
+
>>> output2.shape
|
1845
|
+
(1, 5, 8) # Same output size regardless of input size
|
1846
|
+
|
1847
|
+
See Also
|
1848
|
+
--------
|
1849
|
+
AvgPool1d : Non-adaptive 1D average pooling.
|
1850
|
+
AdaptiveMaxPool1d : Adaptive 1D max pooling.
|
953
1851
|
"""
|
954
1852
|
__module__ = 'brainstate.nn'
|
955
1853
|
|
956
|
-
def __init__(
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
1854
|
+
def __init__(
|
1855
|
+
self,
|
1856
|
+
target_size: Union[int, Sequence[int]],
|
1857
|
+
channel_axis: Optional[int] = -1,
|
1858
|
+
name: Optional[str] = None,
|
1859
|
+
in_size: Optional[Sequence[int]] = None,
|
1860
|
+
):
|
1861
|
+
super().__init__(
|
1862
|
+
in_size=in_size,
|
1863
|
+
target_size=target_size,
|
1864
|
+
channel_axis=channel_axis,
|
1865
|
+
num_spatial_dims=1,
|
1866
|
+
operation=jnp.mean,
|
1867
|
+
name=name
|
1868
|
+
)
|
967
1869
|
|
968
1870
|
|
969
1871
|
class AdaptiveAvgPool2d(_AdaptivePool):
|
970
|
-
r"""Applies a 2D adaptive
|
1872
|
+
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
971
1873
|
|
972
1874
|
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
973
1875
|
The number of output features is equal to the number of input planes.
|
974
1876
|
|
1877
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1878
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1879
|
+
|
975
1880
|
Shape:
|
976
1881
|
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
977
1882
|
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
978
|
-
:math:`(H_{out}, W_{out})=\text{
|
979
|
-
|
980
|
-
Examples:
|
1883
|
+
:math:`(H_{out}, W_{out})=\text{target\_size}`.
|
981
1884
|
|
982
|
-
|
983
|
-
|
984
|
-
|
1885
|
+
Parameters
|
1886
|
+
----------
|
1887
|
+
target_size : int or tuple of int
|
1888
|
+
The target output size. If a single integer is provided, the output will be a square
|
1889
|
+
of that size. If a tuple is provided, it specifies (H_out, W_out).
|
1890
|
+
Use None for dimensions that should not be pooled.
|
1891
|
+
channel_axis : int, optional
|
1892
|
+
Axis of the spatial channels for which pooling is skipped.
|
1893
|
+
If ``None``, there is no channel axis. Default: -1
|
1894
|
+
name : str, optional
|
1895
|
+
The name of the module.
|
1896
|
+
in_size : Sequence of int, optional
|
1897
|
+
The shape of the input tensor for shape inference.
|
1898
|
+
|
1899
|
+
Examples
|
1900
|
+
--------
|
1901
|
+
.. code-block:: python
|
1902
|
+
|
1903
|
+
>>> import brainstate
|
1904
|
+
>>> # Target output size of 5x7
|
1905
|
+
>>> m = AdaptiveAvgPool2d((5, 7))
|
985
1906
|
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
986
1907
|
>>> output = m(input)
|
987
1908
|
>>> output.shape
|
988
1909
|
(1, 5, 7, 64)
|
989
|
-
>>> #
|
990
|
-
>>> m =
|
1910
|
+
>>> # Target output size of 7x7 (square)
|
1911
|
+
>>> m = AdaptiveAvgPool2d(7)
|
991
1912
|
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
992
1913
|
>>> output = m(input)
|
993
1914
|
>>> output.shape
|
994
1915
|
(1, 7, 7, 64)
|
995
|
-
>>> #
|
996
|
-
>>> m =
|
1916
|
+
>>> # Target output size of 10x7
|
1917
|
+
>>> m = AdaptiveAvgPool2d((None, 7))
|
997
1918
|
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
998
1919
|
>>> output = m(input)
|
999
1920
|
>>> output.shape
|
1000
1921
|
(1, 10, 7, 64)
|
1001
1922
|
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
target_size: int, sequence of int
|
1007
|
-
The target output shape.
|
1008
|
-
channel_axis: int, optional
|
1009
|
-
Axis of the spatial channels for which pooling is skipped.
|
1010
|
-
If ``None``, there is no channel axis.
|
1011
|
-
name: str
|
1012
|
-
The class name.
|
1923
|
+
See Also
|
1924
|
+
--------
|
1925
|
+
AvgPool2d : Non-adaptive 2D average pooling.
|
1926
|
+
AdaptiveMaxPool2d : Adaptive 2D max pooling.
|
1013
1927
|
"""
|
1014
1928
|
__module__ = 'brainstate.nn'
|
1015
1929
|
|
1016
|
-
def __init__(
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1930
|
+
def __init__(
|
1931
|
+
self,
|
1932
|
+
target_size: Union[int, Sequence[int]],
|
1933
|
+
channel_axis: Optional[int] = -1,
|
1934
|
+
name: Optional[str] = None,
|
1935
|
+
in_size: Optional[Sequence[int]] = None,
|
1936
|
+
):
|
1937
|
+
super().__init__(
|
1938
|
+
in_size=in_size,
|
1939
|
+
target_size=target_size,
|
1940
|
+
channel_axis=channel_axis,
|
1941
|
+
num_spatial_dims=2,
|
1942
|
+
operation=jnp.mean,
|
1943
|
+
name=name
|
1944
|
+
)
|
1028
1945
|
|
1029
1946
|
|
1030
1947
|
class AdaptiveAvgPool3d(_AdaptivePool):
|
1031
|
-
r"""Applies a 3D adaptive
|
1948
|
+
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
|
1032
1949
|
|
1033
1950
|
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
1034
1951
|
The number of output features is equal to the number of input planes.
|
1035
1952
|
|
1953
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
1954
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
1955
|
+
|
1036
1956
|
Shape:
|
1037
1957
|
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
1038
1958
|
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
1039
|
-
where :math:`(D_{out}, H_{out}, W_{out})=\text{
|
1040
|
-
|
1041
|
-
Examples:
|
1959
|
+
where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
|
1042
1960
|
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1961
|
+
Parameters
|
1962
|
+
----------
|
1963
|
+
target_size : int or tuple of int
|
1964
|
+
The target output size. If a single integer is provided, the output will be a cube
|
1965
|
+
of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
|
1966
|
+
Use None for dimensions that should not be pooled.
|
1967
|
+
channel_axis : int, optional
|
1968
|
+
Axis of the spatial channels for which pooling is skipped.
|
1969
|
+
If ``None``, there is no channel axis. Default: -1
|
1970
|
+
name : str, optional
|
1971
|
+
The name of the module.
|
1972
|
+
in_size : Sequence of int, optional
|
1973
|
+
The shape of the input tensor for shape inference.
|
1974
|
+
|
1975
|
+
Examples
|
1976
|
+
--------
|
1977
|
+
.. code-block:: python
|
1978
|
+
|
1979
|
+
>>> import brainstate
|
1980
|
+
>>> # Target output size of 5x7x9
|
1981
|
+
>>> m = AdaptiveAvgPool3d((5, 7, 9))
|
1046
1982
|
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
1047
1983
|
>>> output = m(input)
|
1048
1984
|
>>> output.shape
|
1049
1985
|
(1, 5, 7, 9, 64)
|
1050
|
-
>>> #
|
1051
|
-
>>> m =
|
1986
|
+
>>> # Target output size of 7x7x7 (cube)
|
1987
|
+
>>> m = AdaptiveAvgPool3d(7)
|
1052
1988
|
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1053
1989
|
>>> output = m(input)
|
1054
1990
|
>>> output.shape
|
1055
1991
|
(1, 7, 7, 7, 64)
|
1056
|
-
>>> #
|
1057
|
-
>>> m =
|
1992
|
+
>>> # Target output size of 7x9x8
|
1993
|
+
>>> m = AdaptiveAvgPool3d((7, None, None))
|
1058
1994
|
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
1059
1995
|
>>> output = m(input)
|
1060
1996
|
>>> output.shape
|
1061
1997
|
(1, 7, 9, 8, 64)
|
1062
1998
|
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
target_size: int, sequence of int
|
1068
|
-
The target output shape.
|
1069
|
-
channel_axis: int, optional
|
1070
|
-
Axis of the spatial channels for which pooling is skipped.
|
1071
|
-
If ``None``, there is no channel axis.
|
1072
|
-
name: str
|
1073
|
-
The class name.
|
1999
|
+
See Also
|
2000
|
+
--------
|
2001
|
+
AvgPool3d : Non-adaptive 3D average pooling.
|
2002
|
+
AdaptiveMaxPool3d : Adaptive 3D max pooling.
|
1074
2003
|
"""
|
1075
2004
|
__module__ = 'brainstate.nn'
|
1076
2005
|
|
1077
|
-
def __init__(
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
2006
|
+
def __init__(
|
2007
|
+
self,
|
2008
|
+
target_size: Union[int, Sequence[int]],
|
2009
|
+
channel_axis: Optional[int] = -1,
|
2010
|
+
name: Optional[str] = None,
|
2011
|
+
in_size: Optional[Sequence[int]] = None,
|
2012
|
+
):
|
2013
|
+
super().__init__(
|
2014
|
+
in_size=in_size,
|
2015
|
+
target_size=target_size,
|
2016
|
+
channel_axis=channel_axis,
|
2017
|
+
num_spatial_dims=3,
|
2018
|
+
operation=jnp.mean,
|
2019
|
+
name=name
|
2020
|
+
)
|
1088
2021
|
|
1089
2022
|
|
1090
2023
|
class AdaptiveMaxPool1d(_AdaptivePool):
|
1091
|
-
"""
|
2024
|
+
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
2025
|
+
|
2026
|
+
The output size is :math:`L_{out}`, for any input size.
|
2027
|
+
The number of output features is equal to the number of input planes.
|
2028
|
+
|
2029
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2030
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2031
|
+
|
2032
|
+
Shape:
|
2033
|
+
- Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
|
2034
|
+
- Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
|
2035
|
+
:math:`L_{out}=\text{target\_size}`.
|
1092
2036
|
|
1093
2037
|
Parameters
|
1094
2038
|
----------
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
2039
|
+
target_size : int or sequence of int
|
2040
|
+
The target output size. The number of output features for each channel.
|
2041
|
+
channel_axis : int, optional
|
2042
|
+
Axis of the spatial channels for which pooling is skipped.
|
2043
|
+
If ``None``, there is no channel axis. Default: -1
|
2044
|
+
name : str, optional
|
2045
|
+
The name of the module.
|
2046
|
+
in_size : Sequence of int, optional
|
2047
|
+
The shape of the input tensor for shape inference.
|
2048
|
+
|
2049
|
+
Examples
|
2050
|
+
--------
|
2051
|
+
.. code-block:: python
|
2052
|
+
|
2053
|
+
>>> import brainstate
|
2054
|
+
>>> # Target output size of 5
|
2055
|
+
>>> m = AdaptiveMaxPool1d(5)
|
2056
|
+
>>> input = brainstate.random.randn(1, 64, 8)
|
2057
|
+
>>> output = m(input)
|
2058
|
+
>>> output.shape
|
2059
|
+
(1, 5, 8)
|
2060
|
+
>>> # Can handle variable input sizes
|
2061
|
+
>>> input2 = brainstate.random.randn(1, 32, 8)
|
2062
|
+
>>> output2 = m(input2)
|
2063
|
+
>>> output2.shape
|
2064
|
+
(1, 5, 8) # Same output size regardless of input size
|
2065
|
+
|
2066
|
+
See Also
|
2067
|
+
--------
|
2068
|
+
MaxPool1d : Non-adaptive 1D max pooling.
|
2069
|
+
AdaptiveAvgPool1d : Adaptive 1D average pooling.
|
1104
2070
|
"""
|
1105
2071
|
__module__ = 'brainstate.nn'
|
1106
2072
|
|
1107
|
-
def __init__(
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
2073
|
+
def __init__(
|
2074
|
+
self,
|
2075
|
+
target_size: Union[int, Sequence[int]],
|
2076
|
+
channel_axis: Optional[int] = -1,
|
2077
|
+
name: Optional[str] = None,
|
2078
|
+
in_size: Optional[Sequence[int]] = None,
|
2079
|
+
):
|
2080
|
+
super().__init__(
|
2081
|
+
in_size=in_size,
|
2082
|
+
target_size=target_size,
|
2083
|
+
channel_axis=channel_axis,
|
2084
|
+
num_spatial_dims=1,
|
2085
|
+
operation=jnp.max,
|
2086
|
+
name=name
|
2087
|
+
)
|
1118
2088
|
|
1119
2089
|
|
1120
2090
|
class AdaptiveMaxPool2d(_AdaptivePool):
|
1121
|
-
"""
|
2091
|
+
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
2092
|
+
|
2093
|
+
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
2094
|
+
The number of output features is equal to the number of input planes.
|
2095
|
+
|
2096
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2097
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2098
|
+
|
2099
|
+
Shape:
|
2100
|
+
- Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
|
2101
|
+
- Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
|
2102
|
+
:math:`(H_{out}, W_{out})=\text{target\_size}`.
|
1122
2103
|
|
1123
2104
|
Parameters
|
1124
2105
|
----------
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
channel_axis: int, optional
|
1130
|
-
|
1131
|
-
|
1132
|
-
name: str
|
1133
|
-
|
2106
|
+
target_size : int or tuple of int
|
2107
|
+
The target output size. If a single integer is provided, the output will be a square
|
2108
|
+
of that size. If a tuple is provided, it specifies (H_out, W_out).
|
2109
|
+
Use None for dimensions that should not be pooled.
|
2110
|
+
channel_axis : int, optional
|
2111
|
+
Axis of the spatial channels for which pooling is skipped.
|
2112
|
+
If ``None``, there is no channel axis. Default: -1
|
2113
|
+
name : str, optional
|
2114
|
+
The name of the module.
|
2115
|
+
in_size : Sequence of int, optional
|
2116
|
+
The shape of the input tensor for shape inference.
|
2117
|
+
|
2118
|
+
Examples
|
2119
|
+
--------
|
2120
|
+
.. code-block:: python
|
2121
|
+
|
2122
|
+
>>> import brainstate
|
2123
|
+
>>> # Target output size of 5x7
|
2124
|
+
>>> m = AdaptiveMaxPool2d((5, 7))
|
2125
|
+
>>> input = brainstate.random.randn(1, 8, 9, 64)
|
2126
|
+
>>> output = m(input)
|
2127
|
+
>>> output.shape
|
2128
|
+
(1, 5, 7, 64)
|
2129
|
+
>>> # Target output size of 7x7 (square)
|
2130
|
+
>>> m = AdaptiveMaxPool2d(7)
|
2131
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
2132
|
+
>>> output = m(input)
|
2133
|
+
>>> output.shape
|
2134
|
+
(1, 7, 7, 64)
|
2135
|
+
>>> # Target output size of 10x7
|
2136
|
+
>>> m = AdaptiveMaxPool2d((None, 7))
|
2137
|
+
>>> input = brainstate.random.randn(1, 10, 9, 64)
|
2138
|
+
>>> output = m(input)
|
2139
|
+
>>> output.shape
|
2140
|
+
(1, 10, 7, 64)
|
2141
|
+
|
2142
|
+
See Also
|
2143
|
+
--------
|
2144
|
+
MaxPool2d : Non-adaptive 2D max pooling.
|
2145
|
+
AdaptiveAvgPool2d : Adaptive 2D average pooling.
|
1134
2146
|
"""
|
1135
2147
|
__module__ = 'brainstate.nn'
|
1136
2148
|
|
1137
|
-
def __init__(
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
2149
|
+
def __init__(
|
2150
|
+
self,
|
2151
|
+
target_size: Union[int, Sequence[int]],
|
2152
|
+
channel_axis: Optional[int] = -1,
|
2153
|
+
name: Optional[str] = None,
|
2154
|
+
in_size: Optional[Sequence[int]] = None,
|
2155
|
+
):
|
2156
|
+
super().__init__(
|
2157
|
+
in_size=in_size,
|
2158
|
+
target_size=target_size,
|
2159
|
+
channel_axis=channel_axis,
|
2160
|
+
num_spatial_dims=2,
|
2161
|
+
operation=jnp.max,
|
2162
|
+
name=name
|
2163
|
+
)
|
1148
2164
|
|
1149
2165
|
|
1150
2166
|
class AdaptiveMaxPool3d(_AdaptivePool):
|
1151
|
-
"""
|
2167
|
+
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
2168
|
+
|
2169
|
+
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
2170
|
+
The number of output features is equal to the number of input planes.
|
2171
|
+
|
2172
|
+
Adaptive pooling automatically computes the kernel size and stride to achieve the desired
|
2173
|
+
output size, making it useful for creating fixed-size representations from variable-sized inputs.
|
2174
|
+
|
2175
|
+
Shape:
|
2176
|
+
- Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
|
2177
|
+
- Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
|
2178
|
+
where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
|
1152
2179
|
|
1153
2180
|
Parameters
|
1154
2181
|
----------
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
channel_axis: int, optional
|
1160
|
-
|
1161
|
-
|
1162
|
-
name: str
|
1163
|
-
|
2182
|
+
target_size : int or tuple of int
|
2183
|
+
The target output size. If a single integer is provided, the output will be a cube
|
2184
|
+
of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
|
2185
|
+
Use None for dimensions that should not be pooled.
|
2186
|
+
channel_axis : int, optional
|
2187
|
+
Axis of the spatial channels for which pooling is skipped.
|
2188
|
+
If ``None``, there is no channel axis. Default: -1
|
2189
|
+
name : str, optional
|
2190
|
+
The name of the module.
|
2191
|
+
in_size : Sequence of int, optional
|
2192
|
+
The shape of the input tensor for shape inference.
|
2193
|
+
|
2194
|
+
Examples
|
2195
|
+
--------
|
2196
|
+
.. code-block:: python
|
2197
|
+
|
2198
|
+
>>> import brainstate
|
2199
|
+
>>> # Target output size of 5x7x9
|
2200
|
+
>>> m = AdaptiveMaxPool3d((5, 7, 9))
|
2201
|
+
>>> input = brainstate.random.randn(1, 8, 9, 10, 64)
|
2202
|
+
>>> output = m(input)
|
2203
|
+
>>> output.shape
|
2204
|
+
(1, 5, 7, 9, 64)
|
2205
|
+
>>> # Target output size of 7x7x7 (cube)
|
2206
|
+
>>> m = AdaptiveMaxPool3d(7)
|
2207
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
2208
|
+
>>> output = m(input)
|
2209
|
+
>>> output.shape
|
2210
|
+
(1, 7, 7, 7, 64)
|
2211
|
+
>>> # Target output size of 7x9x8
|
2212
|
+
>>> m = AdaptiveMaxPool3d((7, None, None))
|
2213
|
+
>>> input = brainstate.random.randn(1, 10, 9, 8, 64)
|
2214
|
+
>>> output = m(input)
|
2215
|
+
>>> output.shape
|
2216
|
+
(1, 7, 9, 8, 64)
|
2217
|
+
|
2218
|
+
See Also
|
2219
|
+
--------
|
2220
|
+
MaxPool3d : Non-adaptive 3D max pooling.
|
2221
|
+
AdaptiveAvgPool3d : Adaptive 3D average pooling.
|
1164
2222
|
"""
|
1165
2223
|
__module__ = 'brainstate.nn'
|
1166
2224
|
|
1167
|
-
def __init__(
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
2225
|
+
def __init__(
|
2226
|
+
self,
|
2227
|
+
target_size: Union[int, Sequence[int]],
|
2228
|
+
channel_axis: Optional[int] = -1,
|
2229
|
+
name: Optional[str] = None,
|
2230
|
+
in_size: Optional[Sequence[int]] = None,
|
2231
|
+
):
|
2232
|
+
super().__init__(
|
2233
|
+
in_size=in_size,
|
2234
|
+
target_size=target_size,
|
2235
|
+
channel_axis=channel_axis,
|
2236
|
+
num_spatial_dims=3,
|
2237
|
+
operation=jnp.max,
|
2238
|
+
name=name
|
2239
|
+
)
|