pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pymc as pm
|
|
5
|
+
import pytensor
|
|
6
|
+
import pytensor.tensor as pt
|
|
7
|
+
|
|
8
|
+
from pymc.distributions.dist_math import check_parameters
|
|
9
|
+
from pymc.distributions.distribution import (
|
|
10
|
+
Distribution,
|
|
11
|
+
SymbolicRandomVariable,
|
|
12
|
+
_support_point,
|
|
13
|
+
support_point,
|
|
14
|
+
)
|
|
15
|
+
from pymc.distributions.shape_utils import (
|
|
16
|
+
_change_dist_size,
|
|
17
|
+
change_dist_size,
|
|
18
|
+
get_support_shape_1d,
|
|
19
|
+
)
|
|
20
|
+
from pymc.logprob.abstract import _logprob
|
|
21
|
+
from pymc.logprob.basic import logp
|
|
22
|
+
from pymc.pytensorf import constant_fold, intX
|
|
23
|
+
from pymc.step_methods import STEP_METHODS
|
|
24
|
+
from pymc.step_methods.arraystep import ArrayStep
|
|
25
|
+
from pymc.step_methods.compound import Competence
|
|
26
|
+
from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
|
|
27
|
+
from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
|
|
28
|
+
from pytensor import Mode
|
|
29
|
+
from pytensor.graph.basic import Node
|
|
30
|
+
from pytensor.tensor import TensorVariable
|
|
31
|
+
from pytensor.tensor.random.op import RandomVariable
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> list[Distribution | dict]:
|
|
35
|
+
"""
|
|
36
|
+
Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away
|
|
37
|
+
the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of
|
|
38
|
+
desired (steps, *batch_size)
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
n_lags: int
|
|
43
|
+
Number of lags the Markov Chain considers when transitioning to the next state
|
|
44
|
+
init_dist: RandomVariable
|
|
45
|
+
Distribution over initial states
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
taps: list
|
|
50
|
+
Lags to be fed into pytensor.scan when drawing a markov chain
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
if n_lags > 1:
|
|
54
|
+
return [{"initial": init_dist, "taps": list(range(-n_lags, 0))}]
|
|
55
|
+
else:
|
|
56
|
+
return [init_dist[0]]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class DiscreteMarkovChainRV(SymbolicRandomVariable):
|
|
60
|
+
n_lags: int
|
|
61
|
+
default_output = 1
|
|
62
|
+
_print_name = ("DiscreteMC", "\\operatorname{DiscreteMC}")
|
|
63
|
+
|
|
64
|
+
def __init__(self, *args, n_lags, **kwargs):
|
|
65
|
+
self.n_lags = n_lags
|
|
66
|
+
super().__init__(*args, **kwargs)
|
|
67
|
+
|
|
68
|
+
def update(self, node: Node):
|
|
69
|
+
return {node.inputs[-1]: node.outputs[0]}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DiscreteMarkovChain(Distribution):
|
|
73
|
+
r"""
|
|
74
|
+
A Discrete Markov Chain is a sequence of random variables
|
|
75
|
+
|
|
76
|
+
.. math::
|
|
77
|
+
|
|
78
|
+
\{x_t\}_{t=0}^T
|
|
79
|
+
|
|
80
|
+
Where transition probability :math:`P(x_t | x_{t-1})` depends only on the state of the system at :math:`x_{t-1}`.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
P: tensor
|
|
85
|
+
Matrix of transition probabilities between states. Rows must sum to 1.
|
|
86
|
+
One of P or P_logits must be provided.
|
|
87
|
+
P_logit: tensor, optional
|
|
88
|
+
Matrix of transition logits. Converted to probabilities via Softmax activation.
|
|
89
|
+
One of P or P_logits must be provided.
|
|
90
|
+
steps: tensor, optional
|
|
91
|
+
Length of the markov chain. Only needed if state is not provided.
|
|
92
|
+
init_dist : unnamed distribution, optional
|
|
93
|
+
Vector distribution for initial values. Unnamed refers to distributions
|
|
94
|
+
created with the ``.dist()`` API. Distribution should have shape n_states.
|
|
95
|
+
If not, it will be automatically resized.
|
|
96
|
+
|
|
97
|
+
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
|
|
98
|
+
|
|
99
|
+
Notes
|
|
100
|
+
-----
|
|
101
|
+
The initial distribution will be cloned, rendering it distinct from the one passed as
|
|
102
|
+
input.
|
|
103
|
+
|
|
104
|
+
Examples
|
|
105
|
+
--------
|
|
106
|
+
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
|
|
107
|
+
3 in this case.
|
|
108
|
+
|
|
109
|
+
.. code-block:: python
|
|
110
|
+
|
|
111
|
+
import pymc as pm
|
|
112
|
+
import pymc_extras as pmx
|
|
113
|
+
|
|
114
|
+
with pm.Model() as markov_chain:
|
|
115
|
+
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
|
|
116
|
+
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
|
|
117
|
+
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
|
|
118
|
+
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
rv_type = DiscreteMarkovChainRV
|
|
122
|
+
|
|
123
|
+
def __new__(cls, *args, steps=None, n_lags=1, **kwargs):
|
|
124
|
+
steps = get_support_shape_1d(
|
|
125
|
+
support_shape=steps,
|
|
126
|
+
shape=None,
|
|
127
|
+
dims=kwargs.get("dims", None),
|
|
128
|
+
observed=kwargs.get("observed", None),
|
|
129
|
+
support_shape_offset=n_lags,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return super().__new__(cls, *args, steps=steps, n_lags=n_lags, **kwargs)
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwargs):
|
|
136
|
+
steps = get_support_shape_1d(
|
|
137
|
+
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=n_lags
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if steps is None:
|
|
141
|
+
raise ValueError("Must specify steps or shape parameter")
|
|
142
|
+
if P is None and logit_P is None:
|
|
143
|
+
raise ValueError("Must specify P or logit_P parameter")
|
|
144
|
+
if P is not None and logit_P is not None:
|
|
145
|
+
raise ValueError("Must specify only one of either P or logit_P parameter")
|
|
146
|
+
|
|
147
|
+
if logit_P is not None:
|
|
148
|
+
P = pm.math.softmax(logit_P, axis=-1)
|
|
149
|
+
|
|
150
|
+
P = pt.as_tensor_variable(P)
|
|
151
|
+
steps = pt.as_tensor_variable(intX(steps))
|
|
152
|
+
|
|
153
|
+
if init_dist is not None:
|
|
154
|
+
if not isinstance(init_dist, TensorVariable) or not isinstance(
|
|
155
|
+
init_dist.owner.op, RandomVariable | SymbolicRandomVariable
|
|
156
|
+
):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Init dist must be a distribution created via the `.dist()` API, "
|
|
159
|
+
f"got {type(init_dist)}"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
check_dist_not_registered(init_dist)
|
|
163
|
+
if init_dist.owner.op.ndim_supp > 1:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"Init distribution must have a scalar or vector support dimension, ",
|
|
166
|
+
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
warnings.warn(
|
|
170
|
+
"Initial distribution not specified, defaulting to "
|
|
171
|
+
"`Categorical.dist(p=pt.full((k_states, ), 1/k_states), shape=...)`. You can specify an init_dist "
|
|
172
|
+
"manually to suppress this warning.",
|
|
173
|
+
UserWarning,
|
|
174
|
+
)
|
|
175
|
+
k = P.shape[-1]
|
|
176
|
+
init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))
|
|
177
|
+
|
|
178
|
+
return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def rv_op(cls, P, steps, init_dist, n_lags, size=None):
|
|
182
|
+
if size is not None:
|
|
183
|
+
batch_size = size
|
|
184
|
+
else:
|
|
185
|
+
batch_size = pt.broadcast_shape(
|
|
186
|
+
P[tuple([...] + [0] * (n_lags + 1))], pt.atleast_1d(init_dist)[..., 0]
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
init_dist = change_dist_size(init_dist, (n_lags, *batch_size))
|
|
190
|
+
init_dist_ = init_dist.type()
|
|
191
|
+
P_ = P.type()
|
|
192
|
+
steps_ = steps.type()
|
|
193
|
+
|
|
194
|
+
state_rng = pytensor.shared(np.random.default_rng())
|
|
195
|
+
|
|
196
|
+
def transition(*args):
|
|
197
|
+
*states, transition_probs, old_rng = args
|
|
198
|
+
p = transition_probs[tuple(states)]
|
|
199
|
+
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
|
|
200
|
+
return next_state, {old_rng: next_rng}
|
|
201
|
+
|
|
202
|
+
markov_chain, state_updates = pytensor.scan(
|
|
203
|
+
transition,
|
|
204
|
+
non_sequences=[P_, state_rng],
|
|
205
|
+
outputs_info=_make_outputs_info(n_lags, init_dist_),
|
|
206
|
+
n_steps=steps_,
|
|
207
|
+
strict=True,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
(state_next_rng,) = tuple(state_updates.values())
|
|
211
|
+
|
|
212
|
+
discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
|
|
213
|
+
|
|
214
|
+
discrete_mc_op = DiscreteMarkovChainRV(
|
|
215
|
+
inputs=[P_, steps_, init_dist_, state_rng],
|
|
216
|
+
outputs=[state_next_rng, discrete_mc_],
|
|
217
|
+
ndim_supp=1,
|
|
218
|
+
n_lags=n_lags,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
|
|
222
|
+
return discrete_mc
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@_change_dist_size.register(DiscreteMarkovChainRV)
|
|
226
|
+
def change_mc_size(op, dist, new_size, expand=False):
|
|
227
|
+
if expand:
|
|
228
|
+
old_size = dist.shape[:-1]
|
|
229
|
+
new_size = tuple(new_size) + tuple(old_size)
|
|
230
|
+
|
|
231
|
+
return DiscreteMarkovChain.rv_op(*dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@_support_point.register(DiscreteMarkovChainRV)
|
|
235
|
+
def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
|
|
236
|
+
init_dist_moment = support_point(init_dist)
|
|
237
|
+
n_lags = op.n_lags
|
|
238
|
+
|
|
239
|
+
def greedy_transition(*args):
|
|
240
|
+
*states, transition_probs, old_rng = args
|
|
241
|
+
p = transition_probs[tuple(states)]
|
|
242
|
+
return pt.argmax(p)
|
|
243
|
+
|
|
244
|
+
chain_moment, moment_updates = pytensor.scan(
|
|
245
|
+
greedy_transition,
|
|
246
|
+
non_sequences=[P, state_rng],
|
|
247
|
+
outputs_info=_make_outputs_info(n_lags, init_dist),
|
|
248
|
+
n_steps=steps,
|
|
249
|
+
strict=True,
|
|
250
|
+
)
|
|
251
|
+
chain_moment = pt.concatenate([init_dist_moment, chain_moment])
|
|
252
|
+
return chain_moment
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@_logprob.register(DiscreteMarkovChainRV)
|
|
256
|
+
def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
|
|
257
|
+
value = values[0]
|
|
258
|
+
n_lags = op.n_lags
|
|
259
|
+
|
|
260
|
+
indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)]
|
|
261
|
+
|
|
262
|
+
mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
|
|
263
|
+
mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)
|
|
264
|
+
|
|
265
|
+
# We cannot leave any RV in the logp graph, even if just for an assert
|
|
266
|
+
[init_dist_leading_dim] = constant_fold(
|
|
267
|
+
[pt.atleast_1d(init_dist).shape[0]], raise_not_constant=False
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
return check_parameters(
|
|
271
|
+
mc_logprob,
|
|
272
|
+
pt.all(pt.eq(P.shape[-(n_lags + 1) :], P.shape[-1])),
|
|
273
|
+
pt.all(pt.allclose(P.sum(axis=-1), 1.0)),
|
|
274
|
+
pt.eq(init_dist_leading_dim, n_lags),
|
|
275
|
+
msg="Last (n_lags + 1) dimensions of P must be square, "
|
|
276
|
+
"P must sum to 1 along the last axis, "
|
|
277
|
+
"First dimension of init_dist must be n_lags",
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
|
|
282
|
+
name = "discrete_markov_chain_gibbs_metropolis"
|
|
283
|
+
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
vars,
|
|
287
|
+
proposal="uniform",
|
|
288
|
+
order="random",
|
|
289
|
+
model=None,
|
|
290
|
+
initial_point=None,
|
|
291
|
+
compile_kwargs: dict | None = None,
|
|
292
|
+
**kwargs,
|
|
293
|
+
):
|
|
294
|
+
model = pm.modelcontext(model)
|
|
295
|
+
vars = get_value_vars_from_user_vars(vars, model)
|
|
296
|
+
if initial_point is None:
|
|
297
|
+
initial_point = model.initial_point()
|
|
298
|
+
|
|
299
|
+
dimcats = []
|
|
300
|
+
# The above variable is a list of pairs (aggregate dimension, number
|
|
301
|
+
# of categories). For example, if vars = [x, y] with x being a 2-D
|
|
302
|
+
# variable with M categories and y being a 3-D variable with N
|
|
303
|
+
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
|
|
304
|
+
for v in vars:
|
|
305
|
+
v_init_val = initial_point[v.name]
|
|
306
|
+
rv_var = model.values_to_rvs[v]
|
|
307
|
+
rv_op = rv_var.owner.op
|
|
308
|
+
|
|
309
|
+
if not isinstance(rv_op, DiscreteMarkovChainRV):
|
|
310
|
+
raise TypeError("All variables must be DiscreteMarkovChainRV")
|
|
311
|
+
|
|
312
|
+
k_graph = rv_var.owner.inputs[0].shape[-1]
|
|
313
|
+
(k_graph,) = model.replace_rvs_by_values((k_graph,))
|
|
314
|
+
k = model.compile_fn(
|
|
315
|
+
k_graph,
|
|
316
|
+
inputs=model.value_vars,
|
|
317
|
+
on_unused_input="ignore",
|
|
318
|
+
mode=Mode(linker="py", optimizer=None),
|
|
319
|
+
)(initial_point)
|
|
320
|
+
start = len(dimcats)
|
|
321
|
+
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]
|
|
322
|
+
|
|
323
|
+
if order == "random":
|
|
324
|
+
self.shuffle_dims = True
|
|
325
|
+
self.dimcats = dimcats
|
|
326
|
+
else:
|
|
327
|
+
if sorted(order) != list(range(len(dimcats))):
|
|
328
|
+
raise ValueError("Argument 'order' has to be a permutation")
|
|
329
|
+
self.shuffle_dims = False
|
|
330
|
+
self.dimcats = [dimcats[j] for j in order]
|
|
331
|
+
|
|
332
|
+
if proposal == "uniform":
|
|
333
|
+
self.astep = self.astep_unif
|
|
334
|
+
elif proposal == "proportional":
|
|
335
|
+
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
|
|
336
|
+
self.astep = self.astep_prop
|
|
337
|
+
else:
|
|
338
|
+
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
|
|
339
|
+
|
|
340
|
+
# Doesn't actually tune, but it's required to emit a sampler stat
|
|
341
|
+
# that indicates whether a draw was done in a tuning phase.
|
|
342
|
+
self.tune = True
|
|
343
|
+
|
|
344
|
+
# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
|
|
345
|
+
if compile_kwargs is None:
|
|
346
|
+
compile_kwargs = {}
|
|
347
|
+
ArrayStep.__init__(self, vars, [model.compile_logp(**compile_kwargs)], **kwargs)
|
|
348
|
+
|
|
349
|
+
@staticmethod
|
|
350
|
+
def competence(var):
|
|
351
|
+
if isinstance(var.owner.op, DiscreteMarkovChainRV):
|
|
352
|
+
return Competence.IDEAL
|
|
353
|
+
return Competence.INCOMPATIBLE
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from pymc_extras.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
|
|
17
|
+
|
|
18
|
+
__all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"]
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from functools import partial
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pymc as pm
|
|
18
|
+
import pytensor.tensor as pt
|
|
19
|
+
|
|
20
|
+
from pymc.gp.util import JITTER_DEFAULT, stabilize
|
|
21
|
+
from pytensor.tensor.linalg import cholesky, solve_triangular
|
|
22
|
+
|
|
23
|
+
solve_lower = partial(solve_triangular, lower=True)
|
|
24
|
+
solve_upper = partial(solve_triangular, lower=False)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LatentApprox(pm.gp.Latent):
|
|
28
|
+
## TODO: use strings to select approximation, like pm.gp.MarginalApprox?
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ProjectedProcess(pm.gp.Latent):
|
|
33
|
+
## AKA: DTC
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
n_inducing: int | None = None,
|
|
37
|
+
*,
|
|
38
|
+
mean_func=pm.gp.mean.Zero(),
|
|
39
|
+
cov_func=pm.gp.cov.Constant(0.0),
|
|
40
|
+
):
|
|
41
|
+
self.n_inducing = n_inducing
|
|
42
|
+
super().__init__(mean_func=mean_func, cov_func=cov_func)
|
|
43
|
+
|
|
44
|
+
def _build_prior(self, name, X, X_inducing, jitter=JITTER_DEFAULT, **kwargs):
|
|
45
|
+
mu = self.mean_func(X)
|
|
46
|
+
Kuu = self.cov_func(X_inducing)
|
|
47
|
+
L = cholesky(stabilize(Kuu, jitter))
|
|
48
|
+
|
|
49
|
+
n_inducing_points = np.shape(X_inducing)[0]
|
|
50
|
+
v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs)
|
|
51
|
+
u = pm.Deterministic(name + "_u", L @ v)
|
|
52
|
+
|
|
53
|
+
Kfu = self.cov_func(X, X_inducing)
|
|
54
|
+
Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u))
|
|
55
|
+
|
|
56
|
+
return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L
|
|
57
|
+
|
|
58
|
+
def prior(
|
|
59
|
+
self,
|
|
60
|
+
name: str,
|
|
61
|
+
X: np.ndarray,
|
|
62
|
+
X_inducing: np.ndarray | None = None,
|
|
63
|
+
jitter: float = JITTER_DEFAULT,
|
|
64
|
+
**kwargs,
|
|
65
|
+
) -> np.ndarray:
|
|
66
|
+
"""
|
|
67
|
+
Builds the GP prior with optional inducing points locations.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
- name: Name for the GP variable.
|
|
72
|
+
- X: Input data.
|
|
73
|
+
- X_inducing: Optional. Inducing points for the GP.
|
|
74
|
+
- jitter: Jitter to ensure numerical stability.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
- GP function
|
|
79
|
+
"""
|
|
80
|
+
# Check if X is a numpy array
|
|
81
|
+
if not isinstance(X, np.ndarray):
|
|
82
|
+
raise ValueError("'X' must be a numpy array.")
|
|
83
|
+
|
|
84
|
+
# Proceed with provided X_inducing or determine X_inducing based on n_inducing
|
|
85
|
+
if X_inducing is not None:
|
|
86
|
+
pass # X_inducing is directly used
|
|
87
|
+
|
|
88
|
+
elif self.n_inducing is not None:
|
|
89
|
+
# Validate n_inducing
|
|
90
|
+
if not isinstance(self.n_inducing, int) or self.n_inducing <= 0:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"The number of inducing points, 'n_inducing', must be a positive integer."
|
|
93
|
+
)
|
|
94
|
+
if self.n_inducing > len(X):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"The number of inducing points, 'n_inducing', cannot be greater than the number of data points in 'X'."
|
|
97
|
+
)
|
|
98
|
+
# Use k-means to select X_inducing from X based on n_inducing
|
|
99
|
+
X_inducing = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs)
|
|
100
|
+
else:
|
|
101
|
+
# Neither X_inducing nor n_inducing provided
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"Either 'X_inducing' (inducing points) or 'n_inducing' (number of inducing points) must be specified."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
f, Kuuiu, L = self._build_prior(name, X, X_inducing, jitter, **kwargs)
|
|
107
|
+
self.X, self.X_inducing = X, X_inducing
|
|
108
|
+
self.L, self.Kuuiu = L, Kuuiu
|
|
109
|
+
self.f = f
|
|
110
|
+
return f
|
|
111
|
+
|
|
112
|
+
def _build_conditional(self, name, Xnew, X_inducing, L, Kuuiu, jitter, **kwargs):
|
|
113
|
+
Ksu = self.cov_func(Xnew, X_inducing)
|
|
114
|
+
mu = self.mean_func(Xnew) + Ksu @ Kuuiu
|
|
115
|
+
tmp = solve_lower(L, pt.transpose(Ksu))
|
|
116
|
+
Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
|
|
117
|
+
Kss = self.cov_func(Xnew)
|
|
118
|
+
Lss = cholesky(stabilize(Kss - Qss, jitter))
|
|
119
|
+
return mu, Lss
|
|
120
|
+
|
|
121
|
+
def conditional(self, name, Xnew, jitter=1e-6, **kwargs):
|
|
122
|
+
mu, chol = self._build_conditional(
|
|
123
|
+
name, Xnew, self.X_inducing, self.L, self.Kuuiu, jitter, **kwargs
|
|
124
|
+
)
|
|
125
|
+
return pm.MvNormal(name, mu=mu, chol=chol)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class KarhunenLoeveExpansion(pm.gp.Latent):
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
variance_limit=None,
|
|
132
|
+
n_eigs=None,
|
|
133
|
+
*,
|
|
134
|
+
mean_func=pm.gp.mean.Zero(),
|
|
135
|
+
cov_func=pm.gp.cov.Constant(0.0),
|
|
136
|
+
):
|
|
137
|
+
self.variance_limit = variance_limit
|
|
138
|
+
self.n_eigs = n_eigs
|
|
139
|
+
super().__init__(mean_func=mean_func, cov_func=cov_func)
|
|
140
|
+
|
|
141
|
+
def _build_prior(self, name, X, jitter=1e-6, **kwargs):
|
|
142
|
+
self.mean_func(X)
|
|
143
|
+
Kxx = pm.gp.util.stabilize(self.cov_func(X), jitter)
|
|
144
|
+
vals, vecs = pt.linalg.eigh(Kxx)
|
|
145
|
+
## NOTE: REMOVED PRECISION CUTOFF
|
|
146
|
+
if self.variance_limit is None:
|
|
147
|
+
n_eigs = self.n_eigs
|
|
148
|
+
else:
|
|
149
|
+
if self.variance_limit == 1:
|
|
150
|
+
n_eigs = len(vals)
|
|
151
|
+
else:
|
|
152
|
+
n_eigs = ((vals[::-1].cumsum() / vals.sum()) > self.variance_limit).nonzero()[0][0]
|
|
153
|
+
U = vecs[:, -n_eigs:]
|
|
154
|
+
s = vals[-n_eigs:]
|
|
155
|
+
basis = U * pt.sqrt(s)
|
|
156
|
+
|
|
157
|
+
coefs_raw = pm.Normal(f"_gp_{name}_coefs", mu=0, sigma=1, size=n_eigs)
|
|
158
|
+
# weight = pm.HalfNormal(f"_gp_{name}_sd")
|
|
159
|
+
# coefs = weight * coefs_raw # dont understand this prior, why weight * coeffs_raw?
|
|
160
|
+
f = basis @ coefs_raw
|
|
161
|
+
return f, U, s, n_eigs
|
|
162
|
+
|
|
163
|
+
def prior(self, name, X, jitter=1e-6, **kwargs):
|
|
164
|
+
f, U, s, n_eigs = self._build_prior(name, X, jitter, **kwargs)
|
|
165
|
+
self.U, self.s, self.n_eigs = U, s, n_eigs
|
|
166
|
+
self.X = X
|
|
167
|
+
self.f = f
|
|
168
|
+
return pm.Deterministic(name, f)
|
|
169
|
+
|
|
170
|
+
def _build_conditional(self, Xnew, X, f, U, s, jitter):
|
|
171
|
+
Kxs = self.cov_func(X, Xnew)
|
|
172
|
+
Kss = self.cov_func(Xnew)
|
|
173
|
+
Kxxpinv = U @ pt.diag(1.0 / s) @ U.T
|
|
174
|
+
mus = Kxs.T @ Kxxpinv @ f
|
|
175
|
+
K = Kss - Kxs.T @ Kxxpinv @ Kxs
|
|
176
|
+
L = cholesky(stabilize(K, jitter))
|
|
177
|
+
return mus, L
|
|
178
|
+
|
|
179
|
+
def conditional(self, name, Xnew, jitter=1e-6, **kwargs):
|
|
180
|
+
X, f = self.X, self.f
|
|
181
|
+
U, s = self.U, self.s
|
|
182
|
+
mu, L = self._build_conditional(Xnew, X, f, U, s, jitter)
|
|
183
|
+
return pm.MvNormal(name, mu=mu, chol=L, **kwargs)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from pymc_extras.inference.fit import fit
|
|
17
|
+
|
|
18
|
+
__all__ = ["fit"]
|