bayinx 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -89,6 +89,7 @@ class Variational(eqx.Module):
89
89
  max_iters: int,
90
90
  data: Any = None,
91
91
  learning_rate: float = 1,
92
+ weight_decay: float = 1e-4,
92
93
  tolerance: float = 1e-4,
93
94
  var_draws: int = 1,
94
95
  key: Key = jr.PRNGKey(0),
@@ -114,7 +115,7 @@ class Variational(eqx.Module):
114
115
 
115
116
  # Initialize optimizer
116
117
  optim: GradientTransformation = opx.chain(
117
- opx.scale(-1.0), opx.nadam(schedule)
118
+ opx.scale(-1.0), opx.nadamw(schedule,weight_decay=weight_decay)
118
119
  )
119
120
  opt_state: OptState = optim.init(dyn)
120
121
 
@@ -1,3 +1,3 @@
1
- from bayinx.mhx.vi.flows.affine import Affine as Affine
1
+ from bayinx.mhx.vi.flows.fullaffine import FullAffine as FullAffine
2
2
  from bayinx.mhx.vi.flows.planar import Planar as Planar
3
3
  from bayinx.mhx.vi.flows.radial import Radial as Radial
@@ -9,7 +9,7 @@ from jaxtyping import Array, Float, Scalar
9
9
  from bayinx.core import Flow
10
10
 
11
11
 
12
- class Affine(Flow):
12
+ class FullAffine(Flow):
13
13
  """
14
14
  An affine flow.
15
15
 
@@ -3,6 +3,7 @@ from typing import Callable, Dict, Tuple
3
3
 
4
4
  import equinox as eqx
5
5
  import jax
6
+ import jax.nn as jnn
6
7
  import jax.numpy as jnp
7
8
  import jax.random as jr
8
9
  from jaxtyping import Array, Float, Scalar
@@ -36,10 +37,23 @@ class Planar(Flow):
36
37
  }
37
38
  self.constraints = {}
38
39
 
40
+ def transform_pars(self):
41
+ params = self.params
42
+
43
+ u = params['u']
44
+ w = params['w']
45
+ b = params['b']
46
+
47
+ m = jnn.softplus(w.dot(u)) - 1.0
48
+
49
+ u = u + (m - w.dot(u)) * w / (w**2).sum()
50
+
51
+ return {'u': u, 'w': w, 'b': b}
52
+
39
53
  @eqx.filter_jit
40
54
  @partial(jax.vmap, in_axes=(None, 0))
41
55
  def forward(self, draws: Array) -> Array:
42
- params = self.constrain_pars()
56
+ params = self.transform_pars()
43
57
 
44
58
  # Extract parameters
45
59
  w: Array = params["w"]
@@ -54,7 +68,7 @@ class Planar(Flow):
54
68
  @eqx.filter_jit
55
69
  @partial(jax.vmap, in_axes=(None, 0))
56
70
  def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
57
- params = self.constrain_pars()
71
+ params = self.transform_pars()
58
72
 
59
73
  # Extract parameters
60
74
  w: Array = params["w"]
@@ -115,7 +115,7 @@ class NormalizingFlow(Variational):
115
115
  return filter_spec
116
116
 
117
117
  @eqx.filter_jit
118
- def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
118
+ def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
119
119
  dyn, static = eqx.partition(self, self.filter_spec())
120
120
 
121
121
  @eqx.filter_jit
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.9
3
+ Version: 0.2.11
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -4,7 +4,7 @@ bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
4
  bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
5
5
  bayinx/core/model.py,sha256=-rT3NHjxqGB0lDBMi0Mr9XNOz1_TUnJWtd4ITj0rsus,2257
6
6
  bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
7
- bayinx/core/variational.py,sha256=yhraTVlNOJaU1NEYVrWpUXVzzWvY1Mq9ZOZv6V0_Vo0,5183
7
+ bayinx/core/variational.py,sha256=3CsDyQkq1XgV2ZBLzGrm5XgUFoJBnT6glHDgxHNcbTc,5250
8
8
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
10
  bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -14,13 +14,13 @@ bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
14
14
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
15
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
16
16
  bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
17
- bayinx/mhx/vi/normalizing_flow.py,sha256=V0-R2Nc_3Vy5c2rkjDOwIutA0G-_kN6trxjwsT5FgsA,4774
17
+ bayinx/mhx/vi/normalizing_flow.py,sha256=XBUWYZpm_Ipi6X9oTnGhqIs3ARY-5QFiuxM7uAWFRps,4790
18
18
  bayinx/mhx/vi/standard.py,sha256=m5gtcHfrYzV28h-Red3Zn6SxEgJlndeIXiIG5gDPecU,1703
19
- bayinx/mhx/vi/flows/__init__.py,sha256=V_Ng5cecKlLlFSI9ncmaiyvoy_d2EAfeDhBFcy5aQhA,168
20
- bayinx/mhx/vi/flows/affine.py,sha256=a205nNx6KRvOwGlnjI6YeDo7OTWPPIxffGZfAcTecNA,1707
21
- bayinx/mhx/vi/flows/planar.py,sha256=0BGdMm-GpTCJnxq9cOrgLl8IsHgGIL0eSFagWJNVdqQ,1944
19
+ bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
20
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
21
+ bayinx/mhx/vi/flows/planar.py,sha256=u4heNqxpmfXsACsE3RH8XkBd04Emd8G674M3sSd7CxM,2232
22
22
  bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
23
23
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
24
- bayinx-0.2.9.dist-info/METADATA,sha256=xp6L_DdXPC-TMHV4SL5LdIuhzX8GizUlx2muMgSFcy0,3057
25
- bayinx-0.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
- bayinx-0.2.9.dist-info/RECORD,,
24
+ bayinx-0.2.11.dist-info/METADATA,sha256=TkIGXb5baSFeP5K4s1bynQ3QqNsR-DKePlAOV_vV0Nc,3058
25
+ bayinx-0.2.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ bayinx-0.2.11.dist-info/RECORD,,