bayinx 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,3 @@
1
- from bayinx.mhx.vi.flows.affine import Affine as Affine
1
+ from bayinx.mhx.vi.flows.fullaffine import FullAffine as FullAffine
2
2
  from bayinx.mhx.vi.flows.planar import Planar as Planar
3
3
  from bayinx.mhx.vi.flows.radial import Radial as Radial
@@ -9,7 +9,7 @@ from jaxtyping import Array, Float, Scalar
9
9
  from bayinx.core import Flow
10
10
 
11
11
 
12
- class Affine(Flow):
12
+ class FullAffine(Flow):
13
13
  """
14
14
  An affine flow.
15
15
 
@@ -30,16 +30,16 @@ class Planar(Flow):
30
30
  - `dim`: The dimension of the parameter space.
31
31
  """
32
32
  self.params = {
33
- "u": jr.normal(key, (dim,)),
34
- "w": jr.normal(key, (dim,)),
35
- "b": jr.normal(key, (1,)),
33
+ "u": jnp.ones(dim),
34
+ "w": jnp.ones(dim),
35
+ "b": jnp.zeros(1),
36
36
  }
37
37
  self.constraints = {}
38
38
 
39
39
  @eqx.filter_jit
40
40
  @partial(jax.vmap, in_axes=(None, 0))
41
41
  def forward(self, draws: Array) -> Array:
42
- params = self.constrain_pars()
42
+ params = self.transform_pars()
43
43
 
44
44
  # Extract parameters
45
45
  w: Array = params["w"]
@@ -54,7 +54,7 @@ class Planar(Flow):
54
54
  @eqx.filter_jit
55
55
  @partial(jax.vmap, in_axes=(None, 0))
56
56
  def adjust_density(self, draws: Array) -> Tuple[Scalar, Array]:
57
- params = self.constrain_pars()
57
+ params = self.transform_pars()
58
58
 
59
59
  # Extract parameters
60
60
  w: Array = params["w"]
@@ -115,7 +115,7 @@ class NormalizingFlow(Variational):
115
115
  return filter_spec
116
116
 
117
117
  @eqx.filter_jit
118
- def elbo(self, n: int, key: Key, data: Any = None) -> Scalar:
118
+ def elbo(self, n: int, key: Key = jr.PRNGKey(0), data: Any = None) -> Scalar:
119
119
  dyn, static = eqx.partition(self, self.filter_spec())
120
120
 
121
121
  @eqx.filter_jit
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.10
3
+ Version: 0.2.12
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -14,13 +14,13 @@ bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
14
14
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
15
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
16
16
  bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
17
- bayinx/mhx/vi/normalizing_flow.py,sha256=V0-R2Nc_3Vy5c2rkjDOwIutA0G-_kN6trxjwsT5FgsA,4774
17
+ bayinx/mhx/vi/normalizing_flow.py,sha256=XBUWYZpm_Ipi6X9oTnGhqIs3ARY-5QFiuxM7uAWFRps,4790
18
18
  bayinx/mhx/vi/standard.py,sha256=m5gtcHfrYzV28h-Red3Zn6SxEgJlndeIXiIG5gDPecU,1703
19
- bayinx/mhx/vi/flows/__init__.py,sha256=V_Ng5cecKlLlFSI9ncmaiyvoy_d2EAfeDhBFcy5aQhA,168
20
- bayinx/mhx/vi/flows/affine.py,sha256=a205nNx6KRvOwGlnjI6YeDo7OTWPPIxffGZfAcTecNA,1707
21
- bayinx/mhx/vi/flows/planar.py,sha256=0BGdMm-GpTCJnxq9cOrgLl8IsHgGIL0eSFagWJNVdqQ,1944
19
+ bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
20
+ bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
21
+ bayinx/mhx/vi/flows/planar.py,sha256=qmtWpIBXRct2seI78pkmtF0X7cASUBELqmZmf2QS5Gs,1918
22
22
  bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
23
23
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
24
- bayinx-0.2.10.dist-info/METADATA,sha256=7Ej3pWMyQr0xLMmWb1WPhRDyxIQHiJ2sNfbTHkCCJ-E,3058
25
- bayinx-0.2.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
- bayinx-0.2.10.dist-info/RECORD,,
24
+ bayinx-0.2.12.dist-info/METADATA,sha256=q4e6XXwZ6ejyBWsyk_wXGDqJG9YCBK1gew93Pg_PncU,3058
25
+ bayinx-0.2.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ bayinx-0.2.12.dist-info/RECORD,,