pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,15 @@
1
+ from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
2
+ from pymc_extras.statespace.filters.kalman_filter import (
3
+ SquareRootFilter,
4
+ StandardFilter,
5
+ UnivariateFilter,
6
+ )
7
+ from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
8
+
9
+ __all__ = [
10
+ "StandardFilter",
11
+ "UnivariateFilter",
12
+ "KalmanSmoother",
13
+ "SquareRootFilter",
14
+ "LinearGaussianStateSpace",
15
+ ]
@@ -0,0 +1,453 @@
1
+ import numpy as np
2
+ import pymc as pm
3
+ import pytensor
4
+ import pytensor.tensor as pt
5
+
6
+ from pymc import intX
7
+ from pymc.distributions.dist_math import check_parameters
8
+ from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
9
+ from pymc.distributions.multivariate import MvNormal
10
+ from pymc.distributions.shape_utils import get_support_shape_1d
11
+ from pymc.logprob.abstract import _logprob
12
+ from pytensor.graph.basic import Node
13
+ from pytensor.tensor.random.basic import MvNormalRV
14
+
15
+ floatX = pytensor.config.floatX
16
+ COV_ZERO_TOL = 0
17
+
18
+ lgss_shape_message = (
19
+ "The LinearGaussianStateSpace distribution needs shape information to be constructed. "
20
+ "Ensure that all input matrices have shape information specified."
21
+ )
22
+
23
+
24
+ def make_signature(sequence_names):
25
+ states = "s"
26
+ obs = "p"
27
+ exog = "r"
28
+ time = "t"
29
+ state_and_obs = "n"
30
+
31
+ matrix_to_shape = {
32
+ "x0": (states,),
33
+ "P0": (states, states),
34
+ "c": (states,),
35
+ "d": (obs,),
36
+ "T": (states, states),
37
+ "Z": (obs, states),
38
+ "R": (states, exog),
39
+ "H": (obs, obs),
40
+ "Q": (exog, exog),
41
+ }
42
+
43
+ for matrix in sequence_names:
44
+ base_shape = matrix_to_shape[matrix]
45
+ matrix_to_shape[matrix] = (time, *base_shape)
46
+
47
+ signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in matrix_to_shape.values()])
48
+
49
+ return f"{signature},[rng]->[rng],({time},{state_and_obs})"
50
+
51
+
52
+ class MvNormalSVDRV(MvNormalRV):
53
+ name = "multivariate_normal"
54
+ signature = "(n),(n,n)->(n)"
55
+ dtype = "floatX"
56
+ _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
57
+
58
+
59
+ class MvNormalSVD(MvNormal):
60
+ """Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
61
+
62
+ A JAX MvNormal robust to low-rank covariance matrices
63
+ """
64
+
65
+ rv_op = MvNormalSVDRV()
66
+
67
+
68
+ try:
69
+ import jax.random
70
+
71
+ from pytensor.link.jax.dispatch.random import jax_sample_fn
72
+
73
+ @jax_sample_fn.register(MvNormalSVDRV)
74
+ def jax_sample_fn_mvnormal_svd(op, node):
75
+ def sample_fn(rng, size, dtype, *parameters):
76
+ rng_key = rng["jax_state"]
77
+ rng_key, sampling_key = jax.random.split(rng_key, 2)
78
+ sample = jax.random.multivariate_normal(
79
+ sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
80
+ )
81
+ rng["jax_state"] = rng_key
82
+ return (rng, sample)
83
+
84
+ return sample_fn
85
+
86
+ except ImportError:
87
+ pass
88
+
89
+
90
+ class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
91
+ default_output = 1
92
+ _print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
93
+
94
+ def update(self, node: Node):
95
+ return {node.inputs[-1]: node.outputs[0]}
96
+
97
+
98
+ class _LinearGaussianStateSpace(Continuous):
99
+ def __new__(
100
+ cls,
101
+ name,
102
+ a0,
103
+ P0,
104
+ c,
105
+ d,
106
+ T,
107
+ Z,
108
+ R,
109
+ H,
110
+ Q,
111
+ steps=None,
112
+ mode=None,
113
+ sequence_names=None,
114
+ append_x0=True,
115
+ **kwargs,
116
+ ):
117
+ # Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
118
+ # created by LinearGaussianStateSpace. This "combined" distribution shouldn't ever be directly used.
119
+ steps = get_support_shape_1d(
120
+ support_shape=steps,
121
+ shape=None,
122
+ dims=None,
123
+ observed=kwargs.get("observed", None),
124
+ support_shape_offset=0,
125
+ )
126
+
127
+ return super().__new__(
128
+ cls,
129
+ name,
130
+ a0,
131
+ P0,
132
+ c,
133
+ d,
134
+ T,
135
+ Z,
136
+ R,
137
+ H,
138
+ Q,
139
+ steps=steps,
140
+ mode=mode,
141
+ sequence_names=sequence_names,
142
+ append_x0=append_x0,
143
+ **kwargs,
144
+ )
145
+
146
+ @classmethod
147
+ def dist(
148
+ cls,
149
+ a0,
150
+ P0,
151
+ c,
152
+ d,
153
+ T,
154
+ Z,
155
+ R,
156
+ H,
157
+ Q,
158
+ steps=None,
159
+ mode=None,
160
+ sequence_names=None,
161
+ append_x0=True,
162
+ **kwargs,
163
+ ):
164
+ steps = get_support_shape_1d(
165
+ support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=0
166
+ )
167
+
168
+ if steps is None:
169
+ raise ValueError("Must specify steps or shape parameter")
170
+
171
+ steps = pt.as_tensor_variable(intX(steps), ndim=0)
172
+
173
+ return super().dist(
174
+ [a0, P0, c, d, T, Z, R, H, Q, steps],
175
+ mode=mode,
176
+ sequence_names=sequence_names,
177
+ append_x0=append_x0,
178
+ **kwargs,
179
+ )
180
+
181
+ @classmethod
182
+ def rv_op(
183
+ cls,
184
+ a0,
185
+ P0,
186
+ c,
187
+ d,
188
+ T,
189
+ Z,
190
+ R,
191
+ H,
192
+ Q,
193
+ steps,
194
+ size=None,
195
+ mode=None,
196
+ sequence_names=None,
197
+ append_x0=True,
198
+ ):
199
+ if sequence_names is None:
200
+ sequence_names = []
201
+
202
+ a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_ = map(
203
+ lambda x: x.type(), [a0, P0, c, d, T, Z, R, H, Q]
204
+ )
205
+
206
+ c_.name = "c"
207
+ d_.name = "d"
208
+ T_.name = "T"
209
+ Z_.name = "Z"
210
+ R_.name = "R"
211
+ H_.name = "H"
212
+ Q_.name = "Q"
213
+
214
+ sequences = [
215
+ x
216
+ for x, name in zip([c_, d_, T_, Z_, R_, H_, Q_], ["c", "d", "T", "Z", "R", "H", "Q"])
217
+ if name in sequence_names
218
+ ]
219
+ non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
220
+
221
+ rng = pytensor.shared(np.random.default_rng())
222
+
223
+ def sort_args(args):
224
+ sorted_args = []
225
+
226
+ # Inside the scan, outputs_info variables get a time step appended to their name
227
+ # e.g. x -> x[t]. Remove this so we can identify variables by name.
228
+ arg_names = [x.name.replace("[t]", "") for x in args]
229
+
230
+ # c, d ,T, Z, R, H, Q is the "canonical" ordering
231
+ for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
232
+ idx = arg_names.index(name)
233
+ sorted_args.append(args[idx])
234
+
235
+ return sorted_args
236
+
237
+ n_seq = len(sequence_names)
238
+
239
+ def step_fn(*args):
240
+ seqs, state, non_seqs = args[:n_seq], args[n_seq], args[n_seq + 1 :]
241
+ non_seqs, rng = non_seqs[:-1], non_seqs[-1]
242
+
243
+ c, d, T, Z, R, H, Q = sort_args(seqs + non_seqs)
244
+ k = T.shape[0]
245
+ a = state[:k]
246
+
247
+ middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs
248
+ next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
249
+
250
+ a_mu = c + T @ a
251
+ a_next = a_mu + R @ a_innovation
252
+
253
+ y_mu = d + Z @ a_next
254
+ y_next = y_mu + y_innovation
255
+
256
+ next_state = pt.concatenate([a_next, y_next], axis=0)
257
+
258
+ return next_state, {rng: next_rng}
259
+
260
+ Z_init = Z_ if Z_ in non_sequences else Z_[0]
261
+ H_init = H_ if H_ in non_sequences else H_[0]
262
+
263
+ init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng)
264
+ init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng)
265
+
266
+ init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
267
+
268
+ statespace, updates = pytensor.scan(
269
+ step_fn,
270
+ outputs_info=[init_dist_],
271
+ sequences=None if len(sequences) == 0 else sequences,
272
+ non_sequences=[*non_sequences, rng],
273
+ n_steps=steps,
274
+ mode=mode,
275
+ strict=True,
276
+ )
277
+
278
+ if append_x0:
279
+ statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0)
280
+ statespace_ = pt.specify_shape(statespace_, (steps + 1, None))
281
+ else:
282
+ statespace_ = statespace
283
+ statespace_ = pt.specify_shape(statespace_, (steps, None))
284
+
285
+ (ss_rng,) = tuple(updates.values())
286
+ linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
287
+ inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
288
+ outputs=[ss_rng, statespace_],
289
+ extended_signature=make_signature(sequence_names),
290
+ )
291
+
292
+ linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps, rng)
293
+ return linear_gaussian_ss
294
+
295
+
296
+ class LinearGaussianStateSpace(Continuous):
297
+ """
298
+ Linear Gaussian Statespace distribution
299
+
300
+ """
301
+
302
+ def __new__(
303
+ cls,
304
+ name,
305
+ a0,
306
+ P0,
307
+ c,
308
+ d,
309
+ T,
310
+ Z,
311
+ R,
312
+ H,
313
+ Q,
314
+ *,
315
+ steps,
316
+ k_endog=None,
317
+ sequence_names=None,
318
+ mode=None,
319
+ append_x0=True,
320
+ **kwargs,
321
+ ):
322
+ dims = kwargs.pop("dims", None)
323
+ latent_dims = None
324
+ obs_dims = None
325
+ if dims is not None:
326
+ if len(dims) != 3:
327
+ ValueError(
328
+ "LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states"
329
+ )
330
+ time_dim, state_dim, obs_dim = dims
331
+ latent_dims = [time_dim, state_dim]
332
+ obs_dims = [time_dim, obs_dim]
333
+
334
+ latent_obs_combined = _LinearGaussianStateSpace(
335
+ f"{name}_combined",
336
+ a0,
337
+ P0,
338
+ c,
339
+ d,
340
+ T,
341
+ Z,
342
+ R,
343
+ H,
344
+ Q,
345
+ steps=steps,
346
+ mode=mode,
347
+ sequence_names=sequence_names,
348
+ append_x0=append_x0,
349
+ **kwargs,
350
+ )
351
+ latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
352
+ if k_endog is None:
353
+ k_endog = cls._get_k_endog(H)
354
+ latent_slice = slice(None, -k_endog)
355
+ obs_slice = slice(-k_endog, None)
356
+
357
+ latent_states = latent_obs_combined[..., latent_slice]
358
+ obs_states = latent_obs_combined[..., obs_slice]
359
+
360
+ latent_states = pm.Deterministic(f"{name}_latent", latent_states, dims=latent_dims)
361
+ obs_states = pm.Deterministic(f"{name}_observed", obs_states, dims=obs_dims)
362
+
363
+ return latent_states, obs_states
364
+
365
+ @classmethod
366
+ def dist(cls, a0, P0, c, d, T, Z, R, H, Q, *, steps=None, **kwargs):
367
+ latent_obs_combined = _LinearGaussianStateSpace.dist(
368
+ a0, P0, c, d, T, Z, R, H, Q, steps=steps, **kwargs
369
+ )
370
+ k_states = T.type.shape[0]
371
+
372
+ latent_states = latent_obs_combined[..., :k_states]
373
+ obs_states = latent_obs_combined[..., k_states:]
374
+
375
+ return latent_states, obs_states
376
+
377
+ @classmethod
378
+ def _get_k_states(cls, T):
379
+ k_states = T.type.shape[0]
380
+ if k_states is None:
381
+ raise ValueError(lgss_shape_message)
382
+ return k_states
383
+
384
+ @classmethod
385
+ def _get_k_endog(cls, H):
386
+ k_endog = H.type.shape[0]
387
+ if k_endog is None:
388
+ raise ValueError(lgss_shape_message)
389
+
390
+ return k_endog
391
+
392
+
393
+ class KalmanFilterRV(SymbolicRandomVariable):
394
+ default_output = 1
395
+ _print_name = ("KalmanFilter", "\\operatorname{KalmanFilter}")
396
+ extended_signature = "(t,s),(t,s,s),(t),[rng]->[rng],(t,s)"
397
+
398
+ def update(self, node: Node):
399
+ return {node.inputs[-1]: node.outputs[0]}
400
+
401
+
402
+ class SequenceMvNormal(Continuous):
403
+ def __new__(cls, *args, **kwargs):
404
+ return super().__new__(cls, *args, **kwargs)
405
+
406
+ @classmethod
407
+ def dist(cls, mus, covs, logp, **kwargs):
408
+ return super().dist([mus, covs, logp], **kwargs)
409
+
410
+ @classmethod
411
+ def rv_op(cls, mus, covs, logp, size=None):
412
+ # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
413
+ if mus.ndim > 2:
414
+ mus = pt.moveaxis(mus, -2, 0)
415
+ if covs.ndim > 3:
416
+ covs = pt.moveaxis(covs, -3, 0)
417
+
418
+ mus_, covs_ = mus.type(), covs.type()
419
+
420
+ logp_ = logp.type()
421
+ rng = pytensor.shared(np.random.default_rng())
422
+
423
+ def step(mu, cov, rng):
424
+ new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng).owner.outputs
425
+ return mvn, {rng: new_rng}
426
+
427
+ mvn_seq, updates = pytensor.scan(
428
+ step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0]
429
+ )
430
+ mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
431
+
432
+ # Move time axis back to position -2 so batches are on the left
433
+ if mvn_seq.ndim > 2:
434
+ mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
435
+
436
+ (seq_mvn_rng,) = tuple(updates.values())
437
+
438
+ mvn_seq_op = KalmanFilterRV(
439
+ inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
440
+ )
441
+
442
+ mvn_seq = mvn_seq_op(mus, covs, logp, rng)
443
+ return mvn_seq
444
+
445
+
446
+ @_logprob.register(KalmanFilterRV)
447
+ def sequence_mvnormal_logp(op, values, mus, covs, logp, rng, **kwargs):
448
+ return check_parameters(
449
+ logp,
450
+ pt.eq(values[0].shape[0], mus.shape[0]),
451
+ pt.eq(covs.shape[0], mus.shape[0]),
452
+ msg="Observed data and parameters must have the same number of timesteps (dimension 0)",
453
+ )