bayinx 0.3.10__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. bayinx/__init__.py +3 -3
  2. bayinx/constraints/__init__.py +4 -3
  3. bayinx/constraints/identity.py +26 -0
  4. bayinx/constraints/interval.py +62 -0
  5. bayinx/constraints/lower.py +31 -24
  6. bayinx/constraints/upper.py +57 -0
  7. bayinx/core/__init__.py +0 -7
  8. bayinx/core/constraint.py +32 -0
  9. bayinx/core/context.py +42 -0
  10. bayinx/core/distribution.py +34 -0
  11. bayinx/core/flow.py +99 -0
  12. bayinx/core/model.py +228 -0
  13. bayinx/core/node.py +201 -0
  14. bayinx/core/types.py +17 -0
  15. bayinx/core/utils.py +109 -0
  16. bayinx/core/variational.py +170 -0
  17. bayinx/dists/__init__.py +5 -3
  18. bayinx/dists/bernoulli.py +180 -11
  19. bayinx/dists/binomial.py +215 -0
  20. bayinx/dists/exponential.py +211 -0
  21. bayinx/dists/normal.py +131 -59
  22. bayinx/dists/poisson.py +203 -0
  23. bayinx/flows/__init__.py +5 -0
  24. bayinx/flows/diagaffine.py +120 -0
  25. bayinx/flows/fullaffine.py +123 -0
  26. bayinx/flows/lowrankaffine.py +165 -0
  27. bayinx/flows/planar.py +155 -0
  28. bayinx/flows/radial.py +1 -0
  29. bayinx/flows/sylvester.py +225 -0
  30. bayinx/nodes/__init__.py +3 -0
  31. bayinx/nodes/continuous.py +64 -0
  32. bayinx/nodes/observed.py +36 -0
  33. bayinx/nodes/stochastic.py +25 -0
  34. bayinx/ops.py +104 -0
  35. bayinx/posterior.py +220 -0
  36. bayinx/vi/__init__.py +0 -0
  37. bayinx/{mhx/vi → vi}/meanfield.py +33 -29
  38. bayinx/vi/normalizing_flow.py +246 -0
  39. bayinx/vi/standard.py +95 -0
  40. bayinx-0.5.3.dist-info/METADATA +93 -0
  41. bayinx-0.5.3.dist-info/RECORD +44 -0
  42. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/WHEEL +1 -1
  43. bayinx/core/_constraint.py +0 -28
  44. bayinx/core/_flow.py +0 -80
  45. bayinx/core/_model.py +0 -98
  46. bayinx/core/_parameter.py +0 -44
  47. bayinx/core/_variational.py +0 -181
  48. bayinx/dists/censored/__init__.py +0 -3
  49. bayinx/dists/censored/gamma2/__init__.py +0 -3
  50. bayinx/dists/censored/gamma2/r.py +0 -68
  51. bayinx/dists/censored/posnormal/__init__.py +0 -3
  52. bayinx/dists/censored/posnormal/r.py +0 -116
  53. bayinx/dists/gamma2.py +0 -49
  54. bayinx/dists/posnormal.py +0 -260
  55. bayinx/dists/uniform.py +0 -75
  56. bayinx/mhx/__init__.py +0 -1
  57. bayinx/mhx/vi/__init__.py +0 -5
  58. bayinx/mhx/vi/flows/__init__.py +0 -3
  59. bayinx/mhx/vi/flows/fullaffine.py +0 -75
  60. bayinx/mhx/vi/flows/planar.py +0 -74
  61. bayinx/mhx/vi/flows/radial.py +0 -94
  62. bayinx/mhx/vi/flows/sylvester.py +0 -19
  63. bayinx/mhx/vi/normalizing_flow.py +0 -149
  64. bayinx/mhx/vi/standard.py +0 -63
  65. bayinx-0.3.10.dist-info/METADATA +0 -39
  66. bayinx-0.3.10.dist-info/RECORD +0 -35
  67. /bayinx/{py.typed → flows/otflow.py} +0 -0
  68. {bayinx-0.3.10.dist-info → bayinx-0.5.3.dist-info}/licenses/LICENSE +0 -0
bayinx/dists/bernoulli.py CHANGED
@@ -1,33 +1,202 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
1
4
  import jax.lax as lax
2
- from jaxtyping import Array, ArrayLike, Real, UInt
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ import jax.tree as jt
8
+ from jaxtyping import Array, ArrayLike, Float, Integer, PRNGKeyArray, Real, Scalar
9
+
10
+ from bayinx.core.distribution import Distribution
11
+ from bayinx.core.node import Node
12
+ from bayinx.nodes import Observed
3
13
 
4
14
 
5
- # MARK: Functions ----
6
- def prob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
15
+ def prob(
16
+ x: Integer[ArrayLike, "..."],
17
+ p: Real[ArrayLike, "..."]
18
+ ) -> Real[Array, "..."]:
7
19
  """
8
20
  The probability mass function (PMF) for a Bernoulli distribution.
9
21
 
10
22
  # Parameters
11
- - `x`: Value(s) at which to evaluate the PDF.
12
- - `p`: The probability parameter(s).
23
+ - `x`: Where to evaluate the PMF (pst be 0 or 1).
24
+ - `p`: The probability of success (p).
13
25
 
14
26
  # Returns
15
27
  The PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
16
28
  """
29
+ # Cast to Array
30
+ x, p = jnp.asarray(x), jnp.asarray(p)
17
31
 
18
- return lax.pow(p, x) * lax.pow(1 - p, 1 - x)
32
+ # Bernoulli PMF: p^x * (1-p)^(1-x)
33
+ return lax.exp(x * lax.log(p) + (1.0 - x) * lax.log1p(-p))
19
34
 
20
35
 
21
- def logprob(x: UInt[ArrayLike, "..."], p: Real[ArrayLike, "..."]) -> Real[Array, "..."]:
36
+ def logprob(
37
+ x: Integer[ArrayLike, "..."],
38
+ p: Real[ArrayLike, "..."]
39
+ ) -> Real[Array, "..."]:
22
40
  """
23
- The log probability mass function (log PMF) for a Bernoulli distribution.
41
+ The log of the probability mass function (log PMF) for a Bernoulli distribution.
24
42
 
25
43
  # Parameters
26
- - `x`: Value(s) at which to evaluate the log PMF.
27
- - `p`: The probability parameter(s).
44
+ - `x`: Where to evaluate the log PMF (pst be 0 or 1).
45
+ - `p`: The probability of success (p).
28
46
 
29
47
  # Returns
30
48
  The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
31
49
  """
50
+ # Cast to Array
51
+ x, p = jnp.asarray(x), jnp.asarray(p)
52
+
53
+ return x * lax.log(p) + (1.0 - x) * lax.log(1.0 - p)
54
+
55
+
56
+ def cdf(
57
+ x: Integer[ArrayLike, "..."],
58
+ p: Real[ArrayLike, "..."]
59
+ ) -> Real[Array, "..."]:
60
+ """
61
+ The cuplative distribution function (CDF) for a Bernoulli distribution.
62
+
63
+ # Parameters
64
+ - `x`: Where to evaluate the CDF.
65
+ - `p`: The probability of success (p).
66
+
67
+ # Returns
68
+ The CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
69
+ """
70
+ # Cast to Array
71
+ x, p = jnp.asarray(x), jnp.asarray(p)
72
+
73
+ return jnp.where(
74
+ x < 0.0,
75
+ 0.0,
76
+ jnp.where(x < 1.0, 1.0 - p, 1.0)
77
+ )
78
+
79
+
80
+ def logcdf(
81
+ x: Integer[ArrayLike, "..."],
82
+ p: Real[ArrayLike, "..."]
83
+ ) -> Real[Array, "..."]:
84
+ """
85
+ The log of the cuplative distribution function (log CDF) for a Bernoulli distribution.
86
+
87
+ # Parameters
88
+ - `x`: Where to evaluate the log CDF.
89
+ - `p`: The probability of success (p).
90
+
91
+ # Returns
92
+ The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
93
+ """
94
+ # Cast to Array
95
+ x, p = jnp.asarray(x), jnp.asarray(p)
96
+
97
+ return jnp.where(
98
+ x < 0.0,
99
+ -jnp.inf,
100
+ jnp.where(x < 1.0, lax.log1p(-p), 0.0)
101
+ )
102
+
103
+
104
+ def ccdf(
105
+ x: Integer[ArrayLike, "..."],
106
+ p: Real[ArrayLike, "..."]
107
+ ) -> Real[Array, "..."]:
108
+ """
109
+ The complementary cuplative distribution function (cCDF) for a Bernoulli distribution.
110
+
111
+ # Parameters
112
+ - `x`: Where to evaluate the cCDF.
113
+ - `p`: The probability of success (p).
114
+
115
+ # Returns
116
+ The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
117
+ """
118
+
119
+ # Cast to Array
120
+ x, p = jnp.asarray(x), jnp.asarray(p)
121
+
122
+ return jnp.where(
123
+ x < 0.0,
124
+ 1.0,
125
+ jnp.where(x < 1.0, p, 0.0)
126
+ )
127
+
128
+
129
+ def logccdf(
130
+ x: Integer[ArrayLike, "..."],
131
+ p: Real[ArrayLike, "..."]
132
+ ) -> Real[Array, "..."]:
133
+ """
134
+ The log of the complementary cuplative distribution function (log cCDF) for a Bernoulli distribution.
135
+
136
+ # Parameters
137
+ - `x`: Where to evaluate the log cCDF.
138
+ - `p`: The probability of success (p).
139
+
140
+ # Returns
141
+ The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `p`.
142
+ """
143
+ # Cast to Array
144
+ x, p = jnp.asarray(x), jnp.asarray(p)
145
+
146
+ return jnp.where(
147
+ x < 0.0,
148
+ 0.0,
149
+ jnp.where(x < 1.0, lax.log(p), -jnp.inf)
150
+ )
151
+
152
+
153
+ class Bernoulli(Distribution):
154
+ """
155
+ A Bernoulli distribution.
156
+
157
+ # Attributes
158
+ - `p`: The probability of success (p).
159
+ """
160
+
161
+ p: Node[Real[Array, "..."]]
162
+
163
+
164
+ def __init__(
165
+ self,
166
+ p: Real[ArrayLike, "..."] | Node[Real[ArrayLike, "..."]],
167
+ ):
168
+ # Initialize probability of success parameter (p)
169
+ if isinstance(p, Node):
170
+ if isinstance(p.obj, ArrayLike):
171
+ self.p = p # type: ignore
172
+ else:
173
+ self.p = Observed(jnp.asarray(p))
174
+
175
+
176
+ def logprob(self, node: Node) -> Scalar:
177
+ obj, p = (
178
+ node.obj,
179
+ self.p.obj,
180
+ )
181
+
182
+ # Filter out irrelevant values
183
+ obj, _ = eqx.partition(obj, node._filter_spec)
184
+
185
+ # Helper function for the single-leaf log-probability evaluation
186
+ def leaf_logprob(x: Float[ArrayLike, ""]) -> Scalar:
187
+ return logprob(x, p).sum()
188
+
189
+ # Compute log probabilities across leaves
190
+ eval_obj = jt.map(leaf_logprob, obj)
191
+
192
+ # Compute total sum
193
+ total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
194
+
195
+ return jnp.asarray(total)
196
+
197
+ def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
198
+ # Coerce to tuple
199
+ if isinstance(shape, int):
200
+ shape = (shape, )
32
201
 
33
- return lax.log(p) * x + (1 - x) * lax.log(1 - p)
202
+ return jr.bernoulli(key, self.p.obj, shape)
@@ -0,0 +1,215 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.lax as lax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ import jax.scipy.special as jsp
8
+ import jax.tree as jt
9
+ from jaxtyping import Array, ArrayLike, Float, Integer, PRNGKeyArray, Real, Scalar
10
+
11
+ from bayinx.core.distribution import Distribution
12
+ from bayinx.core.node import Node
13
+ from bayinx.nodes import Observed
14
+
15
+
16
+ def log_binom_coeff(n: ArrayLike, x: ArrayLike) -> Array:
17
+ n, x = jnp.asarray(n), jnp.asarray(x)
18
+ return jsp.gammaln(n + 1) - jsp.gammaln(x + 1) - jsp.gammaln(n - x + 1)
19
+
20
+
21
+ def prob(
22
+ x: Integer[ArrayLike, "..."],
23
+ n: Integer[ArrayLike, "..."],
24
+ p: Real[ArrayLike, "..."],
25
+ ) -> Real[Array, "..."]:
26
+ """
27
+ The probability mass function (PMF) for a Binomial distribution.
28
+
29
+ # Parameters
30
+ - `x`: Where to evaluate the PMF (number of successes).
31
+ - `n`: The number of trials.
32
+ - `p`: The probability of success.
33
+
34
+ # Returns
35
+ The PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
36
+ """
37
+ # Cast to Array
38
+ x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
39
+
40
+ return lax.exp(logprob(x, n, p))
41
+
42
+
43
+ def logprob(
44
+ x: Integer[ArrayLike, "..."],
45
+ n: Integer[ArrayLike, "..."],
46
+ p: Real[ArrayLike, "..."],
47
+ ) -> Real[Array, "..."]:
48
+ """
49
+ The log of the probability mass function (log PMF) for a Binomial distribution.
50
+
51
+ # Parameters
52
+ - `x`: Where to evaluate the log PMF (number of successes).
53
+ - `n`: The number of trials.
54
+ - `p`: The probability of success.
55
+
56
+ # Returns
57
+ The log PMF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
58
+ """
59
+ # Cast to Array
60
+ k, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
61
+
62
+ return log_binom_coeff(n, k) + k * lax.log(p) + (n - k) * lax.log1p(-p)
63
+
64
+
65
+ def cdf(
66
+ x: Integer[ArrayLike, "..."],
67
+ n: Integer[ArrayLike, "..."],
68
+ p: Real[ArrayLike, "..."],
69
+ ) -> Real[Array, "..."]:
70
+ """
71
+ The cumulative density function (CDF) for a Binomial distribution.
72
+
73
+ # Parameters
74
+ - `x`: Where to evaluate the CDF (P(X <= x)).
75
+ - `n`: The number of trials.
76
+ - `p`: The probability of success.
77
+
78
+ # Returns
79
+ The CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
80
+ """
81
+ # Cast to Array
82
+ x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
83
+
84
+ return jsp.betainc(n - x, x + 1, 1.0 - p)
85
+
86
+
87
+ def logcdf(
88
+ x: Integer[ArrayLike, "..."],
89
+ n: Integer[ArrayLike, "..."],
90
+ p: Real[ArrayLike, "..."],
91
+ ) -> Real[Array, "..."]:
92
+ """
93
+ The log of the cumulative density function (log CDF) for a Binomial distribution.
94
+
95
+ # Parameters
96
+ - `x`: Where to evaluate the log CDF (ln P(X <= x)).
97
+ - `n`: The number of trials.
98
+ - `p`: The probability of success.
99
+
100
+ # Returns
101
+ The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
102
+ """
103
+ # Cast to Array
104
+ x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
105
+
106
+ return lax.log(cdf(x, n, p))
107
+
108
+
109
+ def ccdf(
110
+ x: Integer[ArrayLike, "..."],
111
+ n: Integer[ArrayLike, "..."],
112
+ p: Real[ArrayLike, "..."],
113
+ ) -> Real[Array, "..."]:
114
+ """
115
+ The complementary cumulative density function (cCDF) for a Binomial distribution (P(X > k)).
116
+
117
+ # Parameters
118
+ - `k`: Where to evaluate the cCDF.
119
+ - `n`: The number of trials.
120
+ - `p`: The probability of success.
121
+
122
+ # Returns
123
+ The cCDF evaluated at `k`. The output will have the broadcasted shapes of `k`, `n`, and `p`.
124
+ """
125
+
126
+ # Cast to Array
127
+ x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
128
+
129
+ return jsp.betainc(x + 1, n - x, p)
130
+
131
+
132
+ def logccdf(
133
+ x: Integer[ArrayLike, "..."],
134
+ n: Integer[ArrayLike, "..."],
135
+ p: Real[ArrayLike, "..."],
136
+ ) -> Real[Array, "..."]:
137
+ """
138
+ The log of the complementary cumulative density function (log cCDF) for a Binomial distribution.
139
+
140
+ # Parameters
141
+ - `x`: Where to evaluate the log cCDF.
142
+ - `n`: The number of trials.
143
+ - `p`: The probability of success.
144
+
145
+ # Returns
146
+ The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x`, `n`, and `p`.
147
+ """
148
+ # Cast to Array
149
+ x, n, p = jnp.asarray(x), jnp.asarray(n), jnp.asarray(p)
150
+
151
+ return lax.log(ccdf(x, n, p))
152
+
153
+
154
+ class Binomial(Distribution):
155
+ """
156
+ A Binomial distribution.
157
+
158
+ # Attributes
159
+ - `n`: The number of trials parameter (integer).
160
+ - `p`: The probability of success parameter (float, [0, 1]).
161
+ """
162
+
163
+ n: Node[Integer[Array, "..."]]
164
+ p: Node[Real[Array, "..."]]
165
+
166
+
167
+ def __init__(
168
+ self,
169
+ n: Integer[ArrayLike, "..."] | Node[Integer[Array, "..."]],
170
+ p: Real[ArrayLike, "..."] | Node[Real[Array, "..."]]
171
+ ):
172
+ # Initialize number of trials (n)
173
+ if isinstance(n, Node):
174
+ if isinstance(n.obj, ArrayLike):
175
+ self.n = n # type: ignore
176
+ else:
177
+ self.n = Observed(jnp.asarray(n))
178
+
179
+ # Initialize probability of success parameter (p)
180
+ if isinstance(p, Node):
181
+ if isinstance(p.obj, ArrayLike):
182
+ self.p = p # type: ignore
183
+ else:
184
+ self.p = Observed(jnp.asarray(p))
185
+
186
+ def logprob(self, node: Node) -> Scalar:
187
+ obj, n, p = (
188
+ node.obj,
189
+ self.n.obj,
190
+ self.p.obj
191
+ )
192
+
193
+ # Filter out irrelevant values
194
+ obj, _ = eqx.partition(obj, node._filter_spec)
195
+
196
+ # Helper function for the single-leaf log-probability evaluation
197
+ def leaf_logprob(k: Float[ArrayLike, "..."]) -> Scalar:
198
+ return logprob(k, n, p).sum()
199
+
200
+ # Compute log probabilities across leaves
201
+ eval_obj = jt.map(leaf_logprob, obj)
202
+
203
+ # Compute total sum
204
+ total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
205
+
206
+ return jnp.asarray(total)
207
+
208
+ def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
209
+ # Coerce to tuple
210
+ if isinstance(shape, int):
211
+ shape = (shape, )
212
+
213
+ # Use jr.binomial for sampling (returns integer array)
214
+ # Note: jr.binomial accepts n as int, p as float, and shape
215
+ return jr.binomial(key, self.n.obj, self.p.obj, shape=shape)
@@ -0,0 +1,211 @@
1
+ from typing import Tuple
2
+
3
+ import equinox as eqx
4
+ import jax.lax as lax
5
+ import jax.numpy as jnp
6
+ import jax.random as jr
7
+ import jax.tree as jt
8
+ from jaxtyping import Array, ArrayLike, PRNGKeyArray, Real, Scalar
9
+
10
+ from bayinx.core.distribution import Distribution
11
+ from bayinx.core.node import Node
12
+ from bayinx.nodes import Observed
13
+
14
+
15
+ def prob(
16
+ x: Real[ArrayLike, "..."],
17
+ lam: Real[ArrayLike, "..."],
18
+ ) -> Real[Array, "..."]:
19
+ """
20
+ The probability density function (PDF) for an Exponential distribution.
21
+
22
+ # Parameters
23
+ - `x`: Where to evaluate the PDF (must be >= 0).
24
+ - `lam`: The rate parameter (lambda), must be > 0.
25
+
26
+ # Returns
27
+ The PDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
28
+ """
29
+ # Cast to Array
30
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
31
+
32
+ result = lam * lax.exp(-lam * x)
33
+
34
+ # Handle values outside of support
35
+ result = lax.select(x >= 0, result, jnp.array(0.0))
36
+
37
+ return result
38
+
39
+
40
+ def logprob(
41
+ x: Real[ArrayLike, "..."],
42
+ lam: Real[ArrayLike, "..."],
43
+ ) -> Real[Array, "..."]:
44
+ """
45
+ The log of the probability density function (log PDF) for an Exponential distribution.
46
+
47
+ # Parameters
48
+ - `x`: Where to evaluate the log PDF (must be >= 0).
49
+ - `lam`: The rate parameter (lambda), must be > 0.
50
+
51
+ # Returns
52
+ The log PDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
53
+ """
54
+ # Cast to Array
55
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
56
+
57
+ result = lax.log(lam) - lam * x
58
+
59
+ # Handle values outside of support
60
+ result = lax.select(x >= 0, result, -jnp.inf)
61
+
62
+ return result
63
+
64
+ def cdf(
65
+ x: Real[ArrayLike, "..."],
66
+ lam: Real[ArrayLike, "..."],
67
+ ) -> Real[Array, "..."]:
68
+ """
69
+ The cumulative density function (CDF) for an Exponential distribution (P(X <= x)).
70
+
71
+ # Parameters
72
+ - `x`: Where to evaluate the CDF.
73
+ - `lam`: The rate parameter (lambda).
74
+
75
+ # Returns
76
+ The CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
77
+ """
78
+ # Cast to Array
79
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
80
+
81
+ result = 1.0 - lax.exp(-lam * x)
82
+
83
+ # Handle values outside of support
84
+ result = lax.select(x >= 0, result, jnp.array(0.0))
85
+
86
+ return result
87
+
88
+ def logcdf(
89
+ x: Real[ArrayLike, "..."],
90
+ lam: Real[ArrayLike, "..."],
91
+ ) -> Real[Array, "..."]:
92
+ """
93
+ The log of the cumulative density function (log CDF) for an Exponential distribution.
94
+
95
+ # Parameters
96
+ - `x`: Where to evaluate the log CDF (ln P(X <= x)).
97
+ - `lam`: The rate parameter (lambda).
98
+
99
+ # Returns
100
+ The log CDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
101
+ """
102
+ # Cast to Array
103
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
104
+
105
+ result = lax.log1p(-lax.exp(-lam * x))
106
+
107
+ # Handle values outside of support (x < 0)
108
+ result = lax.select(x >= 0, result, -jnp.inf)
109
+
110
+ return result
111
+
112
+
113
+ def ccdf(
114
+ x: Real[ArrayLike, "..."],
115
+ lam: Real[ArrayLike, "..."],
116
+ ) -> Real[Array, "..."]:
117
+ """
118
+ The complementary cumulative density function (cCDF) for an Exponential distribution (P(X > x)).
119
+
120
+ # Parameters
121
+ - `x`: Where to evaluate the cCDF.
122
+ - `lam`: The rate parameter (lambda).
123
+
124
+ # Returns
125
+ The cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
126
+ """
127
+ # Cast to Array
128
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
129
+
130
+ result = lax.exp(-lam * x)
131
+
132
+ # Handle values outside of support
133
+ result = lax.select(x >= 0, result, jnp.array(1.0))
134
+
135
+ return result
136
+
137
+
138
+ def logccdf(
139
+ x: Real[ArrayLike, "..."],
140
+ lam: Real[ArrayLike, "..."],
141
+ ) -> Real[Array, "..."]:
142
+ """
143
+ The log of the complementary cumulative density function (log cCDF) for an Exponential distribution.
144
+
145
+ # Parameters
146
+ - `x`: Where to evaluate the log cCDF.
147
+ - `lam`: The rate parameter (lambda).
148
+
149
+ # Returns
150
+ The log cCDF evaluated at `x`. The output will have the broadcasted shapes of `x` and `lam`.
151
+ """
152
+ # Cast to Array
153
+ x, lam = jnp.asarray(x), jnp.asarray(lam)
154
+
155
+ # log(cCDF(x)) = -lambda * x for x >= 0
156
+ log_ccdf_val = -lam * x
157
+
158
+ # Handle values outside of support (x < 0), log(P(X > x)) = log(1.0) = 0.0
159
+ return lax.select(x >= 0, log_ccdf_val, jnp.array(0.0))
160
+
161
+
162
+ class Exponential(Distribution):
163
+ """
164
+ An Exponential distribution.
165
+
166
+ # Attributes
167
+ - `lam`: The rate parameter (lambda), must be positive.
168
+ """
169
+
170
+ lam: Node[Real[Array, "..."]]
171
+
172
+
173
+ def __init__(
174
+ self,
175
+ lam: Real[ArrayLike, "..."] | Node[Real[Array, "..."]],
176
+ ):
177
+ # Initialize rate parameter (lambda)
178
+ if isinstance(lam, Node):
179
+ if isinstance(lam.obj, ArrayLike):
180
+ self.lam = lam # type: ignore
181
+ else:
182
+ self.lam = Observed(jnp.asarray(lam))
183
+
184
+
185
+ def logprob(self, node: Node) -> Scalar:
186
+ obj, lam = (
187
+ node.obj,
188
+ self.lam.obj,
189
+ )
190
+
191
+ # Filter out irrelevant values
192
+ obj, _ = eqx.partition(obj, node._filter_spec)
193
+
194
+ # Helper function for the single-leaf log-probability evaluation
195
+ def leaf_logprob(x: Real[ArrayLike, "..."]) -> Scalar:
196
+ return logprob(x, lam).sum()
197
+
198
+ # Compute log probabilities across leaves
199
+ eval_obj = jt.map(leaf_logprob, obj)
200
+
201
+ # Compute total sum
202
+ total = jt.reduce_associative(lambda x,y: x + y, eval_obj, identity=0.0)
203
+
204
+ return jnp.asarray(total)
205
+
206
+ def sample(self, shape: int | Tuple[int, ...], key: PRNGKeyArray = jr.key(0)):
207
+ # Coerce to tuple
208
+ if isinstance(shape, int):
209
+ shape = (shape, )
210
+
211
+ return jr.exponential(key, shape=shape) * 1.0 / self.lam.obj