pymc-extras 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. pymc_extras/distributions/timeseries.py +10 -10
  2. pymc_extras/inference/dadvi/dadvi.py +14 -83
  3. pymc_extras/inference/laplace_approx/laplace.py +187 -159
  4. pymc_extras/inference/pathfinder/pathfinder.py +12 -7
  5. pymc_extras/inference/smc/sampling.py +2 -2
  6. pymc_extras/model/marginal/distributions.py +4 -2
  7. pymc_extras/model/marginal/marginal_model.py +12 -2
  8. pymc_extras/prior.py +3 -3
  9. pymc_extras/statespace/core/properties.py +276 -0
  10. pymc_extras/statespace/core/statespace.py +182 -45
  11. pymc_extras/statespace/filters/distributions.py +19 -34
  12. pymc_extras/statespace/filters/kalman_filter.py +13 -12
  13. pymc_extras/statespace/filters/kalman_smoother.py +2 -2
  14. pymc_extras/statespace/models/DFM.py +179 -168
  15. pymc_extras/statespace/models/ETS.py +177 -151
  16. pymc_extras/statespace/models/SARIMAX.py +149 -152
  17. pymc_extras/statespace/models/VARMAX.py +134 -145
  18. pymc_extras/statespace/models/__init__.py +8 -1
  19. pymc_extras/statespace/models/structural/__init__.py +30 -8
  20. pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
  21. pymc_extras/statespace/models/structural/components/cycle.py +119 -80
  22. pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
  23. pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
  24. pymc_extras/statespace/models/structural/components/regression.py +105 -68
  25. pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
  26. pymc_extras/statespace/models/structural/core.py +397 -286
  27. pymc_extras/statespace/models/utilities.py +5 -20
  28. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +4 -4
  29. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +31 -30
  30. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
  31. {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -278,12 +278,13 @@ def alpha_recover(
278
278
  z = pt.diff(g, axis=0)
279
279
  alpha_l_init = pt.ones(N)
280
280
 
281
- alpha, _ = pytensor.scan(
281
+ alpha = pytensor.scan(
282
282
  fn=compute_alpha_l,
283
283
  outputs_info=alpha_l_init,
284
284
  sequences=[s, z],
285
285
  n_steps=Lp1 - 1,
286
286
  allow_gc=False,
287
+ return_updates=False,
287
288
  )
288
289
 
289
290
  # assert np.all(alpha.eval() > 0), "alpha cannot be negative"
@@ -334,11 +335,12 @@ def inverse_hessian_factors(
334
335
  return pt.set_subtensor(chi_l[j_last], diff_l)
335
336
 
336
337
  chi_init = pt.zeros((J, N))
337
- chi_mat, _ = pytensor.scan(
338
+ chi_mat = pytensor.scan(
338
339
  fn=chi_update,
339
340
  outputs_info=chi_init,
340
341
  sequences=[diff],
341
342
  allow_gc=False,
343
+ return_updates=False,
342
344
  )
343
345
 
344
346
  chi_mat = pt.matrix_transpose(chi_mat)
@@ -377,14 +379,14 @@ def inverse_hessian_factors(
377
379
  eta = pt.diagonal(E, axis1=-2, axis2=-1)
378
380
 
379
381
  # beta: (L, N, 2J)
380
- alpha_diag, _ = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha])
382
+ alpha_diag = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha], return_updates=False)
381
383
  beta = pt.concatenate([alpha_diag @ Z, S], axis=-1)
382
384
 
383
385
  # more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
384
386
 
385
387
  # E_inv: (L, J, J)
386
- E_inv = pt.slinalg.solve_triangular(E, Ij, check_finite=False)
387
- eta_diag, _ = pytensor.scan(pt.diag, sequences=[eta])
388
+ E_inv = pt.linalg.solve_triangular(E, Ij, check_finite=False)
389
+ eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
388
390
 
389
391
  # block_dd: (L, J, J)
390
392
  block_dd = (
@@ -530,7 +532,9 @@ def bfgs_sample_sparse(
530
532
 
531
533
  # qr_input: (L, N, 2J)
532
534
  qr_input = inv_sqrt_alpha_diag @ beta
533
- (Q, R), _ = pytensor.scan(fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False)
535
+ Q, R = pytensor.scan(
536
+ fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False, return_updates=False
537
+ )
534
538
 
535
539
  IdN = pt.eye(R.shape[1])[None, ...]
536
540
  IdN += IdN * REGULARISATION_TERM
@@ -623,10 +627,11 @@ def bfgs_sample(
623
627
 
624
628
  L, N, JJ = beta.shape
625
629
 
626
- (alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag), _ = pytensor.scan(
630
+ alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag = pytensor.scan(
627
631
  lambda a: [pt.diag(a), pt.diag(pt.sqrt(1.0 / a)), pt.diag(pt.sqrt(a))],
628
632
  sequences=[alpha],
629
633
  allow_gc=False,
634
+ return_updates=False,
630
635
  )
631
636
 
632
637
  u = pt.random.normal(size=(L, num_samples, N))
@@ -238,7 +238,7 @@ class SMCDiagnostics(NamedTuple):
238
238
  def update_diagnosis(i, history, info, state):
239
239
  le, lli, ancestors, weights_evolution = history
240
240
  return SMCDiagnostics(
241
- le.at[i].set(state.lmbda),
241
+ le.at[i].set(state.tempering_param),
242
242
  lli.at[i].set(info.log_likelihood_increment),
243
243
  ancestors.at[i].set(info.ancestors),
244
244
  weights_evolution.at[i].set(state.weights),
@@ -265,7 +265,7 @@ def inference_loop(rng_key, initial_state, kernel, iterations_to_diagnose, n_par
265
265
 
266
266
  def cond(carry):
267
267
  i, state, _, _ = carry
268
- return state.lmbda < 1
268
+ return state.tempering_param < 1
269
269
 
270
270
  def one_step(carry):
271
271
  i, state, k, previous_info = carry
@@ -282,11 +282,12 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
282
282
  def logp_fn(marginalized_rv_const, *non_sequences):
283
283
  return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
284
284
 
285
- joint_logps, _ = scan_map(
285
+ joint_logps = scan_map(
286
286
  fn=logp_fn,
287
287
  sequences=marginalized_rv_domain_tensor,
288
288
  non_sequences=[*values, *inputs],
289
289
  mode=Mode().including("local_remove_check_parameter"),
290
+ return_updates=False,
290
291
  )
291
292
 
292
293
  joint_logp = pt.logsumexp(joint_logps, axis=0)
@@ -350,12 +351,13 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
350
351
 
351
352
  P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
352
353
  log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
353
- log_alpha_seq, _ = scan(
354
+ log_alpha_seq = scan(
354
355
  step_alpha,
355
356
  non_sequences=[log_P],
356
357
  outputs_info=[log_alpha_init],
357
358
  # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
358
359
  sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
360
+ return_updates=False,
359
361
  )
360
362
  # Final logp is just the sum of the last scan state
361
363
  joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
@@ -11,7 +11,7 @@ from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_po
11
11
  from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
12
12
  from pymc.distributions.transforms import Chain
13
13
  from pymc.logprob.transforms import IntervalTransform
14
- from pymc.model import Model
14
+ from pymc.model import Model, modelcontext
15
15
  from pymc.model.fgraph import (
16
16
  ModelFreeRV,
17
17
  ModelValuedVar,
@@ -337,8 +337,9 @@ def transform_posterior_pts(model, posterior_pts):
337
337
 
338
338
 
339
339
  def recover_marginals(
340
- model: Model,
341
340
  idata: InferenceData,
341
+ *,
342
+ model: Model | None = None,
342
343
  var_names: Sequence[str] | None = None,
343
344
  return_samples: bool = True,
344
345
  extend_inferencedata: bool = True,
@@ -389,6 +390,15 @@ def recover_marginals(
389
390
 
390
391
 
391
392
  """
393
+ # Temporary error message for helping with migration
394
+ # Will be removed in a future release
395
+ if isinstance(idata, Model):
396
+ raise TypeError(
397
+ "The order of arguments of `recover_marginals` changed. The first input must be an idata"
398
+ )
399
+
400
+ model = modelcontext(model)
401
+
392
402
  unmarginal_model = unmarginalize(model)
393
403
 
394
404
  # Find the names of the marginalized variables
pymc_extras/prior.py CHANGED
@@ -1575,9 +1575,9 @@ def __getattr__(name: str):
1575
1575
  samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]})
1576
1576
 
1577
1577
  """
1578
- # Protect against doctest
1579
- if name == "__wrapped__":
1580
- return
1578
+ # Ignore Python internal attributes needed for introspection
1579
+ if name.startswith("__"):
1580
+ raise AttributeError(name)
1581
1581
 
1582
1582
  _get_pymc_distribution(name)
1583
1583
  return partial(Prior, distribution=name)
@@ -0,0 +1,276 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterator
4
+ from copy import deepcopy
5
+ from dataclasses import dataclass, fields
6
+ from typing import Generic, Protocol, Self, TypeVar
7
+
8
+ from pytensor.tensor.variable import TensorVariable
9
+
10
+ from pymc_extras.statespace.utils.constants import (
11
+ ALL_STATE_AUX_DIM,
12
+ ALL_STATE_DIM,
13
+ OBS_STATE_AUX_DIM,
14
+ OBS_STATE_DIM,
15
+ SHOCK_AUX_DIM,
16
+ SHOCK_DIM,
17
+ )
18
+
19
+
20
+ class StateSpaceLike(Protocol):
21
+ @property
22
+ def state_names(self) -> tuple[str, ...]: ...
23
+
24
+ @property
25
+ def observed_states(self) -> tuple[str, ...]: ...
26
+
27
+ @property
28
+ def shock_names(self) -> tuple[str, ...]: ...
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class Property:
33
+ def __str__(self) -> str:
34
+ return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))
35
+
36
+
37
+ T = TypeVar("T", bound=Property)
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class Info(Generic[T]):
42
+ items: tuple[T, ...] | None
43
+ key_field: str | tuple[str, ...] = "name"
44
+ _index: dict[str | tuple, T] | None = None
45
+
46
+ def __post_init__(self):
47
+ index = {}
48
+ if self.items is None:
49
+ object.__setattr__(self, "items", ())
50
+ else:
51
+ object.__setattr__(self, "items", tuple(self.items))
52
+
53
+ for item in self.items:
54
+ key = self._key(item)
55
+ if key in index:
56
+ raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
57
+ index[key] = item
58
+ object.__setattr__(self, "_index", index)
59
+
60
+ def _key(self, item: T) -> str | tuple:
61
+ if isinstance(self.key_field, tuple):
62
+ return tuple(getattr(item, f) for f in self.key_field)
63
+ return getattr(item, self.key_field)
64
+
65
+ def get(self, key: str | tuple, default=None) -> T | None:
66
+ return self._index.get(key, default)
67
+
68
+ def __getitem__(self, key: str | tuple) -> T:
69
+ try:
70
+ return self._index[key]
71
+ except KeyError as e:
72
+ available = ", ".join(str(k) for k in self._index.keys())
73
+ raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e
74
+
75
+ def __contains__(self, key: object) -> bool:
76
+ return key in self._index
77
+
78
+ def __iter__(self) -> Iterator[T]:
79
+ return iter(self.items)
80
+
81
+ def __len__(self) -> int:
82
+ return len(self.items)
83
+
84
+ def __str__(self) -> str:
85
+ return f"{self.key_field}s: {tuple(self._index.keys())}"
86
+
87
+ def add(self, new_item: T) -> Self:
88
+ return type(self)((*self.items, new_item))
89
+
90
+ def merge(self, other: Self, overwrite_duplicates: bool = False) -> Self:
91
+ if not isinstance(other, type(self)):
92
+ raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")
93
+
94
+ overlapping = set(self._index.keys()) & set(other._index.keys())
95
+ if overlapping and overwrite_duplicates:
96
+ return type(self)(
97
+ (
98
+ *self.items,
99
+ *(item for item in other.items if self._key(item) not in overlapping),
100
+ )
101
+ )
102
+
103
+ return type(self)(self.items + other.items)
104
+
105
+ @property
106
+ def names(self) -> tuple[str, ...]:
107
+ if isinstance(self.key_field, tuple):
108
+ return tuple(item.name for item in self.items)
109
+ return tuple(self._index.keys())
110
+
111
+ def copy(self) -> Info[T]:
112
+ return deepcopy(self)
113
+
114
+
115
+ @dataclass(frozen=True)
116
+ class Parameter(Property):
117
+ name: str
118
+ shape: tuple[int, ...] | None = None
119
+ dims: tuple[str, ...] | None = None
120
+ constraints: str | None = None
121
+
122
+
123
+ @dataclass(frozen=True)
124
+ class ParameterInfo(Info[Parameter]):
125
+ def __init__(self, parameters: tuple[Parameter, ...] | None):
126
+ super().__init__(items=parameters, key_field="name")
127
+
128
+ def to_dict(self):
129
+ return {
130
+ param.name: {"shape": param.shape, "constraints": param.constraints, "dims": param.dims}
131
+ for param in self.items
132
+ }
133
+
134
+
135
+ @dataclass(frozen=True)
136
+ class Data(Property):
137
+ name: str
138
+ shape: tuple[int | None, ...] | None = None
139
+ dims: tuple[str, ...] | None = None
140
+ is_exogenous: bool = False
141
+
142
+
143
+ @dataclass(frozen=True)
144
+ class DataInfo(Info[Data]):
145
+ def __init__(self, data: tuple[Data, ...] | None):
146
+ super().__init__(items=data, key_field="name")
147
+
148
+ @property
149
+ def needs_exogenous_data(self) -> bool:
150
+ return any(d.is_exogenous for d in self.items)
151
+
152
+ @property
153
+ def exogenous_names(self) -> tuple[str, ...]:
154
+ return tuple(d.name for d in self.items if d.is_exogenous)
155
+
156
+ def __str__(self) -> str:
157
+ return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
158
+
159
+ def to_dict(self):
160
+ return {
161
+ data.name: {"shape": data.shape, "dims": data.dims, "exogenous": data.is_exogenous}
162
+ for data in self.items
163
+ }
164
+
165
+
166
+ @dataclass(frozen=True)
167
+ class Coord(Property):
168
+ dimension: str
169
+ labels: tuple[str | int, ...]
170
+
171
+
172
+ @dataclass(frozen=True)
173
+ class CoordInfo(Info[Coord]):
174
+ def __init__(self, coords: tuple[Coord, ...] | None = None):
175
+ super().__init__(items=coords, key_field="dimension")
176
+
177
+ def __str__(self) -> str:
178
+ base = "coordinates:"
179
+ for coord in self.items:
180
+ coord_str = str(coord)
181
+ indented = "\n".join(" " + line for line in coord_str.splitlines())
182
+ base += "\n" + indented + "\n"
183
+ return base
184
+
185
+ @classmethod
186
+ def default_coords_from_model(cls, model: StateSpaceLike) -> CoordInfo:
187
+ states = tuple(model.state_names)
188
+ obs_states = tuple(model.observed_states)
189
+ shocks = tuple(model.shock_names)
190
+
191
+ dim_to_labels = (
192
+ (ALL_STATE_DIM, states),
193
+ (ALL_STATE_AUX_DIM, states),
194
+ (OBS_STATE_DIM, obs_states),
195
+ (OBS_STATE_AUX_DIM, obs_states),
196
+ (SHOCK_DIM, shocks),
197
+ (SHOCK_AUX_DIM, shocks),
198
+ )
199
+
200
+ coords = tuple(Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels)
201
+ return cls(coords=coords)
202
+
203
+ def to_dict(self):
204
+ return {coord.dimension: tuple(coord.labels) for coord in self.items}
205
+
206
+
207
+ @dataclass(frozen=True)
208
+ class State(Property):
209
+ name: str
210
+ observed: bool
211
+ shared: bool = False
212
+
213
+
214
+ @dataclass(frozen=True)
215
+ class StateInfo(Info[State]):
216
+ def __init__(self, states: tuple[State, ...] | None):
217
+ super().__init__(items=states, key_field=("name", "observed"))
218
+
219
+ def __contains__(self, key: object) -> bool:
220
+ if isinstance(key, str):
221
+ return any(s.name == key for s in self.items)
222
+ return key in self._index
223
+
224
+ def __str__(self) -> str:
225
+ return (
226
+ f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
227
+ )
228
+
229
+ @property
230
+ def observed_state_names(self) -> tuple[str, ...]:
231
+ return tuple(s.name for s in self.items if s.observed)
232
+
233
+ @property
234
+ def unobserved_state_names(self) -> tuple[str, ...]:
235
+ return tuple(s.name for s in self.items if not s.observed)
236
+
237
+
238
+ @dataclass(frozen=True)
239
+ class Shock(Property):
240
+ name: str
241
+
242
+
243
+ @dataclass(frozen=True)
244
+ class ShockInfo(Info[Shock]):
245
+ def __init__(self, shocks: tuple[Shock, ...] | None):
246
+ super().__init__(items=shocks, key_field="name")
247
+
248
+
249
+ @dataclass(frozen=True)
250
+ class SymbolicVariable(Property):
251
+ name: str
252
+ symbolic_variable: TensorVariable
253
+
254
+
255
+ @dataclass(frozen=True)
256
+ class SymbolicVariableInfo(Info[SymbolicVariable]):
257
+ def __init__(self, symbolic_variables: tuple[SymbolicVariable, ...] | None = None):
258
+ super().__init__(items=symbolic_variables, key_field="name")
259
+
260
+ def to_dict(self):
261
+ return {variable.name: variable.symbolic_variable for variable in self.items}
262
+
263
+
264
+ @dataclass(frozen=True)
265
+ class SymbolicData(Property):
266
+ name: str
267
+ symbolic_data: TensorVariable
268
+
269
+
270
+ @dataclass(frozen=True)
271
+ class SymbolicDataInfo(Info[SymbolicData]):
272
+ def __init__(self, symbolic_data: tuple[SymbolicData, ...] | None = None):
273
+ super().__init__(items=symbolic_data, key_field="name")
274
+
275
+ def to_dict(self):
276
+ return {data.name: data.symbolic_data for data in self.items}