pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,175 @@
1
+ import unittest
2
+
3
+ import numpy as np
4
+ import pytensor
5
+ import pytensor.tensor as pt
6
+
7
+ from numpy.testing import assert_allclose
8
+
9
+ from pymc_extras.statespace.core.representation import PytensorRepresentation
10
+ from tests.statespace.utilities.shared_fixtures import TEST_SEED
11
+ from tests.statespace.utilities.test_helpers import fast_eval, make_test_inputs
12
+
13
+ floatX = pytensor.config.floatX
14
+ atol = 1e-12 if floatX == "float64" else 1e-6
15
+
16
+
17
+ def unpack_ssm_dims(ssm):
18
+ p = ssm.k_endog
19
+ m = ssm.k_states
20
+ r = ssm.k_posdef
21
+
22
+ return p, m, r
23
+
24
+
25
+ class BasicFunctionality(unittest.TestCase):
26
+ def setUp(self):
27
+ self.rng = np.random.default_rng(TEST_SEED)
28
+
29
+ def test_numpy_to_pytensor(self):
30
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1)
31
+ X = np.eye(5)
32
+ X_pt = ssm._numpy_to_pytensor("transition", X)
33
+ self.assertTrue(isinstance(X_pt, pt.TensorVariable))
34
+ assert_allclose(ssm["transition"].type.shape, X.shape)
35
+
36
+ assert ssm["transition"].name == "transition"
37
+
38
+ def test_default_shapes_full_rank(self):
39
+ ssm = PytensorRepresentation(k_endog=5, k_states=5, k_posdef=5)
40
+ p, m, r = unpack_ssm_dims(ssm)
41
+
42
+ assert_allclose(ssm["design"].type.shape, (p, m))
43
+ assert_allclose(ssm["transition"].type.shape, (m, m))
44
+ assert_allclose(ssm["selection"].type.shape, (m, r))
45
+ assert_allclose(ssm["state_cov"].type.shape, (r, r))
46
+ assert_allclose(ssm["obs_cov"].type.shape, (p, p))
47
+
48
+ def test_default_shapes_low_rank(self):
49
+ ssm = PytensorRepresentation(k_endog=5, k_states=5, k_posdef=2)
50
+ p, m, r = unpack_ssm_dims(ssm)
51
+
52
+ assert_allclose(ssm["design"].type.shape, (p, m))
53
+ assert_allclose(ssm["transition"].type.shape, (m, m))
54
+ assert_allclose(ssm["selection"].type.shape, (m, r))
55
+ assert_allclose(ssm["state_cov"].type.shape, (r, r))
56
+ assert_allclose(ssm["obs_cov"].type.shape, (p, p))
57
+
58
+ def test_matrix_assignment(self):
59
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=2)
60
+
61
+ ssm["design", 0, 0] = 3.0
62
+ ssm["transition", 0, :] = 2.7
63
+ ssm["selection", -1, -1] = 9.9
64
+
65
+ assert_allclose(fast_eval(ssm["design"][0, 0]), 3.0, atol=atol)
66
+ assert_allclose(fast_eval(ssm["transition"][0, :]), 2.7, atol=atol)
67
+ assert_allclose(fast_eval(ssm["selection"][-1, -1]), 9.9, atol=atol)
68
+
69
+ assert ssm["design"].name == "design"
70
+ assert ssm["transition"].name == "transition"
71
+ assert ssm["selection"].name == "selection"
72
+
73
+ def test_build_representation_from_data(self):
74
+ p, m, r, n = 3, 6, 1, 10
75
+ inputs = [data, a0, P0, c, d, T, Z, R, H, Q] = make_test_inputs(
76
+ p, m, r, n, self.rng, missing_data=0
77
+ )
78
+
79
+ ssm = PytensorRepresentation(
80
+ k_endog=p,
81
+ k_states=m,
82
+ k_posdef=r,
83
+ design=Z,
84
+ transition=T,
85
+ selection=R,
86
+ state_cov=Q,
87
+ obs_cov=H,
88
+ initial_state=a0,
89
+ initial_state_cov=P0,
90
+ state_intercept=c,
91
+ obs_intercept=d,
92
+ )
93
+
94
+ names = [
95
+ "initial_state",
96
+ "initial_state_cov",
97
+ "state_intercept",
98
+ "obs_intercept",
99
+ "transition",
100
+ "design",
101
+ "selection",
102
+ "obs_cov",
103
+ "state_cov",
104
+ ]
105
+
106
+ for name, X in zip(names, inputs[1:]):
107
+ assert_allclose(X, fast_eval(ssm[name]), err_msg=name)
108
+
109
+ for name, X in zip(names, inputs[1:]):
110
+ assert ssm[name].name == name
111
+ assert_allclose(ssm[name].type.shape, X.shape, err_msg=f"{name} shape test")
112
+
113
+ def test_assign_time_varying_matrices(self):
114
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=2)
115
+ n = 10
116
+
117
+ ssm["design", 0, 0] = 3.0
118
+ ssm["transition", 0, :] = 2.7
119
+ ssm["selection", -1, -1] = 9.9
120
+
121
+ ssm["state_intercept"] = np.zeros((n, 5))
122
+ ssm["state_intercept", :, 0] = np.arange(n)
123
+
124
+ assert_allclose(fast_eval(ssm["design"][0, 0]), 3.0, atol=atol)
125
+ assert_allclose(fast_eval(ssm["transition"][0, :]), 2.7, atol=atol)
126
+ assert_allclose(fast_eval(ssm["selection"][-1, -1]), 9.9, atol=atol)
127
+ assert_allclose(fast_eval(ssm["state_intercept"][:, 0]), np.arange(n), atol=atol)
128
+
129
+ def test_invalid_key_name_raises(self):
130
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1)
131
+ with self.assertRaises(IndexError) as e:
132
+ X = ssm["invalid_key"]
133
+ msg = str(e.exception)
134
+ self.assertEqual(msg, "invalid_key is an invalid state space matrix name")
135
+
136
+ def test_non_string_key_raises(self):
137
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1)
138
+ with self.assertRaises(IndexError) as e:
139
+ X = ssm[0]
140
+ msg = str(e.exception)
141
+ self.assertEqual(msg, "First index must the name of a valid state space matrix.")
142
+
143
+ def test_invalid_key_tuple_raises(self):
144
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1)
145
+ with self.assertRaises(IndexError) as e:
146
+ X = ssm[0, 1, 1]
147
+ msg = str(e.exception)
148
+ self.assertEqual(msg, "First index must the name of a valid state space matrix.")
149
+
150
+ def test_slice_statespace_matrix(self):
151
+ T = np.eye(5)
152
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1, transition=T)
153
+ T_out = ssm["transition", :3, :]
154
+ assert_allclose(T[:3], fast_eval(T_out))
155
+
156
+ def test_update_matrix_via_key(self):
157
+ T = np.eye(5)
158
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1)
159
+ ssm["transition"] = T
160
+
161
+ assert_allclose(T, fast_eval(ssm["transition"]))
162
+
163
+ def test_update_matrix_with_invalid_shape_raises(self):
164
+ T = np.eye(10)
165
+ ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1)
166
+ with self.assertRaises(ValueError) as e:
167
+ ssm["transition"] = T
168
+ msg = str(e.exception)
169
+ self.assertEqual(
170
+ msg, "The last two dimensions of transition must be (5, 5), found (10, 10)"
171
+ )
172
+
173
+
174
+ if __name__ == "__main__":
175
+ unittest.main()