brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240612__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +43 -5
- brainstate/_state.py +17 -0
- brainstate/environ.py +2 -1
- brainstate/functional/__init__.py +3 -2
- brainstate/functional/_activations.py +1 -1
- brainstate/functional/_normalization.py +3 -0
- brainstate/functional/_others.py +49 -0
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_base.py +10 -7
- brainstate/nn/_dynamics.py +20 -0
- brainstate/nn/_embedding.py +66 -0
- brainstate/nn/_rate_rnns.py +17 -0
- brainstate/nn/_readout.py +6 -0
- brainstate/optim/_lr_scheduler_test.py +13 -0
- brainstate/transform/_jit.py +47 -21
- brainstate/transform/_make_jaxpr.py +165 -3
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/METADATA +8 -6
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/RECORD +21 -29
- brainstate/nn/functional/__init__.py +0 -25
- brainstate/nn/functional/_activations.py +0 -754
- brainstate/nn/functional/_normalization.py +0 -69
- brainstate/nn/functional/_spikes.py +0 -90
- brainstate/nn/init/__init__.py +0 -26
- brainstate/nn/init/_base.py +0 -36
- brainstate/nn/init/_generic.py +0 -175
- brainstate/nn/init/_random_inits.py +0 -489
- brainstate/nn/init/_regular_inits.py +0 -109
- brainstate/nn/surrogate.py +0 -1740
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/top_level.txt +0 -0
@@ -1,489 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import math
|
19
|
-
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import numpy as np
|
22
|
-
|
23
|
-
from brainstate import environ, random
|
24
|
-
from ._base import Initializer, to_size
|
25
|
-
|
26
|
-
__all__ = [
|
27
|
-
'Normal',
|
28
|
-
'TruncatedNormal',
|
29
|
-
'Uniform',
|
30
|
-
'VarianceScaling',
|
31
|
-
'KaimingUniform',
|
32
|
-
'KaimingNormal',
|
33
|
-
'XavierUniform',
|
34
|
-
'XavierNormal',
|
35
|
-
'LecunUniform',
|
36
|
-
'LecunNormal',
|
37
|
-
'Orthogonal',
|
38
|
-
'DeltaOrthogonal',
|
39
|
-
]
|
40
|
-
|
41
|
-
|
42
|
-
def calculate_gain(nonlinearity, param=None):
|
43
|
-
r"""Return the recommended gain value for the given nonlinearity function.
|
44
|
-
The values are as follows:
|
45
|
-
|
46
|
-
================= ====================================================
|
47
|
-
nonlinearity gain
|
48
|
-
================= ====================================================
|
49
|
-
Linear / Identity :math:`1`
|
50
|
-
Conv{1,2,3}D :math:`1`
|
51
|
-
Sigmoid :math:`1`
|
52
|
-
Tanh :math:`\frac{5}{3}`
|
53
|
-
ReLU :math:`\sqrt{2}`
|
54
|
-
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
55
|
-
SELU :math:`\frac{3}{4}`
|
56
|
-
================= ====================================================
|
57
|
-
|
58
|
-
.. warning::
|
59
|
-
In order to implement `Self-Normalizing Neural Networks`_ ,
|
60
|
-
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
|
61
|
-
This gives the initial weights a variance of ``1 / N``,
|
62
|
-
which is necessary to induce a stable fixed point in the forward pass.
|
63
|
-
In contrast, the default gain for ``SELU`` sacrifices the normalisation
|
64
|
-
effect for more stable gradient flow in rectangular layers.
|
65
|
-
|
66
|
-
Args:
|
67
|
-
nonlinearity: the non-linear function (`nn.functional` name)
|
68
|
-
param: optional parameter for the non-linear function
|
69
|
-
|
70
|
-
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
|
71
|
-
"""
|
72
|
-
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
73
|
-
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
74
|
-
return 1
|
75
|
-
elif nonlinearity == 'tanh':
|
76
|
-
return 5.0 / 3
|
77
|
-
elif nonlinearity == 'relu':
|
78
|
-
return math.sqrt(2.0)
|
79
|
-
elif nonlinearity == 'leaky_relu':
|
80
|
-
if param is None:
|
81
|
-
negative_slope = 0.01
|
82
|
-
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
83
|
-
# True/False are instances of int, hence check above
|
84
|
-
negative_slope = param
|
85
|
-
else:
|
86
|
-
raise ValueError("negative_slope {} not a valid number".format(param))
|
87
|
-
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
88
|
-
elif nonlinearity == 'selu':
|
89
|
-
return 3.0 / 4
|
90
|
-
else:
|
91
|
-
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
92
|
-
|
93
|
-
|
94
|
-
def _format_shape(shape):
|
95
|
-
if isinstance(shape, int):
|
96
|
-
return (shape,)
|
97
|
-
if len(shape) == 0:
|
98
|
-
raise ValueError('Please provide shape.')
|
99
|
-
if len(shape) == 1:
|
100
|
-
if isinstance(shape[0], (tuple, list)):
|
101
|
-
return shape[0]
|
102
|
-
else:
|
103
|
-
return shape
|
104
|
-
else:
|
105
|
-
return shape
|
106
|
-
|
107
|
-
|
108
|
-
def _compute_fans(shape, in_axis=-2, out_axis=-1):
|
109
|
-
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
110
|
-
fan_in = shape[in_axis] * receptive_field_size
|
111
|
-
fan_out = shape[out_axis] * receptive_field_size
|
112
|
-
return fan_in, fan_out
|
113
|
-
|
114
|
-
|
115
|
-
class Normal(Initializer):
|
116
|
-
"""Initialize weights with normal distribution.
|
117
|
-
|
118
|
-
Parameters
|
119
|
-
----------
|
120
|
-
scale : float
|
121
|
-
The gain of the derivation of the normal distribution.
|
122
|
-
|
123
|
-
"""
|
124
|
-
|
125
|
-
def __init__(self, mean=0., scale=1., dtype=None):
|
126
|
-
super(Normal, self).__init__()
|
127
|
-
self.scale = scale
|
128
|
-
self.mean = mean
|
129
|
-
self.dtype = dtype or environ.dftype()
|
130
|
-
|
131
|
-
def __call__(self, shape):
|
132
|
-
shape = to_size(shape)
|
133
|
-
weights = random.normal(size=shape, loc=self.mean, scale=self.scale, dtype=self.dtype)
|
134
|
-
return weights
|
135
|
-
|
136
|
-
def __repr__(self):
|
137
|
-
return f'{self.__class__.__name__}(scale={self.scale}, dtype={self.dtype})'
|
138
|
-
|
139
|
-
|
140
|
-
class TruncatedNormal(Initializer):
|
141
|
-
"""Initialize weights with truncated normal distribution.
|
142
|
-
|
143
|
-
Parameters
|
144
|
-
----------
|
145
|
-
loc : float, ndarray
|
146
|
-
Mean ("centre") of the distribution before truncating. Note that
|
147
|
-
the mean of the truncated distribution will not be exactly equal
|
148
|
-
to ``loc``.
|
149
|
-
scale : float
|
150
|
-
The standard deviation of the normal distribution before truncating.
|
151
|
-
lower : float, ndarray
|
152
|
-
A float or array of floats representing the lower bound for
|
153
|
-
truncation. Must be broadcast-compatible with ``upper``.
|
154
|
-
upper : float, ndarray
|
155
|
-
A float or array of floats representing the upper bound for
|
156
|
-
truncation. Must be broadcast-compatible with ``lower``.
|
157
|
-
|
158
|
-
"""
|
159
|
-
|
160
|
-
def __init__(self, loc=0., scale=1., lower=None, upper=None, dtype=None):
|
161
|
-
super(TruncatedNormal, self).__init__()
|
162
|
-
assert scale > 0, '`scale` must be positive.'
|
163
|
-
self.scale = scale
|
164
|
-
self.loc = loc
|
165
|
-
self.lower = lower
|
166
|
-
self.upper = upper
|
167
|
-
self.dtype = dtype or environ.dftype()
|
168
|
-
|
169
|
-
def __call__(self, shape):
|
170
|
-
weights = random.truncated_normal(
|
171
|
-
size=shape,
|
172
|
-
scale=self.scale,
|
173
|
-
lower=self.lower,
|
174
|
-
upper=self.upper,
|
175
|
-
loc=self.loc,
|
176
|
-
dtype=self.dtype
|
177
|
-
)
|
178
|
-
return weights
|
179
|
-
|
180
|
-
def __repr__(self):
|
181
|
-
return (f'{self.__class__.__name__}(loc={self.loc}, scale={self.scale}, '
|
182
|
-
f'lower={self.lower}, upper={self.upper}, dtype={self.dtype})')
|
183
|
-
|
184
|
-
|
185
|
-
class Gamma(Initializer):
|
186
|
-
"""Initialize weights with Gamma distribution.
|
187
|
-
|
188
|
-
Parameters
|
189
|
-
----------
|
190
|
-
shape: float, Array
|
191
|
-
Shape parameter.
|
192
|
-
scale: float, Array
|
193
|
-
The gain of the derivation of the Gamma distribution.
|
194
|
-
|
195
|
-
"""
|
196
|
-
|
197
|
-
def __init__(self, shape, scale=None, dtype=None):
|
198
|
-
self.shape = shape
|
199
|
-
self.scale = scale
|
200
|
-
self.dtype = dtype or environ.dftype()
|
201
|
-
|
202
|
-
def __call__(self, shape):
|
203
|
-
shape = to_size(shape)
|
204
|
-
weights = random.gamma(self.shape, scale=self.scale, size=shape, dtype=self.dtype)
|
205
|
-
return weights
|
206
|
-
|
207
|
-
def __repr__(self):
|
208
|
-
return f'{self.__class__.__name__}(shape={self.shape}, scale={self.scale}, dtype={self.dtype})'
|
209
|
-
|
210
|
-
|
211
|
-
class Exponential(Initializer):
|
212
|
-
"""Initialize weights with Gamma distribution.
|
213
|
-
|
214
|
-
Parameters
|
215
|
-
----------
|
216
|
-
scale: float, Array
|
217
|
-
The gain of the derivation of the Exponential distribution.
|
218
|
-
|
219
|
-
"""
|
220
|
-
|
221
|
-
def __init__(self, scale=None, dtype=None):
|
222
|
-
self.scale = scale
|
223
|
-
self.dtype = dtype or environ.dftype()
|
224
|
-
|
225
|
-
def __call__(self, shape):
|
226
|
-
shape = to_size(shape)
|
227
|
-
weights = random.exponential(scale=self.scale, size=shape, dtype=self.dtype)
|
228
|
-
return weights
|
229
|
-
|
230
|
-
def __repr__(self):
|
231
|
-
return f'{self.__class__.__name__}(scale={self.scale}, dtype={self.dtype})'
|
232
|
-
|
233
|
-
|
234
|
-
class Uniform(Initializer):
|
235
|
-
"""Initialize weights with uniform distribution.
|
236
|
-
|
237
|
-
Parameters
|
238
|
-
----------
|
239
|
-
min_val : float
|
240
|
-
The lower limit of the uniform distribution.
|
241
|
-
max_val : float
|
242
|
-
The upper limit of the uniform distribution.
|
243
|
-
"""
|
244
|
-
|
245
|
-
def __init__(self, min_val: float = 0., max_val: float = 1., dtype=None):
|
246
|
-
super(Uniform, self).__init__()
|
247
|
-
self.min_val = min_val
|
248
|
-
self.max_val = max_val
|
249
|
-
self.dtype = dtype or environ.dftype()
|
250
|
-
|
251
|
-
def __call__(self, shape):
|
252
|
-
shape = to_size(shape)
|
253
|
-
return random.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=self.dtype)
|
254
|
-
|
255
|
-
def __repr__(self):
|
256
|
-
return (f'{self.__class__.__name__}(min_val={self.min_val}, '
|
257
|
-
f'max_val={self.max_val}, dtype={self.dtype})')
|
258
|
-
|
259
|
-
|
260
|
-
class VarianceScaling(Initializer):
|
261
|
-
def __init__(
|
262
|
-
self,
|
263
|
-
scale: float,
|
264
|
-
mode: str,
|
265
|
-
distribution: str,
|
266
|
-
in_axis: int = -2,
|
267
|
-
out_axis: int = -1,
|
268
|
-
dtype=None
|
269
|
-
):
|
270
|
-
assert mode in ['fan_in', 'fan_out', 'fan_avg']
|
271
|
-
assert distribution in ['truncated_normal', 'normal', 'uniform']
|
272
|
-
self.scale = scale
|
273
|
-
self.mode = mode
|
274
|
-
self.in_axis = in_axis
|
275
|
-
self.out_axis = out_axis
|
276
|
-
self.distribution = distribution
|
277
|
-
self.dtype = dtype or environ.dftype()
|
278
|
-
|
279
|
-
def __call__(self, shape):
|
280
|
-
shape = to_size(shape)
|
281
|
-
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
|
282
|
-
if self.mode == "fan_in":
|
283
|
-
denominator = fan_in
|
284
|
-
elif self.mode == "fan_out":
|
285
|
-
denominator = fan_out
|
286
|
-
elif self.mode == "fan_avg":
|
287
|
-
denominator = (fan_in + fan_out) / 2
|
288
|
-
else:
|
289
|
-
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
|
290
|
-
variance = (self.scale / denominator).astype(self.dtype)
|
291
|
-
if self.distribution == "truncated_normal":
|
292
|
-
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
|
293
|
-
res = random.truncated_normal(-2, 2, shape, dtype=self.dtype) * stddev
|
294
|
-
elif self.distribution == "normal":
|
295
|
-
res = random.randn(*shape, dtype=self.dtype) * jnp.sqrt(variance).astype(self.dtype)
|
296
|
-
elif self.distribution == "uniform":
|
297
|
-
res = (random.uniform(low=-1, high=1, size=shape, dtype=self.dtype) *
|
298
|
-
jnp.sqrt(3 * variance).astype(self.dtype))
|
299
|
-
else:
|
300
|
-
raise ValueError("invalid distribution for variance scaling initializer")
|
301
|
-
return res
|
302
|
-
|
303
|
-
def __repr__(self):
|
304
|
-
name = self.__class__.__name__
|
305
|
-
blank = ' ' * len(name)
|
306
|
-
return (f'{name}(scale={self.scale}, mode={self.mode}, in_axis={self.in_axis}, \n'
|
307
|
-
f'{blank}out_axis={self.out_axis}, distribution={self.distribution}, dtype={self.dtype})')
|
308
|
-
|
309
|
-
|
310
|
-
class KaimingUniform(VarianceScaling):
|
311
|
-
def __init__(
|
312
|
-
self,
|
313
|
-
scale: float = 2.0,
|
314
|
-
mode: str = "fan_in",
|
315
|
-
distribution: str = "uniform",
|
316
|
-
in_axis: int = -2,
|
317
|
-
out_axis: int = -1,
|
318
|
-
dtype=None
|
319
|
-
):
|
320
|
-
super().__init__(scale,
|
321
|
-
mode,
|
322
|
-
distribution,
|
323
|
-
in_axis=in_axis,
|
324
|
-
out_axis=out_axis,
|
325
|
-
dtype=dtype)
|
326
|
-
|
327
|
-
|
328
|
-
class KaimingNormal(VarianceScaling):
|
329
|
-
def __init__(
|
330
|
-
self,
|
331
|
-
scale: float = 2.0,
|
332
|
-
mode: str = "fan_in",
|
333
|
-
distribution: str = "truncated_normal",
|
334
|
-
in_axis: int = -2,
|
335
|
-
out_axis: int = -1,
|
336
|
-
dtype=None
|
337
|
-
):
|
338
|
-
super().__init__(scale,
|
339
|
-
mode,
|
340
|
-
distribution,
|
341
|
-
in_axis=in_axis,
|
342
|
-
out_axis=out_axis,
|
343
|
-
dtype=dtype)
|
344
|
-
|
345
|
-
|
346
|
-
class XavierUniform(VarianceScaling):
|
347
|
-
def __init__(
|
348
|
-
self,
|
349
|
-
scale: float = 1.0,
|
350
|
-
mode: str = "fan_avg",
|
351
|
-
distribution: str = "uniform",
|
352
|
-
in_axis: int = -2,
|
353
|
-
out_axis: int = -1,
|
354
|
-
dtype=None
|
355
|
-
):
|
356
|
-
super().__init__(scale,
|
357
|
-
mode,
|
358
|
-
distribution,
|
359
|
-
in_axis=in_axis,
|
360
|
-
out_axis=out_axis,
|
361
|
-
dtype=dtype)
|
362
|
-
|
363
|
-
|
364
|
-
class XavierNormal(VarianceScaling):
|
365
|
-
def __init__(
|
366
|
-
self,
|
367
|
-
scale: float = 1.0,
|
368
|
-
mode: str = "fan_avg",
|
369
|
-
distribution: str = "truncated_normal",
|
370
|
-
in_axis: int = -2,
|
371
|
-
out_axis: int = -1,
|
372
|
-
dtype=None
|
373
|
-
):
|
374
|
-
super().__init__(scale,
|
375
|
-
mode,
|
376
|
-
distribution,
|
377
|
-
in_axis=in_axis,
|
378
|
-
out_axis=out_axis,
|
379
|
-
dtype=dtype)
|
380
|
-
|
381
|
-
|
382
|
-
class LecunUniform(VarianceScaling):
|
383
|
-
def __init__(
|
384
|
-
self,
|
385
|
-
scale: float = 1.0,
|
386
|
-
mode: str = "fan_in",
|
387
|
-
distribution: str = "uniform",
|
388
|
-
in_axis: int = -2,
|
389
|
-
out_axis: int = -1,
|
390
|
-
dtype=None
|
391
|
-
):
|
392
|
-
super().__init__(scale,
|
393
|
-
mode,
|
394
|
-
distribution,
|
395
|
-
in_axis=in_axis,
|
396
|
-
out_axis=out_axis,
|
397
|
-
dtype=dtype)
|
398
|
-
|
399
|
-
|
400
|
-
class LecunNormal(VarianceScaling):
|
401
|
-
def __init__(
|
402
|
-
self,
|
403
|
-
scale: float = 1.0,
|
404
|
-
mode: str = "fan_in",
|
405
|
-
distribution: str = "truncated_normal",
|
406
|
-
in_axis: int = -2,
|
407
|
-
out_axis: int = -1,
|
408
|
-
dtype=None
|
409
|
-
):
|
410
|
-
super().__init__(scale,
|
411
|
-
mode,
|
412
|
-
distribution,
|
413
|
-
in_axis=in_axis,
|
414
|
-
out_axis=out_axis,
|
415
|
-
dtype=dtype)
|
416
|
-
|
417
|
-
|
418
|
-
class Orthogonal(Initializer):
|
419
|
-
"""
|
420
|
-
Construct an initializer for uniformly distributed orthogonal matrices.
|
421
|
-
|
422
|
-
If the shape is not square, the matrix will have orthonormal rows or columns
|
423
|
-
depending on which side is smaller.
|
424
|
-
"""
|
425
|
-
|
426
|
-
def __init__(
|
427
|
-
self,
|
428
|
-
scale: float = 1.,
|
429
|
-
axis: int = -1,
|
430
|
-
dtype=None
|
431
|
-
):
|
432
|
-
super().__init__()
|
433
|
-
self.scale = scale
|
434
|
-
self.axis = axis
|
435
|
-
self.dtype = dtype or environ.dftype()
|
436
|
-
|
437
|
-
def __call__(self, shape):
|
438
|
-
shape = to_size(shape)
|
439
|
-
n_rows = shape[self.axis]
|
440
|
-
n_cols = np.prod(shape) // n_rows
|
441
|
-
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
442
|
-
norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
|
443
|
-
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
444
|
-
# Enforce Q is uniformly distributed
|
445
|
-
q_mat *= jnp.sign(jnp.diag(r_mat))
|
446
|
-
if n_rows < n_cols:
|
447
|
-
q_mat = q_mat.T
|
448
|
-
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
|
449
|
-
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
|
450
|
-
return jnp.asarray(self.scale, dtype=self.dtype) * q_mat
|
451
|
-
|
452
|
-
def __repr__(self):
|
453
|
-
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
454
|
-
|
455
|
-
|
456
|
-
class DeltaOrthogonal(Initializer):
|
457
|
-
"""
|
458
|
-
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
459
|
-
|
460
|
-
The shape must be 3D, 4D or 5D.
|
461
|
-
"""
|
462
|
-
|
463
|
-
def __init__(self, scale=1.0, axis=-1, dtype=None):
|
464
|
-
super(DeltaOrthogonal, self).__init__()
|
465
|
-
self.scale = scale
|
466
|
-
self.axis = axis
|
467
|
-
self.dtype = dtype or environ.dftype()
|
468
|
-
|
469
|
-
def __call__(self, shape):
|
470
|
-
shape = to_size(shape)
|
471
|
-
if len(shape) not in [3, 4, 5]:
|
472
|
-
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
473
|
-
if shape[-1] < shape[-2]:
|
474
|
-
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
475
|
-
ortho_matrix = Orthogonal(scale=self.scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
|
476
|
-
W = jnp.zeros(shape, dtype=self.dtype)
|
477
|
-
if len(shape) == 3:
|
478
|
-
k = shape[0]
|
479
|
-
W = W.at[(k - 1) // 2].set(ortho_matrix)
|
480
|
-
elif len(shape) == 4:
|
481
|
-
k1, k2 = shape[:2]
|
482
|
-
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
|
483
|
-
else:
|
484
|
-
k1, k2, k3 = shape[:3]
|
485
|
-
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
486
|
-
return W
|
487
|
-
|
488
|
-
def __repr__(self):
|
489
|
-
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
@@ -1,109 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import jax.numpy as jnp
|
19
|
-
|
20
|
-
from brainstate import environ
|
21
|
-
from ._base import Initializer, to_size
|
22
|
-
|
23
|
-
__all__ = [
|
24
|
-
'ZeroInit',
|
25
|
-
'Constant',
|
26
|
-
'Identity',
|
27
|
-
]
|
28
|
-
|
29
|
-
|
30
|
-
class ZeroInit(Initializer):
|
31
|
-
"""Zero initializer.
|
32
|
-
|
33
|
-
Initialize the weights with zeros.
|
34
|
-
"""
|
35
|
-
|
36
|
-
def __init__(self, dtype=None):
|
37
|
-
super(ZeroInit, self).__init__()
|
38
|
-
self.dtype = dtype or environ.dftype()
|
39
|
-
|
40
|
-
def __call__(self, shape):
|
41
|
-
shape = to_size(shape)
|
42
|
-
return jnp.zeros(shape, dtype=self.dtype)
|
43
|
-
|
44
|
-
def __repr__(self):
|
45
|
-
return f"{self.__class__.__name__}(dtype={self.dtype})"
|
46
|
-
|
47
|
-
|
48
|
-
class Constant(Initializer):
|
49
|
-
"""Constant initializer.
|
50
|
-
|
51
|
-
Initialize the weights with the given values.
|
52
|
-
|
53
|
-
Parameters
|
54
|
-
----------
|
55
|
-
value : float, int, bm.ndarray
|
56
|
-
The value to specify.
|
57
|
-
"""
|
58
|
-
|
59
|
-
def __init__(self, value=1., dtype=None):
|
60
|
-
super(Constant, self).__init__()
|
61
|
-
self.dtype = dtype or environ.dftype()
|
62
|
-
self.value = jnp.asarray(value, dtype=self.dtype)
|
63
|
-
|
64
|
-
def __call__(self, shape):
|
65
|
-
shape = to_size(shape)
|
66
|
-
return jnp.full(shape, self.value, dtype=self.dtype)
|
67
|
-
|
68
|
-
def __repr__(self):
|
69
|
-
return f'{self.__class__.__name__}(value={self.value}, dtype={self.dtype})'
|
70
|
-
|
71
|
-
|
72
|
-
class Identity(Initializer):
|
73
|
-
"""Returns the identity matrix.
|
74
|
-
|
75
|
-
This initializer was proposed in (Le, et al., 2015) [1]_.
|
76
|
-
|
77
|
-
Parameters
|
78
|
-
----------
|
79
|
-
value : float
|
80
|
-
The optional scaling factor.
|
81
|
-
|
82
|
-
Returns
|
83
|
-
-------
|
84
|
-
shape: tuple of int
|
85
|
-
The weight shape/size.
|
86
|
-
|
87
|
-
References
|
88
|
-
----------
|
89
|
-
.. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
|
90
|
-
initialize recurrent networks of rectified linear units." arXiv preprint
|
91
|
-
arXiv:1504.00941 (2015).
|
92
|
-
"""
|
93
|
-
|
94
|
-
def __init__(self, value=1., dtype=None):
|
95
|
-
super(Identity, self).__init__()
|
96
|
-
self.dtype = dtype or environ.dftype()
|
97
|
-
self.value = jnp.asarray(value, dtype=self.dtype)
|
98
|
-
|
99
|
-
def __call__(self, shape):
|
100
|
-
shape = to_size(shape)
|
101
|
-
if isinstance(shape, (tuple, list)):
|
102
|
-
if len(shape) > 2:
|
103
|
-
raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
|
104
|
-
r = jnp.eye(shape, dtype=self.dtype)
|
105
|
-
r = jnp.fill_diagonal(r, self.value)
|
106
|
-
return r
|
107
|
-
|
108
|
-
def __repr__(self):
|
109
|
-
return f'{self.__class__.__name__}(value={self.value}, dtype={self.dtype})'
|