numpyro-forecast 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,308 @@
1
+ """Forecasting model base class and SVI/MCMC forecasters.
2
+
3
+ This is the JAX/NumPyro port of Pyro's ``pyro.contrib.forecast.forecaster``.
4
+ The classes here are thin object-oriented shims over the functional core in
5
+ :mod:`numpyro_forecast.functional`: :class:`ForecastingModel` threads the
6
+ train/forecast :class:`~numpyro_forecast.functional.Horizon` for you, and the
7
+ forecaster classes wrap :func:`~numpyro_forecast.functional.fit_svi` /
8
+ :func:`~numpyro_forecast.functional.fit_mcmc` plus
9
+ :func:`~numpyro_forecast.functional.forecast`. The two styles are fully
10
+ interchangeable: both consume a model callable ``(covariates, data=None)`` and a
11
+ posterior dict of latent draws.
12
+ """
13
+
14
+ import abc
15
+ from collections.abc import Callable
16
+
17
+ import numpyro.distributions as dist
18
+ from jax import random
19
+ from jaxtyping import Float
20
+ from numpyro.infer.autoguide import AutoGuide
21
+ from numpyro.infer.reparam import Reparam
22
+ from numpyro.optim import _NumPyroOptim
23
+
24
+ from numpyro_forecast.functional import (
25
+ Horizon,
26
+ draw_posterior,
27
+ fit_mcmc,
28
+ fit_svi,
29
+ )
30
+ from numpyro_forecast.functional import forecast as _forecast
31
+ from numpyro_forecast.functional import predict as _predict
32
+ from numpyro_forecast.functional import time_series as _time_series
33
+ from numpyro_forecast.typing import Array, ForecastModel
34
+
35
+
36
+ class ForecastingModel(abc.ABC):
37
+ """Abstract base class for forecasting models.
38
+
39
+ Subclasses implement :meth:`model`, which must call :meth:`predict` exactly
40
+ once. The instance itself is the (pure) NumPyro model function with signature
41
+ ``model_instance(covariates, data=None)``: the forecast horizon is inferred
42
+ from the shapes (``future = covariates.shape[-2] - data.shape[-2]``).
43
+
44
+ This is the object-oriented façade over the functional API: :meth:`time_series`
45
+ and :meth:`predict` delegate to the free functions in
46
+ :mod:`numpyro_forecast.functional`, passing the current
47
+ :class:`~numpyro_forecast.functional.Horizon`.
48
+ """
49
+
50
+ def __init__(self) -> None:
51
+ self._horizon: Horizon | None = None
52
+
53
+ @abc.abstractmethod
54
+ def model(self, zero_data: Array | None, covariates: Array) -> None:
55
+ """Define the generative model and call :meth:`predict` exactly once.
56
+
57
+ Parameters
58
+ ----------
59
+ zero_data
60
+ Zeros shaped like the data extended to the covariate duration
61
+ (shape/dtype only; ``None`` during pure prior sampling).
62
+ covariates
63
+ Covariates with time at axis ``-2`` and shape
64
+ ``(*batch, duration, cov)``.
65
+ """
66
+ raise NotImplementedError
67
+
68
+ def _require_horizon(self) -> Horizon:
69
+ if self._horizon is None:
70
+ msg = "model state is only available during a model call"
71
+ raise RuntimeError(msg)
72
+ return self._horizon
73
+
74
+ @property
75
+ def duration(self) -> int:
76
+ """Total horizon length ``t + future`` (in time steps)."""
77
+ return self._require_horizon().duration
78
+
79
+ @property
80
+ def t_obs(self) -> int:
81
+ """Number of observed (in-sample) time steps ``t``."""
82
+ return self._require_horizon().t_obs
83
+
84
+ @property
85
+ def future(self) -> int:
86
+ """Number of forecast time steps ``f`` (``0`` while training)."""
87
+ return self._require_horizon().future
88
+
89
+ def time_series(
90
+ self,
91
+ name: str,
92
+ dist_fn: Callable[[], dist.Distribution],
93
+ *,
94
+ reparam: Reparam | None = None,
95
+ ) -> Array:
96
+ """Sample a time-varying latent over the full horizon.
97
+
98
+ Thin wrapper over :func:`numpyro_forecast.functional.time_series`.
99
+
100
+ Parameters
101
+ ----------
102
+ name
103
+ Base sample-site name for the in-sample latent.
104
+ dist_fn
105
+ Zero-argument callable returning the per-step prior distribution.
106
+ reparam
107
+ Optional reparameterization (e.g. ``LocScaleReparam``) applied to
108
+ both the in-sample and forecast sites.
109
+
110
+ Returns
111
+ -------
112
+ Array
113
+ The latent over the full horizon with time at axis ``-2``.
114
+ """
115
+ return _time_series(self._require_horizon(), name, dist_fn, reparam=reparam)
116
+
117
+ def predict(self, noise_dist: dist.Distribution, prediction: Array) -> None:
118
+ """Register the observation/forecast sites for the model.
119
+
120
+ Thin wrapper over :func:`numpyro_forecast.functional.predict`.
121
+
122
+ Parameters
123
+ ----------
124
+ noise_dist
125
+ Zero-centered observation noise (e.g. ``Normal(0, sigma)``).
126
+ prediction
127
+ Deterministic mean with time at axis ``-2``, shape
128
+ ``(*batch, duration, obs)``.
129
+ """
130
+ _predict(self._require_horizon(), noise_dist, prediction)
131
+
132
+ def __call__(self, covariates: Array, data: Array | None = None) -> None:
133
+ """Run the model as a NumPyro model function.
134
+
135
+ Parameters
136
+ ----------
137
+ covariates
138
+ Covariates with time at axis ``-2`` spanning the full horizon.
139
+ data
140
+ Observed data with time at axis ``-2`` (``None`` for prior sampling).
141
+ """
142
+ horizon = Horizon.from_data(covariates, data)
143
+ self._horizon = horizon
144
+ try:
145
+ self.model(horizon.zero_data, covariates)
146
+ finally:
147
+ self._horizon = None
148
+
149
+
150
+ class _BaseForecaster(abc.ABC):
151
+ """Shared forecasting logic over a fitted posterior."""
152
+
153
+ def __init__(self, model: ForecastModel) -> None:
154
+ self.model = model
155
+
156
+ @abc.abstractmethod
157
+ def _draw_posterior(self, rng_key: Array, num_samples: int) -> dict[str, Array]:
158
+ """Return ``num_samples`` posterior draws of the latent sites."""
159
+ raise NotImplementedError
160
+
161
+ def __call__(
162
+ self,
163
+ rng_key: Array,
164
+ data: Array,
165
+ covariates: Array,
166
+ num_samples: int,
167
+ *,
168
+ batch_size: int | None = None,
169
+ ) -> Float[Array, " sample *batch future obs"]:
170
+ """Sample forecasts for the steps in ``[t, duration)``.
171
+
172
+ Parameters
173
+ ----------
174
+ rng_key
175
+ PRNG key.
176
+ data
177
+ Observed data with time at axis ``-2`` and length ``t``.
178
+ covariates
179
+ Covariates with time at axis ``-2`` and length ``duration > t``.
180
+ num_samples
181
+ Number of forecast samples to draw.
182
+ batch_size
183
+ Optional chunk size for sampling (caps peak memory).
184
+
185
+ Returns
186
+ -------
187
+ Float[Array, " sample *batch future obs"]
188
+ Forecast samples over the ``future = duration - t`` horizon.
189
+ """
190
+ if data.shape[-2] >= covariates.shape[-2]:
191
+ msg = "covariates must extend beyond data along the time axis"
192
+ raise ValueError(msg)
193
+ if num_samples <= 0:
194
+ msg = "num_samples must be positive"
195
+ raise ValueError(msg)
196
+ key_post, key_pred = random.split(rng_key)
197
+ posterior = self._draw_posterior(key_post, num_samples)
198
+ return _forecast(key_pred, self.model, posterior, data, covariates, batch_size=batch_size)
199
+
200
+
201
+ class Forecaster(_BaseForecaster):
202
+ """Fit a forecasting model with stochastic variational inference.
203
+
204
+ Parameters
205
+ ----------
206
+ rng_key
207
+ PRNG key for inference.
208
+ model
209
+ The forecasting model to fit (OOP instance or functional model).
210
+ data
211
+ In-sample data with time at axis ``-2``.
212
+ covariates
213
+ Covariates with time at axis ``-2`` and the same duration as ``data``.
214
+ guide
215
+ Variational guide; defaults to ``AutoNormal(model)``.
216
+ optim
217
+ NumPyro optimizer; defaults to ``Adam(0.01)``.
218
+ num_steps
219
+ Number of SVI steps.
220
+ num_particles
221
+ Number of ELBO particles.
222
+ progress_bar
223
+ Whether to display the SVI progress bar.
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ rng_key: Array,
229
+ model: ForecastModel,
230
+ data: Array,
231
+ covariates: Array,
232
+ *,
233
+ guide: AutoGuide | None = None,
234
+ optim: _NumPyroOptim | None = None,
235
+ num_steps: int = 1_001,
236
+ num_particles: int = 1,
237
+ progress_bar: bool = False,
238
+ ) -> None:
239
+ super().__init__(model)
240
+ self._fit = fit_svi(
241
+ rng_key,
242
+ model,
243
+ data,
244
+ covariates,
245
+ guide=guide,
246
+ optim=optim,
247
+ num_steps=num_steps,
248
+ num_particles=num_particles,
249
+ progress_bar=progress_bar,
250
+ )
251
+ self.guide: AutoGuide = self._fit.guide
252
+ self.params: dict[str, Array] = self._fit.params
253
+ self.losses: Array = self._fit.losses
254
+
255
+ def _draw_posterior(self, rng_key: Array, num_samples: int) -> dict[str, Array]:
256
+ return draw_posterior(rng_key, self._fit, num_samples)
257
+
258
+
259
+ class HMCForecaster(_BaseForecaster):
260
+ """Fit a forecasting model with NUTS (Hamiltonian Monte Carlo).
261
+
262
+ Parameters
263
+ ----------
264
+ rng_key
265
+ PRNG key for inference.
266
+ model
267
+ The forecasting model to fit (OOP instance or functional model).
268
+ data
269
+ In-sample data with time at axis ``-2``.
270
+ covariates
271
+ Covariates with time at axis ``-2`` and the same duration as ``data``.
272
+ num_warmup
273
+ Number of warmup steps.
274
+ num_samples
275
+ Number of posterior samples.
276
+ num_chains
277
+ Number of MCMC chains.
278
+ progress_bar
279
+ Whether to display the MCMC progress bar.
280
+ """
281
+
282
+ def __init__(
283
+ self,
284
+ rng_key: Array,
285
+ model: ForecastModel,
286
+ data: Array,
287
+ covariates: Array,
288
+ *,
289
+ num_warmup: int = 1_000,
290
+ num_samples: int = 1_000,
291
+ num_chains: int = 1,
292
+ progress_bar: bool = False,
293
+ ) -> None:
294
+ super().__init__(model)
295
+ self._fit = fit_mcmc(
296
+ rng_key,
297
+ model,
298
+ data,
299
+ covariates,
300
+ num_warmup=num_warmup,
301
+ num_samples=num_samples,
302
+ num_chains=num_chains,
303
+ progress_bar=progress_bar,
304
+ )
305
+ self.posterior_samples: dict[str, Array] = self._fit.samples
306
+
307
+ def _draw_posterior(self, rng_key: Array, num_samples: int) -> dict[str, Array]:
308
+ return draw_posterior(rng_key, self._fit, num_samples)