brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1020 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Padding layers for neural networks.
18
+
19
+ This module provides various padding operations for 1D, 2D, and 3D tensors:
20
+ - ReflectionPad: Pads using reflection of the input boundary
21
+ - ReplicationPad: Pads using replication of the input boundary
22
+ - ZeroPad: Pads with zeros
23
+ - ConstantPad: Pads with a constant value
24
+ - CircularPad: Pads circularly (wrap around)
25
+ """
26
+
27
+ import functools
28
+ from typing import Union, Sequence, Optional
29
+
30
+ import jax
31
+ import jax.numpy as jnp
32
+
33
+ from brainstate import environ
34
+ from brainstate.typing import Size
35
+ from ._module import Module
36
+
37
+ __all__ = [
38
+ # Reflection padding
39
+ 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
40
+ # Replication padding
41
+ 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d',
42
+ # Zero padding
43
+ 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d',
44
+ # Constant padding
45
+ 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d',
46
+ # Circular padding
47
+ 'CircularPad1d', 'CircularPad2d', 'CircularPad3d',
48
+ ]
49
+
50
+
51
+ def _format_padding(padding: Union[int, Sequence[int]], ndim: int) -> Sequence[tuple]:
52
+ """
53
+ Convert padding specification to format required by jax.numpy.pad.
54
+
55
+ Args:
56
+ padding: Padding size(s). Can be:
57
+ - int: same padding for all sides
58
+ - Sequence of length 2*ndim: (left, right) for each dimension
59
+ - Sequence of length ndim: same padding for left and right of each dimension
60
+ ndim: Number of spatial dimensions (1, 2, or 3)
61
+
62
+ Returns:
63
+ List of padding tuples for each dimension
64
+ """
65
+ if isinstance(padding, int):
66
+ # Same padding for all sides of all dimensions
67
+ return [(padding, padding) for _ in range(ndim)]
68
+
69
+ padding = list(padding)
70
+
71
+ if len(padding) == ndim:
72
+ # Same padding for left and right of each dimension
73
+ return [(p, p) for p in padding]
74
+ elif len(padding) == 2 * ndim:
75
+ # Different padding for each side: (left1, right1, left2, right2, ...)
76
+ return [(padding[2 * i], padding[2 * i + 1]) for i in range(ndim)]
77
+ else:
78
+ raise ValueError(f"Padding must have length {ndim} or {2 * ndim}, got {len(padding)}")
79
+
80
+
81
+ # =============================================================================
82
+ # Reflection Padding
83
+ # =============================================================================
84
+
85
+ class ReflectionPad1d(Module):
86
+ """
87
+ Pads the input tensor using the reflection of the input boundary.
88
+
89
+ Parameters
90
+ ----------
91
+ padding : int or Sequence[int]
92
+ The size of the padding. Can be:
93
+
94
+ - int: same padding for both sides
95
+ - Sequence[int] of length 2: (left, right)
96
+ in_size : Size, optional
97
+ The input size.
98
+ name : str, optional
99
+ The name of the module.
100
+
101
+ Examples
102
+ --------
103
+ .. code-block:: python
104
+
105
+ >>> import brainstate as brainstate
106
+ >>> import jax.numpy as jnp
107
+ >>> pad = brainstate.nn.ReflectionPad1d(2)
108
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
109
+ >>> output = pad(input)
110
+ >>> print(output.shape)
111
+ (1, 9, 1)
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ padding: Union[int, Sequence[int]],
117
+ in_size: Optional[Size] = None,
118
+ name: Optional[str] = None
119
+ ):
120
+ super().__init__(name=name)
121
+ self.padding = _format_padding(padding, 1)
122
+ if in_size is not None:
123
+ self.in_size = in_size
124
+ y = jax.eval_shape(
125
+ functools.partial(self.update),
126
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
127
+ )
128
+ self.out_size = y.shape
129
+
130
+ def update(self, x):
131
+ # Add (0, 0) padding for non-spatial dimensions
132
+ ndim = x.ndim
133
+ if ndim == 2:
134
+ # (length, channels) -> pad only length dimension
135
+ pad_width = [self.padding[0], (0, 0)]
136
+ elif ndim == 3:
137
+ # (batch, length, channels) -> pad only length dimension
138
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
139
+ else:
140
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
141
+
142
+ return jnp.pad(x, pad_width, mode='reflect')
143
+
144
+
145
+ class ReflectionPad2d(Module):
146
+ """
147
+ Pads the input tensor using the reflection of the input boundary.
148
+
149
+ Parameters
150
+ ----------
151
+ padding : int or Sequence[int]
152
+ The size of the padding. Can be:
153
+
154
+ - int: same padding for all sides
155
+ - Sequence[int] of length 2: (height_pad, width_pad)
156
+ - Sequence[int] of length 4: (left, right, top, bottom)
157
+ in_size : Size, optional
158
+ The input size.
159
+ name : str, optional
160
+ The name of the module.
161
+
162
+ Examples
163
+ --------
164
+ .. code-block:: python
165
+
166
+ >>> import brainstate as brainstate
167
+ >>> import jax.numpy as jnp
168
+ >>> pad = brainstate.nn.ReflectionPad2d(1)
169
+ >>> input = jnp.ones((1, 4, 4, 3))
170
+ >>> output = pad(input)
171
+ >>> print(output.shape)
172
+ (1, 6, 6, 3)
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ padding: Union[int, Sequence[int]],
178
+ in_size: Optional[Size] = None,
179
+ name: Optional[str] = None
180
+ ):
181
+ super().__init__(name=name)
182
+ self.padding = _format_padding(padding, 2)
183
+ if in_size is not None:
184
+ self.in_size = in_size
185
+ y = jax.eval_shape(
186
+ functools.partial(self.update),
187
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
188
+ )
189
+ self.out_size = y.shape
190
+
191
+ def update(self, x):
192
+ # Add (0, 0) padding for non-spatial dimensions
193
+ ndim = x.ndim
194
+ if ndim == 3:
195
+ # (height, width, channels) -> pad height and width
196
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
197
+ elif ndim == 4:
198
+ # (batch, height, width, channels) -> pad height and width
199
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
200
+ else:
201
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
202
+
203
+ return jnp.pad(x, pad_width, mode='reflect')
204
+
205
+
206
+ class ReflectionPad3d(Module):
207
+ """
208
+ Pads the input tensor using the reflection of the input boundary.
209
+
210
+ Parameters
211
+ ----------
212
+ padding : int or Sequence[int]
213
+ The size of the padding. Can be:
214
+
215
+ - int: same padding for all sides
216
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
217
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
218
+ in_size : Size, optional
219
+ The input size.
220
+ name : str, optional
221
+ The name of the module.
222
+
223
+ Examples
224
+ --------
225
+ .. code-block:: python
226
+
227
+ >>> import brainstate as brainstate
228
+ >>> import jax.numpy as jnp
229
+ >>> pad = brainstate.nn.ReflectionPad3d(1)
230
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
231
+ >>> output = pad(input)
232
+ >>> print(output.shape)
233
+ (1, 6, 6, 6, 3)
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ padding: Union[int, Sequence[int]],
239
+ in_size: Optional[Size] = None,
240
+ name: Optional[str] = None
241
+ ):
242
+ super().__init__(name=name)
243
+ self.padding = _format_padding(padding, 3)
244
+ if in_size is not None:
245
+ self.in_size = in_size
246
+ y = jax.eval_shape(
247
+ functools.partial(self.update),
248
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
249
+ )
250
+ self.out_size = y.shape
251
+
252
+ def update(self, x):
253
+ # Add (0, 0) padding for non-spatial dimensions
254
+ ndim = x.ndim
255
+ if ndim == 4:
256
+ # (depth, height, width, channels) -> pad depth, height and width
257
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
258
+ elif ndim == 5:
259
+ # (batch, depth, height, width, channels) -> pad depth, height and width
260
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
261
+ else:
262
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
263
+
264
+ return jnp.pad(x, pad_width, mode='reflect')
265
+
266
+
267
+ # =============================================================================
268
+ # Replication Padding
269
+ # =============================================================================
270
+
271
+ class ReplicationPad1d(Module):
272
+ """
273
+ Pads the input tensor using replication of the input boundary.
274
+
275
+ Parameters
276
+ ----------
277
+ padding : int or Sequence[int]
278
+ The size of the padding. Can be:
279
+
280
+ - int: same padding for both sides
281
+ - Sequence[int] of length 2: (left, right)
282
+ in_size : Size, optional
283
+ The input size.
284
+ name : str, optional
285
+ The name of the module.
286
+
287
+ Examples
288
+ --------
289
+ .. code-block:: python
290
+
291
+ >>> import brainstate as brainstate
292
+ >>> import jax.numpy as jnp
293
+ >>> pad = brainstate.nn.ReplicationPad1d(2)
294
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
295
+ >>> output = pad(input)
296
+ >>> print(output.shape)
297
+ (1, 9, 1)
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ padding: Union[int, Sequence[int]],
303
+ in_size: Optional[Size] = None,
304
+ name: Optional[str] = None
305
+ ):
306
+ super().__init__(name=name)
307
+ self.padding = _format_padding(padding, 1)
308
+ if in_size is not None:
309
+ self.in_size = in_size
310
+ y = jax.eval_shape(
311
+ functools.partial(self.update),
312
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
313
+ )
314
+ self.out_size = y.shape
315
+
316
+ def update(self, x):
317
+ # Add (0, 0) padding for non-spatial dimensions
318
+ ndim = x.ndim
319
+ if ndim == 2:
320
+ # (length, channels) -> pad only length dimension
321
+ pad_width = [self.padding[0], (0, 0)]
322
+ elif ndim == 3:
323
+ # (batch, length, channels) -> pad only length dimension
324
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
325
+ else:
326
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
327
+
328
+ return jnp.pad(x, pad_width, mode='edge')
329
+
330
+
331
+ class ReplicationPad2d(Module):
332
+ """
333
+ Pads the input tensor using replication of the input boundary.
334
+
335
+ Parameters
336
+ ----------
337
+ padding : int or Sequence[int]
338
+ The size of the padding. Can be:
339
+
340
+ - int: same padding for all sides
341
+ - Sequence[int] of length 2: (height_pad, width_pad)
342
+ - Sequence[int] of length 4: (left, right, top, bottom)
343
+ in_size : Size, optional
344
+ The input size.
345
+ name : str, optional
346
+ The name of the module.
347
+
348
+ Examples
349
+ --------
350
+ .. code-block:: python
351
+
352
+ >>> import brainstate as brainstate
353
+ >>> import jax.numpy as jnp
354
+ >>> pad = brainstate.nn.ReplicationPad2d(1)
355
+ >>> input = jnp.ones((1, 4, 4, 3))
356
+ >>> output = pad(input)
357
+ >>> print(output.shape)
358
+ (1, 6, 6, 3)
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ padding: Union[int, Sequence[int]],
364
+ in_size: Optional[Size] = None,
365
+ name: Optional[str] = None
366
+ ):
367
+ super().__init__(name=name)
368
+ self.padding = _format_padding(padding, 2)
369
+ if in_size is not None:
370
+ self.in_size = in_size
371
+ y = jax.eval_shape(
372
+ functools.partial(self.update),
373
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
374
+ )
375
+ self.out_size = y.shape
376
+
377
+ def update(self, x):
378
+ # Add (0, 0) padding for non-spatial dimensions
379
+ ndim = x.ndim
380
+ if ndim == 3:
381
+ # (height, width, channels) -> pad height and width
382
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
383
+ elif ndim == 4:
384
+ # (batch, height, width, channels) -> pad height and width
385
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
386
+ else:
387
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
388
+
389
+ return jnp.pad(x, pad_width, mode='edge')
390
+
391
+
392
+ class ReplicationPad3d(Module):
393
+ """
394
+ Pads the input tensor using replication of the input boundary.
395
+
396
+ Parameters
397
+ ----------
398
+ padding : int or Sequence[int]
399
+ The size of the padding. Can be:
400
+
401
+ - int: same padding for all sides
402
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
403
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
404
+ in_size : Size, optional
405
+ The input size.
406
+ name : str, optional
407
+ The name of the module.
408
+
409
+ Examples
410
+ --------
411
+ .. code-block:: python
412
+
413
+ >>> import brainstate as brainstate
414
+ >>> import jax.numpy as jnp
415
+ >>> pad = brainstate.nn.ReplicationPad3d(1)
416
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
417
+ >>> output = pad(input)
418
+ >>> print(output.shape)
419
+ (1, 6, 6, 6, 3)
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ padding: Union[int, Sequence[int]],
425
+ in_size: Optional[Size] = None,
426
+ name: Optional[str] = None
427
+ ):
428
+ super().__init__(name=name)
429
+ self.padding = _format_padding(padding, 3)
430
+ if in_size is not None:
431
+ self.in_size = in_size
432
+ y = jax.eval_shape(
433
+ functools.partial(self.update),
434
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
435
+ )
436
+ self.out_size = y.shape
437
+
438
+ def update(self, x):
439
+ # Add (0, 0) padding for non-spatial dimensions
440
+ ndim = x.ndim
441
+ if ndim == 4:
442
+ # (depth, height, width, channels) -> pad depth, height and width
443
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
444
+ elif ndim == 5:
445
+ # (batch, depth, height, width, channels) -> pad depth, height and width
446
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
447
+ else:
448
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
449
+
450
+ return jnp.pad(x, pad_width, mode='edge')
451
+
452
+
453
+ # =============================================================================
454
+ # Zero Padding
455
+ # =============================================================================
456
+
457
+ class ZeroPad1d(Module):
458
+ """
459
+ Pads the input tensor with zeros.
460
+
461
+ Parameters
462
+ ----------
463
+ padding : int or Sequence[int]
464
+ The size of the padding. Can be:
465
+
466
+ - int: same padding for both sides
467
+ - Sequence[int] of length 2: (left, right)
468
+ in_size : Size, optional
469
+ The input size.
470
+ name : str, optional
471
+ The name of the module.
472
+
473
+ Examples
474
+ --------
475
+ .. code-block:: python
476
+
477
+ >>> import brainstate as brainstate
478
+ >>> import jax.numpy as jnp
479
+ >>> pad = brainstate.nn.ZeroPad1d(2)
480
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
481
+ >>> output = pad(input)
482
+ >>> print(output.shape)
483
+ (1, 9, 1)
484
+ """
485
+
486
+ def __init__(
487
+ self,
488
+ padding: Union[int, Sequence[int]],
489
+ in_size: Optional[Size] = None,
490
+ name: Optional[str] = None
491
+ ):
492
+ super().__init__(name=name)
493
+ self.padding = _format_padding(padding, 1)
494
+ if in_size is not None:
495
+ self.in_size = in_size
496
+ y = jax.eval_shape(
497
+ functools.partial(self.update),
498
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
499
+ )
500
+ self.out_size = y.shape
501
+
502
+ def update(self, x):
503
+ # Add (0, 0) padding for non-spatial dimensions
504
+ ndim = x.ndim
505
+ if ndim == 2:
506
+ # (length, channels) -> pad only length dimension
507
+ pad_width = [self.padding[0], (0, 0)]
508
+ elif ndim == 3:
509
+ # (batch, length, channels) -> pad only length dimension
510
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
511
+ else:
512
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
513
+
514
+ return jnp.pad(x, pad_width, mode='constant', constant_values=0)
515
+
516
+
517
+ class ZeroPad2d(Module):
518
+ """
519
+ Pads the input tensor with zeros.
520
+
521
+ Parameters
522
+ ----------
523
+ padding : int or Sequence[int]
524
+ The size of the padding. Can be:
525
+
526
+ - int: same padding for all sides
527
+ - Sequence[int] of length 2: (height_pad, width_pad)
528
+ - Sequence[int] of length 4: (left, right, top, bottom)
529
+ in_size : Size, optional
530
+ The input size.
531
+ name : str, optional
532
+ The name of the module.
533
+
534
+ Examples
535
+ --------
536
+ .. code-block:: python
537
+
538
+ >>> import brainstate as brainstate
539
+ >>> import jax.numpy as jnp
540
+ >>> pad = brainstate.nn.ZeroPad2d(1)
541
+ >>> input = jnp.ones((1, 4, 4, 3))
542
+ >>> output = pad(input)
543
+ >>> print(output.shape)
544
+ (1, 6, 6, 3)
545
+ """
546
+
547
+ def __init__(
548
+ self,
549
+ padding: Union[int, Sequence[int]],
550
+ in_size: Optional[Size] = None,
551
+ name: Optional[str] = None
552
+ ):
553
+ super().__init__(name=name)
554
+ self.padding = _format_padding(padding, 2)
555
+ if in_size is not None:
556
+ self.in_size = in_size
557
+ y = jax.eval_shape(
558
+ functools.partial(self.update),
559
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
560
+ )
561
+ self.out_size = y.shape
562
+
563
+ def update(self, x):
564
+ # Add (0, 0) padding for non-spatial dimensions
565
+ ndim = x.ndim
566
+ if ndim == 3:
567
+ # (height, width, channels) -> pad height and width
568
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
569
+ elif ndim == 4:
570
+ # (batch, height, width, channels) -> pad height and width
571
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
572
+ else:
573
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
574
+
575
+ return jnp.pad(x, pad_width, mode='constant', constant_values=0)
576
+
577
+
578
+ class ZeroPad3d(Module):
579
+ """
580
+ Pads the input tensor with zeros.
581
+
582
+ Parameters
583
+ ----------
584
+ padding : int or Sequence[int]
585
+ The size of the padding. Can be:
586
+
587
+ - int: same padding for all sides
588
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
589
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
590
+ in_size : Size, optional
591
+ The input size.
592
+ name : str, optional
593
+ The name of the module.
594
+
595
+ Examples
596
+ --------
597
+ .. code-block:: python
598
+
599
+ >>> import brainstate as brainstate
600
+ >>> import jax.numpy as jnp
601
+ >>> pad = brainstate.nn.ZeroPad3d(1)
602
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
603
+ >>> output = pad(input)
604
+ >>> print(output.shape)
605
+ (1, 6, 6, 6, 3)
606
+ """
607
+
608
+ def __init__(
609
+ self,
610
+ padding: Union[int, Sequence[int]],
611
+ in_size: Optional[Size] = None,
612
+ name: Optional[str] = None
613
+ ):
614
+ super().__init__(name=name)
615
+ self.padding = _format_padding(padding, 3)
616
+ if in_size is not None:
617
+ self.in_size = in_size
618
+ y = jax.eval_shape(
619
+ functools.partial(self.update),
620
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
621
+ )
622
+ self.out_size = y.shape
623
+
624
+ def update(self, x):
625
+ # Add (0, 0) padding for non-spatial dimensions
626
+ ndim = x.ndim
627
+ if ndim == 4:
628
+ # (depth, height, width, channels) -> pad depth, height and width
629
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
630
+ elif ndim == 5:
631
+ # (batch, depth, height, width, channels) -> pad depth, height and width
632
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
633
+ else:
634
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
635
+
636
+ return jnp.pad(x, pad_width, mode='constant', constant_values=0)
637
+
638
+
639
+ # =============================================================================
640
+ # Constant Padding
641
+ # =============================================================================
642
+
643
+ class ConstantPad1d(Module):
644
+ """
645
+ Pads the input tensor with a constant value.
646
+
647
+ Parameters
648
+ ----------
649
+ padding : int or Sequence[int]
650
+ The size of the padding. Can be:
651
+
652
+ - int: same padding for both sides
653
+ - Sequence[int] of length 2: (left, right)
654
+ value : float, optional
655
+ The constant value to use for padding. Default is 0.
656
+ in_size : Size, optional
657
+ The input size.
658
+ name : str, optional
659
+ The name of the module.
660
+
661
+ Examples
662
+ --------
663
+ .. code-block:: python
664
+
665
+ >>> import brainstate as brainstate
666
+ >>> import jax.numpy as jnp
667
+ >>> pad = brainstate.nn.ConstantPad1d(2, value=3.5)
668
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
669
+ >>> output = pad(input)
670
+ >>> print(output.shape)
671
+ (1, 9, 1)
672
+ """
673
+
674
+ def __init__(
675
+ self,
676
+ padding: Union[int, Sequence[int]],
677
+ value: float = 0,
678
+ in_size: Optional[Size] = None,
679
+ name: Optional[str] = None
680
+ ):
681
+ super().__init__(name=name)
682
+ self.padding = _format_padding(padding, 1)
683
+ self.value = value
684
+ if in_size is not None:
685
+ self.in_size = in_size
686
+ y = jax.eval_shape(
687
+ functools.partial(self.update),
688
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
689
+ )
690
+ self.out_size = y.shape
691
+
692
+ def update(self, x):
693
+ # Add (0, 0) padding for non-spatial dimensions
694
+ ndim = x.ndim
695
+ if ndim == 2:
696
+ # (length, channels) -> pad only length dimension
697
+ pad_width = [self.padding[0], (0, 0)]
698
+ elif ndim == 3:
699
+ # (batch, length, channels) -> pad only length dimension
700
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
701
+ else:
702
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
703
+
704
+ return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
705
+
706
+
707
+ class ConstantPad2d(Module):
708
+ """
709
+ Pads the input tensor with a constant value.
710
+
711
+ Parameters
712
+ ----------
713
+ padding : int or Sequence[int]
714
+ The size of the padding. Can be:
715
+
716
+ - int: same padding for all sides
717
+ - Sequence[int] of length 2: (height_pad, width_pad)
718
+ - Sequence[int] of length 4: (left, right, top, bottom)
719
+ value : float, optional
720
+ The constant value to use for padding. Default is 0.
721
+ in_size : Size, optional
722
+ The input size.
723
+ name : str, optional
724
+ The name of the module.
725
+
726
+ Examples
727
+ --------
728
+ .. code-block:: python
729
+
730
+ >>> import brainstate as brainstate
731
+ >>> import jax.numpy as jnp
732
+ >>> pad = brainstate.nn.ConstantPad2d(1, value=3.5)
733
+ >>> input = jnp.ones((1, 4, 4, 3))
734
+ >>> output = pad(input)
735
+ >>> print(output.shape)
736
+ (1, 6, 6, 3)
737
+ """
738
+
739
+ def __init__(
740
+ self,
741
+ padding: Union[int, Sequence[int]],
742
+ value: float = 0,
743
+ in_size: Optional[Size] = None,
744
+ name: Optional[str] = None
745
+ ):
746
+ super().__init__(name=name)
747
+ self.padding = _format_padding(padding, 2)
748
+ self.value = value
749
+ if in_size is not None:
750
+ self.in_size = in_size
751
+ y = jax.eval_shape(
752
+ functools.partial(self.update),
753
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
754
+ )
755
+ self.out_size = y.shape
756
+
757
+ def update(self, x):
758
+ # Add (0, 0) padding for non-spatial dimensions
759
+ ndim = x.ndim
760
+ if ndim == 3:
761
+ # (height, width, channels) -> pad height and width
762
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
763
+ elif ndim == 4:
764
+ # (batch, height, width, channels) -> pad height and width
765
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
766
+ else:
767
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
768
+
769
+ return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
770
+
771
+
772
+ class ConstantPad3d(Module):
773
+ """
774
+ Pads the input tensor with a constant value.
775
+
776
+ Parameters
777
+ ----------
778
+ padding : int or Sequence[int]
779
+ The size of the padding. Can be:
780
+
781
+ - int: same padding for all sides
782
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
783
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
784
+ value : float, optional
785
+ The constant value to use for padding. Default is 0.
786
+ in_size : Size, optional
787
+ The input size.
788
+ name : str, optional
789
+ The name of the module.
790
+
791
+ Examples
792
+ --------
793
+ .. code-block:: python
794
+
795
+ >>> import brainstate as brainstate
796
+ >>> import jax.numpy as jnp
797
+ >>> pad = brainstate.nn.ConstantPad3d(1, value=3.5)
798
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
799
+ >>> output = pad(input)
800
+ >>> print(output.shape)
801
+ (1, 6, 6, 6, 3)
802
+ """
803
+
804
+ def __init__(
805
+ self,
806
+ padding: Union[int, Sequence[int]],
807
+ value: float = 0,
808
+ in_size: Optional[Size] = None,
809
+ name: Optional[str] = None
810
+ ):
811
+ super().__init__(name=name)
812
+ self.padding = _format_padding(padding, 3)
813
+ self.value = value
814
+ if in_size is not None:
815
+ self.in_size = in_size
816
+ y = jax.eval_shape(
817
+ functools.partial(self.update),
818
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
819
+ )
820
+ self.out_size = y.shape
821
+
822
+ def update(self, x):
823
+ # Add (0, 0) padding for non-spatial dimensions
824
+ ndim = x.ndim
825
+ if ndim == 4:
826
+ # (depth, height, width, channels) -> pad depth, height and width
827
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
828
+ elif ndim == 5:
829
+ # (batch, depth, height, width, channels) -> pad depth, height and width
830
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
831
+ else:
832
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
833
+
834
+ return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
835
+
836
+
837
+ # =============================================================================
838
+ # Circular Padding
839
+ # =============================================================================
840
+
841
+ class CircularPad1d(Module):
842
+ """
843
+ Pads the input tensor using circular padding (wrap around).
844
+
845
+ Parameters
846
+ ----------
847
+ padding : int or Sequence[int]
848
+ The size of the padding. Can be:
849
+
850
+ - int: same padding for both sides
851
+ - Sequence[int] of length 2: (left, right)
852
+ in_size : Size, optional
853
+ The input size.
854
+ name : str, optional
855
+ The name of the module.
856
+
857
+ Examples
858
+ --------
859
+ .. code-block:: python
860
+
861
+ >>> import brainstate as brainstate
862
+ >>> import jax.numpy as jnp
863
+ >>> pad = brainstate.nn.CircularPad1d(2)
864
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
865
+ >>> output = pad(input)
866
+ >>> print(output.shape)
867
+ (1, 9, 1)
868
+ """
869
+
870
+ def __init__(
871
+ self,
872
+ padding: Union[int, Sequence[int]],
873
+ in_size: Optional[Size] = None,
874
+ name: Optional[str] = None
875
+ ):
876
+ super().__init__(name=name)
877
+ self.padding = _format_padding(padding, 1)
878
+ if in_size is not None:
879
+ self.in_size = in_size
880
+ y = jax.eval_shape(
881
+ functools.partial(self.update),
882
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
883
+ )
884
+ self.out_size = y.shape
885
+
886
+ def update(self, x):
887
+ # Add (0, 0) padding for non-spatial dimensions
888
+ ndim = x.ndim
889
+ if ndim == 2:
890
+ # (length, channels) -> pad only length dimension
891
+ pad_width = [self.padding[0], (0, 0)]
892
+ elif ndim == 3:
893
+ # (batch, length, channels) -> pad only length dimension
894
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
895
+ else:
896
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
897
+
898
+ return jnp.pad(x, pad_width, mode='wrap')
899
+
900
+
901
+ class CircularPad2d(Module):
902
+ """
903
+ Pads the input tensor using circular padding (wrap around).
904
+
905
+ Parameters
906
+ ----------
907
+ padding : int or Sequence[int]
908
+ The size of the padding. Can be:
909
+
910
+ - int: same padding for all sides
911
+ - Sequence[int] of length 2: (height_pad, width_pad)
912
+ - Sequence[int] of length 4: (left, right, top, bottom)
913
+ in_size : Size, optional
914
+ The input size.
915
+ name : str, optional
916
+ The name of the module.
917
+
918
+ Examples
919
+ --------
920
+ .. code-block:: python
921
+
922
+ >>> import brainstate as brainstate
923
+ >>> import jax.numpy as jnp
924
+ >>> pad = brainstate.nn.CircularPad2d(1)
925
+ >>> input = jnp.ones((1, 4, 4, 3))
926
+ >>> output = pad(input)
927
+ >>> print(output.shape)
928
+ (1, 6, 6, 3)
929
+ """
930
+
931
+ def __init__(
932
+ self,
933
+ padding: Union[int, Sequence[int]],
934
+ in_size: Optional[Size] = None,
935
+ name: Optional[str] = None
936
+ ):
937
+ super().__init__(name=name)
938
+ self.padding = _format_padding(padding, 2)
939
+ if in_size is not None:
940
+ self.in_size = in_size
941
+ y = jax.eval_shape(
942
+ functools.partial(self.update),
943
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
944
+ )
945
+ self.out_size = y.shape
946
+
947
+ def update(self, x):
948
+ # Add (0, 0) padding for non-spatial dimensions
949
+ ndim = x.ndim
950
+ if ndim == 3:
951
+ # (height, width, channels) -> pad height and width
952
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
953
+ elif ndim == 4:
954
+ # (batch, height, width, channels) -> pad height and width
955
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
956
+ else:
957
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
958
+
959
+ return jnp.pad(x, pad_width, mode='wrap')
960
+
961
+
962
+ class CircularPad3d(Module):
963
+ """
964
+ Pads the input tensor using circular padding (wrap around).
965
+
966
+ Parameters
967
+ ----------
968
+ padding : int or Sequence[int]
969
+ The size of the padding. Can be:
970
+
971
+ - int: same padding for all sides
972
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
973
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
974
+ in_size : Size, optional
975
+ The input size.
976
+ name : str, optional
977
+ The name of the module.
978
+
979
+ Examples
980
+ --------
981
+ .. code-block:: python
982
+
983
+ >>> import brainstate as brainstate
984
+ >>> import jax.numpy as jnp
985
+ >>> pad = brainstate.nn.CircularPad3d(1)
986
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
987
+ >>> output = pad(input)
988
+ >>> print(output.shape)
989
+ (1, 6, 6, 6, 3)
990
+ """
991
+
992
+ def __init__(
993
+ self,
994
+ padding: Union[int, Sequence[int]],
995
+ in_size: Optional[Size] = None,
996
+ name: Optional[str] = None
997
+ ):
998
+ super().__init__(name=name)
999
+ self.padding = _format_padding(padding, 3)
1000
+ if in_size is not None:
1001
+ self.in_size = in_size
1002
+ y = jax.eval_shape(
1003
+ functools.partial(self.update),
1004
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
1005
+ )
1006
+ self.out_size = y.shape
1007
+
1008
+ def update(self, x):
1009
+ # Add (0, 0) padding for non-spatial dimensions
1010
+ ndim = x.ndim
1011
+ if ndim == 4:
1012
+ # (depth, height, width, channels) -> pad depth, height and width
1013
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
1014
+ elif ndim == 5:
1015
+ # (batch, depth, height, width, channels) -> pad depth, height and width
1016
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
1017
+ else:
1018
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
1019
+
1020
+ return jnp.pad(x, pad_width, mode='wrap')