pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,399 @@
1
+ # Copyright 2023 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pymc as pm
17
+
18
+ from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
19
+ from pymc.distributions.shape_utils import rv_size_is_none
20
+ from pytensor import tensor as pt
21
+ from pytensor.tensor.random.op import RandomVariable
22
+
23
+
24
+ def log1mexp(x):
25
+ cond = x < np.log(0.5)
26
+ return np.piecewise(
27
+ x,
28
+ [cond, ~cond],
29
+ [lambda x: np.log1p(-np.exp(x)), lambda x: np.log(-np.expm1(x))],
30
+ )
31
+
32
+
33
+ class GeneralizedPoissonRV(RandomVariable):
34
+ name = "generalized_poisson"
35
+ signature = "(),()->()"
36
+ dtype = "int64"
37
+ _print_name = ("GeneralizedPoisson", "\\operatorname{GeneralizedPoisson}")
38
+
39
+ @classmethod
40
+ def rng_fn(cls, rng, theta, lam, size):
41
+ theta = np.asarray(theta)
42
+ lam = np.asarray(lam)
43
+
44
+ if size is not None:
45
+ dist_size = size
46
+ else:
47
+ dist_size = np.broadcast_shapes(theta.shape, lam.shape)
48
+
49
+ # A mix of 2 algorithms described by Famoye (1997) is used depending on parameter values
50
+ # 0: Inverse method, computed on the log scale. Used when lam <= 0.
51
+ # 1: Branching method. Used when lambda > 0.
52
+ x = np.empty(dist_size)
53
+ idxs_mask = np.broadcast_to(lam < 0, dist_size)
54
+ if np.any(idxs_mask):
55
+ x[idxs_mask] = cls._inverse_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[
56
+ idxs_mask
57
+ ]
58
+ idxs_mask = ~idxs_mask
59
+ if np.any(idxs_mask):
60
+ x[idxs_mask] = cls._branching_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[
61
+ idxs_mask
62
+ ]
63
+ return x
64
+
65
+ @classmethod
66
+ def _inverse_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask):
67
+ # We handle x/0 and log(0) issues with branching
68
+ with np.errstate(divide="ignore", invalid="ignore"):
69
+ log_u = np.log(rng.uniform(size=dist_size))
70
+ pos_lam = lam > 0
71
+ abs_log_lam = np.log(np.abs(lam))
72
+ theta_m_lam = theta - lam
73
+ log_s = -theta
74
+ log_p = log_s.copy()
75
+ x_ = 0
76
+ x = np.zeros(dist_size)
77
+ below_cutpoint = log_s < log_u
78
+ while np.any(below_cutpoint[idxs_mask]):
79
+ x_ += 1
80
+ x[below_cutpoint] += 1
81
+ log_c = np.log(theta_m_lam + lam * x_)
82
+ # Compute log(1 + lam / C)
83
+ log1p_lam_m_C = np.where(
84
+ pos_lam,
85
+ np.log1p(np.exp(abs_log_lam - log_c)),
86
+ log1mexp(abs_log_lam - log_c),
87
+ )
88
+ log_p = log_c + log1p_lam_m_C * (x_ - 1) + log_p - np.log(x_) - lam
89
+ log_s = np.logaddexp(log_s, log_p)
90
+ below_cutpoint = log_s < log_u
91
+ return x
92
+
93
+ @classmethod
94
+ def _branching_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask):
95
+ lam_ = np.abs(lam) # This algorithm is only valid for positive lam
96
+ y = rng.poisson(theta, size=dist_size)
97
+ x = y.copy()
98
+ higher_than_zero = y > 0
99
+ while np.any(higher_than_zero[idxs_mask]):
100
+ y = rng.poisson(lam_ * y)
101
+ x[higher_than_zero] = x[higher_than_zero] + y[higher_than_zero]
102
+ higher_than_zero = y > 0
103
+ return x
104
+
105
+
106
+ generalized_poisson = GeneralizedPoissonRV()
107
+
108
+
109
+ class GeneralizedPoisson(pm.distributions.Discrete):
110
+ R"""
111
+ Generalized Poisson.
112
+ Used to model count data that can be either overdispersed or underdispersed.
113
+ Offers greater flexibility than the standard Poisson which assumes equidispersion,
114
+ where the mean is equal to the variance.
115
+ The pmf of this distribution is
116
+
117
+ .. math:: f(x \mid \mu, \lambda) =
118
+ \frac{\mu (\mu + \lambda x)^{x-1} e^{-\mu - \lambda x}}{x!}
119
+ ======== ======================================
120
+ Support :math:`x \in \mathbb{N}_0`
121
+ Mean :math:`\frac{\mu}{1 - \lambda}`
122
+ Variance :math:`\frac{\mu}{(1 - \lambda)^3}`
123
+ ======== ======================================
124
+
125
+ Parameters
126
+ ----------
127
+ mu : tensor_like of float
128
+ Mean parameter (mu > 0).
129
+ lam : tensor_like of float
130
+ Dispersion parameter (max(-1, -mu/4) <= lam <= 1).
131
+
132
+ Notes
133
+ -----
134
+ When lam = 0, the Generalized Poisson reduces to the standard Poisson with the same mu.
135
+ When lam < 0, the mean is greater than the variance (underdispersion).
136
+ When lam > 0, the mean is less than the variance (overdispersion).
137
+
138
+ References
139
+ ----------
140
+ The PMF is taken from [1] and the random generator function is adapted from [2].
141
+ .. [1] Consul, PoC, and Felix Famoye. "Generalized Poisson regression model."
142
+ Communications in Statistics-Theory and Methods 21.1 (1992): 89-109.
143
+ .. [2] Famoye, Felix. "Generalized Poisson random variate generation." American
144
+ Journal of Mathematical and Management Sciences 17.3-4 (1997): 219-237.
145
+ """
146
+
147
+ rv_op = generalized_poisson
148
+
149
+ @classmethod
150
+ def dist(cls, mu, lam, **kwargs):
151
+ mu = pt.as_tensor_variable(mu)
152
+ lam = pt.as_tensor_variable(lam)
153
+ return super().dist([mu, lam], **kwargs)
154
+
155
+ def support_point(rv, size, mu, lam):
156
+ mean = pt.floor(mu / (1 - lam))
157
+ if not rv_size_is_none(size):
158
+ mean = pt.full(size, mean)
159
+ return mean
160
+
161
+ def logp(value, mu, lam):
162
+ mu_lam_value = mu + lam * value
163
+ logprob = np.log(mu) + logpow(mu_lam_value, value - 1) - mu_lam_value - factln(value)
164
+
165
+ # Probability is 0 when value > m, where m is the largest positive integer for
166
+ # which mu + m * lam > 0 (when lam < 0).
167
+ logprob = pt.switch(
168
+ pt.or_(
169
+ mu_lam_value < 0,
170
+ value < 0,
171
+ ),
172
+ -np.inf,
173
+ logprob,
174
+ )
175
+
176
+ return check_parameters(
177
+ logprob,
178
+ 0 < mu,
179
+ pt.abs(lam) <= 1,
180
+ (-mu / 4) <= lam,
181
+ msg="0 < mu, max(-1, -mu/4)) <= lam <= 1",
182
+ )
183
+
184
+
185
+ class BetaNegativeBinomial:
186
+ R"""
187
+ Beta Negative Binomial distribution.
188
+
189
+ The pmf of this distribution is
190
+
191
+ .. math::
192
+
193
+ f(x \mid \alpha, \beta, r) = \frac{B(r + x, \alpha + \beta)}{B(r, \alpha)} \frac{\Gamma(x + \beta)}{x! \Gamma(\beta)}
194
+
195
+ where :math:`B` is the Beta function and :math:`\Gamma` is the Gamma function.
196
+
197
+ For more information, see https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution.
198
+
199
+ .. plot::
200
+ :context: close-figs
201
+
202
+ import matplotlib.pyplot as plt
203
+ import numpy as np
204
+ from scipy.special import betaln, gammaln
205
+ def factln(x):
206
+ return gammaln(x + 1)
207
+ def logp(x, alpha, beta, r):
208
+ return (
209
+ betaln(r + x, alpha + beta)
210
+ - betaln(r, alpha)
211
+ + gammaln(x + beta)
212
+ - factln(x)
213
+ - gammaln(beta)
214
+ )
215
+ plt.style.use('arviz-darkgrid')
216
+ x = np.arange(0, 25)
217
+ params = [
218
+ (1, 1, 1),
219
+ (1, 1, 10),
220
+ (1, 10, 1),
221
+ (1, 10, 10),
222
+ (10, 10, 10),
223
+ ]
224
+ for alpha, beta, r in params:
225
+ pmf = np.exp(logp(x, alpha, beta, r))
226
+ plt.plot(x, pmf, "-o", label=r'$alpha$ = {}, $beta$ = {}, $r$ = {}'.format(alpha, beta, r))
227
+ plt.xlabel('x', fontsize=12)
228
+ plt.ylabel('f(x)', fontsize=12)
229
+ plt.legend(loc=1)
230
+ plt.show()
231
+
232
+ ======== ======================================
233
+ Support :math:`x \in \mathbb{N}_0`
234
+ Mean :math:`{\begin{cases}{\frac {r\beta }{\alpha -1}}&{\text{if}}\ \alpha >1\\\infty &{\text{otherwise}}\ \end{cases}}`
235
+ Variance :math:`{\displaystyle {\begin{cases}{\frac {r\beta (r+\alpha -1)(\beta +\alpha -1)}{(\alpha -2){(\alpha -1)}^{2}}}&{\text{if}}\ \alpha >2\\\infty &{\text{otherwise}}\ \end{cases}}}`
236
+ ======== ======================================
237
+
238
+ Parameters
239
+ ----------
240
+ alpha : tensor_like of float
241
+ shape of the beta distribution (alpha > 0).
242
+ beta : tensor_like of float
243
+ shape of the beta distribution (beta > 0).
244
+ r : tensor_like of float
245
+ number of successes until the experiment is stopped (integer but can be extended to real)
246
+ """
247
+
248
+ @staticmethod
249
+ def beta_negative_binomial_dist(alpha, beta, r, size):
250
+ if rv_size_is_none(size):
251
+ alpha, beta, r = pt.broadcast_arrays(alpha, beta, r)
252
+
253
+ p = pm.Beta.dist(alpha, beta, size=size)
254
+ return pm.NegativeBinomial.dist(p, r, size=size)
255
+
256
+ @staticmethod
257
+ def beta_negative_binomial_logp(value, alpha, beta, r):
258
+ res = (
259
+ betaln(r + value, alpha + beta)
260
+ - betaln(r, alpha)
261
+ + pt.gammaln(value + beta)
262
+ - factln(value)
263
+ - pt.gammaln(beta)
264
+ )
265
+ res = pt.switch(
266
+ pt.lt(value, 0),
267
+ -np.inf,
268
+ res,
269
+ )
270
+
271
+ return check_parameters(
272
+ res,
273
+ alpha > 0,
274
+ beta > 0,
275
+ r > 0,
276
+ msg="alpha > 0, beta > 0, r > 0",
277
+ )
278
+
279
+ def __new__(cls, name, alpha, beta, r, **kwargs):
280
+ return pm.CustomDist(
281
+ name,
282
+ alpha,
283
+ beta,
284
+ r,
285
+ dist=cls.beta_negative_binomial_dist,
286
+ logp=cls.beta_negative_binomial_logp,
287
+ class_name="BetaNegativeBinomial",
288
+ **kwargs,
289
+ )
290
+
291
+ @classmethod
292
+ def dist(cls, alpha, beta, r, **kwargs):
293
+ return pm.CustomDist.dist(
294
+ alpha,
295
+ beta,
296
+ r,
297
+ dist=cls.beta_negative_binomial_dist,
298
+ logp=cls.beta_negative_binomial_logp,
299
+ class_name="BetaNegativeBinomial",
300
+ **kwargs,
301
+ )
302
+
303
+
304
+ class Skellam:
305
+ R"""
306
+ Skellam distribution.
307
+
308
+ The Skellam distribution is the distribution of the difference of two
309
+ Poisson random variables.
310
+
311
+ The pmf of this distribution is
312
+
313
+ .. math::
314
+
315
+ f(x | \mu_1, \mu_2) = e^{{-(\mu _{1}\!+\!\mu _{2})}}\left({\frac {\mu _{1}}{\mu _{2}}}\right)^{{x/2}}\!\!I_{{x}}(2{\sqrt {\mu _{1}\mu _{2}}})
316
+
317
+ where :math:`I_{x}` is the modified Bessel function of the first kind of order :math:`x`.
318
+
319
+ Read more about the Skellam distribution at https://en.wikipedia.org/wiki/Skellam_distribution
320
+
321
+ .. plot::
322
+ :context: close-figs
323
+
324
+ import matplotlib.pyplot as plt
325
+ import numpy as np
326
+ import scipy.stats as st
327
+ import arviz as az
328
+ plt.style.use('arviz-darkgrid')
329
+ x = np.arange(-15, 15)
330
+ params = [
331
+ (1, 1),
332
+ (5, 5),
333
+ (5, 1),
334
+ ]
335
+ for mu1, mu2 in params:
336
+ pmf = st.skellam.pmf(x, mu1, mu2)
337
+ plt.plot(x, pmf, "-o", label=r'$\mu_1$ = {}, $\mu_2$ = {}'.format(mu1, mu2))
338
+ plt.xlabel('x', fontsize=12)
339
+ plt.ylabel('f(x)', fontsize=12)
340
+ plt.legend(loc=1)
341
+ plt.show()
342
+
343
+ ======== ======================================
344
+ Support :math:`x \in \mathbb{Z}`
345
+ Mean :math:`\mu_{1} - \mu_{2}`
346
+ Variance :math:`\mu_{1} + \mu_{2}`
347
+ ======== ======================================
348
+
349
+ Parameters
350
+ ----------
351
+ mu1 : tensor_like of float
352
+ Mean parameter (mu1 >= 0).
353
+ mu2 : tensor_like of float
354
+ Mean parameter (mu2 >= 0).
355
+ """
356
+
357
+ @staticmethod
358
+ def skellam_dist(mu1, mu2, size):
359
+ if rv_size_is_none(size):
360
+ mu1, mu2 = pt.broadcast_arrays(mu1, mu2)
361
+
362
+ return pm.Poisson.dist(mu=mu1, size=size) - pm.Poisson.dist(mu=mu2, size=size)
363
+
364
+ @staticmethod
365
+ def skellam_logp(value, mu1, mu2):
366
+ res = (
367
+ -mu1
368
+ - mu2
369
+ + 0.5 * value * (pt.log(mu1) - pt.log(mu2))
370
+ + pt.log(pt.iv(value, 2 * pt.sqrt(mu1 * mu2)))
371
+ )
372
+ return check_parameters(
373
+ res,
374
+ mu1 >= 0,
375
+ mu2 >= 0,
376
+ msg="mu1 >= 0, mu2 >= 0",
377
+ )
378
+
379
+ def __new__(cls, name, mu1, mu2, **kwargs):
380
+ return pm.CustomDist(
381
+ name,
382
+ mu1,
383
+ mu2,
384
+ dist=cls.skellam_dist,
385
+ logp=cls.skellam_logp,
386
+ class_name="Skellam",
387
+ **kwargs,
388
+ )
389
+
390
+ @classmethod
391
+ def dist(cls, mu1, mu2, **kwargs):
392
+ return pm.CustomDist.dist(
393
+ mu1,
394
+ mu2,
395
+ dist=cls.skellam_dist,
396
+ logp=cls.skellam_logp,
397
+ class_name="Skellam",
398
+ **kwargs,
399
+ )
@@ -0,0 +1,163 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+ import pymc as pm
18
+
19
+ from numpy.typing import ArrayLike
20
+
21
+ __all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
22
+
23
+
24
+ def quantile_histogram(
25
+ data: ArrayLike, n_quantiles=1000, zero_inflation=False
26
+ ) -> dict[str, ArrayLike]:
27
+ try:
28
+ import xhistogram.core
29
+ except ImportError as e:
30
+ raise RuntimeError("quantile_histogram requires xhistogram package") from e
31
+ try:
32
+ import dask.array
33
+ import dask.dataframe
34
+ except ImportError:
35
+ dask = None
36
+ if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
37
+ data = data.to_dask_array(lengths=True)
38
+ if zero_inflation:
39
+ zeros = (data == 0).sum(0)
40
+ mdata = np.ma.masked_where(data == 0, data)
41
+ qdata = data[data > 0]
42
+ else:
43
+ mdata = data
44
+ qdata = data.flatten()
45
+ quantiles = np.percentile(qdata, np.linspace(0, 100, n_quantiles))
46
+ if dask:
47
+ (quantiles,) = dask.compute(quantiles)
48
+ count, _ = xhistogram.core.histogram(mdata, bins=[quantiles], axis=0)
49
+ count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
50
+ lower = quantiles[:-1]
51
+ upper = quantiles[1:]
52
+
53
+ if zero_inflation:
54
+ count = np.concatenate([zeros[None], count])
55
+ lower = np.concatenate([[0], lower])
56
+ upper = np.concatenate([[0], upper])
57
+ lower = lower.reshape(lower.shape + (1,) * (count.ndim - 1))
58
+ upper = upper.reshape(upper.shape + (1,) * (count.ndim - 1))
59
+
60
+ result = dict(
61
+ lower=lower,
62
+ upper=upper,
63
+ mid=(lower + upper) / 2,
64
+ count=count,
65
+ )
66
+ return result
67
+
68
+
69
+ def discrete_histogram(data: ArrayLike, min_count=None) -> dict[str, ArrayLike]:
70
+ try:
71
+ import xhistogram.core
72
+ except ImportError as e:
73
+ raise RuntimeError("discrete_histogram requires xhistogram package") from e
74
+ try:
75
+ import dask.array
76
+ import dask.dataframe
77
+ except ImportError:
78
+ dask = None
79
+
80
+ if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
81
+ data = data.to_dask_array(lengths=True)
82
+ mid, count_uniq = np.unique(data, return_counts=True)
83
+ if min_count is not None:
84
+ mid = mid[count_uniq >= min_count]
85
+ count_uniq = count_uniq[count_uniq >= min_count]
86
+ bins = np.concatenate([mid, [mid.max() + 1]])
87
+ if dask:
88
+ mid, bins = dask.compute(mid, bins)
89
+ count, _ = xhistogram.core.histogram(data, bins=[bins], axis=0)
90
+ count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
91
+ mid = mid.reshape(mid.shape + (1,) * (count.ndim - 1))
92
+ return dict(mid=mid, count=count)
93
+
94
+
95
+ def histogram_approximation(name, dist, *, observed, **h_kwargs):
96
+ """Approximate a distribution with a histogram potential.
97
+
98
+ Parameters
99
+ ----------
100
+ name : str
101
+ Name for the Potential
102
+ dist : TensorVariable
103
+ The output of pm.Distribution.dist()
104
+ observed : ArrayLike
105
+ Observed value to construct a histogram. Histogram is computed over 0th axis.
106
+ Dask is supported.
107
+
108
+ Returns
109
+ -------
110
+ TensorVariable
111
+ Potential
112
+
113
+ Examples
114
+ --------
115
+ Discrete variables are reduced to unique repetitions (up to min_count)
116
+
117
+ >>> import pymc as pm
118
+ >>> import pymc_extras as pmx
119
+ >>> production = np.random.poisson([1, 2, 5], size=(1000, 3))
120
+ >>> with pm.Model(coords=dict(plant=range(3))):
121
+ ... lam = pm.Exponential("lam", 1.0, dims="plant")
122
+ ... pot = pmx.distributions.histogram_approximation(
123
+ ... "pot", pm.Poisson.dist(lam), observed=production, min_count=2
124
+ ... )
125
+
126
+ Continuous variables are discretized into n_quantiles
127
+
128
+ >>> measurements = np.random.normal([1, 2, 3], [0.1, 0.4, 0.2], size=(10000, 3))
129
+ >>> with pm.Model(coords=dict(tests=range(3))):
130
+ ... m = pm.Normal("m", dims="tests")
131
+ ... s = pm.LogNormal("s", dims="tests")
132
+ ... pot = pmx.distributions.histogram_approximation(
133
+ ... "pot", pm.Normal.dist(m, s),
134
+ ... observed=measurements, n_quantiles=50
135
+ ... )
136
+
137
+ For special cases like Zero Inflation in Continuous variables there is a flag.
138
+ The flag adds a separate bin for zeros
139
+
140
+ >>> measurements = abs(measurements)
141
+ >>> measurements[100:] = 0
142
+ >>> with pm.Model(coords=dict(tests=range(3))):
143
+ ... m = pm.Normal("m", dims="tests")
144
+ ... s = pm.LogNormal("s", dims="tests")
145
+ ... pot = pmx.distributions.histogram_approximation(
146
+ ... "pot", pm.Normal.dist(m, s),
147
+ ... observed=measurements, n_quantiles=50, zero_inflation=True
148
+ ... )
149
+ """
150
+ try:
151
+ import dask.array
152
+ import dask.dataframe
153
+ except ImportError:
154
+ dask = None
155
+ if dask and isinstance(observed, dask.dataframe.Series | dask.dataframe.DataFrame):
156
+ observed = observed.to_dask_array(lengths=True)
157
+ if np.issubdtype(observed.dtype, np.integer):
158
+ histogram = discrete_histogram(observed, **h_kwargs)
159
+ else:
160
+ histogram = quantile_histogram(observed, **h_kwargs)
161
+ if dask is not None:
162
+ (histogram,) = dask.compute(histogram)
163
+ return pm.Potential(name, pm.logp(dist, histogram["mid"]) * histogram["count"])
@@ -0,0 +1,3 @@
1
+ from pymc_extras.distributions.multivariate.r2d2m2cp import R2D2M2CP
2
+
3
+ __all__ = ["R2D2M2CP"]