pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
File without changes
@@ -0,0 +1,17 @@
1
+ import pandas as pd
2
+
3
+ from sklearn.base import BaseEstimator, TransformerMixin
4
+ from sklearn.preprocessing import StandardScaler
5
+
6
+
7
+ class StandardScalerDF(StandardScaler, TransformerMixin, BaseEstimator):
8
+ def __init__(self, with_mean=True, with_std=True):
9
+ super().__init__(with_mean=with_mean, with_std=with_std)
10
+
11
+ def transform(self, X, y=None):
12
+ z = super().transform(X)
13
+ return pd.DataFrame(z, index=X.index, columns=X.columns)
14
+
15
+ def fit_transform(self, X, y=None):
16
+ z = super().fit_transform(X)
17
+ return pd.DataFrame(z, index=X.index, columns=X.columns)
@@ -0,0 +1,182 @@
1
+ import numpy as np
2
+
3
+ from pymc import Model
4
+ from pymc.printing import str_for_dist, str_for_potential_or_deterministic
5
+ from pytensor import Mode
6
+ from pytensor.compile.sharedvalue import SharedVariable
7
+ from pytensor.graph.type import Constant, Variable
8
+ from rich.box import SIMPLE_HEAD
9
+ from rich.table import Table
10
+
11
+
12
+ def variable_expression(
13
+ model: Model,
14
+ var: Variable,
15
+ truncate_deterministic: int | None,
16
+ ) -> str:
17
+ """Get the expression of a variable in a human-readable format."""
18
+ if var in model.data_vars:
19
+ var_expr = "Data"
20
+ elif var in model.deterministics:
21
+ str_repr = str_for_potential_or_deterministic(var, dist_name="")
22
+ _, var_expr = str_repr.split(" ~ ")
23
+ var_expr = var_expr[1:-1] # Remove outer parentheses (f(...))
24
+ if truncate_deterministic is not None and len(var_expr) > truncate_deterministic:
25
+ contents = var_expr[2:-1].split(", ")
26
+ str_len = 0
27
+ for show_n, content in enumerate(contents):
28
+ str_len += len(content) + 2
29
+ if str_len > truncate_deterministic:
30
+ break
31
+ var_expr = f"f({', '.join(contents[:show_n])}, ...)"
32
+ elif var in model.potentials:
33
+ var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(" ~ ")[1]
34
+ else: # basic_RVs
35
+ var_expr = str_for_dist(var).split(" ~ ")[1]
36
+ return var_expr
37
+
38
+
39
+ def _extract_dim_value(var: SharedVariable | Constant) -> np.ndarray:
40
+ if isinstance(var, SharedVariable):
41
+ return var.get_value(borrow=True)
42
+ else:
43
+ return var.data
44
+
45
+
46
+ def dims_expression(model: Model, var: Variable) -> str:
47
+ """Get the dimensions of a variable in a human-readable format."""
48
+ if (dims := model.named_vars_to_dims.get(var.name)) is not None:
49
+ dim_sizes = {dim: _extract_dim_value(model.dim_lengths[dim]) for dim in dims}
50
+ return " × ".join(f"{dim}[{dim_size}]" for dim, dim_size in dim_sizes.items())
51
+ else:
52
+ dim_sizes = list(var.shape.eval(mode=Mode(linker="py", optimizer="fast_compile")))
53
+ return f"[{', '.join(map(str, dim_sizes))}]" if dim_sizes else ""
54
+
55
+
56
+ def model_parameter_count(model: Model) -> int:
57
+ """Count the number of parameters in the model."""
58
+ rv_shapes = model.eval_rv_shapes() # Includes transformed variables
59
+ return np.sum([np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs])
60
+
61
+
62
+ def model_table(
63
+ model: Model,
64
+ *,
65
+ split_groups: bool = True,
66
+ truncate_deterministic: int | None = None,
67
+ parameter_count: bool = True,
68
+ ) -> Table:
69
+ """Create a rich table with a summary of the model's variables and their expressions.
70
+
71
+ Parameters
72
+ ----------
73
+ model : Model
74
+ The PyMC model to summarize.
75
+ split_groups : bool
76
+ If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs)
77
+ will be separated by a section.
78
+ truncate_deterministic : int | None
79
+ If not None, truncate the expression of deterministic variables that go beyond this length.
80
+ empty_dims : bool
81
+ If True, show the dimensions of scalar variables as an empty list.
82
+ parameter_count : bool
83
+ If True, add a row with the total number of parameters in the model.
84
+
85
+ Returns
86
+ -------
87
+ Table
88
+ A rich table with the model's variables, their expressions and dims.
89
+
90
+ Examples
91
+ --------
92
+ .. code-block:: python
93
+
94
+ import numpy as np
95
+ import pymc as pm
96
+
97
+ from pymc_extras.printing import model_table
98
+
99
+ coords = {"subject": range(20), "param": ["a", "b"]}
100
+ with pm.Model(coords=coords) as m:
101
+ x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param"))
102
+ y = pm.Data("y", np.random.normal(size=(20,)), dims="subject")
103
+
104
+ beta = pm.Normal("beta", mu=0, sigma=1, dims="param")
105
+ mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject")
106
+ sigma = pm.HalfNormal("sigma", sigma=1)
107
+
108
+ y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject")
109
+
110
+ table = model_table(m)
111
+ table # Displays the following table in an interactive environment
112
+ '''
113
+ Variable Expression Dimensions
114
+ ─────────────────────────────────────────────────────
115
+ x = Data subject[20] × param[2]
116
+ y = Data subject[20]
117
+
118
+ beta ~ Normal(0, 1) param[2]
119
+ sigma ~ HalfNormal(0, 1)
120
+ Parameter count = 3
121
+
122
+ mu = f(beta) subject[20]
123
+
124
+ y_obs ~ Normal(mu, sigma) subject[20]
125
+ '''
126
+
127
+ Output can be explicitly rendered in a rich console or exported to text, html or svg.
128
+
129
+ .. code-block:: python
130
+
131
+ from rich.console import Console
132
+
133
+ console = Console(record=True)
134
+ console.print(table)
135
+ text_export = console.export_text()
136
+ html_export = console.export_html()
137
+ svg_export = console.export_svg()
138
+
139
+ """
140
+ table = Table(
141
+ show_header=True,
142
+ show_edge=False,
143
+ box=SIMPLE_HEAD,
144
+ highlight=False,
145
+ collapse_padding=True,
146
+ )
147
+ table.add_column("Variable", justify="right")
148
+ table.add_column("Expression", justify="left")
149
+ table.add_column("Dimensions")
150
+
151
+ if split_groups:
152
+ groups = (
153
+ model.data_vars,
154
+ model.free_RVs,
155
+ model.deterministics,
156
+ model.potentials,
157
+ model.observed_RVs,
158
+ )
159
+ else:
160
+ # Show variables in the order they were defined
161
+ groups = (model.named_vars.values(),)
162
+
163
+ for group in groups:
164
+ if not group:
165
+ continue
166
+
167
+ for var in group:
168
+ var_name = var.name
169
+ sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
170
+ var_expr = variable_expression(model, var, truncate_deterministic)
171
+ dims_expr = dims_expression(model, var)
172
+ if dims_expr == "[]":
173
+ dims_expr = ""
174
+ table.add_row(var_name + sep, var_expr, dims_expr)
175
+
176
+ if parameter_count and (not split_groups or group == model.free_RVs):
177
+ n_parameters = model_parameter_count(model)
178
+ table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]")
179
+
180
+ table.add_section()
181
+
182
+ return table
@@ -0,0 +1,13 @@
1
+ from pymc_extras.statespace.core.compile import compile_statespace
2
+ from pymc_extras.statespace.models import structural
3
+ from pymc_extras.statespace.models.ETS import BayesianETS
4
+ from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
5
+ from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
6
+
7
+ __all__ = [
8
+ "compile_statespace",
9
+ "structural",
10
+ "BayesianETS",
11
+ "BayesianSARIMA",
12
+ "BayesianVARMAX",
13
+ ]
@@ -0,0 +1,7 @@
1
+ # ruff: noqa: I001
2
+
3
+ from pymc_extras.statespace.core.representation import PytensorRepresentation
4
+ from pymc_extras.statespace.core.statespace import PyMCStateSpace
5
+ from pymc_extras.statespace.core.compile import compile_statespace
6
+
7
+ __all__ = ["PytensorRepresentation", "PyMCStateSpace", "compile_statespace"]
@@ -0,0 +1,48 @@
1
+ import numpy as np
2
+ import pymc as pm
3
+ import pytensor
4
+ import pytensor.tensor as pt
5
+
6
+ from pymc_extras.statespace.core import PyMCStateSpace
7
+ from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
8
+ from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG
9
+
10
+
11
+ def compile_statespace(
12
+ statespace_model: PyMCStateSpace, steps: int | None = None, **compile_kwargs
13
+ ):
14
+ if steps is None:
15
+ steps = pt.iscalar("steps")
16
+
17
+ x0, _, c, d, T, Z, R, H, Q = statespace_model._unpack_statespace_with_placeholders()
18
+
19
+ sequence_names = [x.name for x in [c, d] if x.ndim == 2]
20
+ sequence_names += [x.name for x in [T, Z, R, H, Q] if x.ndim == 3]
21
+
22
+ rename_dict = {v: k for k, v in SHORT_NAME_TO_LONG.items()}
23
+ sequence_names = list(map(rename_dict.get, sequence_names))
24
+
25
+ P0 = pt.zeros((x0.shape[0], x0.shape[0]))
26
+
27
+ outputs = LinearGaussianStateSpace.dist(
28
+ x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
29
+ )
30
+
31
+ inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
32
+
33
+ _f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
34
+
35
+ def f(*, draws=1, **params):
36
+ if isinstance(steps, pt.Variable):
37
+ inner_steps = params.get("steps", 100)
38
+ else:
39
+ inner_steps = steps
40
+
41
+ output = [np.empty((draws, inner_steps + 1, x.type.shape[-1])) for x in outputs]
42
+ for i in range(draws):
43
+ draw = _f(**params)
44
+ for j, x in enumerate(draw):
45
+ output[j][i] = x
46
+ return [x.squeeze() for x in output]
47
+
48
+ return f