pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,434 @@
1
+ import logging
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass
5
+ from functools import singledispatch
6
+
7
+ import numpy as np
8
+ import pymc as pm
9
+ import pytensor
10
+ import pytensor.tensor as pt
11
+ import scipy.special
12
+
13
+ from pymc.distributions import SymbolicRandomVariable
14
+ from pymc.logprob.transforms import Transform
15
+ from pymc.model.fgraph import (
16
+ ModelDeterministic,
17
+ ModelNamed,
18
+ fgraph_from_model,
19
+ model_deterministic,
20
+ model_free_rv,
21
+ model_from_fgraph,
22
+ model_named,
23
+ )
24
+ from pymc.pytensorf import toposort_replace
25
+ from pytensor.graph.basic import Apply, Variable
26
+ from pytensor.tensor.random.op import RandomVariable
27
+
28
+ _log = logging.getLogger("pmx")
29
+
30
+
31
+ @dataclass
32
+ class VIP:
33
+ r"""Helper to reparemetrize VIP model.
34
+
35
+ Manipulation of :math:`\lambda` in the below equation is done using this helper class.
36
+
37
+ .. math::
38
+
39
+ \begin{align*}
40
+ \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\
41
+ \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu)
42
+ \sim \text{normal}(\mu, \sigma).
43
+ \end{align*}
44
+ """
45
+
46
+ _logit_lambda: dict[str, pytensor.tensor.sharedvar.TensorSharedVariable]
47
+
48
+ @property
49
+ def variational_parameters(self) -> list[pytensor.tensor.sharedvar.TensorSharedVariable]:
50
+ r"""Return raw :math:`\operatorname{logit}(\lambda_k)` for custom optimization.
51
+
52
+ Examples
53
+ --------
54
+ with model:
55
+ # set all parameterizations to mix of centered and non-centered
56
+ vip.set_all_lambda(0.5)
57
+
58
+ pm.fit(more_obj_params=vip.variational_parameters, method="fullrank_advi")
59
+ """
60
+ return list(self._logit_lambda.values())
61
+
62
+ def truncate_lambda(self, **kwargs: float):
63
+ r"""Truncate :math:`\lambda_k` with :math:`\varepsilon`.
64
+
65
+ .. math::
66
+
67
+ \hat \lambda_k = \begin{cases}
68
+ 0, \quad &\lambda_k \le \varepsilon\\
69
+ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\
70
+ 1, \quad &\lambda_k \ge 1-\varepsilon\\
71
+ \end{cases}
72
+
73
+ Parameters
74
+ ----------
75
+ kwargs : Dict[str, float]
76
+ Variable to :math:`\varepsilon` mapping.
77
+ If :math:`\lambda` (or :math:`1-\lambda`) is not passing
78
+ the threshold of :math:`\varepsilon`, it will be clipped
79
+ to 1 or zero if rounding is turned on.
80
+ """
81
+ lambdas = self.get_lambda()
82
+ update = dict()
83
+ for var, eps in kwargs.items():
84
+ lam = lambdas[var]
85
+ update[var] = np.piecewise(
86
+ lam,
87
+ [lam < eps, lam > (1 - eps)],
88
+ [0, 1, lambda x: x],
89
+ )
90
+ self.set_lambda(**update)
91
+
92
+ def truncate_all_lambda(self, value: float):
93
+ r"""Truncate all :math:`\lambda_k` with :math:`\varepsilon`.
94
+
95
+ .. math::
96
+
97
+ \hat \lambda_k = \begin{cases}
98
+ 0, \quad &\lambda_k \le \varepsilon\\
99
+ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\
100
+ 1, \quad &\lambda_k \ge 1-\varepsilon\\
101
+ \end{cases}
102
+
103
+
104
+
105
+ Parameters
106
+ ----------
107
+ value : float
108
+ :math:`\varepsilon`
109
+ """
110
+ truncate = dict.fromkeys(
111
+ self._logit_lambda.keys(),
112
+ value,
113
+ )
114
+ self.truncate_lambda(**truncate)
115
+
116
+ def get_lambda(self) -> dict[str, np.ndarray]:
117
+ r"""Get :math:`\lambda_k` that are currently used by the model.
118
+
119
+ Returns
120
+ -------
121
+ Dict[str, np.ndarray]
122
+ Mapping from variable name to :math:`\lambda_k`.
123
+ """
124
+ return {
125
+ name: scipy.special.expit(shared.get_value())
126
+ for name, shared in self._logit_lambda.items()
127
+ }
128
+
129
+ def set_lambda(self, **kwargs: dict[str, np.ndarray | float]):
130
+ r"""Set :math:`\lambda_k` per variable."""
131
+ for key, value in kwargs.items():
132
+ logit_lam = scipy.special.logit(value)
133
+ shared = self._logit_lambda[key]
134
+ fill = np.broadcast_to(
135
+ logit_lam,
136
+ shared.type.shape,
137
+ )
138
+ shared.set_value(fill)
139
+
140
+ def set_all_lambda(self, value: np.ndarray | float):
141
+ r"""Set :math:`\lambda_k` globally."""
142
+ config = dict.fromkeys(
143
+ self._logit_lambda.keys(),
144
+ value,
145
+ )
146
+ self.set_lambda(**config)
147
+
148
+ def fit(self, *args, **kwargs) -> pm.Approximation:
149
+ r"""Set :math:`\lambda_k` using Variational Inference.
150
+
151
+ Examples
152
+ --------
153
+
154
+ .. code-block:: python
155
+
156
+ with model:
157
+ # set all parameterizations to mix of centered and non-centered
158
+ vip.set_all_lambda(0.5)
159
+
160
+ # fit using ADVI
161
+ mf = vip.fit(random_seed=42)
162
+ """
163
+ kwargs.setdefault("obj_optimizer", pm.adagrad_window(learning_rate=0.1))
164
+ kwargs.setdefault("method", "advi")
165
+ return pm.fit(
166
+ *args,
167
+ more_obj_params=self.variational_parameters,
168
+ **kwargs,
169
+ )
170
+
171
+
172
+ def vip_reparam_node(
173
+ op: RandomVariable,
174
+ node: Apply,
175
+ name: str,
176
+ dims: list[Variable],
177
+ transform: Transform | None,
178
+ ) -> tuple[ModelDeterministic, ModelNamed]:
179
+ if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
180
+ raise TypeError("Op should be RandomVariable type")
181
+ # FIXME: This is wrong when size is None
182
+ _, size, *_ = node.inputs
183
+ eval_size = size.eval(mode="FAST_COMPILE")
184
+ if eval_size is not None:
185
+ rv_shape = tuple(eval_size)
186
+ else:
187
+ rv_shape = ()
188
+ lam_name = f"{name}::lam_logit__"
189
+ _log.debug(f"Creating {lam_name} with shape of {rv_shape}")
190
+ logit_lam_ = pytensor.shared(
191
+ np.zeros(rv_shape),
192
+ shape=rv_shape,
193
+ name=lam_name,
194
+ )
195
+ logit_lam = model_named(logit_lam_, *dims)
196
+ lam = pt.sigmoid(logit_lam)
197
+ return (
198
+ _vip_reparam_node(
199
+ op,
200
+ node=node,
201
+ name=name,
202
+ dims=dims,
203
+ transform=transform,
204
+ lam=lam,
205
+ ),
206
+ logit_lam,
207
+ )
208
+
209
+
210
+ @singledispatch
211
+ def _vip_reparam_node(
212
+ op: RandomVariable,
213
+ node: Apply,
214
+ name: str,
215
+ dims: list[Variable],
216
+ transform: Transform | None,
217
+ lam: pt.TensorVariable,
218
+ ) -> ModelDeterministic:
219
+ raise NotImplementedError
220
+
221
+
222
+ @_vip_reparam_node.register
223
+ def _(
224
+ op: pm.Normal,
225
+ node: Apply,
226
+ name: str,
227
+ dims: list[Variable],
228
+ transform: Transform | None,
229
+ lam: pt.TensorVariable,
230
+ ) -> ModelDeterministic:
231
+ rng, size, loc, scale = node.inputs
232
+ if transform is not None:
233
+ raise NotImplementedError("Reparametrization of Normal with Transform is not implemented")
234
+ vip_rv_ = pm.Normal.dist(
235
+ lam * loc,
236
+ scale**lam,
237
+ size=size,
238
+ rng=rng,
239
+ )
240
+ vip_rv_.name = f"{name}::tau_"
241
+
242
+ vip_rv = model_free_rv(
243
+ vip_rv_,
244
+ vip_rv_.clone(),
245
+ None,
246
+ *dims,
247
+ )
248
+
249
+ vip_rep_ = loc + scale ** (1 - lam) * (vip_rv - lam * loc)
250
+
251
+ vip_rep_.name = name
252
+
253
+ vip_rep = model_deterministic(vip_rep_, *dims)
254
+ return vip_rep
255
+
256
+
257
+ @_vip_reparam_node.register
258
+ def _(
259
+ op: pm.Exponential,
260
+ node: Apply,
261
+ name: str,
262
+ dims: list[Variable],
263
+ transform: Transform | None,
264
+ lam: pt.TensorVariable,
265
+ ) -> ModelDeterministic:
266
+ rng, size, scale = node.inputs
267
+ scale_centered = scale**lam
268
+ scale_noncentered = scale ** (1 - lam)
269
+ vip_rv_ = pm.Exponential.dist(
270
+ scale=scale_centered,
271
+ size=size,
272
+ rng=rng,
273
+ )
274
+ vip_rv_value_ = vip_rv_.clone()
275
+ vip_rv_.name = f"{name}::tau_"
276
+ if transform is not None:
277
+ vip_rv_value_.name = f"{vip_rv_.name}_{transform.name}__"
278
+ else:
279
+ vip_rv_value_.name = vip_rv_.name
280
+ vip_rv = model_free_rv(
281
+ vip_rv_,
282
+ vip_rv_value_,
283
+ transform,
284
+ *dims,
285
+ )
286
+
287
+ vip_rep_ = scale_noncentered * vip_rv
288
+
289
+ vip_rep_.name = name
290
+
291
+ vip_rep = model_deterministic(vip_rep_, *dims)
292
+ return vip_rep
293
+
294
+
295
+ def vip_reparametrize(
296
+ model: pm.Model,
297
+ var_names: Sequence[str],
298
+ ) -> tuple[pm.Model, VIP]:
299
+ r"""Repametrize Model using Variationally Informed Parametrization (VIP).
300
+
301
+ .. math::
302
+
303
+ \begin{align*}
304
+ \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\
305
+ \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu)
306
+ \sim \text{normal}(\mu, \sigma).
307
+ \end{align*}
308
+
309
+ Parameters
310
+ ----------
311
+ model : Model
312
+ Model with centered parameterizations for variables.
313
+ var_names : Sequence[str]
314
+ Target variables to reparemetrize.
315
+
316
+ Returns
317
+ -------
318
+ Tuple[Model, VIP]
319
+ Updated model and VIP helper to reparametrize or infer parametrization of the model.
320
+
321
+ Examples
322
+ --------
323
+ The traditional eight schools.
324
+
325
+ .. code-block:: python
326
+
327
+ import pymc as pm
328
+ import numpy as np
329
+
330
+ J = 8
331
+ y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
332
+ sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
333
+
334
+ with pm.Model() as Centered_eight:
335
+ mu = pm.Normal("mu", mu=0, sigma=5)
336
+ tau = pm.HalfCauchy("tau", beta=5)
337
+ theta = pm.Normal("theta", mu=mu, sigma=tau, shape=J)
338
+ obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y)
339
+
340
+ The regular model definition with centered parametrization is sufficient to use VIP.
341
+ To change the model parametrization use the following function.
342
+
343
+ .. code-block:: python
344
+
345
+ from pymc_extras.model.transforms.autoreparam import vip_reparametrize
346
+ Reparam_eight, vip = vip_reparametrize(Centered_eight, ["theta"])
347
+
348
+ with Reparam_eight:
349
+ # set all parameterizations to cenered (not needed)
350
+ vip.set_all_lambda(1)
351
+
352
+ # set all parameterizations to non-cenered (desired)
353
+ vip.set_all_lambda(0)
354
+
355
+ # or per variable
356
+ vip.set_lambda(theta=0)
357
+
358
+ # just set non-centered parameterization
359
+ trace = pm.sample()
360
+
361
+ However, setting it manually is not always great experience, we can learn it.
362
+
363
+ .. code-block:: python
364
+
365
+ with Reparam_eight:
366
+ # set all parameterizations to mix of centered and non-centered
367
+ vip.set_all_lambda(0.5)
368
+
369
+ # fit using ADVI
370
+ mf = vip.fit(random_seed=42)
371
+
372
+ # display lambdas
373
+ print(vip.get_lambda())
374
+
375
+ # {'theta': array([0.01473405, 0.02221006, 0.03656685, 0.03798879, 0.04876761,
376
+ # 0.0300203 , 0.02733082, 0.01817754])}
377
+
378
+ Now you can use sampling again:
379
+
380
+ .. code-block:: python
381
+
382
+ with Reparam_eight:
383
+ trace = pm.sample()
384
+
385
+ Sometimes it makes sense to enable clipping (that is off by default).
386
+ The idea is to round :math:`\varepsilon` to the closest extremum (:math:`0` or :math:`0`)
387
+
388
+ .. math::
389
+
390
+ \hat \lambda_k = \begin{cases}
391
+ 0, \quad &\lambda_k \le \varepsilon\\
392
+ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\
393
+ 1, \quad &\lambda_k \ge 1-\varepsilon
394
+ \end{cases}
395
+
396
+ .. code-block:: python
397
+
398
+ vip.truncate_all_lambda(0.1)
399
+
400
+ Sampling has to be performed again
401
+
402
+ .. code-block:: python
403
+
404
+ with Reparam_eight:
405
+ trace = pm.sample()
406
+
407
+ References
408
+ ----------
409
+ - Automatic Reparameterisation of Probabilistic Programs,
410
+ Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
411
+ """
412
+ fmodel, memo = fgraph_from_model(model)
413
+ lambda_names = []
414
+ replacements = []
415
+ for name in var_names:
416
+ old = memo[model.named_vars[name]]
417
+ rv, _, *dims = old.owner.inputs
418
+ new, lam = vip_reparam_node(
419
+ rv.owner.op,
420
+ rv.owner,
421
+ name=rv.name,
422
+ dims=dims,
423
+ transform=old.owner.op.transform,
424
+ )
425
+ replacements.append((old, new))
426
+ lambda_names.append(lam.name)
427
+ toposort_replace(fmodel, replacements, reverse=True)
428
+ reparam_model = model_from_fgraph(fmodel)
429
+ model_lambdas = {
430
+ var_name: reparam_model[lambda_name]
431
+ for lambda_name, var_name in zip(lambda_names, var_names)
432
+ }
433
+ vip = VIP(model_lambdas)
434
+ return reparam_model, vip