pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
4
|
+
from sklearn.preprocessing import StandardScaler
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StandardScalerDF(StandardScaler, TransformerMixin, BaseEstimator):
|
|
8
|
+
def __init__(self, with_mean=True, with_std=True):
|
|
9
|
+
super().__init__(with_mean=with_mean, with_std=with_std)
|
|
10
|
+
|
|
11
|
+
def transform(self, X, y=None):
|
|
12
|
+
z = super().transform(X)
|
|
13
|
+
return pd.DataFrame(z, index=X.index, columns=X.columns)
|
|
14
|
+
|
|
15
|
+
def fit_transform(self, X, y=None):
|
|
16
|
+
z = super().fit_transform(X)
|
|
17
|
+
return pd.DataFrame(z, index=X.index, columns=X.columns)
|
pymc_extras/printing.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from pymc import Model
|
|
4
|
+
from pymc.printing import str_for_dist, str_for_potential_or_deterministic
|
|
5
|
+
from pytensor import Mode
|
|
6
|
+
from pytensor.compile.sharedvalue import SharedVariable
|
|
7
|
+
from pytensor.graph.type import Constant, Variable
|
|
8
|
+
from rich.box import SIMPLE_HEAD
|
|
9
|
+
from rich.table import Table
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def variable_expression(
|
|
13
|
+
model: Model,
|
|
14
|
+
var: Variable,
|
|
15
|
+
truncate_deterministic: int | None,
|
|
16
|
+
) -> str:
|
|
17
|
+
"""Get the expression of a variable in a human-readable format."""
|
|
18
|
+
if var in model.data_vars:
|
|
19
|
+
var_expr = "Data"
|
|
20
|
+
elif var in model.deterministics:
|
|
21
|
+
str_repr = str_for_potential_or_deterministic(var, dist_name="")
|
|
22
|
+
_, var_expr = str_repr.split(" ~ ")
|
|
23
|
+
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...))
|
|
24
|
+
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic:
|
|
25
|
+
contents = var_expr[2:-1].split(", ")
|
|
26
|
+
str_len = 0
|
|
27
|
+
for show_n, content in enumerate(contents):
|
|
28
|
+
str_len += len(content) + 2
|
|
29
|
+
if str_len > truncate_deterministic:
|
|
30
|
+
break
|
|
31
|
+
var_expr = f"f({', '.join(contents[:show_n])}, ...)"
|
|
32
|
+
elif var in model.potentials:
|
|
33
|
+
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(" ~ ")[1]
|
|
34
|
+
else: # basic_RVs
|
|
35
|
+
var_expr = str_for_dist(var).split(" ~ ")[1]
|
|
36
|
+
return var_expr
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _extract_dim_value(var: SharedVariable | Constant) -> np.ndarray:
|
|
40
|
+
if isinstance(var, SharedVariable):
|
|
41
|
+
return var.get_value(borrow=True)
|
|
42
|
+
else:
|
|
43
|
+
return var.data
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def dims_expression(model: Model, var: Variable) -> str:
|
|
47
|
+
"""Get the dimensions of a variable in a human-readable format."""
|
|
48
|
+
if (dims := model.named_vars_to_dims.get(var.name)) is not None:
|
|
49
|
+
dim_sizes = {dim: _extract_dim_value(model.dim_lengths[dim]) for dim in dims}
|
|
50
|
+
return " × ".join(f"{dim}[{dim_size}]" for dim, dim_size in dim_sizes.items())
|
|
51
|
+
else:
|
|
52
|
+
dim_sizes = list(var.shape.eval(mode=Mode(linker="py", optimizer="fast_compile")))
|
|
53
|
+
return f"[{', '.join(map(str, dim_sizes))}]" if dim_sizes else ""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def model_parameter_count(model: Model) -> int:
|
|
57
|
+
"""Count the number of parameters in the model."""
|
|
58
|
+
rv_shapes = model.eval_rv_shapes() # Includes transformed variables
|
|
59
|
+
return np.sum([np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs])
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def model_table(
|
|
63
|
+
model: Model,
|
|
64
|
+
*,
|
|
65
|
+
split_groups: bool = True,
|
|
66
|
+
truncate_deterministic: int | None = None,
|
|
67
|
+
parameter_count: bool = True,
|
|
68
|
+
) -> Table:
|
|
69
|
+
"""Create a rich table with a summary of the model's variables and their expressions.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
model : Model
|
|
74
|
+
The PyMC model to summarize.
|
|
75
|
+
split_groups : bool
|
|
76
|
+
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs)
|
|
77
|
+
will be separated by a section.
|
|
78
|
+
truncate_deterministic : int | None
|
|
79
|
+
If not None, truncate the expression of deterministic variables that go beyond this length.
|
|
80
|
+
empty_dims : bool
|
|
81
|
+
If True, show the dimensions of scalar variables as an empty list.
|
|
82
|
+
parameter_count : bool
|
|
83
|
+
If True, add a row with the total number of parameters in the model.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
Table
|
|
88
|
+
A rich table with the model's variables, their expressions and dims.
|
|
89
|
+
|
|
90
|
+
Examples
|
|
91
|
+
--------
|
|
92
|
+
.. code-block:: python
|
|
93
|
+
|
|
94
|
+
import numpy as np
|
|
95
|
+
import pymc as pm
|
|
96
|
+
|
|
97
|
+
from pymc_extras.printing import model_table
|
|
98
|
+
|
|
99
|
+
coords = {"subject": range(20), "param": ["a", "b"]}
|
|
100
|
+
with pm.Model(coords=coords) as m:
|
|
101
|
+
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param"))
|
|
102
|
+
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject")
|
|
103
|
+
|
|
104
|
+
beta = pm.Normal("beta", mu=0, sigma=1, dims="param")
|
|
105
|
+
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject")
|
|
106
|
+
sigma = pm.HalfNormal("sigma", sigma=1)
|
|
107
|
+
|
|
108
|
+
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject")
|
|
109
|
+
|
|
110
|
+
table = model_table(m)
|
|
111
|
+
table # Displays the following table in an interactive environment
|
|
112
|
+
'''
|
|
113
|
+
Variable Expression Dimensions
|
|
114
|
+
─────────────────────────────────────────────────────
|
|
115
|
+
x = Data subject[20] × param[2]
|
|
116
|
+
y = Data subject[20]
|
|
117
|
+
|
|
118
|
+
beta ~ Normal(0, 1) param[2]
|
|
119
|
+
sigma ~ HalfNormal(0, 1)
|
|
120
|
+
Parameter count = 3
|
|
121
|
+
|
|
122
|
+
mu = f(beta) subject[20]
|
|
123
|
+
|
|
124
|
+
y_obs ~ Normal(mu, sigma) subject[20]
|
|
125
|
+
'''
|
|
126
|
+
|
|
127
|
+
Output can be explicitly rendered in a rich console or exported to text, html or svg.
|
|
128
|
+
|
|
129
|
+
.. code-block:: python
|
|
130
|
+
|
|
131
|
+
from rich.console import Console
|
|
132
|
+
|
|
133
|
+
console = Console(record=True)
|
|
134
|
+
console.print(table)
|
|
135
|
+
text_export = console.export_text()
|
|
136
|
+
html_export = console.export_html()
|
|
137
|
+
svg_export = console.export_svg()
|
|
138
|
+
|
|
139
|
+
"""
|
|
140
|
+
table = Table(
|
|
141
|
+
show_header=True,
|
|
142
|
+
show_edge=False,
|
|
143
|
+
box=SIMPLE_HEAD,
|
|
144
|
+
highlight=False,
|
|
145
|
+
collapse_padding=True,
|
|
146
|
+
)
|
|
147
|
+
table.add_column("Variable", justify="right")
|
|
148
|
+
table.add_column("Expression", justify="left")
|
|
149
|
+
table.add_column("Dimensions")
|
|
150
|
+
|
|
151
|
+
if split_groups:
|
|
152
|
+
groups = (
|
|
153
|
+
model.data_vars,
|
|
154
|
+
model.free_RVs,
|
|
155
|
+
model.deterministics,
|
|
156
|
+
model.potentials,
|
|
157
|
+
model.observed_RVs,
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
# Show variables in the order they were defined
|
|
161
|
+
groups = (model.named_vars.values(),)
|
|
162
|
+
|
|
163
|
+
for group in groups:
|
|
164
|
+
if not group:
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
for var in group:
|
|
168
|
+
var_name = var.name
|
|
169
|
+
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
|
|
170
|
+
var_expr = variable_expression(model, var, truncate_deterministic)
|
|
171
|
+
dims_expr = dims_expression(model, var)
|
|
172
|
+
if dims_expr == "[]":
|
|
173
|
+
dims_expr = ""
|
|
174
|
+
table.add_row(var_name + sep, var_expr, dims_expr)
|
|
175
|
+
|
|
176
|
+
if parameter_count and (not split_groups or group == model.free_RVs):
|
|
177
|
+
n_parameters = model_parameter_count(model)
|
|
178
|
+
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]")
|
|
179
|
+
|
|
180
|
+
table.add_section()
|
|
181
|
+
|
|
182
|
+
return table
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from pymc_extras.statespace.core.compile import compile_statespace
|
|
2
|
+
from pymc_extras.statespace.models import structural
|
|
3
|
+
from pymc_extras.statespace.models.ETS import BayesianETS
|
|
4
|
+
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
|
|
5
|
+
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"compile_statespace",
|
|
9
|
+
"structural",
|
|
10
|
+
"BayesianETS",
|
|
11
|
+
"BayesianSARIMA",
|
|
12
|
+
"BayesianVARMAX",
|
|
13
|
+
]
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# ruff: noqa: I001
|
|
2
|
+
|
|
3
|
+
from pymc_extras.statespace.core.representation import PytensorRepresentation
|
|
4
|
+
from pymc_extras.statespace.core.statespace import PyMCStateSpace
|
|
5
|
+
from pymc_extras.statespace.core.compile import compile_statespace
|
|
6
|
+
|
|
7
|
+
__all__ = ["PytensorRepresentation", "PyMCStateSpace", "compile_statespace"]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pymc as pm
|
|
3
|
+
import pytensor
|
|
4
|
+
import pytensor.tensor as pt
|
|
5
|
+
|
|
6
|
+
from pymc_extras.statespace.core import PyMCStateSpace
|
|
7
|
+
from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
|
|
8
|
+
from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compile_statespace(
|
|
12
|
+
statespace_model: PyMCStateSpace, steps: int | None = None, **compile_kwargs
|
|
13
|
+
):
|
|
14
|
+
if steps is None:
|
|
15
|
+
steps = pt.iscalar("steps")
|
|
16
|
+
|
|
17
|
+
x0, _, c, d, T, Z, R, H, Q = statespace_model._unpack_statespace_with_placeholders()
|
|
18
|
+
|
|
19
|
+
sequence_names = [x.name for x in [c, d] if x.ndim == 2]
|
|
20
|
+
sequence_names += [x.name for x in [T, Z, R, H, Q] if x.ndim == 3]
|
|
21
|
+
|
|
22
|
+
rename_dict = {v: k for k, v in SHORT_NAME_TO_LONG.items()}
|
|
23
|
+
sequence_names = list(map(rename_dict.get, sequence_names))
|
|
24
|
+
|
|
25
|
+
P0 = pt.zeros((x0.shape[0], x0.shape[0]))
|
|
26
|
+
|
|
27
|
+
outputs = LinearGaussianStateSpace.dist(
|
|
28
|
+
x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
|
|
32
|
+
|
|
33
|
+
_f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
|
|
34
|
+
|
|
35
|
+
def f(*, draws=1, **params):
|
|
36
|
+
if isinstance(steps, pt.Variable):
|
|
37
|
+
inner_steps = params.get("steps", 100)
|
|
38
|
+
else:
|
|
39
|
+
inner_steps = steps
|
|
40
|
+
|
|
41
|
+
output = [np.empty((draws, inner_steps + 1, x.type.shape[-1])) for x in outputs]
|
|
42
|
+
for i in range(draws):
|
|
43
|
+
draw = _f(**params)
|
|
44
|
+
for j, x in enumerate(draw):
|
|
45
|
+
output[j][i] = x
|
|
46
|
+
return [x.squeeze() for x in output]
|
|
47
|
+
|
|
48
|
+
return f
|