pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,290 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+
18
+ EVAL_CG_TOLERANCE = 0.01
19
+ CG_TOLERANCE = 1
20
+
21
+
22
+ def masked_fill(vector, mask, fill_value):
23
+ masked_vector = np.ma.array(vector, mask=mask)
24
+ vector = masked_vector.filled(fill_value=fill_value)
25
+ return vector
26
+
27
+
28
+ def linear_cg_updates(
29
+ result, alpha, residual_inner_prod, eps, beta, residual, precond_residual, curr_conjugate_vec
30
+ ):
31
+ # Everything inside _jit_linear_cg_updates
32
+ result = result + alpha * curr_conjugate_vec
33
+ beta = np.copy(residual_inner_prod)
34
+
35
+ residual_inner_prod = residual.T @ precond_residual
36
+
37
+ # safe division
38
+ is_zero = beta < eps
39
+ beta = masked_fill(beta, mask=is_zero, fill_value=1)
40
+
41
+ beta = residual_inner_prod / beta
42
+ beta = masked_fill(beta, mask=is_zero, fill_value=0)
43
+ curr_conjugate_vec = beta * curr_conjugate_vec + precond_residual
44
+ return (
45
+ result,
46
+ alpha,
47
+ residual_inner_prod,
48
+ eps,
49
+ beta,
50
+ residual,
51
+ precond_residual,
52
+ curr_conjugate_vec,
53
+ )
54
+
55
+
56
+ def linear_cg(
57
+ mat: np.matrix,
58
+ rhs,
59
+ n_tridiag=0,
60
+ tolerance=None,
61
+ eps=1e-10,
62
+ stop_updating_after=1e-10,
63
+ max_iter=1000,
64
+ max_tridiag_iter=20,
65
+ initial_guess=None,
66
+ preconditioner=None,
67
+ terminate_cg_by_size=False,
68
+ use_eval_tolerance=False,
69
+ ):
70
+ if initial_guess is None:
71
+ initial_guess = np.zeros_like(rhs)
72
+
73
+ if preconditioner is None:
74
+
75
+ def preconditioner(x):
76
+ return x
77
+
78
+ precond = False
79
+ else:
80
+ precond = True
81
+
82
+ if tolerance is None:
83
+ if use_eval_tolerance:
84
+ tolerance = EVAL_CG_TOLERANCE
85
+ else:
86
+ tolerance = CG_TOLERANCE
87
+
88
+ # If we are running m CG iterations, we obviously can't get more than m Lanczos coefficients
89
+ if max_tridiag_iter > max_iter:
90
+ raise RuntimeError(
91
+ "Getting a tridiagonalization larger than the number of CG iterations run is not possible!"
92
+ )
93
+
94
+ is_vector = len(rhs.shape) == 1
95
+ if is_vector:
96
+ rhs = rhs[:, np.newaxis]
97
+
98
+ num_rows = rhs.size
99
+ n_iter = min(max_iter, num_rows) if terminate_cg_by_size else max_iter
100
+ n_tridiag_iter = min(max_tridiag_iter, num_rows)
101
+
102
+ # norm of rhs for convergence tests
103
+ rhs_norm = np.linalg.norm(rhs, 2)
104
+ # make almost-zero norms be 1 (so we don't get divide-by-zero errors)
105
+ rhs_is_zero = rhs_norm < eps
106
+ rhs_norm = masked_fill(rhs_norm, mask=rhs_is_zero, fill_value=1)
107
+
108
+ # lets normalize rhs
109
+ rhs = rhs / rhs_norm
110
+
111
+ # residuals
112
+ residual = rhs - mat @ initial_guess
113
+ batch_shape = residual.shape[:-2]
114
+
115
+ result = np.copy(initial_guess)
116
+
117
+ if not np.allclose(residual, residual):
118
+ raise RuntimeError("NaNs encountered when trying to perform matrix-vector multiplication")
119
+
120
+ # sometimes we are lucky and preconditioner solves the system right away
121
+ # check for convergence
122
+ residual_norm = np.linalg.norm(residual, 2)
123
+ has_converged = residual_norm < stop_updating_after
124
+
125
+ if has_converged.all() and not n_tridiag:
126
+ n_iter = 0 # skip iterations
127
+ else:
128
+ precond_residual = preconditioner(residual)
129
+
130
+ curr_conjugate_vec = precond_residual
131
+ residual_inner_prod = residual.T @ precond_residual
132
+
133
+ # define storage matrices
134
+ np.zeros_like(residual)
135
+ alpha = np.zeros((*batch_shape, 1, rhs.shape[-1]))
136
+ beta = np.zeros_like(alpha)
137
+ is_zero = np.zeros((*batch_shape, 1, rhs.shape[-1]))
138
+
139
+ # Define tridiagonal matrices if applicable
140
+ if n_tridiag:
141
+ t_mat = np.zeros((n_tridiag_iter, n_tridiag_iter, *batch_shape, n_tridiag))
142
+ alpha_tridiag_is_zero = np.zeros(*batch_shape, n_tridiag)
143
+ alpha_reciprocal = np.zeros(*batch_shape, n_tridiag)
144
+ prev_alpha_reciprocal = np.zeros_like(alpha_reciprocal)
145
+ prev_beta = np.zeros_like(alpha_reciprocal)
146
+
147
+ update_tridiag = True
148
+ last_tridiag_iter = 0
149
+
150
+ # it is possible that we don't reach tolerance even after all the iterations are over
151
+ tolerance_reached = False
152
+
153
+ # start iteration
154
+ for k in range(n_iter):
155
+ mvms = mat @ curr_conjugate_vec
156
+ if precond:
157
+ alpha = curr_conjugate_vec @ mvms # scalar
158
+
159
+ # safe division
160
+ is_zero = alpha < eps
161
+ alpha = masked_fill(alpha, mask=is_zero, fill_value=1)
162
+ alpha = residual_inner_prod / alpha
163
+ alpha = masked_fill(alpha, mask=is_zero, fill_value=0)
164
+
165
+ # cancel out updates by setting directions which have converged to zero
166
+ alpha = masked_fill(alpha, mask=has_converged, fill_value=0)
167
+ residual = residual - alpha * mvms
168
+
169
+ # update precond_residual
170
+ precond_residual = preconditioner(residual)
171
+
172
+ # Everything inside _jit_linear_cg_updates
173
+ (
174
+ result,
175
+ alpha,
176
+ residual_inner_prod,
177
+ eps,
178
+ beta,
179
+ residual,
180
+ precond_residual,
181
+ curr_conjugate_vec,
182
+ ) = linear_cg_updates(
183
+ result,
184
+ alpha,
185
+ residual_inner_prod,
186
+ eps,
187
+ beta,
188
+ residual,
189
+ precond_residual,
190
+ curr_conjugate_vec,
191
+ )
192
+
193
+ else:
194
+ # everything inside _jit_linear_cg_updates_no_precond
195
+ alpha = curr_conjugate_vec.T @ mvms
196
+
197
+ # safe division
198
+ is_zero = alpha < eps
199
+ alpha = masked_fill(alpha, mask=is_zero, fill_value=1)
200
+ alpha = residual_inner_prod / alpha
201
+ alpha = masked_fill(alpha, is_zero, fill_value=0)
202
+
203
+ alpha = masked_fill(alpha, has_converged, fill_value=0) # <- I'm here
204
+ residual = residual - alpha * mvms
205
+ precond_residual = np.copy(residual)
206
+
207
+ (
208
+ result,
209
+ alpha,
210
+ residual_inner_prod,
211
+ eps,
212
+ beta,
213
+ residual,
214
+ precond_residual,
215
+ curr_conjugate_vec,
216
+ ) = linear_cg_updates(
217
+ result,
218
+ alpha,
219
+ residual_inner_prod,
220
+ eps,
221
+ beta,
222
+ residual,
223
+ precond_residual,
224
+ curr_conjugate_vec,
225
+ )
226
+
227
+ residual_norm = np.linalg.norm(residual, 2)
228
+ residual_norm = masked_fill(residual_norm, mask=rhs_is_zero, fill_value=0)
229
+ has_converged = residual_norm < stop_updating_after
230
+
231
+ if (
232
+ k >= min(10, max_iter - 1)
233
+ and bool(residual_norm.mean() < tolerance)
234
+ and not (n_tridiag and k < min(n_tridiag_iter, max_iter - 1))
235
+ ):
236
+ tolerance_reached = True
237
+ break
238
+
239
+ # Update tridiagonal matrices, if applicable
240
+ if n_tridiag and k < n_tridiag_iter and update_tridiag:
241
+ alpha_tridiag = np.copy(alpha)
242
+ beta_tridiag = np.copy(beta)
243
+
244
+ alpha_tridiag_is_zero = alpha_tridiag == 0
245
+ alpha_tridiag = masked_fill(alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=1)
246
+ alpha_reciprocal = 1 / alpha_tridiag
247
+ alpha_tridiag = masked_fill(alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=0)
248
+
249
+ if k == 0:
250
+ t_mat[k, k] = alpha_reciprocal
251
+ else:
252
+ t_mat[k, k] += np.squeeze(alpha_reciprocal + prev_beta * prev_alpha_reciprocal)
253
+ t_mat[k, k - 1] = np.sqrt(prev_beta) * prev_alpha_reciprocal
254
+ t_mat[k - 1, k] = np.copy(t_mat[k, k - 1])
255
+
256
+ if t_mat[k - 1, k].max() < 1e-6:
257
+ update_tridiag = False
258
+
259
+ last_tridiag_iter = k
260
+
261
+ prev_alpha_reciprocal = np.copy(alpha_reciprocal)
262
+ prev_beta = np.copy(beta_tridiag)
263
+
264
+ # Un-normalize
265
+ result = result * rhs_norm
266
+ if not tolerance_reached and n_iter > 0:
267
+ raise RuntimeError(
268
+ f"CG terminated in {k + 1} iterations with average residual norm {residual_norm.mean()}"
269
+ f" which is larger than the tolerance of {tolerance} specified by"
270
+ " gpytorch.settings.cg_tolerance."
271
+ " If performance is affected, consider raising the maximum number of CG iterations by running code in"
272
+ " a gpytorch.settings.max_cg_iterations(value) context."
273
+ )
274
+
275
+ if n_tridiag:
276
+ t_mat = t_mat[: last_tridiag_iter + 1, : last_tridiag_iter + 1]
277
+ return result, t_mat.transpose(-1, *range(2, 2 + len(batch_shape)), 0, 1)
278
+ else:
279
+ # We set the estimated Lanczos tri-diagonal matrices to be identity so that
280
+ # the subsequent eigen decomposition https://arxiv.org/pdf/1809.11165.pdf (eq.S7)
281
+ # would work fine.
282
+ # t_mat = np.zeros((n_tridiag_iter, n_tridiag_iter, *batch_shape, n_tridiag))
283
+ # Note that after transpose the last two dimensions are dimensions 0 and 1 of the matrix above
284
+ # Which are the same values i.e. n_tridiag_iter
285
+ # So we generate identity matrices of size n_tridiag_iter and repeat them [n_iter, *range(2, 2+len(batch_shape))] times
286
+ # TODO: for same input, n_tridiag = True and n_tridiag = False must produce t_mat with same shape (with assumed n_tridiag=1)
287
+ n_tridiag = 1
288
+ eye = np.eye(n_tridiag_iter)
289
+ t_mat_eye = np.tile(eye, [n_tridiag] + [1] * (len(batch_shape) + 2))
290
+ return result, t_mat_eye
@@ -0,0 +1,69 @@
1
+ try:
2
+ import torch
3
+
4
+ from gpytorch.utils.permutation import apply_permutation
5
+ except ImportError as e:
6
+ raise ImportError("PyTorch and GPyTorch not found.") from e
7
+
8
+ import numpy as np
9
+
10
+
11
+ def pp(x):
12
+ return np.array2string(x, precision=4, floatmode="fixed")
13
+
14
+
15
+ def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf):
16
+ """
17
+ mat: numpy matrix of N x N
18
+
19
+ This is to replicate what is done in GPyTorch verbatim.
20
+ """
21
+ n = mat.shape[-1]
22
+ max_iter = min(int(max_iter), n)
23
+
24
+ d = np.array(np.diag(mat))
25
+ orig_error = np.max(d)
26
+ error = np.linalg.norm(d, 1) / orig_error
27
+ pi = np.arange(n)
28
+
29
+ L = np.zeros((max_iter, n))
30
+
31
+ m = 0
32
+ while m < max_iter and error > error_tol:
33
+ permuted_d = d[pi]
34
+ max_diag_idx = np.argmax(permuted_d[m:])
35
+ max_diag_idx = max_diag_idx + m
36
+ max_diag_val = permuted_d[max_diag_idx]
37
+ i = max_diag_idx
38
+
39
+ # swap pi_m and pi_i
40
+ pi[m], pi[i] = pi[i], pi[m]
41
+ pim = pi[m]
42
+
43
+ L[m, pim] = np.sqrt(max_diag_val)
44
+
45
+ if m + 1 < n:
46
+ row = apply_permutation(
47
+ torch.from_numpy(mat), torch.tensor(pim), right_permutation=None
48
+ ) # left permutation just swaps row
49
+ row = row.numpy().flatten()
50
+ pi_i = pi[m + 1 :]
51
+ L_m_new = row[pi_i] # length = 9
52
+
53
+ if m > 0:
54
+ L_prev = L[:m, pi_i]
55
+ update = L[:m, pim]
56
+ prod = update @ L_prev
57
+ L_m_new = L_m_new - prod # np.sum(prod, axis=-1)
58
+
59
+ L_m = L[m, :]
60
+ L_m_new = L_m_new / L_m[pim]
61
+ L_m[pi_i] = L_m_new
62
+
63
+ matrix_diag_current = d[pi_i]
64
+ d[pi_i] = matrix_diag_current - L_m_new**2
65
+
66
+ L[m, :] = L_m
67
+ error = np.linalg.norm(d[pi_i], 1) / orig_error
68
+ m = m + 1
69
+ return L, pi
@@ -0,0 +1,200 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from collections.abc import Sequence
17
+ from typing import TypedDict
18
+
19
+ import arviz
20
+ import numpy as np
21
+ import pymc as pm
22
+ import pytensor.tensor as pt
23
+
24
+ from pymc.logprob.transforms import Transform
25
+
26
+
27
+ class ParamCfg(TypedDict):
28
+ name: str
29
+ transform: Transform | None
30
+ dims: str | tuple[str] | None
31
+
32
+
33
+ class ShapeInfo(TypedDict):
34
+ # shape might not match slice due to a transform
35
+ shape: tuple[int] # transformed shape
36
+ slice: slice
37
+
38
+
39
+ class VarInfo(TypedDict):
40
+ sinfo: ShapeInfo
41
+ vinfo: ParamCfg
42
+
43
+
44
+ class FlatInfo(TypedDict):
45
+ data: np.ndarray
46
+ info: list[VarInfo]
47
+
48
+
49
+ def _arg_to_param_cfg(key, value: ParamCfg | Transform | str | tuple | None = None):
50
+ if value is None:
51
+ cfg = ParamCfg(name=key, transform=None, dims=None)
52
+ elif isinstance(value, tuple):
53
+ cfg = ParamCfg(name=key, transform=None, dims=value)
54
+ elif isinstance(value, str):
55
+ cfg = ParamCfg(name=value, transform=None, dims=None)
56
+ elif isinstance(value, Transform):
57
+ cfg = ParamCfg(name=key, transform=value, dims=None)
58
+ else:
59
+ cfg = value.copy()
60
+ cfg.setdefault("name", key)
61
+ cfg.setdefault("transform", None)
62
+ cfg.setdefault("dims", None)
63
+ return cfg
64
+
65
+
66
+ def _parse_args(
67
+ var_names: Sequence[str], **kwargs: ParamCfg | Transform | str | tuple
68
+ ) -> dict[str, ParamCfg]:
69
+ results = dict()
70
+ for var in var_names:
71
+ results[var] = _arg_to_param_cfg(var)
72
+ for key, val in kwargs.items():
73
+ results[key] = _arg_to_param_cfg(key, val)
74
+ return results
75
+
76
+
77
+ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo:
78
+ posterior = idata.posterior
79
+ vars = list()
80
+ info = list()
81
+ begin = 0
82
+ for key, cfg in kwargs.items():
83
+ data = (
84
+ posterior[key]
85
+ # combine all draws from all chains
86
+ .stack(__sample__=["chain", "draw"])
87
+ # move sample dim to the first position
88
+ # no matter where it was before
89
+ .transpose("__sample__", ...)
90
+ # we need numpy data for all the rest functionality
91
+ .values
92
+ )
93
+ # omitting __sample__
94
+ # we need shape in the untransformed space
95
+ if cfg["transform"] is not None:
96
+ # some transforms need original shape
97
+ data = cfg["transform"].forward(data).eval()
98
+ shape = data.shape[1:]
99
+ # now we can get rid of shape
100
+ data = data.reshape(data.shape[0], -1)
101
+ end = begin + data.shape[1]
102
+ vars.append(data)
103
+ sinfo = dict(shape=shape, slice=slice(begin, end))
104
+ info.append(dict(sinfo=sinfo, vinfo=cfg))
105
+ begin = end
106
+ return dict(data=np.concatenate(vars, axis=-1), info=info)
107
+
108
+
109
+ def _mean_chol(flat_array: np.ndarray):
110
+ mean = flat_array.mean(0)
111
+ cov = np.cov(flat_array, rowvar=False)
112
+ cov = np.atleast_2d(cov)
113
+ chol = np.linalg.cholesky(cov)
114
+ return mean, chol
115
+
116
+
117
+ def _mvn_prior_from_flat_info(name, flat_info: FlatInfo):
118
+ mean, chol = _mean_chol(flat_info["data"])
119
+ base_dist = pm.Normal(name, np.zeros_like(mean))
120
+ interim = mean + chol @ base_dist
121
+ result = dict()
122
+ for var_info in flat_info["info"]:
123
+ sinfo = var_info["sinfo"]
124
+ vinfo = var_info["vinfo"]
125
+ var = interim[sinfo["slice"]].reshape(sinfo["shape"])
126
+ if vinfo["transform"] is not None:
127
+ var = vinfo["transform"].backward(var)
128
+ var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"])
129
+ result[vinfo["name"]] = var
130
+ return result
131
+
132
+
133
+ def prior_from_idata(
134
+ idata: arviz.InferenceData,
135
+ name="trace_prior_",
136
+ *,
137
+ var_names: Sequence[str] = (),
138
+ **kwargs: ParamCfg | Transform | str | tuple,
139
+ ) -> dict[str, pt.TensorVariable]:
140
+ """
141
+ Create a prior from posterior using MvNormal approximation.
142
+
143
+ The approximation uses MvNormal distribution.
144
+ Keep in mind that this function will only work well for unimodal
145
+ posteriors and will fail when complicated interactions happen.
146
+
147
+ Moreover, if a retrieved variable is constrained, you
148
+ should specify a transform for the variable, e.g.
149
+ ``pymc.distributions.transforms.log`` for standard
150
+ deviation posterior.
151
+
152
+ Parameters
153
+ ----------
154
+ idata: arviz.InferenceData
155
+ Inference data with posterior group
156
+ var_names: Sequence[str]
157
+ names of variables to take as is from the posterior
158
+ kwargs: Union[ParamCfg, Transform, str, Tuple]
159
+ names of variables with additional configuration, see more in Examples
160
+
161
+ Examples
162
+ --------
163
+ >>> import pymc as pm
164
+ >>> import pymc.distributions.transforms as transforms
165
+ >>> import numpy as np
166
+ >>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model1:
167
+ ... a = pm.Normal("a")
168
+ ... b = pm.Normal("b", dims="test")
169
+ ... c = pm.HalfNormal("c")
170
+ ... d = pm.Normal("d")
171
+ ... e = pm.Normal("e")
172
+ ... f = pm.Dirichlet("f", np.ones(3), dims="options")
173
+ ... trace = pm.sample(progressbar=False)
174
+
175
+ You can reuse the posterior in the new model.
176
+
177
+ >>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2:
178
+ ... priors = prior_from_idata(
179
+ ... trace, # the old trace (posterior)
180
+ ... var_names=["a", "d"], # take variables as is
181
+ ...
182
+ ... e="new_e", # assign new name "new_e" for a variable
183
+ ... # similar to dict(name="new_e")
184
+ ...
185
+ ... b=("test", ), # set a dim to "test"
186
+ ... # similar to dict(dims=("test", ))
187
+ ...
188
+ ... c=transforms.log, # apply log transform to a positive variable
189
+ ... # similar to dict(transform=transforms.log)
190
+ ...
191
+ ... # set a name, assign a dim and apply simplex transform
192
+ ... f=dict(name="new_f", dims="options", transform=transforms.simplex)
193
+ ... )
194
+ ... trace1 = pm.sample_prior_predictive(100)
195
+ """
196
+ param_cfg = _parse_args(var_names=var_names, **kwargs)
197
+ if not param_cfg:
198
+ return {}
199
+ flat_info = _flatten(idata, **param_cfg)
200
+ return _mvn_prior_from_flat_info(name, flat_info)
@@ -0,0 +1,131 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+ import pytensor
18
+ import pytensor.sparse as ps
19
+ import pytensor.tensor as pt
20
+ import scipy.interpolate
21
+
22
+ from pytensor.graph.op import Apply, Op
23
+
24
+
25
+ def numpy_bspline_basis(eval_points: np.ndarray, k: int, degree=3):
26
+ k_knots = k + degree + 1
27
+ knots = np.linspace(0, 1, k_knots - 2 * degree)
28
+ knots = np.r_[[0] * degree, knots, [1] * degree]
29
+ basis_funcs = scipy.interpolate.BSpline(knots, np.eye(k), k=degree)
30
+ Bx = basis_funcs(eval_points).astype(eval_points.dtype)
31
+ return Bx
32
+
33
+
34
+ class BSplineBasis(Op):
35
+ __props__ = ("sparse",)
36
+
37
+ def __init__(self, sparse=True) -> None:
38
+ super().__init__()
39
+ if not isinstance(sparse, bool):
40
+ raise TypeError("sparse should be True or False")
41
+ self.sparse = sparse
42
+
43
+ def make_node(self, *inputs) -> Apply:
44
+ eval_points, k, d = map(pt.as_tensor, inputs)
45
+ if not (eval_points.ndim == 1 and np.issubdtype(eval_points.dtype, np.floating)):
46
+ raise TypeError("eval_points should be a vector of floats")
47
+ if k.type not in pt.int_types:
48
+ raise TypeError("k should be integer")
49
+ if d.type not in pt.int_types:
50
+ raise TypeError("degree should be integer")
51
+ if self.sparse:
52
+ out_type = ps.SparseTensorType("csr", eval_points.dtype)()
53
+ else:
54
+ out_type = pt.matrix(dtype=eval_points.dtype)
55
+ return Apply(self, [eval_points, k, d], [out_type])
56
+
57
+ def perform(self, node, inputs, output_storage, params=None) -> None:
58
+ eval_points, k, d = inputs
59
+ Bx = numpy_bspline_basis(eval_points, int(k), int(d))
60
+ if self.sparse:
61
+ Bx = scipy.sparse.csr_matrix(Bx, dtype=eval_points.dtype)
62
+ output_storage[0][0] = Bx
63
+
64
+ def infer_shape(self, fgraph, node, ins_shapes):
65
+ return [(node.inputs[0].shape[0], node.inputs[1])]
66
+
67
+
68
+ def bspline_basis(n, k, degree=3, dtype=None, sparse=True):
69
+ dtype = dtype or pytensor.config.floatX
70
+ eval_points = np.linspace(0, 1, n, dtype=dtype)
71
+ return BSplineBasis(sparse=sparse)(eval_points, k, degree)
72
+
73
+
74
+ def bspline_interpolation(x, *, n=None, eval_points=None, degree=3, sparse=True):
75
+ """Interpolate sparse grid to dense grid using bsplines.
76
+
77
+ Parameters
78
+ ----------
79
+ x : Variable
80
+ Input Variable to interpolate.
81
+ 0th coordinate assumed to be mapped regularly on [0, 1] interval
82
+ n : int (optional)
83
+ Resolution of interpolation
84
+ eval_points : vector (optional)
85
+ Custom eval points in [0, 1] interval (or scaled properly using min/max scaling)
86
+ degree : int, optional
87
+ BSpline degree, by default 3
88
+ sparse : bool, optional
89
+ Use sparse operation, by default True
90
+
91
+ Returns
92
+ -------
93
+ Variable
94
+ The interpolated variable, interpolation is across 0th axis
95
+
96
+ Examples
97
+ --------
98
+ >>> import pymc as pm
99
+ >>> import numpy as np
100
+ >>> half_months = np.linspace(0, 365, 12*2)
101
+ >>> with pm.Model(coords=dict(knots_time=half_months, time=np.arange(365))) as model:
102
+ ... kernel = pm.gp.cov.ExpQuad(1, ls=365/12)
103
+ ... # ready to define gp (a latent process over parameters)
104
+ ... gp = pm.gp.gp.Latent(
105
+ ... cov_func=kernel
106
+ ... )
107
+ ... y_knots = gp.prior("y_knots", half_months[:, None], dims="knots_time")
108
+ ... y = pm.Deterministic(
109
+ ... "y",
110
+ ... bspline_interpolation(y_knots, n=365, degree=3),
111
+ ... dims="time"
112
+ ... )
113
+ ... trace = pm.sample_prior_predictive(1)
114
+
115
+ Notes
116
+ -----
117
+ Adopted from `BayesAlpha <https://github.com/quantopian/bayesalpha/blob/676f4f194ad20211fd040d3b0c6e82969aafb87e/bayesalpha/dists.py#L97>`_
118
+ where it was written by @aseyboldt
119
+ """
120
+ x = pt.as_tensor(x)
121
+ if n is not None and eval_points is not None:
122
+ raise ValueError("Please provide one of n or eval_points")
123
+ elif n is not None:
124
+ eval_points = np.linspace(0, 1, n, dtype=x.dtype)
125
+ elif eval_points is None:
126
+ raise ValueError("Please provide one of n or eval_points")
127
+ basis = BSplineBasis(sparse=sparse)(eval_points, x.shape[0], degree)
128
+ if sparse:
129
+ return ps.dot(basis, x)
130
+ else:
131
+ return pt.dot(basis, x)
pymc_extras/version.py ADDED
@@ -0,0 +1,11 @@
1
+ import os
2
+
3
+
4
+ def get_version():
5
+ version_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "version.txt")
6
+ with open(version_file) as f:
7
+ version = f.read().strip()
8
+ return version
9
+
10
+
11
+ __version__ = get_version()
@@ -0,0 +1 @@
1
+ 0.2.0