pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,126 @@
1
+ import pytensor
2
+ import pytensor.tensor as pt
3
+
4
+ from pytensor.compile import get_mode
5
+ from pytensor.tensor.nlinalg import matrix_dot
6
+
7
+ from pymc_extras.statespace.filters.utilities import (
8
+ quad_form_sym,
9
+ split_vars_into_seq_and_nonseq,
10
+ stabilize,
11
+ )
12
+ from pymc_extras.statespace.utils.constants import JITTER_DEFAULT
13
+
14
+
15
+ class KalmanSmoother:
16
+ """
17
+ Kalman Smoother
18
+
19
+ """
20
+
21
+ def __init__(self, mode: str | None = None):
22
+ self.mode = mode
23
+ self.cov_jitter = JITTER_DEFAULT
24
+ self.seq_names = []
25
+ self.non_seq_names = []
26
+
27
+ def unpack_args(self, args):
28
+ """
29
+ The order of inputs to the inner scan function is not known, since some, all, or none of the input matrices
30
+ can be time varying. The order arguments are fed to the inner function is sequences, outputs_info,
31
+ non-sequences. This function works out which matrices are where, and returns a standardized order expected
32
+ by the kalman_step function.
33
+
34
+ The standard order is: a, P, a_smooth, P_smooth, T, R, Q
35
+ """
36
+ # If there are no sequence parameters (all params are static),
37
+ # no changes are needed, params will be in order.
38
+ args = list(args)
39
+ n_seq = len(self.seq_names)
40
+ if n_seq == 0:
41
+ return args
42
+
43
+ # The first two args are always a and P
44
+ a = args.pop(0)
45
+ P = args.pop(0)
46
+
47
+ # There are always two outputs_info wedged between the seqs and non_seqs
48
+ seqs, (a_smooth, P_smooth), non_seqs = (
49
+ args[:n_seq],
50
+ args[n_seq : n_seq + 2],
51
+ args[n_seq + 2 :],
52
+ )
53
+ return_ordered = []
54
+ for name in ["T", "R", "Q"]:
55
+ if name in self.seq_names:
56
+ idx = self.seq_names.index(name)
57
+ return_ordered.append(seqs[idx])
58
+ else:
59
+ idx = self.non_seq_names.index(name)
60
+ return_ordered.append(non_seqs[idx])
61
+
62
+ T, R, Q = return_ordered
63
+
64
+ return a, P, a_smooth, P_smooth, T, R, Q
65
+
66
+ def build_graph(
67
+ self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
68
+ ):
69
+ self.mode = mode
70
+ self.cov_jitter = cov_jitter
71
+
72
+ n, k = filtered_states.type.shape
73
+
74
+ a_last = pt.specify_shape(filtered_states[-1], (k,))
75
+ P_last = pt.specify_shape(filtered_covariances[-1], (k, k))
76
+
77
+ sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
78
+ [T, R, Q], ["T", "R", "Q"]
79
+ )
80
+
81
+ self.seq_names = seq_names
82
+ self.non_seq_names = non_seq_names
83
+
84
+ smoother_result, updates = pytensor.scan(
85
+ self.smoother_step,
86
+ sequences=[filtered_states[:-1], filtered_covariances[:-1], *sequences],
87
+ outputs_info=[a_last, P_last],
88
+ non_sequences=non_sequences,
89
+ go_backwards=True,
90
+ name="kalman_smoother",
91
+ mode=get_mode(self.mode),
92
+ )
93
+
94
+ smoothed_states, smoothed_covariances = smoother_result
95
+ smoothed_states = pt.concatenate(
96
+ [smoothed_states[::-1], pt.expand_dims(a_last, axis=(0,))], axis=0
97
+ )
98
+ smoothed_covariances = pt.concatenate(
99
+ [smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
100
+ )
101
+
102
+ smoothed_states.name = "smoothed_states"
103
+ smoothed_covariances.name = "smoothed_covariances"
104
+
105
+ return smoothed_states, smoothed_covariances
106
+
107
+ def smoother_step(self, *args):
108
+ a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)
109
+ a_hat, P_hat = self.predict(a, P, T, R, Q)
110
+
111
+ # Use pinv, otherwise P_hat is singular when there is missing data
112
+ smoother_gain = matrix_dot(pt.linalg.pinv(P_hat), T, P).T
113
+ a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
114
+
115
+ P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)
116
+ P_smooth_next = stabilize(P_smooth_next, self.cov_jitter)
117
+ P_smooth_next = pt.specify_shape(stabilize(P_smooth_next), P_smooth.type.shape)
118
+
119
+ return a_smooth_next, P_smooth_next
120
+
121
+ def predict(self, a, P, T, R, Q):
122
+ a_hat = T.dot(a)
123
+ P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q)
124
+ P_hat = stabilize(P_hat, self.cov_jitter)
125
+
126
+ return a_hat, P_hat
@@ -0,0 +1,59 @@
1
+ import pytensor.tensor as pt
2
+
3
+ from pytensor.tensor.nlinalg import matrix_dot
4
+
5
+ from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED
6
+
7
+
8
+ def decide_if_x_time_varies(x, name):
9
+ if name in NEVER_TIME_VARYING:
10
+ return False
11
+
12
+ ndim = x.ndim
13
+
14
+ if name in VECTOR_VALUED:
15
+ if ndim not in [1, 2]:
16
+ raise ValueError(
17
+ f"Vector {name} has {ndim} dimensions; it should have either 1 (static),"
18
+ f" or 2 (time varying )"
19
+ )
20
+
21
+ return ndim == 2
22
+
23
+ if ndim not in [2, 3]:
24
+ raise ValueError(
25
+ f"Matrix {name} has {ndim} dimensions; it should have either"
26
+ f" 2 (static), or 3 (time varying)."
27
+ )
28
+
29
+ return ndim == 3
30
+
31
+
32
+ def split_vars_into_seq_and_nonseq(params, param_names):
33
+ """
34
+ Split inputs into those that are time varying and those that are not. This division is required by scan.
35
+ """
36
+ sequences, non_sequences = [], []
37
+ seq_names, non_seq_names = [], []
38
+
39
+ for param, name in zip(params, param_names):
40
+ if decide_if_x_time_varies(param, name):
41
+ sequences.append(param)
42
+ seq_names.append(name)
43
+ else:
44
+ non_sequences.append(param)
45
+ non_seq_names.append(name)
46
+
47
+ return sequences, non_sequences, seq_names, non_seq_names
48
+
49
+
50
+ def stabilize(cov, jitter=JITTER_DEFAULT):
51
+ # Ensure diagonal is non-zero
52
+ cov = cov + pt.identity_like(cov) * jitter
53
+
54
+ return cov
55
+
56
+
57
+ def quad_form_sym(A, B):
58
+ out = matrix_dot(A, B, A.T)
59
+ return 0.5 * (out + out.T)