pymc-extras 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/inference/laplace_approx/laplace.py +2 -2
- pymc_extras/inference/pathfinder/pathfinder.py +1 -1
- pymc_extras/prior.py +3 -3
- pymc_extras/statespace/core/properties.py +276 -0
- pymc_extras/statespace/core/statespace.py +180 -44
- pymc_extras/statespace/filters/distributions.py +12 -29
- pymc_extras/statespace/filters/kalman_filter.py +1 -1
- pymc_extras/statespace/models/DFM.py +179 -168
- pymc_extras/statespace/models/ETS.py +177 -151
- pymc_extras/statespace/models/SARIMAX.py +149 -152
- pymc_extras/statespace/models/VARMAX.py +134 -145
- pymc_extras/statespace/models/__init__.py +8 -1
- pymc_extras/statespace/models/structural/__init__.py +30 -8
- pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
- pymc_extras/statespace/models/structural/components/cycle.py +119 -80
- pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
- pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
- pymc_extras/statespace/models/structural/components/regression.py +105 -68
- pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
- pymc_extras/statespace/models/structural/core.py +397 -286
- pymc_extras/statespace/models/utilities.py +5 -20
- {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +3 -3
- {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +25 -24
- {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -137,7 +137,7 @@ def get_conditional_gaussian_approximation(
|
|
|
137
137
|
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
|
|
138
138
|
|
|
139
139
|
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
|
|
140
|
-
_, logdetQ = pt.
|
|
140
|
+
_, logdetQ = pt.linalg.slogdet(Q)
|
|
141
141
|
conditional_gaussian_approx = (
|
|
142
142
|
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
|
|
143
143
|
)
|
|
@@ -153,7 +153,7 @@ def unpack_last_axis(packed_input, packed_shapes):
|
|
|
153
153
|
return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
|
|
154
154
|
|
|
155
155
|
keep_axes = tuple(range(packed_input.ndim))[:-1]
|
|
156
|
-
return pt.unpack(packed_input,
|
|
156
|
+
return pt.unpack(packed_input, keep_axes=keep_axes, packed_shapes=packed_shapes)
|
|
157
157
|
|
|
158
158
|
|
|
159
159
|
def draws_from_laplace_approx(
|
|
@@ -385,7 +385,7 @@ def inverse_hessian_factors(
|
|
|
385
385
|
# more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
|
|
386
386
|
|
|
387
387
|
# E_inv: (L, J, J)
|
|
388
|
-
E_inv = pt.
|
|
388
|
+
E_inv = pt.linalg.solve_triangular(E, Ij, check_finite=False)
|
|
389
389
|
eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
|
|
390
390
|
|
|
391
391
|
# block_dd: (L, J, J)
|
pymc_extras/prior.py
CHANGED
|
@@ -1575,9 +1575,9 @@ def __getattr__(name: str):
|
|
|
1575
1575
|
samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]})
|
|
1576
1576
|
|
|
1577
1577
|
"""
|
|
1578
|
-
#
|
|
1579
|
-
if name
|
|
1580
|
-
|
|
1578
|
+
# Ignore Python internal attributes needed for introspection
|
|
1579
|
+
if name.startswith("__"):
|
|
1580
|
+
raise AttributeError(name)
|
|
1581
1581
|
|
|
1582
1582
|
_get_pymc_distribution(name)
|
|
1583
1583
|
return partial(Prior, distribution=name)
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from dataclasses import dataclass, fields
|
|
6
|
+
from typing import Generic, Protocol, Self, TypeVar
|
|
7
|
+
|
|
8
|
+
from pytensor.tensor.variable import TensorVariable
|
|
9
|
+
|
|
10
|
+
from pymc_extras.statespace.utils.constants import (
|
|
11
|
+
ALL_STATE_AUX_DIM,
|
|
12
|
+
ALL_STATE_DIM,
|
|
13
|
+
OBS_STATE_AUX_DIM,
|
|
14
|
+
OBS_STATE_DIM,
|
|
15
|
+
SHOCK_AUX_DIM,
|
|
16
|
+
SHOCK_DIM,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StateSpaceLike(Protocol):
|
|
21
|
+
@property
|
|
22
|
+
def state_names(self) -> tuple[str, ...]: ...
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def observed_states(self) -> tuple[str, ...]: ...
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def shock_names(self) -> tuple[str, ...]: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class Property:
|
|
33
|
+
def __str__(self) -> str:
|
|
34
|
+
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
T = TypeVar("T", bound=Property)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class Info(Generic[T]):
|
|
42
|
+
items: tuple[T, ...] | None
|
|
43
|
+
key_field: str | tuple[str, ...] = "name"
|
|
44
|
+
_index: dict[str | tuple, T] | None = None
|
|
45
|
+
|
|
46
|
+
def __post_init__(self):
|
|
47
|
+
index = {}
|
|
48
|
+
if self.items is None:
|
|
49
|
+
object.__setattr__(self, "items", ())
|
|
50
|
+
else:
|
|
51
|
+
object.__setattr__(self, "items", tuple(self.items))
|
|
52
|
+
|
|
53
|
+
for item in self.items:
|
|
54
|
+
key = self._key(item)
|
|
55
|
+
if key in index:
|
|
56
|
+
raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
|
|
57
|
+
index[key] = item
|
|
58
|
+
object.__setattr__(self, "_index", index)
|
|
59
|
+
|
|
60
|
+
def _key(self, item: T) -> str | tuple:
|
|
61
|
+
if isinstance(self.key_field, tuple):
|
|
62
|
+
return tuple(getattr(item, f) for f in self.key_field)
|
|
63
|
+
return getattr(item, self.key_field)
|
|
64
|
+
|
|
65
|
+
def get(self, key: str | tuple, default=None) -> T | None:
|
|
66
|
+
return self._index.get(key, default)
|
|
67
|
+
|
|
68
|
+
def __getitem__(self, key: str | tuple) -> T:
|
|
69
|
+
try:
|
|
70
|
+
return self._index[key]
|
|
71
|
+
except KeyError as e:
|
|
72
|
+
available = ", ".join(str(k) for k in self._index.keys())
|
|
73
|
+
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e
|
|
74
|
+
|
|
75
|
+
def __contains__(self, key: object) -> bool:
|
|
76
|
+
return key in self._index
|
|
77
|
+
|
|
78
|
+
def __iter__(self) -> Iterator[T]:
|
|
79
|
+
return iter(self.items)
|
|
80
|
+
|
|
81
|
+
def __len__(self) -> int:
|
|
82
|
+
return len(self.items)
|
|
83
|
+
|
|
84
|
+
def __str__(self) -> str:
|
|
85
|
+
return f"{self.key_field}s: {tuple(self._index.keys())}"
|
|
86
|
+
|
|
87
|
+
def add(self, new_item: T) -> Self:
|
|
88
|
+
return type(self)((*self.items, new_item))
|
|
89
|
+
|
|
90
|
+
def merge(self, other: Self, overwrite_duplicates: bool = False) -> Self:
|
|
91
|
+
if not isinstance(other, type(self)):
|
|
92
|
+
raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")
|
|
93
|
+
|
|
94
|
+
overlapping = set(self._index.keys()) & set(other._index.keys())
|
|
95
|
+
if overlapping and overwrite_duplicates:
|
|
96
|
+
return type(self)(
|
|
97
|
+
(
|
|
98
|
+
*self.items,
|
|
99
|
+
*(item for item in other.items if self._key(item) not in overlapping),
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return type(self)(self.items + other.items)
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def names(self) -> tuple[str, ...]:
|
|
107
|
+
if isinstance(self.key_field, tuple):
|
|
108
|
+
return tuple(item.name for item in self.items)
|
|
109
|
+
return tuple(self._index.keys())
|
|
110
|
+
|
|
111
|
+
def copy(self) -> Info[T]:
|
|
112
|
+
return deepcopy(self)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class Parameter(Property):
|
|
117
|
+
name: str
|
|
118
|
+
shape: tuple[int, ...] | None = None
|
|
119
|
+
dims: tuple[str, ...] | None = None
|
|
120
|
+
constraints: str | None = None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass(frozen=True)
|
|
124
|
+
class ParameterInfo(Info[Parameter]):
|
|
125
|
+
def __init__(self, parameters: tuple[Parameter, ...] | None):
|
|
126
|
+
super().__init__(items=parameters, key_field="name")
|
|
127
|
+
|
|
128
|
+
def to_dict(self):
|
|
129
|
+
return {
|
|
130
|
+
param.name: {"shape": param.shape, "constraints": param.constraints, "dims": param.dims}
|
|
131
|
+
for param in self.items
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclass(frozen=True)
|
|
136
|
+
class Data(Property):
|
|
137
|
+
name: str
|
|
138
|
+
shape: tuple[int | None, ...] | None = None
|
|
139
|
+
dims: tuple[str, ...] | None = None
|
|
140
|
+
is_exogenous: bool = False
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclass(frozen=True)
|
|
144
|
+
class DataInfo(Info[Data]):
|
|
145
|
+
def __init__(self, data: tuple[Data, ...] | None):
|
|
146
|
+
super().__init__(items=data, key_field="name")
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def needs_exogenous_data(self) -> bool:
|
|
150
|
+
return any(d.is_exogenous for d in self.items)
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def exogenous_names(self) -> tuple[str, ...]:
|
|
154
|
+
return tuple(d.name for d in self.items if d.is_exogenous)
|
|
155
|
+
|
|
156
|
+
def __str__(self) -> str:
|
|
157
|
+
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
|
|
158
|
+
|
|
159
|
+
def to_dict(self):
|
|
160
|
+
return {
|
|
161
|
+
data.name: {"shape": data.shape, "dims": data.dims, "exogenous": data.is_exogenous}
|
|
162
|
+
for data in self.items
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(frozen=True)
|
|
167
|
+
class Coord(Property):
|
|
168
|
+
dimension: str
|
|
169
|
+
labels: tuple[str | int, ...]
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@dataclass(frozen=True)
|
|
173
|
+
class CoordInfo(Info[Coord]):
|
|
174
|
+
def __init__(self, coords: tuple[Coord, ...] | None = None):
|
|
175
|
+
super().__init__(items=coords, key_field="dimension")
|
|
176
|
+
|
|
177
|
+
def __str__(self) -> str:
|
|
178
|
+
base = "coordinates:"
|
|
179
|
+
for coord in self.items:
|
|
180
|
+
coord_str = str(coord)
|
|
181
|
+
indented = "\n".join(" " + line for line in coord_str.splitlines())
|
|
182
|
+
base += "\n" + indented + "\n"
|
|
183
|
+
return base
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def default_coords_from_model(cls, model: StateSpaceLike) -> CoordInfo:
|
|
187
|
+
states = tuple(model.state_names)
|
|
188
|
+
obs_states = tuple(model.observed_states)
|
|
189
|
+
shocks = tuple(model.shock_names)
|
|
190
|
+
|
|
191
|
+
dim_to_labels = (
|
|
192
|
+
(ALL_STATE_DIM, states),
|
|
193
|
+
(ALL_STATE_AUX_DIM, states),
|
|
194
|
+
(OBS_STATE_DIM, obs_states),
|
|
195
|
+
(OBS_STATE_AUX_DIM, obs_states),
|
|
196
|
+
(SHOCK_DIM, shocks),
|
|
197
|
+
(SHOCK_AUX_DIM, shocks),
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
coords = tuple(Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels)
|
|
201
|
+
return cls(coords=coords)
|
|
202
|
+
|
|
203
|
+
def to_dict(self):
|
|
204
|
+
return {coord.dimension: tuple(coord.labels) for coord in self.items}
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@dataclass(frozen=True)
|
|
208
|
+
class State(Property):
|
|
209
|
+
name: str
|
|
210
|
+
observed: bool
|
|
211
|
+
shared: bool = False
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dataclass(frozen=True)
|
|
215
|
+
class StateInfo(Info[State]):
|
|
216
|
+
def __init__(self, states: tuple[State, ...] | None):
|
|
217
|
+
super().__init__(items=states, key_field=("name", "observed"))
|
|
218
|
+
|
|
219
|
+
def __contains__(self, key: object) -> bool:
|
|
220
|
+
if isinstance(key, str):
|
|
221
|
+
return any(s.name == key for s in self.items)
|
|
222
|
+
return key in self._index
|
|
223
|
+
|
|
224
|
+
def __str__(self) -> str:
|
|
225
|
+
return (
|
|
226
|
+
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def observed_state_names(self) -> tuple[str, ...]:
|
|
231
|
+
return tuple(s.name for s in self.items if s.observed)
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def unobserved_state_names(self) -> tuple[str, ...]:
|
|
235
|
+
return tuple(s.name for s in self.items if not s.observed)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@dataclass(frozen=True)
|
|
239
|
+
class Shock(Property):
|
|
240
|
+
name: str
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclass(frozen=True)
|
|
244
|
+
class ShockInfo(Info[Shock]):
|
|
245
|
+
def __init__(self, shocks: tuple[Shock, ...] | None):
|
|
246
|
+
super().__init__(items=shocks, key_field="name")
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dataclass(frozen=True)
|
|
250
|
+
class SymbolicVariable(Property):
|
|
251
|
+
name: str
|
|
252
|
+
symbolic_variable: TensorVariable
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@dataclass(frozen=True)
|
|
256
|
+
class SymbolicVariableInfo(Info[SymbolicVariable]):
|
|
257
|
+
def __init__(self, symbolic_variables: tuple[SymbolicVariable, ...] | None = None):
|
|
258
|
+
super().__init__(items=symbolic_variables, key_field="name")
|
|
259
|
+
|
|
260
|
+
def to_dict(self):
|
|
261
|
+
return {variable.name: variable.symbolic_variable for variable in self.items}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@dataclass(frozen=True)
|
|
265
|
+
class SymbolicData(Property):
|
|
266
|
+
name: str
|
|
267
|
+
symbolic_data: TensorVariable
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass(frozen=True)
|
|
271
|
+
class SymbolicDataInfo(Info[SymbolicData]):
|
|
272
|
+
def __init__(self, symbolic_data: tuple[SymbolicData, ...] | None = None):
|
|
273
|
+
super().__init__(items=symbolic_data, key_field="name")
|
|
274
|
+
|
|
275
|
+
def to_dict(self):
|
|
276
|
+
return {data.name: data.symbolic_data for data in self.items}
|