brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,1229 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from __future__ import annotations
19
-
20
- import functools
21
- from typing import Sequence, Optional
22
- from typing import Union, Tuple, Callable, List
23
-
24
- import brainunit as u
25
- import jax
26
- import jax.numpy as jnp
27
- import numpy as np
28
-
29
- from ._base import DnnLayer, ExplicitInOutSize
30
- from .. import environ
31
- from ..mixin import Mode
32
- from brainstate.typing import Size
33
-
34
- __all__ = [
35
- 'Flatten', 'Unflatten',
36
-
37
- 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
38
- 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
39
-
40
- 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
41
- 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
42
- ]
43
-
44
-
45
- class Flatten(DnnLayer, ExplicitInOutSize):
46
- r"""
47
- Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
48
-
49
- Shape:
50
- - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
51
- where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
52
- number of dimensions including none.
53
- - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
54
-
55
- Args:
56
- in_size: Sequence of int. The shape of the input tensor.
57
- start_axis: first dim to flatten (default = 1).
58
- end_axis: last dim to flatten (default = -1).
59
-
60
- Examples::
61
- >>> import brainstate as bst
62
- >>> inp = bst.random.randn(32, 1, 5, 5)
63
- >>> # With default parameters
64
- >>> m = Flatten()
65
- >>> output = m(inp)
66
- >>> output.shape
67
- (32, 25)
68
- >>> # With non-default parameters
69
- >>> m = Flatten(0, 2)
70
- >>> output = m(inp)
71
- >>> output.shape
72
- (160, 5)
73
- """
74
- __module__ = 'brainstate.nn'
75
-
76
- def __init__(
77
- self,
78
- start_axis: int = 0,
79
- end_axis: int = -1,
80
- in_size: Optional[Size] = None
81
- ) -> None:
82
- super().__init__()
83
- self.start_axis = start_axis
84
- self.end_axis = end_axis
85
-
86
- if in_size is not None:
87
- self.in_size = tuple(in_size)
88
- y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
89
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
90
- self.out_size = y.shape
91
-
92
- def update(self, x):
93
- if self._in_size is None:
94
- start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
95
- else:
96
- assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
97
- dim_diff = x.ndim - len(self.in_size)
98
- if self.in_size != x.shape[dim_diff:]:
99
- raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
100
- if self.start_axis >= 0:
101
- start_axis = self.start_axis + dim_diff
102
- else:
103
- start_axis = x.ndim + self.start_axis
104
- return u.math.flatten(x, start_axis, self.end_axis)
105
-
106
- def __repr__(self) -> str:
107
- return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
108
-
109
-
110
- class Unflatten(DnnLayer, ExplicitInOutSize):
111
- r"""
112
- Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
113
-
114
- * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
115
- be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
116
-
117
- * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
118
- a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
119
- (tuple of `(name, size)` tuples) for `NamedTensor` input.
120
-
121
- Shape:
122
- - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
123
- dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
124
- - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
125
- :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
126
-
127
- Args:
128
- axis: int, Dimension to be unflattened.
129
- sizes: Sequence of int. New shape of the unflattened dimension.
130
- in_size: Sequence of int. The shape of the input tensor.
131
- """
132
- __module__ = 'brainstate.nn'
133
-
134
- def __init__(
135
- self,
136
- axis: int,
137
- sizes: Size,
138
- mode: Mode = None,
139
- name: str = None,
140
- in_size: Optional[Size] = None
141
- ) -> None:
142
- super().__init__(mode=mode, name=name)
143
-
144
- self.axis = axis
145
- self.sizes = sizes
146
- if isinstance(sizes, (tuple, list)):
147
- for idx, elem in enumerate(sizes):
148
- if not isinstance(elem, int):
149
- raise TypeError("unflattened sizes must be tuple of ints, " +
150
- "but found element of type {} at pos {}".format(type(elem).__name__, idx))
151
- else:
152
- raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
153
-
154
- if in_size is not None:
155
- self.in_size = tuple(in_size)
156
- y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
157
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
158
- self.out_size = y.shape
159
-
160
- def update(self, x):
161
- return u.math.unflatten(x, self.axis, self.sizes)
162
-
163
- def __repr__(self):
164
- return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
165
-
166
-
167
- class _MaxPool(DnnLayer, ExplicitInOutSize):
168
- def __init__(
169
- self,
170
- init_value: float,
171
- computation: Callable,
172
- pool_dim: int,
173
- kernel_size: Size,
174
- stride: Union[int, Sequence[int]] = None,
175
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
176
- channel_axis: Optional[int] = -1,
177
- mode: Mode = None,
178
- name: Optional[str] = None,
179
- in_size: Optional[Size] = None,
180
- ):
181
- super().__init__(name=name, mode=mode)
182
-
183
- self.init_value = init_value
184
- self.computation = computation
185
- self.pool_dim = pool_dim
186
-
187
- # kernel_size
188
- if isinstance(kernel_size, int):
189
- kernel_size = (kernel_size,) * pool_dim
190
- elif isinstance(kernel_size, Sequence):
191
- assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
192
- assert all([isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
193
- if len(kernel_size) != pool_dim:
194
- raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
195
- else:
196
- raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
197
- self.kernel_size = kernel_size
198
-
199
- # stride
200
- if stride is None:
201
- stride = kernel_size
202
- if isinstance(stride, int):
203
- stride = (stride,) * pool_dim
204
- elif isinstance(stride, Sequence):
205
- assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
206
- assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
207
- if len(stride) != pool_dim:
208
- raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
209
- else:
210
- raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
211
- self.stride = stride
212
-
213
- # padding
214
- if isinstance(padding, str):
215
- if padding not in ("SAME", "VALID"):
216
- raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
217
- elif isinstance(padding, int):
218
- padding = [(padding, padding) for _ in range(pool_dim)]
219
- elif isinstance(padding, (list, tuple)):
220
- if isinstance(padding[0], int):
221
- if len(padding) == pool_dim:
222
- padding = [(x, x) for x in padding]
223
- else:
224
- raise ValueError(f'If padding is a sequence of ints, it '
225
- f'should has the length of {pool_dim}.')
226
- else:
227
- if not all([isinstance(x, (tuple, list)) for x in padding]):
228
- raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
229
- if not all([len(x) == 2 for x in padding]):
230
- raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
231
- if len(padding) == 1:
232
- padding = tuple(padding) * pool_dim
233
- assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
234
- else:
235
- raise ValueError
236
- self.padding = padding
237
-
238
- # channel_axis
239
- assert channel_axis is None or isinstance(channel_axis, int), \
240
- f'channel_axis should be an int, but got {channel_axis}'
241
- self.channel_axis = channel_axis
242
-
243
- # in & out shapes
244
- if in_size is not None:
245
- in_size = tuple(in_size)
246
- self.in_size = in_size
247
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
248
- self.out_size = y.shape[1:]
249
-
250
- def update(self, x):
251
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
252
- if x.ndim < x_dim:
253
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
254
- window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
255
- stride = self._infer_shape(x.ndim, self.stride, 1)
256
- padding = (self.padding if isinstance(self.padding, str) else
257
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
258
- r = jax.lax.reduce_window(
259
- x,
260
- init_value=self.init_value,
261
- computation=self.computation,
262
- window_dimensions=window_shape,
263
- window_strides=stride,
264
- padding=padding
265
- )
266
- return r
267
-
268
- def _infer_shape(self, x_dim, inputs, element):
269
- channel_axis = self.channel_axis
270
- if channel_axis and not 0 <= abs(channel_axis) < x_dim:
271
- raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
272
- if channel_axis and channel_axis < 0:
273
- channel_axis = x_dim + channel_axis
274
- all_dims = list(range(x_dim))
275
- if channel_axis is not None:
276
- all_dims.pop(channel_axis)
277
- pool_dims = all_dims[-self.pool_dim:]
278
- results = [element] * x_dim
279
- for i, dim in enumerate(pool_dims):
280
- results[dim] = inputs[i]
281
- return results
282
-
283
-
284
- class _AvgPool(_MaxPool):
285
- def update(self, x):
286
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
287
- if x.ndim < x_dim:
288
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
289
- dims = self._infer_shape(x.ndim, self.kernel_size, 1)
290
- stride = self._infer_shape(x.ndim, self.stride, 1)
291
- padding = (self.padding if isinstance(self.padding, str) else
292
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
293
- pooled = jax.lax.reduce_window(x,
294
- init_value=self.init_value,
295
- computation=self.computation,
296
- window_dimensions=dims,
297
- window_strides=stride,
298
- padding=padding)
299
- if padding == "VALID":
300
- # Avoid the extra reduce_window.
301
- return pooled / np.prod(dims)
302
- else:
303
- # Count the number of valid entries at each input point, then use that for
304
- # computing average. Assumes that any two arrays of same shape will be
305
- # padded the same.
306
- window_counts = jax.lax.reduce_window(jnp.ones_like(x),
307
- init_value=self.init_value,
308
- computation=self.computation,
309
- window_dimensions=dims,
310
- window_strides=stride,
311
- padding=padding)
312
- assert pooled.shape == window_counts.shape
313
- return pooled / window_counts
314
-
315
-
316
- class MaxPool1d(_MaxPool):
317
- r"""Applies a 1D max pooling over an input signal composed of several input planes.
318
-
319
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
320
- and output :math:`(N, L_{out}, C)` can be precisely described as:
321
-
322
- .. math::
323
- out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
324
- input(N_i, stride \times k + m, C_j)
325
-
326
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
327
- for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
328
- sliding window. This `link`_ has a nice visualization of the pooling parameters.
329
-
330
- Shape:
331
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
332
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
333
-
334
- .. math::
335
- L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
336
- \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
337
-
338
-
339
- Examples::
340
-
341
- >>> import brainstate as bst
342
- >>> # pool of size=3, stride=2
343
- >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
344
- >>> input = bst.random.randn(20, 50, 16)
345
- >>> output = m(input)
346
- >>> output.shape
347
- (20, 24, 16)
348
-
349
- Parameters
350
- ----------
351
- in_size: Sequence of int
352
- The shape of the input tensor.
353
- kernel_size: int, sequence of int
354
- An integer, or a sequence of integers defining the window to reduce over.
355
- stride: int, sequence of int
356
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
357
- padding: str, int, sequence of tuple
358
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
359
- of n `(low, high)` integer pairs that give the padding to apply before
360
- and after each spatial dimension.
361
- channel_axis: int, optional
362
- Axis of the spatial channels for which pooling is skipped.
363
- If ``None``, there is no channel axis.
364
- mode: Mode
365
- The computation mode.
366
- name: optional, str
367
- The object name.
368
-
369
- .. _link:
370
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
371
- """
372
- __module__ = 'brainstate.nn'
373
-
374
- def __init__(
375
- self,
376
- kernel_size: Size,
377
- stride: Union[int, Sequence[int]] = None,
378
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
379
- channel_axis: Optional[int] = -1,
380
- mode: Mode = None,
381
- name: Optional[str] = None,
382
- in_size: Optional[Size] = None,
383
- ):
384
- super().__init__(in_size=in_size,
385
- init_value=-jax.numpy.inf,
386
- computation=jax.lax.max,
387
- pool_dim=1,
388
- kernel_size=kernel_size,
389
- stride=stride,
390
- padding=padding,
391
- channel_axis=channel_axis,
392
- name=name,
393
- mode=mode)
394
-
395
-
396
- class MaxPool2d(_MaxPool):
397
- r"""Applies a 2D max pooling over an input signal composed of several input planes.
398
-
399
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
400
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
401
- can be precisely described as:
402
-
403
- .. math::
404
- \begin{aligned}
405
- out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
406
- & \text{input}(N_i, \text{stride[0]} \times h + m,
407
- \text{stride[1]} \times w + n, C_j)
408
- \end{aligned}
409
-
410
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
411
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
412
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
413
-
414
-
415
- Shape:
416
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
417
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
418
-
419
- .. math::
420
- H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
421
- \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
422
-
423
- .. math::
424
- W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
425
- \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
426
-
427
- Examples::
428
-
429
- >>> import brainstate as bst
430
- >>> # pool of square window of size=3, stride=2
431
- >>> m = MaxPool2d(3, stride=2)
432
- >>> # pool of non-square window
433
- >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
434
- >>> input = bst.random.randn(20, 50, 32, 16)
435
- >>> output = m(input)
436
- >>> output.shape
437
- (20, 24, 31, 16)
438
-
439
- .. _link:
440
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
441
-
442
- Parameters
443
- ----------
444
- in_size: Sequence of int
445
- The shape of the input tensor.
446
- kernel_size: int, sequence of int
447
- An integer, or a sequence of integers defining the window to reduce over.
448
- stride: int, sequence of int
449
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
450
- padding: str, int, sequence of tuple
451
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
452
- of n `(low, high)` integer pairs that give the padding to apply before
453
- and after each spatial dimension.
454
- channel_axis: int, optional
455
- Axis of the spatial channels for which pooling is skipped.
456
- If ``None``, there is no channel axis.
457
- mode: Mode
458
- The computation mode.
459
- name: optional, str
460
- The object name.
461
-
462
- """
463
- __module__ = 'brainstate.nn'
464
-
465
- def __init__(
466
- self,
467
- kernel_size: Size,
468
- stride: Union[int, Sequence[int]] = None,
469
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
470
- channel_axis: Optional[int] = -1,
471
- mode: Mode = None,
472
- name: Optional[str] = None,
473
- in_size: Optional[Size] = None,
474
- ):
475
- super().__init__(in_size=in_size,
476
- init_value=-jax.numpy.inf,
477
- computation=jax.lax.max,
478
- pool_dim=2,
479
- kernel_size=kernel_size,
480
- stride=stride,
481
- padding=padding,
482
- channel_axis=channel_axis,
483
- name=name, mode=mode)
484
-
485
-
486
- class MaxPool3d(_MaxPool):
487
- r"""Applies a 3D max pooling over an input signal composed of several input planes.
488
-
489
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
490
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
491
- can be precisely described as:
492
-
493
- .. math::
494
- \begin{aligned}
495
- \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
496
- & \text{input}(N_i, \text{stride[0]} \times d + k,
497
- \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
498
- \end{aligned}
499
-
500
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
501
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
502
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
503
-
504
-
505
- Shape:
506
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
507
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
508
-
509
- .. math::
510
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
511
- (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
512
-
513
- .. math::
514
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
515
- (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
516
-
517
- .. math::
518
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
519
- (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
520
-
521
- Examples::
522
-
523
- >>> import brainstate as bst
524
- >>> # pool of square window of size=3, stride=2
525
- >>> m = MaxPool3d(3, stride=2)
526
- >>> # pool of non-square window
527
- >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
528
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
529
- >>> output = m(input)
530
- >>> output.shape
531
- (20, 24, 43, 15, 16)
532
-
533
- .. _link:
534
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
535
-
536
- Parameters
537
- ----------
538
- in_size: Sequence of int
539
- The shape of the input tensor.
540
- kernel_size: int, sequence of int
541
- An integer, or a sequence of integers defining the window to reduce over.
542
- stride: int, sequence of int
543
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
544
- padding: str, int, sequence of tuple
545
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
546
- of n `(low, high)` integer pairs that give the padding to apply before
547
- and after each spatial dimension.
548
- channel_axis: int, optional
549
- Axis of the spatial channels for which pooling is skipped.
550
- If ``None``, there is no channel axis.
551
- mode: Mode
552
- The computation mode.
553
- name: optional, str
554
- The object name.
555
-
556
- """
557
- __module__ = 'brainstate.nn'
558
-
559
- def __init__(
560
- self,
561
- kernel_size: Size,
562
- stride: Union[int, Sequence[int]] = None,
563
- padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
564
- channel_axis: Optional[int] = -1,
565
- mode: Mode = None,
566
- name: Optional[str] = None,
567
- in_size: Optional[Size] = None,
568
- ):
569
- super().__init__(in_size=in_size,
570
- init_value=-jax.numpy.inf,
571
- computation=jax.lax.max,
572
- pool_dim=3,
573
- kernel_size=kernel_size,
574
- stride=stride,
575
- padding=padding,
576
- channel_axis=channel_axis,
577
- name=name, mode=mode)
578
-
579
-
580
- class AvgPool1d(_AvgPool):
581
- r"""Applies a 1D average pooling over an input signal composed of several input planes.
582
-
583
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
584
- output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
585
- can be precisely described as:
586
-
587
- .. math::
588
-
589
- \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
590
- \text{input}(N_i, \text{stride} \times l + m, C_j)
591
-
592
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
593
- for :attr:`padding` number of points.
594
-
595
- Shape:
596
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
597
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
598
-
599
- .. math::
600
- L_{out} = \left\lfloor \frac{L_{in} +
601
- 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
602
-
603
- Examples::
604
-
605
- >>> import brainstate as bst
606
- >>> # pool with window of size=3, stride=2
607
- >>> m = AvgPool1d(3, stride=2)
608
- >>> input = bst.random.randn(20, 50, 16)
609
- >>> m(input).shape
610
- (20, 24, 16)
611
-
612
- Parameters
613
- ----------
614
- in_size: Sequence of int
615
- The shape of the input tensor.
616
- kernel_size: int, sequence of int
617
- An integer, or a sequence of integers defining the window to reduce over.
618
- stride: int, sequence of int
619
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
620
- padding: str, int, sequence of tuple
621
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
622
- of n `(low, high)` integer pairs that give the padding to apply before
623
- and after each spatial dimension.
624
- channel_axis: int, optional
625
- Axis of the spatial channels for which pooling is skipped.
626
- If ``None``, there is no channel axis.
627
- mode: Mode
628
- The computation mode.
629
- name: optional, str
630
- The object name.
631
-
632
- """
633
- __module__ = 'brainstate.nn'
634
-
635
- def __init__(
636
- self,
637
- kernel_size: Size,
638
- stride: Union[int, Sequence[int]] = 1,
639
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
640
- channel_axis: Optional[int] = -1,
641
- mode: Mode = None,
642
- name: Optional[str] = None,
643
- in_size: Optional[Size] = None,
644
- ):
645
- super().__init__(in_size=in_size,
646
- init_value=0.,
647
- computation=jax.lax.add,
648
- pool_dim=1,
649
- kernel_size=kernel_size,
650
- stride=stride,
651
- padding=padding,
652
- channel_axis=channel_axis,
653
- name=name,
654
- mode=mode)
655
-
656
-
657
- class AvgPool2d(_AvgPool):
658
- r"""Applies a 2D average pooling over an input signal composed of several input planes.
659
-
660
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
661
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
662
- can be precisely described as:
663
-
664
- .. math::
665
-
666
- out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
667
- input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
668
-
669
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
670
- for :attr:`padding` number of points.
671
-
672
- Shape:
673
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
674
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
675
-
676
- .. math::
677
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
678
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
679
-
680
- .. math::
681
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
682
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
683
-
684
- Examples::
685
-
686
- >>> import brainstate as bst
687
- >>> # pool of square window of size=3, stride=2
688
- >>> m = AvgPool2d(3, stride=2)
689
- >>> # pool of non-square window
690
- >>> m = AvgPool2d((3, 2), stride=(2, 1))
691
- >>> input = bst.random.randn(20, 50, 32, , 16)
692
- >>> output = m(input)
693
- >>> output.shape
694
- (20, 24, 31, 16)
695
-
696
- Parameters
697
- ----------
698
- in_size: Sequence of int
699
- The shape of the input tensor.
700
- kernel_size: int, sequence of int
701
- An integer, or a sequence of integers defining the window to reduce over.
702
- stride: int, sequence of int
703
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
704
- padding: str, int, sequence of tuple
705
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
706
- of n `(low, high)` integer pairs that give the padding to apply before
707
- and after each spatial dimension.
708
- channel_axis: int, optional
709
- Axis of the spatial channels for which pooling is skipped.
710
- If ``None``, there is no channel axis.
711
- mode: Mode
712
- The computation mode.
713
- name: optional, str
714
- The object name.
715
- """
716
- __module__ = 'brainstate.nn'
717
-
718
- def __init__(
719
- self,
720
- kernel_size: Size,
721
- stride: Union[int, Sequence[int]] = 1,
722
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
723
- channel_axis: Optional[int] = -1,
724
- mode: Mode = None,
725
- name: Optional[str] = None,
726
- in_size: Optional[Size] = None,
727
- ):
728
- super().__init__(in_size=in_size,
729
- init_value=0.,
730
- computation=jax.lax.add,
731
- pool_dim=2,
732
- kernel_size=kernel_size,
733
- stride=stride,
734
- padding=padding,
735
- channel_axis=channel_axis,
736
- name=name,
737
- mode=mode)
738
-
739
-
740
- class AvgPool3d(_AvgPool):
741
- r"""Applies a 3D average pooling over an input signal composed of several input planes.
742
-
743
-
744
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
745
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
746
- can be precisely described as:
747
-
748
- .. math::
749
- \begin{aligned}
750
- \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
751
- & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
752
- \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
753
- {kD \times kH \times kW}
754
- \end{aligned}
755
-
756
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
757
- for :attr:`padding` number of points.
758
-
759
- Shape:
760
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
761
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
762
- :math:`(D_{out}, H_{out}, W_{out}, C)`, where
763
-
764
- .. math::
765
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
766
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
767
-
768
- .. math::
769
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
770
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
771
-
772
- .. math::
773
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
774
- \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
775
-
776
- Examples::
777
-
778
- >>> import brainstate as bst
779
- >>> # pool of square window of size=3, stride=2
780
- >>> m = AvgPool3d(3, stride=2)
781
- >>> # pool of non-square window
782
- >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
783
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
784
- >>> output = m(input)
785
- >>> output.shape
786
- (20, 24, 43, 15, 16)
787
-
788
- Parameters
789
- ----------
790
- in_size: Sequence of int
791
- The shape of the input tensor.
792
- kernel_size: int, sequence of int
793
- An integer, or a sequence of integers defining the window to reduce over.
794
- stride: int, sequence of int
795
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
796
- padding: str, int, sequence of tuple
797
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
798
- of n `(low, high)` integer pairs that give the padding to apply before
799
- and after each spatial dimension.
800
- channel_axis: int, optional
801
- Axis of the spatial channels for which pooling is skipped.
802
- If ``None``, there is no channel axis.
803
- mode: Mode
804
- The computation mode.
805
- name: optional, str
806
- The object name.
807
-
808
- """
809
- __module__ = 'brainstate.nn'
810
-
811
- def __init__(
812
- self,
813
- kernel_size: Size,
814
- stride: Union[int, Sequence[int]] = 1,
815
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
816
- channel_axis: Optional[int] = -1,
817
- mode: Mode = None,
818
- name: Optional[str] = None,
819
- in_size: Optional[Size] = None,
820
- ):
821
- super().__init__(in_size=in_size,
822
- init_value=0.,
823
- computation=jax.lax.add,
824
- pool_dim=3,
825
- kernel_size=kernel_size,
826
- stride=stride,
827
- padding=padding,
828
- channel_axis=channel_axis,
829
- name=name,
830
- mode=mode)
831
-
832
-
833
- def _adaptive_pool1d(x, target_size: int, operation: Callable):
834
- """Adaptive pool 1D.
835
-
836
- Args:
837
- x: The input. Should be a JAX array of shape `(dim,)`.
838
- target_size: The shape of the output after the pooling operation `(target_size,)`.
839
- operation: The pooling operation to be performed on the input array.
840
-
841
- Returns:
842
- A JAX array of shape `(target_size, )`.
843
- """
844
- size = jnp.size(x)
845
- num_head_arrays = size % target_size
846
- num_block = size // target_size
847
- if num_head_arrays != 0:
848
- head_end_index = num_head_arrays * (num_block + 1)
849
- heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
850
- tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
851
- outs = jnp.concatenate([heads, tails])
852
- else:
853
- outs = jax.vmap(operation)(x.reshape(-1, num_block))
854
- return outs
855
-
856
-
857
- def _generate_vmap(fun: Callable, map_axes: List[int]):
858
- map_axes = sorted(map_axes)
859
- for axis in map_axes:
860
- fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
861
- return fun
862
-
863
-
864
- class _AdaptivePool(DnnLayer, ExplicitInOutSize):
865
- """General N dimensional adaptive down-sampling to a target shape.
866
-
867
- Parameters
868
- ----------
869
- in_size: Sequence of int
870
- The shape of the input tensor.
871
- target_size: int, sequence of int
872
- The target output shape.
873
- num_spatial_dims: int
874
- The number of spatial dimensions.
875
- channel_axis: int, optional
876
- Axis of the spatial channels for which pooling is skipped.
877
- If ``None``, there is no channel axis.
878
- operation: Callable
879
- The down-sampling operation.
880
- name: str
881
- The class name.
882
- mode: Mode
883
- The computing mode.
884
- """
885
-
886
- def __init__(
887
- self,
888
- in_size: Size,
889
- target_size: Size,
890
- num_spatial_dims: int,
891
- operation: Callable,
892
- channel_axis: Optional[int] = -1,
893
- name: Optional[str] = None,
894
- mode: Optional[Mode] = None,
895
- ):
896
- super().__init__(name=name, mode=mode)
897
-
898
- self.channel_axis = channel_axis
899
- self.operation = operation
900
- if isinstance(target_size, int):
901
- self.target_shape = (target_size,) * num_spatial_dims
902
- elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
903
- self.target_shape = target_size
904
- else:
905
- raise ValueError("`target_size` must either be an int or tuple of length "
906
- f"{num_spatial_dims} containing ints.")
907
-
908
- # in & out shapes
909
- if in_size is not None:
910
- in_size = tuple(in_size)
911
- self.in_size = in_size
912
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
913
- self.out_size = y.shape[1:]
914
-
915
- def update(self, x):
916
- """Input-output mapping.
917
-
918
- Parameters
919
- ----------
920
- x: Array
921
- Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
922
- or `(..., dim_1, dim_2)`.
923
- """
924
- # channel axis
925
- channel_axis = self.channel_axis
926
-
927
- if channel_axis:
928
- if not 0 <= abs(channel_axis) < x.ndim:
929
- raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
930
- if channel_axis < 0:
931
- channel_axis = x.ndim + channel_axis
932
- # input dimension
933
- if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
934
- raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
935
- f"dimensions (channel_axis={self.channel_axis}). "
936
- f"But got {x.ndim} dimensions.")
937
- # pooling dimensions
938
- pool_dims = list(range(x.ndim))
939
- if channel_axis:
940
- pool_dims.pop(channel_axis)
941
-
942
- # pooling
943
- for i, di in enumerate(pool_dims[-len(self.target_shape):]):
944
- poo_axes = [j for j in range(x.ndim) if j != di]
945
- op = _generate_vmap(_adaptive_pool1d, poo_axes)
946
- x = op(x, self.target_shape[i], self.operation)
947
- return x
948
-
949
-
950
- class AdaptiveAvgPool1d(_AdaptivePool):
951
- r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
952
-
953
- The output size is :math:`L_{out}`, for any input size.
954
- The number of output features is equal to the number of input planes.
955
-
956
- Shape:
957
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
958
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
959
- :math:`L_{out}=\text{output\_size}`.
960
-
961
- Examples:
962
-
963
- >>> import brainstate as bst
964
- >>> # target output size of 5
965
- >>> m = AdaptiveMaxPool1d(5)
966
- >>> input = bst.random.randn(1, 64, 8)
967
- >>> output = m(input)
968
- >>> output.shape
969
- (1, 5, 8)
970
-
971
- Parameters
972
- ----------
973
- in_size: Sequence of int
974
- The shape of the input tensor.
975
- target_size: int, sequence of int
976
- The target output shape.
977
- channel_axis: int, optional
978
- Axis of the spatial channels for which pooling is skipped.
979
- If ``None``, there is no channel axis.
980
- name: str
981
- The class name.
982
- mode: Mode
983
- The computing mode.
984
- """
985
- __module__ = 'brainstate.nn'
986
-
987
- def __init__(self,
988
- target_size: Union[int, Sequence[int]],
989
- channel_axis: Optional[int] = -1,
990
- name: Optional[str] = None,
991
- mode: Optional[Mode] = None,
992
- in_size: Optional[Sequence[int]] = None, ):
993
- super().__init__(in_size=in_size,
994
- target_size=target_size,
995
- channel_axis=channel_axis,
996
- num_spatial_dims=1,
997
- operation=jnp.mean,
998
- name=name,
999
- mode=mode)
1000
-
1001
-
1002
- class AdaptiveAvgPool2d(_AdaptivePool):
1003
- r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
1004
-
1005
- The output is of size :math:`H_{out} \times W_{out}`, for any input size.
1006
- The number of output features is equal to the number of input planes.
1007
-
1008
- Shape:
1009
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
1010
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1011
- :math:`(H_{out}, W_{out})=\text{output\_size}`.
1012
-
1013
- Examples:
1014
-
1015
- >>> import brainstate as bst
1016
- >>> # target output size of 5x7
1017
- >>> m = AdaptiveMaxPool2d((5, 7))
1018
- >>> input = bst.random.randn(1, 8, 9, 64)
1019
- >>> output = m(input)
1020
- >>> output.shape
1021
- (1, 5, 7, 64)
1022
- >>> # target output size of 7x7 (square)
1023
- >>> m = AdaptiveMaxPool2d(7)
1024
- >>> input = bst.random.randn(1, 10, 9, 64)
1025
- >>> output = m(input)
1026
- >>> output.shape
1027
- (1, 7, 7, 64)
1028
- >>> # target output size of 10x7
1029
- >>> m = AdaptiveMaxPool2d((None, 7))
1030
- >>> input = bst.random.randn(1, 10, 9, 64)
1031
- >>> output = m(input)
1032
- >>> output.shape
1033
- (1, 10, 7, 64)
1034
-
1035
- Parameters
1036
- ----------
1037
- in_size: Sequence of int
1038
- The shape of the input tensor.
1039
- target_size: int, sequence of int
1040
- The target output shape.
1041
- channel_axis: int, optional
1042
- Axis of the spatial channels for which pooling is skipped.
1043
- If ``None``, there is no channel axis.
1044
- name: str
1045
- The class name.
1046
- mode: Mode
1047
- The computing mode.
1048
- """
1049
- __module__ = 'brainstate.nn'
1050
-
1051
- def __init__(self,
1052
- target_size: Union[int, Sequence[int]],
1053
- channel_axis: Optional[int] = -1,
1054
- name: Optional[str] = None,
1055
- mode: Optional[Mode] = None,
1056
- in_size: Optional[Sequence[int]] = None, ):
1057
- super().__init__(in_size=in_size,
1058
- target_size=target_size,
1059
- channel_axis=channel_axis,
1060
- num_spatial_dims=2,
1061
- operation=jnp.mean,
1062
- name=name,
1063
- mode=mode)
1064
-
1065
-
1066
- class AdaptiveAvgPool3d(_AdaptivePool):
1067
- r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1068
-
1069
- The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1070
- The number of output features is equal to the number of input planes.
1071
-
1072
- Shape:
1073
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1074
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1075
- where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1076
-
1077
- Examples:
1078
-
1079
- >>> import brainstate as bst
1080
- >>> # target output size of 5x7x9
1081
- >>> m = AdaptiveMaxPool3d((5, 7, 9))
1082
- >>> input = bst.random.randn(1, 8, 9, 10, 64)
1083
- >>> output = m(input)
1084
- >>> output.shape
1085
- (1, 5, 7, 9, 64)
1086
- >>> # target output size of 7x7x7 (cube)
1087
- >>> m = AdaptiveMaxPool3d(7)
1088
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1089
- >>> output = m(input)
1090
- >>> output.shape
1091
- (1, 7, 7, 7, 64)
1092
- >>> # target output size of 7x9x8
1093
- >>> m = AdaptiveMaxPool3d((7, None, None))
1094
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1095
- >>> output = m(input)
1096
- >>> output.shape
1097
- (1, 7, 9, 8, 64)
1098
-
1099
- Parameters
1100
- ----------
1101
- in_size: Sequence of int
1102
- The shape of the input tensor.
1103
- target_size: int, sequence of int
1104
- The target output shape.
1105
- channel_axis: int, optional
1106
- Axis of the spatial channels for which pooling is skipped.
1107
- If ``None``, there is no channel axis.
1108
- name: str
1109
- The class name.
1110
- mode: Mode
1111
- The computing mode.
1112
- """
1113
- __module__ = 'brainstate.nn'
1114
-
1115
- def __init__(self,
1116
- target_size: Union[int, Sequence[int]],
1117
- channel_axis: Optional[int] = -1,
1118
- name: Optional[str] = None,
1119
- mode: Optional[Mode] = None,
1120
- in_size: Optional[Sequence[int]] = None, ):
1121
- super().__init__(in_size=in_size,
1122
- target_size=target_size,
1123
- channel_axis=channel_axis,
1124
- num_spatial_dims=3,
1125
- operation=jnp.mean,
1126
- name=name,
1127
- mode=mode)
1128
-
1129
-
1130
- class AdaptiveMaxPool1d(_AdaptivePool):
1131
- """Adaptive one-dimensional maximum down-sampling.
1132
-
1133
- Parameters
1134
- ----------
1135
- in_size: Sequence of int
1136
- The shape of the input tensor.
1137
- target_size: int, sequence of int
1138
- The target output shape.
1139
- channel_axis: int, optional
1140
- Axis of the spatial channels for which pooling is skipped.
1141
- If ``None``, there is no channel axis.
1142
- name: str
1143
- The class name.
1144
- mode: Mode
1145
- The computing mode.
1146
- """
1147
- __module__ = 'brainstate.nn'
1148
-
1149
- def __init__(self,
1150
- target_size: Union[int, Sequence[int]],
1151
- channel_axis: Optional[int] = -1,
1152
- name: Optional[str] = None,
1153
- mode: Optional[Mode] = None,
1154
- in_size: Optional[Sequence[int]] = None, ):
1155
- super().__init__(in_size=in_size,
1156
- target_size=target_size,
1157
- channel_axis=channel_axis,
1158
- num_spatial_dims=1,
1159
- operation=jnp.max,
1160
- name=name,
1161
- mode=mode)
1162
-
1163
-
1164
- class AdaptiveMaxPool2d(_AdaptivePool):
1165
- """Adaptive two-dimensional maximum down-sampling.
1166
-
1167
- Parameters
1168
- ----------
1169
- in_size: Sequence of int
1170
- The shape of the input tensor.
1171
- target_size: int, sequence of int
1172
- The target output shape.
1173
- channel_axis: int, optional
1174
- Axis of the spatial channels for which pooling is skipped.
1175
- If ``None``, there is no channel axis.
1176
- name: str
1177
- The class name.
1178
- mode: Mode
1179
- The computing mode.
1180
- """
1181
- __module__ = 'brainstate.nn'
1182
-
1183
- def __init__(self,
1184
- target_size: Union[int, Sequence[int]],
1185
- channel_axis: Optional[int] = -1,
1186
- name: Optional[str] = None,
1187
- mode: Optional[Mode] = None,
1188
- in_size: Optional[Sequence[int]] = None, ):
1189
- super().__init__(in_size=in_size,
1190
- target_size=target_size,
1191
- channel_axis=channel_axis,
1192
- num_spatial_dims=2,
1193
- operation=jnp.max,
1194
- name=name,
1195
- mode=mode)
1196
-
1197
-
1198
- class AdaptiveMaxPool3d(_AdaptivePool):
1199
- """Adaptive three-dimensional maximum down-sampling.
1200
-
1201
- Parameters
1202
- ----------
1203
- in_size: Sequence of int
1204
- The shape of the input tensor.
1205
- target_size: int, sequence of int
1206
- The target output shape.
1207
- channel_axis: int, optional
1208
- Axis of the spatial channels for which pooling is skipped.
1209
- If ``None``, there is no channel axis.
1210
- name: str
1211
- The class name.
1212
- mode: Mode
1213
- The computing mode.
1214
- """
1215
- __module__ = 'brainstate.nn'
1216
-
1217
- def __init__(self,
1218
- target_size: Union[int, Sequence[int]],
1219
- channel_axis: Optional[int] = -1,
1220
- name: Optional[str] = None,
1221
- mode: Optional[Mode] = None,
1222
- in_size: Optional[Sequence[int]] = None, ):
1223
- super().__init__(in_size=in_size,
1224
- target_size=target_size,
1225
- channel_axis=channel_axis,
1226
- num_spatial_dims=3,
1227
- operation=jnp.max,
1228
- name=name,
1229
- mode=mode)