brainstate 0.0.2__py2.py3-none-any.whl → 0.0.2.post20240814__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -289,8 +289,8 @@ class VarianceScaling(Initializer):
289
289
  denominator = (fan_in + fan_out) / 2
290
290
  else:
291
291
  raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
292
- scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
293
- dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
292
+ scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
293
+ unit = bu.get_unit(self.scale)
294
294
  variance = (scale / denominator).astype(self.dtype)
295
295
  if self.distribution == "truncated_normal":
296
296
  stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
@@ -302,7 +302,7 @@ class VarianceScaling(Initializer):
302
302
  jnp.sqrt(3 * variance).astype(self.dtype))
303
303
  else:
304
304
  raise ValueError("invalid distribution for variance scaling initializer")
305
- return res if dim == bu.DIMENSIONLESS else res * dim
305
+ return res if unit.is_unitless else bu.Quantity(res, unit=unit)
306
306
 
307
307
  def __repr__(self):
308
308
  name = self.__class__.__name__
@@ -445,8 +445,8 @@ class Orthogonal(Initializer):
445
445
  matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
446
446
  norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
447
447
 
448
- scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
449
- dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
448
+ scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
449
+ unit = bu.get_unit(self.scale)
450
450
  q_mat, r_mat = jnp.linalg.qr(norm_dst)
451
451
  # Enforce Q is uniformly distributed
452
452
  q_mat *= jnp.sign(jnp.diag(r_mat))
@@ -455,7 +455,7 @@ class Orthogonal(Initializer):
455
455
  q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
456
456
  q_mat = jnp.moveaxis(q_mat, 0, self.axis)
457
457
  r = jnp.asarray(scale, dtype=self.dtype) * q_mat
458
- return r if dim == bu.DIMENSIONLESS else r * dim
458
+ return r if unit.is_unitless else bu.Quantity(r, unit=unit)
459
459
 
460
460
  def __repr__(self):
461
461
  return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
@@ -480,8 +480,8 @@ class DeltaOrthogonal(Initializer):
480
480
  raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
481
481
  if shape[-1] < shape[-2]:
482
482
  raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
483
- scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
484
- dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
483
+ scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
484
+ unit = bu.get_unit(self.scale)
485
485
  ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
486
486
  W = jnp.zeros(shape, dtype=self.dtype)
487
487
  if len(shape) == 3:
@@ -493,7 +493,7 @@ class DeltaOrthogonal(Initializer):
493
493
  else:
494
494
  k1, k2, k3 = shape[:3]
495
495
  W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
496
- return W if dim == bu.DIMENSIONLESS else W * dim
496
+ return W if unit.is_unitless else bu.Quantity(W, unit=unit)
497
497
 
498
498
  def __repr__(self):
499
499
  return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.2
3
+ Version: 0.0.2.post20240814
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BDP
@@ -21,7 +21,7 @@ brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il
21
21
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
22
22
  brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
23
23
  brainstate/init/_generic.py,sha256=LB7IQfswOG6X-q0QX5N8T5vZmxdygetsSBQ6iXlZ0oU,7324
24
- brainstate/init/_random_inits.py,sha256=LsfvKSX4wsR7Kh5jgKgdyXTCEEa5Nn_iYcp_2GgLQKY,16030
24
+ brainstate/init/_random_inits.py,sha256=vNUVDdUOCXTx2i3i1enzxgg1USCzugYd56r0-2lBL-0,15919
25
25
  brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
26
26
  brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
27
27
  brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
59
59
  brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
60
60
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
61
61
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.2.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.2.dist-info/METADATA,sha256=K6yiVOqGj3Qs_vKGgQmFXZtlu8cS4r7EZXl_iyCjwh0,3792
64
- brainstate-0.0.2.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
- brainstate-0.0.2.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.2.dist-info/RECORD,,
62
+ brainstate-0.0.2.post20240814.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
+ brainstate-0.0.2.post20240814.dist-info/METADATA,sha256=skMWlfxiGaJHxzQS7dY95V91umhXVjN_HuhGO0xHP1M,3805
64
+ brainstate-0.0.2.post20240814.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
+ brainstate-0.0.2.post20240814.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
+ brainstate-0.0.2.post20240814.dist-info/RECORD,,