brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1020 +1,1020 @@
1
- # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Padding layers for neural networks.
18
-
19
- This module provides various padding operations for 1D, 2D, and 3D tensors:
20
- - ReflectionPad: Pads using reflection of the input boundary
21
- - ReplicationPad: Pads using replication of the input boundary
22
- - ZeroPad: Pads with zeros
23
- - ConstantPad: Pads with a constant value
24
- - CircularPad: Pads circularly (wrap around)
25
- """
26
-
27
- import functools
28
- from typing import Union, Sequence, Optional
29
-
30
- import jax
31
- import jax.numpy as jnp
32
-
33
- from brainstate import environ
34
- from brainstate.typing import Size
35
- from ._module import Module
36
-
37
- __all__ = [
38
- # Reflection padding
39
- 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
40
- # Replication padding
41
- 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d',
42
- # Zero padding
43
- 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d',
44
- # Constant padding
45
- 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d',
46
- # Circular padding
47
- 'CircularPad1d', 'CircularPad2d', 'CircularPad3d',
48
- ]
49
-
50
-
51
- def _format_padding(padding: Union[int, Sequence[int]], ndim: int) -> Sequence[tuple]:
52
- """
53
- Convert padding specification to format required by jax.numpy.pad.
54
-
55
- Args:
56
- padding: Padding size(s). Can be:
57
- - int: same padding for all sides
58
- - Sequence of length 2*ndim: (left, right) for each dimension
59
- - Sequence of length ndim: same padding for left and right of each dimension
60
- ndim: Number of spatial dimensions (1, 2, or 3)
61
-
62
- Returns:
63
- List of padding tuples for each dimension
64
- """
65
- if isinstance(padding, int):
66
- # Same padding for all sides of all dimensions
67
- return [(padding, padding) for _ in range(ndim)]
68
-
69
- padding = list(padding)
70
-
71
- if len(padding) == ndim:
72
- # Same padding for left and right of each dimension
73
- return [(p, p) for p in padding]
74
- elif len(padding) == 2 * ndim:
75
- # Different padding for each side: (left1, right1, left2, right2, ...)
76
- return [(padding[2 * i], padding[2 * i + 1]) for i in range(ndim)]
77
- else:
78
- raise ValueError(f"Padding must have length {ndim} or {2 * ndim}, got {len(padding)}")
79
-
80
-
81
- # =============================================================================
82
- # Reflection Padding
83
- # =============================================================================
84
-
85
- class ReflectionPad1d(Module):
86
- """
87
- Pads the input tensor using the reflection of the input boundary.
88
-
89
- Parameters
90
- ----------
91
- padding : int or Sequence[int]
92
- The size of the padding. Can be:
93
-
94
- - int: same padding for both sides
95
- - Sequence[int] of length 2: (left, right)
96
- in_size : Size, optional
97
- The input size.
98
- name : str, optional
99
- The name of the module.
100
-
101
- Examples
102
- --------
103
- .. code-block:: python
104
-
105
- >>> import brainstate as brainstate
106
- >>> import jax.numpy as jnp
107
- >>> pad = brainstate.nn.ReflectionPad1d(2)
108
- >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
109
- >>> output = pad(input)
110
- >>> print(output.shape)
111
- (1, 9, 1)
112
- """
113
-
114
- def __init__(
115
- self,
116
- padding: Union[int, Sequence[int]],
117
- in_size: Optional[Size] = None,
118
- name: Optional[str] = None
119
- ):
120
- super().__init__(name=name)
121
- self.padding = _format_padding(padding, 1)
122
- if in_size is not None:
123
- self.in_size = in_size
124
- y = jax.eval_shape(
125
- functools.partial(self.update),
126
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
127
- )
128
- self.out_size = y.shape
129
-
130
- def update(self, x):
131
- # Add (0, 0) padding for non-spatial dimensions
132
- ndim = x.ndim
133
- if ndim == 2:
134
- # (length, channels) -> pad only length dimension
135
- pad_width = [self.padding[0], (0, 0)]
136
- elif ndim == 3:
137
- # (batch, length, channels) -> pad only length dimension
138
- pad_width = [(0, 0), self.padding[0], (0, 0)]
139
- else:
140
- raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
141
-
142
- return jnp.pad(x, pad_width, mode='reflect')
143
-
144
-
145
- class ReflectionPad2d(Module):
146
- """
147
- Pads the input tensor using the reflection of the input boundary.
148
-
149
- Parameters
150
- ----------
151
- padding : int or Sequence[int]
152
- The size of the padding. Can be:
153
-
154
- - int: same padding for all sides
155
- - Sequence[int] of length 2: (height_pad, width_pad)
156
- - Sequence[int] of length 4: (left, right, top, bottom)
157
- in_size : Size, optional
158
- The input size.
159
- name : str, optional
160
- The name of the module.
161
-
162
- Examples
163
- --------
164
- .. code-block:: python
165
-
166
- >>> import brainstate as brainstate
167
- >>> import jax.numpy as jnp
168
- >>> pad = brainstate.nn.ReflectionPad2d(1)
169
- >>> input = jnp.ones((1, 4, 4, 3))
170
- >>> output = pad(input)
171
- >>> print(output.shape)
172
- (1, 6, 6, 3)
173
- """
174
-
175
- def __init__(
176
- self,
177
- padding: Union[int, Sequence[int]],
178
- in_size: Optional[Size] = None,
179
- name: Optional[str] = None
180
- ):
181
- super().__init__(name=name)
182
- self.padding = _format_padding(padding, 2)
183
- if in_size is not None:
184
- self.in_size = in_size
185
- y = jax.eval_shape(
186
- functools.partial(self.update),
187
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
188
- )
189
- self.out_size = y.shape
190
-
191
- def update(self, x):
192
- # Add (0, 0) padding for non-spatial dimensions
193
- ndim = x.ndim
194
- if ndim == 3:
195
- # (height, width, channels) -> pad height and width
196
- pad_width = [self.padding[0], self.padding[1], (0, 0)]
197
- elif ndim == 4:
198
- # (batch, height, width, channels) -> pad height and width
199
- pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
200
- else:
201
- raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
202
-
203
- return jnp.pad(x, pad_width, mode='reflect')
204
-
205
-
206
- class ReflectionPad3d(Module):
207
- """
208
- Pads the input tensor using the reflection of the input boundary.
209
-
210
- Parameters
211
- ----------
212
- padding : int or Sequence[int]
213
- The size of the padding. Can be:
214
-
215
- - int: same padding for all sides
216
- - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
217
- - Sequence[int] of length 6: (left, right, top, bottom, front, back)
218
- in_size : Size, optional
219
- The input size.
220
- name : str, optional
221
- The name of the module.
222
-
223
- Examples
224
- --------
225
- .. code-block:: python
226
-
227
- >>> import brainstate as brainstate
228
- >>> import jax.numpy as jnp
229
- >>> pad = brainstate.nn.ReflectionPad3d(1)
230
- >>> input = jnp.ones((1, 4, 4, 4, 3))
231
- >>> output = pad(input)
232
- >>> print(output.shape)
233
- (1, 6, 6, 6, 3)
234
- """
235
-
236
- def __init__(
237
- self,
238
- padding: Union[int, Sequence[int]],
239
- in_size: Optional[Size] = None,
240
- name: Optional[str] = None
241
- ):
242
- super().__init__(name=name)
243
- self.padding = _format_padding(padding, 3)
244
- if in_size is not None:
245
- self.in_size = in_size
246
- y = jax.eval_shape(
247
- functools.partial(self.update),
248
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
249
- )
250
- self.out_size = y.shape
251
-
252
- def update(self, x):
253
- # Add (0, 0) padding for non-spatial dimensions
254
- ndim = x.ndim
255
- if ndim == 4:
256
- # (depth, height, width, channels) -> pad depth, height and width
257
- pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
258
- elif ndim == 5:
259
- # (batch, depth, height, width, channels) -> pad depth, height and width
260
- pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
261
- else:
262
- raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
263
-
264
- return jnp.pad(x, pad_width, mode='reflect')
265
-
266
-
267
- # =============================================================================
268
- # Replication Padding
269
- # =============================================================================
270
-
271
- class ReplicationPad1d(Module):
272
- """
273
- Pads the input tensor using replication of the input boundary.
274
-
275
- Parameters
276
- ----------
277
- padding : int or Sequence[int]
278
- The size of the padding. Can be:
279
-
280
- - int: same padding for both sides
281
- - Sequence[int] of length 2: (left, right)
282
- in_size : Size, optional
283
- The input size.
284
- name : str, optional
285
- The name of the module.
286
-
287
- Examples
288
- --------
289
- .. code-block:: python
290
-
291
- >>> import brainstate as brainstate
292
- >>> import jax.numpy as jnp
293
- >>> pad = brainstate.nn.ReplicationPad1d(2)
294
- >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
295
- >>> output = pad(input)
296
- >>> print(output.shape)
297
- (1, 9, 1)
298
- """
299
-
300
- def __init__(
301
- self,
302
- padding: Union[int, Sequence[int]],
303
- in_size: Optional[Size] = None,
304
- name: Optional[str] = None
305
- ):
306
- super().__init__(name=name)
307
- self.padding = _format_padding(padding, 1)
308
- if in_size is not None:
309
- self.in_size = in_size
310
- y = jax.eval_shape(
311
- functools.partial(self.update),
312
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
313
- )
314
- self.out_size = y.shape
315
-
316
- def update(self, x):
317
- # Add (0, 0) padding for non-spatial dimensions
318
- ndim = x.ndim
319
- if ndim == 2:
320
- # (length, channels) -> pad only length dimension
321
- pad_width = [self.padding[0], (0, 0)]
322
- elif ndim == 3:
323
- # (batch, length, channels) -> pad only length dimension
324
- pad_width = [(0, 0), self.padding[0], (0, 0)]
325
- else:
326
- raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
327
-
328
- return jnp.pad(x, pad_width, mode='edge')
329
-
330
-
331
- class ReplicationPad2d(Module):
332
- """
333
- Pads the input tensor using replication of the input boundary.
334
-
335
- Parameters
336
- ----------
337
- padding : int or Sequence[int]
338
- The size of the padding. Can be:
339
-
340
- - int: same padding for all sides
341
- - Sequence[int] of length 2: (height_pad, width_pad)
342
- - Sequence[int] of length 4: (left, right, top, bottom)
343
- in_size : Size, optional
344
- The input size.
345
- name : str, optional
346
- The name of the module.
347
-
348
- Examples
349
- --------
350
- .. code-block:: python
351
-
352
- >>> import brainstate as brainstate
353
- >>> import jax.numpy as jnp
354
- >>> pad = brainstate.nn.ReplicationPad2d(1)
355
- >>> input = jnp.ones((1, 4, 4, 3))
356
- >>> output = pad(input)
357
- >>> print(output.shape)
358
- (1, 6, 6, 3)
359
- """
360
-
361
- def __init__(
362
- self,
363
- padding: Union[int, Sequence[int]],
364
- in_size: Optional[Size] = None,
365
- name: Optional[str] = None
366
- ):
367
- super().__init__(name=name)
368
- self.padding = _format_padding(padding, 2)
369
- if in_size is not None:
370
- self.in_size = in_size
371
- y = jax.eval_shape(
372
- functools.partial(self.update),
373
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
374
- )
375
- self.out_size = y.shape
376
-
377
- def update(self, x):
378
- # Add (0, 0) padding for non-spatial dimensions
379
- ndim = x.ndim
380
- if ndim == 3:
381
- # (height, width, channels) -> pad height and width
382
- pad_width = [self.padding[0], self.padding[1], (0, 0)]
383
- elif ndim == 4:
384
- # (batch, height, width, channels) -> pad height and width
385
- pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
386
- else:
387
- raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
388
-
389
- return jnp.pad(x, pad_width, mode='edge')
390
-
391
-
392
- class ReplicationPad3d(Module):
393
- """
394
- Pads the input tensor using replication of the input boundary.
395
-
396
- Parameters
397
- ----------
398
- padding : int or Sequence[int]
399
- The size of the padding. Can be:
400
-
401
- - int: same padding for all sides
402
- - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
403
- - Sequence[int] of length 6: (left, right, top, bottom, front, back)
404
- in_size : Size, optional
405
- The input size.
406
- name : str, optional
407
- The name of the module.
408
-
409
- Examples
410
- --------
411
- .. code-block:: python
412
-
413
- >>> import brainstate as brainstate
414
- >>> import jax.numpy as jnp
415
- >>> pad = brainstate.nn.ReplicationPad3d(1)
416
- >>> input = jnp.ones((1, 4, 4, 4, 3))
417
- >>> output = pad(input)
418
- >>> print(output.shape)
419
- (1, 6, 6, 6, 3)
420
- """
421
-
422
- def __init__(
423
- self,
424
- padding: Union[int, Sequence[int]],
425
- in_size: Optional[Size] = None,
426
- name: Optional[str] = None
427
- ):
428
- super().__init__(name=name)
429
- self.padding = _format_padding(padding, 3)
430
- if in_size is not None:
431
- self.in_size = in_size
432
- y = jax.eval_shape(
433
- functools.partial(self.update),
434
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
435
- )
436
- self.out_size = y.shape
437
-
438
- def update(self, x):
439
- # Add (0, 0) padding for non-spatial dimensions
440
- ndim = x.ndim
441
- if ndim == 4:
442
- # (depth, height, width, channels) -> pad depth, height and width
443
- pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
444
- elif ndim == 5:
445
- # (batch, depth, height, width, channels) -> pad depth, height and width
446
- pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
447
- else:
448
- raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
449
-
450
- return jnp.pad(x, pad_width, mode='edge')
451
-
452
-
453
- # =============================================================================
454
- # Zero Padding
455
- # =============================================================================
456
-
457
- class ZeroPad1d(Module):
458
- """
459
- Pads the input tensor with zeros.
460
-
461
- Parameters
462
- ----------
463
- padding : int or Sequence[int]
464
- The size of the padding. Can be:
465
-
466
- - int: same padding for both sides
467
- - Sequence[int] of length 2: (left, right)
468
- in_size : Size, optional
469
- The input size.
470
- name : str, optional
471
- The name of the module.
472
-
473
- Examples
474
- --------
475
- .. code-block:: python
476
-
477
- >>> import brainstate as brainstate
478
- >>> import jax.numpy as jnp
479
- >>> pad = brainstate.nn.ZeroPad1d(2)
480
- >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
481
- >>> output = pad(input)
482
- >>> print(output.shape)
483
- (1, 9, 1)
484
- """
485
-
486
- def __init__(
487
- self,
488
- padding: Union[int, Sequence[int]],
489
- in_size: Optional[Size] = None,
490
- name: Optional[str] = None
491
- ):
492
- super().__init__(name=name)
493
- self.padding = _format_padding(padding, 1)
494
- if in_size is not None:
495
- self.in_size = in_size
496
- y = jax.eval_shape(
497
- functools.partial(self.update),
498
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
499
- )
500
- self.out_size = y.shape
501
-
502
- def update(self, x):
503
- # Add (0, 0) padding for non-spatial dimensions
504
- ndim = x.ndim
505
- if ndim == 2:
506
- # (length, channels) -> pad only length dimension
507
- pad_width = [self.padding[0], (0, 0)]
508
- elif ndim == 3:
509
- # (batch, length, channels) -> pad only length dimension
510
- pad_width = [(0, 0), self.padding[0], (0, 0)]
511
- else:
512
- raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
513
-
514
- return jnp.pad(x, pad_width, mode='constant', constant_values=0)
515
-
516
-
517
- class ZeroPad2d(Module):
518
- """
519
- Pads the input tensor with zeros.
520
-
521
- Parameters
522
- ----------
523
- padding : int or Sequence[int]
524
- The size of the padding. Can be:
525
-
526
- - int: same padding for all sides
527
- - Sequence[int] of length 2: (height_pad, width_pad)
528
- - Sequence[int] of length 4: (left, right, top, bottom)
529
- in_size : Size, optional
530
- The input size.
531
- name : str, optional
532
- The name of the module.
533
-
534
- Examples
535
- --------
536
- .. code-block:: python
537
-
538
- >>> import brainstate as brainstate
539
- >>> import jax.numpy as jnp
540
- >>> pad = brainstate.nn.ZeroPad2d(1)
541
- >>> input = jnp.ones((1, 4, 4, 3))
542
- >>> output = pad(input)
543
- >>> print(output.shape)
544
- (1, 6, 6, 3)
545
- """
546
-
547
- def __init__(
548
- self,
549
- padding: Union[int, Sequence[int]],
550
- in_size: Optional[Size] = None,
551
- name: Optional[str] = None
552
- ):
553
- super().__init__(name=name)
554
- self.padding = _format_padding(padding, 2)
555
- if in_size is not None:
556
- self.in_size = in_size
557
- y = jax.eval_shape(
558
- functools.partial(self.update),
559
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
560
- )
561
- self.out_size = y.shape
562
-
563
- def update(self, x):
564
- # Add (0, 0) padding for non-spatial dimensions
565
- ndim = x.ndim
566
- if ndim == 3:
567
- # (height, width, channels) -> pad height and width
568
- pad_width = [self.padding[0], self.padding[1], (0, 0)]
569
- elif ndim == 4:
570
- # (batch, height, width, channels) -> pad height and width
571
- pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
572
- else:
573
- raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
574
-
575
- return jnp.pad(x, pad_width, mode='constant', constant_values=0)
576
-
577
-
578
- class ZeroPad3d(Module):
579
- """
580
- Pads the input tensor with zeros.
581
-
582
- Parameters
583
- ----------
584
- padding : int or Sequence[int]
585
- The size of the padding. Can be:
586
-
587
- - int: same padding for all sides
588
- - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
589
- - Sequence[int] of length 6: (left, right, top, bottom, front, back)
590
- in_size : Size, optional
591
- The input size.
592
- name : str, optional
593
- The name of the module.
594
-
595
- Examples
596
- --------
597
- .. code-block:: python
598
-
599
- >>> import brainstate as brainstate
600
- >>> import jax.numpy as jnp
601
- >>> pad = brainstate.nn.ZeroPad3d(1)
602
- >>> input = jnp.ones((1, 4, 4, 4, 3))
603
- >>> output = pad(input)
604
- >>> print(output.shape)
605
- (1, 6, 6, 6, 3)
606
- """
607
-
608
- def __init__(
609
- self,
610
- padding: Union[int, Sequence[int]],
611
- in_size: Optional[Size] = None,
612
- name: Optional[str] = None
613
- ):
614
- super().__init__(name=name)
615
- self.padding = _format_padding(padding, 3)
616
- if in_size is not None:
617
- self.in_size = in_size
618
- y = jax.eval_shape(
619
- functools.partial(self.update),
620
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
621
- )
622
- self.out_size = y.shape
623
-
624
- def update(self, x):
625
- # Add (0, 0) padding for non-spatial dimensions
626
- ndim = x.ndim
627
- if ndim == 4:
628
- # (depth, height, width, channels) -> pad depth, height and width
629
- pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
630
- elif ndim == 5:
631
- # (batch, depth, height, width, channels) -> pad depth, height and width
632
- pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
633
- else:
634
- raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
635
-
636
- return jnp.pad(x, pad_width, mode='constant', constant_values=0)
637
-
638
-
639
- # =============================================================================
640
- # Constant Padding
641
- # =============================================================================
642
-
643
- class ConstantPad1d(Module):
644
- """
645
- Pads the input tensor with a constant value.
646
-
647
- Parameters
648
- ----------
649
- padding : int or Sequence[int]
650
- The size of the padding. Can be:
651
-
652
- - int: same padding for both sides
653
- - Sequence[int] of length 2: (left, right)
654
- value : float, optional
655
- The constant value to use for padding. Default is 0.
656
- in_size : Size, optional
657
- The input size.
658
- name : str, optional
659
- The name of the module.
660
-
661
- Examples
662
- --------
663
- .. code-block:: python
664
-
665
- >>> import brainstate as brainstate
666
- >>> import jax.numpy as jnp
667
- >>> pad = brainstate.nn.ConstantPad1d(2, value=3.5)
668
- >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
669
- >>> output = pad(input)
670
- >>> print(output.shape)
671
- (1, 9, 1)
672
- """
673
-
674
- def __init__(
675
- self,
676
- padding: Union[int, Sequence[int]],
677
- value: float = 0,
678
- in_size: Optional[Size] = None,
679
- name: Optional[str] = None
680
- ):
681
- super().__init__(name=name)
682
- self.padding = _format_padding(padding, 1)
683
- self.value = value
684
- if in_size is not None:
685
- self.in_size = in_size
686
- y = jax.eval_shape(
687
- functools.partial(self.update),
688
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
689
- )
690
- self.out_size = y.shape
691
-
692
- def update(self, x):
693
- # Add (0, 0) padding for non-spatial dimensions
694
- ndim = x.ndim
695
- if ndim == 2:
696
- # (length, channels) -> pad only length dimension
697
- pad_width = [self.padding[0], (0, 0)]
698
- elif ndim == 3:
699
- # (batch, length, channels) -> pad only length dimension
700
- pad_width = [(0, 0), self.padding[0], (0, 0)]
701
- else:
702
- raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
703
-
704
- return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
705
-
706
-
707
- class ConstantPad2d(Module):
708
- """
709
- Pads the input tensor with a constant value.
710
-
711
- Parameters
712
- ----------
713
- padding : int or Sequence[int]
714
- The size of the padding. Can be:
715
-
716
- - int: same padding for all sides
717
- - Sequence[int] of length 2: (height_pad, width_pad)
718
- - Sequence[int] of length 4: (left, right, top, bottom)
719
- value : float, optional
720
- The constant value to use for padding. Default is 0.
721
- in_size : Size, optional
722
- The input size.
723
- name : str, optional
724
- The name of the module.
725
-
726
- Examples
727
- --------
728
- .. code-block:: python
729
-
730
- >>> import brainstate as brainstate
731
- >>> import jax.numpy as jnp
732
- >>> pad = brainstate.nn.ConstantPad2d(1, value=3.5)
733
- >>> input = jnp.ones((1, 4, 4, 3))
734
- >>> output = pad(input)
735
- >>> print(output.shape)
736
- (1, 6, 6, 3)
737
- """
738
-
739
- def __init__(
740
- self,
741
- padding: Union[int, Sequence[int]],
742
- value: float = 0,
743
- in_size: Optional[Size] = None,
744
- name: Optional[str] = None
745
- ):
746
- super().__init__(name=name)
747
- self.padding = _format_padding(padding, 2)
748
- self.value = value
749
- if in_size is not None:
750
- self.in_size = in_size
751
- y = jax.eval_shape(
752
- functools.partial(self.update),
753
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
754
- )
755
- self.out_size = y.shape
756
-
757
- def update(self, x):
758
- # Add (0, 0) padding for non-spatial dimensions
759
- ndim = x.ndim
760
- if ndim == 3:
761
- # (height, width, channels) -> pad height and width
762
- pad_width = [self.padding[0], self.padding[1], (0, 0)]
763
- elif ndim == 4:
764
- # (batch, height, width, channels) -> pad height and width
765
- pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
766
- else:
767
- raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
768
-
769
- return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
770
-
771
-
772
- class ConstantPad3d(Module):
773
- """
774
- Pads the input tensor with a constant value.
775
-
776
- Parameters
777
- ----------
778
- padding : int or Sequence[int]
779
- The size of the padding. Can be:
780
-
781
- - int: same padding for all sides
782
- - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
783
- - Sequence[int] of length 6: (left, right, top, bottom, front, back)
784
- value : float, optional
785
- The constant value to use for padding. Default is 0.
786
- in_size : Size, optional
787
- The input size.
788
- name : str, optional
789
- The name of the module.
790
-
791
- Examples
792
- --------
793
- .. code-block:: python
794
-
795
- >>> import brainstate as brainstate
796
- >>> import jax.numpy as jnp
797
- >>> pad = brainstate.nn.ConstantPad3d(1, value=3.5)
798
- >>> input = jnp.ones((1, 4, 4, 4, 3))
799
- >>> output = pad(input)
800
- >>> print(output.shape)
801
- (1, 6, 6, 6, 3)
802
- """
803
-
804
- def __init__(
805
- self,
806
- padding: Union[int, Sequence[int]],
807
- value: float = 0,
808
- in_size: Optional[Size] = None,
809
- name: Optional[str] = None
810
- ):
811
- super().__init__(name=name)
812
- self.padding = _format_padding(padding, 3)
813
- self.value = value
814
- if in_size is not None:
815
- self.in_size = in_size
816
- y = jax.eval_shape(
817
- functools.partial(self.update),
818
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
819
- )
820
- self.out_size = y.shape
821
-
822
- def update(self, x):
823
- # Add (0, 0) padding for non-spatial dimensions
824
- ndim = x.ndim
825
- if ndim == 4:
826
- # (depth, height, width, channels) -> pad depth, height and width
827
- pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
828
- elif ndim == 5:
829
- # (batch, depth, height, width, channels) -> pad depth, height and width
830
- pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
831
- else:
832
- raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
833
-
834
- return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
835
-
836
-
837
- # =============================================================================
838
- # Circular Padding
839
- # =============================================================================
840
-
841
- class CircularPad1d(Module):
842
- """
843
- Pads the input tensor using circular padding (wrap around).
844
-
845
- Parameters
846
- ----------
847
- padding : int or Sequence[int]
848
- The size of the padding. Can be:
849
-
850
- - int: same padding for both sides
851
- - Sequence[int] of length 2: (left, right)
852
- in_size : Size, optional
853
- The input size.
854
- name : str, optional
855
- The name of the module.
856
-
857
- Examples
858
- --------
859
- .. code-block:: python
860
-
861
- >>> import brainstate as brainstate
862
- >>> import jax.numpy as jnp
863
- >>> pad = brainstate.nn.CircularPad1d(2)
864
- >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
865
- >>> output = pad(input)
866
- >>> print(output.shape)
867
- (1, 9, 1)
868
- """
869
-
870
- def __init__(
871
- self,
872
- padding: Union[int, Sequence[int]],
873
- in_size: Optional[Size] = None,
874
- name: Optional[str] = None
875
- ):
876
- super().__init__(name=name)
877
- self.padding = _format_padding(padding, 1)
878
- if in_size is not None:
879
- self.in_size = in_size
880
- y = jax.eval_shape(
881
- functools.partial(self.update),
882
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
883
- )
884
- self.out_size = y.shape
885
-
886
- def update(self, x):
887
- # Add (0, 0) padding for non-spatial dimensions
888
- ndim = x.ndim
889
- if ndim == 2:
890
- # (length, channels) -> pad only length dimension
891
- pad_width = [self.padding[0], (0, 0)]
892
- elif ndim == 3:
893
- # (batch, length, channels) -> pad only length dimension
894
- pad_width = [(0, 0), self.padding[0], (0, 0)]
895
- else:
896
- raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
897
-
898
- return jnp.pad(x, pad_width, mode='wrap')
899
-
900
-
901
- class CircularPad2d(Module):
902
- """
903
- Pads the input tensor using circular padding (wrap around).
904
-
905
- Parameters
906
- ----------
907
- padding : int or Sequence[int]
908
- The size of the padding. Can be:
909
-
910
- - int: same padding for all sides
911
- - Sequence[int] of length 2: (height_pad, width_pad)
912
- - Sequence[int] of length 4: (left, right, top, bottom)
913
- in_size : Size, optional
914
- The input size.
915
- name : str, optional
916
- The name of the module.
917
-
918
- Examples
919
- --------
920
- .. code-block:: python
921
-
922
- >>> import brainstate as brainstate
923
- >>> import jax.numpy as jnp
924
- >>> pad = brainstate.nn.CircularPad2d(1)
925
- >>> input = jnp.ones((1, 4, 4, 3))
926
- >>> output = pad(input)
927
- >>> print(output.shape)
928
- (1, 6, 6, 3)
929
- """
930
-
931
- def __init__(
932
- self,
933
- padding: Union[int, Sequence[int]],
934
- in_size: Optional[Size] = None,
935
- name: Optional[str] = None
936
- ):
937
- super().__init__(name=name)
938
- self.padding = _format_padding(padding, 2)
939
- if in_size is not None:
940
- self.in_size = in_size
941
- y = jax.eval_shape(
942
- functools.partial(self.update),
943
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
944
- )
945
- self.out_size = y.shape
946
-
947
- def update(self, x):
948
- # Add (0, 0) padding for non-spatial dimensions
949
- ndim = x.ndim
950
- if ndim == 3:
951
- # (height, width, channels) -> pad height and width
952
- pad_width = [self.padding[0], self.padding[1], (0, 0)]
953
- elif ndim == 4:
954
- # (batch, height, width, channels) -> pad height and width
955
- pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
956
- else:
957
- raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
958
-
959
- return jnp.pad(x, pad_width, mode='wrap')
960
-
961
-
962
- class CircularPad3d(Module):
963
- """
964
- Pads the input tensor using circular padding (wrap around).
965
-
966
- Parameters
967
- ----------
968
- padding : int or Sequence[int]
969
- The size of the padding. Can be:
970
-
971
- - int: same padding for all sides
972
- - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
973
- - Sequence[int] of length 6: (left, right, top, bottom, front, back)
974
- in_size : Size, optional
975
- The input size.
976
- name : str, optional
977
- The name of the module.
978
-
979
- Examples
980
- --------
981
- .. code-block:: python
982
-
983
- >>> import brainstate as brainstate
984
- >>> import jax.numpy as jnp
985
- >>> pad = brainstate.nn.CircularPad3d(1)
986
- >>> input = jnp.ones((1, 4, 4, 4, 3))
987
- >>> output = pad(input)
988
- >>> print(output.shape)
989
- (1, 6, 6, 6, 3)
990
- """
991
-
992
- def __init__(
993
- self,
994
- padding: Union[int, Sequence[int]],
995
- in_size: Optional[Size] = None,
996
- name: Optional[str] = None
997
- ):
998
- super().__init__(name=name)
999
- self.padding = _format_padding(padding, 3)
1000
- if in_size is not None:
1001
- self.in_size = in_size
1002
- y = jax.eval_shape(
1003
- functools.partial(self.update),
1004
- jax.ShapeDtypeStruct(self.in_size, environ.dftype())
1005
- )
1006
- self.out_size = y.shape
1007
-
1008
- def update(self, x):
1009
- # Add (0, 0) padding for non-spatial dimensions
1010
- ndim = x.ndim
1011
- if ndim == 4:
1012
- # (depth, height, width, channels) -> pad depth, height and width
1013
- pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
1014
- elif ndim == 5:
1015
- # (batch, depth, height, width, channels) -> pad depth, height and width
1016
- pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
1017
- else:
1018
- raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
1019
-
1020
- return jnp.pad(x, pad_width, mode='wrap')
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Padding layers for neural networks.
18
+
19
+ This module provides various padding operations for 1D, 2D, and 3D tensors:
20
+ - ReflectionPad: Pads using reflection of the input boundary
21
+ - ReplicationPad: Pads using replication of the input boundary
22
+ - ZeroPad: Pads with zeros
23
+ - ConstantPad: Pads with a constant value
24
+ - CircularPad: Pads circularly (wrap around)
25
+ """
26
+
27
+ import functools
28
+ from typing import Union, Sequence, Optional
29
+
30
+ import jax
31
+ import jax.numpy as jnp
32
+
33
+ from brainstate import environ
34
+ from brainstate.typing import Size
35
+ from ._module import Module
36
+
37
+ __all__ = [
38
+ # Reflection padding
39
+ 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
40
+ # Replication padding
41
+ 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d',
42
+ # Zero padding
43
+ 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d',
44
+ # Constant padding
45
+ 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d',
46
+ # Circular padding
47
+ 'CircularPad1d', 'CircularPad2d', 'CircularPad3d',
48
+ ]
49
+
50
+
51
+ def _format_padding(padding: Union[int, Sequence[int]], ndim: int) -> Sequence[tuple]:
52
+ """
53
+ Convert padding specification to format required by jax.numpy.pad.
54
+
55
+ Args:
56
+ padding: Padding size(s). Can be:
57
+ - int: same padding for all sides
58
+ - Sequence of length 2*ndim: (left, right) for each dimension
59
+ - Sequence of length ndim: same padding for left and right of each dimension
60
+ ndim: Number of spatial dimensions (1, 2, or 3)
61
+
62
+ Returns:
63
+ List of padding tuples for each dimension
64
+ """
65
+ if isinstance(padding, int):
66
+ # Same padding for all sides of all dimensions
67
+ return [(padding, padding) for _ in range(ndim)]
68
+
69
+ padding = list(padding)
70
+
71
+ if len(padding) == ndim:
72
+ # Same padding for left and right of each dimension
73
+ return [(p, p) for p in padding]
74
+ elif len(padding) == 2 * ndim:
75
+ # Different padding for each side: (left1, right1, left2, right2, ...)
76
+ return [(padding[2 * i], padding[2 * i + 1]) for i in range(ndim)]
77
+ else:
78
+ raise ValueError(f"Padding must have length {ndim} or {2 * ndim}, got {len(padding)}")
79
+
80
+
81
+ # =============================================================================
82
+ # Reflection Padding
83
+ # =============================================================================
84
+
85
+ class ReflectionPad1d(Module):
86
+ """
87
+ Pads the input tensor using the reflection of the input boundary.
88
+
89
+ Parameters
90
+ ----------
91
+ padding : int or Sequence[int]
92
+ The size of the padding. Can be:
93
+
94
+ - int: same padding for both sides
95
+ - Sequence[int] of length 2: (left, right)
96
+ in_size : Size, optional
97
+ The input size.
98
+ name : str, optional
99
+ The name of the module.
100
+
101
+ Examples
102
+ --------
103
+ .. code-block:: python
104
+
105
+ >>> import brainstate as brainstate
106
+ >>> import jax.numpy as jnp
107
+ >>> pad = brainstate.nn.ReflectionPad1d(2)
108
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
109
+ >>> output = pad(input)
110
+ >>> print(output.shape)
111
+ (1, 9, 1)
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ padding: Union[int, Sequence[int]],
117
+ in_size: Optional[Size] = None,
118
+ name: Optional[str] = None
119
+ ):
120
+ super().__init__(name=name)
121
+ self.padding = _format_padding(padding, 1)
122
+ if in_size is not None:
123
+ self.in_size = in_size
124
+ y = jax.eval_shape(
125
+ functools.partial(self.update),
126
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
127
+ )
128
+ self.out_size = y.shape
129
+
130
+ def update(self, x):
131
+ # Add (0, 0) padding for non-spatial dimensions
132
+ ndim = x.ndim
133
+ if ndim == 2:
134
+ # (length, channels) -> pad only length dimension
135
+ pad_width = [self.padding[0], (0, 0)]
136
+ elif ndim == 3:
137
+ # (batch, length, channels) -> pad only length dimension
138
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
139
+ else:
140
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
141
+
142
+ return jnp.pad(x, pad_width, mode='reflect')
143
+
144
+
145
+ class ReflectionPad2d(Module):
146
+ """
147
+ Pads the input tensor using the reflection of the input boundary.
148
+
149
+ Parameters
150
+ ----------
151
+ padding : int or Sequence[int]
152
+ The size of the padding. Can be:
153
+
154
+ - int: same padding for all sides
155
+ - Sequence[int] of length 2: (height_pad, width_pad)
156
+ - Sequence[int] of length 4: (left, right, top, bottom)
157
+ in_size : Size, optional
158
+ The input size.
159
+ name : str, optional
160
+ The name of the module.
161
+
162
+ Examples
163
+ --------
164
+ .. code-block:: python
165
+
166
+ >>> import brainstate as brainstate
167
+ >>> import jax.numpy as jnp
168
+ >>> pad = brainstate.nn.ReflectionPad2d(1)
169
+ >>> input = jnp.ones((1, 4, 4, 3))
170
+ >>> output = pad(input)
171
+ >>> print(output.shape)
172
+ (1, 6, 6, 3)
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ padding: Union[int, Sequence[int]],
178
+ in_size: Optional[Size] = None,
179
+ name: Optional[str] = None
180
+ ):
181
+ super().__init__(name=name)
182
+ self.padding = _format_padding(padding, 2)
183
+ if in_size is not None:
184
+ self.in_size = in_size
185
+ y = jax.eval_shape(
186
+ functools.partial(self.update),
187
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
188
+ )
189
+ self.out_size = y.shape
190
+
191
+ def update(self, x):
192
+ # Add (0, 0) padding for non-spatial dimensions
193
+ ndim = x.ndim
194
+ if ndim == 3:
195
+ # (height, width, channels) -> pad height and width
196
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
197
+ elif ndim == 4:
198
+ # (batch, height, width, channels) -> pad height and width
199
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
200
+ else:
201
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
202
+
203
+ return jnp.pad(x, pad_width, mode='reflect')
204
+
205
+
206
+ class ReflectionPad3d(Module):
207
+ """
208
+ Pads the input tensor using the reflection of the input boundary.
209
+
210
+ Parameters
211
+ ----------
212
+ padding : int or Sequence[int]
213
+ The size of the padding. Can be:
214
+
215
+ - int: same padding for all sides
216
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
217
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
218
+ in_size : Size, optional
219
+ The input size.
220
+ name : str, optional
221
+ The name of the module.
222
+
223
+ Examples
224
+ --------
225
+ .. code-block:: python
226
+
227
+ >>> import brainstate as brainstate
228
+ >>> import jax.numpy as jnp
229
+ >>> pad = brainstate.nn.ReflectionPad3d(1)
230
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
231
+ >>> output = pad(input)
232
+ >>> print(output.shape)
233
+ (1, 6, 6, 6, 3)
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ padding: Union[int, Sequence[int]],
239
+ in_size: Optional[Size] = None,
240
+ name: Optional[str] = None
241
+ ):
242
+ super().__init__(name=name)
243
+ self.padding = _format_padding(padding, 3)
244
+ if in_size is not None:
245
+ self.in_size = in_size
246
+ y = jax.eval_shape(
247
+ functools.partial(self.update),
248
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
249
+ )
250
+ self.out_size = y.shape
251
+
252
+ def update(self, x):
253
+ # Add (0, 0) padding for non-spatial dimensions
254
+ ndim = x.ndim
255
+ if ndim == 4:
256
+ # (depth, height, width, channels) -> pad depth, height and width
257
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
258
+ elif ndim == 5:
259
+ # (batch, depth, height, width, channels) -> pad depth, height and width
260
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
261
+ else:
262
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
263
+
264
+ return jnp.pad(x, pad_width, mode='reflect')
265
+
266
+
267
+ # =============================================================================
268
+ # Replication Padding
269
+ # =============================================================================
270
+
271
+ class ReplicationPad1d(Module):
272
+ """
273
+ Pads the input tensor using replication of the input boundary.
274
+
275
+ Parameters
276
+ ----------
277
+ padding : int or Sequence[int]
278
+ The size of the padding. Can be:
279
+
280
+ - int: same padding for both sides
281
+ - Sequence[int] of length 2: (left, right)
282
+ in_size : Size, optional
283
+ The input size.
284
+ name : str, optional
285
+ The name of the module.
286
+
287
+ Examples
288
+ --------
289
+ .. code-block:: python
290
+
291
+ >>> import brainstate as brainstate
292
+ >>> import jax.numpy as jnp
293
+ >>> pad = brainstate.nn.ReplicationPad1d(2)
294
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
295
+ >>> output = pad(input)
296
+ >>> print(output.shape)
297
+ (1, 9, 1)
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ padding: Union[int, Sequence[int]],
303
+ in_size: Optional[Size] = None,
304
+ name: Optional[str] = None
305
+ ):
306
+ super().__init__(name=name)
307
+ self.padding = _format_padding(padding, 1)
308
+ if in_size is not None:
309
+ self.in_size = in_size
310
+ y = jax.eval_shape(
311
+ functools.partial(self.update),
312
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
313
+ )
314
+ self.out_size = y.shape
315
+
316
+ def update(self, x):
317
+ # Add (0, 0) padding for non-spatial dimensions
318
+ ndim = x.ndim
319
+ if ndim == 2:
320
+ # (length, channels) -> pad only length dimension
321
+ pad_width = [self.padding[0], (0, 0)]
322
+ elif ndim == 3:
323
+ # (batch, length, channels) -> pad only length dimension
324
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
325
+ else:
326
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
327
+
328
+ return jnp.pad(x, pad_width, mode='edge')
329
+
330
+
331
+ class ReplicationPad2d(Module):
332
+ """
333
+ Pads the input tensor using replication of the input boundary.
334
+
335
+ Parameters
336
+ ----------
337
+ padding : int or Sequence[int]
338
+ The size of the padding. Can be:
339
+
340
+ - int: same padding for all sides
341
+ - Sequence[int] of length 2: (height_pad, width_pad)
342
+ - Sequence[int] of length 4: (left, right, top, bottom)
343
+ in_size : Size, optional
344
+ The input size.
345
+ name : str, optional
346
+ The name of the module.
347
+
348
+ Examples
349
+ --------
350
+ .. code-block:: python
351
+
352
+ >>> import brainstate as brainstate
353
+ >>> import jax.numpy as jnp
354
+ >>> pad = brainstate.nn.ReplicationPad2d(1)
355
+ >>> input = jnp.ones((1, 4, 4, 3))
356
+ >>> output = pad(input)
357
+ >>> print(output.shape)
358
+ (1, 6, 6, 3)
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ padding: Union[int, Sequence[int]],
364
+ in_size: Optional[Size] = None,
365
+ name: Optional[str] = None
366
+ ):
367
+ super().__init__(name=name)
368
+ self.padding = _format_padding(padding, 2)
369
+ if in_size is not None:
370
+ self.in_size = in_size
371
+ y = jax.eval_shape(
372
+ functools.partial(self.update),
373
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
374
+ )
375
+ self.out_size = y.shape
376
+
377
+ def update(self, x):
378
+ # Add (0, 0) padding for non-spatial dimensions
379
+ ndim = x.ndim
380
+ if ndim == 3:
381
+ # (height, width, channels) -> pad height and width
382
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
383
+ elif ndim == 4:
384
+ # (batch, height, width, channels) -> pad height and width
385
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
386
+ else:
387
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
388
+
389
+ return jnp.pad(x, pad_width, mode='edge')
390
+
391
+
392
+ class ReplicationPad3d(Module):
393
+ """
394
+ Pads the input tensor using replication of the input boundary.
395
+
396
+ Parameters
397
+ ----------
398
+ padding : int or Sequence[int]
399
+ The size of the padding. Can be:
400
+
401
+ - int: same padding for all sides
402
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
403
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
404
+ in_size : Size, optional
405
+ The input size.
406
+ name : str, optional
407
+ The name of the module.
408
+
409
+ Examples
410
+ --------
411
+ .. code-block:: python
412
+
413
+ >>> import brainstate as brainstate
414
+ >>> import jax.numpy as jnp
415
+ >>> pad = brainstate.nn.ReplicationPad3d(1)
416
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
417
+ >>> output = pad(input)
418
+ >>> print(output.shape)
419
+ (1, 6, 6, 6, 3)
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ padding: Union[int, Sequence[int]],
425
+ in_size: Optional[Size] = None,
426
+ name: Optional[str] = None
427
+ ):
428
+ super().__init__(name=name)
429
+ self.padding = _format_padding(padding, 3)
430
+ if in_size is not None:
431
+ self.in_size = in_size
432
+ y = jax.eval_shape(
433
+ functools.partial(self.update),
434
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
435
+ )
436
+ self.out_size = y.shape
437
+
438
+ def update(self, x):
439
+ # Add (0, 0) padding for non-spatial dimensions
440
+ ndim = x.ndim
441
+ if ndim == 4:
442
+ # (depth, height, width, channels) -> pad depth, height and width
443
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
444
+ elif ndim == 5:
445
+ # (batch, depth, height, width, channels) -> pad depth, height and width
446
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
447
+ else:
448
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
449
+
450
+ return jnp.pad(x, pad_width, mode='edge')
451
+
452
+
453
+ # =============================================================================
454
+ # Zero Padding
455
+ # =============================================================================
456
+
457
+ class ZeroPad1d(Module):
458
+ """
459
+ Pads the input tensor with zeros.
460
+
461
+ Parameters
462
+ ----------
463
+ padding : int or Sequence[int]
464
+ The size of the padding. Can be:
465
+
466
+ - int: same padding for both sides
467
+ - Sequence[int] of length 2: (left, right)
468
+ in_size : Size, optional
469
+ The input size.
470
+ name : str, optional
471
+ The name of the module.
472
+
473
+ Examples
474
+ --------
475
+ .. code-block:: python
476
+
477
+ >>> import brainstate as brainstate
478
+ >>> import jax.numpy as jnp
479
+ >>> pad = brainstate.nn.ZeroPad1d(2)
480
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
481
+ >>> output = pad(input)
482
+ >>> print(output.shape)
483
+ (1, 9, 1)
484
+ """
485
+
486
+ def __init__(
487
+ self,
488
+ padding: Union[int, Sequence[int]],
489
+ in_size: Optional[Size] = None,
490
+ name: Optional[str] = None
491
+ ):
492
+ super().__init__(name=name)
493
+ self.padding = _format_padding(padding, 1)
494
+ if in_size is not None:
495
+ self.in_size = in_size
496
+ y = jax.eval_shape(
497
+ functools.partial(self.update),
498
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
499
+ )
500
+ self.out_size = y.shape
501
+
502
+ def update(self, x):
503
+ # Add (0, 0) padding for non-spatial dimensions
504
+ ndim = x.ndim
505
+ if ndim == 2:
506
+ # (length, channels) -> pad only length dimension
507
+ pad_width = [self.padding[0], (0, 0)]
508
+ elif ndim == 3:
509
+ # (batch, length, channels) -> pad only length dimension
510
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
511
+ else:
512
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
513
+
514
+ return jnp.pad(x, pad_width, mode='constant', constant_values=0)
515
+
516
+
517
+ class ZeroPad2d(Module):
518
+ """
519
+ Pads the input tensor with zeros.
520
+
521
+ Parameters
522
+ ----------
523
+ padding : int or Sequence[int]
524
+ The size of the padding. Can be:
525
+
526
+ - int: same padding for all sides
527
+ - Sequence[int] of length 2: (height_pad, width_pad)
528
+ - Sequence[int] of length 4: (left, right, top, bottom)
529
+ in_size : Size, optional
530
+ The input size.
531
+ name : str, optional
532
+ The name of the module.
533
+
534
+ Examples
535
+ --------
536
+ .. code-block:: python
537
+
538
+ >>> import brainstate as brainstate
539
+ >>> import jax.numpy as jnp
540
+ >>> pad = brainstate.nn.ZeroPad2d(1)
541
+ >>> input = jnp.ones((1, 4, 4, 3))
542
+ >>> output = pad(input)
543
+ >>> print(output.shape)
544
+ (1, 6, 6, 3)
545
+ """
546
+
547
+ def __init__(
548
+ self,
549
+ padding: Union[int, Sequence[int]],
550
+ in_size: Optional[Size] = None,
551
+ name: Optional[str] = None
552
+ ):
553
+ super().__init__(name=name)
554
+ self.padding = _format_padding(padding, 2)
555
+ if in_size is not None:
556
+ self.in_size = in_size
557
+ y = jax.eval_shape(
558
+ functools.partial(self.update),
559
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
560
+ )
561
+ self.out_size = y.shape
562
+
563
+ def update(self, x):
564
+ # Add (0, 0) padding for non-spatial dimensions
565
+ ndim = x.ndim
566
+ if ndim == 3:
567
+ # (height, width, channels) -> pad height and width
568
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
569
+ elif ndim == 4:
570
+ # (batch, height, width, channels) -> pad height and width
571
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
572
+ else:
573
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
574
+
575
+ return jnp.pad(x, pad_width, mode='constant', constant_values=0)
576
+
577
+
578
+ class ZeroPad3d(Module):
579
+ """
580
+ Pads the input tensor with zeros.
581
+
582
+ Parameters
583
+ ----------
584
+ padding : int or Sequence[int]
585
+ The size of the padding. Can be:
586
+
587
+ - int: same padding for all sides
588
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
589
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
590
+ in_size : Size, optional
591
+ The input size.
592
+ name : str, optional
593
+ The name of the module.
594
+
595
+ Examples
596
+ --------
597
+ .. code-block:: python
598
+
599
+ >>> import brainstate as brainstate
600
+ >>> import jax.numpy as jnp
601
+ >>> pad = brainstate.nn.ZeroPad3d(1)
602
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
603
+ >>> output = pad(input)
604
+ >>> print(output.shape)
605
+ (1, 6, 6, 6, 3)
606
+ """
607
+
608
+ def __init__(
609
+ self,
610
+ padding: Union[int, Sequence[int]],
611
+ in_size: Optional[Size] = None,
612
+ name: Optional[str] = None
613
+ ):
614
+ super().__init__(name=name)
615
+ self.padding = _format_padding(padding, 3)
616
+ if in_size is not None:
617
+ self.in_size = in_size
618
+ y = jax.eval_shape(
619
+ functools.partial(self.update),
620
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
621
+ )
622
+ self.out_size = y.shape
623
+
624
+ def update(self, x):
625
+ # Add (0, 0) padding for non-spatial dimensions
626
+ ndim = x.ndim
627
+ if ndim == 4:
628
+ # (depth, height, width, channels) -> pad depth, height and width
629
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
630
+ elif ndim == 5:
631
+ # (batch, depth, height, width, channels) -> pad depth, height and width
632
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
633
+ else:
634
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
635
+
636
+ return jnp.pad(x, pad_width, mode='constant', constant_values=0)
637
+
638
+
639
+ # =============================================================================
640
+ # Constant Padding
641
+ # =============================================================================
642
+
643
+ class ConstantPad1d(Module):
644
+ """
645
+ Pads the input tensor with a constant value.
646
+
647
+ Parameters
648
+ ----------
649
+ padding : int or Sequence[int]
650
+ The size of the padding. Can be:
651
+
652
+ - int: same padding for both sides
653
+ - Sequence[int] of length 2: (left, right)
654
+ value : float, optional
655
+ The constant value to use for padding. Default is 0.
656
+ in_size : Size, optional
657
+ The input size.
658
+ name : str, optional
659
+ The name of the module.
660
+
661
+ Examples
662
+ --------
663
+ .. code-block:: python
664
+
665
+ >>> import brainstate as brainstate
666
+ >>> import jax.numpy as jnp
667
+ >>> pad = brainstate.nn.ConstantPad1d(2, value=3.5)
668
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
669
+ >>> output = pad(input)
670
+ >>> print(output.shape)
671
+ (1, 9, 1)
672
+ """
673
+
674
+ def __init__(
675
+ self,
676
+ padding: Union[int, Sequence[int]],
677
+ value: float = 0,
678
+ in_size: Optional[Size] = None,
679
+ name: Optional[str] = None
680
+ ):
681
+ super().__init__(name=name)
682
+ self.padding = _format_padding(padding, 1)
683
+ self.value = value
684
+ if in_size is not None:
685
+ self.in_size = in_size
686
+ y = jax.eval_shape(
687
+ functools.partial(self.update),
688
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
689
+ )
690
+ self.out_size = y.shape
691
+
692
+ def update(self, x):
693
+ # Add (0, 0) padding for non-spatial dimensions
694
+ ndim = x.ndim
695
+ if ndim == 2:
696
+ # (length, channels) -> pad only length dimension
697
+ pad_width = [self.padding[0], (0, 0)]
698
+ elif ndim == 3:
699
+ # (batch, length, channels) -> pad only length dimension
700
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
701
+ else:
702
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
703
+
704
+ return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
705
+
706
+
707
+ class ConstantPad2d(Module):
708
+ """
709
+ Pads the input tensor with a constant value.
710
+
711
+ Parameters
712
+ ----------
713
+ padding : int or Sequence[int]
714
+ The size of the padding. Can be:
715
+
716
+ - int: same padding for all sides
717
+ - Sequence[int] of length 2: (height_pad, width_pad)
718
+ - Sequence[int] of length 4: (left, right, top, bottom)
719
+ value : float, optional
720
+ The constant value to use for padding. Default is 0.
721
+ in_size : Size, optional
722
+ The input size.
723
+ name : str, optional
724
+ The name of the module.
725
+
726
+ Examples
727
+ --------
728
+ .. code-block:: python
729
+
730
+ >>> import brainstate as brainstate
731
+ >>> import jax.numpy as jnp
732
+ >>> pad = brainstate.nn.ConstantPad2d(1, value=3.5)
733
+ >>> input = jnp.ones((1, 4, 4, 3))
734
+ >>> output = pad(input)
735
+ >>> print(output.shape)
736
+ (1, 6, 6, 3)
737
+ """
738
+
739
+ def __init__(
740
+ self,
741
+ padding: Union[int, Sequence[int]],
742
+ value: float = 0,
743
+ in_size: Optional[Size] = None,
744
+ name: Optional[str] = None
745
+ ):
746
+ super().__init__(name=name)
747
+ self.padding = _format_padding(padding, 2)
748
+ self.value = value
749
+ if in_size is not None:
750
+ self.in_size = in_size
751
+ y = jax.eval_shape(
752
+ functools.partial(self.update),
753
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
754
+ )
755
+ self.out_size = y.shape
756
+
757
+ def update(self, x):
758
+ # Add (0, 0) padding for non-spatial dimensions
759
+ ndim = x.ndim
760
+ if ndim == 3:
761
+ # (height, width, channels) -> pad height and width
762
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
763
+ elif ndim == 4:
764
+ # (batch, height, width, channels) -> pad height and width
765
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
766
+ else:
767
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
768
+
769
+ return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
770
+
771
+
772
+ class ConstantPad3d(Module):
773
+ """
774
+ Pads the input tensor with a constant value.
775
+
776
+ Parameters
777
+ ----------
778
+ padding : int or Sequence[int]
779
+ The size of the padding. Can be:
780
+
781
+ - int: same padding for all sides
782
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
783
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
784
+ value : float, optional
785
+ The constant value to use for padding. Default is 0.
786
+ in_size : Size, optional
787
+ The input size.
788
+ name : str, optional
789
+ The name of the module.
790
+
791
+ Examples
792
+ --------
793
+ .. code-block:: python
794
+
795
+ >>> import brainstate as brainstate
796
+ >>> import jax.numpy as jnp
797
+ >>> pad = brainstate.nn.ConstantPad3d(1, value=3.5)
798
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
799
+ >>> output = pad(input)
800
+ >>> print(output.shape)
801
+ (1, 6, 6, 6, 3)
802
+ """
803
+
804
+ def __init__(
805
+ self,
806
+ padding: Union[int, Sequence[int]],
807
+ value: float = 0,
808
+ in_size: Optional[Size] = None,
809
+ name: Optional[str] = None
810
+ ):
811
+ super().__init__(name=name)
812
+ self.padding = _format_padding(padding, 3)
813
+ self.value = value
814
+ if in_size is not None:
815
+ self.in_size = in_size
816
+ y = jax.eval_shape(
817
+ functools.partial(self.update),
818
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
819
+ )
820
+ self.out_size = y.shape
821
+
822
+ def update(self, x):
823
+ # Add (0, 0) padding for non-spatial dimensions
824
+ ndim = x.ndim
825
+ if ndim == 4:
826
+ # (depth, height, width, channels) -> pad depth, height and width
827
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
828
+ elif ndim == 5:
829
+ # (batch, depth, height, width, channels) -> pad depth, height and width
830
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
831
+ else:
832
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
833
+
834
+ return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
835
+
836
+
837
+ # =============================================================================
838
+ # Circular Padding
839
+ # =============================================================================
840
+
841
+ class CircularPad1d(Module):
842
+ """
843
+ Pads the input tensor using circular padding (wrap around).
844
+
845
+ Parameters
846
+ ----------
847
+ padding : int or Sequence[int]
848
+ The size of the padding. Can be:
849
+
850
+ - int: same padding for both sides
851
+ - Sequence[int] of length 2: (left, right)
852
+ in_size : Size, optional
853
+ The input size.
854
+ name : str, optional
855
+ The name of the module.
856
+
857
+ Examples
858
+ --------
859
+ .. code-block:: python
860
+
861
+ >>> import brainstate as brainstate
862
+ >>> import jax.numpy as jnp
863
+ >>> pad = brainstate.nn.CircularPad1d(2)
864
+ >>> input = jnp.array([[[1, 2, 3, 4, 5]]])
865
+ >>> output = pad(input)
866
+ >>> print(output.shape)
867
+ (1, 9, 1)
868
+ """
869
+
870
+ def __init__(
871
+ self,
872
+ padding: Union[int, Sequence[int]],
873
+ in_size: Optional[Size] = None,
874
+ name: Optional[str] = None
875
+ ):
876
+ super().__init__(name=name)
877
+ self.padding = _format_padding(padding, 1)
878
+ if in_size is not None:
879
+ self.in_size = in_size
880
+ y = jax.eval_shape(
881
+ functools.partial(self.update),
882
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
883
+ )
884
+ self.out_size = y.shape
885
+
886
+ def update(self, x):
887
+ # Add (0, 0) padding for non-spatial dimensions
888
+ ndim = x.ndim
889
+ if ndim == 2:
890
+ # (length, channels) -> pad only length dimension
891
+ pad_width = [self.padding[0], (0, 0)]
892
+ elif ndim == 3:
893
+ # (batch, length, channels) -> pad only length dimension
894
+ pad_width = [(0, 0), self.padding[0], (0, 0)]
895
+ else:
896
+ raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
897
+
898
+ return jnp.pad(x, pad_width, mode='wrap')
899
+
900
+
901
+ class CircularPad2d(Module):
902
+ """
903
+ Pads the input tensor using circular padding (wrap around).
904
+
905
+ Parameters
906
+ ----------
907
+ padding : int or Sequence[int]
908
+ The size of the padding. Can be:
909
+
910
+ - int: same padding for all sides
911
+ - Sequence[int] of length 2: (height_pad, width_pad)
912
+ - Sequence[int] of length 4: (left, right, top, bottom)
913
+ in_size : Size, optional
914
+ The input size.
915
+ name : str, optional
916
+ The name of the module.
917
+
918
+ Examples
919
+ --------
920
+ .. code-block:: python
921
+
922
+ >>> import brainstate as brainstate
923
+ >>> import jax.numpy as jnp
924
+ >>> pad = brainstate.nn.CircularPad2d(1)
925
+ >>> input = jnp.ones((1, 4, 4, 3))
926
+ >>> output = pad(input)
927
+ >>> print(output.shape)
928
+ (1, 6, 6, 3)
929
+ """
930
+
931
+ def __init__(
932
+ self,
933
+ padding: Union[int, Sequence[int]],
934
+ in_size: Optional[Size] = None,
935
+ name: Optional[str] = None
936
+ ):
937
+ super().__init__(name=name)
938
+ self.padding = _format_padding(padding, 2)
939
+ if in_size is not None:
940
+ self.in_size = in_size
941
+ y = jax.eval_shape(
942
+ functools.partial(self.update),
943
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
944
+ )
945
+ self.out_size = y.shape
946
+
947
+ def update(self, x):
948
+ # Add (0, 0) padding for non-spatial dimensions
949
+ ndim = x.ndim
950
+ if ndim == 3:
951
+ # (height, width, channels) -> pad height and width
952
+ pad_width = [self.padding[0], self.padding[1], (0, 0)]
953
+ elif ndim == 4:
954
+ # (batch, height, width, channels) -> pad height and width
955
+ pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
956
+ else:
957
+ raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
958
+
959
+ return jnp.pad(x, pad_width, mode='wrap')
960
+
961
+
962
+ class CircularPad3d(Module):
963
+ """
964
+ Pads the input tensor using circular padding (wrap around).
965
+
966
+ Parameters
967
+ ----------
968
+ padding : int or Sequence[int]
969
+ The size of the padding. Can be:
970
+
971
+ - int: same padding for all sides
972
+ - Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
973
+ - Sequence[int] of length 6: (left, right, top, bottom, front, back)
974
+ in_size : Size, optional
975
+ The input size.
976
+ name : str, optional
977
+ The name of the module.
978
+
979
+ Examples
980
+ --------
981
+ .. code-block:: python
982
+
983
+ >>> import brainstate as brainstate
984
+ >>> import jax.numpy as jnp
985
+ >>> pad = brainstate.nn.CircularPad3d(1)
986
+ >>> input = jnp.ones((1, 4, 4, 4, 3))
987
+ >>> output = pad(input)
988
+ >>> print(output.shape)
989
+ (1, 6, 6, 6, 3)
990
+ """
991
+
992
+ def __init__(
993
+ self,
994
+ padding: Union[int, Sequence[int]],
995
+ in_size: Optional[Size] = None,
996
+ name: Optional[str] = None
997
+ ):
998
+ super().__init__(name=name)
999
+ self.padding = _format_padding(padding, 3)
1000
+ if in_size is not None:
1001
+ self.in_size = in_size
1002
+ y = jax.eval_shape(
1003
+ functools.partial(self.update),
1004
+ jax.ShapeDtypeStruct(self.in_size, environ.dftype())
1005
+ )
1006
+ self.out_size = y.shape
1007
+
1008
+ def update(self, x):
1009
+ # Add (0, 0) padding for non-spatial dimensions
1010
+ ndim = x.ndim
1011
+ if ndim == 4:
1012
+ # (depth, height, width, channels) -> pad depth, height and width
1013
+ pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
1014
+ elif ndim == 5:
1015
+ # (batch, depth, height, width, channels) -> pad depth, height and width
1016
+ pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
1017
+ else:
1018
+ raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
1019
+
1020
+ return jnp.pad(x, pad_width, mode='wrap')