brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,489 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import math
19
+
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+
23
+ from brainstate import environ, random
24
+ from ._base import Initializer, to_size
25
+
26
+ __all__ = [
27
+ 'Normal',
28
+ 'TruncatedNormal',
29
+ 'Uniform',
30
+ 'VarianceScaling',
31
+ 'KaimingUniform',
32
+ 'KaimingNormal',
33
+ 'XavierUniform',
34
+ 'XavierNormal',
35
+ 'LecunUniform',
36
+ 'LecunNormal',
37
+ 'Orthogonal',
38
+ 'DeltaOrthogonal',
39
+ ]
40
+
41
+
42
+ def calculate_gain(nonlinearity, param=None):
43
+ r"""Return the recommended gain value for the given nonlinearity function.
44
+ The values are as follows:
45
+
46
+ ================= ====================================================
47
+ nonlinearity gain
48
+ ================= ====================================================
49
+ Linear / Identity :math:`1`
50
+ Conv{1,2,3}D :math:`1`
51
+ Sigmoid :math:`1`
52
+ Tanh :math:`\frac{5}{3}`
53
+ ReLU :math:`\sqrt{2}`
54
+ Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
55
+ SELU :math:`\frac{3}{4}`
56
+ ================= ====================================================
57
+
58
+ .. warning::
59
+ In order to implement `Self-Normalizing Neural Networks`_ ,
60
+ you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
61
+ This gives the initial weights a variance of ``1 / N``,
62
+ which is necessary to induce a stable fixed point in the forward pass.
63
+ In contrast, the default gain for ``SELU`` sacrifices the normalisation
64
+ effect for more stable gradient flow in rectangular layers.
65
+
66
+ Args:
67
+ nonlinearity: the non-linear function (`nn.functional` name)
68
+ param: optional parameter for the non-linear function
69
+
70
+ .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
71
+ """
72
+ linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
73
+ if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
74
+ return 1
75
+ elif nonlinearity == 'tanh':
76
+ return 5.0 / 3
77
+ elif nonlinearity == 'relu':
78
+ return math.sqrt(2.0)
79
+ elif nonlinearity == 'leaky_relu':
80
+ if param is None:
81
+ negative_slope = 0.01
82
+ elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
83
+ # True/False are instances of int, hence check above
84
+ negative_slope = param
85
+ else:
86
+ raise ValueError("negative_slope {} not a valid number".format(param))
87
+ return math.sqrt(2.0 / (1 + negative_slope ** 2))
88
+ elif nonlinearity == 'selu':
89
+ return 3.0 / 4
90
+ else:
91
+ raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
92
+
93
+
94
+ def _format_shape(shape):
95
+ if isinstance(shape, int):
96
+ return (shape,)
97
+ if len(shape) == 0:
98
+ raise ValueError('Please provide shape.')
99
+ if len(shape) == 1:
100
+ if isinstance(shape[0], (tuple, list)):
101
+ return shape[0]
102
+ else:
103
+ return shape
104
+ else:
105
+ return shape
106
+
107
+
108
+ def _compute_fans(shape, in_axis=-2, out_axis=-1):
109
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
110
+ fan_in = shape[in_axis] * receptive_field_size
111
+ fan_out = shape[out_axis] * receptive_field_size
112
+ return fan_in, fan_out
113
+
114
+
115
+ class Normal(Initializer):
116
+ """Initialize weights with normal distribution.
117
+
118
+ Parameters
119
+ ----------
120
+ scale : float
121
+ The gain of the derivation of the normal distribution.
122
+
123
+ """
124
+
125
+ def __init__(self, mean=0., scale=1., dtype=None):
126
+ super(Normal, self).__init__()
127
+ self.scale = scale
128
+ self.mean = mean
129
+ self.dtype = dtype or environ.dftype()
130
+
131
+ def __call__(self, shape):
132
+ shape = to_size(shape)
133
+ weights = random.normal(size=shape, loc=self.mean, scale=self.scale, dtype=self.dtype)
134
+ return weights
135
+
136
+ def __repr__(self):
137
+ return f'{self.__class__.__name__}(scale={self.scale}, dtype={self.dtype})'
138
+
139
+
140
+ class TruncatedNormal(Initializer):
141
+ """Initialize weights with truncated normal distribution.
142
+
143
+ Parameters
144
+ ----------
145
+ loc : float, ndarray
146
+ Mean ("centre") of the distribution before truncating. Note that
147
+ the mean of the truncated distribution will not be exactly equal
148
+ to ``loc``.
149
+ scale : float
150
+ The standard deviation of the normal distribution before truncating.
151
+ lower : float, ndarray
152
+ A float or array of floats representing the lower bound for
153
+ truncation. Must be broadcast-compatible with ``upper``.
154
+ upper : float, ndarray
155
+ A float or array of floats representing the upper bound for
156
+ truncation. Must be broadcast-compatible with ``lower``.
157
+
158
+ """
159
+
160
+ def __init__(self, loc=0., scale=1., lower=None, upper=None, dtype=None):
161
+ super(TruncatedNormal, self).__init__()
162
+ assert scale > 0, '`scale` must be positive.'
163
+ self.scale = scale
164
+ self.loc = loc
165
+ self.lower = lower
166
+ self.upper = upper
167
+ self.dtype = dtype or environ.dftype()
168
+
169
+ def __call__(self, shape):
170
+ weights = random.truncated_normal(
171
+ size=shape,
172
+ scale=self.scale,
173
+ lower=self.lower,
174
+ upper=self.upper,
175
+ loc=self.loc,
176
+ dtype=self.dtype
177
+ )
178
+ return weights
179
+
180
+ def __repr__(self):
181
+ return (f'{self.__class__.__name__}(loc={self.loc}, scale={self.scale}, '
182
+ f'lower={self.lower}, upper={self.upper}, dtype={self.dtype})')
183
+
184
+
185
+ class Gamma(Initializer):
186
+ """Initialize weights with Gamma distribution.
187
+
188
+ Parameters
189
+ ----------
190
+ shape: float, Array
191
+ Shape parameter.
192
+ scale: float, Array
193
+ The gain of the derivation of the Gamma distribution.
194
+
195
+ """
196
+
197
+ def __init__(self, shape, scale=None, dtype=None):
198
+ self.shape = shape
199
+ self.scale = scale
200
+ self.dtype = dtype or environ.dftype()
201
+
202
+ def __call__(self, shape):
203
+ shape = to_size(shape)
204
+ weights = random.gamma(self.shape, scale=self.scale, size=shape, dtype=self.dtype)
205
+ return weights
206
+
207
+ def __repr__(self):
208
+ return f'{self.__class__.__name__}(shape={self.shape}, scale={self.scale}, dtype={self.dtype})'
209
+
210
+
211
+ class Exponential(Initializer):
212
+ """Initialize weights with Gamma distribution.
213
+
214
+ Parameters
215
+ ----------
216
+ scale: float, Array
217
+ The gain of the derivation of the Exponential distribution.
218
+
219
+ """
220
+
221
+ def __init__(self, scale=None, dtype=None):
222
+ self.scale = scale
223
+ self.dtype = dtype or environ.dftype()
224
+
225
+ def __call__(self, shape):
226
+ shape = to_size(shape)
227
+ weights = random.exponential(scale=self.scale, size=shape, dtype=self.dtype)
228
+ return weights
229
+
230
+ def __repr__(self):
231
+ return f'{self.__class__.__name__}(scale={self.scale}, dtype={self.dtype})'
232
+
233
+
234
+ class Uniform(Initializer):
235
+ """Initialize weights with uniform distribution.
236
+
237
+ Parameters
238
+ ----------
239
+ min_val : float
240
+ The lower limit of the uniform distribution.
241
+ max_val : float
242
+ The upper limit of the uniform distribution.
243
+ """
244
+
245
+ def __init__(self, min_val: float = 0., max_val: float = 1., dtype=None):
246
+ super(Uniform, self).__init__()
247
+ self.min_val = min_val
248
+ self.max_val = max_val
249
+ self.dtype = dtype or environ.dftype()
250
+
251
+ def __call__(self, shape):
252
+ shape = to_size(shape)
253
+ return random.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=self.dtype)
254
+
255
+ def __repr__(self):
256
+ return (f'{self.__class__.__name__}(min_val={self.min_val}, '
257
+ f'max_val={self.max_val}, dtype={self.dtype})')
258
+
259
+
260
+ class VarianceScaling(Initializer):
261
+ def __init__(
262
+ self,
263
+ scale: float,
264
+ mode: str,
265
+ distribution: str,
266
+ in_axis: int = -2,
267
+ out_axis: int = -1,
268
+ dtype=None
269
+ ):
270
+ assert mode in ['fan_in', 'fan_out', 'fan_avg']
271
+ assert distribution in ['truncated_normal', 'normal', 'uniform']
272
+ self.scale = scale
273
+ self.mode = mode
274
+ self.in_axis = in_axis
275
+ self.out_axis = out_axis
276
+ self.distribution = distribution
277
+ self.dtype = dtype or environ.dftype()
278
+
279
+ def __call__(self, shape):
280
+ shape = to_size(shape)
281
+ fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
282
+ if self.mode == "fan_in":
283
+ denominator = fan_in
284
+ elif self.mode == "fan_out":
285
+ denominator = fan_out
286
+ elif self.mode == "fan_avg":
287
+ denominator = (fan_in + fan_out) / 2
288
+ else:
289
+ raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
290
+ variance = (self.scale / denominator).astype(self.dtype)
291
+ if self.distribution == "truncated_normal":
292
+ stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
293
+ res = random.truncated_normal(-2, 2, shape, dtype=self.dtype) * stddev
294
+ elif self.distribution == "normal":
295
+ res = random.randn(*shape, dtype=self.dtype) * jnp.sqrt(variance).astype(self.dtype)
296
+ elif self.distribution == "uniform":
297
+ res = (random.uniform(low=-1, high=1, size=shape, dtype=self.dtype) *
298
+ jnp.sqrt(3 * variance).astype(self.dtype))
299
+ else:
300
+ raise ValueError("invalid distribution for variance scaling initializer")
301
+ return res
302
+
303
+ def __repr__(self):
304
+ name = self.__class__.__name__
305
+ blank = ' ' * len(name)
306
+ return (f'{name}(scale={self.scale}, mode={self.mode}, in_axis={self.in_axis}, \n'
307
+ f'{blank}out_axis={self.out_axis}, distribution={self.distribution}, dtype={self.dtype})')
308
+
309
+
310
+ class KaimingUniform(VarianceScaling):
311
+ def __init__(
312
+ self,
313
+ scale: float = 2.0,
314
+ mode: str = "fan_in",
315
+ distribution: str = "uniform",
316
+ in_axis: int = -2,
317
+ out_axis: int = -1,
318
+ dtype=None
319
+ ):
320
+ super().__init__(scale,
321
+ mode,
322
+ distribution,
323
+ in_axis=in_axis,
324
+ out_axis=out_axis,
325
+ dtype=dtype)
326
+
327
+
328
+ class KaimingNormal(VarianceScaling):
329
+ def __init__(
330
+ self,
331
+ scale: float = 2.0,
332
+ mode: str = "fan_in",
333
+ distribution: str = "truncated_normal",
334
+ in_axis: int = -2,
335
+ out_axis: int = -1,
336
+ dtype=None
337
+ ):
338
+ super().__init__(scale,
339
+ mode,
340
+ distribution,
341
+ in_axis=in_axis,
342
+ out_axis=out_axis,
343
+ dtype=dtype)
344
+
345
+
346
+ class XavierUniform(VarianceScaling):
347
+ def __init__(
348
+ self,
349
+ scale: float = 1.0,
350
+ mode: str = "fan_avg",
351
+ distribution: str = "uniform",
352
+ in_axis: int = -2,
353
+ out_axis: int = -1,
354
+ dtype=None
355
+ ):
356
+ super().__init__(scale,
357
+ mode,
358
+ distribution,
359
+ in_axis=in_axis,
360
+ out_axis=out_axis,
361
+ dtype=dtype)
362
+
363
+
364
+ class XavierNormal(VarianceScaling):
365
+ def __init__(
366
+ self,
367
+ scale: float = 1.0,
368
+ mode: str = "fan_avg",
369
+ distribution: str = "truncated_normal",
370
+ in_axis: int = -2,
371
+ out_axis: int = -1,
372
+ dtype=None
373
+ ):
374
+ super().__init__(scale,
375
+ mode,
376
+ distribution,
377
+ in_axis=in_axis,
378
+ out_axis=out_axis,
379
+ dtype=dtype)
380
+
381
+
382
+ class LecunUniform(VarianceScaling):
383
+ def __init__(
384
+ self,
385
+ scale: float = 1.0,
386
+ mode: str = "fan_in",
387
+ distribution: str = "uniform",
388
+ in_axis: int = -2,
389
+ out_axis: int = -1,
390
+ dtype=None
391
+ ):
392
+ super().__init__(scale,
393
+ mode,
394
+ distribution,
395
+ in_axis=in_axis,
396
+ out_axis=out_axis,
397
+ dtype=dtype)
398
+
399
+
400
+ class LecunNormal(VarianceScaling):
401
+ def __init__(
402
+ self,
403
+ scale: float = 1.0,
404
+ mode: str = "fan_in",
405
+ distribution: str = "truncated_normal",
406
+ in_axis: int = -2,
407
+ out_axis: int = -1,
408
+ dtype=None
409
+ ):
410
+ super().__init__(scale,
411
+ mode,
412
+ distribution,
413
+ in_axis=in_axis,
414
+ out_axis=out_axis,
415
+ dtype=dtype)
416
+
417
+
418
+ class Orthogonal(Initializer):
419
+ """
420
+ Construct an initializer for uniformly distributed orthogonal matrices.
421
+
422
+ If the shape is not square, the matrix will have orthonormal rows or columns
423
+ depending on which side is smaller.
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ scale: float = 1.,
429
+ axis: int = -1,
430
+ dtype=None
431
+ ):
432
+ super().__init__()
433
+ self.scale = scale
434
+ self.axis = axis
435
+ self.dtype = dtype or environ.dftype()
436
+
437
+ def __call__(self, shape):
438
+ shape = to_size(shape)
439
+ n_rows = shape[self.axis]
440
+ n_cols = np.prod(shape) // n_rows
441
+ matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
442
+ norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
443
+ q_mat, r_mat = jnp.linalg.qr(norm_dst)
444
+ # Enforce Q is uniformly distributed
445
+ q_mat *= jnp.sign(jnp.diag(r_mat))
446
+ if n_rows < n_cols:
447
+ q_mat = q_mat.T
448
+ q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
449
+ q_mat = jnp.moveaxis(q_mat, 0, self.axis)
450
+ return jnp.asarray(self.scale, dtype=self.dtype) * q_mat
451
+
452
+ def __repr__(self):
453
+ return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
454
+
455
+
456
+ class DeltaOrthogonal(Initializer):
457
+ """
458
+ Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
459
+
460
+ The shape must be 3D, 4D or 5D.
461
+ """
462
+
463
+ def __init__(self, scale=1.0, axis=-1, dtype=None):
464
+ super(DeltaOrthogonal, self).__init__()
465
+ self.scale = scale
466
+ self.axis = axis
467
+ self.dtype = dtype or environ.dftype()
468
+
469
+ def __call__(self, shape):
470
+ shape = to_size(shape)
471
+ if len(shape) not in [3, 4, 5]:
472
+ raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
473
+ if shape[-1] < shape[-2]:
474
+ raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
475
+ ortho_matrix = Orthogonal(scale=self.scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
476
+ W = jnp.zeros(shape, dtype=self.dtype)
477
+ if len(shape) == 3:
478
+ k = shape[0]
479
+ W = W.at[(k - 1) // 2].set(ortho_matrix)
480
+ elif len(shape) == 4:
481
+ k1, k2 = shape[:2]
482
+ W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
483
+ else:
484
+ k1, k2, k3 = shape[:3]
485
+ W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
486
+ return W
487
+
488
+ def __repr__(self):
489
+ return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import jax.numpy as jnp
19
+
20
+ from brainstate import environ
21
+ from ._base import Initializer, to_size
22
+
23
+ __all__ = [
24
+ 'ZeroInit',
25
+ 'Constant',
26
+ 'Identity',
27
+ ]
28
+
29
+
30
+ class ZeroInit(Initializer):
31
+ """Zero initializer.
32
+
33
+ Initialize the weights with zeros.
34
+ """
35
+
36
+ def __init__(self, dtype=None):
37
+ super(ZeroInit, self).__init__()
38
+ self.dtype = dtype or environ.dftype()
39
+
40
+ def __call__(self, shape):
41
+ shape = to_size(shape)
42
+ return jnp.zeros(shape, dtype=self.dtype)
43
+
44
+ def __repr__(self):
45
+ return f"{self.__class__.__name__}(dtype={self.dtype})"
46
+
47
+
48
+ class Constant(Initializer):
49
+ """Constant initializer.
50
+
51
+ Initialize the weights with the given values.
52
+
53
+ Parameters
54
+ ----------
55
+ value : float, int, bm.ndarray
56
+ The value to specify.
57
+ """
58
+
59
+ def __init__(self, value=1., dtype=None):
60
+ super(Constant, self).__init__()
61
+ self.dtype = dtype or environ.dftype()
62
+ self.value = jnp.asarray(value, dtype=self.dtype)
63
+
64
+ def __call__(self, shape):
65
+ shape = to_size(shape)
66
+ return jnp.full(shape, self.value, dtype=self.dtype)
67
+
68
+ def __repr__(self):
69
+ return f'{self.__class__.__name__}(value={self.value}, dtype={self.dtype})'
70
+
71
+
72
+ class Identity(Initializer):
73
+ """Returns the identity matrix.
74
+
75
+ This initializer was proposed in (Le, et al., 2015) [1]_.
76
+
77
+ Parameters
78
+ ----------
79
+ value : float
80
+ The optional scaling factor.
81
+
82
+ Returns
83
+ -------
84
+ shape: tuple of int
85
+ The weight shape/size.
86
+
87
+ References
88
+ ----------
89
+ .. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
90
+ initialize recurrent networks of rectified linear units." arXiv preprint
91
+ arXiv:1504.00941 (2015).
92
+ """
93
+
94
+ def __init__(self, value=1., dtype=None):
95
+ super(Identity, self).__init__()
96
+ self.dtype = dtype or environ.dftype()
97
+ self.value = jnp.asarray(value, dtype=self.dtype)
98
+
99
+ def __call__(self, shape):
100
+ shape = to_size(shape)
101
+ if isinstance(shape, (tuple, list)):
102
+ if len(shape) > 2:
103
+ raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
104
+ r = jnp.eye(shape, dtype=self.dtype)
105
+ r = jnp.fill_diagonal(r, self.value)
106
+ return r
107
+
108
+ def __repr__(self):
109
+ return f'{self.__class__.__name__}(value={self.value}, dtype={self.dtype})'