brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,1229 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from __future__ import annotations
19
-
20
- import functools
21
- from typing import Sequence, Optional
22
- from typing import Union, Tuple, Callable, List
23
-
24
- import brainunit as u
25
- import jax
26
- import jax.numpy as jnp
27
- import numpy as np
28
-
29
- from ._base import DnnLayer, ExplicitInOutSize
30
- from .. import environ
31
- from ..mixin import Mode
32
- from brainstate.typing import Size
33
-
34
- __all__ = [
35
- 'Flatten', 'Unflatten',
36
-
37
- 'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
38
- 'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
39
-
40
- 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
41
- 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
42
- ]
43
-
44
-
45
- class Flatten(DnnLayer, ExplicitInOutSize):
46
- r"""
47
- Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
48
-
49
- Shape:
50
- - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
51
- where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
52
- number of dimensions including none.
53
- - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
54
-
55
- Args:
56
- in_size: Sequence of int. The shape of the input tensor.
57
- start_axis: first dim to flatten (default = 1).
58
- end_axis: last dim to flatten (default = -1).
59
-
60
- Examples::
61
- >>> import brainstate as bst
62
- >>> inp = bst.random.randn(32, 1, 5, 5)
63
- >>> # With default parameters
64
- >>> m = Flatten()
65
- >>> output = m(inp)
66
- >>> output.shape
67
- (32, 25)
68
- >>> # With non-default parameters
69
- >>> m = Flatten(0, 2)
70
- >>> output = m(inp)
71
- >>> output.shape
72
- (160, 5)
73
- """
74
- __module__ = 'brainstate.nn'
75
-
76
- def __init__(
77
- self,
78
- start_axis: int = 0,
79
- end_axis: int = -1,
80
- in_size: Optional[Size] = None
81
- ) -> None:
82
- super().__init__()
83
- self.start_axis = start_axis
84
- self.end_axis = end_axis
85
-
86
- if in_size is not None:
87
- self.in_size = tuple(in_size)
88
- y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
89
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
90
- self.out_size = y.shape
91
-
92
- def update(self, x):
93
- if self._in_size is None:
94
- start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
95
- else:
96
- assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
97
- dim_diff = x.ndim - len(self.in_size)
98
- if self.in_size != x.shape[dim_diff:]:
99
- raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
100
- if self.start_axis >= 0:
101
- start_axis = self.start_axis + dim_diff
102
- else:
103
- start_axis = x.ndim + self.start_axis
104
- return u.math.flatten(x, start_axis, self.end_axis)
105
-
106
- def __repr__(self) -> str:
107
- return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
108
-
109
-
110
- class Unflatten(DnnLayer, ExplicitInOutSize):
111
- r"""
112
- Unflatten a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
113
-
114
- * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
115
- be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
116
-
117
- * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
118
- a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
119
- (tuple of `(name, size)` tuples) for `NamedTensor` input.
120
-
121
- Shape:
122
- - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
123
- dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
124
- - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
125
- :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
126
-
127
- Args:
128
- axis: int, Dimension to be unflattened.
129
- sizes: Sequence of int. New shape of the unflattened dimension.
130
- in_size: Sequence of int. The shape of the input tensor.
131
- """
132
- __module__ = 'brainstate.nn'
133
-
134
- def __init__(
135
- self,
136
- axis: int,
137
- sizes: Size,
138
- mode: Mode = None,
139
- name: str = None,
140
- in_size: Optional[Size] = None
141
- ) -> None:
142
- super().__init__(mode=mode, name=name)
143
-
144
- self.axis = axis
145
- self.sizes = sizes
146
- if isinstance(sizes, (tuple, list)):
147
- for idx, elem in enumerate(sizes):
148
- if not isinstance(elem, int):
149
- raise TypeError("unflattened sizes must be tuple of ints, " +
150
- "but found element of type {} at pos {}".format(type(elem).__name__, idx))
151
- else:
152
- raise TypeError("unflattened sizes must be tuple or list, but found type {}".format(type(sizes).__name__))
153
-
154
- if in_size is not None:
155
- self.in_size = tuple(in_size)
156
- y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
157
- jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
158
- self.out_size = y.shape
159
-
160
- def update(self, x):
161
- return u.math.unflatten(x, self.axis, self.sizes)
162
-
163
- def __repr__(self):
164
- return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
165
-
166
-
167
- class _MaxPool(DnnLayer, ExplicitInOutSize):
168
- def __init__(
169
- self,
170
- init_value: float,
171
- computation: Callable,
172
- pool_dim: int,
173
- kernel_size: Size,
174
- stride: Union[int, Sequence[int]] = None,
175
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
176
- channel_axis: Optional[int] = -1,
177
- mode: Mode = None,
178
- name: Optional[str] = None,
179
- in_size: Optional[Size] = None,
180
- ):
181
- super().__init__(name=name, mode=mode)
182
-
183
- self.init_value = init_value
184
- self.computation = computation
185
- self.pool_dim = pool_dim
186
-
187
- # kernel_size
188
- if isinstance(kernel_size, int):
189
- kernel_size = (kernel_size,) * pool_dim
190
- elif isinstance(kernel_size, Sequence):
191
- assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
192
- assert all([isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
193
- if len(kernel_size) != pool_dim:
194
- raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
195
- else:
196
- raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
197
- self.kernel_size = kernel_size
198
-
199
- # stride
200
- if stride is None:
201
- stride = kernel_size
202
- if isinstance(stride, int):
203
- stride = (stride,) * pool_dim
204
- elif isinstance(stride, Sequence):
205
- assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
206
- assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
207
- if len(stride) != pool_dim:
208
- raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
209
- else:
210
- raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
211
- self.stride = stride
212
-
213
- # padding
214
- if isinstance(padding, str):
215
- if padding not in ("SAME", "VALID"):
216
- raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
217
- elif isinstance(padding, int):
218
- padding = [(padding, padding) for _ in range(pool_dim)]
219
- elif isinstance(padding, (list, tuple)):
220
- if isinstance(padding[0], int):
221
- if len(padding) == pool_dim:
222
- padding = [(x, x) for x in padding]
223
- else:
224
- raise ValueError(f'If padding is a sequence of ints, it '
225
- f'should has the length of {pool_dim}.')
226
- else:
227
- if not all([isinstance(x, (tuple, list)) for x in padding]):
228
- raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
229
- if not all([len(x) == 2 for x in padding]):
230
- raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
231
- if len(padding) == 1:
232
- padding = tuple(padding) * pool_dim
233
- assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
234
- else:
235
- raise ValueError
236
- self.padding = padding
237
-
238
- # channel_axis
239
- assert channel_axis is None or isinstance(channel_axis, int), \
240
- f'channel_axis should be an int, but got {channel_axis}'
241
- self.channel_axis = channel_axis
242
-
243
- # in & out shapes
244
- if in_size is not None:
245
- in_size = tuple(in_size)
246
- self.in_size = in_size
247
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
248
- self.out_size = y.shape[1:]
249
-
250
- def update(self, x):
251
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
252
- if x.ndim < x_dim:
253
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
254
- window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
255
- stride = self._infer_shape(x.ndim, self.stride, 1)
256
- padding = (self.padding if isinstance(self.padding, str) else
257
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
258
- r = jax.lax.reduce_window(
259
- x,
260
- init_value=self.init_value,
261
- computation=self.computation,
262
- window_dimensions=window_shape,
263
- window_strides=stride,
264
- padding=padding
265
- )
266
- return r
267
-
268
- def _infer_shape(self, x_dim, inputs, element):
269
- channel_axis = self.channel_axis
270
- if channel_axis and not 0 <= abs(channel_axis) < x_dim:
271
- raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
272
- if channel_axis and channel_axis < 0:
273
- channel_axis = x_dim + channel_axis
274
- all_dims = list(range(x_dim))
275
- if channel_axis is not None:
276
- all_dims.pop(channel_axis)
277
- pool_dims = all_dims[-self.pool_dim:]
278
- results = [element] * x_dim
279
- for i, dim in enumerate(pool_dims):
280
- results[dim] = inputs[i]
281
- return results
282
-
283
-
284
- class _AvgPool(_MaxPool):
285
- def update(self, x):
286
- x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
287
- if x.ndim < x_dim:
288
- raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
289
- dims = self._infer_shape(x.ndim, self.kernel_size, 1)
290
- stride = self._infer_shape(x.ndim, self.stride, 1)
291
- padding = (self.padding if isinstance(self.padding, str) else
292
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
293
- pooled = jax.lax.reduce_window(x,
294
- init_value=self.init_value,
295
- computation=self.computation,
296
- window_dimensions=dims,
297
- window_strides=stride,
298
- padding=padding)
299
- if padding == "VALID":
300
- # Avoid the extra reduce_window.
301
- return pooled / np.prod(dims)
302
- else:
303
- # Count the number of valid entries at each input point, then use that for
304
- # computing average. Assumes that any two arrays of same shape will be
305
- # padded the same.
306
- window_counts = jax.lax.reduce_window(jnp.ones_like(x),
307
- init_value=self.init_value,
308
- computation=self.computation,
309
- window_dimensions=dims,
310
- window_strides=stride,
311
- padding=padding)
312
- assert pooled.shape == window_counts.shape
313
- return pooled / window_counts
314
-
315
-
316
- class MaxPool1d(_MaxPool):
317
- r"""Applies a 1D max pooling over an input signal composed of several input planes.
318
-
319
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`
320
- and output :math:`(N, L_{out}, C)` can be precisely described as:
321
-
322
- .. math::
323
- out(N_i, k, C_j) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
324
- input(N_i, stride \times k + m, C_j)
325
-
326
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
327
- for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
328
- sliding window. This `link`_ has a nice visualization of the pooling parameters.
329
-
330
- Shape:
331
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
332
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
333
-
334
- .. math::
335
- L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
336
- \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
337
-
338
-
339
- Examples::
340
-
341
- >>> import brainstate as bst
342
- >>> # pool of size=3, stride=2
343
- >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
344
- >>> input = bst.random.randn(20, 50, 16)
345
- >>> output = m(input)
346
- >>> output.shape
347
- (20, 24, 16)
348
-
349
- Parameters
350
- ----------
351
- in_size: Sequence of int
352
- The shape of the input tensor.
353
- kernel_size: int, sequence of int
354
- An integer, or a sequence of integers defining the window to reduce over.
355
- stride: int, sequence of int
356
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
357
- padding: str, int, sequence of tuple
358
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
359
- of n `(low, high)` integer pairs that give the padding to apply before
360
- and after each spatial dimension.
361
- channel_axis: int, optional
362
- Axis of the spatial channels for which pooling is skipped.
363
- If ``None``, there is no channel axis.
364
- mode: Mode
365
- The computation mode.
366
- name: optional, str
367
- The object name.
368
-
369
- .. _link:
370
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
371
- """
372
- __module__ = 'brainstate.nn'
373
-
374
- def __init__(
375
- self,
376
- kernel_size: Size,
377
- stride: Union[int, Sequence[int]] = None,
378
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
379
- channel_axis: Optional[int] = -1,
380
- mode: Mode = None,
381
- name: Optional[str] = None,
382
- in_size: Optional[Size] = None,
383
- ):
384
- super().__init__(in_size=in_size,
385
- init_value=-jax.numpy.inf,
386
- computation=jax.lax.max,
387
- pool_dim=1,
388
- kernel_size=kernel_size,
389
- stride=stride,
390
- padding=padding,
391
- channel_axis=channel_axis,
392
- name=name,
393
- mode=mode)
394
-
395
-
396
- class MaxPool2d(_MaxPool):
397
- r"""Applies a 2D max pooling over an input signal composed of several input planes.
398
-
399
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
400
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
401
- can be precisely described as:
402
-
403
- .. math::
404
- \begin{aligned}
405
- out(N_i, h, w, C_j) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
406
- & \text{input}(N_i, \text{stride[0]} \times h + m,
407
- \text{stride[1]} \times w + n, C_j)
408
- \end{aligned}
409
-
410
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
411
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
412
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
413
-
414
-
415
- Shape:
416
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
417
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
418
-
419
- .. math::
420
- H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
421
- \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
422
-
423
- .. math::
424
- W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
425
- \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
426
-
427
- Examples::
428
-
429
- >>> import brainstate as bst
430
- >>> # pool of square window of size=3, stride=2
431
- >>> m = MaxPool2d(3, stride=2)
432
- >>> # pool of non-square window
433
- >>> m = MaxPool2d((3, 2), stride=(2, 1), channel_axis=-1)
434
- >>> input = bst.random.randn(20, 50, 32, 16)
435
- >>> output = m(input)
436
- >>> output.shape
437
- (20, 24, 31, 16)
438
-
439
- .. _link:
440
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
441
-
442
- Parameters
443
- ----------
444
- in_size: Sequence of int
445
- The shape of the input tensor.
446
- kernel_size: int, sequence of int
447
- An integer, or a sequence of integers defining the window to reduce over.
448
- stride: int, sequence of int
449
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
450
- padding: str, int, sequence of tuple
451
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
452
- of n `(low, high)` integer pairs that give the padding to apply before
453
- and after each spatial dimension.
454
- channel_axis: int, optional
455
- Axis of the spatial channels for which pooling is skipped.
456
- If ``None``, there is no channel axis.
457
- mode: Mode
458
- The computation mode.
459
- name: optional, str
460
- The object name.
461
-
462
- """
463
- __module__ = 'brainstate.nn'
464
-
465
- def __init__(
466
- self,
467
- kernel_size: Size,
468
- stride: Union[int, Sequence[int]] = None,
469
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
470
- channel_axis: Optional[int] = -1,
471
- mode: Mode = None,
472
- name: Optional[str] = None,
473
- in_size: Optional[Size] = None,
474
- ):
475
- super().__init__(in_size=in_size,
476
- init_value=-jax.numpy.inf,
477
- computation=jax.lax.max,
478
- pool_dim=2,
479
- kernel_size=kernel_size,
480
- stride=stride,
481
- padding=padding,
482
- channel_axis=channel_axis,
483
- name=name, mode=mode)
484
-
485
-
486
- class MaxPool3d(_MaxPool):
487
- r"""Applies a 3D max pooling over an input signal composed of several input planes.
488
-
489
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
490
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
491
- can be precisely described as:
492
-
493
- .. math::
494
- \begin{aligned}
495
- \text{out}(N_i, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
496
- & \text{input}(N_i, \text{stride[0]} \times d + k,
497
- \text{stride[1]} \times h + m, \text{stride[2]} \times w + n, C_j)
498
- \end{aligned}
499
-
500
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
501
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
502
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
503
-
504
-
505
- Shape:
506
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
507
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
508
-
509
- .. math::
510
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
511
- (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
512
-
513
- .. math::
514
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
515
- (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
516
-
517
- .. math::
518
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
519
- (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
520
-
521
- Examples::
522
-
523
- >>> import brainstate as bst
524
- >>> # pool of square window of size=3, stride=2
525
- >>> m = MaxPool3d(3, stride=2)
526
- >>> # pool of non-square window
527
- >>> m = MaxPool3d((3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
528
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
529
- >>> output = m(input)
530
- >>> output.shape
531
- (20, 24, 43, 15, 16)
532
-
533
- .. _link:
534
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
535
-
536
- Parameters
537
- ----------
538
- in_size: Sequence of int
539
- The shape of the input tensor.
540
- kernel_size: int, sequence of int
541
- An integer, or a sequence of integers defining the window to reduce over.
542
- stride: int, sequence of int
543
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
544
- padding: str, int, sequence of tuple
545
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
546
- of n `(low, high)` integer pairs that give the padding to apply before
547
- and after each spatial dimension.
548
- channel_axis: int, optional
549
- Axis of the spatial channels for which pooling is skipped.
550
- If ``None``, there is no channel axis.
551
- mode: Mode
552
- The computation mode.
553
- name: optional, str
554
- The object name.
555
-
556
- """
557
- __module__ = 'brainstate.nn'
558
-
559
- def __init__(
560
- self,
561
- kernel_size: Size,
562
- stride: Union[int, Sequence[int]] = None,
563
- padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
564
- channel_axis: Optional[int] = -1,
565
- mode: Mode = None,
566
- name: Optional[str] = None,
567
- in_size: Optional[Size] = None,
568
- ):
569
- super().__init__(in_size=in_size,
570
- init_value=-jax.numpy.inf,
571
- computation=jax.lax.max,
572
- pool_dim=3,
573
- kernel_size=kernel_size,
574
- stride=stride,
575
- padding=padding,
576
- channel_axis=channel_axis,
577
- name=name, mode=mode)
578
-
579
-
580
- class AvgPool1d(_AvgPool):
581
- r"""Applies a 1D average pooling over an input signal composed of several input planes.
582
-
583
- In the simplest case, the output value of the layer with input size :math:`(N, L, C)`,
584
- output :math:`(N, L_{out}, C)` and :attr:`kernel_size` :math:`k`
585
- can be precisely described as:
586
-
587
- .. math::
588
-
589
- \text{out}(N_i, l, C_j) = \frac{1}{k} \sum_{m=0}^{k-1}
590
- \text{input}(N_i, \text{stride} \times l + m, C_j)
591
-
592
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
593
- for :attr:`padding` number of points.
594
-
595
- Shape:
596
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
597
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
598
-
599
- .. math::
600
- L_{out} = \left\lfloor \frac{L_{in} +
601
- 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
602
-
603
- Examples::
604
-
605
- >>> import brainstate as bst
606
- >>> # pool with window of size=3, stride=2
607
- >>> m = AvgPool1d(3, stride=2)
608
- >>> input = bst.random.randn(20, 50, 16)
609
- >>> m(input).shape
610
- (20, 24, 16)
611
-
612
- Parameters
613
- ----------
614
- in_size: Sequence of int
615
- The shape of the input tensor.
616
- kernel_size: int, sequence of int
617
- An integer, or a sequence of integers defining the window to reduce over.
618
- stride: int, sequence of int
619
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
620
- padding: str, int, sequence of tuple
621
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
622
- of n `(low, high)` integer pairs that give the padding to apply before
623
- and after each spatial dimension.
624
- channel_axis: int, optional
625
- Axis of the spatial channels for which pooling is skipped.
626
- If ``None``, there is no channel axis.
627
- mode: Mode
628
- The computation mode.
629
- name: optional, str
630
- The object name.
631
-
632
- """
633
- __module__ = 'brainstate.nn'
634
-
635
- def __init__(
636
- self,
637
- kernel_size: Size,
638
- stride: Union[int, Sequence[int]] = 1,
639
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
640
- channel_axis: Optional[int] = -1,
641
- mode: Mode = None,
642
- name: Optional[str] = None,
643
- in_size: Optional[Size] = None,
644
- ):
645
- super().__init__(in_size=in_size,
646
- init_value=0.,
647
- computation=jax.lax.add,
648
- pool_dim=1,
649
- kernel_size=kernel_size,
650
- stride=stride,
651
- padding=padding,
652
- channel_axis=channel_axis,
653
- name=name,
654
- mode=mode)
655
-
656
-
657
- class AvgPool2d(_AvgPool):
658
- r"""Applies a 2D average pooling over an input signal composed of several input planes.
659
-
660
- In the simplest case, the output value of the layer with input size :math:`(N, H, W, C)`,
661
- output :math:`(N, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kH, kW)`
662
- can be precisely described as:
663
-
664
- .. math::
665
-
666
- out(N_i, h, w, C_j) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
667
- input(N_i, stride[0] \times h + m, stride[1] \times w + n, C_j)
668
-
669
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
670
- for :attr:`padding` number of points.
671
-
672
- Shape:
673
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
674
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
675
-
676
- .. math::
677
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
678
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
679
-
680
- .. math::
681
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
682
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
683
-
684
- Examples::
685
-
686
- >>> import brainstate as bst
687
- >>> # pool of square window of size=3, stride=2
688
- >>> m = AvgPool2d(3, stride=2)
689
- >>> # pool of non-square window
690
- >>> m = AvgPool2d((3, 2), stride=(2, 1))
691
- >>> input = bst.random.randn(20, 50, 32, , 16)
692
- >>> output = m(input)
693
- >>> output.shape
694
- (20, 24, 31, 16)
695
-
696
- Parameters
697
- ----------
698
- in_size: Sequence of int
699
- The shape of the input tensor.
700
- kernel_size: int, sequence of int
701
- An integer, or a sequence of integers defining the window to reduce over.
702
- stride: int, sequence of int
703
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
704
- padding: str, int, sequence of tuple
705
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
706
- of n `(low, high)` integer pairs that give the padding to apply before
707
- and after each spatial dimension.
708
- channel_axis: int, optional
709
- Axis of the spatial channels for which pooling is skipped.
710
- If ``None``, there is no channel axis.
711
- mode: Mode
712
- The computation mode.
713
- name: optional, str
714
- The object name.
715
- """
716
- __module__ = 'brainstate.nn'
717
-
718
- def __init__(
719
- self,
720
- kernel_size: Size,
721
- stride: Union[int, Sequence[int]] = 1,
722
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
723
- channel_axis: Optional[int] = -1,
724
- mode: Mode = None,
725
- name: Optional[str] = None,
726
- in_size: Optional[Size] = None,
727
- ):
728
- super().__init__(in_size=in_size,
729
- init_value=0.,
730
- computation=jax.lax.add,
731
- pool_dim=2,
732
- kernel_size=kernel_size,
733
- stride=stride,
734
- padding=padding,
735
- channel_axis=channel_axis,
736
- name=name,
737
- mode=mode)
738
-
739
-
740
- class AvgPool3d(_AvgPool):
741
- r"""Applies a 3D average pooling over an input signal composed of several input planes.
742
-
743
-
744
- In the simplest case, the output value of the layer with input size :math:`(N, D, H, W, C)`,
745
- output :math:`(N, D_{out}, H_{out}, W_{out}, C)` and :attr:`kernel_size` :math:`(kD, kH, kW)`
746
- can be precisely described as:
747
-
748
- .. math::
749
- \begin{aligned}
750
- \text{out}(N_i, d, h, w, C_j) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
751
- & \frac{\text{input}(N_i, \text{stride}[0] \times d + k,
752
- \text{stride}[1] \times h + m, \text{stride}[2] \times w + n, C_j)}
753
- {kD \times kH \times kW}
754
- \end{aligned}
755
-
756
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
757
- for :attr:`padding` number of points.
758
-
759
- Shape:
760
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
761
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or
762
- :math:`(D_{out}, H_{out}, W_{out}, C)`, where
763
-
764
- .. math::
765
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
766
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
767
-
768
- .. math::
769
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
770
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
771
-
772
- .. math::
773
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
774
- \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
775
-
776
- Examples::
777
-
778
- >>> import brainstate as bst
779
- >>> # pool of square window of size=3, stride=2
780
- >>> m = AvgPool3d(3, stride=2)
781
- >>> # pool of non-square window
782
- >>> m = AvgPool3d((3, 2, 2), stride=(2, 1, 2))
783
- >>> input = bst.random.randn(20, 50, 44, 31, 16)
784
- >>> output = m(input)
785
- >>> output.shape
786
- (20, 24, 43, 15, 16)
787
-
788
- Parameters
789
- ----------
790
- in_size: Sequence of int
791
- The shape of the input tensor.
792
- kernel_size: int, sequence of int
793
- An integer, or a sequence of integers defining the window to reduce over.
794
- stride: int, sequence of int
795
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
796
- padding: str, int, sequence of tuple
797
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
798
- of n `(low, high)` integer pairs that give the padding to apply before
799
- and after each spatial dimension.
800
- channel_axis: int, optional
801
- Axis of the spatial channels for which pooling is skipped.
802
- If ``None``, there is no channel axis.
803
- mode: Mode
804
- The computation mode.
805
- name: optional, str
806
- The object name.
807
-
808
- """
809
- __module__ = 'brainstate.nn'
810
-
811
- def __init__(
812
- self,
813
- kernel_size: Size,
814
- stride: Union[int, Sequence[int]] = 1,
815
- padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
816
- channel_axis: Optional[int] = -1,
817
- mode: Mode = None,
818
- name: Optional[str] = None,
819
- in_size: Optional[Size] = None,
820
- ):
821
- super().__init__(in_size=in_size,
822
- init_value=0.,
823
- computation=jax.lax.add,
824
- pool_dim=3,
825
- kernel_size=kernel_size,
826
- stride=stride,
827
- padding=padding,
828
- channel_axis=channel_axis,
829
- name=name,
830
- mode=mode)
831
-
832
-
833
- def _adaptive_pool1d(x, target_size: int, operation: Callable):
834
- """Adaptive pool 1D.
835
-
836
- Args:
837
- x: The input. Should be a JAX array of shape `(dim,)`.
838
- target_size: The shape of the output after the pooling operation `(target_size,)`.
839
- operation: The pooling operation to be performed on the input array.
840
-
841
- Returns:
842
- A JAX array of shape `(target_size, )`.
843
- """
844
- size = jnp.size(x)
845
- num_head_arrays = size % target_size
846
- num_block = size // target_size
847
- if num_head_arrays != 0:
848
- head_end_index = num_head_arrays * (num_block + 1)
849
- heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
850
- tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
851
- outs = jnp.concatenate([heads, tails])
852
- else:
853
- outs = jax.vmap(operation)(x.reshape(-1, num_block))
854
- return outs
855
-
856
-
857
- def _generate_vmap(fun: Callable, map_axes: List[int]):
858
- map_axes = sorted(map_axes)
859
- for axis in map_axes:
860
- fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
861
- return fun
862
-
863
-
864
- class _AdaptivePool(DnnLayer, ExplicitInOutSize):
865
- """General N dimensional adaptive down-sampling to a target shape.
866
-
867
- Parameters
868
- ----------
869
- in_size: Sequence of int
870
- The shape of the input tensor.
871
- target_size: int, sequence of int
872
- The target output shape.
873
- num_spatial_dims: int
874
- The number of spatial dimensions.
875
- channel_axis: int, optional
876
- Axis of the spatial channels for which pooling is skipped.
877
- If ``None``, there is no channel axis.
878
- operation: Callable
879
- The down-sampling operation.
880
- name: str
881
- The class name.
882
- mode: Mode
883
- The computing mode.
884
- """
885
-
886
- def __init__(
887
- self,
888
- in_size: Size,
889
- target_size: Size,
890
- num_spatial_dims: int,
891
- operation: Callable,
892
- channel_axis: Optional[int] = -1,
893
- name: Optional[str] = None,
894
- mode: Optional[Mode] = None,
895
- ):
896
- super().__init__(name=name, mode=mode)
897
-
898
- self.channel_axis = channel_axis
899
- self.operation = operation
900
- if isinstance(target_size, int):
901
- self.target_shape = (target_size,) * num_spatial_dims
902
- elif isinstance(target_size, Sequence) and (len(target_size) == num_spatial_dims):
903
- self.target_shape = target_size
904
- else:
905
- raise ValueError("`target_size` must either be an int or tuple of length "
906
- f"{num_spatial_dims} containing ints.")
907
-
908
- # in & out shapes
909
- if in_size is not None:
910
- in_size = tuple(in_size)
911
- self.in_size = in_size
912
- y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
913
- self.out_size = y.shape[1:]
914
-
915
- def update(self, x):
916
- """Input-output mapping.
917
-
918
- Parameters
919
- ----------
920
- x: Array
921
- Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
922
- or `(..., dim_1, dim_2)`.
923
- """
924
- # channel axis
925
- channel_axis = self.channel_axis
926
-
927
- if channel_axis:
928
- if not 0 <= abs(channel_axis) < x.ndim:
929
- raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
930
- if channel_axis < 0:
931
- channel_axis = x.ndim + channel_axis
932
- # input dimension
933
- if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
934
- raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
935
- f"dimensions (channel_axis={self.channel_axis}). "
936
- f"But got {x.ndim} dimensions.")
937
- # pooling dimensions
938
- pool_dims = list(range(x.ndim))
939
- if channel_axis:
940
- pool_dims.pop(channel_axis)
941
-
942
- # pooling
943
- for i, di in enumerate(pool_dims[-len(self.target_shape):]):
944
- poo_axes = [j for j in range(x.ndim) if j != di]
945
- op = _generate_vmap(_adaptive_pool1d, poo_axes)
946
- x = op(x, self.target_shape[i], self.operation)
947
- return x
948
-
949
-
950
- class AdaptiveAvgPool1d(_AdaptivePool):
951
- r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
952
-
953
- The output size is :math:`L_{out}`, for any input size.
954
- The number of output features is equal to the number of input planes.
955
-
956
- Shape:
957
- - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
958
- - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
959
- :math:`L_{out}=\text{output\_size}`.
960
-
961
- Examples:
962
-
963
- >>> import brainstate as bst
964
- >>> # target output size of 5
965
- >>> m = AdaptiveMaxPool1d(5)
966
- >>> input = bst.random.randn(1, 64, 8)
967
- >>> output = m(input)
968
- >>> output.shape
969
- (1, 5, 8)
970
-
971
- Parameters
972
- ----------
973
- in_size: Sequence of int
974
- The shape of the input tensor.
975
- target_size: int, sequence of int
976
- The target output shape.
977
- channel_axis: int, optional
978
- Axis of the spatial channels for which pooling is skipped.
979
- If ``None``, there is no channel axis.
980
- name: str
981
- The class name.
982
- mode: Mode
983
- The computing mode.
984
- """
985
- __module__ = 'brainstate.nn'
986
-
987
- def __init__(self,
988
- target_size: Union[int, Sequence[int]],
989
- channel_axis: Optional[int] = -1,
990
- name: Optional[str] = None,
991
- mode: Optional[Mode] = None,
992
- in_size: Optional[Sequence[int]] = None, ):
993
- super().__init__(in_size=in_size,
994
- target_size=target_size,
995
- channel_axis=channel_axis,
996
- num_spatial_dims=1,
997
- operation=jnp.mean,
998
- name=name,
999
- mode=mode)
1000
-
1001
-
1002
- class AdaptiveAvgPool2d(_AdaptivePool):
1003
- r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
1004
-
1005
- The output is of size :math:`H_{out} \times W_{out}`, for any input size.
1006
- The number of output features is equal to the number of input planes.
1007
-
1008
- Shape:
1009
- - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
1010
- - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1011
- :math:`(H_{out}, W_{out})=\text{output\_size}`.
1012
-
1013
- Examples:
1014
-
1015
- >>> import brainstate as bst
1016
- >>> # target output size of 5x7
1017
- >>> m = AdaptiveMaxPool2d((5, 7))
1018
- >>> input = bst.random.randn(1, 8, 9, 64)
1019
- >>> output = m(input)
1020
- >>> output.shape
1021
- (1, 5, 7, 64)
1022
- >>> # target output size of 7x7 (square)
1023
- >>> m = AdaptiveMaxPool2d(7)
1024
- >>> input = bst.random.randn(1, 10, 9, 64)
1025
- >>> output = m(input)
1026
- >>> output.shape
1027
- (1, 7, 7, 64)
1028
- >>> # target output size of 10x7
1029
- >>> m = AdaptiveMaxPool2d((None, 7))
1030
- >>> input = bst.random.randn(1, 10, 9, 64)
1031
- >>> output = m(input)
1032
- >>> output.shape
1033
- (1, 10, 7, 64)
1034
-
1035
- Parameters
1036
- ----------
1037
- in_size: Sequence of int
1038
- The shape of the input tensor.
1039
- target_size: int, sequence of int
1040
- The target output shape.
1041
- channel_axis: int, optional
1042
- Axis of the spatial channels for which pooling is skipped.
1043
- If ``None``, there is no channel axis.
1044
- name: str
1045
- The class name.
1046
- mode: Mode
1047
- The computing mode.
1048
- """
1049
- __module__ = 'brainstate.nn'
1050
-
1051
- def __init__(self,
1052
- target_size: Union[int, Sequence[int]],
1053
- channel_axis: Optional[int] = -1,
1054
- name: Optional[str] = None,
1055
- mode: Optional[Mode] = None,
1056
- in_size: Optional[Sequence[int]] = None, ):
1057
- super().__init__(in_size=in_size,
1058
- target_size=target_size,
1059
- channel_axis=channel_axis,
1060
- num_spatial_dims=2,
1061
- operation=jnp.mean,
1062
- name=name,
1063
- mode=mode)
1064
-
1065
-
1066
- class AdaptiveAvgPool3d(_AdaptivePool):
1067
- r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1068
-
1069
- The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1070
- The number of output features is equal to the number of input planes.
1071
-
1072
- Shape:
1073
- - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1074
- - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1075
- where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1076
-
1077
- Examples:
1078
-
1079
- >>> import brainstate as bst
1080
- >>> # target output size of 5x7x9
1081
- >>> m = AdaptiveMaxPool3d((5, 7, 9))
1082
- >>> input = bst.random.randn(1, 8, 9, 10, 64)
1083
- >>> output = m(input)
1084
- >>> output.shape
1085
- (1, 5, 7, 9, 64)
1086
- >>> # target output size of 7x7x7 (cube)
1087
- >>> m = AdaptiveMaxPool3d(7)
1088
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1089
- >>> output = m(input)
1090
- >>> output.shape
1091
- (1, 7, 7, 7, 64)
1092
- >>> # target output size of 7x9x8
1093
- >>> m = AdaptiveMaxPool3d((7, None, None))
1094
- >>> input = bst.random.randn(1, 10, 9, 8, 64)
1095
- >>> output = m(input)
1096
- >>> output.shape
1097
- (1, 7, 9, 8, 64)
1098
-
1099
- Parameters
1100
- ----------
1101
- in_size: Sequence of int
1102
- The shape of the input tensor.
1103
- target_size: int, sequence of int
1104
- The target output shape.
1105
- channel_axis: int, optional
1106
- Axis of the spatial channels for which pooling is skipped.
1107
- If ``None``, there is no channel axis.
1108
- name: str
1109
- The class name.
1110
- mode: Mode
1111
- The computing mode.
1112
- """
1113
- __module__ = 'brainstate.nn'
1114
-
1115
- def __init__(self,
1116
- target_size: Union[int, Sequence[int]],
1117
- channel_axis: Optional[int] = -1,
1118
- name: Optional[str] = None,
1119
- mode: Optional[Mode] = None,
1120
- in_size: Optional[Sequence[int]] = None, ):
1121
- super().__init__(in_size=in_size,
1122
- target_size=target_size,
1123
- channel_axis=channel_axis,
1124
- num_spatial_dims=3,
1125
- operation=jnp.mean,
1126
- name=name,
1127
- mode=mode)
1128
-
1129
-
1130
- class AdaptiveMaxPool1d(_AdaptivePool):
1131
- """Adaptive one-dimensional maximum down-sampling.
1132
-
1133
- Parameters
1134
- ----------
1135
- in_size: Sequence of int
1136
- The shape of the input tensor.
1137
- target_size: int, sequence of int
1138
- The target output shape.
1139
- channel_axis: int, optional
1140
- Axis of the spatial channels for which pooling is skipped.
1141
- If ``None``, there is no channel axis.
1142
- name: str
1143
- The class name.
1144
- mode: Mode
1145
- The computing mode.
1146
- """
1147
- __module__ = 'brainstate.nn'
1148
-
1149
- def __init__(self,
1150
- target_size: Union[int, Sequence[int]],
1151
- channel_axis: Optional[int] = -1,
1152
- name: Optional[str] = None,
1153
- mode: Optional[Mode] = None,
1154
- in_size: Optional[Sequence[int]] = None, ):
1155
- super().__init__(in_size=in_size,
1156
- target_size=target_size,
1157
- channel_axis=channel_axis,
1158
- num_spatial_dims=1,
1159
- operation=jnp.max,
1160
- name=name,
1161
- mode=mode)
1162
-
1163
-
1164
- class AdaptiveMaxPool2d(_AdaptivePool):
1165
- """Adaptive two-dimensional maximum down-sampling.
1166
-
1167
- Parameters
1168
- ----------
1169
- in_size: Sequence of int
1170
- The shape of the input tensor.
1171
- target_size: int, sequence of int
1172
- The target output shape.
1173
- channel_axis: int, optional
1174
- Axis of the spatial channels for which pooling is skipped.
1175
- If ``None``, there is no channel axis.
1176
- name: str
1177
- The class name.
1178
- mode: Mode
1179
- The computing mode.
1180
- """
1181
- __module__ = 'brainstate.nn'
1182
-
1183
- def __init__(self,
1184
- target_size: Union[int, Sequence[int]],
1185
- channel_axis: Optional[int] = -1,
1186
- name: Optional[str] = None,
1187
- mode: Optional[Mode] = None,
1188
- in_size: Optional[Sequence[int]] = None, ):
1189
- super().__init__(in_size=in_size,
1190
- target_size=target_size,
1191
- channel_axis=channel_axis,
1192
- num_spatial_dims=2,
1193
- operation=jnp.max,
1194
- name=name,
1195
- mode=mode)
1196
-
1197
-
1198
- class AdaptiveMaxPool3d(_AdaptivePool):
1199
- """Adaptive three-dimensional maximum down-sampling.
1200
-
1201
- Parameters
1202
- ----------
1203
- in_size: Sequence of int
1204
- The shape of the input tensor.
1205
- target_size: int, sequence of int
1206
- The target output shape.
1207
- channel_axis: int, optional
1208
- Axis of the spatial channels for which pooling is skipped.
1209
- If ``None``, there is no channel axis.
1210
- name: str
1211
- The class name.
1212
- mode: Mode
1213
- The computing mode.
1214
- """
1215
- __module__ = 'brainstate.nn'
1216
-
1217
- def __init__(self,
1218
- target_size: Union[int, Sequence[int]],
1219
- channel_axis: Optional[int] = -1,
1220
- name: Optional[str] = None,
1221
- mode: Optional[Mode] = None,
1222
- in_size: Optional[Sequence[int]] = None, ):
1223
- super().__init__(in_size=in_size,
1224
- target_size=target_size,
1225
- channel_axis=channel_axis,
1226
- num_spatial_dims=3,
1227
- operation=jnp.max,
1228
- name=name,
1229
- mode=mode)