bayinx 0.2.12__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -115,7 +115,7 @@ class Variational(eqx.Module):
115
115
 
116
116
  # Initialize optimizer
117
117
  optim: GradientTransformation = opx.chain(
118
- opx.scale(-1.0), opx.nadamw(schedule,weight_decay=weight_decay)
118
+ opx.scale(-1.0), opx.nadamw(schedule, weight_decay=weight_decay)
119
119
  )
120
120
  opt_state: OptState = optim.init(dyn)
121
121
 
bayinx/dists/normal.py CHANGED
@@ -1,17 +1,12 @@
1
- # MARK: Imports ----
2
1
  import jax.lax as _lax
2
+ from jaxtyping import Array, ArrayLike, Float, Real
3
3
 
4
- ## Typing
5
- from jaxtyping import Array, Real
6
-
7
- # MARK: Constants
8
4
  _PI = 3.141592653589793
9
5
 
10
6
 
11
- # MARK: Functions ----
12
7
  def prob(
13
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
14
- ) -> Real[Array, "..."]:
8
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
9
+ ) -> Float[Array, "..."]:
15
10
  """
16
11
  The probability density function (PDF) for a Normal distribution.
17
12
 
@@ -30,8 +25,8 @@ def prob(
30
25
 
31
26
 
32
27
  def logprob(
33
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
34
- ) -> Real[Array, "..."]:
28
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
29
+ ) -> Float[Array, "..."]:
35
30
  """
36
31
  The log of the probability density function (log PDF) for a Normal distribution.
37
32
 
@@ -48,8 +43,8 @@ def logprob(
48
43
 
49
44
 
50
45
  def uprob(
51
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
52
- ) -> Real[Array, "..."]:
46
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
47
+ ) -> Float[Array, "..."]:
53
48
  """
54
49
  The unnormalized probability density function (uPDF) for a Normal distribution.
55
50
 
@@ -66,8 +61,8 @@ def uprob(
66
61
 
67
62
 
68
63
  def ulogprob(
69
- x: Real[Array, "..."], mu: Real[Array, "..."], sigma: Real[Array, "..."]
70
- ) -> Real[Array, "..."]:
64
+ x: Real[ArrayLike, "..."], mu: Real[ArrayLike, "..."], sigma: Real[ArrayLike, "..."]
65
+ ) -> Float[Array, "..."]:
71
66
  """
72
67
  The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
73
68
 
bayinx/mhx/vi/standard.py CHANGED
@@ -52,7 +52,7 @@ class Standard(Variational):
52
52
  x=draws,
53
53
  mu=jnp.array(0.0),
54
54
  sigma=jnp.array(1.0),
55
- ).sum(axis=1)
55
+ ).sum(axis=1, keepdims=True)
56
56
 
57
57
  @eqx.filter_jit
58
58
  def filter_spec(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayinx
3
- Version: 0.2.12
3
+ Version: 0.2.18
4
4
  Summary: Bayesian Inference with JAX
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: equinox>=0.11.12
@@ -4,23 +4,23 @@ bayinx/core/__init__.py,sha256=7vW2F8t3K4TWlSu5nZrYCdUrz5N9FMIfQQBn3IoeH6o,150
4
4
  bayinx/core/flow.py,sha256=oZE0OHCninIHjp-WVLFWd1DaN0-qXxNWFAUAdgIDmRU,2423
5
5
  bayinx/core/model.py,sha256=-rT3NHjxqGB0lDBMi0Mr9XNOz1_TUnJWtd4ITj0rsus,2257
6
6
  bayinx/core/utils.py,sha256=-YewhqzMFL3GJEjVdm3LgaZyHwDs9IVYllU9wAXZrtw,1859
7
- bayinx/core/variational.py,sha256=3CsDyQkq1XgV2ZBLzGrm5XgUFoJBnT6glHDgxHNcbTc,5250
7
+ bayinx/core/variational.py,sha256=k9wWn7Tnj3eET-qK1pZtzDyPZVvQTRUexJUBVSdGXOA,5251
8
8
  bayinx/dists/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  bayinx/dists/bernoulli.py,sha256=xMV9BgtVX_1XkPdZ43q0meMIEkgMyuUPx--dyo6_DKs,1006
10
10
  bayinx/dists/binomial.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  bayinx/dists/gamma.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  bayinx/dists/gamma2.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- bayinx/dists/normal.py,sha256=e9gXXAHeZQKjBndW2TnMvP3gtmvpfYGG7kehcpGeAoU,2590
13
+ bayinx/dists/normal.py,sha256=OOKg46y5hHFP76ydbRjEXaDkgefZcj9sd0XAl7yokww,2587
14
14
  bayinx/mhx/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
15
  bayinx/mhx/vi/__init__.py,sha256=YfkXKsqo9Dk_AmQGjZKm4vfG8eLer2ez92G-cOExphs,193
16
16
  bayinx/mhx/vi/meanfield.py,sha256=LNLwfjKO9os7YBmRBpGEiFxlxonuN7DHVFEmXV3hvfI,3876
17
17
  bayinx/mhx/vi/normalizing_flow.py,sha256=XBUWYZpm_Ipi6X9oTnGhqIs3ARY-5QFiuxM7uAWFRps,4790
18
- bayinx/mhx/vi/standard.py,sha256=m5gtcHfrYzV28h-Red3Zn6SxEgJlndeIXiIG5gDPecU,1703
18
+ bayinx/mhx/vi/standard.py,sha256=HaJsIz70Qo1Ql2hMQ-GQhcnfWiOGtyxgkOsm_yQaDKI,1718
19
19
  bayinx/mhx/vi/flows/__init__.py,sha256=Hn0Wqvvyv8Vr-mFmimwgNKCByxj-fjrlIvdR7tUSolg,180
20
20
  bayinx/mhx/vi/flows/fullaffine.py,sha256=2QbOtA1Jmu-yRcJeFmCKc8N1atm8G7JXYMLEZaEXKV0,1711
21
21
  bayinx/mhx/vi/flows/planar.py,sha256=qmtWpIBXRct2seI78pkmtF0X7cASUBELqmZmf2QS5Gs,1918
22
22
  bayinx/mhx/vi/flows/radial.py,sha256=c-SWybGn_jmgBQk9ZMQ5uHKPzcdhowNp8MD8t1-8VZM,2501
23
23
  bayinx/mhx/vi/flows/sylvester.py,sha256=ppK0BmS_ThvrCEhJiP_-p-kj67TQHSlU_RUZpDbIhsQ,469
24
- bayinx-0.2.12.dist-info/METADATA,sha256=q4e6XXwZ6ejyBWsyk_wXGDqJG9YCBK1gew93Pg_PncU,3058
25
- bayinx-0.2.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
- bayinx-0.2.12.dist-info/RECORD,,
24
+ bayinx-0.2.18.dist-info/METADATA,sha256=849aiZhV6f578GKb62AfN_YgIoyJZ7FfA-fYw-Ogk_0,3058
25
+ bayinx-0.2.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ bayinx-0.2.18.dist-info/RECORD,,