brainstate 0.0.2__py2.py3-none-any.whl → 0.0.2.post20240824__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/init/_random_inits.py +9 -9
- {brainstate-0.0.2.dist-info → brainstate-0.0.2.post20240824.dist-info}/METADATA +2 -2
- {brainstate-0.0.2.dist-info → brainstate-0.0.2.post20240824.dist-info}/RECORD +6 -6
- {brainstate-0.0.2.dist-info → brainstate-0.0.2.post20240824.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.dist-info → brainstate-0.0.2.post20240824.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.dist-info → brainstate-0.0.2.post20240824.dist-info}/top_level.txt +0 -0
brainstate/init/_random_inits.py
CHANGED
@@ -289,8 +289,8 @@ class VarianceScaling(Initializer):
|
|
289
289
|
denominator = (fan_in + fan_out) / 2
|
290
290
|
else:
|
291
291
|
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
|
292
|
-
scale = self.scale.
|
293
|
-
|
292
|
+
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
|
293
|
+
unit = bu.get_unit(self.scale)
|
294
294
|
variance = (scale / denominator).astype(self.dtype)
|
295
295
|
if self.distribution == "truncated_normal":
|
296
296
|
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
|
@@ -302,7 +302,7 @@ class VarianceScaling(Initializer):
|
|
302
302
|
jnp.sqrt(3 * variance).astype(self.dtype))
|
303
303
|
else:
|
304
304
|
raise ValueError("invalid distribution for variance scaling initializer")
|
305
|
-
return res if
|
305
|
+
return res if unit.is_unitless else bu.Quantity(res, unit=unit)
|
306
306
|
|
307
307
|
def __repr__(self):
|
308
308
|
name = self.__class__.__name__
|
@@ -445,8 +445,8 @@ class Orthogonal(Initializer):
|
|
445
445
|
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
446
446
|
norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
|
447
447
|
|
448
|
-
scale = self.scale.
|
449
|
-
|
448
|
+
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
|
449
|
+
unit = bu.get_unit(self.scale)
|
450
450
|
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
451
451
|
# Enforce Q is uniformly distributed
|
452
452
|
q_mat *= jnp.sign(jnp.diag(r_mat))
|
@@ -455,7 +455,7 @@ class Orthogonal(Initializer):
|
|
455
455
|
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
|
456
456
|
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
|
457
457
|
r = jnp.asarray(scale, dtype=self.dtype) * q_mat
|
458
|
-
return r if
|
458
|
+
return r if unit.is_unitless else bu.Quantity(r, unit=unit)
|
459
459
|
|
460
460
|
def __repr__(self):
|
461
461
|
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
@@ -480,8 +480,8 @@ class DeltaOrthogonal(Initializer):
|
|
480
480
|
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
481
481
|
if shape[-1] < shape[-2]:
|
482
482
|
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
483
|
-
scale = self.scale.
|
484
|
-
|
483
|
+
scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
|
484
|
+
unit = bu.get_unit(self.scale)
|
485
485
|
ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
|
486
486
|
W = jnp.zeros(shape, dtype=self.dtype)
|
487
487
|
if len(shape) == 3:
|
@@ -493,7 +493,7 @@ class DeltaOrthogonal(Initializer):
|
|
493
493
|
else:
|
494
494
|
k1, k2, k3 = shape[:3]
|
495
495
|
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
496
|
-
return W if
|
496
|
+
return W if unit.is_unitless else bu.Quantity(W, unit=unit)
|
497
497
|
|
498
498
|
def __repr__(self):
|
499
499
|
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.2
|
3
|
+
Version: 0.0.2.post20240824
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
5
|
Home-page: https://github.com/brainpy/brainstate
|
6
6
|
Author: BDP
|
@@ -35,7 +35,7 @@ Requires-Dist: brainunit
|
|
35
35
|
Provides-Extra: cpu
|
36
36
|
Requires-Dist: jaxlib ; extra == 'cpu'
|
37
37
|
Provides-Extra: cuda12
|
38
|
-
Requires-Dist: jaxlib[
|
38
|
+
Requires-Dist: jaxlib[cuda12] ; extra == 'cuda12'
|
39
39
|
Provides-Extra: testing
|
40
40
|
Requires-Dist: pytest ; extra == 'testing'
|
41
41
|
Provides-Extra: tpu
|
@@ -21,7 +21,7 @@ brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il
|
|
21
21
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
22
22
|
brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
|
23
23
|
brainstate/init/_generic.py,sha256=LB7IQfswOG6X-q0QX5N8T5vZmxdygetsSBQ6iXlZ0oU,7324
|
24
|
-
brainstate/init/_random_inits.py,sha256=
|
24
|
+
brainstate/init/_random_inits.py,sha256=vNUVDdUOCXTx2i3i1enzxgg1USCzugYd56r0-2lBL-0,15919
|
25
25
|
brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
|
26
26
|
brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
|
27
27
|
brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
|
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
|
|
59
59
|
brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
|
60
60
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
61
61
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
62
|
-
brainstate-0.0.2.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
63
|
-
brainstate-0.0.2.dist-info/METADATA,sha256=
|
64
|
-
brainstate-0.0.2.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
65
|
-
brainstate-0.0.2.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
66
|
-
brainstate-0.0.2.dist-info/RECORD,,
|
62
|
+
brainstate-0.0.2.post20240824.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
63
|
+
brainstate-0.0.2.post20240824.dist-info/METADATA,sha256=XmDSiVoXh250MvzBm2tJGfcNpRQ4FUrQTJsCvphGSSA,3801
|
64
|
+
brainstate-0.0.2.post20240824.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
65
|
+
brainstate-0.0.2.post20240824.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
66
|
+
brainstate-0.0.2.post20240824.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|