bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/dists/normal.py CHANGED
@@ -1,16 +1,25 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
1
4
  import jax.lax as lax
2
5
  import jax.numpy as jnp
3
- import jax.scipy.special as jss
4
- from jaxtyping import Array, ArrayLike, Float
6
+ import jax.random as jr
7
+ import jax.scipy.special as jsp
8
+ import jax.tree as jt
9
+ from jaxtyping import Array, ArrayLike, PRNGKeyArray, Real, Scalar
10
+
11
+ from bayinx.core.distribution import Distribution
12
+ from bayinx.core.node import Node
13
+ from bayinx.nodes import Observed
5
14
 
6
- __PI = 3.141592653589793
15
+ PI = 3.141592653589793
7
16
 
8
17
 
9
18
  def prob(
10
- x: Float[ArrayLike, "..."],
11
- mu: Float[ArrayLike, "..."],
12
- sigma: Float[ArrayLike, "..."],
13
- ) -> Float[Array, "..."]:
19
+ x: Real[ArrayLike, "..."],
20
+ mu: Real[ArrayLike, "..."],
21
+ sigma: Real[ArrayLike, "..."],
22
+ ) -> Real[Array, "..."]:
14
23
  """
15
24
  The probability density function (PDF) for a Normal distribution.
16
25
 
@@ -25,14 +34,14 @@ def prob(
25
34
  # Cast to Array
26
35
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
27
36
 
28
- return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 * __PI))
37
+ return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / (sigma * lax.sqrt(2.0 * PI))
29
38
 
30
39
 
31
40
  def logprob(
32
- x: Float[ArrayLike, "..."],
33
- mu: Float[ArrayLike, "..."],
34
- sigma: Float[ArrayLike, "..."],
35
- ) -> Float[Array, "..."]:
41
+ x: Real[ArrayLike, "..."],
42
+ mu: Real[ArrayLike, "..."],
43
+ sigma: Real[ArrayLike, "..."],
44
+ ) -> Real[Array, "..."]:
36
45
  """
37
46
  The log of the probability density function (log PDF) for a Normal distribution.
38
47
 
@@ -47,92 +56,155 @@ def logprob(
47
56
  # Cast to Array
48
57
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
49
58
 
50
- return -lax.log(sigma * lax.sqrt(2.0 * __PI)) - 0.5 * lax.square((x - mu) / sigma)
59
+ return -lax.log(lax.sqrt(2.0 * PI)) - lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
51
60
 
52
61
 
53
- def uprob(
54
- x: Float[ArrayLike, "..."],
55
- mu: Float[ArrayLike, "..."],
56
- sigma: Float[ArrayLike, "..."],
57
- ) -> Float[Array, "..."]:
62
+ def cdf(
63
+ x: Real[ArrayLike, "..."],
64
+ mu: Real[ArrayLike, "..."],
65
+ sigma: Real[ArrayLike, "..."],
66
+ ) -> Real[Array, "..."]:
58
67
  """
59
- The unnormalized probability density function (uPDF) for a Normal distribution.
68
+ The cumulative density function (CDF) for a Normal distribution.
60
69
 
61
70
  # Parameters
62
- - `x`: Where to evaluate the PDF.
71
+ - `x`: Where to evaluate the CDF.
63
72
  - `mu`: The mean.
64
73
  - `sigma`: The standard deviation.
65
74
 
66
75
  # Returns
67
- The uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
76
+ The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
68
77
  """
69
78
  # Cast to Array
70
79
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
71
80
 
72
- return lax.exp(-0.5 * lax.square((x - mu) / sigma)) / sigma
81
+ return jsp.ndtr((x - mu) / sigma)
73
82
 
74
83
 
75
- def ulogprob(
76
- x: Float[ArrayLike, "..."],
77
- mu: Float[ArrayLike, "..."],
78
- sigma: Float[ArrayLike, "..."],
79
- ) -> Float[Array, "..."]:
84
+ def logcdf(
85
+ x: Real[ArrayLike, "..."],
86
+ mu: Real[ArrayLike, "..."],
87
+ sigma: Real[ArrayLike, "..."],
88
+ ) -> Real[Array, "..."]:
80
89
  """
81
- The log of the unnormalized probability density function (log uPDF) for a Normal distribution.
90
+ The log of the cumulative density function (log CDF) for a Normal distribution.
82
91
 
83
92
  # Parameters
84
- - `x`: Where to evaluate the PDF.
93
+ - `x`: Where to evaluate the log CDF.
85
94
  - `mu`: The mean.
86
95
  - `sigma`: The standard deviation.
87
96
 
88
97
  # Returns
89
- The log uPDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
98
+ The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
90
99
  """
91
100
  # Cast to Array
92
101
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
93
102
 
94
- return -lax.log(sigma) - 0.5 * lax.square((x - mu) / sigma)
103
+ return jsp.log_ndtr((x - mu) / sigma)
95
104
 
96
105
 
97
- def cdf(
98
- x: Float[ArrayLike, "..."],
99
- mu: Float[ArrayLike, "..."],
100
- sigma: Float[ArrayLike, "..."],
101
- ) -> Float[Array, "..."]:
106
+ def ccdf(
107
+ x: Real[ArrayLike, "..."],
108
+ mu: Real[ArrayLike, "..."],
109
+ sigma: Real[ArrayLike, "..."],
110
+ ) -> Real[Array, "..."]:
111
+ """
112
+ The complementary cumulative density function (cCDF) for a Normal distribution.
113
+
114
+ # Parameters
115
+ - `x`: Where to evaluate the cCDF.
116
+ - `mu`: The mean.
117
+ - `sigma`: The standard deviation.
118
+
119
+ # Returns
120
+ The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
121
+ """
102
122
  # Cast to Array
103
123
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
104
124
 
105
- return jss.ndtr((x - mu) / sigma)
125
+ return jsp.ndtr((mu - x) / sigma)
106
126
 
107
127
 
108
- def logcdf(
109
- x: Float[ArrayLike, "..."],
110
- mu: Float[ArrayLike, "..."],
111
- sigma: Float[ArrayLike, "..."],
112
- ) -> Float[Array, "..."]:
128
+ def logccdf(
129
+ x: Real[ArrayLike, "..."],
130
+ mu: Real[ArrayLike, "..."],
131
+ sigma: Real[ArrayLike, "..."],
132
+ ) -> Real[Array, "..."]:
133
+ """
134
+ The log of the complementary cumulative density function (log cCDF) for a Normal distribution.
135
+
136
+ # Parameters
137
+ - `x`: Where to evaluate the log cCDF.
138
+ - `mu`: The mean.
139
+ - `sigma`: The standard deviation.
140
+
141
+ # Returns
142
+ The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `mu`, and `sigma`.
143
+ """
113
144
  # Cast to Array
114
145
  x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
115
146
 
116
- return jss.log_ndtr((x - mu) / sigma)
147
+ return jsp.log_ndtr((mu - x) / sigma)
117
148
 
118
149
 
119
- def ccdf(
120
- x: Float[ArrayLike, "..."],
121
- mu: Float[ArrayLike, "..."],
122
- sigma: Float[ArrayLike, "..."],
123
- ) -> Float[Array, "..."]:
124
- # Cast to Array
125
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
150
+ class Normal(Distribution):
151
+ """
152
+ A normal distribution.
126
153
 
127
- return jss.ndtr((mu - x) / sigma)
154
+ # Attributes
155
+ - `mu`: The mean/location parameter.
156
+ - `sigma`: The standard-deviation/scale parameter.
157
+ """
128
158
 
159
+ mu: Node[Real[Array, "..."]]
160
+ sigma: Node[Real[Array, "..."]]
129
161
 
130
- def logccdf(
131
- x: Float[ArrayLike, "..."],
132
- mu: Float[ArrayLike, "..."],
133
- sigma: Float[ArrayLike, "..."],
134
- ) -> Float[Array, "..."]:
135
- # Cast to Array
136
- x, mu, sigma = jnp.asarray(x), jnp.asarray(mu), jnp.asarray(sigma)
137
162
 
138
- return jss.log_ndtr((mu - x) / sigma)
163
+ def __init__(
164
+ self,
165
+ mu: Real[ArrayLike, "..."] | Node[Real[Array, "..."]],
166
+ sigma: Real[ArrayLike, "..."] | Node[Real[Array, "..."]]
167
+ ):
168
+ # Initialize mean/location parameter (mu)
169
+ if isinstance(mu, Node):
170
+ if isinstance(mu.obj, ArrayLike):
171
+ self.mu = mu # type: ignore
172
+ else:
173
+ self.mu = Observed(jnp.asarray(mu))
174
+
175
+ # Initialize dispersion/scale parameter (sigma)
176
+ if isinstance(sigma, Node):
177
+ if isinstance(sigma.obj, ArrayLike):
178
+ self.sigma = sigma # type: ignore
179
+ else:
180
+ self.sigma = Observed(jnp.asarray(sigma))
181
+
182
+
183
+ def logprob(self, node: Node) -> Scalar:
184
+ obj, mu, sigma = (
185
+ node.obj,
186
+ self.mu.obj,
187
+ self.sigma.obj
188
+ )
189
+
190
+ # Filter out irrelevant values
191
+ obj, _ = eqx.partition(obj, node._filter_spec)
192
+
193
+ # Helper function for the single-leaf log-probability evaluation
194
+ def leaf_logprob(x: Real[ArrayLike, "..."]) -> Scalar:
195
+ return logprob(x, mu, sigma).sum()
196
+
197
+ # Compute log probabilities across leaves
198
+ eval_obj = jt.map(leaf_logprob, obj)
199
+
200
+ # Compute total sum
201
+ total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
202
+
203
+ return jnp.asarray(total)
204
+
205
+ def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
206
+ # Coerce to tuple
207
+ if isinstance(shape, int):
208
+ shape = (shape, )
209
+
210
+ return jr.normal(key, shape) * self.sigma.obj + self.mu.obj
@@ -0,0 +1,203 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.lax as lax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ import jax.scipy.special as jsp
8
+ import jax.tree as jt
9
+ from jaxtyping import Array, ArrayLike, Integer, PRNGKeyArray, Real, Scalar
10
+
11
+ from bayinx.core.distribution import Distribution
12
+ from bayinx.core.node import Node
13
+ from bayinx.nodes import Observed
14
+
15
+
16
+ def prob(
17
+ x: Integer[ArrayLike, "..."],
18
+ lam: Real[ArrayLike, "..."],
19
+ ) -> Real[Array, "..."]:
20
+ """
21
+ The probability mass function (PMF) for a Poisson distribution.
22
+
23
+ # Parameters
24
+ - `x`: Where to evaluate the PMF.
25
+ - `lam`: The rate parameter (lambda), representing the average number of events.
26
+
27
+ # Returns
28
+ The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
29
+ """
30
+ # Cast to Array
31
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
32
+
33
+ return lax.exp(logprob(x, lam))
34
+
35
+
36
+ def logprob(
37
+ x: Integer[ArrayLike, "..."],
38
+ lam: Real[ArrayLike, "..."],
39
+ ) -> Real[Array, "..."]:
40
+ """
41
+ The log of the probability mass function (log PMF) for a Poisson distribution.
42
+
43
+ # Parameters
44
+ - `x`: Where to evaluate the log PMF.
45
+ - `lam`: The rate parameter (lambda), representing the average number of events.
46
+
47
+ # Returns
48
+ The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
49
+ """
50
+ # Cast to Array
51
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
52
+
53
+ return x * lax.log(lam) - lam - jsp.gammaln(x + 1)
54
+
55
+
56
+ def cdf(
57
+ x: Integer[ArrayLike, "..."],
58
+ lam: Real[ArrayLike, "..."],
59
+ ) -> Real[Array, "..."]:
60
+ """
61
+ The cumulative density function (CDF) for a Poisson distribution (P(X <= x)).
62
+
63
+ # Parameters
64
+ - `x`: Where to evaluate the CDF.
65
+ - `lam`: The rate parameter (lambda).
66
+
67
+ # Returns
68
+ The CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
69
+ """
70
+ # Cast to Array
71
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
72
+
73
+ result = jsp.gammainc(x + 1.0, lam)
74
+ result = lax.select(x < 0, jnp.array(0.0), result)
75
+
76
+ return result
77
+
78
+
79
+ def logcdf(
80
+ x: Integer[ArrayLike, "..."],
81
+ lam: Real[ArrayLike, "..."],
82
+ ) -> Real[Array, "..."]:
83
+ """
84
+ The log of the cumulative density function (log CDF) for a Poisson distribution.
85
+
86
+ # Parameters
87
+ - `x`: Where to evaluate the log CDF (ln P(X <= x)).
88
+ - `lam`: The rate parameter (lambda).
89
+
90
+ # Returns
91
+ The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
92
+ """
93
+ # Cast to Array
94
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
95
+
96
+ result = lax.log(jsp.gammainc(x + 1.0, lam))
97
+
98
+ # Handle values outside of support
99
+ result = lax.select(x < 0, -jnp.inf, result)
100
+
101
+ return result
102
+
103
+
104
+ def ccdf(
105
+ x: Integer[ArrayLike, "..."],
106
+ lam: Real[ArrayLike, "..."],
107
+ ) -> Real[Array, "..."]:
108
+ """
109
+ The complementary cumulative density function (cCDF) for a Poisson distribution (P(X > x)).
110
+
111
+ # Parameters
112
+ - `x`: Where to evaluate the cCDF.
113
+ - `lam`: The rate parameter (lambda).
114
+
115
+ # Returns
116
+ The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
117
+ """
118
+ # Cast to Array
119
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
120
+
121
+ result = jsp.gammaincc(x + 1.0, lam)
122
+
123
+ # Handle values outside of support
124
+ result = lax.select(x < 0, jnp.array(1.0), result)
125
+
126
+ return result
127
+
128
+
129
+ def logccdf(
130
+ x: Integer[ArrayLike, "..."],
131
+ lam: Real[ArrayLike, "..."],
132
+ ) -> Real[Array, "..."]:
133
+ """
134
+ The log of the complementary cumulative density function (log cCDF) for a Poisson distribution.
135
+
136
+ # Parameters
137
+ - `x`: Where to evaluate the log cCDF.
138
+ - `lam`: The rate parameter (lambda).
139
+
140
+ # Returns
141
+ The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
142
+ """
143
+ # Cast to Array
144
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
145
+
146
+ result = lax.log(jsp.gammaincc(x + 1.0, lam))
147
+
148
+ # Handle values outside of support
149
+ result = lax.select(x < 0, jnp.array(0.0), result)
150
+
151
+ return result
152
+
153
+
154
+ class Poisson(Distribution):
155
+ """
156
+ A Poisson distribution.
157
+
158
+ # Attributes
159
+ - `lam`: The rate parameter (lambda), representing the average number of events in an interval.
160
+ """
161
+
162
+ lam: Node[Real[Array, "..."]]
163
+
164
+
165
+ def __init__(
166
+ self,
167
+ lam: Real[ArrayLike, "..."] | Node[Real[Array, "..."]],
168
+ ):
169
+ # Initialize rate parameter (lambda)
170
+ if isinstance(lam, Node):
171
+ if isinstance(lam.obj, ArrayLike):
172
+ self.lam = lam # type: ignore
173
+ else:
174
+ self.lam = Observed(jnp.asarray(lam))
175
+
176
+
177
+ def logprob(self, node: Node) -> Scalar:
178
+ obj, lam = (
179
+ node.obj,
180
+ self.lam.obj,
181
+ )
182
+
183
+ # Filter out irrelevant values
184
+ obj, _ = eqx.partition(obj, node._filter_spec)
185
+
186
+ # Helper function for the single-leaf log-probability evaluation
187
+ def leaf_logprob(x: Integer[ArrayLike, "..."]) -> Scalar:
188
+ return logprob(x, lam).sum()
189
+
190
+ # Compute log probabilities across leaves
191
+ eval_obj = jt.map(leaf_logprob, obj)
192
+
193
+ # Compute total sum
194
+ total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
195
+
196
+ return jnp.asarray(total)
197
+
198
+ def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
199
+ # Coerce to tuple
200
+ if isinstance(shape, int):
201
+ shape = (shape, )
202
+
203
+ return jr.poisson(key, self.lam.obj, shape=shape)
@@ -0,0 +1,5 @@
1
+ from bayinx.flows.diagaffine import DiagAffine as DiagAffine
2
+ from bayinx.flows.fullaffine import FullAffine as FullAffine
3
+ from bayinx.flows.lowrankaffine import LowRankAffine as LowRankAffine
4
+ from bayinx.flows.planar import Planar as Planar
5
+ from bayinx.flows.sylvester import Sylvester as Sylvester
@@ -0,0 +1,120 @@
1
+ from typing import Callable, Dict, Tuple
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ from jaxtyping import Array, Float, PRNGKeyArray, Scalar
8
+
9
+ from bayinx.core.flow import FlowLayer, FlowSpec
10
+
11
+
12
+ class DiagAffineLayer(FlowLayer):
13
+ """
14
+ A diagonal (element-wise) affine flow.
15
+
16
+ # Attributes
17
+ - `params`: The parameters of the diagonal affine flow.
18
+ - `constraints`: The constraining transformations for the parameters of the diagonal affine flow.
19
+ - `static`: Whether the flow layer is frozen (parameters are not subject to further optimization).
20
+ """
21
+
22
+ params: Dict[str, Array]
23
+ constraints: Dict[str, Callable[[Array], Array]]
24
+ static: bool
25
+
26
+ def __init__(self, dim: int, key: PRNGKeyArray):
27
+ """
28
+ Initializes a full affine flow.
29
+
30
+ # Parameters
31
+ - `dim`: The dimension of the parameter space.
32
+ """
33
+ self.static = False
34
+ # Split key
35
+ k1, k2 = jr.split(key)
36
+
37
+ # Initialize parameters
38
+ self.params = {
39
+ "shift": jr.normal(k1, (dim,)) / dim**0.5,
40
+ "scale": jr.normal(k2, (dim,)) / dim**0.5,
41
+ }
42
+
43
+ # Define constraints
44
+ self.constraints = {"scale": jnp.exp}
45
+
46
+
47
+ def __forward(self, draw: Float[Array, " D"]) -> Float[Array, " D"]:
48
+ params = self.transform_params()
49
+
50
+ assert len(draw.shape) == 1
51
+
52
+ # Extract parameters
53
+ shift: Float[Array, " dim"] = params["shift"]
54
+ scale: Float[Array, " dim"] = params["scale"]
55
+
56
+ # Compute forward transformation
57
+ draw = draw * scale + shift
58
+
59
+ return draw
60
+
61
+ @eqx.filter_jit
62
+ def forward(self, draws: Float[Array, "draws dim"]) -> Float[Array, "draws dim"]:
63
+ f = jax.vmap(self.__forward, 0)
64
+ return f(draws)
65
+
66
+ def __adjust(self, draw: Float[Array, " dim"]) -> Float[Array, " dim"]:
67
+ params = self.transform_params()
68
+
69
+ # Extract parameters
70
+ scale: Float[Array, " dim"] = params["scale"]
71
+
72
+ # Compute log-Jacobian adjustment
73
+ lja: Array = jnp.log(scale).sum()
74
+
75
+ assert lja.shape == ()
76
+
77
+ return lja
78
+
79
+ @eqx.filter_jit
80
+ def adjust(self, draws: Float[Array, "draws dim"]) -> Float[Array, "draws dim"]:
81
+ f = jax.vmap(self.__adjust, 0)
82
+ return f(draws)
83
+
84
+ def __forward_and_adjust(self, draw: Float[Array, " dim"]) -> Tuple[Float[Array, " dim"], Scalar]:
85
+ params = self.transform_params()
86
+
87
+ assert len(draw.shape) == 1
88
+
89
+ # Extract parameters
90
+ shift: Float[Array, " dim"] = params["shift"] # noqa
91
+ scale: Float[Array, " dim"] = params["scale"] # noqa
92
+
93
+ # Compute forward transformation
94
+ draw = scale * draw + shift
95
+
96
+ assert len(draw.shape) == 1
97
+
98
+ # Compute log-Jacobian adjustment
99
+ lja: Scalar = jnp.log(scale).sum()
100
+
101
+ assert lja.shape == ()
102
+
103
+ return draw, lja
104
+
105
+ @eqx.filter_jit
106
+ def forward_and_adjust(self, draws: Float[Array, "draws dim"]) -> Tuple[Float[Array, "draws dim"], Scalar]:
107
+ f = jax.vmap(self.__forward_and_adjust, 0)
108
+ return f(draws)
109
+
110
+
111
+ class DiagAffine(FlowSpec):
112
+ """
113
+ A specification for the diagonal affine flow.
114
+ """
115
+ key: PRNGKeyArray
116
+ def __init__(self, key: PRNGKeyArray = jr.key(0)):
117
+ self.key = key
118
+
119
+ def construct(self, dim: int) -> DiagAffineLayer:
120
+ return DiagAffineLayer(dim, self.key)