brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_paddings.py
CHANGED
@@ -1,1020 +1,1020 @@
|
|
1
|
-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
Padding layers for neural networks.
|
18
|
-
|
19
|
-
This module provides various padding operations for 1D, 2D, and 3D tensors:
|
20
|
-
- ReflectionPad: Pads using reflection of the input boundary
|
21
|
-
- ReplicationPad: Pads using replication of the input boundary
|
22
|
-
- ZeroPad: Pads with zeros
|
23
|
-
- ConstantPad: Pads with a constant value
|
24
|
-
- CircularPad: Pads circularly (wrap around)
|
25
|
-
"""
|
26
|
-
|
27
|
-
import functools
|
28
|
-
from typing import Union, Sequence, Optional
|
29
|
-
|
30
|
-
import jax
|
31
|
-
import jax.numpy as jnp
|
32
|
-
|
33
|
-
from brainstate import environ
|
34
|
-
from brainstate.typing import Size
|
35
|
-
from ._module import Module
|
36
|
-
|
37
|
-
__all__ = [
|
38
|
-
# Reflection padding
|
39
|
-
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
|
40
|
-
# Replication padding
|
41
|
-
'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d',
|
42
|
-
# Zero padding
|
43
|
-
'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d',
|
44
|
-
# Constant padding
|
45
|
-
'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d',
|
46
|
-
# Circular padding
|
47
|
-
'CircularPad1d', 'CircularPad2d', 'CircularPad3d',
|
48
|
-
]
|
49
|
-
|
50
|
-
|
51
|
-
def _format_padding(padding: Union[int, Sequence[int]], ndim: int) -> Sequence[tuple]:
|
52
|
-
"""
|
53
|
-
Convert padding specification to format required by jax.numpy.pad.
|
54
|
-
|
55
|
-
Args:
|
56
|
-
padding: Padding size(s). Can be:
|
57
|
-
- int: same padding for all sides
|
58
|
-
- Sequence of length 2*ndim: (left, right) for each dimension
|
59
|
-
- Sequence of length ndim: same padding for left and right of each dimension
|
60
|
-
ndim: Number of spatial dimensions (1, 2, or 3)
|
61
|
-
|
62
|
-
Returns:
|
63
|
-
List of padding tuples for each dimension
|
64
|
-
"""
|
65
|
-
if isinstance(padding, int):
|
66
|
-
# Same padding for all sides of all dimensions
|
67
|
-
return [(padding, padding) for _ in range(ndim)]
|
68
|
-
|
69
|
-
padding = list(padding)
|
70
|
-
|
71
|
-
if len(padding) == ndim:
|
72
|
-
# Same padding for left and right of each dimension
|
73
|
-
return [(p, p) for p in padding]
|
74
|
-
elif len(padding) == 2 * ndim:
|
75
|
-
# Different padding for each side: (left1, right1, left2, right2, ...)
|
76
|
-
return [(padding[2 * i], padding[2 * i + 1]) for i in range(ndim)]
|
77
|
-
else:
|
78
|
-
raise ValueError(f"Padding must have length {ndim} or {2 * ndim}, got {len(padding)}")
|
79
|
-
|
80
|
-
|
81
|
-
# =============================================================================
|
82
|
-
# Reflection Padding
|
83
|
-
# =============================================================================
|
84
|
-
|
85
|
-
class ReflectionPad1d(Module):
|
86
|
-
"""
|
87
|
-
Pads the input tensor using the reflection of the input boundary.
|
88
|
-
|
89
|
-
Parameters
|
90
|
-
----------
|
91
|
-
padding : int or Sequence[int]
|
92
|
-
The size of the padding. Can be:
|
93
|
-
|
94
|
-
- int: same padding for both sides
|
95
|
-
- Sequence[int] of length 2: (left, right)
|
96
|
-
in_size : Size, optional
|
97
|
-
The input size.
|
98
|
-
name : str, optional
|
99
|
-
The name of the module.
|
100
|
-
|
101
|
-
Examples
|
102
|
-
--------
|
103
|
-
.. code-block:: python
|
104
|
-
|
105
|
-
>>> import brainstate as brainstate
|
106
|
-
>>> import jax.numpy as jnp
|
107
|
-
>>> pad = brainstate.nn.ReflectionPad1d(2)
|
108
|
-
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
109
|
-
>>> output = pad(input)
|
110
|
-
>>> print(output.shape)
|
111
|
-
(1, 9, 1)
|
112
|
-
"""
|
113
|
-
|
114
|
-
def __init__(
|
115
|
-
self,
|
116
|
-
padding: Union[int, Sequence[int]],
|
117
|
-
in_size: Optional[Size] = None,
|
118
|
-
name: Optional[str] = None
|
119
|
-
):
|
120
|
-
super().__init__(name=name)
|
121
|
-
self.padding = _format_padding(padding, 1)
|
122
|
-
if in_size is not None:
|
123
|
-
self.in_size = in_size
|
124
|
-
y = jax.eval_shape(
|
125
|
-
functools.partial(self.update),
|
126
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
127
|
-
)
|
128
|
-
self.out_size = y.shape
|
129
|
-
|
130
|
-
def update(self, x):
|
131
|
-
# Add (0, 0) padding for non-spatial dimensions
|
132
|
-
ndim = x.ndim
|
133
|
-
if ndim == 2:
|
134
|
-
# (length, channels) -> pad only length dimension
|
135
|
-
pad_width = [self.padding[0], (0, 0)]
|
136
|
-
elif ndim == 3:
|
137
|
-
# (batch, length, channels) -> pad only length dimension
|
138
|
-
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
139
|
-
else:
|
140
|
-
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
141
|
-
|
142
|
-
return jnp.pad(x, pad_width, mode='reflect')
|
143
|
-
|
144
|
-
|
145
|
-
class ReflectionPad2d(Module):
|
146
|
-
"""
|
147
|
-
Pads the input tensor using the reflection of the input boundary.
|
148
|
-
|
149
|
-
Parameters
|
150
|
-
----------
|
151
|
-
padding : int or Sequence[int]
|
152
|
-
The size of the padding. Can be:
|
153
|
-
|
154
|
-
- int: same padding for all sides
|
155
|
-
- Sequence[int] of length 2: (height_pad, width_pad)
|
156
|
-
- Sequence[int] of length 4: (left, right, top, bottom)
|
157
|
-
in_size : Size, optional
|
158
|
-
The input size.
|
159
|
-
name : str, optional
|
160
|
-
The name of the module.
|
161
|
-
|
162
|
-
Examples
|
163
|
-
--------
|
164
|
-
.. code-block:: python
|
165
|
-
|
166
|
-
>>> import brainstate as brainstate
|
167
|
-
>>> import jax.numpy as jnp
|
168
|
-
>>> pad = brainstate.nn.ReflectionPad2d(1)
|
169
|
-
>>> input = jnp.ones((1, 4, 4, 3))
|
170
|
-
>>> output = pad(input)
|
171
|
-
>>> print(output.shape)
|
172
|
-
(1, 6, 6, 3)
|
173
|
-
"""
|
174
|
-
|
175
|
-
def __init__(
|
176
|
-
self,
|
177
|
-
padding: Union[int, Sequence[int]],
|
178
|
-
in_size: Optional[Size] = None,
|
179
|
-
name: Optional[str] = None
|
180
|
-
):
|
181
|
-
super().__init__(name=name)
|
182
|
-
self.padding = _format_padding(padding, 2)
|
183
|
-
if in_size is not None:
|
184
|
-
self.in_size = in_size
|
185
|
-
y = jax.eval_shape(
|
186
|
-
functools.partial(self.update),
|
187
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
188
|
-
)
|
189
|
-
self.out_size = y.shape
|
190
|
-
|
191
|
-
def update(self, x):
|
192
|
-
# Add (0, 0) padding for non-spatial dimensions
|
193
|
-
ndim = x.ndim
|
194
|
-
if ndim == 3:
|
195
|
-
# (height, width, channels) -> pad height and width
|
196
|
-
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
197
|
-
elif ndim == 4:
|
198
|
-
# (batch, height, width, channels) -> pad height and width
|
199
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
200
|
-
else:
|
201
|
-
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
202
|
-
|
203
|
-
return jnp.pad(x, pad_width, mode='reflect')
|
204
|
-
|
205
|
-
|
206
|
-
class ReflectionPad3d(Module):
|
207
|
-
"""
|
208
|
-
Pads the input tensor using the reflection of the input boundary.
|
209
|
-
|
210
|
-
Parameters
|
211
|
-
----------
|
212
|
-
padding : int or Sequence[int]
|
213
|
-
The size of the padding. Can be:
|
214
|
-
|
215
|
-
- int: same padding for all sides
|
216
|
-
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
217
|
-
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
218
|
-
in_size : Size, optional
|
219
|
-
The input size.
|
220
|
-
name : str, optional
|
221
|
-
The name of the module.
|
222
|
-
|
223
|
-
Examples
|
224
|
-
--------
|
225
|
-
.. code-block:: python
|
226
|
-
|
227
|
-
>>> import brainstate as brainstate
|
228
|
-
>>> import jax.numpy as jnp
|
229
|
-
>>> pad = brainstate.nn.ReflectionPad3d(1)
|
230
|
-
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
231
|
-
>>> output = pad(input)
|
232
|
-
>>> print(output.shape)
|
233
|
-
(1, 6, 6, 6, 3)
|
234
|
-
"""
|
235
|
-
|
236
|
-
def __init__(
|
237
|
-
self,
|
238
|
-
padding: Union[int, Sequence[int]],
|
239
|
-
in_size: Optional[Size] = None,
|
240
|
-
name: Optional[str] = None
|
241
|
-
):
|
242
|
-
super().__init__(name=name)
|
243
|
-
self.padding = _format_padding(padding, 3)
|
244
|
-
if in_size is not None:
|
245
|
-
self.in_size = in_size
|
246
|
-
y = jax.eval_shape(
|
247
|
-
functools.partial(self.update),
|
248
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
249
|
-
)
|
250
|
-
self.out_size = y.shape
|
251
|
-
|
252
|
-
def update(self, x):
|
253
|
-
# Add (0, 0) padding for non-spatial dimensions
|
254
|
-
ndim = x.ndim
|
255
|
-
if ndim == 4:
|
256
|
-
# (depth, height, width, channels) -> pad depth, height and width
|
257
|
-
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
258
|
-
elif ndim == 5:
|
259
|
-
# (batch, depth, height, width, channels) -> pad depth, height and width
|
260
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
261
|
-
else:
|
262
|
-
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
263
|
-
|
264
|
-
return jnp.pad(x, pad_width, mode='reflect')
|
265
|
-
|
266
|
-
|
267
|
-
# =============================================================================
|
268
|
-
# Replication Padding
|
269
|
-
# =============================================================================
|
270
|
-
|
271
|
-
class ReplicationPad1d(Module):
|
272
|
-
"""
|
273
|
-
Pads the input tensor using replication of the input boundary.
|
274
|
-
|
275
|
-
Parameters
|
276
|
-
----------
|
277
|
-
padding : int or Sequence[int]
|
278
|
-
The size of the padding. Can be:
|
279
|
-
|
280
|
-
- int: same padding for both sides
|
281
|
-
- Sequence[int] of length 2: (left, right)
|
282
|
-
in_size : Size, optional
|
283
|
-
The input size.
|
284
|
-
name : str, optional
|
285
|
-
The name of the module.
|
286
|
-
|
287
|
-
Examples
|
288
|
-
--------
|
289
|
-
.. code-block:: python
|
290
|
-
|
291
|
-
>>> import brainstate as brainstate
|
292
|
-
>>> import jax.numpy as jnp
|
293
|
-
>>> pad = brainstate.nn.ReplicationPad1d(2)
|
294
|
-
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
295
|
-
>>> output = pad(input)
|
296
|
-
>>> print(output.shape)
|
297
|
-
(1, 9, 1)
|
298
|
-
"""
|
299
|
-
|
300
|
-
def __init__(
|
301
|
-
self,
|
302
|
-
padding: Union[int, Sequence[int]],
|
303
|
-
in_size: Optional[Size] = None,
|
304
|
-
name: Optional[str] = None
|
305
|
-
):
|
306
|
-
super().__init__(name=name)
|
307
|
-
self.padding = _format_padding(padding, 1)
|
308
|
-
if in_size is not None:
|
309
|
-
self.in_size = in_size
|
310
|
-
y = jax.eval_shape(
|
311
|
-
functools.partial(self.update),
|
312
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
313
|
-
)
|
314
|
-
self.out_size = y.shape
|
315
|
-
|
316
|
-
def update(self, x):
|
317
|
-
# Add (0, 0) padding for non-spatial dimensions
|
318
|
-
ndim = x.ndim
|
319
|
-
if ndim == 2:
|
320
|
-
# (length, channels) -> pad only length dimension
|
321
|
-
pad_width = [self.padding[0], (0, 0)]
|
322
|
-
elif ndim == 3:
|
323
|
-
# (batch, length, channels) -> pad only length dimension
|
324
|
-
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
325
|
-
else:
|
326
|
-
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
327
|
-
|
328
|
-
return jnp.pad(x, pad_width, mode='edge')
|
329
|
-
|
330
|
-
|
331
|
-
class ReplicationPad2d(Module):
|
332
|
-
"""
|
333
|
-
Pads the input tensor using replication of the input boundary.
|
334
|
-
|
335
|
-
Parameters
|
336
|
-
----------
|
337
|
-
padding : int or Sequence[int]
|
338
|
-
The size of the padding. Can be:
|
339
|
-
|
340
|
-
- int: same padding for all sides
|
341
|
-
- Sequence[int] of length 2: (height_pad, width_pad)
|
342
|
-
- Sequence[int] of length 4: (left, right, top, bottom)
|
343
|
-
in_size : Size, optional
|
344
|
-
The input size.
|
345
|
-
name : str, optional
|
346
|
-
The name of the module.
|
347
|
-
|
348
|
-
Examples
|
349
|
-
--------
|
350
|
-
.. code-block:: python
|
351
|
-
|
352
|
-
>>> import brainstate as brainstate
|
353
|
-
>>> import jax.numpy as jnp
|
354
|
-
>>> pad = brainstate.nn.ReplicationPad2d(1)
|
355
|
-
>>> input = jnp.ones((1, 4, 4, 3))
|
356
|
-
>>> output = pad(input)
|
357
|
-
>>> print(output.shape)
|
358
|
-
(1, 6, 6, 3)
|
359
|
-
"""
|
360
|
-
|
361
|
-
def __init__(
|
362
|
-
self,
|
363
|
-
padding: Union[int, Sequence[int]],
|
364
|
-
in_size: Optional[Size] = None,
|
365
|
-
name: Optional[str] = None
|
366
|
-
):
|
367
|
-
super().__init__(name=name)
|
368
|
-
self.padding = _format_padding(padding, 2)
|
369
|
-
if in_size is not None:
|
370
|
-
self.in_size = in_size
|
371
|
-
y = jax.eval_shape(
|
372
|
-
functools.partial(self.update),
|
373
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
374
|
-
)
|
375
|
-
self.out_size = y.shape
|
376
|
-
|
377
|
-
def update(self, x):
|
378
|
-
# Add (0, 0) padding for non-spatial dimensions
|
379
|
-
ndim = x.ndim
|
380
|
-
if ndim == 3:
|
381
|
-
# (height, width, channels) -> pad height and width
|
382
|
-
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
383
|
-
elif ndim == 4:
|
384
|
-
# (batch, height, width, channels) -> pad height and width
|
385
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
386
|
-
else:
|
387
|
-
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
388
|
-
|
389
|
-
return jnp.pad(x, pad_width, mode='edge')
|
390
|
-
|
391
|
-
|
392
|
-
class ReplicationPad3d(Module):
|
393
|
-
"""
|
394
|
-
Pads the input tensor using replication of the input boundary.
|
395
|
-
|
396
|
-
Parameters
|
397
|
-
----------
|
398
|
-
padding : int or Sequence[int]
|
399
|
-
The size of the padding. Can be:
|
400
|
-
|
401
|
-
- int: same padding for all sides
|
402
|
-
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
403
|
-
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
404
|
-
in_size : Size, optional
|
405
|
-
The input size.
|
406
|
-
name : str, optional
|
407
|
-
The name of the module.
|
408
|
-
|
409
|
-
Examples
|
410
|
-
--------
|
411
|
-
.. code-block:: python
|
412
|
-
|
413
|
-
>>> import brainstate as brainstate
|
414
|
-
>>> import jax.numpy as jnp
|
415
|
-
>>> pad = brainstate.nn.ReplicationPad3d(1)
|
416
|
-
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
417
|
-
>>> output = pad(input)
|
418
|
-
>>> print(output.shape)
|
419
|
-
(1, 6, 6, 6, 3)
|
420
|
-
"""
|
421
|
-
|
422
|
-
def __init__(
|
423
|
-
self,
|
424
|
-
padding: Union[int, Sequence[int]],
|
425
|
-
in_size: Optional[Size] = None,
|
426
|
-
name: Optional[str] = None
|
427
|
-
):
|
428
|
-
super().__init__(name=name)
|
429
|
-
self.padding = _format_padding(padding, 3)
|
430
|
-
if in_size is not None:
|
431
|
-
self.in_size = in_size
|
432
|
-
y = jax.eval_shape(
|
433
|
-
functools.partial(self.update),
|
434
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
435
|
-
)
|
436
|
-
self.out_size = y.shape
|
437
|
-
|
438
|
-
def update(self, x):
|
439
|
-
# Add (0, 0) padding for non-spatial dimensions
|
440
|
-
ndim = x.ndim
|
441
|
-
if ndim == 4:
|
442
|
-
# (depth, height, width, channels) -> pad depth, height and width
|
443
|
-
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
444
|
-
elif ndim == 5:
|
445
|
-
# (batch, depth, height, width, channels) -> pad depth, height and width
|
446
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
447
|
-
else:
|
448
|
-
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
449
|
-
|
450
|
-
return jnp.pad(x, pad_width, mode='edge')
|
451
|
-
|
452
|
-
|
453
|
-
# =============================================================================
|
454
|
-
# Zero Padding
|
455
|
-
# =============================================================================
|
456
|
-
|
457
|
-
class ZeroPad1d(Module):
|
458
|
-
"""
|
459
|
-
Pads the input tensor with zeros.
|
460
|
-
|
461
|
-
Parameters
|
462
|
-
----------
|
463
|
-
padding : int or Sequence[int]
|
464
|
-
The size of the padding. Can be:
|
465
|
-
|
466
|
-
- int: same padding for both sides
|
467
|
-
- Sequence[int] of length 2: (left, right)
|
468
|
-
in_size : Size, optional
|
469
|
-
The input size.
|
470
|
-
name : str, optional
|
471
|
-
The name of the module.
|
472
|
-
|
473
|
-
Examples
|
474
|
-
--------
|
475
|
-
.. code-block:: python
|
476
|
-
|
477
|
-
>>> import brainstate as brainstate
|
478
|
-
>>> import jax.numpy as jnp
|
479
|
-
>>> pad = brainstate.nn.ZeroPad1d(2)
|
480
|
-
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
481
|
-
>>> output = pad(input)
|
482
|
-
>>> print(output.shape)
|
483
|
-
(1, 9, 1)
|
484
|
-
"""
|
485
|
-
|
486
|
-
def __init__(
|
487
|
-
self,
|
488
|
-
padding: Union[int, Sequence[int]],
|
489
|
-
in_size: Optional[Size] = None,
|
490
|
-
name: Optional[str] = None
|
491
|
-
):
|
492
|
-
super().__init__(name=name)
|
493
|
-
self.padding = _format_padding(padding, 1)
|
494
|
-
if in_size is not None:
|
495
|
-
self.in_size = in_size
|
496
|
-
y = jax.eval_shape(
|
497
|
-
functools.partial(self.update),
|
498
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
499
|
-
)
|
500
|
-
self.out_size = y.shape
|
501
|
-
|
502
|
-
def update(self, x):
|
503
|
-
# Add (0, 0) padding for non-spatial dimensions
|
504
|
-
ndim = x.ndim
|
505
|
-
if ndim == 2:
|
506
|
-
# (length, channels) -> pad only length dimension
|
507
|
-
pad_width = [self.padding[0], (0, 0)]
|
508
|
-
elif ndim == 3:
|
509
|
-
# (batch, length, channels) -> pad only length dimension
|
510
|
-
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
511
|
-
else:
|
512
|
-
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
513
|
-
|
514
|
-
return jnp.pad(x, pad_width, mode='constant', constant_values=0)
|
515
|
-
|
516
|
-
|
517
|
-
class ZeroPad2d(Module):
|
518
|
-
"""
|
519
|
-
Pads the input tensor with zeros.
|
520
|
-
|
521
|
-
Parameters
|
522
|
-
----------
|
523
|
-
padding : int or Sequence[int]
|
524
|
-
The size of the padding. Can be:
|
525
|
-
|
526
|
-
- int: same padding for all sides
|
527
|
-
- Sequence[int] of length 2: (height_pad, width_pad)
|
528
|
-
- Sequence[int] of length 4: (left, right, top, bottom)
|
529
|
-
in_size : Size, optional
|
530
|
-
The input size.
|
531
|
-
name : str, optional
|
532
|
-
The name of the module.
|
533
|
-
|
534
|
-
Examples
|
535
|
-
--------
|
536
|
-
.. code-block:: python
|
537
|
-
|
538
|
-
>>> import brainstate as brainstate
|
539
|
-
>>> import jax.numpy as jnp
|
540
|
-
>>> pad = brainstate.nn.ZeroPad2d(1)
|
541
|
-
>>> input = jnp.ones((1, 4, 4, 3))
|
542
|
-
>>> output = pad(input)
|
543
|
-
>>> print(output.shape)
|
544
|
-
(1, 6, 6, 3)
|
545
|
-
"""
|
546
|
-
|
547
|
-
def __init__(
|
548
|
-
self,
|
549
|
-
padding: Union[int, Sequence[int]],
|
550
|
-
in_size: Optional[Size] = None,
|
551
|
-
name: Optional[str] = None
|
552
|
-
):
|
553
|
-
super().__init__(name=name)
|
554
|
-
self.padding = _format_padding(padding, 2)
|
555
|
-
if in_size is not None:
|
556
|
-
self.in_size = in_size
|
557
|
-
y = jax.eval_shape(
|
558
|
-
functools.partial(self.update),
|
559
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
560
|
-
)
|
561
|
-
self.out_size = y.shape
|
562
|
-
|
563
|
-
def update(self, x):
|
564
|
-
# Add (0, 0) padding for non-spatial dimensions
|
565
|
-
ndim = x.ndim
|
566
|
-
if ndim == 3:
|
567
|
-
# (height, width, channels) -> pad height and width
|
568
|
-
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
569
|
-
elif ndim == 4:
|
570
|
-
# (batch, height, width, channels) -> pad height and width
|
571
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
572
|
-
else:
|
573
|
-
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
574
|
-
|
575
|
-
return jnp.pad(x, pad_width, mode='constant', constant_values=0)
|
576
|
-
|
577
|
-
|
578
|
-
class ZeroPad3d(Module):
|
579
|
-
"""
|
580
|
-
Pads the input tensor with zeros.
|
581
|
-
|
582
|
-
Parameters
|
583
|
-
----------
|
584
|
-
padding : int or Sequence[int]
|
585
|
-
The size of the padding. Can be:
|
586
|
-
|
587
|
-
- int: same padding for all sides
|
588
|
-
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
589
|
-
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
590
|
-
in_size : Size, optional
|
591
|
-
The input size.
|
592
|
-
name : str, optional
|
593
|
-
The name of the module.
|
594
|
-
|
595
|
-
Examples
|
596
|
-
--------
|
597
|
-
.. code-block:: python
|
598
|
-
|
599
|
-
>>> import brainstate as brainstate
|
600
|
-
>>> import jax.numpy as jnp
|
601
|
-
>>> pad = brainstate.nn.ZeroPad3d(1)
|
602
|
-
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
603
|
-
>>> output = pad(input)
|
604
|
-
>>> print(output.shape)
|
605
|
-
(1, 6, 6, 6, 3)
|
606
|
-
"""
|
607
|
-
|
608
|
-
def __init__(
|
609
|
-
self,
|
610
|
-
padding: Union[int, Sequence[int]],
|
611
|
-
in_size: Optional[Size] = None,
|
612
|
-
name: Optional[str] = None
|
613
|
-
):
|
614
|
-
super().__init__(name=name)
|
615
|
-
self.padding = _format_padding(padding, 3)
|
616
|
-
if in_size is not None:
|
617
|
-
self.in_size = in_size
|
618
|
-
y = jax.eval_shape(
|
619
|
-
functools.partial(self.update),
|
620
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
621
|
-
)
|
622
|
-
self.out_size = y.shape
|
623
|
-
|
624
|
-
def update(self, x):
|
625
|
-
# Add (0, 0) padding for non-spatial dimensions
|
626
|
-
ndim = x.ndim
|
627
|
-
if ndim == 4:
|
628
|
-
# (depth, height, width, channels) -> pad depth, height and width
|
629
|
-
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
630
|
-
elif ndim == 5:
|
631
|
-
# (batch, depth, height, width, channels) -> pad depth, height and width
|
632
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
633
|
-
else:
|
634
|
-
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
635
|
-
|
636
|
-
return jnp.pad(x, pad_width, mode='constant', constant_values=0)
|
637
|
-
|
638
|
-
|
639
|
-
# =============================================================================
|
640
|
-
# Constant Padding
|
641
|
-
# =============================================================================
|
642
|
-
|
643
|
-
class ConstantPad1d(Module):
|
644
|
-
"""
|
645
|
-
Pads the input tensor with a constant value.
|
646
|
-
|
647
|
-
Parameters
|
648
|
-
----------
|
649
|
-
padding : int or Sequence[int]
|
650
|
-
The size of the padding. Can be:
|
651
|
-
|
652
|
-
- int: same padding for both sides
|
653
|
-
- Sequence[int] of length 2: (left, right)
|
654
|
-
value : float, optional
|
655
|
-
The constant value to use for padding. Default is 0.
|
656
|
-
in_size : Size, optional
|
657
|
-
The input size.
|
658
|
-
name : str, optional
|
659
|
-
The name of the module.
|
660
|
-
|
661
|
-
Examples
|
662
|
-
--------
|
663
|
-
.. code-block:: python
|
664
|
-
|
665
|
-
>>> import brainstate as brainstate
|
666
|
-
>>> import jax.numpy as jnp
|
667
|
-
>>> pad = brainstate.nn.ConstantPad1d(2, value=3.5)
|
668
|
-
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
669
|
-
>>> output = pad(input)
|
670
|
-
>>> print(output.shape)
|
671
|
-
(1, 9, 1)
|
672
|
-
"""
|
673
|
-
|
674
|
-
def __init__(
|
675
|
-
self,
|
676
|
-
padding: Union[int, Sequence[int]],
|
677
|
-
value: float = 0,
|
678
|
-
in_size: Optional[Size] = None,
|
679
|
-
name: Optional[str] = None
|
680
|
-
):
|
681
|
-
super().__init__(name=name)
|
682
|
-
self.padding = _format_padding(padding, 1)
|
683
|
-
self.value = value
|
684
|
-
if in_size is not None:
|
685
|
-
self.in_size = in_size
|
686
|
-
y = jax.eval_shape(
|
687
|
-
functools.partial(self.update),
|
688
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
689
|
-
)
|
690
|
-
self.out_size = y.shape
|
691
|
-
|
692
|
-
def update(self, x):
|
693
|
-
# Add (0, 0) padding for non-spatial dimensions
|
694
|
-
ndim = x.ndim
|
695
|
-
if ndim == 2:
|
696
|
-
# (length, channels) -> pad only length dimension
|
697
|
-
pad_width = [self.padding[0], (0, 0)]
|
698
|
-
elif ndim == 3:
|
699
|
-
# (batch, length, channels) -> pad only length dimension
|
700
|
-
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
701
|
-
else:
|
702
|
-
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
703
|
-
|
704
|
-
return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
|
705
|
-
|
706
|
-
|
707
|
-
class ConstantPad2d(Module):
|
708
|
-
"""
|
709
|
-
Pads the input tensor with a constant value.
|
710
|
-
|
711
|
-
Parameters
|
712
|
-
----------
|
713
|
-
padding : int or Sequence[int]
|
714
|
-
The size of the padding. Can be:
|
715
|
-
|
716
|
-
- int: same padding for all sides
|
717
|
-
- Sequence[int] of length 2: (height_pad, width_pad)
|
718
|
-
- Sequence[int] of length 4: (left, right, top, bottom)
|
719
|
-
value : float, optional
|
720
|
-
The constant value to use for padding. Default is 0.
|
721
|
-
in_size : Size, optional
|
722
|
-
The input size.
|
723
|
-
name : str, optional
|
724
|
-
The name of the module.
|
725
|
-
|
726
|
-
Examples
|
727
|
-
--------
|
728
|
-
.. code-block:: python
|
729
|
-
|
730
|
-
>>> import brainstate as brainstate
|
731
|
-
>>> import jax.numpy as jnp
|
732
|
-
>>> pad = brainstate.nn.ConstantPad2d(1, value=3.5)
|
733
|
-
>>> input = jnp.ones((1, 4, 4, 3))
|
734
|
-
>>> output = pad(input)
|
735
|
-
>>> print(output.shape)
|
736
|
-
(1, 6, 6, 3)
|
737
|
-
"""
|
738
|
-
|
739
|
-
def __init__(
|
740
|
-
self,
|
741
|
-
padding: Union[int, Sequence[int]],
|
742
|
-
value: float = 0,
|
743
|
-
in_size: Optional[Size] = None,
|
744
|
-
name: Optional[str] = None
|
745
|
-
):
|
746
|
-
super().__init__(name=name)
|
747
|
-
self.padding = _format_padding(padding, 2)
|
748
|
-
self.value = value
|
749
|
-
if in_size is not None:
|
750
|
-
self.in_size = in_size
|
751
|
-
y = jax.eval_shape(
|
752
|
-
functools.partial(self.update),
|
753
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
754
|
-
)
|
755
|
-
self.out_size = y.shape
|
756
|
-
|
757
|
-
def update(self, x):
|
758
|
-
# Add (0, 0) padding for non-spatial dimensions
|
759
|
-
ndim = x.ndim
|
760
|
-
if ndim == 3:
|
761
|
-
# (height, width, channels) -> pad height and width
|
762
|
-
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
763
|
-
elif ndim == 4:
|
764
|
-
# (batch, height, width, channels) -> pad height and width
|
765
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
766
|
-
else:
|
767
|
-
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
768
|
-
|
769
|
-
return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
|
770
|
-
|
771
|
-
|
772
|
-
class ConstantPad3d(Module):
|
773
|
-
"""
|
774
|
-
Pads the input tensor with a constant value.
|
775
|
-
|
776
|
-
Parameters
|
777
|
-
----------
|
778
|
-
padding : int or Sequence[int]
|
779
|
-
The size of the padding. Can be:
|
780
|
-
|
781
|
-
- int: same padding for all sides
|
782
|
-
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
783
|
-
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
784
|
-
value : float, optional
|
785
|
-
The constant value to use for padding. Default is 0.
|
786
|
-
in_size : Size, optional
|
787
|
-
The input size.
|
788
|
-
name : str, optional
|
789
|
-
The name of the module.
|
790
|
-
|
791
|
-
Examples
|
792
|
-
--------
|
793
|
-
.. code-block:: python
|
794
|
-
|
795
|
-
>>> import brainstate as brainstate
|
796
|
-
>>> import jax.numpy as jnp
|
797
|
-
>>> pad = brainstate.nn.ConstantPad3d(1, value=3.5)
|
798
|
-
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
799
|
-
>>> output = pad(input)
|
800
|
-
>>> print(output.shape)
|
801
|
-
(1, 6, 6, 6, 3)
|
802
|
-
"""
|
803
|
-
|
804
|
-
def __init__(
|
805
|
-
self,
|
806
|
-
padding: Union[int, Sequence[int]],
|
807
|
-
value: float = 0,
|
808
|
-
in_size: Optional[Size] = None,
|
809
|
-
name: Optional[str] = None
|
810
|
-
):
|
811
|
-
super().__init__(name=name)
|
812
|
-
self.padding = _format_padding(padding, 3)
|
813
|
-
self.value = value
|
814
|
-
if in_size is not None:
|
815
|
-
self.in_size = in_size
|
816
|
-
y = jax.eval_shape(
|
817
|
-
functools.partial(self.update),
|
818
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
819
|
-
)
|
820
|
-
self.out_size = y.shape
|
821
|
-
|
822
|
-
def update(self, x):
|
823
|
-
# Add (0, 0) padding for non-spatial dimensions
|
824
|
-
ndim = x.ndim
|
825
|
-
if ndim == 4:
|
826
|
-
# (depth, height, width, channels) -> pad depth, height and width
|
827
|
-
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
828
|
-
elif ndim == 5:
|
829
|
-
# (batch, depth, height, width, channels) -> pad depth, height and width
|
830
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
831
|
-
else:
|
832
|
-
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
833
|
-
|
834
|
-
return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
|
835
|
-
|
836
|
-
|
837
|
-
# =============================================================================
|
838
|
-
# Circular Padding
|
839
|
-
# =============================================================================
|
840
|
-
|
841
|
-
class CircularPad1d(Module):
|
842
|
-
"""
|
843
|
-
Pads the input tensor using circular padding (wrap around).
|
844
|
-
|
845
|
-
Parameters
|
846
|
-
----------
|
847
|
-
padding : int or Sequence[int]
|
848
|
-
The size of the padding. Can be:
|
849
|
-
|
850
|
-
- int: same padding for both sides
|
851
|
-
- Sequence[int] of length 2: (left, right)
|
852
|
-
in_size : Size, optional
|
853
|
-
The input size.
|
854
|
-
name : str, optional
|
855
|
-
The name of the module.
|
856
|
-
|
857
|
-
Examples
|
858
|
-
--------
|
859
|
-
.. code-block:: python
|
860
|
-
|
861
|
-
>>> import brainstate as brainstate
|
862
|
-
>>> import jax.numpy as jnp
|
863
|
-
>>> pad = brainstate.nn.CircularPad1d(2)
|
864
|
-
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
865
|
-
>>> output = pad(input)
|
866
|
-
>>> print(output.shape)
|
867
|
-
(1, 9, 1)
|
868
|
-
"""
|
869
|
-
|
870
|
-
def __init__(
|
871
|
-
self,
|
872
|
-
padding: Union[int, Sequence[int]],
|
873
|
-
in_size: Optional[Size] = None,
|
874
|
-
name: Optional[str] = None
|
875
|
-
):
|
876
|
-
super().__init__(name=name)
|
877
|
-
self.padding = _format_padding(padding, 1)
|
878
|
-
if in_size is not None:
|
879
|
-
self.in_size = in_size
|
880
|
-
y = jax.eval_shape(
|
881
|
-
functools.partial(self.update),
|
882
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
883
|
-
)
|
884
|
-
self.out_size = y.shape
|
885
|
-
|
886
|
-
def update(self, x):
|
887
|
-
# Add (0, 0) padding for non-spatial dimensions
|
888
|
-
ndim = x.ndim
|
889
|
-
if ndim == 2:
|
890
|
-
# (length, channels) -> pad only length dimension
|
891
|
-
pad_width = [self.padding[0], (0, 0)]
|
892
|
-
elif ndim == 3:
|
893
|
-
# (batch, length, channels) -> pad only length dimension
|
894
|
-
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
895
|
-
else:
|
896
|
-
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
897
|
-
|
898
|
-
return jnp.pad(x, pad_width, mode='wrap')
|
899
|
-
|
900
|
-
|
901
|
-
class CircularPad2d(Module):
|
902
|
-
"""
|
903
|
-
Pads the input tensor using circular padding (wrap around).
|
904
|
-
|
905
|
-
Parameters
|
906
|
-
----------
|
907
|
-
padding : int or Sequence[int]
|
908
|
-
The size of the padding. Can be:
|
909
|
-
|
910
|
-
- int: same padding for all sides
|
911
|
-
- Sequence[int] of length 2: (height_pad, width_pad)
|
912
|
-
- Sequence[int] of length 4: (left, right, top, bottom)
|
913
|
-
in_size : Size, optional
|
914
|
-
The input size.
|
915
|
-
name : str, optional
|
916
|
-
The name of the module.
|
917
|
-
|
918
|
-
Examples
|
919
|
-
--------
|
920
|
-
.. code-block:: python
|
921
|
-
|
922
|
-
>>> import brainstate as brainstate
|
923
|
-
>>> import jax.numpy as jnp
|
924
|
-
>>> pad = brainstate.nn.CircularPad2d(1)
|
925
|
-
>>> input = jnp.ones((1, 4, 4, 3))
|
926
|
-
>>> output = pad(input)
|
927
|
-
>>> print(output.shape)
|
928
|
-
(1, 6, 6, 3)
|
929
|
-
"""
|
930
|
-
|
931
|
-
def __init__(
|
932
|
-
self,
|
933
|
-
padding: Union[int, Sequence[int]],
|
934
|
-
in_size: Optional[Size] = None,
|
935
|
-
name: Optional[str] = None
|
936
|
-
):
|
937
|
-
super().__init__(name=name)
|
938
|
-
self.padding = _format_padding(padding, 2)
|
939
|
-
if in_size is not None:
|
940
|
-
self.in_size = in_size
|
941
|
-
y = jax.eval_shape(
|
942
|
-
functools.partial(self.update),
|
943
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
944
|
-
)
|
945
|
-
self.out_size = y.shape
|
946
|
-
|
947
|
-
def update(self, x):
|
948
|
-
# Add (0, 0) padding for non-spatial dimensions
|
949
|
-
ndim = x.ndim
|
950
|
-
if ndim == 3:
|
951
|
-
# (height, width, channels) -> pad height and width
|
952
|
-
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
953
|
-
elif ndim == 4:
|
954
|
-
# (batch, height, width, channels) -> pad height and width
|
955
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
956
|
-
else:
|
957
|
-
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
958
|
-
|
959
|
-
return jnp.pad(x, pad_width, mode='wrap')
|
960
|
-
|
961
|
-
|
962
|
-
class CircularPad3d(Module):
|
963
|
-
"""
|
964
|
-
Pads the input tensor using circular padding (wrap around).
|
965
|
-
|
966
|
-
Parameters
|
967
|
-
----------
|
968
|
-
padding : int or Sequence[int]
|
969
|
-
The size of the padding. Can be:
|
970
|
-
|
971
|
-
- int: same padding for all sides
|
972
|
-
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
973
|
-
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
974
|
-
in_size : Size, optional
|
975
|
-
The input size.
|
976
|
-
name : str, optional
|
977
|
-
The name of the module.
|
978
|
-
|
979
|
-
Examples
|
980
|
-
--------
|
981
|
-
.. code-block:: python
|
982
|
-
|
983
|
-
>>> import brainstate as brainstate
|
984
|
-
>>> import jax.numpy as jnp
|
985
|
-
>>> pad = brainstate.nn.CircularPad3d(1)
|
986
|
-
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
987
|
-
>>> output = pad(input)
|
988
|
-
>>> print(output.shape)
|
989
|
-
(1, 6, 6, 6, 3)
|
990
|
-
"""
|
991
|
-
|
992
|
-
def __init__(
|
993
|
-
self,
|
994
|
-
padding: Union[int, Sequence[int]],
|
995
|
-
in_size: Optional[Size] = None,
|
996
|
-
name: Optional[str] = None
|
997
|
-
):
|
998
|
-
super().__init__(name=name)
|
999
|
-
self.padding = _format_padding(padding, 3)
|
1000
|
-
if in_size is not None:
|
1001
|
-
self.in_size = in_size
|
1002
|
-
y = jax.eval_shape(
|
1003
|
-
functools.partial(self.update),
|
1004
|
-
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
1005
|
-
)
|
1006
|
-
self.out_size = y.shape
|
1007
|
-
|
1008
|
-
def update(self, x):
|
1009
|
-
# Add (0, 0) padding for non-spatial dimensions
|
1010
|
-
ndim = x.ndim
|
1011
|
-
if ndim == 4:
|
1012
|
-
# (depth, height, width, channels) -> pad depth, height and width
|
1013
|
-
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
1014
|
-
elif ndim == 5:
|
1015
|
-
# (batch, depth, height, width, channels) -> pad depth, height and width
|
1016
|
-
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
1017
|
-
else:
|
1018
|
-
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
1019
|
-
|
1020
|
-
return jnp.pad(x, pad_width, mode='wrap')
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Padding layers for neural networks.
|
18
|
+
|
19
|
+
This module provides various padding operations for 1D, 2D, and 3D tensors:
|
20
|
+
- ReflectionPad: Pads using reflection of the input boundary
|
21
|
+
- ReplicationPad: Pads using replication of the input boundary
|
22
|
+
- ZeroPad: Pads with zeros
|
23
|
+
- ConstantPad: Pads with a constant value
|
24
|
+
- CircularPad: Pads circularly (wrap around)
|
25
|
+
"""
|
26
|
+
|
27
|
+
import functools
|
28
|
+
from typing import Union, Sequence, Optional
|
29
|
+
|
30
|
+
import jax
|
31
|
+
import jax.numpy as jnp
|
32
|
+
|
33
|
+
from brainstate import environ
|
34
|
+
from brainstate.typing import Size
|
35
|
+
from ._module import Module
|
36
|
+
|
37
|
+
__all__ = [
|
38
|
+
# Reflection padding
|
39
|
+
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
|
40
|
+
# Replication padding
|
41
|
+
'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d',
|
42
|
+
# Zero padding
|
43
|
+
'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d',
|
44
|
+
# Constant padding
|
45
|
+
'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d',
|
46
|
+
# Circular padding
|
47
|
+
'CircularPad1d', 'CircularPad2d', 'CircularPad3d',
|
48
|
+
]
|
49
|
+
|
50
|
+
|
51
|
+
def _format_padding(padding: Union[int, Sequence[int]], ndim: int) -> Sequence[tuple]:
|
52
|
+
"""
|
53
|
+
Convert padding specification to format required by jax.numpy.pad.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
padding: Padding size(s). Can be:
|
57
|
+
- int: same padding for all sides
|
58
|
+
- Sequence of length 2*ndim: (left, right) for each dimension
|
59
|
+
- Sequence of length ndim: same padding for left and right of each dimension
|
60
|
+
ndim: Number of spatial dimensions (1, 2, or 3)
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
List of padding tuples for each dimension
|
64
|
+
"""
|
65
|
+
if isinstance(padding, int):
|
66
|
+
# Same padding for all sides of all dimensions
|
67
|
+
return [(padding, padding) for _ in range(ndim)]
|
68
|
+
|
69
|
+
padding = list(padding)
|
70
|
+
|
71
|
+
if len(padding) == ndim:
|
72
|
+
# Same padding for left and right of each dimension
|
73
|
+
return [(p, p) for p in padding]
|
74
|
+
elif len(padding) == 2 * ndim:
|
75
|
+
# Different padding for each side: (left1, right1, left2, right2, ...)
|
76
|
+
return [(padding[2 * i], padding[2 * i + 1]) for i in range(ndim)]
|
77
|
+
else:
|
78
|
+
raise ValueError(f"Padding must have length {ndim} or {2 * ndim}, got {len(padding)}")
|
79
|
+
|
80
|
+
|
81
|
+
# =============================================================================
|
82
|
+
# Reflection Padding
|
83
|
+
# =============================================================================
|
84
|
+
|
85
|
+
class ReflectionPad1d(Module):
|
86
|
+
"""
|
87
|
+
Pads the input tensor using the reflection of the input boundary.
|
88
|
+
|
89
|
+
Parameters
|
90
|
+
----------
|
91
|
+
padding : int or Sequence[int]
|
92
|
+
The size of the padding. Can be:
|
93
|
+
|
94
|
+
- int: same padding for both sides
|
95
|
+
- Sequence[int] of length 2: (left, right)
|
96
|
+
in_size : Size, optional
|
97
|
+
The input size.
|
98
|
+
name : str, optional
|
99
|
+
The name of the module.
|
100
|
+
|
101
|
+
Examples
|
102
|
+
--------
|
103
|
+
.. code-block:: python
|
104
|
+
|
105
|
+
>>> import brainstate as brainstate
|
106
|
+
>>> import jax.numpy as jnp
|
107
|
+
>>> pad = brainstate.nn.ReflectionPad1d(2)
|
108
|
+
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
109
|
+
>>> output = pad(input)
|
110
|
+
>>> print(output.shape)
|
111
|
+
(1, 9, 1)
|
112
|
+
"""
|
113
|
+
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
padding: Union[int, Sequence[int]],
|
117
|
+
in_size: Optional[Size] = None,
|
118
|
+
name: Optional[str] = None
|
119
|
+
):
|
120
|
+
super().__init__(name=name)
|
121
|
+
self.padding = _format_padding(padding, 1)
|
122
|
+
if in_size is not None:
|
123
|
+
self.in_size = in_size
|
124
|
+
y = jax.eval_shape(
|
125
|
+
functools.partial(self.update),
|
126
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
127
|
+
)
|
128
|
+
self.out_size = y.shape
|
129
|
+
|
130
|
+
def update(self, x):
|
131
|
+
# Add (0, 0) padding for non-spatial dimensions
|
132
|
+
ndim = x.ndim
|
133
|
+
if ndim == 2:
|
134
|
+
# (length, channels) -> pad only length dimension
|
135
|
+
pad_width = [self.padding[0], (0, 0)]
|
136
|
+
elif ndim == 3:
|
137
|
+
# (batch, length, channels) -> pad only length dimension
|
138
|
+
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
139
|
+
else:
|
140
|
+
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
141
|
+
|
142
|
+
return jnp.pad(x, pad_width, mode='reflect')
|
143
|
+
|
144
|
+
|
145
|
+
class ReflectionPad2d(Module):
|
146
|
+
"""
|
147
|
+
Pads the input tensor using the reflection of the input boundary.
|
148
|
+
|
149
|
+
Parameters
|
150
|
+
----------
|
151
|
+
padding : int or Sequence[int]
|
152
|
+
The size of the padding. Can be:
|
153
|
+
|
154
|
+
- int: same padding for all sides
|
155
|
+
- Sequence[int] of length 2: (height_pad, width_pad)
|
156
|
+
- Sequence[int] of length 4: (left, right, top, bottom)
|
157
|
+
in_size : Size, optional
|
158
|
+
The input size.
|
159
|
+
name : str, optional
|
160
|
+
The name of the module.
|
161
|
+
|
162
|
+
Examples
|
163
|
+
--------
|
164
|
+
.. code-block:: python
|
165
|
+
|
166
|
+
>>> import brainstate as brainstate
|
167
|
+
>>> import jax.numpy as jnp
|
168
|
+
>>> pad = brainstate.nn.ReflectionPad2d(1)
|
169
|
+
>>> input = jnp.ones((1, 4, 4, 3))
|
170
|
+
>>> output = pad(input)
|
171
|
+
>>> print(output.shape)
|
172
|
+
(1, 6, 6, 3)
|
173
|
+
"""
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
padding: Union[int, Sequence[int]],
|
178
|
+
in_size: Optional[Size] = None,
|
179
|
+
name: Optional[str] = None
|
180
|
+
):
|
181
|
+
super().__init__(name=name)
|
182
|
+
self.padding = _format_padding(padding, 2)
|
183
|
+
if in_size is not None:
|
184
|
+
self.in_size = in_size
|
185
|
+
y = jax.eval_shape(
|
186
|
+
functools.partial(self.update),
|
187
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
188
|
+
)
|
189
|
+
self.out_size = y.shape
|
190
|
+
|
191
|
+
def update(self, x):
|
192
|
+
# Add (0, 0) padding for non-spatial dimensions
|
193
|
+
ndim = x.ndim
|
194
|
+
if ndim == 3:
|
195
|
+
# (height, width, channels) -> pad height and width
|
196
|
+
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
197
|
+
elif ndim == 4:
|
198
|
+
# (batch, height, width, channels) -> pad height and width
|
199
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
200
|
+
else:
|
201
|
+
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
202
|
+
|
203
|
+
return jnp.pad(x, pad_width, mode='reflect')
|
204
|
+
|
205
|
+
|
206
|
+
class ReflectionPad3d(Module):
|
207
|
+
"""
|
208
|
+
Pads the input tensor using the reflection of the input boundary.
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
padding : int or Sequence[int]
|
213
|
+
The size of the padding. Can be:
|
214
|
+
|
215
|
+
- int: same padding for all sides
|
216
|
+
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
217
|
+
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
218
|
+
in_size : Size, optional
|
219
|
+
The input size.
|
220
|
+
name : str, optional
|
221
|
+
The name of the module.
|
222
|
+
|
223
|
+
Examples
|
224
|
+
--------
|
225
|
+
.. code-block:: python
|
226
|
+
|
227
|
+
>>> import brainstate as brainstate
|
228
|
+
>>> import jax.numpy as jnp
|
229
|
+
>>> pad = brainstate.nn.ReflectionPad3d(1)
|
230
|
+
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
231
|
+
>>> output = pad(input)
|
232
|
+
>>> print(output.shape)
|
233
|
+
(1, 6, 6, 6, 3)
|
234
|
+
"""
|
235
|
+
|
236
|
+
def __init__(
|
237
|
+
self,
|
238
|
+
padding: Union[int, Sequence[int]],
|
239
|
+
in_size: Optional[Size] = None,
|
240
|
+
name: Optional[str] = None
|
241
|
+
):
|
242
|
+
super().__init__(name=name)
|
243
|
+
self.padding = _format_padding(padding, 3)
|
244
|
+
if in_size is not None:
|
245
|
+
self.in_size = in_size
|
246
|
+
y = jax.eval_shape(
|
247
|
+
functools.partial(self.update),
|
248
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
249
|
+
)
|
250
|
+
self.out_size = y.shape
|
251
|
+
|
252
|
+
def update(self, x):
|
253
|
+
# Add (0, 0) padding for non-spatial dimensions
|
254
|
+
ndim = x.ndim
|
255
|
+
if ndim == 4:
|
256
|
+
# (depth, height, width, channels) -> pad depth, height and width
|
257
|
+
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
258
|
+
elif ndim == 5:
|
259
|
+
# (batch, depth, height, width, channels) -> pad depth, height and width
|
260
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
261
|
+
else:
|
262
|
+
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
263
|
+
|
264
|
+
return jnp.pad(x, pad_width, mode='reflect')
|
265
|
+
|
266
|
+
|
267
|
+
# =============================================================================
|
268
|
+
# Replication Padding
|
269
|
+
# =============================================================================
|
270
|
+
|
271
|
+
class ReplicationPad1d(Module):
|
272
|
+
"""
|
273
|
+
Pads the input tensor using replication of the input boundary.
|
274
|
+
|
275
|
+
Parameters
|
276
|
+
----------
|
277
|
+
padding : int or Sequence[int]
|
278
|
+
The size of the padding. Can be:
|
279
|
+
|
280
|
+
- int: same padding for both sides
|
281
|
+
- Sequence[int] of length 2: (left, right)
|
282
|
+
in_size : Size, optional
|
283
|
+
The input size.
|
284
|
+
name : str, optional
|
285
|
+
The name of the module.
|
286
|
+
|
287
|
+
Examples
|
288
|
+
--------
|
289
|
+
.. code-block:: python
|
290
|
+
|
291
|
+
>>> import brainstate as brainstate
|
292
|
+
>>> import jax.numpy as jnp
|
293
|
+
>>> pad = brainstate.nn.ReplicationPad1d(2)
|
294
|
+
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
295
|
+
>>> output = pad(input)
|
296
|
+
>>> print(output.shape)
|
297
|
+
(1, 9, 1)
|
298
|
+
"""
|
299
|
+
|
300
|
+
def __init__(
|
301
|
+
self,
|
302
|
+
padding: Union[int, Sequence[int]],
|
303
|
+
in_size: Optional[Size] = None,
|
304
|
+
name: Optional[str] = None
|
305
|
+
):
|
306
|
+
super().__init__(name=name)
|
307
|
+
self.padding = _format_padding(padding, 1)
|
308
|
+
if in_size is not None:
|
309
|
+
self.in_size = in_size
|
310
|
+
y = jax.eval_shape(
|
311
|
+
functools.partial(self.update),
|
312
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
313
|
+
)
|
314
|
+
self.out_size = y.shape
|
315
|
+
|
316
|
+
def update(self, x):
|
317
|
+
# Add (0, 0) padding for non-spatial dimensions
|
318
|
+
ndim = x.ndim
|
319
|
+
if ndim == 2:
|
320
|
+
# (length, channels) -> pad only length dimension
|
321
|
+
pad_width = [self.padding[0], (0, 0)]
|
322
|
+
elif ndim == 3:
|
323
|
+
# (batch, length, channels) -> pad only length dimension
|
324
|
+
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
325
|
+
else:
|
326
|
+
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
327
|
+
|
328
|
+
return jnp.pad(x, pad_width, mode='edge')
|
329
|
+
|
330
|
+
|
331
|
+
class ReplicationPad2d(Module):
|
332
|
+
"""
|
333
|
+
Pads the input tensor using replication of the input boundary.
|
334
|
+
|
335
|
+
Parameters
|
336
|
+
----------
|
337
|
+
padding : int or Sequence[int]
|
338
|
+
The size of the padding. Can be:
|
339
|
+
|
340
|
+
- int: same padding for all sides
|
341
|
+
- Sequence[int] of length 2: (height_pad, width_pad)
|
342
|
+
- Sequence[int] of length 4: (left, right, top, bottom)
|
343
|
+
in_size : Size, optional
|
344
|
+
The input size.
|
345
|
+
name : str, optional
|
346
|
+
The name of the module.
|
347
|
+
|
348
|
+
Examples
|
349
|
+
--------
|
350
|
+
.. code-block:: python
|
351
|
+
|
352
|
+
>>> import brainstate as brainstate
|
353
|
+
>>> import jax.numpy as jnp
|
354
|
+
>>> pad = brainstate.nn.ReplicationPad2d(1)
|
355
|
+
>>> input = jnp.ones((1, 4, 4, 3))
|
356
|
+
>>> output = pad(input)
|
357
|
+
>>> print(output.shape)
|
358
|
+
(1, 6, 6, 3)
|
359
|
+
"""
|
360
|
+
|
361
|
+
def __init__(
|
362
|
+
self,
|
363
|
+
padding: Union[int, Sequence[int]],
|
364
|
+
in_size: Optional[Size] = None,
|
365
|
+
name: Optional[str] = None
|
366
|
+
):
|
367
|
+
super().__init__(name=name)
|
368
|
+
self.padding = _format_padding(padding, 2)
|
369
|
+
if in_size is not None:
|
370
|
+
self.in_size = in_size
|
371
|
+
y = jax.eval_shape(
|
372
|
+
functools.partial(self.update),
|
373
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
374
|
+
)
|
375
|
+
self.out_size = y.shape
|
376
|
+
|
377
|
+
def update(self, x):
|
378
|
+
# Add (0, 0) padding for non-spatial dimensions
|
379
|
+
ndim = x.ndim
|
380
|
+
if ndim == 3:
|
381
|
+
# (height, width, channels) -> pad height and width
|
382
|
+
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
383
|
+
elif ndim == 4:
|
384
|
+
# (batch, height, width, channels) -> pad height and width
|
385
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
386
|
+
else:
|
387
|
+
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
388
|
+
|
389
|
+
return jnp.pad(x, pad_width, mode='edge')
|
390
|
+
|
391
|
+
|
392
|
+
class ReplicationPad3d(Module):
|
393
|
+
"""
|
394
|
+
Pads the input tensor using replication of the input boundary.
|
395
|
+
|
396
|
+
Parameters
|
397
|
+
----------
|
398
|
+
padding : int or Sequence[int]
|
399
|
+
The size of the padding. Can be:
|
400
|
+
|
401
|
+
- int: same padding for all sides
|
402
|
+
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
403
|
+
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
404
|
+
in_size : Size, optional
|
405
|
+
The input size.
|
406
|
+
name : str, optional
|
407
|
+
The name of the module.
|
408
|
+
|
409
|
+
Examples
|
410
|
+
--------
|
411
|
+
.. code-block:: python
|
412
|
+
|
413
|
+
>>> import brainstate as brainstate
|
414
|
+
>>> import jax.numpy as jnp
|
415
|
+
>>> pad = brainstate.nn.ReplicationPad3d(1)
|
416
|
+
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
417
|
+
>>> output = pad(input)
|
418
|
+
>>> print(output.shape)
|
419
|
+
(1, 6, 6, 6, 3)
|
420
|
+
"""
|
421
|
+
|
422
|
+
def __init__(
|
423
|
+
self,
|
424
|
+
padding: Union[int, Sequence[int]],
|
425
|
+
in_size: Optional[Size] = None,
|
426
|
+
name: Optional[str] = None
|
427
|
+
):
|
428
|
+
super().__init__(name=name)
|
429
|
+
self.padding = _format_padding(padding, 3)
|
430
|
+
if in_size is not None:
|
431
|
+
self.in_size = in_size
|
432
|
+
y = jax.eval_shape(
|
433
|
+
functools.partial(self.update),
|
434
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
435
|
+
)
|
436
|
+
self.out_size = y.shape
|
437
|
+
|
438
|
+
def update(self, x):
|
439
|
+
# Add (0, 0) padding for non-spatial dimensions
|
440
|
+
ndim = x.ndim
|
441
|
+
if ndim == 4:
|
442
|
+
# (depth, height, width, channels) -> pad depth, height and width
|
443
|
+
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
444
|
+
elif ndim == 5:
|
445
|
+
# (batch, depth, height, width, channels) -> pad depth, height and width
|
446
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
447
|
+
else:
|
448
|
+
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
449
|
+
|
450
|
+
return jnp.pad(x, pad_width, mode='edge')
|
451
|
+
|
452
|
+
|
453
|
+
# =============================================================================
|
454
|
+
# Zero Padding
|
455
|
+
# =============================================================================
|
456
|
+
|
457
|
+
class ZeroPad1d(Module):
|
458
|
+
"""
|
459
|
+
Pads the input tensor with zeros.
|
460
|
+
|
461
|
+
Parameters
|
462
|
+
----------
|
463
|
+
padding : int or Sequence[int]
|
464
|
+
The size of the padding. Can be:
|
465
|
+
|
466
|
+
- int: same padding for both sides
|
467
|
+
- Sequence[int] of length 2: (left, right)
|
468
|
+
in_size : Size, optional
|
469
|
+
The input size.
|
470
|
+
name : str, optional
|
471
|
+
The name of the module.
|
472
|
+
|
473
|
+
Examples
|
474
|
+
--------
|
475
|
+
.. code-block:: python
|
476
|
+
|
477
|
+
>>> import brainstate as brainstate
|
478
|
+
>>> import jax.numpy as jnp
|
479
|
+
>>> pad = brainstate.nn.ZeroPad1d(2)
|
480
|
+
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
481
|
+
>>> output = pad(input)
|
482
|
+
>>> print(output.shape)
|
483
|
+
(1, 9, 1)
|
484
|
+
"""
|
485
|
+
|
486
|
+
def __init__(
|
487
|
+
self,
|
488
|
+
padding: Union[int, Sequence[int]],
|
489
|
+
in_size: Optional[Size] = None,
|
490
|
+
name: Optional[str] = None
|
491
|
+
):
|
492
|
+
super().__init__(name=name)
|
493
|
+
self.padding = _format_padding(padding, 1)
|
494
|
+
if in_size is not None:
|
495
|
+
self.in_size = in_size
|
496
|
+
y = jax.eval_shape(
|
497
|
+
functools.partial(self.update),
|
498
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
499
|
+
)
|
500
|
+
self.out_size = y.shape
|
501
|
+
|
502
|
+
def update(self, x):
|
503
|
+
# Add (0, 0) padding for non-spatial dimensions
|
504
|
+
ndim = x.ndim
|
505
|
+
if ndim == 2:
|
506
|
+
# (length, channels) -> pad only length dimension
|
507
|
+
pad_width = [self.padding[0], (0, 0)]
|
508
|
+
elif ndim == 3:
|
509
|
+
# (batch, length, channels) -> pad only length dimension
|
510
|
+
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
511
|
+
else:
|
512
|
+
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
513
|
+
|
514
|
+
return jnp.pad(x, pad_width, mode='constant', constant_values=0)
|
515
|
+
|
516
|
+
|
517
|
+
class ZeroPad2d(Module):
|
518
|
+
"""
|
519
|
+
Pads the input tensor with zeros.
|
520
|
+
|
521
|
+
Parameters
|
522
|
+
----------
|
523
|
+
padding : int or Sequence[int]
|
524
|
+
The size of the padding. Can be:
|
525
|
+
|
526
|
+
- int: same padding for all sides
|
527
|
+
- Sequence[int] of length 2: (height_pad, width_pad)
|
528
|
+
- Sequence[int] of length 4: (left, right, top, bottom)
|
529
|
+
in_size : Size, optional
|
530
|
+
The input size.
|
531
|
+
name : str, optional
|
532
|
+
The name of the module.
|
533
|
+
|
534
|
+
Examples
|
535
|
+
--------
|
536
|
+
.. code-block:: python
|
537
|
+
|
538
|
+
>>> import brainstate as brainstate
|
539
|
+
>>> import jax.numpy as jnp
|
540
|
+
>>> pad = brainstate.nn.ZeroPad2d(1)
|
541
|
+
>>> input = jnp.ones((1, 4, 4, 3))
|
542
|
+
>>> output = pad(input)
|
543
|
+
>>> print(output.shape)
|
544
|
+
(1, 6, 6, 3)
|
545
|
+
"""
|
546
|
+
|
547
|
+
def __init__(
|
548
|
+
self,
|
549
|
+
padding: Union[int, Sequence[int]],
|
550
|
+
in_size: Optional[Size] = None,
|
551
|
+
name: Optional[str] = None
|
552
|
+
):
|
553
|
+
super().__init__(name=name)
|
554
|
+
self.padding = _format_padding(padding, 2)
|
555
|
+
if in_size is not None:
|
556
|
+
self.in_size = in_size
|
557
|
+
y = jax.eval_shape(
|
558
|
+
functools.partial(self.update),
|
559
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
560
|
+
)
|
561
|
+
self.out_size = y.shape
|
562
|
+
|
563
|
+
def update(self, x):
|
564
|
+
# Add (0, 0) padding for non-spatial dimensions
|
565
|
+
ndim = x.ndim
|
566
|
+
if ndim == 3:
|
567
|
+
# (height, width, channels) -> pad height and width
|
568
|
+
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
569
|
+
elif ndim == 4:
|
570
|
+
# (batch, height, width, channels) -> pad height and width
|
571
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
572
|
+
else:
|
573
|
+
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
574
|
+
|
575
|
+
return jnp.pad(x, pad_width, mode='constant', constant_values=0)
|
576
|
+
|
577
|
+
|
578
|
+
class ZeroPad3d(Module):
|
579
|
+
"""
|
580
|
+
Pads the input tensor with zeros.
|
581
|
+
|
582
|
+
Parameters
|
583
|
+
----------
|
584
|
+
padding : int or Sequence[int]
|
585
|
+
The size of the padding. Can be:
|
586
|
+
|
587
|
+
- int: same padding for all sides
|
588
|
+
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
589
|
+
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
590
|
+
in_size : Size, optional
|
591
|
+
The input size.
|
592
|
+
name : str, optional
|
593
|
+
The name of the module.
|
594
|
+
|
595
|
+
Examples
|
596
|
+
--------
|
597
|
+
.. code-block:: python
|
598
|
+
|
599
|
+
>>> import brainstate as brainstate
|
600
|
+
>>> import jax.numpy as jnp
|
601
|
+
>>> pad = brainstate.nn.ZeroPad3d(1)
|
602
|
+
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
603
|
+
>>> output = pad(input)
|
604
|
+
>>> print(output.shape)
|
605
|
+
(1, 6, 6, 6, 3)
|
606
|
+
"""
|
607
|
+
|
608
|
+
def __init__(
|
609
|
+
self,
|
610
|
+
padding: Union[int, Sequence[int]],
|
611
|
+
in_size: Optional[Size] = None,
|
612
|
+
name: Optional[str] = None
|
613
|
+
):
|
614
|
+
super().__init__(name=name)
|
615
|
+
self.padding = _format_padding(padding, 3)
|
616
|
+
if in_size is not None:
|
617
|
+
self.in_size = in_size
|
618
|
+
y = jax.eval_shape(
|
619
|
+
functools.partial(self.update),
|
620
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
621
|
+
)
|
622
|
+
self.out_size = y.shape
|
623
|
+
|
624
|
+
def update(self, x):
|
625
|
+
# Add (0, 0) padding for non-spatial dimensions
|
626
|
+
ndim = x.ndim
|
627
|
+
if ndim == 4:
|
628
|
+
# (depth, height, width, channels) -> pad depth, height and width
|
629
|
+
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
630
|
+
elif ndim == 5:
|
631
|
+
# (batch, depth, height, width, channels) -> pad depth, height and width
|
632
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
633
|
+
else:
|
634
|
+
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
635
|
+
|
636
|
+
return jnp.pad(x, pad_width, mode='constant', constant_values=0)
|
637
|
+
|
638
|
+
|
639
|
+
# =============================================================================
|
640
|
+
# Constant Padding
|
641
|
+
# =============================================================================
|
642
|
+
|
643
|
+
class ConstantPad1d(Module):
|
644
|
+
"""
|
645
|
+
Pads the input tensor with a constant value.
|
646
|
+
|
647
|
+
Parameters
|
648
|
+
----------
|
649
|
+
padding : int or Sequence[int]
|
650
|
+
The size of the padding. Can be:
|
651
|
+
|
652
|
+
- int: same padding for both sides
|
653
|
+
- Sequence[int] of length 2: (left, right)
|
654
|
+
value : float, optional
|
655
|
+
The constant value to use for padding. Default is 0.
|
656
|
+
in_size : Size, optional
|
657
|
+
The input size.
|
658
|
+
name : str, optional
|
659
|
+
The name of the module.
|
660
|
+
|
661
|
+
Examples
|
662
|
+
--------
|
663
|
+
.. code-block:: python
|
664
|
+
|
665
|
+
>>> import brainstate as brainstate
|
666
|
+
>>> import jax.numpy as jnp
|
667
|
+
>>> pad = brainstate.nn.ConstantPad1d(2, value=3.5)
|
668
|
+
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
669
|
+
>>> output = pad(input)
|
670
|
+
>>> print(output.shape)
|
671
|
+
(1, 9, 1)
|
672
|
+
"""
|
673
|
+
|
674
|
+
def __init__(
|
675
|
+
self,
|
676
|
+
padding: Union[int, Sequence[int]],
|
677
|
+
value: float = 0,
|
678
|
+
in_size: Optional[Size] = None,
|
679
|
+
name: Optional[str] = None
|
680
|
+
):
|
681
|
+
super().__init__(name=name)
|
682
|
+
self.padding = _format_padding(padding, 1)
|
683
|
+
self.value = value
|
684
|
+
if in_size is not None:
|
685
|
+
self.in_size = in_size
|
686
|
+
y = jax.eval_shape(
|
687
|
+
functools.partial(self.update),
|
688
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
689
|
+
)
|
690
|
+
self.out_size = y.shape
|
691
|
+
|
692
|
+
def update(self, x):
|
693
|
+
# Add (0, 0) padding for non-spatial dimensions
|
694
|
+
ndim = x.ndim
|
695
|
+
if ndim == 2:
|
696
|
+
# (length, channels) -> pad only length dimension
|
697
|
+
pad_width = [self.padding[0], (0, 0)]
|
698
|
+
elif ndim == 3:
|
699
|
+
# (batch, length, channels) -> pad only length dimension
|
700
|
+
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
701
|
+
else:
|
702
|
+
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
703
|
+
|
704
|
+
return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
|
705
|
+
|
706
|
+
|
707
|
+
class ConstantPad2d(Module):
|
708
|
+
"""
|
709
|
+
Pads the input tensor with a constant value.
|
710
|
+
|
711
|
+
Parameters
|
712
|
+
----------
|
713
|
+
padding : int or Sequence[int]
|
714
|
+
The size of the padding. Can be:
|
715
|
+
|
716
|
+
- int: same padding for all sides
|
717
|
+
- Sequence[int] of length 2: (height_pad, width_pad)
|
718
|
+
- Sequence[int] of length 4: (left, right, top, bottom)
|
719
|
+
value : float, optional
|
720
|
+
The constant value to use for padding. Default is 0.
|
721
|
+
in_size : Size, optional
|
722
|
+
The input size.
|
723
|
+
name : str, optional
|
724
|
+
The name of the module.
|
725
|
+
|
726
|
+
Examples
|
727
|
+
--------
|
728
|
+
.. code-block:: python
|
729
|
+
|
730
|
+
>>> import brainstate as brainstate
|
731
|
+
>>> import jax.numpy as jnp
|
732
|
+
>>> pad = brainstate.nn.ConstantPad2d(1, value=3.5)
|
733
|
+
>>> input = jnp.ones((1, 4, 4, 3))
|
734
|
+
>>> output = pad(input)
|
735
|
+
>>> print(output.shape)
|
736
|
+
(1, 6, 6, 3)
|
737
|
+
"""
|
738
|
+
|
739
|
+
def __init__(
|
740
|
+
self,
|
741
|
+
padding: Union[int, Sequence[int]],
|
742
|
+
value: float = 0,
|
743
|
+
in_size: Optional[Size] = None,
|
744
|
+
name: Optional[str] = None
|
745
|
+
):
|
746
|
+
super().__init__(name=name)
|
747
|
+
self.padding = _format_padding(padding, 2)
|
748
|
+
self.value = value
|
749
|
+
if in_size is not None:
|
750
|
+
self.in_size = in_size
|
751
|
+
y = jax.eval_shape(
|
752
|
+
functools.partial(self.update),
|
753
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
754
|
+
)
|
755
|
+
self.out_size = y.shape
|
756
|
+
|
757
|
+
def update(self, x):
|
758
|
+
# Add (0, 0) padding for non-spatial dimensions
|
759
|
+
ndim = x.ndim
|
760
|
+
if ndim == 3:
|
761
|
+
# (height, width, channels) -> pad height and width
|
762
|
+
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
763
|
+
elif ndim == 4:
|
764
|
+
# (batch, height, width, channels) -> pad height and width
|
765
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
766
|
+
else:
|
767
|
+
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
768
|
+
|
769
|
+
return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
|
770
|
+
|
771
|
+
|
772
|
+
class ConstantPad3d(Module):
|
773
|
+
"""
|
774
|
+
Pads the input tensor with a constant value.
|
775
|
+
|
776
|
+
Parameters
|
777
|
+
----------
|
778
|
+
padding : int or Sequence[int]
|
779
|
+
The size of the padding. Can be:
|
780
|
+
|
781
|
+
- int: same padding for all sides
|
782
|
+
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
783
|
+
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
784
|
+
value : float, optional
|
785
|
+
The constant value to use for padding. Default is 0.
|
786
|
+
in_size : Size, optional
|
787
|
+
The input size.
|
788
|
+
name : str, optional
|
789
|
+
The name of the module.
|
790
|
+
|
791
|
+
Examples
|
792
|
+
--------
|
793
|
+
.. code-block:: python
|
794
|
+
|
795
|
+
>>> import brainstate as brainstate
|
796
|
+
>>> import jax.numpy as jnp
|
797
|
+
>>> pad = brainstate.nn.ConstantPad3d(1, value=3.5)
|
798
|
+
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
799
|
+
>>> output = pad(input)
|
800
|
+
>>> print(output.shape)
|
801
|
+
(1, 6, 6, 6, 3)
|
802
|
+
"""
|
803
|
+
|
804
|
+
def __init__(
|
805
|
+
self,
|
806
|
+
padding: Union[int, Sequence[int]],
|
807
|
+
value: float = 0,
|
808
|
+
in_size: Optional[Size] = None,
|
809
|
+
name: Optional[str] = None
|
810
|
+
):
|
811
|
+
super().__init__(name=name)
|
812
|
+
self.padding = _format_padding(padding, 3)
|
813
|
+
self.value = value
|
814
|
+
if in_size is not None:
|
815
|
+
self.in_size = in_size
|
816
|
+
y = jax.eval_shape(
|
817
|
+
functools.partial(self.update),
|
818
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
819
|
+
)
|
820
|
+
self.out_size = y.shape
|
821
|
+
|
822
|
+
def update(self, x):
|
823
|
+
# Add (0, 0) padding for non-spatial dimensions
|
824
|
+
ndim = x.ndim
|
825
|
+
if ndim == 4:
|
826
|
+
# (depth, height, width, channels) -> pad depth, height and width
|
827
|
+
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
828
|
+
elif ndim == 5:
|
829
|
+
# (batch, depth, height, width, channels) -> pad depth, height and width
|
830
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
831
|
+
else:
|
832
|
+
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
833
|
+
|
834
|
+
return jnp.pad(x, pad_width, mode='constant', constant_values=self.value)
|
835
|
+
|
836
|
+
|
837
|
+
# =============================================================================
|
838
|
+
# Circular Padding
|
839
|
+
# =============================================================================
|
840
|
+
|
841
|
+
class CircularPad1d(Module):
|
842
|
+
"""
|
843
|
+
Pads the input tensor using circular padding (wrap around).
|
844
|
+
|
845
|
+
Parameters
|
846
|
+
----------
|
847
|
+
padding : int or Sequence[int]
|
848
|
+
The size of the padding. Can be:
|
849
|
+
|
850
|
+
- int: same padding for both sides
|
851
|
+
- Sequence[int] of length 2: (left, right)
|
852
|
+
in_size : Size, optional
|
853
|
+
The input size.
|
854
|
+
name : str, optional
|
855
|
+
The name of the module.
|
856
|
+
|
857
|
+
Examples
|
858
|
+
--------
|
859
|
+
.. code-block:: python
|
860
|
+
|
861
|
+
>>> import brainstate as brainstate
|
862
|
+
>>> import jax.numpy as jnp
|
863
|
+
>>> pad = brainstate.nn.CircularPad1d(2)
|
864
|
+
>>> input = jnp.array([[[1, 2, 3, 4, 5]]])
|
865
|
+
>>> output = pad(input)
|
866
|
+
>>> print(output.shape)
|
867
|
+
(1, 9, 1)
|
868
|
+
"""
|
869
|
+
|
870
|
+
def __init__(
|
871
|
+
self,
|
872
|
+
padding: Union[int, Sequence[int]],
|
873
|
+
in_size: Optional[Size] = None,
|
874
|
+
name: Optional[str] = None
|
875
|
+
):
|
876
|
+
super().__init__(name=name)
|
877
|
+
self.padding = _format_padding(padding, 1)
|
878
|
+
if in_size is not None:
|
879
|
+
self.in_size = in_size
|
880
|
+
y = jax.eval_shape(
|
881
|
+
functools.partial(self.update),
|
882
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
883
|
+
)
|
884
|
+
self.out_size = y.shape
|
885
|
+
|
886
|
+
def update(self, x):
|
887
|
+
# Add (0, 0) padding for non-spatial dimensions
|
888
|
+
ndim = x.ndim
|
889
|
+
if ndim == 2:
|
890
|
+
# (length, channels) -> pad only length dimension
|
891
|
+
pad_width = [self.padding[0], (0, 0)]
|
892
|
+
elif ndim == 3:
|
893
|
+
# (batch, length, channels) -> pad only length dimension
|
894
|
+
pad_width = [(0, 0), self.padding[0], (0, 0)]
|
895
|
+
else:
|
896
|
+
raise ValueError(f"Expected 2D or 3D input, got {ndim}D")
|
897
|
+
|
898
|
+
return jnp.pad(x, pad_width, mode='wrap')
|
899
|
+
|
900
|
+
|
901
|
+
class CircularPad2d(Module):
|
902
|
+
"""
|
903
|
+
Pads the input tensor using circular padding (wrap around).
|
904
|
+
|
905
|
+
Parameters
|
906
|
+
----------
|
907
|
+
padding : int or Sequence[int]
|
908
|
+
The size of the padding. Can be:
|
909
|
+
|
910
|
+
- int: same padding for all sides
|
911
|
+
- Sequence[int] of length 2: (height_pad, width_pad)
|
912
|
+
- Sequence[int] of length 4: (left, right, top, bottom)
|
913
|
+
in_size : Size, optional
|
914
|
+
The input size.
|
915
|
+
name : str, optional
|
916
|
+
The name of the module.
|
917
|
+
|
918
|
+
Examples
|
919
|
+
--------
|
920
|
+
.. code-block:: python
|
921
|
+
|
922
|
+
>>> import brainstate as brainstate
|
923
|
+
>>> import jax.numpy as jnp
|
924
|
+
>>> pad = brainstate.nn.CircularPad2d(1)
|
925
|
+
>>> input = jnp.ones((1, 4, 4, 3))
|
926
|
+
>>> output = pad(input)
|
927
|
+
>>> print(output.shape)
|
928
|
+
(1, 6, 6, 3)
|
929
|
+
"""
|
930
|
+
|
931
|
+
def __init__(
|
932
|
+
self,
|
933
|
+
padding: Union[int, Sequence[int]],
|
934
|
+
in_size: Optional[Size] = None,
|
935
|
+
name: Optional[str] = None
|
936
|
+
):
|
937
|
+
super().__init__(name=name)
|
938
|
+
self.padding = _format_padding(padding, 2)
|
939
|
+
if in_size is not None:
|
940
|
+
self.in_size = in_size
|
941
|
+
y = jax.eval_shape(
|
942
|
+
functools.partial(self.update),
|
943
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
944
|
+
)
|
945
|
+
self.out_size = y.shape
|
946
|
+
|
947
|
+
def update(self, x):
|
948
|
+
# Add (0, 0) padding for non-spatial dimensions
|
949
|
+
ndim = x.ndim
|
950
|
+
if ndim == 3:
|
951
|
+
# (height, width, channels) -> pad height and width
|
952
|
+
pad_width = [self.padding[0], self.padding[1], (0, 0)]
|
953
|
+
elif ndim == 4:
|
954
|
+
# (batch, height, width, channels) -> pad height and width
|
955
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], (0, 0)]
|
956
|
+
else:
|
957
|
+
raise ValueError(f"Expected 3D or 4D input, got {ndim}D")
|
958
|
+
|
959
|
+
return jnp.pad(x, pad_width, mode='wrap')
|
960
|
+
|
961
|
+
|
962
|
+
class CircularPad3d(Module):
|
963
|
+
"""
|
964
|
+
Pads the input tensor using circular padding (wrap around).
|
965
|
+
|
966
|
+
Parameters
|
967
|
+
----------
|
968
|
+
padding : int or Sequence[int]
|
969
|
+
The size of the padding. Can be:
|
970
|
+
|
971
|
+
- int: same padding for all sides
|
972
|
+
- Sequence[int] of length 3: (depth_pad, height_pad, width_pad)
|
973
|
+
- Sequence[int] of length 6: (left, right, top, bottom, front, back)
|
974
|
+
in_size : Size, optional
|
975
|
+
The input size.
|
976
|
+
name : str, optional
|
977
|
+
The name of the module.
|
978
|
+
|
979
|
+
Examples
|
980
|
+
--------
|
981
|
+
.. code-block:: python
|
982
|
+
|
983
|
+
>>> import brainstate as brainstate
|
984
|
+
>>> import jax.numpy as jnp
|
985
|
+
>>> pad = brainstate.nn.CircularPad3d(1)
|
986
|
+
>>> input = jnp.ones((1, 4, 4, 4, 3))
|
987
|
+
>>> output = pad(input)
|
988
|
+
>>> print(output.shape)
|
989
|
+
(1, 6, 6, 6, 3)
|
990
|
+
"""
|
991
|
+
|
992
|
+
def __init__(
|
993
|
+
self,
|
994
|
+
padding: Union[int, Sequence[int]],
|
995
|
+
in_size: Optional[Size] = None,
|
996
|
+
name: Optional[str] = None
|
997
|
+
):
|
998
|
+
super().__init__(name=name)
|
999
|
+
self.padding = _format_padding(padding, 3)
|
1000
|
+
if in_size is not None:
|
1001
|
+
self.in_size = in_size
|
1002
|
+
y = jax.eval_shape(
|
1003
|
+
functools.partial(self.update),
|
1004
|
+
jax.ShapeDtypeStruct(self.in_size, environ.dftype())
|
1005
|
+
)
|
1006
|
+
self.out_size = y.shape
|
1007
|
+
|
1008
|
+
def update(self, x):
|
1009
|
+
# Add (0, 0) padding for non-spatial dimensions
|
1010
|
+
ndim = x.ndim
|
1011
|
+
if ndim == 4:
|
1012
|
+
# (depth, height, width, channels) -> pad depth, height and width
|
1013
|
+
pad_width = [self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
1014
|
+
elif ndim == 5:
|
1015
|
+
# (batch, depth, height, width, channels) -> pad depth, height and width
|
1016
|
+
pad_width = [(0, 0), self.padding[0], self.padding[1], self.padding[2], (0, 0)]
|
1017
|
+
else:
|
1018
|
+
raise ValueError(f"Expected 4D or 5D input, got {ndim}D")
|
1019
|
+
|
1020
|
+
return jnp.pad(x, pad_width, mode='wrap')
|