brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -33,6 +33,8 @@ __all__ = [
33
33
 
34
34
  'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
35
35
  'MaxPool1d', 'MaxPool2d', 'MaxPool3d',
36
+ 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
37
+ 'LPPool1d', 'LPPool2d', 'LPPool3d',
36
38
 
37
39
  'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
38
40
  'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
@@ -49,13 +51,20 @@ class Flatten(Module):
49
51
  number of dimensions including none.
50
52
  - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
51
53
 
52
- Args:
53
- in_size: Sequence of int. The shape of the input tensor.
54
- start_axis: first dim to flatten (default = 1).
55
- end_axis: last dim to flatten (default = -1).
56
-
57
- Examples::
58
- >>> import brainstate as brainstate
54
+ Parameters
55
+ ----------
56
+ start_axis : int, optional
57
+ First dim to flatten (default = 0).
58
+ end_axis : int, optional
59
+ Last dim to flatten (default = -1).
60
+ in_size : Sequence of int, optional
61
+ The shape of the input tensor.
62
+
63
+ Examples
64
+ --------
65
+ .. code-block:: python
66
+
67
+ >>> import brainstate
59
68
  >>> inp = brainstate.random.randn(32, 1, 5, 5)
60
69
  >>> # With default parameters
61
70
  >>> m = Flatten()
@@ -100,9 +109,6 @@ class Flatten(Module):
100
109
  start_axis = x.ndim + self.start_axis
101
110
  return u.math.flatten(x, start_axis, self.end_axis)
102
111
 
103
- def __repr__(self) -> str:
104
- return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
105
-
106
112
 
107
113
  class Unflatten(Module):
108
114
  r"""
@@ -121,10 +127,16 @@ class Unflatten(Module):
121
127
  - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
122
128
  :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
123
129
 
124
- Args:
125
- axis: int, Dimension to be unflattened.
126
- sizes: Sequence of int. New shape of the unflattened dimension.
127
- in_size: Sequence of int. The shape of the input tensor.
130
+ Parameters
131
+ ----------
132
+ axis : int
133
+ Dimension to be unflattened.
134
+ sizes : Sequence of int
135
+ New shape of the unflattened dimension.
136
+ name : str, optional
137
+ The name of the module.
138
+ in_size : Sequence of int, optional
139
+ The shape of the input tensor.
128
140
  """
129
141
  __module__ = 'brainstate.nn'
130
142
 
@@ -156,9 +168,6 @@ class Unflatten(Module):
156
168
  def update(self, x):
157
169
  return u.math.unflatten(x, self.axis, self.sizes)
158
170
 
159
- def __repr__(self):
160
- return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
161
-
162
171
 
163
172
  class _MaxPool(Module):
164
173
  def __init__(
@@ -170,6 +179,7 @@ class _MaxPool(Module):
170
179
  stride: Union[int, Sequence[int]] = None,
171
180
  padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
172
181
  channel_axis: Optional[int] = -1,
182
+ return_indices: bool = False,
173
183
  name: Optional[str] = None,
174
184
  in_size: Optional[Size] = None,
175
185
  ):
@@ -178,6 +188,7 @@ class _MaxPool(Module):
178
188
  self.init_value = init_value
179
189
  self.computation = computation
180
190
  self.pool_dim = pool_dim
191
+ self.return_indices = return_indices
181
192
 
182
193
  # kernel_size
183
194
  if isinstance(kernel_size, int):
@@ -247,11 +258,18 @@ class _MaxPool(Module):
247
258
  x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
248
259
  if x.ndim < x_dim:
249
260
  raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
250
- window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
251
- stride = self._infer_shape(x.ndim, self.stride, 1)
252
- padding = (self.padding if isinstance(self.padding, str) else
253
- self._infer_shape(x.ndim, self.padding, element=(0, 0)))
254
- r = jax.lax.reduce_window(
261
+ window_shape = tuple(self._infer_shape(x.ndim, self.kernel_size, 1))
262
+ stride = tuple(self._infer_shape(x.ndim, self.stride, 1))
263
+ if isinstance(self.padding, str):
264
+ padding = tuple(jax.lax.padtype_to_pads(x.shape, window_shape, stride, self.padding))
265
+ else:
266
+ padding = tuple(self._infer_shape(x.ndim, self.padding, element=(0, 0)))
267
+
268
+ if self.return_indices:
269
+ # For returning indices, we need to use a custom implementation
270
+ return self._pooling_with_indices(x, window_shape, stride, padding)
271
+
272
+ return jax.lax.reduce_window(
255
273
  x,
256
274
  init_value=self.init_value,
257
275
  computation=self.computation,
@@ -259,7 +277,39 @@ class _MaxPool(Module):
259
277
  window_strides=stride,
260
278
  padding=padding
261
279
  )
262
- return r
280
+
281
+ def _pooling_with_indices(self, x, window_shape, stride, padding):
282
+ """Perform max pooling and return both pooled values and indices."""
283
+ total_size = x.size
284
+ flat_indices = jnp.arange(total_size, dtype=jnp.int32).reshape(x.shape)
285
+
286
+ init_val = jnp.asarray(self.init_value, dtype=x.dtype)
287
+ init_idx = jnp.array(total_size, dtype=flat_indices.dtype)
288
+
289
+ def reducer(acc, operand):
290
+ acc_val, acc_idx = acc
291
+ cur_val, cur_idx = operand
292
+
293
+ better = cur_val > acc_val
294
+ best_val = jnp.where(better, cur_val, acc_val)
295
+ best_idx = jnp.where(better, cur_idx, acc_idx)
296
+ tie = jnp.logical_and(cur_val == acc_val, cur_idx < acc_idx)
297
+ best_idx = jnp.where(tie, cur_idx, best_idx)
298
+
299
+ return best_val, best_idx
300
+
301
+ pooled, indices_result = jax.lax.reduce_window(
302
+ (x, flat_indices),
303
+ (init_val, init_idx),
304
+ reducer,
305
+ window_dimensions=window_shape,
306
+ window_strides=stride,
307
+ padding=padding
308
+ )
309
+
310
+ indices_result = jnp.where(indices_result == total_size, 0, indices_result)
311
+
312
+ return pooled, indices_result.astype(jnp.int32)
263
313
 
264
314
  def _infer_shape(self, x_dim, inputs, element):
265
315
  channel_axis = self.channel_axis
@@ -331,10 +381,33 @@ class MaxPool1d(_MaxPool):
331
381
  L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
332
382
  \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
333
383
 
334
-
335
- Examples::
336
-
337
- >>> import brainstate as brainstate
384
+ Parameters
385
+ ----------
386
+ kernel_size : int or sequence of int
387
+ An integer, or a sequence of integers defining the window to reduce over.
388
+ stride : int or sequence of int, optional
389
+ An integer, or a sequence of integers, representing the inter-window stride.
390
+ Default: kernel_size
391
+ padding : str, int or sequence of tuple, optional
392
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
393
+ of n `(low, high)` integer pairs that give the padding to apply before
394
+ and after each spatial dimension. Default: 'VALID'
395
+ channel_axis : int, optional
396
+ Axis of the spatial channels for which pooling is skipped.
397
+ If ``None``, there is no channel axis. Default: -1
398
+ return_indices : bool, optional
399
+ If True, will return the max indices along with the outputs.
400
+ Useful for MaxUnpool1d. Default: False
401
+ name : str, optional
402
+ The object name.
403
+ in_size : Sequence of int, optional
404
+ The shape of the input tensor.
405
+
406
+ Examples
407
+ --------
408
+ .. code-block:: python
409
+
410
+ >>> import brainstate
338
411
  >>> # pool of size=3, stride=2
339
412
  >>> m = MaxPool1d(3, stride=2, channel_axis=-1)
340
413
  >>> input = brainstate.random.randn(20, 50, 16)
@@ -342,24 +415,6 @@ class MaxPool1d(_MaxPool):
342
415
  >>> output.shape
343
416
  (20, 24, 16)
344
417
 
345
- Parameters
346
- ----------
347
- in_size: Sequence of int
348
- The shape of the input tensor.
349
- kernel_size: int, sequence of int
350
- An integer, or a sequence of integers defining the window to reduce over.
351
- stride: int, sequence of int
352
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
353
- padding: str, int, sequence of tuple
354
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
355
- of n `(low, high)` integer pairs that give the padding to apply before
356
- and after each spatial dimension.
357
- channel_axis: int, optional
358
- Axis of the spatial channels for which pooling is skipped.
359
- If ``None``, there is no channel axis.
360
- name: optional, str
361
- The object name.
362
-
363
418
  .. _link:
364
419
  https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
365
420
  """
@@ -371,6 +426,7 @@ class MaxPool1d(_MaxPool):
371
426
  stride: Union[int, Sequence[int]] = None,
372
427
  padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
373
428
  channel_axis: Optional[int] = -1,
429
+ return_indices: bool = False,
374
430
  name: Optional[str] = None,
375
431
  in_size: Optional[Size] = None,
376
432
  ):
@@ -382,6 +438,7 @@ class MaxPool1d(_MaxPool):
382
438
  stride=stride,
383
439
  padding=padding,
384
440
  channel_axis=channel_axis,
441
+ return_indices=return_indices,
385
442
  name=name)
386
443
 
387
444
 
@@ -403,7 +460,6 @@ class MaxPool2d(_MaxPool):
403
460
  for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
404
461
  It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
405
462
 
406
-
407
463
  Shape:
408
464
  - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
409
465
  - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
@@ -416,9 +472,33 @@ class MaxPool2d(_MaxPool):
416
472
  W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
417
473
  \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
418
474
 
419
- Examples::
420
-
421
- >>> import brainstate as brainstate
475
+ Parameters
476
+ ----------
477
+ kernel_size : int or sequence of int
478
+ An integer, or a sequence of integers defining the window to reduce over.
479
+ stride : int or sequence of int, optional
480
+ An integer, or a sequence of integers, representing the inter-window stride.
481
+ Default: kernel_size
482
+ padding : str, int or sequence of tuple, optional
483
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
484
+ of n `(low, high)` integer pairs that give the padding to apply before
485
+ and after each spatial dimension. Default: 'VALID'
486
+ channel_axis : int, optional
487
+ Axis of the spatial channels for which pooling is skipped.
488
+ If ``None``, there is no channel axis. Default: -1
489
+ return_indices : bool, optional
490
+ If True, will return the max indices along with the outputs.
491
+ Useful for MaxUnpool2d. Default: False
492
+ name : str, optional
493
+ The object name.
494
+ in_size : Sequence of int, optional
495
+ The shape of the input tensor.
496
+
497
+ Examples
498
+ --------
499
+ .. code-block:: python
500
+
501
+ >>> import brainstate
422
502
  >>> # pool of square window of size=3, stride=2
423
503
  >>> m = MaxPool2d(3, stride=2)
424
504
  >>> # pool of non-square window
@@ -430,25 +510,6 @@ class MaxPool2d(_MaxPool):
430
510
 
431
511
  .. _link:
432
512
  https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
433
-
434
- Parameters
435
- ----------
436
- in_size: Sequence of int
437
- The shape of the input tensor.
438
- kernel_size: int, sequence of int
439
- An integer, or a sequence of integers defining the window to reduce over.
440
- stride: int, sequence of int
441
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
442
- padding: str, int, sequence of tuple
443
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
444
- of n `(low, high)` integer pairs that give the padding to apply before
445
- and after each spatial dimension.
446
- channel_axis: int, optional
447
- Axis of the spatial channels for which pooling is skipped.
448
- If ``None``, there is no channel axis.
449
- name: optional, str
450
- The object name.
451
-
452
513
  """
453
514
  __module__ = 'brainstate.nn'
454
515
 
@@ -458,6 +519,7 @@ class MaxPool2d(_MaxPool):
458
519
  stride: Union[int, Sequence[int]] = None,
459
520
  padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
460
521
  channel_axis: Optional[int] = -1,
522
+ return_indices: bool = False,
461
523
  name: Optional[str] = None,
462
524
  in_size: Optional[Size] = None,
463
525
  ):
@@ -469,6 +531,7 @@ class MaxPool2d(_MaxPool):
469
531
  stride=stride,
470
532
  padding=padding,
471
533
  channel_axis=channel_axis,
534
+ return_indices=return_indices,
472
535
  name=name)
473
536
 
474
537
 
@@ -490,7 +553,6 @@ class MaxPool3d(_MaxPool):
490
553
  for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
491
554
  It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
492
555
 
493
-
494
556
  Shape:
495
557
  - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
496
558
  - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
@@ -507,9 +569,33 @@ class MaxPool3d(_MaxPool):
507
569
  W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
508
570
  (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
509
571
 
510
- Examples::
511
-
512
- >>> import brainstate as brainstate
572
+ Parameters
573
+ ----------
574
+ kernel_size : int or sequence of int
575
+ An integer, or a sequence of integers defining the window to reduce over.
576
+ stride : int or sequence of int, optional
577
+ An integer, or a sequence of integers, representing the inter-window stride.
578
+ Default: kernel_size
579
+ padding : str, int or sequence of tuple, optional
580
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
581
+ of n `(low, high)` integer pairs that give the padding to apply before
582
+ and after each spatial dimension. Default: 'VALID'
583
+ channel_axis : int, optional
584
+ Axis of the spatial channels for which pooling is skipped.
585
+ If ``None``, there is no channel axis. Default: -1
586
+ return_indices : bool, optional
587
+ If True, will return the max indices along with the outputs.
588
+ Useful for MaxUnpool3d. Default: False
589
+ name : str, optional
590
+ The object name.
591
+ in_size : Sequence of int, optional
592
+ The shape of the input tensor.
593
+
594
+ Examples
595
+ --------
596
+ .. code-block:: python
597
+
598
+ >>> import brainstate
513
599
  >>> # pool of square window of size=3, stride=2
514
600
  >>> m = MaxPool3d(3, stride=2)
515
601
  >>> # pool of non-square window
@@ -521,25 +607,6 @@ class MaxPool3d(_MaxPool):
521
607
 
522
608
  .. _link:
523
609
  https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
524
-
525
- Parameters
526
- ----------
527
- in_size: Sequence of int
528
- The shape of the input tensor.
529
- kernel_size: int, sequence of int
530
- An integer, or a sequence of integers defining the window to reduce over.
531
- stride: int, sequence of int
532
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
533
- padding: str, int, sequence of tuple
534
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
535
- of n `(low, high)` integer pairs that give the padding to apply before
536
- and after each spatial dimension.
537
- channel_axis: int, optional
538
- Axis of the spatial channels for which pooling is skipped.
539
- If ``None``, there is no channel axis.
540
- name: optional, str
541
- The object name.
542
-
543
610
  """
544
611
  __module__ = 'brainstate.nn'
545
612
 
@@ -549,6 +616,7 @@ class MaxPool3d(_MaxPool):
549
616
  stride: Union[int, Sequence[int]] = None,
550
617
  padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
551
618
  channel_axis: Optional[int] = -1,
619
+ return_indices: bool = False,
552
620
  name: Optional[str] = None,
553
621
  in_size: Optional[Size] = None,
554
622
  ):
@@ -560,9 +628,432 @@ class MaxPool3d(_MaxPool):
560
628
  stride=stride,
561
629
  padding=padding,
562
630
  channel_axis=channel_axis,
631
+ return_indices=return_indices,
563
632
  name=name)
564
633
 
565
634
 
635
+ class _MaxUnpool(Module):
636
+ """Base class for max unpooling operations."""
637
+
638
+ def __init__(
639
+ self,
640
+ pool_dim: int,
641
+ kernel_size: Size,
642
+ stride: Union[int, Sequence[int]] = None,
643
+ padding: Union[int, Tuple[int, ...]] = 0,
644
+ channel_axis: Optional[int] = -1,
645
+ name: Optional[str] = None,
646
+ in_size: Optional[Size] = None,
647
+ ):
648
+ super().__init__(name=name)
649
+
650
+ self.pool_dim = pool_dim
651
+
652
+ # kernel_size
653
+ if isinstance(kernel_size, int):
654
+ kernel_size = (kernel_size,) * pool_dim
655
+ elif isinstance(kernel_size, Sequence):
656
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
657
+ assert all(
658
+ [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
659
+ if len(kernel_size) != pool_dim:
660
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
661
+ else:
662
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
663
+ self.kernel_size = kernel_size
664
+
665
+ # stride
666
+ if stride is None:
667
+ stride = kernel_size
668
+ if isinstance(stride, int):
669
+ stride = (stride,) * pool_dim
670
+ elif isinstance(stride, Sequence):
671
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
672
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
673
+ if len(stride) != pool_dim:
674
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
675
+ else:
676
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
677
+ self.stride = stride
678
+
679
+ # padding
680
+ if isinstance(padding, int):
681
+ padding = (padding,) * pool_dim
682
+ elif isinstance(padding, (tuple, list)):
683
+ if len(padding) != pool_dim:
684
+ raise ValueError(f'padding should have {pool_dim} values, but got {len(padding)}')
685
+ else:
686
+ raise TypeError(f'padding should be int or tuple of {pool_dim} ints.')
687
+ self.padding = padding
688
+
689
+ # channel_axis
690
+ assert channel_axis is None or isinstance(channel_axis, int), \
691
+ f'channel_axis should be an int, but got {channel_axis}'
692
+ self.channel_axis = channel_axis
693
+
694
+ # in & out shapes
695
+ if in_size is not None:
696
+ in_size = tuple(in_size)
697
+ self.in_size = in_size
698
+
699
+ def _compute_output_shape(self, input_shape, output_size=None):
700
+ """Compute the output shape after unpooling."""
701
+ if output_size is not None:
702
+ return output_size
703
+
704
+ # Calculate output shape based on kernel, stride, and padding
705
+ output_shape = []
706
+ for i in range(self.pool_dim):
707
+ dim_size = (input_shape[i] - 1) * self.stride[i] - 2 * self.padding[i] + self.kernel_size[i]
708
+ output_shape.append(dim_size)
709
+
710
+ return tuple(output_shape)
711
+
712
+ def _unpool_nd(self, x, indices, output_size=None):
713
+ """Perform N-dimensional max unpooling."""
714
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
715
+ if x.ndim < x_dim:
716
+ raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
717
+
718
+ # Determine output shape
719
+ if output_size is None:
720
+ # Infer output shape from input shape
721
+ spatial_dims = self._get_spatial_dims(x.shape)
722
+ output_spatial_shape = self._compute_output_shape(spatial_dims, output_size)
723
+ output_shape = list(x.shape)
724
+
725
+ # Update spatial dimensions in output shape
726
+ spatial_start = self._get_spatial_start_idx(x.ndim)
727
+ for i, size in enumerate(output_spatial_shape):
728
+ output_shape[spatial_start + i] = size
729
+ output_shape = tuple(output_shape)
730
+ else:
731
+ # Use provided output size
732
+ if isinstance(output_size, (list, tuple)):
733
+ if len(output_size) == x.ndim:
734
+ # Full output shape provided
735
+ output_shape = tuple(output_size)
736
+ else:
737
+ # Only spatial dimensions provided
738
+ if len(output_size) != self.pool_dim:
739
+ raise ValueError(f"output_size must have {self.pool_dim} spatial dimensions, got {len(output_size)}")
740
+ output_shape = list(x.shape)
741
+ spatial_start = self._get_spatial_start_idx(x.ndim)
742
+ for i, size in enumerate(output_size):
743
+ output_shape[spatial_start + i] = size
744
+ output_shape = tuple(output_shape)
745
+ else:
746
+ # Single integer provided, use for all spatial dims
747
+ output_shape = list(x.shape)
748
+ spatial_start = self._get_spatial_start_idx(x.ndim)
749
+ for i in range(self.pool_dim):
750
+ output_shape[spatial_start + i] = output_size
751
+ output_shape = tuple(output_shape)
752
+
753
+ # Create output array filled with zeros
754
+ output = jnp.zeros(output_shape, dtype=x.dtype)
755
+
756
+ # # Scatter input values to output using indices
757
+ # # Flatten spatial dimensions for easier indexing
758
+ # batch_dims = x.ndim - self.pool_dim - (0 if self.channel_axis is None else 1)
759
+ #
760
+ # # Reshape for processing
761
+ # if batch_dims > 0:
762
+ # batch_shape = x.shape[:batch_dims]
763
+ # if self.channel_axis is not None and self.channel_axis < batch_dims:
764
+ # # Channel axis is before spatial dims
765
+ # channel_idx = self.channel_axis
766
+ # n_channels = x.shape[channel_idx]
767
+ # elif self.channel_axis is not None:
768
+ # # Channel axis is after spatial dims
769
+ # if self.channel_axis < 0:
770
+ # channel_idx = x.ndim + self.channel_axis
771
+ # else:
772
+ # channel_idx = self.channel_axis
773
+ # n_channels = x.shape[channel_idx]
774
+ # else:
775
+ # n_channels = None
776
+ # else:
777
+ # batch_shape = ()
778
+ # if self.channel_axis is not None:
779
+ # if self.channel_axis < 0:
780
+ # channel_idx = x.ndim + self.channel_axis
781
+ # else:
782
+ # channel_idx = self.channel_axis
783
+ # n_channels = x.shape[channel_idx]
784
+ # else:
785
+ # n_channels = None
786
+
787
+ # Use JAX's scatter operation
788
+ # Flatten the indices to 1D for scatter
789
+ flat_indices = indices.ravel()
790
+ flat_values = x.ravel()
791
+ flat_output = output.ravel()
792
+
793
+ # Scatter the values
794
+ flat_output = flat_output.at[flat_indices].set(flat_values)
795
+
796
+ # Reshape back to original shape
797
+ output = flat_output.reshape(output_shape)
798
+
799
+ return output
800
+
801
+ def _get_spatial_dims(self, shape):
802
+ """Extract spatial dimensions from input shape."""
803
+ if self.channel_axis is None:
804
+ return shape[-self.pool_dim:]
805
+ else:
806
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else len(shape) + self.channel_axis
807
+ all_dims = list(range(len(shape)))
808
+ all_dims.pop(channel_axis)
809
+ return tuple(shape[i] for i in all_dims[-self.pool_dim:])
810
+
811
+ def _get_spatial_start_idx(self, ndim):
812
+ """Get the starting index of spatial dimensions."""
813
+ if self.channel_axis is None:
814
+ return ndim - self.pool_dim
815
+ else:
816
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else ndim + self.channel_axis
817
+ if channel_axis < ndim - self.pool_dim:
818
+ return ndim - self.pool_dim
819
+ else:
820
+ return ndim - self.pool_dim - 1
821
+
822
+ def update(self, x, indices, output_size=None):
823
+ """Forward pass of MaxUnpool1d.
824
+
825
+ Parameters
826
+ ----------
827
+ x : Array
828
+ Input tensor from MaxPool1d
829
+ indices : Array
830
+ Indices of maximum values from MaxPool1d
831
+ output_size : int or tuple, optional
832
+ The targeted output size
833
+
834
+ Returns
835
+ -------
836
+ Array
837
+ Unpooled output
838
+ """
839
+ return self._unpool_nd(x, indices, output_size)
840
+
841
+
842
+ class MaxUnpool1d(_MaxUnpool):
843
+ r"""Computes a partial inverse of MaxPool1d.
844
+
845
+ MaxPool1d is not fully invertible, since the non-maximal values are lost.
846
+ MaxUnpool1d takes in as input the output of MaxPool1d including the indices
847
+ of the maximal values and computes a partial inverse in which all
848
+ non-maximal values are set to zero.
849
+
850
+ Note:
851
+ This function may produce nondeterministic gradients when given tensors
852
+ on a CUDA device. See notes on reproducibility for more information.
853
+
854
+ Shape:
855
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`
856
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
857
+
858
+ .. math::
859
+ L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size}
860
+
861
+ or as given by :attr:`output_size` in the call operator
862
+
863
+ Parameters
864
+ ----------
865
+ kernel_size : int or tuple
866
+ Size of the max pooling window.
867
+ stride : int or tuple, optional
868
+ Stride of the max pooling window. Default: kernel_size
869
+ padding : int or tuple, optional
870
+ Padding that was added to the input. Default: 0
871
+ channel_axis : int, optional
872
+ Axis of the channels. Default: -1
873
+ name : str, optional
874
+ Name of the module.
875
+ in_size : Size, optional
876
+ Input size for shape inference.
877
+
878
+ Examples
879
+ --------
880
+ .. code-block:: python
881
+
882
+ >>> import brainstate
883
+ >>> import jax.numpy as jnp
884
+ >>> # Create pooling and unpooling layers
885
+ >>> pool = MaxPool1d(2, stride=2, return_indices=True, channel_axis=-1)
886
+ >>> unpool = MaxUnpool1d(2, stride=2, channel_axis=-1)
887
+ >>> input = brainstate.random.randn(20, 50, 16)
888
+ >>> output, indices = pool(input)
889
+ >>> unpooled = unpool(output, indices)
890
+ >>> # unpooled will have shape (20, 100, 16) with zeros at non-maximal positions
891
+ """
892
+ __module__ = 'brainstate.nn'
893
+
894
+ def __init__(
895
+ self,
896
+ kernel_size: Size,
897
+ stride: Union[int, Sequence[int]] = None,
898
+ padding: Union[int, Tuple[int, ...]] = 0,
899
+ channel_axis: Optional[int] = -1,
900
+ name: Optional[str] = None,
901
+ in_size: Optional[Size] = None,
902
+ ):
903
+ super().__init__(
904
+ pool_dim=1,
905
+ kernel_size=kernel_size,
906
+ stride=stride,
907
+ padding=padding,
908
+ channel_axis=channel_axis,
909
+ name=name,
910
+ in_size=in_size
911
+ )
912
+
913
+
914
+ class MaxUnpool2d(_MaxUnpool):
915
+ r"""Computes a partial inverse of MaxPool2d.
916
+
917
+ MaxPool2d is not fully invertible, since the non-maximal values are lost.
918
+ MaxUnpool2d takes in as input the output of MaxPool2d including the indices
919
+ of the maximal values and computes a partial inverse in which all
920
+ non-maximal values are set to zero.
921
+
922
+ Shape:
923
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
924
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
925
+
926
+ .. math::
927
+ H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
928
+
929
+ .. math::
930
+ W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
931
+
932
+ or as given by :attr:`output_size` in the call operator
933
+
934
+ Parameters
935
+ ----------
936
+ kernel_size : int or tuple
937
+ Size of the max pooling window.
938
+ stride : int or tuple, optional
939
+ Stride of the max pooling window. Default: kernel_size
940
+ padding : int or tuple, optional
941
+ Padding that was added to the input. Default: 0
942
+ channel_axis : int, optional
943
+ Axis of the channels. Default: -1
944
+ name : str, optional
945
+ Name of the module.
946
+ in_size : Size, optional
947
+ Input size for shape inference.
948
+
949
+ Examples
950
+ --------
951
+ .. code-block:: python
952
+
953
+ >>> import brainstate
954
+ >>> # Create pooling and unpooling layers
955
+ >>> pool = MaxPool2d(2, stride=2, return_indices=True, channel_axis=-1)
956
+ >>> unpool = MaxUnpool2d(2, stride=2, channel_axis=-1)
957
+ >>> input = brainstate.random.randn(1, 4, 4, 16)
958
+ >>> output, indices = pool(input)
959
+ >>> unpooled = unpool(output, indices)
960
+ >>> # unpooled will have shape (1, 8, 8, 16) with zeros at non-maximal positions
961
+ """
962
+ __module__ = 'brainstate.nn'
963
+
964
+ def __init__(
965
+ self,
966
+ kernel_size: Size,
967
+ stride: Union[int, Sequence[int]] = None,
968
+ padding: Union[int, Tuple[int, ...]] = 0,
969
+ channel_axis: Optional[int] = -1,
970
+ name: Optional[str] = None,
971
+ in_size: Optional[Size] = None,
972
+ ):
973
+ super().__init__(
974
+ pool_dim=2,
975
+ kernel_size=kernel_size,
976
+ stride=stride,
977
+ padding=padding,
978
+ channel_axis=channel_axis,
979
+ name=name,
980
+ in_size=in_size
981
+ )
982
+
983
+
984
+ class MaxUnpool3d(_MaxUnpool):
985
+ r"""Computes a partial inverse of MaxPool3d.
986
+
987
+ MaxPool3d is not fully invertible, since the non-maximal values are lost.
988
+ MaxUnpool3d takes in as input the output of MaxPool3d including the indices
989
+ of the maximal values and computes a partial inverse in which all
990
+ non-maximal values are set to zero.
991
+
992
+ Shape:
993
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`
994
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
995
+
996
+ .. math::
997
+ D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
998
+
999
+ .. math::
1000
+ H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{kernel\_size}[1]
1001
+
1002
+ .. math::
1003
+ W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{kernel\_size}[2]
1004
+
1005
+ or as given by :attr:`output_size` in the call operator
1006
+
1007
+ Parameters
1008
+ ----------
1009
+ kernel_size : int or tuple
1010
+ Size of the max pooling window.
1011
+ stride : int or tuple, optional
1012
+ Stride of the max pooling window. Default: kernel_size
1013
+ padding : int or tuple, optional
1014
+ Padding that was added to the input. Default: 0
1015
+ channel_axis : int, optional
1016
+ Axis of the channels. Default: -1
1017
+ name : str, optional
1018
+ Name of the module.
1019
+ in_size : Size, optional
1020
+ Input size for shape inference.
1021
+
1022
+ Examples
1023
+ --------
1024
+ .. code-block:: python
1025
+
1026
+ >>> import brainstate
1027
+ >>> # Create pooling and unpooling layers
1028
+ >>> pool = MaxPool3d(2, stride=2, return_indices=True, channel_axis=-1)
1029
+ >>> unpool = MaxUnpool3d(2, stride=2, channel_axis=-1)
1030
+ >>> input = brainstate.random.randn(1, 4, 4, 4, 16)
1031
+ >>> output, indices = pool(input)
1032
+ >>> unpooled = unpool(output, indices)
1033
+ >>> # unpooled will have shape (1, 8, 8, 8, 16) with zeros at non-maximal positions
1034
+ """
1035
+ __module__ = 'brainstate.nn'
1036
+
1037
+ def __init__(
1038
+ self,
1039
+ kernel_size: Size,
1040
+ stride: Union[int, Sequence[int]] = None,
1041
+ padding: Union[int, Tuple[int, ...]] = 0,
1042
+ channel_axis: Optional[int] = -1,
1043
+ name: Optional[str] = None,
1044
+ in_size: Optional[Size] = None,
1045
+ ):
1046
+ super().__init__(
1047
+ pool_dim=3,
1048
+ kernel_size=kernel_size,
1049
+ stride=stride,
1050
+ padding=padding,
1051
+ channel_axis=channel_axis,
1052
+ name=name,
1053
+ in_size=in_size
1054
+ )
1055
+
1056
+
566
1057
  class AvgPool1d(_AvgPool):
567
1058
  r"""Applies a 1D average pooling over an input signal composed of several input planes.
568
1059
 
@@ -586,33 +1077,35 @@ class AvgPool1d(_AvgPool):
586
1077
  L_{out} = \left\lfloor \frac{L_{in} +
587
1078
  2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
588
1079
 
589
- Examples::
590
-
591
- >>> import brainstate as brainstate
1080
+ Parameters
1081
+ ----------
1082
+ kernel_size : int or sequence of int
1083
+ An integer, or a sequence of integers defining the window to reduce over.
1084
+ stride : int or sequence of int, optional
1085
+ An integer, or a sequence of integers, representing the inter-window stride.
1086
+ Default: 1
1087
+ padding : str, int or sequence of tuple, optional
1088
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1089
+ of n `(low, high)` integer pairs that give the padding to apply before
1090
+ and after each spatial dimension. Default: 'VALID'
1091
+ channel_axis : int, optional
1092
+ Axis of the spatial channels for which pooling is skipped.
1093
+ If ``None``, there is no channel axis. Default: -1
1094
+ name : str, optional
1095
+ The object name.
1096
+ in_size : Sequence of int, optional
1097
+ The shape of the input tensor.
1098
+
1099
+ Examples
1100
+ --------
1101
+ .. code-block:: python
1102
+
1103
+ >>> import brainstate
592
1104
  >>> # pool with window of size=3, stride=2
593
1105
  >>> m = AvgPool1d(3, stride=2)
594
1106
  >>> input = brainstate.random.randn(20, 50, 16)
595
1107
  >>> m(input).shape
596
1108
  (20, 24, 16)
597
-
598
- Parameters
599
- ----------
600
- in_size: Sequence of int
601
- The shape of the input tensor.
602
- kernel_size: int, sequence of int
603
- An integer, or a sequence of integers defining the window to reduce over.
604
- stride: int, sequence of int
605
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
606
- padding: str, int, sequence of tuple
607
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
608
- of n `(low, high)` integer pairs that give the padding to apply before
609
- and after each spatial dimension.
610
- channel_axis: int, optional
611
- Axis of the spatial channels for which pooling is skipped.
612
- If ``None``, there is no channel axis.
613
- name: optional, str
614
- The object name.
615
-
616
1109
  """
617
1110
  __module__ = 'brainstate.nn'
618
1111
 
@@ -625,15 +1118,17 @@ class AvgPool1d(_AvgPool):
625
1118
  name: Optional[str] = None,
626
1119
  in_size: Optional[Size] = None,
627
1120
  ):
628
- super().__init__(in_size=in_size,
629
- init_value=0.,
630
- computation=jax.lax.add,
631
- pool_dim=1,
632
- kernel_size=kernel_size,
633
- stride=stride,
634
- padding=padding,
635
- channel_axis=channel_axis,
636
- name=name)
1121
+ super().__init__(
1122
+ in_size=in_size,
1123
+ init_value=0.,
1124
+ computation=jax.lax.add,
1125
+ pool_dim=1,
1126
+ kernel_size=kernel_size,
1127
+ stride=stride,
1128
+ padding=padding,
1129
+ channel_axis=channel_axis,
1130
+ name=name
1131
+ )
637
1132
 
638
1133
 
639
1134
  class AvgPool2d(_AvgPool):
@@ -663,35 +1158,38 @@ class AvgPool2d(_AvgPool):
663
1158
  W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
664
1159
  \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
665
1160
 
666
- Examples::
667
-
668
- >>> import brainstate as brainstate
1161
+ Parameters
1162
+ ----------
1163
+ kernel_size : int or sequence of int
1164
+ An integer, or a sequence of integers defining the window to reduce over.
1165
+ stride : int or sequence of int, optional
1166
+ An integer, or a sequence of integers, representing the inter-window stride.
1167
+ Default: 1
1168
+ padding : str, int or sequence of tuple, optional
1169
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1170
+ of n `(low, high)` integer pairs that give the padding to apply before
1171
+ and after each spatial dimension. Default: 'VALID'
1172
+ channel_axis : int, optional
1173
+ Axis of the spatial channels for which pooling is skipped.
1174
+ If ``None``, there is no channel axis. Default: -1
1175
+ name : str, optional
1176
+ The object name.
1177
+ in_size : Sequence of int, optional
1178
+ The shape of the input tensor.
1179
+
1180
+ Examples
1181
+ --------
1182
+ .. code-block:: python
1183
+
1184
+ >>> import brainstate
669
1185
  >>> # pool of square window of size=3, stride=2
670
1186
  >>> m = AvgPool2d(3, stride=2)
671
1187
  >>> # pool of non-square window
672
1188
  >>> m = AvgPool2d((3, 2), stride=(2, 1))
673
- >>> input = brainstate.random.randn(20, 50, 32, , 16)
1189
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
674
1190
  >>> output = m(input)
675
1191
  >>> output.shape
676
1192
  (20, 24, 31, 16)
677
-
678
- Parameters
679
- ----------
680
- in_size: Sequence of int
681
- The shape of the input tensor.
682
- kernel_size: int, sequence of int
683
- An integer, or a sequence of integers defining the window to reduce over.
684
- stride: int, sequence of int
685
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
686
- padding: str, int, sequence of tuple
687
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
688
- of n `(low, high)` integer pairs that give the padding to apply before
689
- and after each spatial dimension.
690
- channel_axis: int, optional
691
- Axis of the spatial channels for which pooling is skipped.
692
- If ``None``, there is no channel axis.
693
- name: optional, str
694
- The object name.
695
1193
  """
696
1194
  __module__ = 'brainstate.nn'
697
1195
 
@@ -704,15 +1202,17 @@ class AvgPool2d(_AvgPool):
704
1202
  name: Optional[str] = None,
705
1203
  in_size: Optional[Size] = None,
706
1204
  ):
707
- super().__init__(in_size=in_size,
708
- init_value=0.,
709
- computation=jax.lax.add,
710
- pool_dim=2,
711
- kernel_size=kernel_size,
712
- stride=stride,
713
- padding=padding,
714
- channel_axis=channel_axis,
715
- name=name)
1205
+ super().__init__(
1206
+ in_size=in_size,
1207
+ init_value=0.,
1208
+ computation=jax.lax.add,
1209
+ pool_dim=2,
1210
+ kernel_size=kernel_size,
1211
+ stride=stride,
1212
+ padding=padding,
1213
+ channel_axis=channel_axis,
1214
+ name=name
1215
+ )
716
1216
 
717
1217
 
718
1218
  class AvgPool3d(_AvgPool):
@@ -751,9 +1251,30 @@ class AvgPool3d(_AvgPool):
751
1251
  W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
752
1252
  \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
753
1253
 
754
- Examples::
755
-
756
- >>> import brainstate as brainstate
1254
+ Parameters
1255
+ ----------
1256
+ kernel_size : int or sequence of int
1257
+ An integer, or a sequence of integers defining the window to reduce over.
1258
+ stride : int or sequence of int, optional
1259
+ An integer, or a sequence of integers, representing the inter-window stride.
1260
+ Default: 1
1261
+ padding : str, int or sequence of tuple, optional
1262
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1263
+ of n `(low, high)` integer pairs that give the padding to apply before
1264
+ and after each spatial dimension. Default: 'VALID'
1265
+ channel_axis : int, optional
1266
+ Axis of the spatial channels for which pooling is skipped.
1267
+ If ``None``, there is no channel axis. Default: -1
1268
+ name : str, optional
1269
+ The object name.
1270
+ in_size : Sequence of int, optional
1271
+ The shape of the input tensor.
1272
+
1273
+ Examples
1274
+ --------
1275
+ .. code-block:: python
1276
+
1277
+ >>> import brainstate
757
1278
  >>> # pool of square window of size=3, stride=2
758
1279
  >>> m = AvgPool3d(3, stride=2)
759
1280
  >>> # pool of non-square window
@@ -763,45 +1284,407 @@ class AvgPool3d(_AvgPool):
763
1284
  >>> output.shape
764
1285
  (20, 24, 43, 15, 16)
765
1286
 
1287
+ """
1288
+ __module__ = 'brainstate.nn'
1289
+
1290
+ def __init__(
1291
+ self,
1292
+ kernel_size: Size,
1293
+ stride: Union[int, Sequence[int]] = 1,
1294
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1295
+ channel_axis: Optional[int] = -1,
1296
+ name: Optional[str] = None,
1297
+ in_size: Optional[Size] = None,
1298
+ ):
1299
+ super().__init__(
1300
+ in_size=in_size,
1301
+ init_value=0.,
1302
+ computation=jax.lax.add,
1303
+ pool_dim=3,
1304
+ kernel_size=kernel_size,
1305
+ stride=stride,
1306
+ padding=padding,
1307
+ channel_axis=channel_axis,
1308
+ name=name
1309
+ )
1310
+
1311
+
1312
+ class _LPPool(Module):
1313
+ """Base class for Lp pooling operations."""
1314
+
1315
+ def __init__(
1316
+ self,
1317
+ norm_type: float,
1318
+ pool_dim: int,
1319
+ kernel_size: Size,
1320
+ stride: Union[int, Sequence[int]] = None,
1321
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1322
+ channel_axis: Optional[int] = -1,
1323
+ name: Optional[str] = None,
1324
+ in_size: Optional[Size] = None,
1325
+ ):
1326
+ super().__init__(name=name)
1327
+
1328
+ if norm_type <= 0:
1329
+ raise ValueError(f"norm_type must be positive, got {norm_type}")
1330
+ self.norm_type = norm_type
1331
+ self.pool_dim = pool_dim
1332
+
1333
+ # kernel_size
1334
+ if isinstance(kernel_size, int):
1335
+ kernel_size = (kernel_size,) * pool_dim
1336
+ elif isinstance(kernel_size, Sequence):
1337
+ assert isinstance(kernel_size, (tuple, list)), f'kernel_size should be a tuple, but got {type(kernel_size)}'
1338
+ assert all(
1339
+ [isinstance(x, int) for x in kernel_size]), f'kernel_size should be a tuple of ints. {kernel_size}'
1340
+ if len(kernel_size) != pool_dim:
1341
+ raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
1342
+ else:
1343
+ raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
1344
+ self.kernel_size = kernel_size
1345
+
1346
+ # stride
1347
+ if stride is None:
1348
+ stride = kernel_size
1349
+ if isinstance(stride, int):
1350
+ stride = (stride,) * pool_dim
1351
+ elif isinstance(stride, Sequence):
1352
+ assert isinstance(stride, (tuple, list)), f'stride should be a tuple, but got {type(stride)}'
1353
+ assert all([isinstance(x, int) for x in stride]), f'stride should be a tuple of ints. {stride}'
1354
+ if len(stride) != pool_dim:
1355
+ raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(stride)}')
1356
+ else:
1357
+ raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
1358
+ self.stride = stride
1359
+
1360
+ # padding
1361
+ if isinstance(padding, str):
1362
+ if padding not in ("SAME", "VALID"):
1363
+ raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
1364
+ elif isinstance(padding, int):
1365
+ padding = [(padding, padding) for _ in range(pool_dim)]
1366
+ elif isinstance(padding, (list, tuple)):
1367
+ if isinstance(padding[0], int):
1368
+ if len(padding) == pool_dim:
1369
+ padding = [(x, x) for x in padding]
1370
+ else:
1371
+ raise ValueError(f'If padding is a sequence of ints, it '
1372
+ f'should has the length of {pool_dim}.')
1373
+ else:
1374
+ if not all([isinstance(x, (tuple, list)) for x in padding]):
1375
+ raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
1376
+ if not all([len(x) == 2 for x in padding]):
1377
+ raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
1378
+ if len(padding) == 1:
1379
+ padding = tuple(padding) * pool_dim
1380
+ assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
1381
+ else:
1382
+ raise ValueError
1383
+ self.padding = padding
1384
+
1385
+ # channel_axis
1386
+ assert channel_axis is None or isinstance(channel_axis, int), \
1387
+ f'channel_axis should be an int, but got {channel_axis}'
1388
+ self.channel_axis = channel_axis
1389
+
1390
+ # in & out shapes
1391
+ if in_size is not None:
1392
+ in_size = tuple(in_size)
1393
+ self.in_size = in_size
1394
+ y = jax.eval_shape(self.update, jax.ShapeDtypeStruct((128,) + in_size, environ.dftype()))
1395
+ self.out_size = y.shape[1:]
1396
+
1397
+ def update(self, x):
1398
+ x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
1399
+ if x.ndim < x_dim:
1400
+ raise ValueError(f'Expected input with >= {x_dim} dimensions, but got {x.ndim}.')
1401
+
1402
+ window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
1403
+ stride = self._infer_shape(x.ndim, self.stride, 1)
1404
+ padding = (self.padding if isinstance(self.padding, str) else
1405
+ self._infer_shape(x.ndim, self.padding, element=(0, 0)))
1406
+
1407
+ # For Lp pooling, we need to:
1408
+ # 1. Take absolute value and raise to power p
1409
+ # 2. Sum over the window
1410
+ # 3. Take the p-th root
1411
+
1412
+ # Step 1: |x|^p
1413
+ x_pow = jnp.abs(x) ** self.norm_type
1414
+
1415
+ # Step 2: Sum over window
1416
+ pooled_sum = jax.lax.reduce_window(
1417
+ x_pow,
1418
+ init_value=0.,
1419
+ computation=jax.lax.add,
1420
+ window_dimensions=window_shape,
1421
+ window_strides=stride,
1422
+ padding=padding
1423
+ )
1424
+
1425
+ # Step 3: Take p-th root and multiply by normalization factor
1426
+ # The normalization factor is (1/N)^(1/p) where N is the window size
1427
+ window_size = np.prod([w for i, w in enumerate(self.kernel_size)])
1428
+ norm_factor = window_size ** (-1.0 / self.norm_type)
1429
+ result = norm_factor * (pooled_sum ** (1.0 / self.norm_type))
1430
+
1431
+ return result
1432
+
1433
+ def _infer_shape(self, x_dim, inputs, element):
1434
+ channel_axis = self.channel_axis
1435
+ if channel_axis and not 0 <= abs(channel_axis) < x_dim:
1436
+ raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
1437
+ if channel_axis and channel_axis < 0:
1438
+ channel_axis = x_dim + channel_axis
1439
+ all_dims = list(range(x_dim))
1440
+ if channel_axis is not None:
1441
+ all_dims.pop(channel_axis)
1442
+ pool_dims = all_dims[-self.pool_dim:]
1443
+ results = [element] * x_dim
1444
+ for i, dim in enumerate(pool_dims):
1445
+ results[dim] = inputs[i]
1446
+ return results
1447
+
1448
+
1449
+ class LPPool1d(_LPPool):
1450
+ r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
1451
+
1452
+ On each window, the function computed is:
1453
+
1454
+ .. math::
1455
+ f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1456
+
1457
+ - At :math:`p = \infty`, one gets max pooling
1458
+ - At :math:`p = 1`, one gets average pooling (with absolute values)
1459
+ - At :math:`p = 2`, one gets root mean square (RMS) pooling
1460
+
1461
+ Shape:
1462
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
1463
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
1464
+
1465
+ .. math::
1466
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
1467
+
766
1468
  Parameters
767
1469
  ----------
768
- in_size: Sequence of int
769
- The shape of the input tensor.
770
- kernel_size: int, sequence of int
771
- An integer, or a sequence of integers defining the window to reduce over.
772
- stride: int, sequence of int
773
- An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
774
- padding: str, int, sequence of tuple
775
- Either the string `'SAME'`, the string `'VALID'`, or a sequence
776
- of n `(low, high)` integer pairs that give the padding to apply before
777
- and after each spatial dimension.
778
- channel_axis: int, optional
779
- Axis of the spatial channels for which pooling is skipped.
780
- If ``None``, there is no channel axis.
781
- name: optional, str
782
- The object name.
1470
+ norm_type : float
1471
+ Exponent for the pooling operation. Default: 2.0
1472
+ kernel_size : int or sequence of int
1473
+ An integer, or a sequence of integers defining the window to reduce over.
1474
+ stride : int or sequence of int, optional
1475
+ An integer, or a sequence of integers, representing the inter-window stride.
1476
+ Default: kernel_size
1477
+ padding : str, int or sequence of tuple, optional
1478
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1479
+ of n `(low, high)` integer pairs that give the padding to apply before
1480
+ and after each spatial dimension. Default: 'VALID'
1481
+ channel_axis : int, optional
1482
+ Axis of the spatial channels for which pooling is skipped.
1483
+ If ``None``, there is no channel axis. Default: -1
1484
+ name : str, optional
1485
+ The object name.
1486
+ in_size : Sequence of int, optional
1487
+ The shape of the input tensor.
1488
+
1489
+ Examples
1490
+ --------
1491
+ .. code-block:: python
1492
+
1493
+ >>> import brainstate
1494
+ >>> # power-average pooling of window of size=3, stride=2 with norm_type=2.0
1495
+ >>> m = LPPool1d(2, 3, stride=2)
1496
+ >>> input = brainstate.random.randn(20, 50, 16)
1497
+ >>> output = m(input)
1498
+ >>> output.shape
1499
+ (20, 24, 16)
1500
+ """
1501
+ __module__ = 'brainstate.nn'
783
1502
 
1503
+ def __init__(
1504
+ self,
1505
+ norm_type: float,
1506
+ kernel_size: Size,
1507
+ stride: Union[int, Sequence[int]] = None,
1508
+ padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
1509
+ channel_axis: Optional[int] = -1,
1510
+ name: Optional[str] = None,
1511
+ in_size: Optional[Size] = None,
1512
+ ):
1513
+ super().__init__(
1514
+ norm_type=norm_type,
1515
+ pool_dim=1,
1516
+ kernel_size=kernel_size,
1517
+ stride=stride,
1518
+ padding=padding,
1519
+ channel_axis=channel_axis,
1520
+ name=name,
1521
+ in_size=in_size
1522
+ )
1523
+
1524
+
1525
+ class LPPool2d(_LPPool):
1526
+ r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
1527
+
1528
+ On each window, the function computed is:
1529
+
1530
+ .. math::
1531
+ f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1532
+
1533
+ - At :math:`p = \infty`, one gets max pooling
1534
+ - At :math:`p = 1`, one gets average pooling (with absolute values)
1535
+ - At :math:`p = 2`, one gets root mean square (RMS) pooling
1536
+
1537
+ Shape:
1538
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`
1539
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
1540
+
1541
+ .. math::
1542
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1543
+
1544
+ .. math::
1545
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1546
+
1547
+ Parameters
1548
+ ----------
1549
+ norm_type : float
1550
+ Exponent for the pooling operation. Default: 2.0
1551
+ kernel_size : int or sequence of int
1552
+ An integer, or a sequence of integers defining the window to reduce over.
1553
+ stride : int or sequence of int, optional
1554
+ An integer, or a sequence of integers, representing the inter-window stride.
1555
+ Default: kernel_size
1556
+ padding : str, int or sequence of tuple, optional
1557
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1558
+ of n `(low, high)` integer pairs that give the padding to apply before
1559
+ and after each spatial dimension. Default: 'VALID'
1560
+ channel_axis : int, optional
1561
+ Axis of the spatial channels for which pooling is skipped.
1562
+ If ``None``, there is no channel axis. Default: -1
1563
+ name : str, optional
1564
+ The object name.
1565
+ in_size : Sequence of int, optional
1566
+ The shape of the input tensor.
1567
+
1568
+ Examples
1569
+ --------
1570
+ .. code-block:: python
1571
+
1572
+ >>> import brainstate
1573
+ >>> # power-average pooling of square window of size=3, stride=2
1574
+ >>> m = LPPool2d(2, 3, stride=2)
1575
+ >>> # pool of non-square window with norm_type=1.5
1576
+ >>> m = LPPool2d(1.5, (3, 2), stride=(2, 1), channel_axis=-1)
1577
+ >>> input = brainstate.random.randn(20, 50, 32, 16)
1578
+ >>> output = m(input)
1579
+ >>> output.shape
1580
+ (20, 24, 31, 16)
784
1581
  """
785
1582
  __module__ = 'brainstate.nn'
786
1583
 
787
1584
  def __init__(
788
1585
  self,
1586
+ norm_type: float,
789
1587
  kernel_size: Size,
790
- stride: Union[int, Sequence[int]] = 1,
1588
+ stride: Union[int, Sequence[int]] = None,
791
1589
  padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
792
1590
  channel_axis: Optional[int] = -1,
793
1591
  name: Optional[str] = None,
794
1592
  in_size: Optional[Size] = None,
795
1593
  ):
796
- super().__init__(in_size=in_size,
797
- init_value=0.,
798
- computation=jax.lax.add,
799
- pool_dim=3,
800
- kernel_size=kernel_size,
801
- stride=stride,
802
- padding=padding,
803
- channel_axis=channel_axis,
804
- name=name)
1594
+ super().__init__(
1595
+ norm_type=norm_type,
1596
+ pool_dim=2,
1597
+ kernel_size=kernel_size,
1598
+ stride=stride,
1599
+ padding=padding,
1600
+ channel_axis=channel_axis,
1601
+ name=name,
1602
+ in_size=in_size
1603
+ )
1604
+
1605
+
1606
+ class LPPool3d(_LPPool):
1607
+ r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
1608
+
1609
+ On each window, the function computed is:
1610
+
1611
+ .. math::
1612
+ f(X) = \sqrt[p]{\sum_{x \in X} |x|^{p}}
1613
+
1614
+ - At :math:`p = \infty`, one gets max pooling
1615
+ - At :math:`p = 1`, one gets average pooling (with absolute values)
1616
+ - At :math:`p = 2`, one gets root mean square (RMS) pooling
1617
+
1618
+ Shape:
1619
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1620
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`, where
1621
+
1622
+ .. math::
1623
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1624
+
1625
+ .. math::
1626
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1627
+
1628
+ .. math::
1629
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
1630
+
1631
+ Parameters
1632
+ ----------
1633
+ norm_type : float
1634
+ Exponent for the pooling operation. Default: 2.0
1635
+ kernel_size : int or sequence of int
1636
+ An integer, or a sequence of integers defining the window to reduce over.
1637
+ stride : int or sequence of int, optional
1638
+ An integer, or a sequence of integers, representing the inter-window stride.
1639
+ Default: kernel_size
1640
+ padding : str, int or sequence of tuple, optional
1641
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence
1642
+ of n `(low, high)` integer pairs that give the padding to apply before
1643
+ and after each spatial dimension. Default: 'VALID'
1644
+ channel_axis : int, optional
1645
+ Axis of the spatial channels for which pooling is skipped.
1646
+ If ``None``, there is no channel axis. Default: -1
1647
+ name : str, optional
1648
+ The object name.
1649
+ in_size : Sequence of int, optional
1650
+ The shape of the input tensor.
1651
+
1652
+ Examples
1653
+ --------
1654
+ .. code-block:: python
1655
+
1656
+ >>> import brainstate
1657
+ >>> # power-average pooling of cube window of size=3, stride=2
1658
+ >>> m = LPPool3d(2, 3, stride=2)
1659
+ >>> # pool of non-cubic window with norm_type=1.5
1660
+ >>> m = LPPool3d(1.5, (3, 2, 2), stride=(2, 1, 2), channel_axis=-1)
1661
+ >>> input = brainstate.random.randn(20, 50, 44, 31, 16)
1662
+ >>> output = m(input)
1663
+ >>> output.shape
1664
+ (20, 24, 43, 15, 16)
1665
+ """
1666
+ __module__ = 'brainstate.nn'
1667
+
1668
+ def __init__(
1669
+ self,
1670
+ norm_type: float,
1671
+ kernel_size: Size,
1672
+ stride: Union[int, Sequence[int]] = None,
1673
+ padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
1674
+ channel_axis: Optional[int] = -1,
1675
+ name: Optional[str] = None,
1676
+ in_size: Optional[Size] = None,
1677
+ ):
1678
+ super().__init__(
1679
+ norm_type=norm_type,
1680
+ pool_dim=3,
1681
+ kernel_size=kernel_size,
1682
+ stride=stride,
1683
+ padding=padding,
1684
+ channel_axis=channel_axis,
1685
+ name=name,
1686
+ in_size=in_size
1687
+ )
805
1688
 
806
1689
 
807
1690
  def _adaptive_pool1d(x, target_size: int, operation: Callable):
@@ -919,259 +1802,438 @@ class _AdaptivePool(Module):
919
1802
 
920
1803
 
921
1804
  class AdaptiveAvgPool1d(_AdaptivePool):
922
- r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
1805
+ r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
923
1806
 
924
1807
  The output size is :math:`L_{out}`, for any input size.
925
1808
  The number of output features is equal to the number of input planes.
926
1809
 
1810
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1811
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
1812
+
927
1813
  Shape:
928
1814
  - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
929
1815
  - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
930
- :math:`L_{out}=\text{output\_size}`.
1816
+ :math:`L_{out}=\text{target\_size}`.
931
1817
 
932
- Examples:
933
-
934
- >>> import brainstate as brainstate
935
- >>> # target output size of 5
936
- >>> m = AdaptiveMaxPool1d(5)
1818
+ Parameters
1819
+ ----------
1820
+ target_size : int or sequence of int
1821
+ The target output size. The number of output features for each channel.
1822
+ channel_axis : int, optional
1823
+ Axis of the spatial channels for which pooling is skipped.
1824
+ If ``None``, there is no channel axis. Default: -1
1825
+ name : str, optional
1826
+ The name of the module.
1827
+ in_size : Sequence of int, optional
1828
+ The shape of the input tensor for shape inference.
1829
+
1830
+ Examples
1831
+ --------
1832
+ .. code-block:: python
1833
+
1834
+ >>> import brainstate
1835
+ >>> # Target output size of 5
1836
+ >>> m = AdaptiveAvgPool1d(5)
937
1837
  >>> input = brainstate.random.randn(1, 64, 8)
938
1838
  >>> output = m(input)
939
1839
  >>> output.shape
940
1840
  (1, 5, 8)
941
-
942
- Parameters
943
- ----------
944
- in_size: Sequence of int
945
- The shape of the input tensor.
946
- target_size: int, sequence of int
947
- The target output shape.
948
- channel_axis: int, optional
949
- Axis of the spatial channels for which pooling is skipped.
950
- If ``None``, there is no channel axis.
951
- name: str
952
- The class name.
1841
+ >>> # Can handle variable input sizes
1842
+ >>> input2 = brainstate.random.randn(1, 32, 8)
1843
+ >>> output2 = m(input2)
1844
+ >>> output2.shape
1845
+ (1, 5, 8) # Same output size regardless of input size
1846
+
1847
+ See Also
1848
+ --------
1849
+ AvgPool1d : Non-adaptive 1D average pooling.
1850
+ AdaptiveMaxPool1d : Adaptive 1D max pooling.
953
1851
  """
954
1852
  __module__ = 'brainstate.nn'
955
1853
 
956
- def __init__(self,
957
- target_size: Union[int, Sequence[int]],
958
- channel_axis: Optional[int] = -1,
959
- name: Optional[str] = None,
960
- in_size: Optional[Sequence[int]] = None, ):
961
- super().__init__(in_size=in_size,
962
- target_size=target_size,
963
- channel_axis=channel_axis,
964
- num_spatial_dims=1,
965
- operation=jnp.mean,
966
- name=name)
1854
+ def __init__(
1855
+ self,
1856
+ target_size: Union[int, Sequence[int]],
1857
+ channel_axis: Optional[int] = -1,
1858
+ name: Optional[str] = None,
1859
+ in_size: Optional[Sequence[int]] = None,
1860
+ ):
1861
+ super().__init__(
1862
+ in_size=in_size,
1863
+ target_size=target_size,
1864
+ channel_axis=channel_axis,
1865
+ num_spatial_dims=1,
1866
+ operation=jnp.mean,
1867
+ name=name
1868
+ )
967
1869
 
968
1870
 
969
1871
  class AdaptiveAvgPool2d(_AdaptivePool):
970
- r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
1872
+ r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
971
1873
 
972
1874
  The output is of size :math:`H_{out} \times W_{out}`, for any input size.
973
1875
  The number of output features is equal to the number of input planes.
974
1876
 
1877
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1878
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
1879
+
975
1880
  Shape:
976
1881
  - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
977
1882
  - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
978
- :math:`(H_{out}, W_{out})=\text{output\_size}`.
979
-
980
- Examples:
1883
+ :math:`(H_{out}, W_{out})=\text{target\_size}`.
981
1884
 
982
- >>> import brainstate as brainstate
983
- >>> # target output size of 5x7
984
- >>> m = AdaptiveMaxPool2d((5, 7))
1885
+ Parameters
1886
+ ----------
1887
+ target_size : int or tuple of int
1888
+ The target output size. If a single integer is provided, the output will be a square
1889
+ of that size. If a tuple is provided, it specifies (H_out, W_out).
1890
+ Use None for dimensions that should not be pooled.
1891
+ channel_axis : int, optional
1892
+ Axis of the spatial channels for which pooling is skipped.
1893
+ If ``None``, there is no channel axis. Default: -1
1894
+ name : str, optional
1895
+ The name of the module.
1896
+ in_size : Sequence of int, optional
1897
+ The shape of the input tensor for shape inference.
1898
+
1899
+ Examples
1900
+ --------
1901
+ .. code-block:: python
1902
+
1903
+ >>> import brainstate
1904
+ >>> # Target output size of 5x7
1905
+ >>> m = AdaptiveAvgPool2d((5, 7))
985
1906
  >>> input = brainstate.random.randn(1, 8, 9, 64)
986
1907
  >>> output = m(input)
987
1908
  >>> output.shape
988
1909
  (1, 5, 7, 64)
989
- >>> # target output size of 7x7 (square)
990
- >>> m = AdaptiveMaxPool2d(7)
1910
+ >>> # Target output size of 7x7 (square)
1911
+ >>> m = AdaptiveAvgPool2d(7)
991
1912
  >>> input = brainstate.random.randn(1, 10, 9, 64)
992
1913
  >>> output = m(input)
993
1914
  >>> output.shape
994
1915
  (1, 7, 7, 64)
995
- >>> # target output size of 10x7
996
- >>> m = AdaptiveMaxPool2d((None, 7))
1916
+ >>> # Target output size of 10x7
1917
+ >>> m = AdaptiveAvgPool2d((None, 7))
997
1918
  >>> input = brainstate.random.randn(1, 10, 9, 64)
998
1919
  >>> output = m(input)
999
1920
  >>> output.shape
1000
1921
  (1, 10, 7, 64)
1001
1922
 
1002
- Parameters
1003
- ----------
1004
- in_size: Sequence of int
1005
- The shape of the input tensor.
1006
- target_size: int, sequence of int
1007
- The target output shape.
1008
- channel_axis: int, optional
1009
- Axis of the spatial channels for which pooling is skipped.
1010
- If ``None``, there is no channel axis.
1011
- name: str
1012
- The class name.
1923
+ See Also
1924
+ --------
1925
+ AvgPool2d : Non-adaptive 2D average pooling.
1926
+ AdaptiveMaxPool2d : Adaptive 2D max pooling.
1013
1927
  """
1014
1928
  __module__ = 'brainstate.nn'
1015
1929
 
1016
- def __init__(self,
1017
- target_size: Union[int, Sequence[int]],
1018
- channel_axis: Optional[int] = -1,
1019
- name: Optional[str] = None,
1020
-
1021
- in_size: Optional[Sequence[int]] = None, ):
1022
- super().__init__(in_size=in_size,
1023
- target_size=target_size,
1024
- channel_axis=channel_axis,
1025
- num_spatial_dims=2,
1026
- operation=jnp.mean,
1027
- name=name)
1930
+ def __init__(
1931
+ self,
1932
+ target_size: Union[int, Sequence[int]],
1933
+ channel_axis: Optional[int] = -1,
1934
+ name: Optional[str] = None,
1935
+ in_size: Optional[Sequence[int]] = None,
1936
+ ):
1937
+ super().__init__(
1938
+ in_size=in_size,
1939
+ target_size=target_size,
1940
+ channel_axis=channel_axis,
1941
+ num_spatial_dims=2,
1942
+ operation=jnp.mean,
1943
+ name=name
1944
+ )
1028
1945
 
1029
1946
 
1030
1947
  class AdaptiveAvgPool3d(_AdaptivePool):
1031
- r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1948
+ r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
1032
1949
 
1033
1950
  The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1034
1951
  The number of output features is equal to the number of input planes.
1035
1952
 
1953
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
1954
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
1955
+
1036
1956
  Shape:
1037
1957
  - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
1038
1958
  - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
1039
- where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1040
-
1041
- Examples:
1959
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
1042
1960
 
1043
- >>> import brainstate as brainstate
1044
- >>> # target output size of 5x7x9
1045
- >>> m = AdaptiveMaxPool3d((5, 7, 9))
1961
+ Parameters
1962
+ ----------
1963
+ target_size : int or tuple of int
1964
+ The target output size. If a single integer is provided, the output will be a cube
1965
+ of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
1966
+ Use None for dimensions that should not be pooled.
1967
+ channel_axis : int, optional
1968
+ Axis of the spatial channels for which pooling is skipped.
1969
+ If ``None``, there is no channel axis. Default: -1
1970
+ name : str, optional
1971
+ The name of the module.
1972
+ in_size : Sequence of int, optional
1973
+ The shape of the input tensor for shape inference.
1974
+
1975
+ Examples
1976
+ --------
1977
+ .. code-block:: python
1978
+
1979
+ >>> import brainstate
1980
+ >>> # Target output size of 5x7x9
1981
+ >>> m = AdaptiveAvgPool3d((5, 7, 9))
1046
1982
  >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
1047
1983
  >>> output = m(input)
1048
1984
  >>> output.shape
1049
1985
  (1, 5, 7, 9, 64)
1050
- >>> # target output size of 7x7x7 (cube)
1051
- >>> m = AdaptiveMaxPool3d(7)
1986
+ >>> # Target output size of 7x7x7 (cube)
1987
+ >>> m = AdaptiveAvgPool3d(7)
1052
1988
  >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1053
1989
  >>> output = m(input)
1054
1990
  >>> output.shape
1055
1991
  (1, 7, 7, 7, 64)
1056
- >>> # target output size of 7x9x8
1057
- >>> m = AdaptiveMaxPool3d((7, None, None))
1992
+ >>> # Target output size of 7x9x8
1993
+ >>> m = AdaptiveAvgPool3d((7, None, None))
1058
1994
  >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
1059
1995
  >>> output = m(input)
1060
1996
  >>> output.shape
1061
1997
  (1, 7, 9, 8, 64)
1062
1998
 
1063
- Parameters
1064
- ----------
1065
- in_size: Sequence of int
1066
- The shape of the input tensor.
1067
- target_size: int, sequence of int
1068
- The target output shape.
1069
- channel_axis: int, optional
1070
- Axis of the spatial channels for which pooling is skipped.
1071
- If ``None``, there is no channel axis.
1072
- name: str
1073
- The class name.
1999
+ See Also
2000
+ --------
2001
+ AvgPool3d : Non-adaptive 3D average pooling.
2002
+ AdaptiveMaxPool3d : Adaptive 3D max pooling.
1074
2003
  """
1075
2004
  __module__ = 'brainstate.nn'
1076
2005
 
1077
- def __init__(self,
1078
- target_size: Union[int, Sequence[int]],
1079
- channel_axis: Optional[int] = -1,
1080
- name: Optional[str] = None,
1081
- in_size: Optional[Sequence[int]] = None, ):
1082
- super().__init__(in_size=in_size,
1083
- target_size=target_size,
1084
- channel_axis=channel_axis,
1085
- num_spatial_dims=3,
1086
- operation=jnp.mean,
1087
- name=name)
2006
+ def __init__(
2007
+ self,
2008
+ target_size: Union[int, Sequence[int]],
2009
+ channel_axis: Optional[int] = -1,
2010
+ name: Optional[str] = None,
2011
+ in_size: Optional[Sequence[int]] = None,
2012
+ ):
2013
+ super().__init__(
2014
+ in_size=in_size,
2015
+ target_size=target_size,
2016
+ channel_axis=channel_axis,
2017
+ num_spatial_dims=3,
2018
+ operation=jnp.mean,
2019
+ name=name
2020
+ )
1088
2021
 
1089
2022
 
1090
2023
  class AdaptiveMaxPool1d(_AdaptivePool):
1091
- """Adaptive one-dimensional maximum down-sampling.
2024
+ r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
2025
+
2026
+ The output size is :math:`L_{out}`, for any input size.
2027
+ The number of output features is equal to the number of input planes.
2028
+
2029
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2030
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
2031
+
2032
+ Shape:
2033
+ - Input: :math:`(N, L_{in}, C)` or :math:`(L_{in}, C)`.
2034
+ - Output: :math:`(N, L_{out}, C)` or :math:`(L_{out}, C)`, where
2035
+ :math:`L_{out}=\text{target\_size}`.
1092
2036
 
1093
2037
  Parameters
1094
2038
  ----------
1095
- in_size: Sequence of int
1096
- The shape of the input tensor.
1097
- target_size: int, sequence of int
1098
- The target output shape.
1099
- channel_axis: int, optional
1100
- Axis of the spatial channels for which pooling is skipped.
1101
- If ``None``, there is no channel axis.
1102
- name: str
1103
- The class name.
2039
+ target_size : int or sequence of int
2040
+ The target output size. The number of output features for each channel.
2041
+ channel_axis : int, optional
2042
+ Axis of the spatial channels for which pooling is skipped.
2043
+ If ``None``, there is no channel axis. Default: -1
2044
+ name : str, optional
2045
+ The name of the module.
2046
+ in_size : Sequence of int, optional
2047
+ The shape of the input tensor for shape inference.
2048
+
2049
+ Examples
2050
+ --------
2051
+ .. code-block:: python
2052
+
2053
+ >>> import brainstate
2054
+ >>> # Target output size of 5
2055
+ >>> m = AdaptiveMaxPool1d(5)
2056
+ >>> input = brainstate.random.randn(1, 64, 8)
2057
+ >>> output = m(input)
2058
+ >>> output.shape
2059
+ (1, 5, 8)
2060
+ >>> # Can handle variable input sizes
2061
+ >>> input2 = brainstate.random.randn(1, 32, 8)
2062
+ >>> output2 = m(input2)
2063
+ >>> output2.shape
2064
+ (1, 5, 8) # Same output size regardless of input size
2065
+
2066
+ See Also
2067
+ --------
2068
+ MaxPool1d : Non-adaptive 1D max pooling.
2069
+ AdaptiveAvgPool1d : Adaptive 1D average pooling.
1104
2070
  """
1105
2071
  __module__ = 'brainstate.nn'
1106
2072
 
1107
- def __init__(self,
1108
- target_size: Union[int, Sequence[int]],
1109
- channel_axis: Optional[int] = -1,
1110
- name: Optional[str] = None,
1111
- in_size: Optional[Sequence[int]] = None, ):
1112
- super().__init__(in_size=in_size,
1113
- target_size=target_size,
1114
- channel_axis=channel_axis,
1115
- num_spatial_dims=1,
1116
- operation=jnp.max,
1117
- name=name)
2073
+ def __init__(
2074
+ self,
2075
+ target_size: Union[int, Sequence[int]],
2076
+ channel_axis: Optional[int] = -1,
2077
+ name: Optional[str] = None,
2078
+ in_size: Optional[Sequence[int]] = None,
2079
+ ):
2080
+ super().__init__(
2081
+ in_size=in_size,
2082
+ target_size=target_size,
2083
+ channel_axis=channel_axis,
2084
+ num_spatial_dims=1,
2085
+ operation=jnp.max,
2086
+ name=name
2087
+ )
1118
2088
 
1119
2089
 
1120
2090
  class AdaptiveMaxPool2d(_AdaptivePool):
1121
- """Adaptive two-dimensional maximum down-sampling.
2091
+ r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
2092
+
2093
+ The output is of size :math:`H_{out} \times W_{out}`, for any input size.
2094
+ The number of output features is equal to the number of input planes.
2095
+
2096
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2097
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
2098
+
2099
+ Shape:
2100
+ - Input: :math:`(N, H_{in}, W_{in}, C)` or :math:`(H_{in}, W_{in}, C)`.
2101
+ - Output: :math:`(N, H_{out}, W_{out}, C)` or :math:`(H_{out}, W_{out}, C)`, where
2102
+ :math:`(H_{out}, W_{out})=\text{target\_size}`.
1122
2103
 
1123
2104
  Parameters
1124
2105
  ----------
1125
- in_size: Sequence of int
1126
- The shape of the input tensor.
1127
- target_size: int, sequence of int
1128
- The target output shape.
1129
- channel_axis: int, optional
1130
- Axis of the spatial channels for which pooling is skipped.
1131
- If ``None``, there is no channel axis.
1132
- name: str
1133
- The class name.
2106
+ target_size : int or tuple of int
2107
+ The target output size. If a single integer is provided, the output will be a square
2108
+ of that size. If a tuple is provided, it specifies (H_out, W_out).
2109
+ Use None for dimensions that should not be pooled.
2110
+ channel_axis : int, optional
2111
+ Axis of the spatial channels for which pooling is skipped.
2112
+ If ``None``, there is no channel axis. Default: -1
2113
+ name : str, optional
2114
+ The name of the module.
2115
+ in_size : Sequence of int, optional
2116
+ The shape of the input tensor for shape inference.
2117
+
2118
+ Examples
2119
+ --------
2120
+ .. code-block:: python
2121
+
2122
+ >>> import brainstate
2123
+ >>> # Target output size of 5x7
2124
+ >>> m = AdaptiveMaxPool2d((5, 7))
2125
+ >>> input = brainstate.random.randn(1, 8, 9, 64)
2126
+ >>> output = m(input)
2127
+ >>> output.shape
2128
+ (1, 5, 7, 64)
2129
+ >>> # Target output size of 7x7 (square)
2130
+ >>> m = AdaptiveMaxPool2d(7)
2131
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
2132
+ >>> output = m(input)
2133
+ >>> output.shape
2134
+ (1, 7, 7, 64)
2135
+ >>> # Target output size of 10x7
2136
+ >>> m = AdaptiveMaxPool2d((None, 7))
2137
+ >>> input = brainstate.random.randn(1, 10, 9, 64)
2138
+ >>> output = m(input)
2139
+ >>> output.shape
2140
+ (1, 10, 7, 64)
2141
+
2142
+ See Also
2143
+ --------
2144
+ MaxPool2d : Non-adaptive 2D max pooling.
2145
+ AdaptiveAvgPool2d : Adaptive 2D average pooling.
1134
2146
  """
1135
2147
  __module__ = 'brainstate.nn'
1136
2148
 
1137
- def __init__(self,
1138
- target_size: Union[int, Sequence[int]],
1139
- channel_axis: Optional[int] = -1,
1140
- name: Optional[str] = None,
1141
- in_size: Optional[Sequence[int]] = None, ):
1142
- super().__init__(in_size=in_size,
1143
- target_size=target_size,
1144
- channel_axis=channel_axis,
1145
- num_spatial_dims=2,
1146
- operation=jnp.max,
1147
- name=name)
2149
+ def __init__(
2150
+ self,
2151
+ target_size: Union[int, Sequence[int]],
2152
+ channel_axis: Optional[int] = -1,
2153
+ name: Optional[str] = None,
2154
+ in_size: Optional[Sequence[int]] = None,
2155
+ ):
2156
+ super().__init__(
2157
+ in_size=in_size,
2158
+ target_size=target_size,
2159
+ channel_axis=channel_axis,
2160
+ num_spatial_dims=2,
2161
+ operation=jnp.max,
2162
+ name=name
2163
+ )
1148
2164
 
1149
2165
 
1150
2166
  class AdaptiveMaxPool3d(_AdaptivePool):
1151
- """Adaptive three-dimensional maximum down-sampling.
2167
+ r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
2168
+
2169
+ The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
2170
+ The number of output features is equal to the number of input planes.
2171
+
2172
+ Adaptive pooling automatically computes the kernel size and stride to achieve the desired
2173
+ output size, making it useful for creating fixed-size representations from variable-sized inputs.
2174
+
2175
+ Shape:
2176
+ - Input: :math:`(N, D_{in}, H_{in}, W_{in}, C)` or :math:`(D_{in}, H_{in}, W_{in}, C)`.
2177
+ - Output: :math:`(N, D_{out}, H_{out}, W_{out}, C)` or :math:`(D_{out}, H_{out}, W_{out}, C)`,
2178
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{target\_size}`.
1152
2179
 
1153
2180
  Parameters
1154
2181
  ----------
1155
- in_size: Sequence of int
1156
- The shape of the input tensor.
1157
- target_size: int, sequence of int
1158
- The target output shape.
1159
- channel_axis: int, optional
1160
- Axis of the spatial channels for which pooling is skipped.
1161
- If ``None``, there is no channel axis.
1162
- name: str
1163
- The class name.
2182
+ target_size : int or tuple of int
2183
+ The target output size. If a single integer is provided, the output will be a cube
2184
+ of that size. If a tuple is provided, it specifies (D_out, H_out, W_out).
2185
+ Use None for dimensions that should not be pooled.
2186
+ channel_axis : int, optional
2187
+ Axis of the spatial channels for which pooling is skipped.
2188
+ If ``None``, there is no channel axis. Default: -1
2189
+ name : str, optional
2190
+ The name of the module.
2191
+ in_size : Sequence of int, optional
2192
+ The shape of the input tensor for shape inference.
2193
+
2194
+ Examples
2195
+ --------
2196
+ .. code-block:: python
2197
+
2198
+ >>> import brainstate
2199
+ >>> # Target output size of 5x7x9
2200
+ >>> m = AdaptiveMaxPool3d((5, 7, 9))
2201
+ >>> input = brainstate.random.randn(1, 8, 9, 10, 64)
2202
+ >>> output = m(input)
2203
+ >>> output.shape
2204
+ (1, 5, 7, 9, 64)
2205
+ >>> # Target output size of 7x7x7 (cube)
2206
+ >>> m = AdaptiveMaxPool3d(7)
2207
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
2208
+ >>> output = m(input)
2209
+ >>> output.shape
2210
+ (1, 7, 7, 7, 64)
2211
+ >>> # Target output size of 7x9x8
2212
+ >>> m = AdaptiveMaxPool3d((7, None, None))
2213
+ >>> input = brainstate.random.randn(1, 10, 9, 8, 64)
2214
+ >>> output = m(input)
2215
+ >>> output.shape
2216
+ (1, 7, 9, 8, 64)
2217
+
2218
+ See Also
2219
+ --------
2220
+ MaxPool3d : Non-adaptive 3D max pooling.
2221
+ AdaptiveAvgPool3d : Adaptive 3D average pooling.
1164
2222
  """
1165
2223
  __module__ = 'brainstate.nn'
1166
2224
 
1167
- def __init__(self,
1168
- target_size: Union[int, Sequence[int]],
1169
- channel_axis: Optional[int] = -1,
1170
- name: Optional[str] = None,
1171
- in_size: Optional[Sequence[int]] = None, ):
1172
- super().__init__(in_size=in_size,
1173
- target_size=target_size,
1174
- channel_axis=channel_axis,
1175
- num_spatial_dims=3,
1176
- operation=jnp.max,
1177
- name=name)
2225
+ def __init__(
2226
+ self,
2227
+ target_size: Union[int, Sequence[int]],
2228
+ channel_axis: Optional[int] = -1,
2229
+ name: Optional[str] = None,
2230
+ in_size: Optional[Sequence[int]] = None,
2231
+ ):
2232
+ super().__init__(
2233
+ in_size=in_size,
2234
+ target_size=target_size,
2235
+ channel_axis=channel_axis,
2236
+ num_spatial_dims=3,
2237
+ operation=jnp.max,
2238
+ name=name
2239
+ )