pymc-extras 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. pymc_extras/inference/laplace_approx/laplace.py +2 -2
  2. pymc_extras/inference/pathfinder/pathfinder.py +1 -1
  3. pymc_extras/prior.py +3 -3
  4. pymc_extras/statespace/core/properties.py +276 -0
  5. pymc_extras/statespace/core/statespace.py +180 -44
  6. pymc_extras/statespace/filters/distributions.py +12 -29
  7. pymc_extras/statespace/filters/kalman_filter.py +1 -1
  8. pymc_extras/statespace/models/DFM.py +179 -168
  9. pymc_extras/statespace/models/ETS.py +177 -151
  10. pymc_extras/statespace/models/SARIMAX.py +149 -152
  11. pymc_extras/statespace/models/VARMAX.py +134 -145
  12. pymc_extras/statespace/models/__init__.py +8 -1
  13. pymc_extras/statespace/models/structural/__init__.py +30 -8
  14. pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
  15. pymc_extras/statespace/models/structural/components/cycle.py +119 -80
  16. pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
  17. pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
  18. pymc_extras/statespace/models/structural/components/regression.py +105 -68
  19. pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
  20. pymc_extras/statespace/models/structural/core.py +397 -286
  21. pymc_extras/statespace/models/utilities.py +5 -20
  22. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +3 -3
  23. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +25 -24
  24. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
  25. {pymc_extras-0.7.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -137,7 +137,7 @@ def get_conditional_gaussian_approximation(
137
137
  hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
138
138
 
139
139
  # Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
140
- _, logdetQ = pt.nlinalg.slogdet(Q)
140
+ _, logdetQ = pt.linalg.slogdet(Q)
141
141
  conditional_gaussian_approx = (
142
142
  -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
143
143
  )
@@ -153,7 +153,7 @@ def unpack_last_axis(packed_input, packed_shapes):
153
153
  return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
154
154
 
155
155
  keep_axes = tuple(range(packed_input.ndim))[:-1]
156
- return pt.unpack(packed_input, axes=keep_axes, packed_shapes=packed_shapes)
156
+ return pt.unpack(packed_input, keep_axes=keep_axes, packed_shapes=packed_shapes)
157
157
 
158
158
 
159
159
  def draws_from_laplace_approx(
@@ -385,7 +385,7 @@ def inverse_hessian_factors(
385
385
  # more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
386
386
 
387
387
  # E_inv: (L, J, J)
388
- E_inv = pt.slinalg.solve_triangular(E, Ij, check_finite=False)
388
+ E_inv = pt.linalg.solve_triangular(E, Ij, check_finite=False)
389
389
  eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
390
390
 
391
391
  # block_dd: (L, J, J)
pymc_extras/prior.py CHANGED
@@ -1575,9 +1575,9 @@ def __getattr__(name: str):
1575
1575
  samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]})
1576
1576
 
1577
1577
  """
1578
- # Protect against doctest
1579
- if name == "__wrapped__":
1580
- return
1578
+ # Ignore Python internal attributes needed for introspection
1579
+ if name.startswith("__"):
1580
+ raise AttributeError(name)
1581
1581
 
1582
1582
  _get_pymc_distribution(name)
1583
1583
  return partial(Prior, distribution=name)
@@ -0,0 +1,276 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterator
4
+ from copy import deepcopy
5
+ from dataclasses import dataclass, fields
6
+ from typing import Generic, Protocol, Self, TypeVar
7
+
8
+ from pytensor.tensor.variable import TensorVariable
9
+
10
+ from pymc_extras.statespace.utils.constants import (
11
+ ALL_STATE_AUX_DIM,
12
+ ALL_STATE_DIM,
13
+ OBS_STATE_AUX_DIM,
14
+ OBS_STATE_DIM,
15
+ SHOCK_AUX_DIM,
16
+ SHOCK_DIM,
17
+ )
18
+
19
+
20
+ class StateSpaceLike(Protocol):
21
+ @property
22
+ def state_names(self) -> tuple[str, ...]: ...
23
+
24
+ @property
25
+ def observed_states(self) -> tuple[str, ...]: ...
26
+
27
+ @property
28
+ def shock_names(self) -> tuple[str, ...]: ...
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class Property:
33
+ def __str__(self) -> str:
34
+ return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))
35
+
36
+
37
+ T = TypeVar("T", bound=Property)
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class Info(Generic[T]):
42
+ items: tuple[T, ...] | None
43
+ key_field: str | tuple[str, ...] = "name"
44
+ _index: dict[str | tuple, T] | None = None
45
+
46
+ def __post_init__(self):
47
+ index = {}
48
+ if self.items is None:
49
+ object.__setattr__(self, "items", ())
50
+ else:
51
+ object.__setattr__(self, "items", tuple(self.items))
52
+
53
+ for item in self.items:
54
+ key = self._key(item)
55
+ if key in index:
56
+ raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
57
+ index[key] = item
58
+ object.__setattr__(self, "_index", index)
59
+
60
+ def _key(self, item: T) -> str | tuple:
61
+ if isinstance(self.key_field, tuple):
62
+ return tuple(getattr(item, f) for f in self.key_field)
63
+ return getattr(item, self.key_field)
64
+
65
+ def get(self, key: str | tuple, default=None) -> T | None:
66
+ return self._index.get(key, default)
67
+
68
+ def __getitem__(self, key: str | tuple) -> T:
69
+ try:
70
+ return self._index[key]
71
+ except KeyError as e:
72
+ available = ", ".join(str(k) for k in self._index.keys())
73
+ raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e
74
+
75
+ def __contains__(self, key: object) -> bool:
76
+ return key in self._index
77
+
78
+ def __iter__(self) -> Iterator[T]:
79
+ return iter(self.items)
80
+
81
+ def __len__(self) -> int:
82
+ return len(self.items)
83
+
84
+ def __str__(self) -> str:
85
+ return f"{self.key_field}s: {tuple(self._index.keys())}"
86
+
87
+ def add(self, new_item: T) -> Self:
88
+ return type(self)((*self.items, new_item))
89
+
90
+ def merge(self, other: Self, overwrite_duplicates: bool = False) -> Self:
91
+ if not isinstance(other, type(self)):
92
+ raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")
93
+
94
+ overlapping = set(self._index.keys()) & set(other._index.keys())
95
+ if overlapping and overwrite_duplicates:
96
+ return type(self)(
97
+ (
98
+ *self.items,
99
+ *(item for item in other.items if self._key(item) not in overlapping),
100
+ )
101
+ )
102
+
103
+ return type(self)(self.items + other.items)
104
+
105
+ @property
106
+ def names(self) -> tuple[str, ...]:
107
+ if isinstance(self.key_field, tuple):
108
+ return tuple(item.name for item in self.items)
109
+ return tuple(self._index.keys())
110
+
111
+ def copy(self) -> Info[T]:
112
+ return deepcopy(self)
113
+
114
+
115
+ @dataclass(frozen=True)
116
+ class Parameter(Property):
117
+ name: str
118
+ shape: tuple[int, ...] | None = None
119
+ dims: tuple[str, ...] | None = None
120
+ constraints: str | None = None
121
+
122
+
123
+ @dataclass(frozen=True)
124
+ class ParameterInfo(Info[Parameter]):
125
+ def __init__(self, parameters: tuple[Parameter, ...] | None):
126
+ super().__init__(items=parameters, key_field="name")
127
+
128
+ def to_dict(self):
129
+ return {
130
+ param.name: {"shape": param.shape, "constraints": param.constraints, "dims": param.dims}
131
+ for param in self.items
132
+ }
133
+
134
+
135
+ @dataclass(frozen=True)
136
+ class Data(Property):
137
+ name: str
138
+ shape: tuple[int | None, ...] | None = None
139
+ dims: tuple[str, ...] | None = None
140
+ is_exogenous: bool = False
141
+
142
+
143
+ @dataclass(frozen=True)
144
+ class DataInfo(Info[Data]):
145
+ def __init__(self, data: tuple[Data, ...] | None):
146
+ super().__init__(items=data, key_field="name")
147
+
148
+ @property
149
+ def needs_exogenous_data(self) -> bool:
150
+ return any(d.is_exogenous for d in self.items)
151
+
152
+ @property
153
+ def exogenous_names(self) -> tuple[str, ...]:
154
+ return tuple(d.name for d in self.items if d.is_exogenous)
155
+
156
+ def __str__(self) -> str:
157
+ return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
158
+
159
+ def to_dict(self):
160
+ return {
161
+ data.name: {"shape": data.shape, "dims": data.dims, "exogenous": data.is_exogenous}
162
+ for data in self.items
163
+ }
164
+
165
+
166
+ @dataclass(frozen=True)
167
+ class Coord(Property):
168
+ dimension: str
169
+ labels: tuple[str | int, ...]
170
+
171
+
172
+ @dataclass(frozen=True)
173
+ class CoordInfo(Info[Coord]):
174
+ def __init__(self, coords: tuple[Coord, ...] | None = None):
175
+ super().__init__(items=coords, key_field="dimension")
176
+
177
+ def __str__(self) -> str:
178
+ base = "coordinates:"
179
+ for coord in self.items:
180
+ coord_str = str(coord)
181
+ indented = "\n".join(" " + line for line in coord_str.splitlines())
182
+ base += "\n" + indented + "\n"
183
+ return base
184
+
185
+ @classmethod
186
+ def default_coords_from_model(cls, model: StateSpaceLike) -> CoordInfo:
187
+ states = tuple(model.state_names)
188
+ obs_states = tuple(model.observed_states)
189
+ shocks = tuple(model.shock_names)
190
+
191
+ dim_to_labels = (
192
+ (ALL_STATE_DIM, states),
193
+ (ALL_STATE_AUX_DIM, states),
194
+ (OBS_STATE_DIM, obs_states),
195
+ (OBS_STATE_AUX_DIM, obs_states),
196
+ (SHOCK_DIM, shocks),
197
+ (SHOCK_AUX_DIM, shocks),
198
+ )
199
+
200
+ coords = tuple(Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels)
201
+ return cls(coords=coords)
202
+
203
+ def to_dict(self):
204
+ return {coord.dimension: tuple(coord.labels) for coord in self.items}
205
+
206
+
207
+ @dataclass(frozen=True)
208
+ class State(Property):
209
+ name: str
210
+ observed: bool
211
+ shared: bool = False
212
+
213
+
214
+ @dataclass(frozen=True)
215
+ class StateInfo(Info[State]):
216
+ def __init__(self, states: tuple[State, ...] | None):
217
+ super().__init__(items=states, key_field=("name", "observed"))
218
+
219
+ def __contains__(self, key: object) -> bool:
220
+ if isinstance(key, str):
221
+ return any(s.name == key for s in self.items)
222
+ return key in self._index
223
+
224
+ def __str__(self) -> str:
225
+ return (
226
+ f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
227
+ )
228
+
229
+ @property
230
+ def observed_state_names(self) -> tuple[str, ...]:
231
+ return tuple(s.name for s in self.items if s.observed)
232
+
233
+ @property
234
+ def unobserved_state_names(self) -> tuple[str, ...]:
235
+ return tuple(s.name for s in self.items if not s.observed)
236
+
237
+
238
+ @dataclass(frozen=True)
239
+ class Shock(Property):
240
+ name: str
241
+
242
+
243
+ @dataclass(frozen=True)
244
+ class ShockInfo(Info[Shock]):
245
+ def __init__(self, shocks: tuple[Shock, ...] | None):
246
+ super().__init__(items=shocks, key_field="name")
247
+
248
+
249
+ @dataclass(frozen=True)
250
+ class SymbolicVariable(Property):
251
+ name: str
252
+ symbolic_variable: TensorVariable
253
+
254
+
255
+ @dataclass(frozen=True)
256
+ class SymbolicVariableInfo(Info[SymbolicVariable]):
257
+ def __init__(self, symbolic_variables: tuple[SymbolicVariable, ...] | None = None):
258
+ super().__init__(items=symbolic_variables, key_field="name")
259
+
260
+ def to_dict(self):
261
+ return {variable.name: variable.symbolic_variable for variable in self.items}
262
+
263
+
264
+ @dataclass(frozen=True)
265
+ class SymbolicData(Property):
266
+ name: str
267
+ symbolic_data: TensorVariable
268
+
269
+
270
+ @dataclass(frozen=True)
271
+ class SymbolicDataInfo(Info[SymbolicData]):
272
+ def __init__(self, symbolic_data: tuple[SymbolicData, ...] | None = None):
273
+ super().__init__(items=symbolic_data, key_field="name")
274
+
275
+ def to_dict(self):
276
+ return {data.name: data.symbolic_data for data in self.items}