pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,356 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ import pymc as pm
5
+ import pytensor
6
+ import pytensor.tensor as pt
7
+
8
+ from pymc.distributions.dist_math import check_parameters
9
+ from pymc.distributions.distribution import (
10
+ Distribution,
11
+ SymbolicRandomVariable,
12
+ _support_point,
13
+ support_point,
14
+ )
15
+ from pymc.distributions.shape_utils import (
16
+ _change_dist_size,
17
+ change_dist_size,
18
+ get_support_shape_1d,
19
+ )
20
+ from pymc.logprob.abstract import _logprob
21
+ from pymc.logprob.basic import logp
22
+ from pymc.pytensorf import constant_fold, intX
23
+ from pymc.step_methods import STEP_METHODS
24
+ from pymc.step_methods.arraystep import ArrayStep
25
+ from pymc.step_methods.compound import Competence
26
+ from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
27
+ from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
28
+ from pytensor import Mode
29
+ from pytensor.graph.basic import Node
30
+ from pytensor.tensor import TensorVariable
31
+ from pytensor.tensor.random.op import RandomVariable
32
+
33
+
34
+ def _make_outputs_info(n_lags: int, init_dist: Distribution) -> list[Distribution | dict]:
35
+ """
36
+ Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away
37
+ the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of
38
+ desired (steps, *batch_size)
39
+
40
+ Parameters
41
+ ----------
42
+ n_lags: int
43
+ Number of lags the Markov Chain considers when transitioning to the next state
44
+ init_dist: RandomVariable
45
+ Distribution over initial states
46
+
47
+ Returns
48
+ -------
49
+ taps: list
50
+ Lags to be fed into pytensor.scan when drawing a markov chain
51
+ """
52
+
53
+ if n_lags > 1:
54
+ return [{"initial": init_dist, "taps": list(range(-n_lags, 0))}]
55
+ else:
56
+ return [init_dist[0]]
57
+
58
+
59
+ class DiscreteMarkovChainRV(SymbolicRandomVariable):
60
+ n_lags: int
61
+ default_output = 1
62
+ _print_name = ("DiscreteMC", "\\operatorname{DiscreteMC}")
63
+
64
+ def __init__(self, *args, n_lags, **kwargs):
65
+ self.n_lags = n_lags
66
+ super().__init__(*args, **kwargs)
67
+
68
+ def update(self, node: Node):
69
+ return {node.inputs[-1]: node.outputs[0]}
70
+
71
+
72
+ class DiscreteMarkovChain(Distribution):
73
+ r"""
74
+ A Discrete Markov Chain is a sequence of random variables
75
+
76
+ .. math::
77
+
78
+ \{x_t\}_{t=0}^T
79
+
80
+ Where transition probability :math:`P(x_t | x_{t-1})` depends only on the state of the system at :math:`x_{t-1}`.
81
+
82
+ Parameters
83
+ ----------
84
+ P: tensor
85
+ Matrix of transition probabilities between states. Rows must sum to 1.
86
+ One of P or P_logits must be provided.
87
+ P_logit: tensor, optional
88
+ Matrix of transition logits. Converted to probabilities via Softmax activation.
89
+ One of P or P_logits must be provided.
90
+ steps: tensor, optional
91
+ Length of the markov chain. Only needed if state is not provided.
92
+ init_dist : unnamed distribution, optional
93
+ Vector distribution for initial values. Unnamed refers to distributions
94
+ created with the ``.dist()`` API. Distribution should have shape n_states.
95
+ If not, it will be automatically resized.
96
+
97
+ .. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
98
+
99
+ Notes
100
+ -----
101
+ The initial distribution will be cloned, rendering it distinct from the one passed as
102
+ input.
103
+
104
+ Examples
105
+ --------
106
+ Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
107
+ 3 in this case.
108
+
109
+ .. code-block:: python
110
+
111
+ import pymc as pm
112
+ import pymc_extras as pmx
113
+
114
+ with pm.Model() as markov_chain:
115
+ P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116
+ init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117
+ markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
118
+
119
+ """
120
+
121
+ rv_type = DiscreteMarkovChainRV
122
+
123
+ def __new__(cls, *args, steps=None, n_lags=1, **kwargs):
124
+ steps = get_support_shape_1d(
125
+ support_shape=steps,
126
+ shape=None,
127
+ dims=kwargs.get("dims", None),
128
+ observed=kwargs.get("observed", None),
129
+ support_shape_offset=n_lags,
130
+ )
131
+
132
+ return super().__new__(cls, *args, steps=steps, n_lags=n_lags, **kwargs)
133
+
134
+ @classmethod
135
+ def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwargs):
136
+ steps = get_support_shape_1d(
137
+ support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=n_lags
138
+ )
139
+
140
+ if steps is None:
141
+ raise ValueError("Must specify steps or shape parameter")
142
+ if P is None and logit_P is None:
143
+ raise ValueError("Must specify P or logit_P parameter")
144
+ if P is not None and logit_P is not None:
145
+ raise ValueError("Must specify only one of either P or logit_P parameter")
146
+
147
+ if logit_P is not None:
148
+ P = pm.math.softmax(logit_P, axis=-1)
149
+
150
+ P = pt.as_tensor_variable(P)
151
+ steps = pt.as_tensor_variable(intX(steps))
152
+
153
+ if init_dist is not None:
154
+ if not isinstance(init_dist, TensorVariable) or not isinstance(
155
+ init_dist.owner.op, RandomVariable | SymbolicRandomVariable
156
+ ):
157
+ raise ValueError(
158
+ f"Init dist must be a distribution created via the `.dist()` API, "
159
+ f"got {type(init_dist)}"
160
+ )
161
+
162
+ check_dist_not_registered(init_dist)
163
+ if init_dist.owner.op.ndim_supp > 1:
164
+ raise ValueError(
165
+ "Init distribution must have a scalar or vector support dimension, ",
166
+ f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
167
+ )
168
+ else:
169
+ warnings.warn(
170
+ "Initial distribution not specified, defaulting to "
171
+ "`Categorical.dist(p=pt.full((k_states, ), 1/k_states), shape=...)`. You can specify an init_dist "
172
+ "manually to suppress this warning.",
173
+ UserWarning,
174
+ )
175
+ k = P.shape[-1]
176
+ init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))
177
+
178
+ return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)
179
+
180
+ @classmethod
181
+ def rv_op(cls, P, steps, init_dist, n_lags, size=None):
182
+ if size is not None:
183
+ batch_size = size
184
+ else:
185
+ batch_size = pt.broadcast_shape(
186
+ P[tuple([...] + [0] * (n_lags + 1))], pt.atleast_1d(init_dist)[..., 0]
187
+ )
188
+
189
+ init_dist = change_dist_size(init_dist, (n_lags, *batch_size))
190
+ init_dist_ = init_dist.type()
191
+ P_ = P.type()
192
+ steps_ = steps.type()
193
+
194
+ state_rng = pytensor.shared(np.random.default_rng())
195
+
196
+ def transition(*args):
197
+ *states, transition_probs, old_rng = args
198
+ p = transition_probs[tuple(states)]
199
+ next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
200
+ return next_state, {old_rng: next_rng}
201
+
202
+ markov_chain, state_updates = pytensor.scan(
203
+ transition,
204
+ non_sequences=[P_, state_rng],
205
+ outputs_info=_make_outputs_info(n_lags, init_dist_),
206
+ n_steps=steps_,
207
+ strict=True,
208
+ )
209
+
210
+ (state_next_rng,) = tuple(state_updates.values())
211
+
212
+ discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
213
+
214
+ discrete_mc_op = DiscreteMarkovChainRV(
215
+ inputs=[P_, steps_, init_dist_, state_rng],
216
+ outputs=[state_next_rng, discrete_mc_],
217
+ ndim_supp=1,
218
+ n_lags=n_lags,
219
+ )
220
+
221
+ discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
222
+ return discrete_mc
223
+
224
+
225
+ @_change_dist_size.register(DiscreteMarkovChainRV)
226
+ def change_mc_size(op, dist, new_size, expand=False):
227
+ if expand:
228
+ old_size = dist.shape[:-1]
229
+ new_size = tuple(new_size) + tuple(old_size)
230
+
231
+ return DiscreteMarkovChain.rv_op(*dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags)
232
+
233
+
234
+ @_support_point.register(DiscreteMarkovChainRV)
235
+ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
236
+ init_dist_moment = support_point(init_dist)
237
+ n_lags = op.n_lags
238
+
239
+ def greedy_transition(*args):
240
+ *states, transition_probs, old_rng = args
241
+ p = transition_probs[tuple(states)]
242
+ return pt.argmax(p)
243
+
244
+ chain_moment, moment_updates = pytensor.scan(
245
+ greedy_transition,
246
+ non_sequences=[P, state_rng],
247
+ outputs_info=_make_outputs_info(n_lags, init_dist),
248
+ n_steps=steps,
249
+ strict=True,
250
+ )
251
+ chain_moment = pt.concatenate([init_dist_moment, chain_moment])
252
+ return chain_moment
253
+
254
+
255
+ @_logprob.register(DiscreteMarkovChainRV)
256
+ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
257
+ value = values[0]
258
+ n_lags = op.n_lags
259
+
260
+ indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)]
261
+
262
+ mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
263
+ mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)
264
+
265
+ # We cannot leave any RV in the logp graph, even if just for an assert
266
+ [init_dist_leading_dim] = constant_fold(
267
+ [pt.atleast_1d(init_dist).shape[0]], raise_not_constant=False
268
+ )
269
+
270
+ return check_parameters(
271
+ mc_logprob,
272
+ pt.all(pt.eq(P.shape[-(n_lags + 1) :], P.shape[-1])),
273
+ pt.all(pt.allclose(P.sum(axis=-1), 1.0)),
274
+ pt.eq(init_dist_leading_dim, n_lags),
275
+ msg="Last (n_lags + 1) dimensions of P must be square, "
276
+ "P must sum to 1 along the last axis, "
277
+ "First dimension of init_dist must be n_lags",
278
+ )
279
+
280
+
281
+ class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
282
+ name = "discrete_markov_chain_gibbs_metropolis"
283
+
284
+ def __init__(
285
+ self,
286
+ vars,
287
+ proposal="uniform",
288
+ order="random",
289
+ model=None,
290
+ initial_point=None,
291
+ compile_kwargs: dict | None = None,
292
+ **kwargs,
293
+ ):
294
+ model = pm.modelcontext(model)
295
+ vars = get_value_vars_from_user_vars(vars, model)
296
+ if initial_point is None:
297
+ initial_point = model.initial_point()
298
+
299
+ dimcats = []
300
+ # The above variable is a list of pairs (aggregate dimension, number
301
+ # of categories). For example, if vars = [x, y] with x being a 2-D
302
+ # variable with M categories and y being a 3-D variable with N
303
+ # categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
304
+ for v in vars:
305
+ v_init_val = initial_point[v.name]
306
+ rv_var = model.values_to_rvs[v]
307
+ rv_op = rv_var.owner.op
308
+
309
+ if not isinstance(rv_op, DiscreteMarkovChainRV):
310
+ raise TypeError("All variables must be DiscreteMarkovChainRV")
311
+
312
+ k_graph = rv_var.owner.inputs[0].shape[-1]
313
+ (k_graph,) = model.replace_rvs_by_values((k_graph,))
314
+ k = model.compile_fn(
315
+ k_graph,
316
+ inputs=model.value_vars,
317
+ on_unused_input="ignore",
318
+ mode=Mode(linker="py", optimizer=None),
319
+ )(initial_point)
320
+ start = len(dimcats)
321
+ dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]
322
+
323
+ if order == "random":
324
+ self.shuffle_dims = True
325
+ self.dimcats = dimcats
326
+ else:
327
+ if sorted(order) != list(range(len(dimcats))):
328
+ raise ValueError("Argument 'order' has to be a permutation")
329
+ self.shuffle_dims = False
330
+ self.dimcats = [dimcats[j] for j in order]
331
+
332
+ if proposal == "uniform":
333
+ self.astep = self.astep_unif
334
+ elif proposal == "proportional":
335
+ # Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
336
+ self.astep = self.astep_prop
337
+ else:
338
+ raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
339
+
340
+ # Doesn't actually tune, but it's required to emit a sampler stat
341
+ # that indicates whether a draw was done in a tuning phase.
342
+ self.tune = True
343
+
344
+ # We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
345
+ if compile_kwargs is None:
346
+ compile_kwargs = {}
347
+ ArrayStep.__init__(self, vars, [model.compile_logp(**compile_kwargs)], **kwargs)
348
+
349
+ @staticmethod
350
+ def competence(var):
351
+ if isinstance(var.owner.op, DiscreteMarkovChainRV):
352
+ return Competence.IDEAL
353
+ return Competence.INCOMPATIBLE
354
+
355
+
356
+ STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)
@@ -0,0 +1,18 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from pymc_extras.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
17
+
18
+ __all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"]
@@ -0,0 +1,183 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+
16
+ import numpy as np
17
+ import pymc as pm
18
+ import pytensor.tensor as pt
19
+
20
+ from pymc.gp.util import JITTER_DEFAULT, stabilize
21
+ from pytensor.tensor.linalg import cholesky, solve_triangular
22
+
23
+ solve_lower = partial(solve_triangular, lower=True)
24
+ solve_upper = partial(solve_triangular, lower=False)
25
+
26
+
27
+ class LatentApprox(pm.gp.Latent):
28
+ ## TODO: use strings to select approximation, like pm.gp.MarginalApprox?
29
+ pass
30
+
31
+
32
+ class ProjectedProcess(pm.gp.Latent):
33
+ ## AKA: DTC
34
+ def __init__(
35
+ self,
36
+ n_inducing: int | None = None,
37
+ *,
38
+ mean_func=pm.gp.mean.Zero(),
39
+ cov_func=pm.gp.cov.Constant(0.0),
40
+ ):
41
+ self.n_inducing = n_inducing
42
+ super().__init__(mean_func=mean_func, cov_func=cov_func)
43
+
44
+ def _build_prior(self, name, X, X_inducing, jitter=JITTER_DEFAULT, **kwargs):
45
+ mu = self.mean_func(X)
46
+ Kuu = self.cov_func(X_inducing)
47
+ L = cholesky(stabilize(Kuu, jitter))
48
+
49
+ n_inducing_points = np.shape(X_inducing)[0]
50
+ v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs)
51
+ u = pm.Deterministic(name + "_u", L @ v)
52
+
53
+ Kfu = self.cov_func(X, X_inducing)
54
+ Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u))
55
+
56
+ return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L
57
+
58
+ def prior(
59
+ self,
60
+ name: str,
61
+ X: np.ndarray,
62
+ X_inducing: np.ndarray | None = None,
63
+ jitter: float = JITTER_DEFAULT,
64
+ **kwargs,
65
+ ) -> np.ndarray:
66
+ """
67
+ Builds the GP prior with optional inducing points locations.
68
+
69
+ Parameters
70
+ ----------
71
+ - name: Name for the GP variable.
72
+ - X: Input data.
73
+ - X_inducing: Optional. Inducing points for the GP.
74
+ - jitter: Jitter to ensure numerical stability.
75
+
76
+ Returns
77
+ -------
78
+ - GP function
79
+ """
80
+ # Check if X is a numpy array
81
+ if not isinstance(X, np.ndarray):
82
+ raise ValueError("'X' must be a numpy array.")
83
+
84
+ # Proceed with provided X_inducing or determine X_inducing based on n_inducing
85
+ if X_inducing is not None:
86
+ pass # X_inducing is directly used
87
+
88
+ elif self.n_inducing is not None:
89
+ # Validate n_inducing
90
+ if not isinstance(self.n_inducing, int) or self.n_inducing <= 0:
91
+ raise ValueError(
92
+ "The number of inducing points, 'n_inducing', must be a positive integer."
93
+ )
94
+ if self.n_inducing > len(X):
95
+ raise ValueError(
96
+ "The number of inducing points, 'n_inducing', cannot be greater than the number of data points in 'X'."
97
+ )
98
+ # Use k-means to select X_inducing from X based on n_inducing
99
+ X_inducing = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs)
100
+ else:
101
+ # Neither X_inducing nor n_inducing provided
102
+ raise ValueError(
103
+ "Either 'X_inducing' (inducing points) or 'n_inducing' (number of inducing points) must be specified."
104
+ )
105
+
106
+ f, Kuuiu, L = self._build_prior(name, X, X_inducing, jitter, **kwargs)
107
+ self.X, self.X_inducing = X, X_inducing
108
+ self.L, self.Kuuiu = L, Kuuiu
109
+ self.f = f
110
+ return f
111
+
112
+ def _build_conditional(self, name, Xnew, X_inducing, L, Kuuiu, jitter, **kwargs):
113
+ Ksu = self.cov_func(Xnew, X_inducing)
114
+ mu = self.mean_func(Xnew) + Ksu @ Kuuiu
115
+ tmp = solve_lower(L, pt.transpose(Ksu))
116
+ Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
117
+ Kss = self.cov_func(Xnew)
118
+ Lss = cholesky(stabilize(Kss - Qss, jitter))
119
+ return mu, Lss
120
+
121
+ def conditional(self, name, Xnew, jitter=1e-6, **kwargs):
122
+ mu, chol = self._build_conditional(
123
+ name, Xnew, self.X_inducing, self.L, self.Kuuiu, jitter, **kwargs
124
+ )
125
+ return pm.MvNormal(name, mu=mu, chol=chol)
126
+
127
+
128
+ class KarhunenLoeveExpansion(pm.gp.Latent):
129
+ def __init__(
130
+ self,
131
+ variance_limit=None,
132
+ n_eigs=None,
133
+ *,
134
+ mean_func=pm.gp.mean.Zero(),
135
+ cov_func=pm.gp.cov.Constant(0.0),
136
+ ):
137
+ self.variance_limit = variance_limit
138
+ self.n_eigs = n_eigs
139
+ super().__init__(mean_func=mean_func, cov_func=cov_func)
140
+
141
+ def _build_prior(self, name, X, jitter=1e-6, **kwargs):
142
+ self.mean_func(X)
143
+ Kxx = pm.gp.util.stabilize(self.cov_func(X), jitter)
144
+ vals, vecs = pt.linalg.eigh(Kxx)
145
+ ## NOTE: REMOVED PRECISION CUTOFF
146
+ if self.variance_limit is None:
147
+ n_eigs = self.n_eigs
148
+ else:
149
+ if self.variance_limit == 1:
150
+ n_eigs = len(vals)
151
+ else:
152
+ n_eigs = ((vals[::-1].cumsum() / vals.sum()) > self.variance_limit).nonzero()[0][0]
153
+ U = vecs[:, -n_eigs:]
154
+ s = vals[-n_eigs:]
155
+ basis = U * pt.sqrt(s)
156
+
157
+ coefs_raw = pm.Normal(f"_gp_{name}_coefs", mu=0, sigma=1, size=n_eigs)
158
+ # weight = pm.HalfNormal(f"_gp_{name}_sd")
159
+ # coefs = weight * coefs_raw # dont understand this prior, why weight * coeffs_raw?
160
+ f = basis @ coefs_raw
161
+ return f, U, s, n_eigs
162
+
163
+ def prior(self, name, X, jitter=1e-6, **kwargs):
164
+ f, U, s, n_eigs = self._build_prior(name, X, jitter, **kwargs)
165
+ self.U, self.s, self.n_eigs = U, s, n_eigs
166
+ self.X = X
167
+ self.f = f
168
+ return pm.Deterministic(name, f)
169
+
170
+ def _build_conditional(self, Xnew, X, f, U, s, jitter):
171
+ Kxs = self.cov_func(X, Xnew)
172
+ Kss = self.cov_func(Xnew)
173
+ Kxxpinv = U @ pt.diag(1.0 / s) @ U.T
174
+ mus = Kxs.T @ Kxxpinv @ f
175
+ K = Kss - Kxs.T @ Kxxpinv @ Kxs
176
+ L = cholesky(stabilize(K, jitter))
177
+ return mus, L
178
+
179
+ def conditional(self, name, Xnew, jitter=1e-6, **kwargs):
180
+ X, f = self.X, self.f
181
+ U, s = self.U, self.s
182
+ mu, L = self._build_conditional(Xnew, X, f, U, s, jitter)
183
+ return pm.MvNormal(name, mu=mu, chol=L, **kwargs)
@@ -0,0 +1,18 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from pymc_extras.inference.fit import fit
17
+
18
+ __all__ = ["fit"]