brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,1177 +1,1177 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import functools
19
- from typing import Sequence, Optional
20
- from typing import Union, Tuple, Callable, List
21
-
22
- import brainunit as u
23
- import jax
24
- import jax.numpy as jnp
25
- import numpy as np
26
-
27
- from brainstate import environ
28
- from brainstate.typing import Size
29
- from ._module import Module
30
-
31
- __all__ = [
32
- 'Flatten', 'Unflatten',
33
-
34
- 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
35
- 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
36
-
37
- 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
38
- 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
39
- ]
40
-
41
-
42
- class Flatten(Module):
43
- r"""
44
- Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
45
-
46
- Shape:
47
- - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
48
- where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
49
- number of dimensions including none.
50
- - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
51
-
52
- Args:
53
- in_size: Sequence of int. The shape of the input tensor.
54
- start_axis: first dim to flatten (default = 1).
55
- end_axis: last dim to flatten (default = -1).
56
-
57
- Examples::
58
- >>> import brainstate as brainstate
59
- >>> inp = brainstate.random.randn(32, 1, 5, 5)
60
- >>> # With default parameters
61
- >>> m = Flatten()
62
- >>> output = m(inp)
63
- >>> output.shape
64
- (32, 25)
65
- >>> # With non-default parameters
66
- >>> m = Flatten(0, 2)
67
- >>> output = m(inp)
68
- >>> output.shape
69
- (160, 5)
70
- """
71
- __module__ = 'brainstate.nn'
72
-
73
- def __init__(
74
- self,
75
- start_axis: int = 0,
76
- end_axis: int = -1,
77
- in_size: Optional[Size] = None
78
- ) -> None:
79
- super().__init__()
80
- self.start_axis = start_axis
81
- self.end_axis = end_axis
82
-
83
- if in_size is not None:
84
- self.in_size = tuple(in_size)
85
- y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
86
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
87
- self.out_size = y.shape
88
-
89
- def update(self, x):
90
- if self._in_size is None:
91
- start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
92
- else:
93
- assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
94
- dim_diff = x.ndim - len(self.in_size)
95
- if self.in_size != x.shape[dim_diff:]:
96
- raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
97
- if self.start_axis >= 0:
98
- start_axis = self.start_axis + dim_diff
99
- else:
100
- start_axis = x.ndim + self.start_axis
101
- return u.math.flatten(x, start_axis, self.end_axis)
102
-
103
- def __repr__(self) -> str:
104
- return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
105
-
106
-
107
- class Unflatten(Module):
108
- r"""
109
- Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
110
-
111
- * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
112
- be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
113
-
114
- * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
115
- a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
116
- (tuple of `(name, size)` tuples) for `NamedTensor` input.
117
-
118
- Shape:
119
- - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
120
- dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
121
- - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
122
- :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
123
-
124
- Args:
125
- axis: int, Dimension to be unflattened.
126
- sizes: Sequence of int. New shape of the unflattened dimension.
127
- in_size: Sequence of int. The shape of the input tensor.
128
- """
129
- __module__ = 'brainstate.nn'
130
-
131
- def __init__(
132
- self,
133
- axis: int,
134
- sizes: Size,
135
- name: str = None,
136
- in_size: Optional[Size] = None
137
- ) -> None:
138
- super().__init__(name=name)
139
-
140
- self.axis = axis
141
- self.sizes = sizes
142
- if isinstance(sizes, (tuple, list)):
143
- for idx, elem in enumerate(sizes):
144
- if not isinstance(elem, int):
145
- raise TypeError("unflattened sizes must be tuple of ints, " +
146
- "but found element of type {} at pos {}".format(type(elem).__name__, idx))
147
- else:
148
- raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
149
-
150
- if in_size is not None:
151
- self.in_size = tuple(in_size)
152
- y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
153
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
154
- self.out_size = y.shape
155
-
156
- def update(self, x):
157
- return u.math.unflatten(x, self.axis, self.sizes)
158
-
159
- def __repr__(self):
160
- return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
161
-
162
-
163
- class _MaxPool(Module):
164
- def __init__(
165
- self,
166
- init_value: float,
167
- computation: Callable,
168
- pool_dim: int,
169
- kernel_size: Size,
170
- stride: Union[int, Sequence[int]] = None,
171
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
172
- channel_axis: Optional[int] = -1,
173
- name: Optional[str] = None,
174
- in_size: Optional[Size] = None,
175
- ):
176
- super().__init__(name=name)
177
-
178
- self.init_value = init_value
179
- self.computation = computation
180
- self.pool_dim = pool_dim
181
-
182
- # kernel_size
183
- if isinstance(kernel_size, int):
184
- kernel_size = (kernel_size,) * pool_dim
185
- elif isinstance(kernel_size, Sequence):
186
- assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
187
- assert all(
188
- [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
189
- if len(kernel_size) != pool_dim:
190
- raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
191
- else:
192
- raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
193
- self.kernel_size = kernel_size
194
-
195
- # stride
196
- if stride is None:
197
- stride = kernel_size
198
- if isinstance(stride, int):
199
- stride = (stride,) * pool_dim
200
- elif isinstance(stride, Sequence):
201
- assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
202
- assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
203
- if len(stride) != pool_dim:
204
- raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
205
- else:
206
- raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
207
- self.stride = stride
208
-
209
- # padding
210
- if isinstance(padding, str):
211
- if padding not in ("SAME", "VALID"):
212
- raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
213
- elif isinstance(padding, int):
214
- padding = [(padding, padding) for _ in range(pool_dim)]
215
- elif isinstance(padding, (list, tuple)):
216
- if isinstance(padding[0], int):
217
- if len(padding) == pool_dim:
218
- padding = [(x, x) for x in padding]
219
- else:
220
- raise ValueError(f'If padding is a sequence of ints, it '
221
- f'should has the length of {pool_dim}.')
222
- else:
223
- if not all([isinstance(x, (tuple, list)) for x in padding]):
224
- raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
225
- if not all([len(x) == 2 for x in padding]):
226
- raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
227
- if len(padding) == 1:
228
- padding = tuple(padding) * pool_dim
229
- assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
230
- else:
231
- raise ValueError
232
- self.padding = padding
233
-
234
- # channel_axis
235
- assert channel_axis is None or isinstance(channel_axis, int), \
236
- f'channel_axis should be an int, but got {channel_axis}'
237
- self.channel_axis = channel_axis
238
-
239
- # in & out shapes
240
- if in_size is not None:
241
- in_size = tuple(in_size)
242
- self.in_size = in_size
243
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
244
- self.out_size = y.shape[1:]
245
-
246
- def update(self, x):
247
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
248
- if x.ndim < x_dim:
249
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
250
- window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
251
- stride = self._infer_shape(x.ndim, self.stride, 1)
252
- padding = (self.padding if isinstance(self.padding, str) else
253
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
254
- r = jax.lax.reduce_window(
255
- x,
256
- init_value=self.init_value,
257
- computation=self.computation,
258
- window_dimensions=window_shape,
259
- window_strides=stride,
260
- padding=padding
261
- )
262
- return r
263
-
264
- def _infer_shape(self, x_dim, inputs, element):
265
- channel_axis = self.channel_axis
266
- if channel_axis and not 0 <= abs(channel_axis) < x_dim:
267
- raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
268
- if channel_axis and channel_axis < 0:
269
- channel_axis = x_dim + channel_axis
270
- all_dims = list(range(x_dim))
271
- if channel_axis is not None:
272
- all_dims.pop(channel_axis)
273
- pool_dims = all_dims[-self.pool_dim:]
274
- results = [element] * x_dim
275
- for i, dim in enumerate(pool_dims):
276
- results[dim] = inputs[i]
277
- return results
278
-
279
-
280
- class _AvgPool(_MaxPool):
281
- def update(self, x):
282
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
283
- if x.ndim < x_dim:
284
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
285
- dims = self._infer_shape(x.ndim, self.kernel_size, 1)
286
- stride = self._infer_shape(x.ndim, self.stride, 1)
287
- padding = (self.padding if isinstance(self.padding, str) else
288
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
289
- pooled = jax.lax.reduce_window(x,
290
- init_value=self.init_value,
291
- computation=self.computation,
292
- window_dimensions=dims,
293
- window_strides=stride,
294
- padding=padding)
295
- if padding == "VALID":
296
- # Avoid the extra reduce_window.
297
- return pooled / np.prod(dims)
298
- else:
299
- # Count the number of valid entries at each input point, then use that for
300
- # computing average. Assumes that any two arrays of same shape will be
301
- # padded the same.
302
- window_counts = jax.lax.reduce_window(jnp.ones_like(x),
303
- init_value=self.init_value,
304
- computation=self.computation,
305
- window_dimensions=dims,
306
- window_strides=stride,
307
- padding=padding)
308
- assert pooled.shape == window_counts.shape
309
- return pooled / window_counts
310
-
311
-
312
- class MaxPool1d(_MaxPool):
313
- r"""Applies a 1D max pooling over an input signal composed of several input planes.
314
-
315
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
316
- and output :math:`(N, L_{out}, C)` can be precisely described as:
317
-
318
- .. math::
319
- out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
320
- input(N_i, stride \times k + m, C_j)
321
-
322
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
323
- for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
324
- sliding window. This `link`_ has a nice visualization of the pooling parameters.
325
-
326
- Shape:
327
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
328
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
329
-
330
- .. math::
331
- L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
332
- \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
333
-
334
-
335
- Examples::
336
-
337
- >>> import brainstate as brainstate
338
- >>> # pool of size=3, stride=2
339
- >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
340
- >>> input = brainstate.random.randn(20, 50, 16)
341
- >>> output = m(input)
342
- >>> output.shape
343
- (20, 24, 16)
344
-
345
- Parameters
346
- ----------
347
- in_size: Sequence of int
348
- The shape of the input tensor.
349
- kernel_size: int, sequence of int
350
- An integer, or a sequence of integers defining the window to reduce over.
351
- stride: int, sequence of int
352
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
353
- padding: str, int, sequence of tuple
354
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
355
- of n `(low, high)` integer pairs that give the padding to apply before
356
- and after each spatial dimension.
357
- channel_axis: int, optional
358
- Axis of the spatial channels for which pooling is skipped.
359
- If ``None``, there is no channel axis.
360
- name: optional, str
361
- The object name.
362
-
363
- .. _link:
364
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
365
- """
366
- __module__ = 'brainstate.nn'
367
-
368
- def __init__(
369
- self,
370
- kernel_size: Size,
371
- stride: Union[int, Sequence[int]] = None,
372
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
373
- channel_axis: Optional[int] = -1,
374
- name: Optional[str] = None,
375
- in_size: Optional[Size] = None,
376
- ):
377
- super().__init__(in_size=in_size,
378
- init_value=-jax.numpy.inf,
379
- computation=jax.lax.max,
380
- pool_dim=1,
381
- kernel_size=kernel_size,
382
- stride=stride,
383
- padding=padding,
384
- channel_axis=channel_axis,
385
- name=name)
386
-
387
-
388
- class MaxPool2d(_MaxPool):
389
- r"""Applies a 2D max pooling over an input signal composed of several input planes.
390
-
391
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
392
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
393
- can be precisely described as:
394
-
395
- .. math::
396
- \begin{aligned}
397
- out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
398
- & \text{input}(N_i, \text{stride[0]} \times h + m,
399
- \text{stride[1]} \times w + n, C_j)
400
- \end{aligned}
401
-
402
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
403
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
404
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
405
-
406
-
407
- Shape:
408
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
409
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
410
-
411
- .. math::
412
- H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
413
- \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
414
-
415
- .. math::
416
- W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
417
- \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
418
-
419
- Examples::
420
-
421
- >>> import brainstate as brainstate
422
- >>> # pool of square window of size=3, stride=2
423
- >>> m = MaxPool2d(3, stride=2)
424
- >>> # pool of non-square window
425
- >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
426
- >>> input = brainstate.random.randn(20, 50, 32, 16)
427
- >>> output = m(input)
428
- >>> output.shape
429
- (20, 24, 31, 16)
430
-
431
- .. _link:
432
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
433
-
434
- Parameters
435
- ----------
436
- in_size: Sequence of int
437
- The shape of the input tensor.
438
- kernel_size: int, sequence of int
439
- An integer, or a sequence of integers defining the window to reduce over.
440
- stride: int, sequence of int
441
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
442
- padding: str, int, sequence of tuple
443
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
444
- of n `(low, high)` integer pairs that give the padding to apply before
445
- and after each spatial dimension.
446
- channel_axis: int, optional
447
- Axis of the spatial channels for which pooling is skipped.
448
- If ``None``, there is no channel axis.
449
- name: optional, str
450
- The object name.
451
-
452
- """
453
- __module__ = 'brainstate.nn'
454
-
455
- def __init__(
456
- self,
457
- kernel_size: Size,
458
- stride: Union[int, Sequence[int]] = None,
459
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
460
- channel_axis: Optional[int] = -1,
461
- name: Optional[str] = None,
462
- in_size: Optional[Size] = None,
463
- ):
464
- super().__init__(in_size=in_size,
465
- init_value=-jax.numpy.inf,
466
- computation=jax.lax.max,
467
- pool_dim=2,
468
- kernel_size=kernel_size,
469
- stride=stride,
470
- padding=padding,
471
- channel_axis=channel_axis,
472
- name=name)
473
-
474
-
475
- class MaxPool3d(_MaxPool):
476
- r"""Applies a 3D max pooling over an input signal composed of several input planes.
477
-
478
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
479
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
480
- can be precisely described as:
481
-
482
- .. math::
483
- \begin{aligned}
484
- \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
485
- & \text{input}(N_i, \text{stride[0]} \times d + k,
486
- \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
487
- \end{aligned}
488
-
489
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
490
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
491
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
492
-
493
-
494
- Shape:
495
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
496
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
497
-
498
- .. math::
499
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
500
- (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
501
-
502
- .. math::
503
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
504
- (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
505
-
506
- .. math::
507
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
508
- (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
509
-
510
- Examples::
511
-
512
- >>> import brainstate as brainstate
513
- >>> # pool of square window of size=3, stride=2
514
- >>> m = MaxPool3d(3, stride=2)
515
- >>> # pool of non-square window
516
- >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
517
- >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
518
- >>> output = m(input)
519
- >>> output.shape
520
- (20, 24, 43, 15, 16)
521
-
522
- .. _link:
523
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
524
-
525
- Parameters
526
- ----------
527
- in_size: Sequence of int
528
- The shape of the input tensor.
529
- kernel_size: int, sequence of int
530
- An integer, or a sequence of integers defining the window to reduce over.
531
- stride: int, sequence of int
532
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
533
- padding: str, int, sequence of tuple
534
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
535
- of n `(low, high)` integer pairs that give the padding to apply before
536
- and after each spatial dimension.
537
- channel_axis: int, optional
538
- Axis of the spatial channels for which pooling is skipped.
539
- If ``None``, there is no channel axis.
540
- name: optional, str
541
- The object name.
542
-
543
- """
544
- __module__ = 'brainstate.nn'
545
-
546
- def __init__(
547
- self,
548
- kernel_size: Size,
549
- stride: Union[int, Sequence[int]] = None,
550
- padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
551
- channel_axis: Optional[int] = -1,
552
- name: Optional[str] = None,
553
- in_size: Optional[Size] = None,
554
- ):
555
- super().__init__(in_size=in_size,
556
- init_value=-jax.numpy.inf,
557
- computation=jax.lax.max,
558
- pool_dim=3,
559
- kernel_size=kernel_size,
560
- stride=stride,
561
- padding=padding,
562
- channel_axis=channel_axis,
563
- name=name)
564
-
565
-
566
- class AvgPool1d(_AvgPool):
567
- r"""Applies a 1D average pooling over an input signal composed of several input planes.
568
-
569
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
570
- output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
571
- can be precisely described as:
572
-
573
- .. math::
574
-
575
- \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
576
- \text{input}(N_i, \text{stride} \times l + m, C_j)
577
-
578
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
579
- for :attr:`padding` number of points.
580
-
581
- Shape:
582
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
583
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
584
-
585
- .. math::
586
- L_{out} = \left\lfloor \frac{L_{in} +
587
- 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
588
-
589
- Examples::
590
-
591
- >>> import brainstate as brainstate
592
- >>> # pool with window of size=3, stride=2
593
- >>> m = AvgPool1d(3, stride=2)
594
- >>> input = brainstate.random.randn(20, 50, 16)
595
- >>> m(input).shape
596
- (20, 24, 16)
597
-
598
- Parameters
599
- ----------
600
- in_size: Sequence of int
601
- The shape of the input tensor.
602
- kernel_size: int, sequence of int
603
- An integer, or a sequence of integers defining the window to reduce over.
604
- stride: int, sequence of int
605
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
606
- padding: str, int, sequence of tuple
607
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
608
- of n `(low, high)` integer pairs that give the padding to apply before
609
- and after each spatial dimension.
610
- channel_axis: int, optional
611
- Axis of the spatial channels for which pooling is skipped.
612
- If ``None``, there is no channel axis.
613
- name: optional, str
614
- The object name.
615
-
616
- """
617
- __module__ = 'brainstate.nn'
618
-
619
- def __init__(
620
- self,
621
- kernel_size: Size,
622
- stride: Union[int, Sequence[int]] = 1,
623
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
624
- channel_axis: Optional[int] = -1,
625
- name: Optional[str] = None,
626
- in_size: Optional[Size] = None,
627
- ):
628
- super().__init__(in_size=in_size,
629
- init_value=0.,
630
- computation=jax.lax.add,
631
- pool_dim=1,
632
- kernel_size=kernel_size,
633
- stride=stride,
634
- padding=padding,
635
- channel_axis=channel_axis,
636
- name=name)
637
-
638
-
639
- class AvgPool2d(_AvgPool):
640
- r"""Applies a 2D average pooling over an input signal composed of several input planes.
641
-
642
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
643
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
644
- can be precisely described as:
645
-
646
- .. math::
647
-
648
- out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
649
- input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
650
-
651
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
652
- for :attr:`padding` number of points.
653
-
654
- Shape:
655
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
656
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
657
-
658
- .. math::
659
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
660
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
661
-
662
- .. math::
663
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
664
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
665
-
666
- Examples::
667
-
668
- >>> import brainstate as brainstate
669
- >>> # pool of square window of size=3, stride=2
670
- >>> m = AvgPool2d(3, stride=2)
671
- >>> # pool of non-square window
672
- >>> m = AvgPool2d((3, 2), stride=(2, 1))
673
- >>> input = brainstate.random.randn(20, 50, 32, , 16)
674
- >>> output = m(input)
675
- >>> output.shape
676
- (20, 24, 31, 16)
677
-
678
- Parameters
679
- ----------
680
- in_size: Sequence of int
681
- The shape of the input tensor.
682
- kernel_size: int, sequence of int
683
- An integer, or a sequence of integers defining the window to reduce over.
684
- stride: int, sequence of int
685
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
686
- padding: str, int, sequence of tuple
687
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
688
- of n `(low, high)` integer pairs that give the padding to apply before
689
- and after each spatial dimension.
690
- channel_axis: int, optional
691
- Axis of the spatial channels for which pooling is skipped.
692
- If ``None``, there is no channel axis.
693
- name: optional, str
694
- The object name.
695
- """
696
- __module__ = 'brainstate.nn'
697
-
698
- def __init__(
699
- self,
700
- kernel_size: Size,
701
- stride: Union[int, Sequence[int]] = 1,
702
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
703
- channel_axis: Optional[int] = -1,
704
- name: Optional[str] = None,
705
- in_size: Optional[Size] = None,
706
- ):
707
- super().__init__(in_size=in_size,
708
- init_value=0.,
709
- computation=jax.lax.add,
710
- pool_dim=2,
711
- kernel_size=kernel_size,
712
- stride=stride,
713
- padding=padding,
714
- channel_axis=channel_axis,
715
- name=name)
716
-
717
-
718
- class AvgPool3d(_AvgPool):
719
- r"""Applies a 3D average pooling over an input signal composed of several input planes.
720
-
721
-
722
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
723
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
724
- can be precisely described as:
725
-
726
- .. math::
727
- \begin{aligned}
728
- \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
729
- & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
730
- \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
731
- {kD \times kH \times kW}
732
- \end{aligned}
733
-
734
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
735
- for :attr:`padding` number of points.
736
-
737
- Shape:
738
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
739
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
740
- :math:`(D_{out}, H_{out}, W_{out}, C)`, where
741
-
742
- .. math::
743
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
744
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
745
-
746
- .. math::
747
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
748
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
749
-
750
- .. math::
751
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
752
- \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
753
-
754
- Examples::
755
-
756
- >>> import brainstate as brainstate
757
- >>> # pool of square window of size=3, stride=2
758
- >>> m = AvgPool3d(3, stride=2)
759
- >>> # pool of non-square window
760
- >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
761
- >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
762
- >>> output = m(input)
763
- >>> output.shape
764
- (20, 24, 43, 15, 16)
765
-
766
- Parameters
767
- ----------
768
- in_size: Sequence of int
769
- The shape of the input tensor.
770
- kernel_size: int, sequence of int
771
- An integer, or a sequence of integers defining the window to reduce over.
772
- stride: int, sequence of int
773
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
774
- padding: str, int, sequence of tuple
775
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
776
- of n `(low, high)` integer pairs that give the padding to apply before
777
- and after each spatial dimension.
778
- channel_axis: int, optional
779
- Axis of the spatial channels for which pooling is skipped.
780
- If ``None``, there is no channel axis.
781
- name: optional, str
782
- The object name.
783
-
784
- """
785
- __module__ = 'brainstate.nn'
786
-
787
- def __init__(
788
- self,
789
- kernel_size: Size,
790
- stride: Union[int, Sequence[int]] = 1,
791
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
792
- channel_axis: Optional[int] = -1,
793
- name: Optional[str] = None,
794
- in_size: Optional[Size] = None,
795
- ):
796
- super().__init__(in_size=in_size,
797
- init_value=0.,
798
- computation=jax.lax.add,
799
- pool_dim=3,
800
- kernel_size=kernel_size,
801
- stride=stride,
802
- padding=padding,
803
- channel_axis=channel_axis,
804
- name=name)
805
-
806
-
807
- def _adaptive_pool1d(x, target_size: int, operation: Callable):
808
- """Adaptive pool 1D.
809
-
810
- Args:
811
- x: The input. Should be a JAX array of shape `(dim,)`.
812
- target_size: The shape of the output after the pooling operation `(target_size,)`.
813
- operation: The pooling operation to be performed on the input array.
814
-
815
- Returns:
816
- A JAX array of shape `(target_size, )`.
817
- """
818
- size = jnp.size(x)
819
- num_head_arrays = size % target_size
820
- num_block = size // target_size
821
- if num_head_arrays != 0:
822
- head_end_index = num_head_arrays * (num_block + 1)
823
- heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
824
- tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
825
- outs = jnp.concatenate([heads, tails])
826
- else:
827
- outs = jax.vmap(operation)(x.reshape(-1, num_block))
828
- return outs
829
-
830
-
831
- def _generate_vmap(fun: Callable, map_axes: List[int]):
832
- map_axes = sorted(map_axes)
833
- for axis in map_axes:
834
- fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
835
- return fun
836
-
837
-
838
- class _AdaptivePool(Module):
839
- """General N dimensional adaptive down-sampling to a target shape.
840
-
841
- Parameters
842
- ----------
843
- in_size: Sequence of int
844
- The shape of the input tensor.
845
- target_size: int, sequence of int
846
- The target output shape.
847
- num_spatial_dims: int
848
- The number of spatial dimensions.
849
- channel_axis: int, optional
850
- Axis of the spatial channels for which pooling is skipped.
851
- If ``None``, there is no channel axis.
852
- operation: Callable
853
- The down-sampling operation.
854
- name: str
855
- The class name.
856
- """
857
-
858
- def __init__(
859
- self,
860
- in_size: Size,
861
- target_size: Size,
862
- num_spatial_dims: int,
863
- operation: Callable,
864
- channel_axis: Optional[int] = -1,
865
- name: Optional[str] = None,
866
- ):
867
- super().__init__(name=name)
868
-
869
- self.channel_axis = channel_axis
870
- self.operation = operation
871
- if isinstance(target_size, int):
872
- self.target_shape = (target_size,) * num_spatial_dims
873
- elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
874
- self.target_shape = target_size
875
- else:
876
- raise ValueError("`target_size` must either be an int or tuple of length "
877
- f"{num_spatial_dims} containing ints.")
878
-
879
- # in & out shapes
880
- if in_size is not None:
881
- in_size = tuple(in_size)
882
- self.in_size = in_size
883
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
884
- self.out_size = y.shape[1:]
885
-
886
- def update(self, x):
887
- """Input-output mapping.
888
-
889
- Parameters
890
- ----------
891
- x: Array
892
- Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
893
- or `(..., dim_1, dim_2)`.
894
- """
895
- # channel axis
896
- channel_axis = self.channel_axis
897
-
898
- if channel_axis:
899
- if not 0 <= abs(channel_axis) < x.ndim:
900
- raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
901
- if channel_axis < 0:
902
- channel_axis = x.ndim + channel_axis
903
- # input dimension
904
- if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
905
- raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
906
- f"dimensions (channel_axis={self.channel_axis}). "
907
- f"But got {x.ndim} dimensions.")
908
- # pooling dimensions
909
- pool_dims = list(range(x.ndim))
910
- if channel_axis:
911
- pool_dims.pop(channel_axis)
912
-
913
- # pooling
914
- for i, di in enumerate(pool_dims[-len(self.target_shape):]):
915
- poo_axes = [j for j in range(x.ndim) if j != di]
916
- op = _generate_vmap(_adaptive_pool1d, poo_axes)
917
- x = op(x, self.target_shape[i], self.operation)
918
- return x
919
-
920
-
921
- class AdaptiveAvgPool1d(_AdaptivePool):
922
- r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
923
-
924
- The output size is :math:`L_{out}`, for any input size.
925
- The number of output features is equal to the number of input planes.
926
-
927
- Shape:
928
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
929
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
930
- :math:`L_{out}=\text{output\_size}`.
931
-
932
- Examples:
933
-
934
- >>> import brainstate as brainstate
935
- >>> # target output size of 5
936
- >>> m = AdaptiveMaxPool1d(5)
937
- >>> input = brainstate.random.randn(1, 64, 8)
938
- >>> output = m(input)
939
- >>> output.shape
940
- (1, 5, 8)
941
-
942
- Parameters
943
- ----------
944
- in_size: Sequence of int
945
- The shape of the input tensor.
946
- target_size: int, sequence of int
947
- The target output shape.
948
- channel_axis: int, optional
949
- Axis of the spatial channels for which pooling is skipped.
950
- If ``None``, there is no channel axis.
951
- name: str
952
- The class name.
953
- """
954
- __module__ = 'brainstate.nn'
955
-
956
- def __init__(self,
957
- target_size: Union[int, Sequence[int]],
958
- channel_axis: Optional[int] = -1,
959
- name: Optional[str] = None,
960
- in_size: Optional[Sequence[int]] = None, ):
961
- super().__init__(in_size=in_size,
962
- target_size=target_size,
963
- channel_axis=channel_axis,
964
- num_spatial_dims=1,
965
- operation=jnp.mean,
966
- name=name)
967
-
968
-
969
- class AdaptiveAvgPool2d(_AdaptivePool):
970
- r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
971
-
972
- The output is of size :math:`H_{out} \times W_{out}`, for any input size.
973
- The number of output features is equal to the number of input planes.
974
-
975
- Shape:
976
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
977
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
978
- :math:`(H_{out}, W_{out})=\text{output\_size}`.
979
-
980
- Examples:
981
-
982
- >>> import brainstate as brainstate
983
- >>> # target output size of 5x7
984
- >>> m = AdaptiveMaxPool2d((5, 7))
985
- >>> input = brainstate.random.randn(1, 8, 9, 64)
986
- >>> output = m(input)
987
- >>> output.shape
988
- (1, 5, 7, 64)
989
- >>> # target output size of 7x7 (square)
990
- >>> m = AdaptiveMaxPool2d(7)
991
- >>> input = brainstate.random.randn(1, 10, 9, 64)
992
- >>> output = m(input)
993
- >>> output.shape
994
- (1, 7, 7, 64)
995
- >>> # target output size of 10x7
996
- >>> m = AdaptiveMaxPool2d((None, 7))
997
- >>> input = brainstate.random.randn(1, 10, 9, 64)
998
- >>> output = m(input)
999
- >>> output.shape
1000
- (1, 10, 7, 64)
1001
-
1002
- Parameters
1003
- ----------
1004
- in_size: Sequence of int
1005
- The shape of the input tensor.
1006
- target_size: int, sequence of int
1007
- The target output shape.
1008
- channel_axis: int, optional
1009
- Axis of the spatial channels for which pooling is skipped.
1010
- If ``None``, there is no channel axis.
1011
- name: str
1012
- The class name.
1013
- """
1014
- __module__ = 'brainstate.nn'
1015
-
1016
- def __init__(self,
1017
- target_size: Union[int, Sequence[int]],
1018
- channel_axis: Optional[int] = -1,
1019
- name: Optional[str] = None,
1020
-
1021
- in_size: Optional[Sequence[int]] = None, ):
1022
- super().__init__(in_size=in_size,
1023
- target_size=target_size,
1024
- channel_axis=channel_axis,
1025
- num_spatial_dims=2,
1026
- operation=jnp.mean,
1027
- name=name)
1028
-
1029
-
1030
- class AdaptiveAvgPool3d(_AdaptivePool):
1031
- r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1032
-
1033
- The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1034
- The number of output features is equal to the number of input planes.
1035
-
1036
- Shape:
1037
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1038
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1039
- where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1040
-
1041
- Examples:
1042
-
1043
- >>> import brainstate as brainstate
1044
- >>> # target output size of 5x7x9
1045
- >>> m = AdaptiveMaxPool3d((5, 7, 9))
1046
- >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
1047
- >>> output = m(input)
1048
- >>> output.shape
1049
- (1, 5, 7, 9, 64)
1050
- >>> # target output size of 7x7x7 (cube)
1051
- >>> m = AdaptiveMaxPool3d(7)
1052
- >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1053
- >>> output = m(input)
1054
- >>> output.shape
1055
- (1, 7, 7, 7, 64)
1056
- >>> # target output size of 7x9x8
1057
- >>> m = AdaptiveMaxPool3d((7, None, None))
1058
- >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1059
- >>> output = m(input)
1060
- >>> output.shape
1061
- (1, 7, 9, 8, 64)
1062
-
1063
- Parameters
1064
- ----------
1065
- in_size: Sequence of int
1066
- The shape of the input tensor.
1067
- target_size: int, sequence of int
1068
- The target output shape.
1069
- channel_axis: int, optional
1070
- Axis of the spatial channels for which pooling is skipped.
1071
- If ``None``, there is no channel axis.
1072
- name: str
1073
- The class name.
1074
- """
1075
- __module__ = 'brainstate.nn'
1076
-
1077
- def __init__(self,
1078
- target_size: Union[int, Sequence[int]],
1079
- channel_axis: Optional[int] = -1,
1080
- name: Optional[str] = None,
1081
- in_size: Optional[Sequence[int]] = None, ):
1082
- super().__init__(in_size=in_size,
1083
- target_size=target_size,
1084
- channel_axis=channel_axis,
1085
- num_spatial_dims=3,
1086
- operation=jnp.mean,
1087
- name=name)
1088
-
1089
-
1090
- class AdaptiveMaxPool1d(_AdaptivePool):
1091
- """Adaptive one-dimensional maximum down-sampling.
1092
-
1093
- Parameters
1094
- ----------
1095
- in_size: Sequence of int
1096
- The shape of the input tensor.
1097
- target_size: int, sequence of int
1098
- The target output shape.
1099
- channel_axis: int, optional
1100
- Axis of the spatial channels for which pooling is skipped.
1101
- If ``None``, there is no channel axis.
1102
- name: str
1103
- The class name.
1104
- """
1105
- __module__ = 'brainstate.nn'
1106
-
1107
- def __init__(self,
1108
- target_size: Union[int, Sequence[int]],
1109
- channel_axis: Optional[int] = -1,
1110
- name: Optional[str] = None,
1111
- in_size: Optional[Sequence[int]] = None, ):
1112
- super().__init__(in_size=in_size,
1113
- target_size=target_size,
1114
- channel_axis=channel_axis,
1115
- num_spatial_dims=1,
1116
- operation=jnp.max,
1117
- name=name)
1118
-
1119
-
1120
- class AdaptiveMaxPool2d(_AdaptivePool):
1121
- """Adaptive two-dimensional maximum down-sampling.
1122
-
1123
- Parameters
1124
- ----------
1125
- in_size: Sequence of int
1126
- The shape of the input tensor.
1127
- target_size: int, sequence of int
1128
- The target output shape.
1129
- channel_axis: int, optional
1130
- Axis of the spatial channels for which pooling is skipped.
1131
- If ``None``, there is no channel axis.
1132
- name: str
1133
- The class name.
1134
- """
1135
- __module__ = 'brainstate.nn'
1136
-
1137
- def __init__(self,
1138
- target_size: Union[int, Sequence[int]],
1139
- channel_axis: Optional[int] = -1,
1140
- name: Optional[str] = None,
1141
- in_size: Optional[Sequence[int]] = None, ):
1142
- super().__init__(in_size=in_size,
1143
- target_size=target_size,
1144
- channel_axis=channel_axis,
1145
- num_spatial_dims=2,
1146
- operation=jnp.max,
1147
- name=name)
1148
-
1149
-
1150
- class AdaptiveMaxPool3d(_AdaptivePool):
1151
- """Adaptive three-dimensional maximum down-sampling.
1152
-
1153
- Parameters
1154
- ----------
1155
- in_size: Sequence of int
1156
- The shape of the input tensor.
1157
- target_size: int, sequence of int
1158
- The target output shape.
1159
- channel_axis: int, optional
1160
- Axis of the spatial channels for which pooling is skipped.
1161
- If ``None``, there is no channel axis.
1162
- name: str
1163
- The class name.
1164
- """
1165
- __module__ = 'brainstate.nn'
1166
-
1167
- def __init__(self,
1168
- target_size: Union[int, Sequence[int]],
1169
- channel_axis: Optional[int] = -1,
1170
- name: Optional[str] = None,
1171
- in_size: Optional[Sequence[int]] = None, ):
1172
- super().__init__(in_size=in_size,
1173
- target_size=target_size,
1174
- channel_axis=channel_axis,
1175
- num_spatial_dims=3,
1176
- operation=jnp.max,
1177
- name=name)
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import functools
19
+ from typing import Sequence, Optional
20
+ from typing import Union, Tuple, Callable, List
21
+
22
+ import brainunit as u
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ from brainstate import environ
28
+ from brainstate.typing import Size
29
+ from ._module import Module
30
+
31
+ __all__ = [
32
+ 'Flatten', 'Unflatten',
33
+
34
+ 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
35
+ 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
36
+
37
+ 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
38
+ 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
39
+ ]
40
+
41
+
42
+ class Flatten(Module):
43
+ r"""
44
+ Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
45
+
46
+ Shape:
47
+ - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
48
+ where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
49
+ number of dimensions including none.
50
+ - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
51
+
52
+ Args:
53
+ in_size: Sequence of int. The shape of the input tensor.
54
+ start_axis: first dim to flatten (default = 1).
55
+ end_axis: last dim to flatten (default = -1).
56
+
57
+ Examples::
58
+ >>> import brainstate as brainstate
59
+ >>> inp = brainstate.random.randn(32, 1, 5, 5)
60
+ >>> # With default parameters
61
+ >>> m = Flatten()
62
+ >>> output = m(inp)
63
+ >>> output.shape
64
+ (32, 25)
65
+ >>> # With non-default parameters
66
+ >>> m = Flatten(0, 2)
67
+ >>> output = m(inp)
68
+ >>> output.shape
69
+ (160, 5)
70
+ """
71
+ __module__ = 'brainstate.nn'
72
+
73
+ def __init__(
74
+ self,
75
+ start_axis: int = 0,
76
+ end_axis: int = -1,
77
+ in_size: Optional[Size] = None
78
+ ) -> None:
79
+ super().__init__()
80
+ self.start_axis = start_axis
81
+ self.end_axis = end_axis
82
+
83
+ if in_size is not None:
84
+ self.in_size = tuple(in_size)
85
+ y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
86
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
87
+ self.out_size = y.shape
88
+
89
+ def update(self, x):
90
+ if self._in_size is None:
91
+ start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
92
+ else:
93
+ assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
94
+ dim_diff = x.ndim - len(self.in_size)
95
+ if self.in_size != x.shape[dim_diff:]:
96
+ raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
97
+ if self.start_axis >= 0:
98
+ start_axis = self.start_axis + dim_diff
99
+ else:
100
+ start_axis = x.ndim + self.start_axis
101
+ return u.math.flatten(x, start_axis, self.end_axis)
102
+
103
+ def __repr__(self) -> str:
104
+ return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
105
+
106
+
107
+ class Unflatten(Module):
108
+ r"""
109
+ Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
110
+
111
+ * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
112
+ be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
113
+
114
+ * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
115
+ a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
116
+ (tuple of `(name, size)` tuples) for `NamedTensor` input.
117
+
118
+ Shape:
119
+ - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
120
+ dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
121
+ - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
122
+ :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
123
+
124
+ Args:
125
+ axis: int, Dimension to be unflattened.
126
+ sizes: Sequence of int. New shape of the unflattened dimension.
127
+ in_size: Sequence of int. The shape of the input tensor.
128
+ """
129
+ __module__ = 'brainstate.nn'
130
+
131
+ def __init__(
132
+ self,
133
+ axis: int,
134
+ sizes: Size,
135
+ name: str = None,
136
+ in_size: Optional[Size] = None
137
+ ) -> None:
138
+ super().__init__(name=name)
139
+
140
+ self.axis = axis
141
+ self.sizes = sizes
142
+ if isinstance(sizes, (tuple, list)):
143
+ for idx, elem in enumerate(sizes):
144
+ if not isinstance(elem, int):
145
+ raise TypeError("unflattened sizes must be tuple of ints, " +
146
+ "but found element of type {} at pos {}".format(type(elem).__name__, idx))
147
+ else:
148
+ raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
149
+
150
+ if in_size is not None:
151
+ self.in_size = tuple(in_size)
152
+ y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
153
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
154
+ self.out_size = y.shape
155
+
156
+ def update(self, x):
157
+ return u.math.unflatten(x, self.axis, self.sizes)
158
+
159
+ def __repr__(self):
160
+ return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
161
+
162
+
163
+ class _MaxPool(Module):
164
+ def __init__(
165
+ self,
166
+ init_value: float,
167
+ computation: Callable,
168
+ pool_dim: int,
169
+ kernel_size: Size,
170
+ stride: Union[int, Sequence[int]] = None,
171
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
172
+ channel_axis: Optional[int] = -1,
173
+ name: Optional[str] = None,
174
+ in_size: Optional[Size] = None,
175
+ ):
176
+ super().__init__(name=name)
177
+
178
+ self.init_value = init_value
179
+ self.computation = computation
180
+ self.pool_dim = pool_dim
181
+
182
+ # kernel_size
183
+ if isinstance(kernel_size, int):
184
+ kernel_size = (kernel_size,) * pool_dim
185
+ elif isinstance(kernel_size, Sequence):
186
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
187
+ assert all(
188
+ [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
189
+ if len(kernel_size) != pool_dim:
190
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
191
+ else:
192
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
193
+ self.kernel_size = kernel_size
194
+
195
+ # stride
196
+ if stride is None:
197
+ stride = kernel_size
198
+ if isinstance(stride, int):
199
+ stride = (stride,) * pool_dim
200
+ elif isinstance(stride, Sequence):
201
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
202
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
203
+ if len(stride) != pool_dim:
204
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
205
+ else:
206
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
207
+ self.stride = stride
208
+
209
+ # padding
210
+ if isinstance(padding, str):
211
+ if padding not in ("SAME", "VALID"):
212
+ raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
213
+ elif isinstance(padding, int):
214
+ padding = [(padding, padding) for _ in range(pool_dim)]
215
+ elif isinstance(padding, (list, tuple)):
216
+ if isinstance(padding[0], int):
217
+ if len(padding) == pool_dim:
218
+ padding = [(x, x) for x in padding]
219
+ else:
220
+ raise ValueError(f'If padding is a sequence of ints, it '
221
+ f'should has the length of {pool_dim}.')
222
+ else:
223
+ if not all([isinstance(x, (tuple, list)) for x in padding]):
224
+ raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
225
+ if not all([len(x) == 2 for x in padding]):
226
+ raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
227
+ if len(padding) == 1:
228
+ padding = tuple(padding) * pool_dim
229
+ assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
230
+ else:
231
+ raise ValueError
232
+ self.padding = padding
233
+
234
+ # channel_axis
235
+ assert channel_axis is None or isinstance(channel_axis, int), \
236
+ f'channel_axis should be an int, but got {channel_axis}'
237
+ self.channel_axis = channel_axis
238
+
239
+ # in & out shapes
240
+ if in_size is not None:
241
+ in_size = tuple(in_size)
242
+ self.in_size = in_size
243
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
244
+ self.out_size = y.shape[1:]
245
+
246
+ def update(self, x):
247
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
248
+ if x.ndim < x_dim:
249
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
250
+ window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
251
+ stride = self._infer_shape(x.ndim, self.stride, 1)
252
+ padding = (self.padding if isinstance(self.padding, str) else
253
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
254
+ r = jax.lax.reduce_window(
255
+ x,
256
+ init_value=self.init_value,
257
+ computation=self.computation,
258
+ window_dimensions=window_shape,
259
+ window_strides=stride,
260
+ padding=padding
261
+ )
262
+ return r
263
+
264
+ def _infer_shape(self, x_dim, inputs, element):
265
+ channel_axis = self.channel_axis
266
+ if channel_axis and not 0 <= abs(channel_axis) < x_dim:
267
+ raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
268
+ if channel_axis and channel_axis < 0:
269
+ channel_axis = x_dim + channel_axis
270
+ all_dims = list(range(x_dim))
271
+ if channel_axis is not None:
272
+ all_dims.pop(channel_axis)
273
+ pool_dims = all_dims[-self.pool_dim:]
274
+ results = [element] * x_dim
275
+ for i, dim in enumerate(pool_dims):
276
+ results[dim] = inputs[i]
277
+ return results
278
+
279
+
280
+ class _AvgPool(_MaxPool):
281
+ def update(self, x):
282
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
283
+ if x.ndim < x_dim:
284
+ raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
285
+ dims = self._infer_shape(x.ndim, self.kernel_size, 1)
286
+ stride = self._infer_shape(x.ndim, self.stride, 1)
287
+ padding = (self.padding if isinstance(self.padding, str) else
288
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
289
+ pooled = jax.lax.reduce_window(x,
290
+ init_value=self.init_value,
291
+ computation=self.computation,
292
+ window_dimensions=dims,
293
+ window_strides=stride,
294
+ padding=padding)
295
+ if padding == "VALID":
296
+ # Avoid the extra reduce_window.
297
+ return pooled / np.prod(dims)
298
+ else:
299
+ # Count the number of valid entries at each input point, then use that for
300
+ # computing average. Assumes that any two arrays of same shape will be
301
+ # padded the same.
302
+ window_counts = jax.lax.reduce_window(jnp.ones_like(x),
303
+ init_value=self.init_value,
304
+ computation=self.computation,
305
+ window_dimensions=dims,
306
+ window_strides=stride,
307
+ padding=padding)
308
+ assert pooled.shape == window_counts.shape
309
+ return pooled / window_counts
310
+
311
+
312
+ class MaxPool1d(_MaxPool):
313
+ r"""Applies a 1D max pooling over an input signal composed of several input planes.
314
+
315
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
316
+ and output :math:`(N, L_{out}, C)` can be precisely described as:
317
+
318
+ .. math::
319
+ out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
320
+ input(N_i, stride \times k + m, C_j)
321
+
322
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
323
+ for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
324
+ sliding window. This `link`_ has a nice visualization of the pooling parameters.
325
+
326
+ Shape:
327
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
328
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
329
+
330
+ .. math::
331
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
332
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
333
+
334
+
335
+ Examples::
336
+
337
+ >>> import brainstate as brainstate
338
+ >>> # pool of size=3, stride=2
339
+ >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
340
+ >>> input = brainstate.random.randn(20, 50, 16)
341
+ >>> output = m(input)
342
+ >>> output.shape
343
+ (20, 24, 16)
344
+
345
+ Parameters
346
+ ----------
347
+ in_size: Sequence of int
348
+ The shape of the input tensor.
349
+ kernel_size: int, sequence of int
350
+ An integer, or a sequence of integers defining the window to reduce over.
351
+ stride: int, sequence of int
352
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
353
+ padding: str, int, sequence of tuple
354
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
355
+ of n `(low, high)` integer pairs that give the padding to apply before
356
+ and after each spatial dimension.
357
+ channel_axis: int, optional
358
+ Axis of the spatial channels for which pooling is skipped.
359
+ If ``None``, there is no channel axis.
360
+ name: optional, str
361
+ The object name.
362
+
363
+ .. _link:
364
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
365
+ """
366
+ __module__ = 'brainstate.nn'
367
+
368
+ def __init__(
369
+ self,
370
+ kernel_size: Size,
371
+ stride: Union[int, Sequence[int]] = None,
372
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
373
+ channel_axis: Optional[int] = -1,
374
+ name: Optional[str] = None,
375
+ in_size: Optional[Size] = None,
376
+ ):
377
+ super().__init__(in_size=in_size,
378
+ init_value=-jax.numpy.inf,
379
+ computation=jax.lax.max,
380
+ pool_dim=1,
381
+ kernel_size=kernel_size,
382
+ stride=stride,
383
+ padding=padding,
384
+ channel_axis=channel_axis,
385
+ name=name)
386
+
387
+
388
+ class MaxPool2d(_MaxPool):
389
+ r"""Applies a 2D max pooling over an input signal composed of several input planes.
390
+
391
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
392
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
393
+ can be precisely described as:
394
+
395
+ .. math::
396
+ \begin{aligned}
397
+ out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
398
+ & \text{input}(N_i, \text{stride[0]} \times h + m,
399
+ \text{stride[1]} \times w + n, C_j)
400
+ \end{aligned}
401
+
402
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
403
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
404
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
405
+
406
+
407
+ Shape:
408
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
409
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
410
+
411
+ .. math::
412
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
413
+ \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
414
+
415
+ .. math::
416
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
417
+ \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
418
+
419
+ Examples::
420
+
421
+ >>> import brainstate as brainstate
422
+ >>> # pool of square window of size=3, stride=2
423
+ >>> m = MaxPool2d(3, stride=2)
424
+ >>> # pool of non-square window
425
+ >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
426
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
427
+ >>> output = m(input)
428
+ >>> output.shape
429
+ (20, 24, 31, 16)
430
+
431
+ .. _link:
432
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
433
+
434
+ Parameters
435
+ ----------
436
+ in_size: Sequence of int
437
+ The shape of the input tensor.
438
+ kernel_size: int, sequence of int
439
+ An integer, or a sequence of integers defining the window to reduce over.
440
+ stride: int, sequence of int
441
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
442
+ padding: str, int, sequence of tuple
443
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
444
+ of n `(low, high)` integer pairs that give the padding to apply before
445
+ and after each spatial dimension.
446
+ channel_axis: int, optional
447
+ Axis of the spatial channels for which pooling is skipped.
448
+ If ``None``, there is no channel axis.
449
+ name: optional, str
450
+ The object name.
451
+
452
+ """
453
+ __module__ = 'brainstate.nn'
454
+
455
+ def __init__(
456
+ self,
457
+ kernel_size: Size,
458
+ stride: Union[int, Sequence[int]] = None,
459
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
460
+ channel_axis: Optional[int] = -1,
461
+ name: Optional[str] = None,
462
+ in_size: Optional[Size] = None,
463
+ ):
464
+ super().__init__(in_size=in_size,
465
+ init_value=-jax.numpy.inf,
466
+ computation=jax.lax.max,
467
+ pool_dim=2,
468
+ kernel_size=kernel_size,
469
+ stride=stride,
470
+ padding=padding,
471
+ channel_axis=channel_axis,
472
+ name=name)
473
+
474
+
475
+ class MaxPool3d(_MaxPool):
476
+ r"""Applies a 3D max pooling over an input signal composed of several input planes.
477
+
478
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
479
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
480
+ can be precisely described as:
481
+
482
+ .. math::
483
+ \begin{aligned}
484
+ \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
485
+ & \text{input}(N_i, \text{stride[0]} \times d + k,
486
+ \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
487
+ \end{aligned}
488
+
489
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
490
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
491
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
492
+
493
+
494
+ Shape:
495
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
496
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
497
+
498
+ .. math::
499
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
500
+ (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
501
+
502
+ .. math::
503
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
504
+ (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
505
+
506
+ .. math::
507
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
508
+ (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
509
+
510
+ Examples::
511
+
512
+ >>> import brainstate as brainstate
513
+ >>> # pool of square window of size=3, stride=2
514
+ >>> m = MaxPool3d(3, stride=2)
515
+ >>> # pool of non-square window
516
+ >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
517
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
518
+ >>> output = m(input)
519
+ >>> output.shape
520
+ (20, 24, 43, 15, 16)
521
+
522
+ .. _link:
523
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
524
+
525
+ Parameters
526
+ ----------
527
+ in_size: Sequence of int
528
+ The shape of the input tensor.
529
+ kernel_size: int, sequence of int
530
+ An integer, or a sequence of integers defining the window to reduce over.
531
+ stride: int, sequence of int
532
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
533
+ padding: str, int, sequence of tuple
534
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
535
+ of n `(low, high)` integer pairs that give the padding to apply before
536
+ and after each spatial dimension.
537
+ channel_axis: int, optional
538
+ Axis of the spatial channels for which pooling is skipped.
539
+ If ``None``, there is no channel axis.
540
+ name: optional, str
541
+ The object name.
542
+
543
+ """
544
+ __module__ = 'brainstate.nn'
545
+
546
+ def __init__(
547
+ self,
548
+ kernel_size: Size,
549
+ stride: Union[int, Sequence[int]] = None,
550
+ padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
551
+ channel_axis: Optional[int] = -1,
552
+ name: Optional[str] = None,
553
+ in_size: Optional[Size] = None,
554
+ ):
555
+ super().__init__(in_size=in_size,
556
+ init_value=-jax.numpy.inf,
557
+ computation=jax.lax.max,
558
+ pool_dim=3,
559
+ kernel_size=kernel_size,
560
+ stride=stride,
561
+ padding=padding,
562
+ channel_axis=channel_axis,
563
+ name=name)
564
+
565
+
566
+ class AvgPool1d(_AvgPool):
567
+ r"""Applies a 1D average pooling over an input signal composed of several input planes.
568
+
569
+ In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
570
+ output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
571
+ can be precisely described as:
572
+
573
+ .. math::
574
+
575
+ \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
576
+ \text{input}(N_i, \text{stride} \times l + m, C_j)
577
+
578
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
579
+ for :attr:`padding` number of points.
580
+
581
+ Shape:
582
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
583
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
584
+
585
+ .. math::
586
+ L_{out} = \left\lfloor \frac{L_{in} +
587
+ 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
588
+
589
+ Examples::
590
+
591
+ >>> import brainstate as brainstate
592
+ >>> # pool with window of size=3, stride=2
593
+ >>> m = AvgPool1d(3, stride=2)
594
+ >>> input = brainstate.random.randn(20, 50, 16)
595
+ >>> m(input).shape
596
+ (20, 24, 16)
597
+
598
+ Parameters
599
+ ----------
600
+ in_size: Sequence of int
601
+ The shape of the input tensor.
602
+ kernel_size: int, sequence of int
603
+ An integer, or a sequence of integers defining the window to reduce over.
604
+ stride: int, sequence of int
605
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
606
+ padding: str, int, sequence of tuple
607
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
608
+ of n `(low, high)` integer pairs that give the padding to apply before
609
+ and after each spatial dimension.
610
+ channel_axis: int, optional
611
+ Axis of the spatial channels for which pooling is skipped.
612
+ If ``None``, there is no channel axis.
613
+ name: optional, str
614
+ The object name.
615
+
616
+ """
617
+ __module__ = 'brainstate.nn'
618
+
619
+ def __init__(
620
+ self,
621
+ kernel_size: Size,
622
+ stride: Union[int, Sequence[int]] = 1,
623
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
624
+ channel_axis: Optional[int] = -1,
625
+ name: Optional[str] = None,
626
+ in_size: Optional[Size] = None,
627
+ ):
628
+ super().__init__(in_size=in_size,
629
+ init_value=0.,
630
+ computation=jax.lax.add,
631
+ pool_dim=1,
632
+ kernel_size=kernel_size,
633
+ stride=stride,
634
+ padding=padding,
635
+ channel_axis=channel_axis,
636
+ name=name)
637
+
638
+
639
+ class AvgPool2d(_AvgPool):
640
+ r"""Applies a 2D average pooling over an input signal composed of several input planes.
641
+
642
+ In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
643
+ output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
644
+ can be precisely described as:
645
+
646
+ .. math::
647
+
648
+ out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
649
+ input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
650
+
651
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
652
+ for :attr:`padding` number of points.
653
+
654
+ Shape:
655
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
656
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
657
+
658
+ .. math::
659
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
660
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
661
+
662
+ .. math::
663
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
664
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
665
+
666
+ Examples::
667
+
668
+ >>> import brainstate as brainstate
669
+ >>> # pool of square window of size=3, stride=2
670
+ >>> m = AvgPool2d(3, stride=2)
671
+ >>> # pool of non-square window
672
+ >>> m = AvgPool2d((3, 2), stride=(2, 1))
673
+ >>> input = brainstate.random.randn(20, 50, 32, , 16)
674
+ >>> output = m(input)
675
+ >>> output.shape
676
+ (20, 24, 31, 16)
677
+
678
+ Parameters
679
+ ----------
680
+ in_size: Sequence of int
681
+ The shape of the input tensor.
682
+ kernel_size: int, sequence of int
683
+ An integer, or a sequence of integers defining the window to reduce over.
684
+ stride: int, sequence of int
685
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
686
+ padding: str, int, sequence of tuple
687
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
688
+ of n `(low, high)` integer pairs that give the padding to apply before
689
+ and after each spatial dimension.
690
+ channel_axis: int, optional
691
+ Axis of the spatial channels for which pooling is skipped.
692
+ If ``None``, there is no channel axis.
693
+ name: optional, str
694
+ The object name.
695
+ """
696
+ __module__ = 'brainstate.nn'
697
+
698
+ def __init__(
699
+ self,
700
+ kernel_size: Size,
701
+ stride: Union[int, Sequence[int]] = 1,
702
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
703
+ channel_axis: Optional[int] = -1,
704
+ name: Optional[str] = None,
705
+ in_size: Optional[Size] = None,
706
+ ):
707
+ super().__init__(in_size=in_size,
708
+ init_value=0.,
709
+ computation=jax.lax.add,
710
+ pool_dim=2,
711
+ kernel_size=kernel_size,
712
+ stride=stride,
713
+ padding=padding,
714
+ channel_axis=channel_axis,
715
+ name=name)
716
+
717
+
718
+ class AvgPool3d(_AvgPool):
719
+ r"""Applies a 3D average pooling over an input signal composed of several input planes.
720
+
721
+
722
+ In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
723
+ output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
724
+ can be precisely described as:
725
+
726
+ .. math::
727
+ \begin{aligned}
728
+ \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
729
+ & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
730
+ \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
731
+ {kD \times kH \times kW}
732
+ \end{aligned}
733
+
734
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
735
+ for :attr:`padding` number of points.
736
+
737
+ Shape:
738
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
739
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
740
+ :math:`(D_{out}, H_{out}, W_{out}, C)`, where
741
+
742
+ .. math::
743
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
744
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
745
+
746
+ .. math::
747
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
748
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
749
+
750
+ .. math::
751
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
752
+ \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
753
+
754
+ Examples::
755
+
756
+ >>> import brainstate as brainstate
757
+ >>> # pool of square window of size=3, stride=2
758
+ >>> m = AvgPool3d(3, stride=2)
759
+ >>> # pool of non-square window
760
+ >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
761
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
762
+ >>> output = m(input)
763
+ >>> output.shape
764
+ (20, 24, 43, 15, 16)
765
+
766
+ Parameters
767
+ ----------
768
+ in_size: Sequence of int
769
+ The shape of the input tensor.
770
+ kernel_size: int, sequence of int
771
+ An integer, or a sequence of integers defining the window to reduce over.
772
+ stride: int, sequence of int
773
+ An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
774
+ padding: str, int, sequence of tuple
775
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
776
+ of n `(low, high)` integer pairs that give the padding to apply before
777
+ and after each spatial dimension.
778
+ channel_axis: int, optional
779
+ Axis of the spatial channels for which pooling is skipped.
780
+ If ``None``, there is no channel axis.
781
+ name: optional, str
782
+ The object name.
783
+
784
+ """
785
+ __module__ = 'brainstate.nn'
786
+
787
+ def __init__(
788
+ self,
789
+ kernel_size: Size,
790
+ stride: Union[int, Sequence[int]] = 1,
791
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
792
+ channel_axis: Optional[int] = -1,
793
+ name: Optional[str] = None,
794
+ in_size: Optional[Size] = None,
795
+ ):
796
+ super().__init__(in_size=in_size,
797
+ init_value=0.,
798
+ computation=jax.lax.add,
799
+ pool_dim=3,
800
+ kernel_size=kernel_size,
801
+ stride=stride,
802
+ padding=padding,
803
+ channel_axis=channel_axis,
804
+ name=name)
805
+
806
+
807
+ def _adaptive_pool1d(x, target_size: int, operation: Callable):
808
+ """Adaptive pool 1D.
809
+
810
+ Args:
811
+ x: The input. Should be a JAX array of shape `(dim,)`.
812
+ target_size: The shape of the output after the pooling operation `(target_size,)`.
813
+ operation: The pooling operation to be performed on the input array.
814
+
815
+ Returns:
816
+ A JAX array of shape `(target_size, )`.
817
+ """
818
+ size = jnp.size(x)
819
+ num_head_arrays = size % target_size
820
+ num_block = size // target_size
821
+ if num_head_arrays != 0:
822
+ head_end_index = num_head_arrays * (num_block + 1)
823
+ heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
824
+ tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
825
+ outs = jnp.concatenate([heads, tails])
826
+ else:
827
+ outs = jax.vmap(operation)(x.reshape(-1, num_block))
828
+ return outs
829
+
830
+
831
+ def _generate_vmap(fun: Callable, map_axes: List[int]):
832
+ map_axes = sorted(map_axes)
833
+ for axis in map_axes:
834
+ fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
835
+ return fun
836
+
837
+
838
+ class _AdaptivePool(Module):
839
+ """General N dimensional adaptive down-sampling to a target shape.
840
+
841
+ Parameters
842
+ ----------
843
+ in_size: Sequence of int
844
+ The shape of the input tensor.
845
+ target_size: int, sequence of int
846
+ The target output shape.
847
+ num_spatial_dims: int
848
+ The number of spatial dimensions.
849
+ channel_axis: int, optional
850
+ Axis of the spatial channels for which pooling is skipped.
851
+ If ``None``, there is no channel axis.
852
+ operation: Callable
853
+ The down-sampling operation.
854
+ name: str
855
+ The class name.
856
+ """
857
+
858
+ def __init__(
859
+ self,
860
+ in_size: Size,
861
+ target_size: Size,
862
+ num_spatial_dims: int,
863
+ operation: Callable,
864
+ channel_axis: Optional[int] = -1,
865
+ name: Optional[str] = None,
866
+ ):
867
+ super().__init__(name=name)
868
+
869
+ self.channel_axis = channel_axis
870
+ self.operation = operation
871
+ if isinstance(target_size, int):
872
+ self.target_shape = (target_size,) * num_spatial_dims
873
+ elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
874
+ self.target_shape = target_size
875
+ else:
876
+ raise ValueError("`target_size` must either be an int or tuple of length "
877
+ f"{num_spatial_dims} containing ints.")
878
+
879
+ # in & out shapes
880
+ if in_size is not None:
881
+ in_size = tuple(in_size)
882
+ self.in_size = in_size
883
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
884
+ self.out_size = y.shape[1:]
885
+
886
+ def update(self, x):
887
+ """Input-output mapping.
888
+
889
+ Parameters
890
+ ----------
891
+ x: Array
892
+ Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
893
+ or `(..., dim_1, dim_2)`.
894
+ """
895
+ # channel axis
896
+ channel_axis = self.channel_axis
897
+
898
+ if channel_axis:
899
+ if not 0 <= abs(channel_axis) < x.ndim:
900
+ raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
901
+ if channel_axis < 0:
902
+ channel_axis = x.ndim + channel_axis
903
+ # input dimension
904
+ if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
905
+ raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
906
+ f"dimensions (channel_axis={self.channel_axis}). "
907
+ f"But got {x.ndim} dimensions.")
908
+ # pooling dimensions
909
+ pool_dims = list(range(x.ndim))
910
+ if channel_axis:
911
+ pool_dims.pop(channel_axis)
912
+
913
+ # pooling
914
+ for i, di in enumerate(pool_dims[-len(self.target_shape):]):
915
+ poo_axes = [j for j in range(x.ndim) if j != di]
916
+ op = _generate_vmap(_adaptive_pool1d, poo_axes)
917
+ x = op(x, self.target_shape[i], self.operation)
918
+ return x
919
+
920
+
921
+ class AdaptiveAvgPool1d(_AdaptivePool):
922
+ r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
923
+
924
+ The output size is :math:`L_{out}`, for any input size.
925
+ The number of output features is equal to the number of input planes.
926
+
927
+ Shape:
928
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
929
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
930
+ :math:`L_{out}=\text{output\_size}`.
931
+
932
+ Examples:
933
+
934
+ >>> import brainstate as brainstate
935
+ >>> # target output size of 5
936
+ >>> m = AdaptiveMaxPool1d(5)
937
+ >>> input = brainstate.random.randn(1, 64, 8)
938
+ >>> output = m(input)
939
+ >>> output.shape
940
+ (1, 5, 8)
941
+
942
+ Parameters
943
+ ----------
944
+ in_size: Sequence of int
945
+ The shape of the input tensor.
946
+ target_size: int, sequence of int
947
+ The target output shape.
948
+ channel_axis: int, optional
949
+ Axis of the spatial channels for which pooling is skipped.
950
+ If ``None``, there is no channel axis.
951
+ name: str
952
+ The class name.
953
+ """
954
+ __module__ = 'brainstate.nn'
955
+
956
+ def __init__(self,
957
+ target_size: Union[int, Sequence[int]],
958
+ channel_axis: Optional[int] = -1,
959
+ name: Optional[str] = None,
960
+ in_size: Optional[Sequence[int]] = None, ):
961
+ super().__init__(in_size=in_size,
962
+ target_size=target_size,
963
+ channel_axis=channel_axis,
964
+ num_spatial_dims=1,
965
+ operation=jnp.mean,
966
+ name=name)
967
+
968
+
969
+ class AdaptiveAvgPool2d(_AdaptivePool):
970
+ r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
971
+
972
+ The output is of size :math:`H_{out} \times W_{out}`, for any input size.
973
+ The number of output features is equal to the number of input planes.
974
+
975
+ Shape:
976
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
977
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
978
+ :math:`(H_{out}, W_{out})=\text{output\_size}`.
979
+
980
+ Examples:
981
+
982
+ >>> import brainstate as brainstate
983
+ >>> # target output size of 5x7
984
+ >>> m = AdaptiveMaxPool2d((5, 7))
985
+ >>> input = brainstate.random.randn(1, 8, 9, 64)
986
+ >>> output = m(input)
987
+ >>> output.shape
988
+ (1, 5, 7, 64)
989
+ >>> # target output size of 7x7 (square)
990
+ >>> m = AdaptiveMaxPool2d(7)
991
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
992
+ >>> output = m(input)
993
+ >>> output.shape
994
+ (1, 7, 7, 64)
995
+ >>> # target output size of 10x7
996
+ >>> m = AdaptiveMaxPool2d((None, 7))
997
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
998
+ >>> output = m(input)
999
+ >>> output.shape
1000
+ (1, 10, 7, 64)
1001
+
1002
+ Parameters
1003
+ ----------
1004
+ in_size: Sequence of int
1005
+ The shape of the input tensor.
1006
+ target_size: int, sequence of int
1007
+ The target output shape.
1008
+ channel_axis: int, optional
1009
+ Axis of the spatial channels for which pooling is skipped.
1010
+ If ``None``, there is no channel axis.
1011
+ name: str
1012
+ The class name.
1013
+ """
1014
+ __module__ = 'brainstate.nn'
1015
+
1016
+ def __init__(self,
1017
+ target_size: Union[int, Sequence[int]],
1018
+ channel_axis: Optional[int] = -1,
1019
+ name: Optional[str] = None,
1020
+
1021
+ in_size: Optional[Sequence[int]] = None, ):
1022
+ super().__init__(in_size=in_size,
1023
+ target_size=target_size,
1024
+ channel_axis=channel_axis,
1025
+ num_spatial_dims=2,
1026
+ operation=jnp.mean,
1027
+ name=name)
1028
+
1029
+
1030
+ class AdaptiveAvgPool3d(_AdaptivePool):
1031
+ r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1032
+
1033
+ The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1034
+ The number of output features is equal to the number of input planes.
1035
+
1036
+ Shape:
1037
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1038
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1039
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1040
+
1041
+ Examples:
1042
+
1043
+ >>> import brainstate as brainstate
1044
+ >>> # target output size of 5x7x9
1045
+ >>> m = AdaptiveMaxPool3d((5, 7, 9))
1046
+ >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
1047
+ >>> output = m(input)
1048
+ >>> output.shape
1049
+ (1, 5, 7, 9, 64)
1050
+ >>> # target output size of 7x7x7 (cube)
1051
+ >>> m = AdaptiveMaxPool3d(7)
1052
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1053
+ >>> output = m(input)
1054
+ >>> output.shape
1055
+ (1, 7, 7, 7, 64)
1056
+ >>> # target output size of 7x9x8
1057
+ >>> m = AdaptiveMaxPool3d((7, None, None))
1058
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1059
+ >>> output = m(input)
1060
+ >>> output.shape
1061
+ (1, 7, 9, 8, 64)
1062
+
1063
+ Parameters
1064
+ ----------
1065
+ in_size: Sequence of int
1066
+ The shape of the input tensor.
1067
+ target_size: int, sequence of int
1068
+ The target output shape.
1069
+ channel_axis: int, optional
1070
+ Axis of the spatial channels for which pooling is skipped.
1071
+ If ``None``, there is no channel axis.
1072
+ name: str
1073
+ The class name.
1074
+ """
1075
+ __module__ = 'brainstate.nn'
1076
+
1077
+ def __init__(self,
1078
+ target_size: Union[int, Sequence[int]],
1079
+ channel_axis: Optional[int] = -1,
1080
+ name: Optional[str] = None,
1081
+ in_size: Optional[Sequence[int]] = None, ):
1082
+ super().__init__(in_size=in_size,
1083
+ target_size=target_size,
1084
+ channel_axis=channel_axis,
1085
+ num_spatial_dims=3,
1086
+ operation=jnp.mean,
1087
+ name=name)
1088
+
1089
+
1090
+ class AdaptiveMaxPool1d(_AdaptivePool):
1091
+ """Adaptive one-dimensional maximum down-sampling.
1092
+
1093
+ Parameters
1094
+ ----------
1095
+ in_size: Sequence of int
1096
+ The shape of the input tensor.
1097
+ target_size: int, sequence of int
1098
+ The target output shape.
1099
+ channel_axis: int, optional
1100
+ Axis of the spatial channels for which pooling is skipped.
1101
+ If ``None``, there is no channel axis.
1102
+ name: str
1103
+ The class name.
1104
+ """
1105
+ __module__ = 'brainstate.nn'
1106
+
1107
+ def __init__(self,
1108
+ target_size: Union[int, Sequence[int]],
1109
+ channel_axis: Optional[int] = -1,
1110
+ name: Optional[str] = None,
1111
+ in_size: Optional[Sequence[int]] = None, ):
1112
+ super().__init__(in_size=in_size,
1113
+ target_size=target_size,
1114
+ channel_axis=channel_axis,
1115
+ num_spatial_dims=1,
1116
+ operation=jnp.max,
1117
+ name=name)
1118
+
1119
+
1120
+ class AdaptiveMaxPool2d(_AdaptivePool):
1121
+ """Adaptive two-dimensional maximum down-sampling.
1122
+
1123
+ Parameters
1124
+ ----------
1125
+ in_size: Sequence of int
1126
+ The shape of the input tensor.
1127
+ target_size: int, sequence of int
1128
+ The target output shape.
1129
+ channel_axis: int, optional
1130
+ Axis of the spatial channels for which pooling is skipped.
1131
+ If ``None``, there is no channel axis.
1132
+ name: str
1133
+ The class name.
1134
+ """
1135
+ __module__ = 'brainstate.nn'
1136
+
1137
+ def __init__(self,
1138
+ target_size: Union[int, Sequence[int]],
1139
+ channel_axis: Optional[int] = -1,
1140
+ name: Optional[str] = None,
1141
+ in_size: Optional[Sequence[int]] = None, ):
1142
+ super().__init__(in_size=in_size,
1143
+ target_size=target_size,
1144
+ channel_axis=channel_axis,
1145
+ num_spatial_dims=2,
1146
+ operation=jnp.max,
1147
+ name=name)
1148
+
1149
+
1150
+ class AdaptiveMaxPool3d(_AdaptivePool):
1151
+ """Adaptive three-dimensional maximum down-sampling.
1152
+
1153
+ Parameters
1154
+ ----------
1155
+ in_size: Sequence of int
1156
+ The shape of the input tensor.
1157
+ target_size: int, sequence of int
1158
+ The target output shape.
1159
+ channel_axis: int, optional
1160
+ Axis of the spatial channels for which pooling is skipped.
1161
+ If ``None``, there is no channel axis.
1162
+ name: str
1163
+ The class name.
1164
+ """
1165
+ __module__ = 'brainstate.nn'
1166
+
1167
+ def __init__(self,
1168
+ target_size: Union[int, Sequence[int]],
1169
+ channel_axis: Optional[int] = -1,
1170
+ name: Optional[str] = None,
1171
+ in_size: Optional[Sequence[int]] = None, ):
1172
+ super().__init__(in_size=in_size,
1173
+ target_size=target_size,
1174
+ channel_axis=channel_axis,
1175
+ num_spatial_dims=3,
1176
+ operation=jnp.max,
1177
+ name=name)